mirror of
https://github.com/openai/codex.git
synced 2026-03-24 07:33:50 +00:00
Compare commits
11 Commits
dev/cc/ref
...
codex-exp-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c2fef72f12 | ||
|
|
3f36ea91e3 | ||
|
|
55be8ca802 | ||
|
|
1a2536d12b | ||
|
|
732d7ac81f | ||
|
|
1f3ec4172c | ||
|
|
52fdfbcfb8 | ||
|
|
2321532dec | ||
|
|
de98643403 | ||
|
|
fb27c20581 | ||
|
|
bccce0f2d8 |
@@ -8,6 +8,10 @@ license.workspace = true
|
||||
name = "codex"
|
||||
path = "src/main.rs"
|
||||
|
||||
[[bin]]
|
||||
name = "codex-exp-2"
|
||||
path = "src/main.rs"
|
||||
|
||||
[lib]
|
||||
name = "codex_cli"
|
||||
path = "src/lib.rs"
|
||||
|
||||
@@ -31,6 +31,7 @@ use codex_tui::update_action::UpdateAction;
|
||||
use codex_utils_cli::CliConfigOverrides;
|
||||
use owo_colors::OwoColorize;
|
||||
use std::io::IsTerminal;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use supports_color::Stream;
|
||||
|
||||
@@ -84,6 +85,16 @@ struct MultitoolCli {
|
||||
subcommand: Option<Subcommand>,
|
||||
}
|
||||
|
||||
const CODEX_EXP_2_BIN_NAME: &str = "codex-exp-2";
|
||||
// Benchmarked on a 16-core M4 Max with a 20-task cross-repo workload.
|
||||
const CODEX_EXP_2_DEFAULT_OVERRIDES: [&str; 5] = [
|
||||
r#"model="gpt-5.4""#,
|
||||
r#"model_reasoning_effort="xhigh""#,
|
||||
r#"service_tier="fast""#,
|
||||
"agents.max_threads=8",
|
||||
"agents.max_depth=3",
|
||||
];
|
||||
|
||||
#[derive(Debug, clap::Subcommand)]
|
||||
enum Subcommand {
|
||||
/// Run Codex non-interactively.
|
||||
@@ -607,6 +618,11 @@ async fn cli_main(arg0_paths: Arg0DispatchPaths) -> anyhow::Result<()> {
|
||||
subcommand,
|
||||
} = MultitoolCli::parse();
|
||||
|
||||
prepend_bin_tuned_defaults(
|
||||
&mut root_config_overrides,
|
||||
current_invocation_name().as_deref(),
|
||||
);
|
||||
|
||||
// Fold --enable/--disable into config overrides so they flow to all subcommands.
|
||||
let toggle_overrides = feature_toggles.to_overrides()?;
|
||||
root_config_overrides.raw_overrides.extend(toggle_overrides);
|
||||
@@ -1045,6 +1061,28 @@ fn prepend_config_flags(
|
||||
.splice(0..0, cli_config_overrides.raw_overrides);
|
||||
}
|
||||
|
||||
fn current_invocation_name() -> Option<String> {
|
||||
std::env::args_os()
|
||||
.next()
|
||||
.and_then(|arg0| Path::new(&arg0).file_stem()?.to_str().map(str::to_owned))
|
||||
}
|
||||
|
||||
fn prepend_bin_tuned_defaults(
|
||||
root_config_overrides: &mut CliConfigOverrides,
|
||||
invocation_name: Option<&str>,
|
||||
) {
|
||||
if invocation_name != Some(CODEX_EXP_2_BIN_NAME) {
|
||||
return;
|
||||
}
|
||||
|
||||
root_config_overrides.raw_overrides.splice(
|
||||
0..0,
|
||||
CODEX_EXP_2_DEFAULT_OVERRIDES
|
||||
.into_iter()
|
||||
.map(ToString::to_string),
|
||||
);
|
||||
}
|
||||
|
||||
fn reject_remote_mode_for_subcommand(remote: Option<&str>, subcommand: &str) -> anyhow::Result<()> {
|
||||
if let Some(remote) = remote {
|
||||
anyhow::bail!(
|
||||
@@ -1774,4 +1812,34 @@ mod tests {
|
||||
.expect_err("feature should be rejected");
|
||||
assert_eq!(err.to_string(), "Unknown feature flag: does_not_exist");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn exp_2_defaults_apply_only_for_exp_2_bin() {
|
||||
let mut overrides = CliConfigOverrides::default();
|
||||
prepend_bin_tuned_defaults(&mut overrides, Some("codex-exp-2"));
|
||||
assert_eq!(
|
||||
overrides.raw_overrides,
|
||||
CODEX_EXP_2_DEFAULT_OVERRIDES.map(str::to_string).to_vec()
|
||||
);
|
||||
|
||||
let mut regular_overrides = CliConfigOverrides::default();
|
||||
prepend_bin_tuned_defaults(&mut regular_overrides, Some("codex"));
|
||||
assert!(regular_overrides.raw_overrides.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn exp_2_defaults_stay_lower_precedence_than_user_overrides() {
|
||||
let mut overrides = CliConfigOverrides {
|
||||
raw_overrides: vec!["agents.max_threads=99".to_string()],
|
||||
};
|
||||
prepend_bin_tuned_defaults(&mut overrides, Some("codex-exp-2"));
|
||||
assert_eq!(
|
||||
overrides.raw_overrides,
|
||||
CODEX_EXP_2_DEFAULT_OVERRIDES
|
||||
.map(str::to_string)
|
||||
.into_iter()
|
||||
.chain(std::iter::once("agents.max_threads=99".to_string()))
|
||||
.collect::<Vec<_>>()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -125,6 +125,8 @@ use futures::future::BoxFuture;
|
||||
use futures::future::Shared;
|
||||
use futures::prelude::*;
|
||||
use futures::stream::FuturesOrdered;
|
||||
use reqwest::header::HeaderMap;
|
||||
use reqwest::header::HeaderValue;
|
||||
use rmcp::model::ListResourceTemplatesResult;
|
||||
use rmcp::model::ListResourcesResult;
|
||||
use rmcp::model::PaginatedRequestParams;
|
||||
@@ -3943,6 +3945,12 @@ impl Session {
|
||||
arguments: Option<serde_json::Value>,
|
||||
meta: Option<serde_json::Value>,
|
||||
) -> anyhow::Result<CallToolResult> {
|
||||
if server == CODEX_APPS_MCP_SERVER_NAME
|
||||
&& let Some((turn_context, _)) = self.active_turn_context_and_cancellation_token().await
|
||||
{
|
||||
self.sync_mcp_request_headers_for_turn(turn_context.as_ref())
|
||||
.await;
|
||||
}
|
||||
self.services
|
||||
.mcp_connection_manager
|
||||
.read()
|
||||
@@ -3951,6 +3959,45 @@ impl Session {
|
||||
.await
|
||||
}
|
||||
|
||||
pub(crate) async fn sync_mcp_request_headers_for_turn(&self, turn_context: &TurnContext) {
|
||||
let mut request_headers = HeaderMap::new();
|
||||
let session_id = self.conversation_id.to_string();
|
||||
if let Ok(value) = HeaderValue::from_str(&session_id) {
|
||||
request_headers.insert("session_id", value.clone());
|
||||
request_headers.insert("x-client-request-id", value);
|
||||
}
|
||||
if let Some(turn_metadata) = turn_context.turn_metadata_state.current_header_value()
|
||||
&& let Ok(value) = HeaderValue::from_str(&turn_metadata)
|
||||
{
|
||||
request_headers.insert(crate::X_CODEX_TURN_METADATA_HEADER, value);
|
||||
}
|
||||
|
||||
let request_headers = if request_headers.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(request_headers)
|
||||
};
|
||||
self.services
|
||||
.mcp_connection_manager
|
||||
.read()
|
||||
.await
|
||||
.set_request_headers_for_server(
|
||||
crate::mcp::CODEX_APPS_MCP_SERVER_NAME,
|
||||
request_headers,
|
||||
);
|
||||
}
|
||||
|
||||
pub(crate) async fn clear_mcp_request_headers(&self) {
|
||||
self.services
|
||||
.mcp_connection_manager
|
||||
.read()
|
||||
.await
|
||||
.set_request_headers_for_server(
|
||||
crate::mcp::CODEX_APPS_MCP_SERVER_NAME,
|
||||
/*request_headers*/ None,
|
||||
);
|
||||
}
|
||||
|
||||
pub(crate) async fn parse_mcp_tool_name(
|
||||
&self,
|
||||
name: &str,
|
||||
|
||||
@@ -26,6 +26,8 @@ use codex_protocol::protocol::ReadOnlyAccess;
|
||||
use codex_protocol::protocol::SandboxPolicy;
|
||||
use codex_protocol::request_permissions::PermissionGrantScope;
|
||||
use codex_protocol::request_permissions::RequestPermissionProfile;
|
||||
use codex_protocol::user_input::UserInput;
|
||||
use reqwest::header::HeaderValue;
|
||||
use tracing::Span;
|
||||
|
||||
use crate::protocol::CompactedItem;
|
||||
@@ -56,6 +58,7 @@ use crate::tools::handlers::UnifiedExecHandler;
|
||||
use crate::tools::registry::ToolHandler;
|
||||
use crate::tools::router::ToolCallSource;
|
||||
use crate::turn_diff_tracker::TurnDiffTracker;
|
||||
use async_trait::async_trait;
|
||||
use codex_app_server_protocol::AppInfo;
|
||||
use codex_execpolicy::Decision;
|
||||
use codex_execpolicy::NetworkRuleProtocol;
|
||||
@@ -78,7 +81,10 @@ use opentelemetry::trace::TracerProvider as _;
|
||||
use opentelemetry_sdk::trace::SdkTracerProvider;
|
||||
use std::path::Path;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::Notify;
|
||||
use tokio::time::sleep;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tokio_util::task::AbortOnDropHandle;
|
||||
use tracing_opentelemetry::OpenTelemetrySpanExt;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
@@ -2716,6 +2722,99 @@ async fn request_permissions_is_auto_denied_when_granular_policy_blocks_tool_req
|
||||
);
|
||||
}
|
||||
|
||||
struct NoopTask;
|
||||
|
||||
#[async_trait]
|
||||
impl SessionTask for NoopTask {
|
||||
fn kind(&self) -> TaskKind {
|
||||
TaskKind::Regular
|
||||
}
|
||||
|
||||
fn span_name(&self) -> &'static str {
|
||||
"noop"
|
||||
}
|
||||
|
||||
async fn run(
|
||||
self: Arc<Self>,
|
||||
_ctx: Arc<SessionTaskContext>,
|
||||
_turn_context: Arc<TurnContext>,
|
||||
_input: Vec<UserInput>,
|
||||
_cancellation_token: CancellationToken,
|
||||
) -> Option<String> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn call_tool_refreshes_codex_apps_request_headers_from_active_turn() {
|
||||
let (session, turn_context) = make_session_and_context().await;
|
||||
let session = Arc::new(session);
|
||||
let turn_context = Arc::new(turn_context);
|
||||
|
||||
session
|
||||
.services
|
||||
.mcp_connection_manager
|
||||
.write()
|
||||
.await
|
||||
.register_test_server_for_request_headers(CODEX_APPS_MCP_SERVER_NAME);
|
||||
|
||||
session
|
||||
.sync_mcp_request_headers_for_turn(turn_context.as_ref())
|
||||
.await;
|
||||
let base_header = turn_context
|
||||
.turn_metadata_state
|
||||
.current_header_value()
|
||||
.expect("base turn metadata header");
|
||||
|
||||
let updated_header = serde_json::json!({
|
||||
"turn_id": turn_context.sub_id,
|
||||
"sandbox": "test",
|
||||
"workspaces": {
|
||||
"/tmp/repo": {
|
||||
"has_changes": true
|
||||
}
|
||||
}
|
||||
})
|
||||
.to_string();
|
||||
turn_context
|
||||
.turn_metadata_state
|
||||
.set_enriched_header_for_tests(Some(updated_header.clone()));
|
||||
|
||||
let mut active_turn = ActiveTurn::default();
|
||||
let handle = tokio::spawn(async {});
|
||||
active_turn.add_task(crate::state::RunningTask {
|
||||
done: Arc::new(Notify::new()),
|
||||
kind: TaskKind::Regular,
|
||||
task: Arc::new(NoopTask),
|
||||
cancellation_token: CancellationToken::new(),
|
||||
handle: Arc::new(AbortOnDropHandle::new(handle)),
|
||||
turn_context: Arc::clone(&turn_context),
|
||||
_timer: None,
|
||||
});
|
||||
*session.active_turn.lock().await = Some(active_turn);
|
||||
|
||||
let _err = session
|
||||
.call_tool(CODEX_APPS_MCP_SERVER_NAME, "echo", None, None)
|
||||
.await
|
||||
.expect_err("test server is not initialized");
|
||||
|
||||
let headers = session
|
||||
.services
|
||||
.mcp_connection_manager
|
||||
.read()
|
||||
.await
|
||||
.request_headers_for_server(CODEX_APPS_MCP_SERVER_NAME)
|
||||
.expect("request headers should be tracked for codex apps");
|
||||
assert_eq!(
|
||||
headers.get(crate::X_CODEX_TURN_METADATA_HEADER),
|
||||
Some(&HeaderValue::from_str(&updated_header).expect("valid enriched header")),
|
||||
);
|
||||
assert_ne!(
|
||||
headers.get(crate::X_CODEX_TURN_METADATA_HEADER),
|
||||
Some(&HeaderValue::from_str(&base_header).expect("valid base header")),
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn submit_with_id_captures_current_span_trace_context() {
|
||||
let (session, _turn_context) = make_session_and_context().await;
|
||||
|
||||
@@ -423,6 +423,7 @@ impl ManagedClient {
|
||||
#[derive(Clone)]
|
||||
struct AsyncManagedClient {
|
||||
client: Shared<BoxFuture<'static, Result<ManagedClient, StartupOutcomeError>>>,
|
||||
request_headers: Arc<StdMutex<Option<reqwest::header::HeaderMap>>>,
|
||||
startup_snapshot: Option<Vec<ToolInfo>>,
|
||||
startup_complete: Arc<AtomicBool>,
|
||||
tool_plugin_provenance: Arc<ToolPluginProvenance>,
|
||||
@@ -448,17 +449,26 @@ impl AsyncManagedClient {
|
||||
codex_apps_tools_cache_context.as_ref(),
|
||||
)
|
||||
.map(|tools| filter_tools(tools, &tool_filter));
|
||||
let request_headers = Arc::new(StdMutex::new(None));
|
||||
let startup_tool_filter = tool_filter;
|
||||
let startup_complete = Arc::new(AtomicBool::new(false));
|
||||
let startup_complete_for_fut = Arc::clone(&startup_complete);
|
||||
let request_headers_for_client = Arc::clone(&request_headers);
|
||||
let fut = async move {
|
||||
let outcome = async {
|
||||
if let Err(error) = validate_mcp_server_name(&server_name) {
|
||||
return Err(error.into());
|
||||
}
|
||||
|
||||
let client =
|
||||
Arc::new(make_rmcp_client(&server_name, config.transport, store_mode).await?);
|
||||
let client = Arc::new(
|
||||
make_rmcp_client(
|
||||
&server_name,
|
||||
config.transport,
|
||||
store_mode,
|
||||
request_headers_for_client,
|
||||
)
|
||||
.await?,
|
||||
);
|
||||
match start_server_task(
|
||||
server_name,
|
||||
client,
|
||||
@@ -495,6 +505,7 @@ impl AsyncManagedClient {
|
||||
|
||||
Self {
|
||||
client,
|
||||
request_headers,
|
||||
startup_snapshot,
|
||||
startup_complete,
|
||||
tool_plugin_provenance,
|
||||
@@ -576,6 +587,14 @@ impl AsyncManagedClient {
|
||||
let managed = self.client().await?;
|
||||
managed.notify_sandbox_state_change(sandbox_state).await
|
||||
}
|
||||
|
||||
fn set_request_headers(&self, request_headers: Option<reqwest::header::HeaderMap>) {
|
||||
let mut guard = self
|
||||
.request_headers
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
*guard = request_headers;
|
||||
}
|
||||
}
|
||||
|
||||
pub const MCP_SANDBOX_STATE_CAPABILITY: &str = "codex/sandbox-state";
|
||||
@@ -617,6 +636,40 @@ impl McpConnectionManager {
|
||||
Self::new_uninitialized(approval_policy)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn register_test_server_for_request_headers(&mut self, server_name: &str) {
|
||||
let failed_client = futures::future::ready::<Result<ManagedClient, StartupOutcomeError>>(
|
||||
Err(StartupOutcomeError::Failed {
|
||||
error: "test request headers stub".to_string(),
|
||||
}),
|
||||
)
|
||||
.boxed()
|
||||
.shared();
|
||||
self.clients.insert(
|
||||
server_name.to_string(),
|
||||
AsyncManagedClient {
|
||||
client: failed_client,
|
||||
request_headers: Arc::new(StdMutex::new(None)),
|
||||
startup_snapshot: None,
|
||||
startup_complete: Arc::new(AtomicBool::new(true)),
|
||||
tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn request_headers_for_server(
|
||||
&self,
|
||||
server_name: &str,
|
||||
) -> Option<reqwest::header::HeaderMap> {
|
||||
let client = self.clients.get(server_name)?;
|
||||
client
|
||||
.request_headers
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
||||
.clone()
|
||||
}
|
||||
|
||||
pub(crate) fn has_servers(&self) -> bool {
|
||||
!self.clients.is_empty()
|
||||
}
|
||||
@@ -1046,6 +1099,16 @@ impl McpConnectionManager {
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn set_request_headers_for_server(
|
||||
&self,
|
||||
server_name: &str,
|
||||
request_headers: Option<reqwest::header::HeaderMap>,
|
||||
) {
|
||||
if let Some(client) = self.clients.get(server_name) {
|
||||
client.set_request_headers(request_headers);
|
||||
}
|
||||
}
|
||||
|
||||
/// List resources from the specified server.
|
||||
pub async fn list_resources(
|
||||
&self,
|
||||
@@ -1429,6 +1492,7 @@ async fn make_rmcp_client(
|
||||
server_name: &str,
|
||||
transport: McpServerTransportConfig,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
request_headers: Arc<StdMutex<Option<reqwest::header::HeaderMap>>>,
|
||||
) -> Result<RmcpClient, StartupOutcomeError> {
|
||||
match transport {
|
||||
McpServerTransportConfig::Stdio {
|
||||
@@ -1462,6 +1526,7 @@ async fn make_rmcp_client(
|
||||
http_headers,
|
||||
env_http_headers,
|
||||
store_mode,
|
||||
request_headers,
|
||||
)
|
||||
.await
|
||||
.map_err(StartupOutcomeError::from)
|
||||
|
||||
@@ -4,6 +4,7 @@ use codex_protocol::protocol::McpAuthStatus;
|
||||
use rmcp::model::JsonObject;
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex as StdMutex;
|
||||
use tempfile::tempdir;
|
||||
|
||||
fn create_test_tool(server_name: &str, tool_name: &str) -> ToolInfo {
|
||||
@@ -413,6 +414,7 @@ async fn list_all_tools_uses_startup_snapshot_while_client_is_pending() {
|
||||
CODEX_APPS_MCP_SERVER_NAME.to_string(),
|
||||
AsyncManagedClient {
|
||||
client: pending_client,
|
||||
request_headers: Arc::new(StdMutex::new(None)),
|
||||
startup_snapshot: Some(startup_tools),
|
||||
startup_complete: Arc::new(std::sync::atomic::AtomicBool::new(false)),
|
||||
tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()),
|
||||
@@ -438,6 +440,7 @@ async fn list_all_tools_blocks_while_client_is_pending_without_startup_snapshot(
|
||||
CODEX_APPS_MCP_SERVER_NAME.to_string(),
|
||||
AsyncManagedClient {
|
||||
client: pending_client,
|
||||
request_headers: Arc::new(StdMutex::new(None)),
|
||||
startup_snapshot: None,
|
||||
startup_complete: Arc::new(std::sync::atomic::AtomicBool::new(false)),
|
||||
tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()),
|
||||
@@ -460,6 +463,7 @@ async fn list_all_tools_does_not_block_when_startup_snapshot_cache_hit_is_empty(
|
||||
CODEX_APPS_MCP_SERVER_NAME.to_string(),
|
||||
AsyncManagedClient {
|
||||
client: pending_client,
|
||||
request_headers: Arc::new(StdMutex::new(None)),
|
||||
startup_snapshot: Some(Vec::new()),
|
||||
startup_complete: Arc::new(std::sync::atomic::AtomicBool::new(false)),
|
||||
tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()),
|
||||
@@ -492,6 +496,7 @@ async fn list_all_tools_uses_startup_snapshot_when_client_startup_fails() {
|
||||
CODEX_APPS_MCP_SERVER_NAME.to_string(),
|
||||
AsyncManagedClient {
|
||||
client: failed_client,
|
||||
request_headers: Arc::new(StdMutex::new(None)),
|
||||
startup_snapshot: Some(startup_tools),
|
||||
startup_complete,
|
||||
tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()),
|
||||
|
||||
@@ -153,6 +153,8 @@ impl Session {
|
||||
) {
|
||||
self.abort_all_tasks(TurnAbortReason::Replaced).await;
|
||||
self.clear_connector_selection().await;
|
||||
self.sync_mcp_request_headers_for_turn(turn_context.as_ref())
|
||||
.await;
|
||||
|
||||
let task: Arc<dyn SessionTask> = Arc::new(task);
|
||||
let task_kind = task.kind();
|
||||
@@ -233,6 +235,7 @@ impl Session {
|
||||
// in-flight approval wait can surface as a model-visible rejection before TurnAborted.
|
||||
active_turn.clear_pending().await;
|
||||
}
|
||||
self.clear_mcp_request_headers().await;
|
||||
}
|
||||
|
||||
pub async fn on_task_finished(
|
||||
@@ -262,6 +265,9 @@ impl Session {
|
||||
*active = None;
|
||||
}
|
||||
drop(active);
|
||||
if should_clear_active_turn {
|
||||
self.clear_mcp_request_headers().await;
|
||||
}
|
||||
if !pending_input.is_empty() {
|
||||
for pending_input_item in pending_input {
|
||||
match inspect_pending_input(self, &turn_context, pending_input_item).await {
|
||||
|
||||
@@ -119,43 +119,21 @@ impl ToolHandler for Handler {
|
||||
}
|
||||
}
|
||||
|
||||
let statuses = if !initial_final_statuses.is_empty() {
|
||||
initial_final_statuses
|
||||
} else {
|
||||
let mut futures = FuturesUnordered::new();
|
||||
for (id, rx) in status_rxs.into_iter() {
|
||||
let session = session.clone();
|
||||
futures.push(wait_for_final_status(session, id, rx));
|
||||
}
|
||||
let mut results = Vec::new();
|
||||
let deadline = Instant::now() + Duration::from_millis(timeout_ms as u64);
|
||||
loop {
|
||||
match timeout_at(deadline, futures.next()).await {
|
||||
Ok(Some(Some(result))) => {
|
||||
results.push(result);
|
||||
break;
|
||||
}
|
||||
Ok(Some(None)) => continue,
|
||||
Ok(None) | Err(_) => break,
|
||||
}
|
||||
}
|
||||
if !results.is_empty() {
|
||||
loop {
|
||||
match futures.next().now_or_never() {
|
||||
Some(Some(Some(result))) => results.push(result),
|
||||
Some(Some(None)) => continue,
|
||||
Some(None) | None => break,
|
||||
}
|
||||
}
|
||||
}
|
||||
results
|
||||
};
|
||||
|
||||
let statuses_map = statuses.clone().into_iter().collect::<HashMap<_, _>>();
|
||||
let expected_status_count = receiver_thread_ids.len();
|
||||
let deadline = Instant::now() + Duration::from_millis(timeout_ms as u64);
|
||||
let statuses_map = wait_for_requested_statuses(
|
||||
session.clone(),
|
||||
status_rxs,
|
||||
initial_final_statuses,
|
||||
expected_status_count,
|
||||
args.wait_for_all,
|
||||
deadline,
|
||||
)
|
||||
.await;
|
||||
let agent_statuses = build_wait_agent_statuses(&statuses_map, &receiver_agents);
|
||||
let result = WaitAgentResult {
|
||||
status: statuses_map.clone(),
|
||||
timed_out: statuses.is_empty(),
|
||||
timed_out: wait_timed_out(statuses_map.len(), expected_status_count, args.wait_for_all),
|
||||
};
|
||||
|
||||
session
|
||||
@@ -179,6 +157,8 @@ impl ToolHandler for Handler {
|
||||
struct WaitArgs {
|
||||
ids: Vec<String>,
|
||||
timeout_ms: Option<i64>,
|
||||
#[serde(default)]
|
||||
wait_for_all: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, PartialEq, Eq)]
|
||||
@@ -226,3 +206,77 @@ async fn wait_for_final_status(
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn wait_for_requested_statuses(
|
||||
session: Arc<Session>,
|
||||
status_rxs: Vec<(ThreadId, Receiver<AgentStatus>)>,
|
||||
initial_statuses: Vec<(ThreadId, AgentStatus)>,
|
||||
expected_count: usize,
|
||||
wait_for_all: bool,
|
||||
deadline: Instant,
|
||||
) -> HashMap<ThreadId, AgentStatus> {
|
||||
let mut statuses = initial_statuses.into_iter().collect::<HashMap<_, _>>();
|
||||
if wait_condition_satisfied(statuses.len(), expected_count, wait_for_all) {
|
||||
return statuses;
|
||||
}
|
||||
|
||||
let mut futures = FuturesUnordered::new();
|
||||
for (id, rx) in status_rxs {
|
||||
if statuses.contains_key(&id) {
|
||||
continue;
|
||||
}
|
||||
let session = session.clone();
|
||||
futures.push(wait_for_final_status(session, id, rx));
|
||||
}
|
||||
|
||||
loop {
|
||||
if wait_condition_satisfied(statuses.len(), expected_count, wait_for_all) {
|
||||
break;
|
||||
}
|
||||
|
||||
match timeout_at(deadline, futures.next()).await {
|
||||
Ok(Some(Some((id, status)))) => {
|
||||
statuses.insert(id, status);
|
||||
if !wait_for_all {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(Some(None)) => continue,
|
||||
Ok(None) | Err(_) => break,
|
||||
}
|
||||
}
|
||||
|
||||
drain_ready_final_statuses(&mut futures, &mut statuses);
|
||||
statuses
|
||||
}
|
||||
|
||||
fn drain_ready_final_statuses(
|
||||
futures: &mut FuturesUnordered<impl futures::Future<Output = Option<(ThreadId, AgentStatus)>>>,
|
||||
statuses: &mut HashMap<ThreadId, AgentStatus>,
|
||||
) {
|
||||
loop {
|
||||
match futures.next().now_or_never() {
|
||||
Some(Some(Some((id, status)))) => {
|
||||
statuses.insert(id, status);
|
||||
}
|
||||
Some(Some(None)) => continue,
|
||||
Some(None) | None => break,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn wait_condition_satisfied(
|
||||
observed_count: usize,
|
||||
expected_count: usize,
|
||||
wait_for_all: bool,
|
||||
) -> bool {
|
||||
if wait_for_all {
|
||||
observed_count >= expected_count
|
||||
} else {
|
||||
observed_count > 0
|
||||
}
|
||||
}
|
||||
|
||||
fn wait_timed_out(observed_count: usize, expected_count: usize, wait_for_all: bool) -> bool {
|
||||
!wait_condition_satisfied(observed_count, expected_count, wait_for_all)
|
||||
}
|
||||
|
||||
@@ -31,6 +31,7 @@ use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::time::sleep;
|
||||
use tokio::time::timeout;
|
||||
|
||||
fn invocation(
|
||||
@@ -970,6 +971,105 @@ async fn wait_agent_returns_final_status_without_timeout() {
|
||||
assert_eq!(success, None);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn wait_agent_wait_for_all_does_not_return_after_first_final_status() {
|
||||
let (mut session, turn) = make_session_and_context().await;
|
||||
let manager = thread_manager();
|
||||
session.services.agent_control = manager.agent_control();
|
||||
let config = turn.config.as_ref().clone();
|
||||
let thread_a = manager
|
||||
.start_thread(config.clone())
|
||||
.await
|
||||
.expect("start thread");
|
||||
let thread_b = manager.start_thread(config).await.expect("start thread");
|
||||
let id_a = thread_a.thread_id;
|
||||
let thread_a_handle = Arc::clone(&thread_a.thread);
|
||||
|
||||
tokio::spawn(async move {
|
||||
sleep(Duration::from_millis(20)).await;
|
||||
let _ = thread_a_handle.submit(Op::Shutdown {}).await;
|
||||
});
|
||||
|
||||
let invocation = invocation(
|
||||
Arc::new(session),
|
||||
Arc::new(turn),
|
||||
"wait_agent",
|
||||
function_payload(json!({
|
||||
"ids": [id_a.to_string(), thread_b.thread_id.to_string()],
|
||||
"wait_for_all": true,
|
||||
"timeout_ms": 1000
|
||||
})),
|
||||
);
|
||||
|
||||
let early = timeout(
|
||||
Duration::from_millis(80),
|
||||
WaitAgentHandler.handle(invocation),
|
||||
)
|
||||
.await;
|
||||
assert!(
|
||||
early.is_err(),
|
||||
"wait_agent(wait_for_all=true) should keep waiting after the first agent finishes"
|
||||
);
|
||||
|
||||
let _ = thread_b
|
||||
.thread
|
||||
.submit(Op::Shutdown {})
|
||||
.await
|
||||
.expect("shutdown should submit");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn wait_agent_wait_for_all_returns_all_final_statuses() {
|
||||
let (mut session, turn) = make_session_and_context().await;
|
||||
let manager = thread_manager();
|
||||
session.services.agent_control = manager.agent_control();
|
||||
let config = turn.config.as_ref().clone();
|
||||
let thread_a = manager
|
||||
.start_thread(config.clone())
|
||||
.await
|
||||
.expect("start thread");
|
||||
let thread_b = manager.start_thread(config).await.expect("start thread");
|
||||
let id_a = thread_a.thread_id;
|
||||
let id_b = thread_b.thread_id;
|
||||
let thread_a_handle = Arc::clone(&thread_a.thread);
|
||||
let thread_b_handle = Arc::clone(&thread_b.thread);
|
||||
|
||||
tokio::spawn(async move {
|
||||
sleep(Duration::from_millis(20)).await;
|
||||
let _ = thread_a_handle.submit(Op::Shutdown {}).await;
|
||||
});
|
||||
tokio::spawn(async move {
|
||||
sleep(Duration::from_millis(40)).await;
|
||||
let _ = thread_b_handle.submit(Op::Shutdown {}).await;
|
||||
});
|
||||
|
||||
let invocation = invocation(
|
||||
Arc::new(session),
|
||||
Arc::new(turn),
|
||||
"wait_agent",
|
||||
function_payload(json!({
|
||||
"ids": [id_a.to_string(), id_b.to_string()],
|
||||
"wait_for_all": true,
|
||||
"timeout_ms": 1000
|
||||
})),
|
||||
);
|
||||
let output = WaitAgentHandler
|
||||
.handle(invocation)
|
||||
.await
|
||||
.expect("wait_agent should succeed");
|
||||
let (content, success) = expect_text_output(output);
|
||||
let result: wait::WaitAgentResult =
|
||||
serde_json::from_str(&content).expect("wait_agent result should be json");
|
||||
assert_eq!(
|
||||
result,
|
||||
wait::WaitAgentResult {
|
||||
status: HashMap::from([(id_a, AgentStatus::Shutdown), (id_b, AgentStatus::Shutdown)]),
|
||||
timed_out: false
|
||||
}
|
||||
);
|
||||
assert_eq!(success, None);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn close_agent_submits_shutdown_and_returns_previous_status() {
|
||||
let (mut session, turn) = make_session_and_context().await;
|
||||
|
||||
@@ -178,12 +178,12 @@ fn wait_output_schema() -> JsonValue {
|
||||
"properties": {
|
||||
"status": {
|
||||
"type": "object",
|
||||
"description": "Final statuses keyed by agent id for agents that finished before the timeout.",
|
||||
"description": "Final statuses keyed by agent id for agents that finished before the timeout or before the requested wait condition was satisfied.",
|
||||
"additionalProperties": agent_status_output_schema()
|
||||
},
|
||||
"timed_out": {
|
||||
"type": "boolean",
|
||||
"description": "Whether the wait call returned due to timeout before any agent reached a final status."
|
||||
"description": "Whether the wait call returned due to timeout before the requested wait condition was satisfied. With wait_for_all=true, partial statuses may still be returned."
|
||||
}
|
||||
},
|
||||
"required": ["status", "timed_out"],
|
||||
@@ -1114,6 +1114,7 @@ fn create_spawn_agent_tool(config: &ToolsConfig) -> ToolSpec {
|
||||
|
||||
### After you delegate
|
||||
- Call wait_agent very sparingly. Only call wait_agent when you need the result immediately for the next critical-path step and you are blocked until it returns.
|
||||
- If you launch a fixed batch and need every result before the next step, spawn the whole batch first and use one wait_agent call with wait_for_all=true and a sufficiently long timeout instead of a short wait loop.
|
||||
- Do not redo delegated subagent tasks yourself; focus on integrating results or tackling non-overlapping work.
|
||||
- While the subagent is running in the background, do meaningful non-overlapping work immediately.
|
||||
- Do not repeatedly wait by reflex.
|
||||
@@ -1370,7 +1371,16 @@ fn create_wait_agent_tool() -> ToolSpec {
|
||||
JsonSchema::Array {
|
||||
items: Box::new(JsonSchema::String { description: None }),
|
||||
description: Some(
|
||||
"Agent ids to wait on. Pass multiple ids to wait for whichever finishes first."
|
||||
"Agent ids to wait on. Pass multiple ids to wait for whichever finishes first unless wait_for_all is true."
|
||||
.to_string(),
|
||||
),
|
||||
},
|
||||
);
|
||||
properties.insert(
|
||||
"wait_for_all".to_string(),
|
||||
JsonSchema::Boolean {
|
||||
description: Some(
|
||||
"When true, wait until every requested agent reaches a final status or the timeout expires. When false (default), return after the first requested agent reaches a final status."
|
||||
.to_string(),
|
||||
),
|
||||
},
|
||||
@@ -1386,7 +1396,7 @@ fn create_wait_agent_tool() -> ToolSpec {
|
||||
|
||||
ToolSpec::Function(ResponsesApiTool {
|
||||
name: "wait_agent".to_string(),
|
||||
description: "Wait for agents to reach a final status. Completed statuses may include the agent's final message. Returns empty status when timed out. Once the agent reaches a final status, a notification message will be received containing the same completed status."
|
||||
description: "Wait for agents to reach a final status. By default this returns when any requested agent finishes; set wait_for_all=true to wait for the whole batch. Completed statuses may include the agent's final message. Returns empty status when timed out before any requested agent finishes; with wait_for_all=true, partial statuses may still be returned on timeout."
|
||||
.to_string(),
|
||||
strict: false,
|
||||
defer_loading: None,
|
||||
|
||||
@@ -217,6 +217,14 @@ impl TurnMetadataState {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn set_enriched_header_for_tests(&self, header: Option<String>) {
|
||||
*self
|
||||
.enriched_header
|
||||
.write()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner) = header;
|
||||
}
|
||||
|
||||
async fn fetch_workspace_git_metadata(&self) -> WorkspaceGitMetadata {
|
||||
let (latest_git_commit_hash, associated_remote_urls, has_changes) = tokio::join!(
|
||||
get_head_commit_hash(&self.cwd),
|
||||
|
||||
@@ -100,7 +100,8 @@ Sub-agents are their to make you go fast and time is a big constraint so leverag
|
||||
|
||||
## Flow
|
||||
1. Understand the task.
|
||||
2. Spawn the optimal necessary sub-agents.
|
||||
2. Spawn the optimal necessary sub-agents. When multiple independent agents are needed, launch the whole batch up front before waiting.
|
||||
3. Coordinate them via wait_agent / send_input.
|
||||
If you need every agent result before the next step, prefer one wait_agent call with `wait_for_all=true` and a long timeout instead of repeatedly waiting for whichever agent finishes first.
|
||||
4. Iterate on this. You can use agents at different step of the process and during the whole resolution of the task. Never forget to use them.
|
||||
5. Ask the user before shutting sub-agents down unless you need to because you reached the agent limit.
|
||||
|
||||
@@ -12,4 +12,5 @@ This feature must be used wisely. For simple or straightforward tasks, you don't
|
||||
* Running tests or some config commands can output a large amount of logs. In order to optimize your own context, you can spawn an agent and ask it to do it for you. In such cases, you must tell this agent that it can't spawn another agent himself (to prevent infinite recursion)
|
||||
* When you're done with a sub-agent, don't forget to close it using `close_agent`.
|
||||
* Be careful on the `timeout_ms` parameter you choose for `wait_agent`. It should be wisely scaled.
|
||||
* If you launch a fixed batch and need every result before the next step, spawn the whole batch first and use one `wait_agent(..., wait_for_all=true)` call with a long timeout instead of a short wait loop.
|
||||
* Sub-agents have access to the same set of tools as you do so you must tell them if they are allowed to spawn sub-agents themselves or not.
|
||||
|
||||
@@ -30,6 +30,8 @@ use core_test_support::wait_for_event;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::Value;
|
||||
use serde_json::json;
|
||||
use std::fs;
|
||||
use std::process::Command;
|
||||
|
||||
const SEARCH_TOOL_DESCRIPTION_SNIPPETS: [&str; 2] = [
|
||||
"You have access to all the tools of the following apps/connectors",
|
||||
@@ -86,6 +88,15 @@ fn tool_search_output_tools(request: &ResponsesRequest, call_id: &str) -> Vec<Va
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
fn json_rpc_method(request: &wiremock::Request) -> Option<String> {
|
||||
request
|
||||
.body_json::<Value>()
|
||||
.ok()?
|
||||
.get("method")
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string)
|
||||
}
|
||||
|
||||
fn configure_apps(config: &mut Config, apps_base_url: &str) {
|
||||
config
|
||||
.features
|
||||
@@ -499,5 +510,195 @@ async fn tool_search_returns_deferred_tools_without_follow_up_tool_injection() -
|
||||
"post-tool follow-up should still rely on tool_search_output history, not tool injection: {third_request_tools:?}"
|
||||
);
|
||||
|
||||
let mcp_requests = server
|
||||
.received_requests()
|
||||
.await
|
||||
.expect("failed to fetch recorded requests");
|
||||
let tools_list_request = mcp_requests
|
||||
.iter()
|
||||
.find(|request| json_rpc_method(request).as_deref() == Some("tools/list"))
|
||||
.expect("tools/list MCP request");
|
||||
assert!(
|
||||
tools_list_request
|
||||
.headers
|
||||
.get("x-codex-turn-metadata")
|
||||
.is_none(),
|
||||
"tools/list should not include per-turn MCP headers"
|
||||
);
|
||||
|
||||
let tools_call_request = mcp_requests
|
||||
.iter()
|
||||
.find(|request| json_rpc_method(request).as_deref() == Some("tools/call"))
|
||||
.expect("tools/call MCP request");
|
||||
let session_id_header = tools_call_request
|
||||
.headers
|
||||
.get("session_id")
|
||||
.expect("tools/call session_id header");
|
||||
let request_id_header = tools_call_request
|
||||
.headers
|
||||
.get("x-client-request-id")
|
||||
.expect("tools/call x-client-request-id header");
|
||||
let turn_metadata_header = tools_call_request
|
||||
.headers
|
||||
.get("x-codex-turn-metadata")
|
||||
.expect("tools/call turn metadata header");
|
||||
assert_eq!(
|
||||
session_id_header
|
||||
.to_str()
|
||||
.expect("session_id header to be utf8"),
|
||||
request_id_header
|
||||
.to_str()
|
||||
.expect("x-client-request-id header to be utf8")
|
||||
);
|
||||
assert!(
|
||||
turn_metadata_header
|
||||
.to_str()
|
||||
.expect("turn metadata header to be utf8")
|
||||
.contains("\"turn_id\""),
|
||||
"expected turn metadata header to contain serialized turn metadata"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn apps_mcp_tool_call_uses_enriched_turn_metadata_header() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let apps_server = AppsTestServer::mount_searchable(&server).await?;
|
||||
let call_id = "tool-search-git-metadata";
|
||||
let mock = mount_sse_sequence(
|
||||
&server,
|
||||
vec![
|
||||
sse(vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_tool_search_call(
|
||||
call_id,
|
||||
&json!({
|
||||
"query": "create calendar event",
|
||||
"limit": 1,
|
||||
}),
|
||||
),
|
||||
ev_completed("resp-1"),
|
||||
]),
|
||||
sse(vec![
|
||||
ev_response_created("resp-2"),
|
||||
json!({
|
||||
"type": "response.output_item.done",
|
||||
"item": {
|
||||
"type": "function_call",
|
||||
"call_id": "calendar-call-git-metadata",
|
||||
"name": SEARCH_CALENDAR_CREATE_TOOL,
|
||||
"namespace": SEARCH_CALENDAR_NAMESPACE,
|
||||
"arguments": serde_json::to_string(&json!({
|
||||
"title": "Lunch",
|
||||
"starts_at": "2026-03-10T12:00:00Z"
|
||||
})).expect("serialize calendar args")
|
||||
}
|
||||
}),
|
||||
ev_completed("resp-2"),
|
||||
]),
|
||||
sse(vec![
|
||||
ev_response_created("resp-3"),
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-3"),
|
||||
]),
|
||||
],
|
||||
)
|
||||
.await;
|
||||
|
||||
let mut builder = configured_builder(apps_server.chatgpt_base_url.clone());
|
||||
let test = builder.build(&server).await?;
|
||||
|
||||
let cwd = test.cwd_path().to_path_buf();
|
||||
assert!(
|
||||
Command::new("git")
|
||||
.arg("init")
|
||||
.current_dir(&cwd)
|
||||
.status()?
|
||||
.success()
|
||||
);
|
||||
assert!(
|
||||
Command::new("git")
|
||||
.args(["config", "user.name", "Codex Test"])
|
||||
.current_dir(&cwd)
|
||||
.status()?
|
||||
.success()
|
||||
);
|
||||
assert!(
|
||||
Command::new("git")
|
||||
.args(["config", "user.email", "codex@example.com"])
|
||||
.current_dir(&cwd)
|
||||
.status()?
|
||||
.success()
|
||||
);
|
||||
assert!(
|
||||
Command::new("git")
|
||||
.args(["remote", "add", "origin", "https://example.test/repo.git"])
|
||||
.current_dir(&cwd)
|
||||
.status()?
|
||||
.success()
|
||||
);
|
||||
for idx in 0..400 {
|
||||
fs::write(
|
||||
cwd.join(format!("file-{idx:04}.txt")),
|
||||
format!("fixture file {idx}\n"),
|
||||
)?;
|
||||
}
|
||||
assert!(
|
||||
Command::new("git")
|
||||
.args(["add", "."])
|
||||
.current_dir(&cwd)
|
||||
.status()?
|
||||
.success()
|
||||
);
|
||||
assert!(
|
||||
Command::new("git")
|
||||
.args(["commit", "-m", "init"])
|
||||
.current_dir(&cwd)
|
||||
.status()?
|
||||
.success()
|
||||
);
|
||||
|
||||
test.codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![UserInput::Text {
|
||||
text: "Find the calendar create tool".to_string(),
|
||||
text_elements: Vec::new(),
|
||||
}],
|
||||
final_output_json_schema: None,
|
||||
})
|
||||
.await?;
|
||||
|
||||
wait_for_event(&test.codex, |event| {
|
||||
matches!(event, EventMsg::TurnComplete(_))
|
||||
})
|
||||
.await;
|
||||
|
||||
let _requests = mock.requests();
|
||||
let mcp_requests = server
|
||||
.received_requests()
|
||||
.await
|
||||
.expect("failed to fetch recorded requests");
|
||||
let tools_call_request = mcp_requests
|
||||
.iter()
|
||||
.find(|request| json_rpc_method(request).as_deref() == Some("tools/call"))
|
||||
.expect("tools/call MCP request");
|
||||
let turn_metadata_header = tools_call_request
|
||||
.headers
|
||||
.get("x-codex-turn-metadata")
|
||||
.expect("tools/call turn metadata header")
|
||||
.to_str()
|
||||
.expect("turn metadata header to be utf8");
|
||||
let parsed: Value = serde_json::from_str(turn_metadata_header)?;
|
||||
assert!(
|
||||
parsed
|
||||
.get("workspaces")
|
||||
.and_then(Value::as_object)
|
||||
.is_some_and(|workspaces| !workspaces.is_empty()),
|
||||
"expected enriched MCP turn metadata header with workspace git metadata, got {parsed:#?}"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ use std::io;
|
||||
use std::path::PathBuf;
|
||||
use std::process::Stdio;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex as StdMutex;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Result;
|
||||
@@ -22,6 +23,7 @@ use reqwest::header::HeaderMap;
|
||||
use reqwest::header::WWW_AUTHENTICATE;
|
||||
use rmcp::model::CallToolRequestParams;
|
||||
use rmcp::model::CallToolResult;
|
||||
use rmcp::model::ClientJsonRpcMessage;
|
||||
use rmcp::model::ClientNotification;
|
||||
use rmcp::model::ClientRequest;
|
||||
use rmcp::model::CreateElicitationRequestParams;
|
||||
@@ -83,14 +85,45 @@ const HEADER_LAST_EVENT_ID: &str = "Last-Event-Id";
|
||||
const HEADER_SESSION_ID: &str = "Mcp-Session-Id";
|
||||
const NON_JSON_RESPONSE_BODY_PREVIEW_BYTES: usize = 8_192;
|
||||
|
||||
fn message_uses_request_scoped_headers(message: &ClientJsonRpcMessage) -> bool {
|
||||
matches!(
|
||||
message,
|
||||
ClientJsonRpcMessage::Request(request)
|
||||
if request.request.method() == "tools/call"
|
||||
)
|
||||
}
|
||||
|
||||
fn apply_request_scoped_headers(
|
||||
mut request: reqwest::RequestBuilder,
|
||||
request_headers_state: &Arc<StdMutex<Option<HeaderMap>>>,
|
||||
) -> reqwest::RequestBuilder {
|
||||
let extra_headers = request_headers_state
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
||||
.clone();
|
||||
if let Some(extra_headers) = extra_headers {
|
||||
for (name, value) in &extra_headers {
|
||||
request = request.header(name, value.clone());
|
||||
}
|
||||
}
|
||||
request
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct StreamableHttpResponseClient {
|
||||
inner: reqwest::Client,
|
||||
request_headers_state: Arc<StdMutex<Option<HeaderMap>>>,
|
||||
}
|
||||
|
||||
impl StreamableHttpResponseClient {
|
||||
fn new(inner: reqwest::Client) -> Self {
|
||||
Self { inner }
|
||||
fn new(
|
||||
inner: reqwest::Client,
|
||||
request_headers_state: Arc<StdMutex<Option<HeaderMap>>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
inner,
|
||||
request_headers_state,
|
||||
}
|
||||
}
|
||||
|
||||
fn reqwest_error(
|
||||
@@ -133,6 +166,9 @@ impl StreamableHttpClient for StreamableHttpResponseClient {
|
||||
if let Some(session_id_value) = session_id.as_ref() {
|
||||
request = request.header(HEADER_SESSION_ID, session_id_value.as_ref());
|
||||
}
|
||||
if message_uses_request_scoped_headers(&message) {
|
||||
request = apply_request_scoped_headers(request, &self.request_headers_state);
|
||||
}
|
||||
|
||||
let response = request
|
||||
.json(&message)
|
||||
@@ -472,6 +508,7 @@ pub struct RmcpClient {
|
||||
transport_recipe: TransportRecipe,
|
||||
initialize_context: Mutex<Option<InitializeContext>>,
|
||||
session_recovery_lock: Mutex<()>,
|
||||
request_headers: Option<Arc<StdMutex<Option<HeaderMap>>>>,
|
||||
}
|
||||
|
||||
impl RmcpClient {
|
||||
@@ -489,9 +526,10 @@ impl RmcpClient {
|
||||
env_vars: env_vars.to_vec(),
|
||||
cwd,
|
||||
};
|
||||
let transport = Self::create_pending_transport(&transport_recipe)
|
||||
.await
|
||||
.map_err(io::Error::other)?;
|
||||
let transport =
|
||||
Self::create_pending_transport(&transport_recipe, /*request_headers*/ None)
|
||||
.await
|
||||
.map_err(io::Error::other)?;
|
||||
|
||||
Ok(Self {
|
||||
state: Mutex::new(ClientState::Connecting {
|
||||
@@ -500,6 +538,7 @@ impl RmcpClient {
|
||||
transport_recipe,
|
||||
initialize_context: Mutex::new(None),
|
||||
session_recovery_lock: Mutex::new(()),
|
||||
request_headers: None,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -511,6 +550,7 @@ impl RmcpClient {
|
||||
http_headers: Option<HashMap<String, String>>,
|
||||
env_http_headers: Option<HashMap<String, String>>,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
request_headers: Arc<StdMutex<Option<HeaderMap>>>,
|
||||
) -> Result<Self> {
|
||||
let transport_recipe = TransportRecipe::StreamableHttp {
|
||||
server_name: server_name.to_string(),
|
||||
@@ -520,7 +560,9 @@ impl RmcpClient {
|
||||
env_http_headers,
|
||||
store_mode,
|
||||
};
|
||||
let transport = Self::create_pending_transport(&transport_recipe).await?;
|
||||
let transport =
|
||||
Self::create_pending_transport(&transport_recipe, Some(Arc::clone(&request_headers)))
|
||||
.await?;
|
||||
Ok(Self {
|
||||
state: Mutex::new(ClientState::Connecting {
|
||||
transport: Some(transport),
|
||||
@@ -528,6 +570,7 @@ impl RmcpClient {
|
||||
transport_recipe,
|
||||
initialize_context: Mutex::new(None),
|
||||
session_recovery_lock: Mutex::new(()),
|
||||
request_headers: Some(request_headers),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -830,6 +873,7 @@ impl RmcpClient {
|
||||
|
||||
async fn create_pending_transport(
|
||||
transport_recipe: &TransportRecipe,
|
||||
request_headers: Option<Arc<StdMutex<Option<HeaderMap>>>>,
|
||||
) -> Result<PendingTransport> {
|
||||
match transport_recipe {
|
||||
TransportRecipe::Stdio {
|
||||
@@ -946,7 +990,12 @@ impl RmcpClient {
|
||||
.auth_header(access_token);
|
||||
let http_client = build_http_client(&default_headers)?;
|
||||
let transport = StreamableHttpClientTransport::with_client(
|
||||
StreamableHttpResponseClient::new(http_client),
|
||||
StreamableHttpResponseClient::new(
|
||||
http_client,
|
||||
request_headers
|
||||
.clone()
|
||||
.unwrap_or_else(|| Arc::new(StdMutex::new(None))),
|
||||
),
|
||||
http_config,
|
||||
);
|
||||
Ok(PendingTransport::StreamableHttp { transport })
|
||||
@@ -963,7 +1012,12 @@ impl RmcpClient {
|
||||
let http_client = build_http_client(&default_headers)?;
|
||||
|
||||
let transport = StreamableHttpClientTransport::with_client(
|
||||
StreamableHttpResponseClient::new(http_client),
|
||||
StreamableHttpResponseClient::new(
|
||||
http_client,
|
||||
request_headers
|
||||
.clone()
|
||||
.unwrap_or_else(|| Arc::new(StdMutex::new(None))),
|
||||
),
|
||||
http_config,
|
||||
);
|
||||
Ok(PendingTransport::StreamableHttp { transport })
|
||||
@@ -1111,7 +1165,9 @@ impl RmcpClient {
|
||||
.await
|
||||
.clone()
|
||||
.ok_or_else(|| anyhow!("MCP client cannot recover before initialize succeeds"))?;
|
||||
let pending_transport = Self::create_pending_transport(&self.transport_recipe).await?;
|
||||
let pending_transport =
|
||||
Self::create_pending_transport(&self.transport_recipe, self.request_headers.clone())
|
||||
.await?;
|
||||
let (service, oauth_persistor, process_group_guard) = Self::connect_pending_transport(
|
||||
pending_transport,
|
||||
initialize_context.handler,
|
||||
@@ -1166,7 +1222,10 @@ async fn create_oauth_transport_and_runtime(
|
||||
}
|
||||
};
|
||||
|
||||
let auth_client = AuthClient::new(StreamableHttpResponseClient::new(http_client), manager);
|
||||
let auth_client = AuthClient::new(
|
||||
StreamableHttpResponseClient::new(http_client, Arc::new(StdMutex::new(None))),
|
||||
manager,
|
||||
);
|
||||
let auth_manager = auth_client.auth_manager.clone();
|
||||
|
||||
let transport = StreamableHttpClientTransport::with_client(
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
use std::net::TcpListener;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex as StdMutex;
|
||||
use std::time::Duration;
|
||||
use std::time::Instant;
|
||||
|
||||
@@ -77,6 +79,7 @@ async fn create_client(base_url: &str) -> anyhow::Result<RmcpClient> {
|
||||
None,
|
||||
None,
|
||||
OAuthCredentialsStoreMode::File,
|
||||
Arc::new(StdMutex::new(None)),
|
||||
)
|
||||
.await?;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user