Compare commits

...

10 Commits

Author SHA1 Message Date
ychhabria
aedd076b8c Raise standalone file search limit to 50 2026-03-23 23:18:39 -07:00
ychhabria
55be8ca802 Refresh codex apps MCP headers at tool call time 2026-03-18 18:01:31 -07:00
nicholasclark-openai
1a2536d12b codex: fix MCP header CI regressions
Add the new request_headers test fixture state, update the streamable HTTP
recovery test to pass the new client constructor argument, and satisfy the
argument comment lint for clearing turn-scoped MCP request headers.

Co-authored-by: Codex <noreply@openai.com>
2026-03-18 17:12:13 -07:00
nicholasclark-openai
732d7ac81f codex: scope MCP request headers to turn lifecycle
Set Codex Apps MCP request headers once per active turn and clear them on turn end,
instead of threading request-scoped headers through every tool call. Keep RMCP
header injection limited to streamable HTTP tools/call requests so list/init
paths stay unchanged and concurrent tool calls on the same client are not
serialized.

Co-authored-by: Codex <noreply@openai.com>
2026-03-18 16:42:10 -07:00
nicholasclark-openai
1f3ec4172c Merge branch 'main' into nicholasclark/tool-call-task-headers 2026-03-18 12:11:32 -07:00
nicholasclark-openai
52fdfbcfb8 Merge branch 'main' into nicholasclark/tool-call-task-headers 2026-03-18 10:44:04 -07:00
nicholasclark-openai
2321532dec Merge branch 'main' into nicholasclark/tool-call-task-headers 2026-03-18 09:41:54 -07:00
nicholasclark-openai
de98643403 codex: fix CI failure on PR #15011
Co-authored-by: Codex <noreply@openai.com>
2026-03-18 09:41:06 -07:00
nicholasclark-openai
fb27c20581 codex: fix CI failure on PR #15011
Co-authored-by: Codex <noreply@openai.com>
2026-03-18 09:12:10 -07:00
nicholasclark-openai
bccce0f2d8 Forward tool call task headers to MCP HTTP requests
Co-authored-by: Codex <noreply@openai.com>
2026-03-17 18:14:32 -07:00
11 changed files with 515 additions and 12 deletions

View File

@@ -125,6 +125,8 @@ use futures::future::BoxFuture;
use futures::future::Shared;
use futures::prelude::*;
use futures::stream::FuturesOrdered;
use reqwest::header::HeaderMap;
use reqwest::header::HeaderValue;
use rmcp::model::ListResourceTemplatesResult;
use rmcp::model::ListResourcesResult;
use rmcp::model::PaginatedRequestParams;
@@ -3943,6 +3945,12 @@ impl Session {
arguments: Option<serde_json::Value>,
meta: Option<serde_json::Value>,
) -> anyhow::Result<CallToolResult> {
if server == CODEX_APPS_MCP_SERVER_NAME
&& let Some((turn_context, _)) = self.active_turn_context_and_cancellation_token().await
{
self.sync_mcp_request_headers_for_turn(turn_context.as_ref())
.await;
}
self.services
.mcp_connection_manager
.read()
@@ -3951,6 +3959,45 @@ impl Session {
.await
}
pub(crate) async fn sync_mcp_request_headers_for_turn(&self, turn_context: &TurnContext) {
let mut request_headers = HeaderMap::new();
let session_id = self.conversation_id.to_string();
if let Ok(value) = HeaderValue::from_str(&session_id) {
request_headers.insert("session_id", value.clone());
request_headers.insert("x-client-request-id", value);
}
if let Some(turn_metadata) = turn_context.turn_metadata_state.current_header_value()
&& let Ok(value) = HeaderValue::from_str(&turn_metadata)
{
request_headers.insert(crate::X_CODEX_TURN_METADATA_HEADER, value);
}
let request_headers = if request_headers.is_empty() {
None
} else {
Some(request_headers)
};
self.services
.mcp_connection_manager
.read()
.await
.set_request_headers_for_server(
crate::mcp::CODEX_APPS_MCP_SERVER_NAME,
request_headers,
);
}
pub(crate) async fn clear_mcp_request_headers(&self) {
self.services
.mcp_connection_manager
.read()
.await
.set_request_headers_for_server(
crate::mcp::CODEX_APPS_MCP_SERVER_NAME,
/*request_headers*/ None,
);
}
pub(crate) async fn parse_mcp_tool_name(
&self,
name: &str,

View File

@@ -26,6 +26,8 @@ use codex_protocol::protocol::ReadOnlyAccess;
use codex_protocol::protocol::SandboxPolicy;
use codex_protocol::request_permissions::PermissionGrantScope;
use codex_protocol::request_permissions::RequestPermissionProfile;
use codex_protocol::user_input::UserInput;
use reqwest::header::HeaderValue;
use tracing::Span;
use crate::protocol::CompactedItem;
@@ -56,6 +58,7 @@ use crate::tools::handlers::UnifiedExecHandler;
use crate::tools::registry::ToolHandler;
use crate::tools::router::ToolCallSource;
use crate::turn_diff_tracker::TurnDiffTracker;
use async_trait::async_trait;
use codex_app_server_protocol::AppInfo;
use codex_execpolicy::Decision;
use codex_execpolicy::NetworkRuleProtocol;
@@ -78,7 +81,10 @@ use opentelemetry::trace::TracerProvider as _;
use opentelemetry_sdk::trace::SdkTracerProvider;
use std::path::Path;
use std::time::Duration;
use tokio::sync::Notify;
use tokio::time::sleep;
use tokio_util::sync::CancellationToken;
use tokio_util::task::AbortOnDropHandle;
use tracing_opentelemetry::OpenTelemetrySpanExt;
use tracing_subscriber::prelude::*;
@@ -2716,6 +2722,99 @@ async fn request_permissions_is_auto_denied_when_granular_policy_blocks_tool_req
);
}
struct NoopTask;
#[async_trait]
impl SessionTask for NoopTask {
fn kind(&self) -> TaskKind {
TaskKind::Regular
}
fn span_name(&self) -> &'static str {
"noop"
}
async fn run(
self: Arc<Self>,
_ctx: Arc<SessionTaskContext>,
_turn_context: Arc<TurnContext>,
_input: Vec<UserInput>,
_cancellation_token: CancellationToken,
) -> Option<String> {
None
}
}
#[tokio::test]
async fn call_tool_refreshes_codex_apps_request_headers_from_active_turn() {
let (session, turn_context) = make_session_and_context().await;
let session = Arc::new(session);
let turn_context = Arc::new(turn_context);
session
.services
.mcp_connection_manager
.write()
.await
.register_test_server_for_request_headers(CODEX_APPS_MCP_SERVER_NAME);
session
.sync_mcp_request_headers_for_turn(turn_context.as_ref())
.await;
let base_header = turn_context
.turn_metadata_state
.current_header_value()
.expect("base turn metadata header");
let updated_header = serde_json::json!({
"turn_id": turn_context.sub_id,
"sandbox": "test",
"workspaces": {
"/tmp/repo": {
"has_changes": true
}
}
})
.to_string();
turn_context
.turn_metadata_state
.set_enriched_header_for_tests(Some(updated_header.clone()));
let mut active_turn = ActiveTurn::default();
let handle = tokio::spawn(async {});
active_turn.add_task(crate::state::RunningTask {
done: Arc::new(Notify::new()),
kind: TaskKind::Regular,
task: Arc::new(NoopTask),
cancellation_token: CancellationToken::new(),
handle: Arc::new(AbortOnDropHandle::new(handle)),
turn_context: Arc::clone(&turn_context),
_timer: None,
});
*session.active_turn.lock().await = Some(active_turn);
let _err = session
.call_tool(CODEX_APPS_MCP_SERVER_NAME, "echo", None, None)
.await
.expect_err("test server is not initialized");
let headers = session
.services
.mcp_connection_manager
.read()
.await
.request_headers_for_server(CODEX_APPS_MCP_SERVER_NAME)
.expect("request headers should be tracked for codex apps");
assert_eq!(
headers.get(crate::X_CODEX_TURN_METADATA_HEADER),
Some(&HeaderValue::from_str(&updated_header).expect("valid enriched header")),
);
assert_ne!(
headers.get(crate::X_CODEX_TURN_METADATA_HEADER),
Some(&HeaderValue::from_str(&base_header).expect("valid base header")),
);
}
#[tokio::test]
async fn submit_with_id_captures_current_span_trace_context() {
let (session, _turn_context) = make_session_and_context().await;

View File

@@ -423,6 +423,7 @@ impl ManagedClient {
#[derive(Clone)]
struct AsyncManagedClient {
client: Shared<BoxFuture<'static, Result<ManagedClient, StartupOutcomeError>>>,
request_headers: Arc<StdMutex<Option<reqwest::header::HeaderMap>>>,
startup_snapshot: Option<Vec<ToolInfo>>,
startup_complete: Arc<AtomicBool>,
tool_plugin_provenance: Arc<ToolPluginProvenance>,
@@ -448,17 +449,26 @@ impl AsyncManagedClient {
codex_apps_tools_cache_context.as_ref(),
)
.map(|tools| filter_tools(tools, &tool_filter));
let request_headers = Arc::new(StdMutex::new(None));
let startup_tool_filter = tool_filter;
let startup_complete = Arc::new(AtomicBool::new(false));
let startup_complete_for_fut = Arc::clone(&startup_complete);
let request_headers_for_client = Arc::clone(&request_headers);
let fut = async move {
let outcome = async {
if let Err(error) = validate_mcp_server_name(&server_name) {
return Err(error.into());
}
let client =
Arc::new(make_rmcp_client(&server_name, config.transport, store_mode).await?);
let client = Arc::new(
make_rmcp_client(
&server_name,
config.transport,
store_mode,
request_headers_for_client,
)
.await?,
);
match start_server_task(
server_name,
client,
@@ -495,6 +505,7 @@ impl AsyncManagedClient {
Self {
client,
request_headers,
startup_snapshot,
startup_complete,
tool_plugin_provenance,
@@ -576,6 +587,14 @@ impl AsyncManagedClient {
let managed = self.client().await?;
managed.notify_sandbox_state_change(sandbox_state).await
}
fn set_request_headers(&self, request_headers: Option<reqwest::header::HeaderMap>) {
let mut guard = self
.request_headers
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
*guard = request_headers;
}
}
pub const MCP_SANDBOX_STATE_CAPABILITY: &str = "codex/sandbox-state";
@@ -617,6 +636,40 @@ impl McpConnectionManager {
Self::new_uninitialized(approval_policy)
}
#[cfg(test)]
pub(crate) fn register_test_server_for_request_headers(&mut self, server_name: &str) {
let failed_client = futures::future::ready::<Result<ManagedClient, StartupOutcomeError>>(
Err(StartupOutcomeError::Failed {
error: "test request headers stub".to_string(),
}),
)
.boxed()
.shared();
self.clients.insert(
server_name.to_string(),
AsyncManagedClient {
client: failed_client,
request_headers: Arc::new(StdMutex::new(None)),
startup_snapshot: None,
startup_complete: Arc::new(AtomicBool::new(true)),
tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()),
},
);
}
#[cfg(test)]
pub(crate) fn request_headers_for_server(
&self,
server_name: &str,
) -> Option<reqwest::header::HeaderMap> {
let client = self.clients.get(server_name)?;
client
.request_headers
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.clone()
}
pub(crate) fn has_servers(&self) -> bool {
!self.clients.is_empty()
}
@@ -1046,6 +1099,16 @@ impl McpConnectionManager {
})
}
pub(crate) fn set_request_headers_for_server(
&self,
server_name: &str,
request_headers: Option<reqwest::header::HeaderMap>,
) {
if let Some(client) = self.clients.get(server_name) {
client.set_request_headers(request_headers);
}
}
/// List resources from the specified server.
pub async fn list_resources(
&self,
@@ -1429,6 +1492,7 @@ async fn make_rmcp_client(
server_name: &str,
transport: McpServerTransportConfig,
store_mode: OAuthCredentialsStoreMode,
request_headers: Arc<StdMutex<Option<reqwest::header::HeaderMap>>>,
) -> Result<RmcpClient, StartupOutcomeError> {
match transport {
McpServerTransportConfig::Stdio {
@@ -1462,6 +1526,7 @@ async fn make_rmcp_client(
http_headers,
env_http_headers,
store_mode,
request_headers,
)
.await
.map_err(StartupOutcomeError::from)

View File

@@ -4,6 +4,7 @@ use codex_protocol::protocol::McpAuthStatus;
use rmcp::model::JsonObject;
use std::collections::HashSet;
use std::sync::Arc;
use std::sync::Mutex as StdMutex;
use tempfile::tempdir;
fn create_test_tool(server_name: &str, tool_name: &str) -> ToolInfo {
@@ -413,6 +414,7 @@ async fn list_all_tools_uses_startup_snapshot_while_client_is_pending() {
CODEX_APPS_MCP_SERVER_NAME.to_string(),
AsyncManagedClient {
client: pending_client,
request_headers: Arc::new(StdMutex::new(None)),
startup_snapshot: Some(startup_tools),
startup_complete: Arc::new(std::sync::atomic::AtomicBool::new(false)),
tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()),
@@ -438,6 +440,7 @@ async fn list_all_tools_blocks_while_client_is_pending_without_startup_snapshot(
CODEX_APPS_MCP_SERVER_NAME.to_string(),
AsyncManagedClient {
client: pending_client,
request_headers: Arc::new(StdMutex::new(None)),
startup_snapshot: None,
startup_complete: Arc::new(std::sync::atomic::AtomicBool::new(false)),
tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()),
@@ -460,6 +463,7 @@ async fn list_all_tools_does_not_block_when_startup_snapshot_cache_hit_is_empty(
CODEX_APPS_MCP_SERVER_NAME.to_string(),
AsyncManagedClient {
client: pending_client,
request_headers: Arc::new(StdMutex::new(None)),
startup_snapshot: Some(Vec::new()),
startup_complete: Arc::new(std::sync::atomic::AtomicBool::new(false)),
tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()),
@@ -492,6 +496,7 @@ async fn list_all_tools_uses_startup_snapshot_when_client_startup_fails() {
CODEX_APPS_MCP_SERVER_NAME.to_string(),
AsyncManagedClient {
client: failed_client,
request_headers: Arc::new(StdMutex::new(None)),
startup_snapshot: Some(startup_tools),
startup_complete,
tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()),

View File

@@ -153,6 +153,8 @@ impl Session {
) {
self.abort_all_tasks(TurnAbortReason::Replaced).await;
self.clear_connector_selection().await;
self.sync_mcp_request_headers_for_turn(turn_context.as_ref())
.await;
let task: Arc<dyn SessionTask> = Arc::new(task);
let task_kind = task.kind();
@@ -233,6 +235,7 @@ impl Session {
// in-flight approval wait can surface as a model-visible rejection before TurnAborted.
active_turn.clear_pending().await;
}
self.clear_mcp_request_headers().await;
}
pub async fn on_task_finished(
@@ -262,6 +265,9 @@ impl Session {
*active = None;
}
drop(active);
if should_clear_active_turn {
self.clear_mcp_request_headers().await;
}
if !pending_input.is_empty() {
for pending_input_item in pending_input {
match inspect_pending_input(self, &turn_context, pending_input_item).await {

View File

@@ -217,6 +217,14 @@ impl TurnMetadataState {
}
}
#[cfg(test)]
pub(crate) fn set_enriched_header_for_tests(&self, header: Option<String>) {
*self
.enriched_header
.write()
.unwrap_or_else(std::sync::PoisonError::into_inner) = header;
}
async fn fetch_workspace_git_metadata(&self) -> WorkspaceGitMetadata {
let (latest_git_commit_hash, associated_remote_urls, has_changes) = tokio::join!(
get_head_commit_hash(&self.cwd),

View File

@@ -30,6 +30,8 @@ use core_test_support::wait_for_event;
use pretty_assertions::assert_eq;
use serde_json::Value;
use serde_json::json;
use std::fs;
use std::process::Command;
const SEARCH_TOOL_DESCRIPTION_SNIPPETS: [&str; 2] = [
"You have access to all the tools of the following apps/connectors",
@@ -86,6 +88,15 @@ fn tool_search_output_tools(request: &ResponsesRequest, call_id: &str) -> Vec<Va
.unwrap_or_default()
}
fn json_rpc_method(request: &wiremock::Request) -> Option<String> {
request
.body_json::<Value>()
.ok()?
.get("method")
.and_then(Value::as_str)
.map(str::to_string)
}
fn configure_apps(config: &mut Config, apps_base_url: &str) {
config
.features
@@ -499,5 +510,195 @@ async fn tool_search_returns_deferred_tools_without_follow_up_tool_injection() -
"post-tool follow-up should still rely on tool_search_output history, not tool injection: {third_request_tools:?}"
);
let mcp_requests = server
.received_requests()
.await
.expect("failed to fetch recorded requests");
let tools_list_request = mcp_requests
.iter()
.find(|request| json_rpc_method(request).as_deref() == Some("tools/list"))
.expect("tools/list MCP request");
assert!(
tools_list_request
.headers
.get("x-codex-turn-metadata")
.is_none(),
"tools/list should not include per-turn MCP headers"
);
let tools_call_request = mcp_requests
.iter()
.find(|request| json_rpc_method(request).as_deref() == Some("tools/call"))
.expect("tools/call MCP request");
let session_id_header = tools_call_request
.headers
.get("session_id")
.expect("tools/call session_id header");
let request_id_header = tools_call_request
.headers
.get("x-client-request-id")
.expect("tools/call x-client-request-id header");
let turn_metadata_header = tools_call_request
.headers
.get("x-codex-turn-metadata")
.expect("tools/call turn metadata header");
assert_eq!(
session_id_header
.to_str()
.expect("session_id header to be utf8"),
request_id_header
.to_str()
.expect("x-client-request-id header to be utf8")
);
assert!(
turn_metadata_header
.to_str()
.expect("turn metadata header to be utf8")
.contains("\"turn_id\""),
"expected turn metadata header to contain serialized turn metadata"
);
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn apps_mcp_tool_call_uses_enriched_turn_metadata_header() -> Result<()> {
skip_if_no_network!(Ok(()));
let server = start_mock_server().await;
let apps_server = AppsTestServer::mount_searchable(&server).await?;
let call_id = "tool-search-git-metadata";
let mock = mount_sse_sequence(
&server,
vec![
sse(vec![
ev_response_created("resp-1"),
ev_tool_search_call(
call_id,
&json!({
"query": "create calendar event",
"limit": 1,
}),
),
ev_completed("resp-1"),
]),
sse(vec![
ev_response_created("resp-2"),
json!({
"type": "response.output_item.done",
"item": {
"type": "function_call",
"call_id": "calendar-call-git-metadata",
"name": SEARCH_CALENDAR_CREATE_TOOL,
"namespace": SEARCH_CALENDAR_NAMESPACE,
"arguments": serde_json::to_string(&json!({
"title": "Lunch",
"starts_at": "2026-03-10T12:00:00Z"
})).expect("serialize calendar args")
}
}),
ev_completed("resp-2"),
]),
sse(vec![
ev_response_created("resp-3"),
ev_assistant_message("msg-1", "done"),
ev_completed("resp-3"),
]),
],
)
.await;
let mut builder = configured_builder(apps_server.chatgpt_base_url.clone());
let test = builder.build(&server).await?;
let cwd = test.cwd_path().to_path_buf();
assert!(
Command::new("git")
.arg("init")
.current_dir(&cwd)
.status()?
.success()
);
assert!(
Command::new("git")
.args(["config", "user.name", "Codex Test"])
.current_dir(&cwd)
.status()?
.success()
);
assert!(
Command::new("git")
.args(["config", "user.email", "codex@example.com"])
.current_dir(&cwd)
.status()?
.success()
);
assert!(
Command::new("git")
.args(["remote", "add", "origin", "https://example.test/repo.git"])
.current_dir(&cwd)
.status()?
.success()
);
for idx in 0..400 {
fs::write(
cwd.join(format!("file-{idx:04}.txt")),
format!("fixture file {idx}\n"),
)?;
}
assert!(
Command::new("git")
.args(["add", "."])
.current_dir(&cwd)
.status()?
.success()
);
assert!(
Command::new("git")
.args(["commit", "-m", "init"])
.current_dir(&cwd)
.status()?
.success()
);
test.codex
.submit(Op::UserInput {
items: vec![UserInput::Text {
text: "Find the calendar create tool".to_string(),
text_elements: Vec::new(),
}],
final_output_json_schema: None,
})
.await?;
wait_for_event(&test.codex, |event| {
matches!(event, EventMsg::TurnComplete(_))
})
.await;
let _requests = mock.requests();
let mcp_requests = server
.received_requests()
.await
.expect("failed to fetch recorded requests");
let tools_call_request = mcp_requests
.iter()
.find(|request| json_rpc_method(request).as_deref() == Some("tools/call"))
.expect("tools/call MCP request");
let turn_metadata_header = tools_call_request
.headers
.get("x-codex-turn-metadata")
.expect("tools/call turn metadata header")
.to_str()
.expect("turn metadata header to be utf8");
let parsed: Value = serde_json::from_str(turn_metadata_header)?;
assert!(
parsed
.get("workspaces")
.and_then(Value::as_object)
.is_some_and(|workspaces| !workspaces.is_empty()),
"expected enriched MCP turn metadata header with workspace git metadata, got {parsed:#?}"
);
Ok(())
}

View File

@@ -5,6 +5,7 @@ use std::io;
use std::path::PathBuf;
use std::process::Stdio;
use std::sync::Arc;
use std::sync::Mutex as StdMutex;
use std::time::Duration;
use anyhow::Result;
@@ -22,6 +23,7 @@ use reqwest::header::HeaderMap;
use reqwest::header::WWW_AUTHENTICATE;
use rmcp::model::CallToolRequestParams;
use rmcp::model::CallToolResult;
use rmcp::model::ClientJsonRpcMessage;
use rmcp::model::ClientNotification;
use rmcp::model::ClientRequest;
use rmcp::model::CreateElicitationRequestParams;
@@ -83,14 +85,45 @@ const HEADER_LAST_EVENT_ID: &str = "Last-Event-Id";
const HEADER_SESSION_ID: &str = "Mcp-Session-Id";
const NON_JSON_RESPONSE_BODY_PREVIEW_BYTES: usize = 8_192;
fn message_uses_request_scoped_headers(message: &ClientJsonRpcMessage) -> bool {
matches!(
message,
ClientJsonRpcMessage::Request(request)
if request.request.method() == "tools/call"
)
}
fn apply_request_scoped_headers(
mut request: reqwest::RequestBuilder,
request_headers_state: &Arc<StdMutex<Option<HeaderMap>>>,
) -> reqwest::RequestBuilder {
let extra_headers = request_headers_state
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.clone();
if let Some(extra_headers) = extra_headers {
for (name, value) in &extra_headers {
request = request.header(name, value.clone());
}
}
request
}
#[derive(Clone)]
struct StreamableHttpResponseClient {
inner: reqwest::Client,
request_headers_state: Arc<StdMutex<Option<HeaderMap>>>,
}
impl StreamableHttpResponseClient {
fn new(inner: reqwest::Client) -> Self {
Self { inner }
fn new(
inner: reqwest::Client,
request_headers_state: Arc<StdMutex<Option<HeaderMap>>>,
) -> Self {
Self {
inner,
request_headers_state,
}
}
fn reqwest_error(
@@ -133,6 +166,9 @@ impl StreamableHttpClient for StreamableHttpResponseClient {
if let Some(session_id_value) = session_id.as_ref() {
request = request.header(HEADER_SESSION_ID, session_id_value.as_ref());
}
if message_uses_request_scoped_headers(&message) {
request = apply_request_scoped_headers(request, &self.request_headers_state);
}
let response = request
.json(&message)
@@ -472,6 +508,7 @@ pub struct RmcpClient {
transport_recipe: TransportRecipe,
initialize_context: Mutex<Option<InitializeContext>>,
session_recovery_lock: Mutex<()>,
request_headers: Option<Arc<StdMutex<Option<HeaderMap>>>>,
}
impl RmcpClient {
@@ -489,9 +526,10 @@ impl RmcpClient {
env_vars: env_vars.to_vec(),
cwd,
};
let transport = Self::create_pending_transport(&transport_recipe)
.await
.map_err(io::Error::other)?;
let transport =
Self::create_pending_transport(&transport_recipe, /*request_headers*/ None)
.await
.map_err(io::Error::other)?;
Ok(Self {
state: Mutex::new(ClientState::Connecting {
@@ -500,6 +538,7 @@ impl RmcpClient {
transport_recipe,
initialize_context: Mutex::new(None),
session_recovery_lock: Mutex::new(()),
request_headers: None,
})
}
@@ -511,6 +550,7 @@ impl RmcpClient {
http_headers: Option<HashMap<String, String>>,
env_http_headers: Option<HashMap<String, String>>,
store_mode: OAuthCredentialsStoreMode,
request_headers: Arc<StdMutex<Option<HeaderMap>>>,
) -> Result<Self> {
let transport_recipe = TransportRecipe::StreamableHttp {
server_name: server_name.to_string(),
@@ -520,7 +560,9 @@ impl RmcpClient {
env_http_headers,
store_mode,
};
let transport = Self::create_pending_transport(&transport_recipe).await?;
let transport =
Self::create_pending_transport(&transport_recipe, Some(Arc::clone(&request_headers)))
.await?;
Ok(Self {
state: Mutex::new(ClientState::Connecting {
transport: Some(transport),
@@ -528,6 +570,7 @@ impl RmcpClient {
transport_recipe,
initialize_context: Mutex::new(None),
session_recovery_lock: Mutex::new(()),
request_headers: Some(request_headers),
})
}
@@ -830,6 +873,7 @@ impl RmcpClient {
async fn create_pending_transport(
transport_recipe: &TransportRecipe,
request_headers: Option<Arc<StdMutex<Option<HeaderMap>>>>,
) -> Result<PendingTransport> {
match transport_recipe {
TransportRecipe::Stdio {
@@ -946,7 +990,12 @@ impl RmcpClient {
.auth_header(access_token);
let http_client = build_http_client(&default_headers)?;
let transport = StreamableHttpClientTransport::with_client(
StreamableHttpResponseClient::new(http_client),
StreamableHttpResponseClient::new(
http_client,
request_headers
.clone()
.unwrap_or_else(|| Arc::new(StdMutex::new(None))),
),
http_config,
);
Ok(PendingTransport::StreamableHttp { transport })
@@ -963,7 +1012,12 @@ impl RmcpClient {
let http_client = build_http_client(&default_headers)?;
let transport = StreamableHttpClientTransport::with_client(
StreamableHttpResponseClient::new(http_client),
StreamableHttpResponseClient::new(
http_client,
request_headers
.clone()
.unwrap_or_else(|| Arc::new(StdMutex::new(None))),
),
http_config,
);
Ok(PendingTransport::StreamableHttp { transport })
@@ -1111,7 +1165,9 @@ impl RmcpClient {
.await
.clone()
.ok_or_else(|| anyhow!("MCP client cannot recover before initialize succeeds"))?;
let pending_transport = Self::create_pending_transport(&self.transport_recipe).await?;
let pending_transport =
Self::create_pending_transport(&self.transport_recipe, self.request_headers.clone())
.await?;
let (service, oauth_persistor, process_group_guard) = Self::connect_pending_transport(
pending_transport,
initialize_context.handler,
@@ -1166,7 +1222,10 @@ async fn create_oauth_transport_and_runtime(
}
};
let auth_client = AuthClient::new(StreamableHttpResponseClient::new(http_client), manager);
let auth_client = AuthClient::new(
StreamableHttpResponseClient::new(http_client, Arc::new(StdMutex::new(None))),
manager,
);
let auth_manager = auth_client.auth_manager.clone();
let transport = StreamableHttpClientTransport::with_client(

View File

@@ -1,5 +1,7 @@
use std::net::TcpListener;
use std::path::PathBuf;
use std::sync::Arc;
use std::sync::Mutex as StdMutex;
use std::time::Duration;
use std::time::Instant;
@@ -77,6 +79,7 @@ async fn create_client(base_url: &str) -> anyhow::Result<RmcpClient> {
None,
None,
OAuthCredentialsStoreMode::File,
Arc::new(StdMutex::new(None)),
)
.await?;

View File

@@ -6,6 +6,7 @@
//! on every keystroke, and drops the session when the query becomes empty.
use codex_file_search as file_search;
use std::num::NonZero;
use std::path::PathBuf;
use std::sync::Arc;
use std::sync::Mutex;
@@ -13,6 +14,8 @@ use std::sync::Mutex;
use crate::app_event::AppEvent;
use crate::app_event_sender::AppEventSender;
const FILE_SEARCH_LIMIT: usize = 50;
pub(crate) struct FileSearchManager {
state: Arc<Mutex<SearchState>>,
search_dir: PathBuf,
@@ -83,6 +86,8 @@ impl FileSearchManager {
let session = file_search::create_session(
vec![self.search_dir.clone()],
file_search::FileSearchOptions {
#[expect(clippy::unwrap_used)]
limit: NonZero::new(FILE_SEARCH_LIMIT).unwrap(),
compute_indices: true,
..Default::default()
},

View File

@@ -6,6 +6,7 @@
//! on every keystroke, and drops the session when the query becomes empty.
use codex_file_search as file_search;
use std::num::NonZero;
use std::path::PathBuf;
use std::sync::Arc;
use std::sync::Mutex;
@@ -13,6 +14,8 @@ use std::sync::Mutex;
use crate::app_event::AppEvent;
use crate::app_event_sender::AppEventSender;
const FILE_SEARCH_LIMIT: usize = 50;
pub(crate) struct FileSearchManager {
state: Arc<Mutex<SearchState>>,
search_dir: PathBuf,
@@ -83,6 +86,8 @@ impl FileSearchManager {
let session = file_search::create_session(
vec![self.search_dir.clone()],
file_search::FileSearchOptions {
#[expect(clippy::unwrap_used)]
limit: NonZero::new(FILE_SEARCH_LIMIT).unwrap(),
compute_indices: true,
..Default::default()
},