mirror of
https://github.com/openai/codex.git
synced 2026-04-24 06:35:50 +00:00
codex: scope MCP request headers to turn lifecycle
Set Codex Apps MCP request headers once per active turn and clear them on turn end, instead of threading request-scoped headers through every tool call. Keep RMCP header injection limited to streamable HTTP tools/call requests so list/init paths stay unchanged and concurrent tool calls on the same client are not serialized. Co-authored-by: Codex <noreply@openai.com>
This commit is contained in:
@@ -125,6 +125,8 @@ use futures::future::BoxFuture;
|
||||
use futures::future::Shared;
|
||||
use futures::prelude::*;
|
||||
use futures::stream::FuturesOrdered;
|
||||
use reqwest::header::HeaderMap;
|
||||
use reqwest::header::HeaderValue;
|
||||
use rmcp::model::ListResourceTemplatesResult;
|
||||
use rmcp::model::ListResourcesResult;
|
||||
use rmcp::model::PaginatedRequestParams;
|
||||
@@ -3942,16 +3944,51 @@ impl Session {
|
||||
tool: &str,
|
||||
arguments: Option<serde_json::Value>,
|
||||
meta: Option<serde_json::Value>,
|
||||
request_headers: Option<reqwest::header::HeaderMap>,
|
||||
) -> anyhow::Result<CallToolResult> {
|
||||
self.services
|
||||
.mcp_connection_manager
|
||||
.read()
|
||||
.await
|
||||
.call_tool(server, tool, arguments, meta, request_headers)
|
||||
.call_tool(server, tool, arguments, meta)
|
||||
.await
|
||||
}
|
||||
|
||||
pub(crate) async fn sync_mcp_request_headers_for_turn(&self, turn_context: &TurnContext) {
|
||||
let mut request_headers = HeaderMap::new();
|
||||
let session_id = self.conversation_id.to_string();
|
||||
if let Ok(value) = HeaderValue::from_str(&session_id) {
|
||||
request_headers.insert("session_id", value.clone());
|
||||
request_headers.insert("x-client-request-id", value);
|
||||
}
|
||||
if let Some(turn_metadata) = turn_context.turn_metadata_state.current_header_value()
|
||||
&& let Ok(value) = HeaderValue::from_str(&turn_metadata)
|
||||
{
|
||||
request_headers.insert(crate::X_CODEX_TURN_METADATA_HEADER, value);
|
||||
}
|
||||
|
||||
let request_headers = if request_headers.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(request_headers)
|
||||
};
|
||||
self.services
|
||||
.mcp_connection_manager
|
||||
.read()
|
||||
.await
|
||||
.set_request_headers_for_server(
|
||||
crate::mcp::CODEX_APPS_MCP_SERVER_NAME,
|
||||
request_headers,
|
||||
);
|
||||
}
|
||||
|
||||
pub(crate) async fn clear_mcp_request_headers(&self) {
|
||||
self.services
|
||||
.mcp_connection_manager
|
||||
.read()
|
||||
.await
|
||||
.set_request_headers_for_server(crate::mcp::CODEX_APPS_MCP_SERVER_NAME, None);
|
||||
}
|
||||
|
||||
pub(crate) async fn parse_mcp_tool_name(
|
||||
&self,
|
||||
name: &str,
|
||||
|
||||
@@ -670,7 +670,6 @@ async fn maybe_auto_review_mcp_request_user_input(
|
||||
parent_ctx.as_ref(),
|
||||
&invocation.server,
|
||||
&invocation.tool,
|
||||
/*request_headers*/ None,
|
||||
)
|
||||
.await;
|
||||
let review_cancel = cancel_token.child_token();
|
||||
|
||||
@@ -378,7 +378,7 @@ struct ManagedClient {
|
||||
}
|
||||
|
||||
impl ManagedClient {
|
||||
fn listed_tools(&self, _request_headers: Option<reqwest::header::HeaderMap>) -> Vec<ToolInfo> {
|
||||
fn listed_tools(&self) -> Vec<ToolInfo> {
|
||||
let total_start = Instant::now();
|
||||
if let Some(cache_context) = self.codex_apps_tools_cache_context.as_ref()
|
||||
&& let CachedCodexAppsToolsLoad::Hit(tools) =
|
||||
@@ -423,44 +423,12 @@ impl ManagedClient {
|
||||
#[derive(Clone)]
|
||||
struct AsyncManagedClient {
|
||||
client: Shared<BoxFuture<'static, Result<ManagedClient, StartupOutcomeError>>>,
|
||||
request_headers: Arc<StdMutex<Option<reqwest::header::HeaderMap>>>,
|
||||
startup_snapshot: Option<Vec<ToolInfo>>,
|
||||
startup_complete: Arc<AtomicBool>,
|
||||
startup_request_headers: Arc<StdMutex<Option<reqwest::header::HeaderMap>>>,
|
||||
tool_plugin_provenance: Arc<ToolPluginProvenance>,
|
||||
}
|
||||
|
||||
struct StartupRequestHeadersGuard {
|
||||
state: Arc<StdMutex<Option<reqwest::header::HeaderMap>>>,
|
||||
previous: Option<reqwest::header::HeaderMap>,
|
||||
}
|
||||
|
||||
impl StartupRequestHeadersGuard {
|
||||
fn set(
|
||||
state: Arc<StdMutex<Option<reqwest::header::HeaderMap>>>,
|
||||
headers: Option<reqwest::header::HeaderMap>,
|
||||
) -> Self {
|
||||
let previous = {
|
||||
let mut guard = state
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
let previous = guard.clone();
|
||||
*guard = headers;
|
||||
previous
|
||||
};
|
||||
Self { state, previous }
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for StartupRequestHeadersGuard {
|
||||
fn drop(&mut self) {
|
||||
let mut guard = self
|
||||
.state
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
*guard = self.previous.clone();
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncManagedClient {
|
||||
// Keep this constructor flat so the startup inputs remain readable at the
|
||||
// single call site instead of introducing a one-off params wrapper.
|
||||
@@ -481,19 +449,26 @@ impl AsyncManagedClient {
|
||||
codex_apps_tools_cache_context.as_ref(),
|
||||
)
|
||||
.map(|tools| filter_tools(tools, &tool_filter));
|
||||
let request_headers = Arc::new(StdMutex::new(None));
|
||||
let startup_tool_filter = tool_filter;
|
||||
let startup_complete = Arc::new(AtomicBool::new(false));
|
||||
let startup_complete_for_fut = Arc::clone(&startup_complete);
|
||||
let startup_request_headers = Arc::new(StdMutex::new(None));
|
||||
let startup_request_headers_for_fut = Arc::clone(&startup_request_headers);
|
||||
let request_headers_for_client = Arc::clone(&request_headers);
|
||||
let fut = async move {
|
||||
let outcome = async {
|
||||
if let Err(error) = validate_mcp_server_name(&server_name) {
|
||||
return Err(error.into());
|
||||
}
|
||||
|
||||
let client =
|
||||
Arc::new(make_rmcp_client(&server_name, config.transport, store_mode).await?);
|
||||
let client = Arc::new(
|
||||
make_rmcp_client(
|
||||
&server_name,
|
||||
config.transport,
|
||||
store_mode,
|
||||
request_headers_for_client,
|
||||
)
|
||||
.await?,
|
||||
);
|
||||
match start_server_task(
|
||||
server_name,
|
||||
client,
|
||||
@@ -506,7 +481,6 @@ impl AsyncManagedClient {
|
||||
tx_event,
|
||||
elicitation_requests,
|
||||
codex_apps_tools_cache_context,
|
||||
startup_request_headers: startup_request_headers_for_fut,
|
||||
},
|
||||
)
|
||||
.or_cancel(&cancel_token)
|
||||
@@ -531,19 +505,14 @@ impl AsyncManagedClient {
|
||||
|
||||
Self {
|
||||
client,
|
||||
request_headers,
|
||||
startup_snapshot,
|
||||
startup_complete,
|
||||
startup_request_headers,
|
||||
tool_plugin_provenance,
|
||||
}
|
||||
}
|
||||
|
||||
async fn client(
|
||||
&self,
|
||||
request_headers: Option<reqwest::header::HeaderMap>,
|
||||
) -> Result<ManagedClient, StartupOutcomeError> {
|
||||
let _request_headers_guard =
|
||||
StartupRequestHeadersGuard::set(self.startup_request_headers.clone(), request_headers);
|
||||
async fn client(&self) -> Result<ManagedClient, StartupOutcomeError> {
|
||||
self.client.clone().await
|
||||
}
|
||||
|
||||
@@ -554,10 +523,7 @@ impl AsyncManagedClient {
|
||||
None
|
||||
}
|
||||
|
||||
async fn listed_tools(
|
||||
&self,
|
||||
request_headers: Option<reqwest::header::HeaderMap>,
|
||||
) -> Option<Vec<ToolInfo>> {
|
||||
async fn listed_tools(&self) -> Option<Vec<ToolInfo>> {
|
||||
let annotate_tools = |tools: Vec<ToolInfo>| {
|
||||
let mut tools = tools;
|
||||
for tool in &mut tools {
|
||||
@@ -609,8 +575,8 @@ impl AsyncManagedClient {
|
||||
let tools = if let Some(startup_tools) = self.startup_snapshot_while_initializing() {
|
||||
Some(startup_tools)
|
||||
} else {
|
||||
match self.client(request_headers.clone()).await {
|
||||
Ok(client) => Some(client.listed_tools(request_headers)),
|
||||
match self.client().await {
|
||||
Ok(client) => Some(client.listed_tools()),
|
||||
Err(_) => self.startup_snapshot.clone(),
|
||||
}
|
||||
};
|
||||
@@ -618,9 +584,17 @@ impl AsyncManagedClient {
|
||||
}
|
||||
|
||||
async fn notify_sandbox_state_change(&self, sandbox_state: &SandboxState) -> Result<()> {
|
||||
let managed = self.client(/*request_headers*/ None).await?;
|
||||
let managed = self.client().await?;
|
||||
managed.notify_sandbox_state_change(sandbox_state).await
|
||||
}
|
||||
|
||||
fn set_request_headers(&self, request_headers: Option<reqwest::header::HeaderMap>) {
|
||||
let mut guard = self
|
||||
.request_headers
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
*guard = request_headers;
|
||||
}
|
||||
}
|
||||
|
||||
pub const MCP_SANDBOX_STATE_CAPABILITY: &str = "codex/sandbox-state";
|
||||
@@ -731,7 +705,7 @@ impl McpConnectionManager {
|
||||
let auth_entry = auth_entries.get(&server_name).cloned();
|
||||
let sandbox_state = initial_sandbox_state.clone();
|
||||
join_set.spawn(async move {
|
||||
let outcome = async_managed_client.client(/*request_headers*/ None).await;
|
||||
let outcome = async_managed_client.client().await;
|
||||
if cancel_token.is_cancelled() {
|
||||
return (server_name, Err(StartupOutcomeError::Cancelled));
|
||||
}
|
||||
@@ -800,15 +774,11 @@ impl McpConnectionManager {
|
||||
(manager, cancel_token)
|
||||
}
|
||||
|
||||
async fn client_by_name(
|
||||
&self,
|
||||
name: &str,
|
||||
request_headers: Option<reqwest::header::HeaderMap>,
|
||||
) -> Result<ManagedClient> {
|
||||
async fn client_by_name(&self, name: &str) -> Result<ManagedClient> {
|
||||
self.clients
|
||||
.get(name)
|
||||
.ok_or_else(|| anyhow!("unknown MCP server '{name}'"))?
|
||||
.client(request_headers)
|
||||
.client()
|
||||
.await
|
||||
.context("failed to get client")
|
||||
}
|
||||
@@ -829,12 +799,7 @@ impl McpConnectionManager {
|
||||
return false;
|
||||
};
|
||||
|
||||
match tokio::time::timeout(
|
||||
timeout,
|
||||
async_managed_client.client(/*request_headers*/ None),
|
||||
)
|
||||
.await
|
||||
{
|
||||
match tokio::time::timeout(timeout, async_managed_client.client()).await {
|
||||
Ok(Ok(_)) => true,
|
||||
Ok(Err(_)) | Err(_) => false,
|
||||
}
|
||||
@@ -854,7 +819,7 @@ impl McpConnectionManager {
|
||||
continue;
|
||||
};
|
||||
|
||||
match async_managed_client.client(/*request_headers*/ None).await {
|
||||
match async_managed_client.client().await {
|
||||
Ok(_) => {}
|
||||
Err(error) => failures.push(McpStartupFailure {
|
||||
server: server_name.clone(),
|
||||
@@ -869,19 +834,9 @@ impl McpConnectionManager {
|
||||
/// fully-qualified name for the tool.
|
||||
#[instrument(level = "trace", skip_all)]
|
||||
pub async fn list_all_tools(&self) -> HashMap<String, ToolInfo> {
|
||||
self.list_all_tools_with_request_headers(/*request_headers*/ None)
|
||||
.await
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip_all)]
|
||||
pub async fn list_all_tools_with_request_headers(
|
||||
&self,
|
||||
request_headers: Option<reqwest::header::HeaderMap>,
|
||||
) -> HashMap<String, ToolInfo> {
|
||||
let mut tools = HashMap::new();
|
||||
for managed_client in self.clients.values() {
|
||||
let Some(server_tools) = managed_client.listed_tools(request_headers.clone()).await
|
||||
else {
|
||||
let Some(server_tools) = managed_client.listed_tools().await else {
|
||||
continue;
|
||||
};
|
||||
tools.extend(qualify_tools(server_tools));
|
||||
@@ -899,7 +854,7 @@ impl McpConnectionManager {
|
||||
.clients
|
||||
.get(CODEX_APPS_MCP_SERVER_NAME)
|
||||
.ok_or_else(|| anyhow!("unknown MCP server '{CODEX_APPS_MCP_SERVER_NAME}'"))?
|
||||
.client(/*request_headers*/ None)
|
||||
.client()
|
||||
.await
|
||||
.context("failed to get client")?;
|
||||
|
||||
@@ -909,7 +864,6 @@ impl McpConnectionManager {
|
||||
CODEX_APPS_MCP_SERVER_NAME,
|
||||
&managed_client.client,
|
||||
managed_client.tool_timeout,
|
||||
/*request_headers*/ None,
|
||||
)
|
||||
.await
|
||||
.with_context(|| {
|
||||
@@ -946,8 +900,7 @@ impl McpConnectionManager {
|
||||
|
||||
for (server_name, async_managed_client) in clients_snapshot {
|
||||
let server_name = server_name.clone();
|
||||
let Ok(managed_client) = async_managed_client.client(/*request_headers*/ None).await
|
||||
else {
|
||||
let Ok(managed_client) = async_managed_client.client().await else {
|
||||
continue;
|
||||
};
|
||||
let timeout = managed_client.tool_timeout;
|
||||
@@ -1013,8 +966,7 @@ impl McpConnectionManager {
|
||||
|
||||
for (server_name, async_managed_client) in clients_snapshot {
|
||||
let server_name_cloned = server_name.clone();
|
||||
let Ok(managed_client) = async_managed_client.client(/*request_headers*/ None).await
|
||||
else {
|
||||
let Ok(managed_client) = async_managed_client.client().await else {
|
||||
continue;
|
||||
};
|
||||
let client = managed_client.client.clone();
|
||||
@@ -1082,9 +1034,8 @@ impl McpConnectionManager {
|
||||
tool: &str,
|
||||
arguments: Option<serde_json::Value>,
|
||||
meta: Option<serde_json::Value>,
|
||||
request_headers: Option<reqwest::header::HeaderMap>,
|
||||
) -> Result<CallToolResult> {
|
||||
let client = self.client_by_name(server, request_headers.clone()).await?;
|
||||
let client = self.client_by_name(server).await?;
|
||||
if !client.tool_filter.allows(tool) {
|
||||
return Err(anyhow!(
|
||||
"tool '{tool}' is disabled for MCP server '{server}'"
|
||||
@@ -1093,13 +1044,7 @@ impl McpConnectionManager {
|
||||
|
||||
let result: rmcp::model::CallToolResult = client
|
||||
.client
|
||||
.call_tool(
|
||||
tool.to_string(),
|
||||
arguments,
|
||||
meta,
|
||||
client.tool_timeout,
|
||||
request_headers,
|
||||
)
|
||||
.call_tool(tool.to_string(), arguments, meta, client.tool_timeout)
|
||||
.await
|
||||
.with_context(|| format!("tool call failed for `{server}/{tool}`"))?;
|
||||
|
||||
@@ -1120,15 +1065,23 @@ impl McpConnectionManager {
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn set_request_headers_for_server(
|
||||
&self,
|
||||
server_name: &str,
|
||||
request_headers: Option<reqwest::header::HeaderMap>,
|
||||
) {
|
||||
if let Some(client) = self.clients.get(server_name) {
|
||||
client.set_request_headers(request_headers);
|
||||
}
|
||||
}
|
||||
|
||||
/// List resources from the specified server.
|
||||
pub async fn list_resources(
|
||||
&self,
|
||||
server: &str,
|
||||
params: Option<PaginatedRequestParams>,
|
||||
) -> Result<ListResourcesResult> {
|
||||
let managed = self
|
||||
.client_by_name(server, /*request_headers*/ None)
|
||||
.await?;
|
||||
let managed = self.client_by_name(server).await?;
|
||||
let timeout = managed.tool_timeout;
|
||||
|
||||
managed
|
||||
@@ -1144,9 +1097,7 @@ impl McpConnectionManager {
|
||||
server: &str,
|
||||
params: Option<PaginatedRequestParams>,
|
||||
) -> Result<ListResourceTemplatesResult> {
|
||||
let managed = self
|
||||
.client_by_name(server, /*request_headers*/ None)
|
||||
.await?;
|
||||
let managed = self.client_by_name(server).await?;
|
||||
let client = managed.client.clone();
|
||||
let timeout = managed.tool_timeout;
|
||||
|
||||
@@ -1162,9 +1113,7 @@ impl McpConnectionManager {
|
||||
server: &str,
|
||||
params: ReadResourceRequestParams,
|
||||
) -> Result<ReadResourceResult> {
|
||||
let managed = self
|
||||
.client_by_name(server, /*request_headers*/ None)
|
||||
.await?;
|
||||
let managed = self.client_by_name(server).await?;
|
||||
let client = managed.client.clone();
|
||||
let timeout = managed.tool_timeout;
|
||||
let uri = params.uri.clone();
|
||||
@@ -1424,7 +1373,6 @@ async fn start_server_task(
|
||||
tx_event,
|
||||
elicitation_requests,
|
||||
codex_apps_tools_cache_context,
|
||||
startup_request_headers,
|
||||
} = params;
|
||||
let elicitation = elicitation_capability_for_server(&server_name);
|
||||
let params = InitializeRequestParams {
|
||||
@@ -1450,34 +1398,16 @@ async fn start_server_task(
|
||||
|
||||
let send_elicitation = elicitation_requests.make_sender(server_name.clone(), tx_event);
|
||||
|
||||
let initialize_request_headers = startup_request_headers
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
||||
.clone();
|
||||
let initialize_result = client
|
||||
.initialize(
|
||||
params,
|
||||
startup_timeout,
|
||||
send_elicitation,
|
||||
initialize_request_headers,
|
||||
)
|
||||
.initialize(params, startup_timeout, send_elicitation)
|
||||
.await
|
||||
.map_err(StartupOutcomeError::from)?;
|
||||
|
||||
let list_start = Instant::now();
|
||||
let fetch_start = Instant::now();
|
||||
let list_request_headers = startup_request_headers
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
||||
.clone();
|
||||
let tools = list_tools_for_client_uncached(
|
||||
&server_name,
|
||||
&client,
|
||||
startup_timeout,
|
||||
list_request_headers,
|
||||
)
|
||||
.await
|
||||
.map_err(StartupOutcomeError::from)?;
|
||||
let tools = list_tools_for_client_uncached(&server_name, &client, startup_timeout)
|
||||
.await
|
||||
.map_err(StartupOutcomeError::from)?;
|
||||
emit_duration(
|
||||
MCP_TOOLS_FETCH_UNCACHED_DURATION_METRIC,
|
||||
fetch_start.elapsed(),
|
||||
@@ -1522,13 +1452,13 @@ struct StartServerTaskParams {
|
||||
tx_event: Sender<Event>,
|
||||
elicitation_requests: ElicitationRequestManager,
|
||||
codex_apps_tools_cache_context: Option<CodexAppsToolsCacheContext>,
|
||||
startup_request_headers: Arc<StdMutex<Option<reqwest::header::HeaderMap>>>,
|
||||
}
|
||||
|
||||
async fn make_rmcp_client(
|
||||
server_name: &str,
|
||||
transport: McpServerTransportConfig,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
request_headers: Arc<StdMutex<Option<reqwest::header::HeaderMap>>>,
|
||||
) -> Result<RmcpClient, StartupOutcomeError> {
|
||||
match transport {
|
||||
McpServerTransportConfig::Stdio {
|
||||
@@ -1562,6 +1492,7 @@ async fn make_rmcp_client(
|
||||
http_headers,
|
||||
env_http_headers,
|
||||
store_mode,
|
||||
request_headers,
|
||||
)
|
||||
.await
|
||||
.map_err(StartupOutcomeError::from)
|
||||
@@ -1684,10 +1615,9 @@ async fn list_tools_for_client_uncached(
|
||||
server_name: &str,
|
||||
client: &Arc<RmcpClient>,
|
||||
timeout: Option<Duration>,
|
||||
request_headers: Option<reqwest::header::HeaderMap>,
|
||||
) -> Result<Vec<ToolInfo>> {
|
||||
let resp = client
|
||||
.list_tools_with_connector_ids(/*params*/ None, timeout, request_headers)
|
||||
.list_tools_with_connector_ids(/*params*/ None, timeout)
|
||||
.await?;
|
||||
let tools = resp
|
||||
.tools
|
||||
|
||||
@@ -4,7 +4,6 @@ use codex_protocol::protocol::McpAuthStatus;
|
||||
use rmcp::model::JsonObject;
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex as StdMutex;
|
||||
use tempfile::tempdir;
|
||||
|
||||
fn create_test_tool(server_name: &str, tool_name: &str) -> ToolInfo {
|
||||
@@ -416,7 +415,6 @@ async fn list_all_tools_uses_startup_snapshot_while_client_is_pending() {
|
||||
client: pending_client,
|
||||
startup_snapshot: Some(startup_tools),
|
||||
startup_complete: Arc::new(std::sync::atomic::AtomicBool::new(false)),
|
||||
startup_request_headers: Arc::new(StdMutex::new(None)),
|
||||
tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()),
|
||||
},
|
||||
);
|
||||
@@ -442,7 +440,6 @@ async fn list_all_tools_blocks_while_client_is_pending_without_startup_snapshot(
|
||||
client: pending_client,
|
||||
startup_snapshot: None,
|
||||
startup_complete: Arc::new(std::sync::atomic::AtomicBool::new(false)),
|
||||
startup_request_headers: Arc::new(StdMutex::new(None)),
|
||||
tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()),
|
||||
},
|
||||
);
|
||||
@@ -465,7 +462,6 @@ async fn list_all_tools_does_not_block_when_startup_snapshot_cache_hit_is_empty(
|
||||
client: pending_client,
|
||||
startup_snapshot: Some(Vec::new()),
|
||||
startup_complete: Arc::new(std::sync::atomic::AtomicBool::new(false)),
|
||||
startup_request_headers: Arc::new(StdMutex::new(None)),
|
||||
tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()),
|
||||
},
|
||||
);
|
||||
@@ -498,7 +494,6 @@ async fn list_all_tools_uses_startup_snapshot_when_client_startup_fails() {
|
||||
client: failed_client,
|
||||
startup_snapshot: Some(startup_tools),
|
||||
startup_complete,
|
||||
startup_request_headers: Arc::new(StdMutex::new(None)),
|
||||
tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()),
|
||||
},
|
||||
);
|
||||
|
||||
@@ -45,8 +45,6 @@ use codex_protocol::request_user_input::RequestUserInputQuestionOption;
|
||||
use codex_protocol::request_user_input::RequestUserInputResponse;
|
||||
use codex_rmcp_client::ElicitationAction;
|
||||
use codex_rmcp_client::ElicitationResponse;
|
||||
use reqwest::header::HeaderMap;
|
||||
use reqwest::header::HeaderValue;
|
||||
use rmcp::model::ToolAnnotations;
|
||||
use serde::Serialize;
|
||||
use std::path::Path;
|
||||
@@ -83,15 +81,8 @@ pub(crate) async fn handle_mcp_tool_call(
|
||||
arguments: arguments_value.clone(),
|
||||
};
|
||||
|
||||
let request_headers = build_mcp_request_headers(sess.as_ref(), turn_context.as_ref(), &server);
|
||||
let metadata = lookup_mcp_tool_metadata(
|
||||
sess.as_ref(),
|
||||
turn_context.as_ref(),
|
||||
&server,
|
||||
&tool_name,
|
||||
request_headers.clone(),
|
||||
)
|
||||
.await;
|
||||
let metadata =
|
||||
lookup_mcp_tool_metadata(sess.as_ref(), turn_context.as_ref(), &server, &tool_name).await;
|
||||
let app_tool_policy = if server == CODEX_APPS_MCP_SERVER_NAME {
|
||||
connectors::app_tool_policy(
|
||||
&turn_context.config,
|
||||
@@ -159,7 +150,6 @@ pub(crate) async fn handle_mcp_tool_call(
|
||||
&tool_name,
|
||||
arguments_value.clone(),
|
||||
request_meta.clone(),
|
||||
request_headers.clone(),
|
||||
)
|
||||
.await
|
||||
.map_err(|e| format!("tool call error: {e:?}"));
|
||||
@@ -190,7 +180,6 @@ pub(crate) async fn handle_mcp_tool_call(
|
||||
turn_context.as_ref(),
|
||||
&server,
|
||||
&tool_name,
|
||||
request_headers.clone(),
|
||||
)
|
||||
.await;
|
||||
result
|
||||
@@ -247,13 +236,7 @@ pub(crate) async fn handle_mcp_tool_call(
|
||||
let start = Instant::now();
|
||||
// Perform the tool call.
|
||||
let result = sess
|
||||
.call_tool(
|
||||
&server,
|
||||
&tool_name,
|
||||
arguments_value.clone(),
|
||||
request_meta,
|
||||
request_headers.clone(),
|
||||
)
|
||||
.call_tool(&server, &tool_name, arguments_value.clone(), request_meta)
|
||||
.await
|
||||
.map_err(|e| format!("tool call error: {e:?}"));
|
||||
let result = sanitize_mcp_tool_result_for_model(
|
||||
@@ -279,14 +262,7 @@ pub(crate) async fn handle_mcp_tool_call(
|
||||
tool_call_end_event.clone(),
|
||||
)
|
||||
.await;
|
||||
maybe_track_codex_app_used(
|
||||
sess.as_ref(),
|
||||
turn_context.as_ref(),
|
||||
&server,
|
||||
&tool_name,
|
||||
request_headers.clone(),
|
||||
)
|
||||
.await;
|
||||
maybe_track_codex_app_used(sess.as_ref(), turn_context.as_ref(), &server, &tool_name).await;
|
||||
|
||||
let status = if result.is_ok() { "ok" } else { "error" };
|
||||
turn_context
|
||||
@@ -357,12 +333,11 @@ async fn maybe_track_codex_app_used(
|
||||
turn_context: &TurnContext,
|
||||
server: &str,
|
||||
tool_name: &str,
|
||||
request_headers: Option<HeaderMap>,
|
||||
) {
|
||||
if server != CODEX_APPS_MCP_SERVER_NAME {
|
||||
return;
|
||||
}
|
||||
let metadata = lookup_mcp_app_usage_metadata(sess, server, tool_name, request_headers).await;
|
||||
let metadata = lookup_mcp_app_usage_metadata(sess, server, tool_name).await;
|
||||
let (connector_id, app_name) = metadata
|
||||
.map(|metadata| (metadata.connector_id, metadata.app_name))
|
||||
.unwrap_or((None, None));
|
||||
@@ -414,34 +389,6 @@ pub(crate) struct McpToolApprovalMetadata {
|
||||
|
||||
const MCP_TOOL_CODEX_APPS_META_KEY: &str = "_codex_apps";
|
||||
|
||||
fn build_mcp_request_headers(
|
||||
sess: &Session,
|
||||
turn_context: &TurnContext,
|
||||
server: &str,
|
||||
) -> Option<HeaderMap> {
|
||||
if server != CODEX_APPS_MCP_SERVER_NAME {
|
||||
return None;
|
||||
}
|
||||
|
||||
let session_id = sess.conversation_id.to_string();
|
||||
let mut headers = HeaderMap::new();
|
||||
if let Ok(value) = HeaderValue::from_str(&session_id) {
|
||||
headers.insert("session_id", value.clone());
|
||||
headers.insert("x-client-request-id", value);
|
||||
}
|
||||
if let Some(turn_metadata) = turn_context.turn_metadata_state.current_header_value()
|
||||
&& let Ok(value) = HeaderValue::from_str(&turn_metadata)
|
||||
{
|
||||
headers.insert(crate::X_CODEX_TURN_METADATA_HEADER, value);
|
||||
}
|
||||
|
||||
if headers.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(headers)
|
||||
}
|
||||
}
|
||||
|
||||
fn build_mcp_tool_call_request_meta(
|
||||
server: &str,
|
||||
metadata: Option<&McpToolApprovalMetadata>,
|
||||
@@ -794,14 +741,13 @@ pub(crate) async fn lookup_mcp_tool_metadata(
|
||||
turn_context: &TurnContext,
|
||||
server: &str,
|
||||
tool_name: &str,
|
||||
request_headers: Option<HeaderMap>,
|
||||
) -> Option<McpToolApprovalMetadata> {
|
||||
let tools = sess
|
||||
.services
|
||||
.mcp_connection_manager
|
||||
.read()
|
||||
.await
|
||||
.list_all_tools_with_request_headers(request_headers.clone())
|
||||
.list_all_tools()
|
||||
.await;
|
||||
|
||||
let tool_info = tools
|
||||
@@ -852,14 +798,13 @@ async fn lookup_mcp_app_usage_metadata(
|
||||
sess: &Session,
|
||||
server: &str,
|
||||
tool_name: &str,
|
||||
request_headers: Option<HeaderMap>,
|
||||
) -> Option<McpAppUsageMetadata> {
|
||||
let tools = sess
|
||||
.services
|
||||
.mcp_connection_manager
|
||||
.read()
|
||||
.await
|
||||
.list_all_tools_with_request_headers(request_headers)
|
||||
.list_all_tools()
|
||||
.await;
|
||||
|
||||
tools.into_values().find_map(|tool_info| {
|
||||
|
||||
@@ -153,6 +153,8 @@ impl Session {
|
||||
) {
|
||||
self.abort_all_tasks(TurnAbortReason::Replaced).await;
|
||||
self.clear_connector_selection().await;
|
||||
self.sync_mcp_request_headers_for_turn(turn_context.as_ref())
|
||||
.await;
|
||||
|
||||
let task: Arc<dyn SessionTask> = Arc::new(task);
|
||||
let task_kind = task.kind();
|
||||
@@ -233,6 +235,7 @@ impl Session {
|
||||
// in-flight approval wait can surface as a model-visible rejection before TurnAborted.
|
||||
active_turn.clear_pending().await;
|
||||
}
|
||||
self.clear_mcp_request_headers().await;
|
||||
}
|
||||
|
||||
pub async fn on_task_finished(
|
||||
@@ -262,6 +265,9 @@ impl Session {
|
||||
*active = None;
|
||||
}
|
||||
drop(active);
|
||||
if should_clear_active_turn {
|
||||
self.clear_mcp_request_headers().await;
|
||||
}
|
||||
if !pending_input.is_empty() {
|
||||
for pending_input_item in pending_input {
|
||||
match inspect_pending_input(self, &turn_context, pending_input_item).await {
|
||||
|
||||
@@ -23,6 +23,7 @@ use reqwest::header::HeaderMap;
|
||||
use reqwest::header::WWW_AUTHENTICATE;
|
||||
use rmcp::model::CallToolRequestParams;
|
||||
use rmcp::model::CallToolResult;
|
||||
use rmcp::model::ClientJsonRpcMessage;
|
||||
use rmcp::model::ClientNotification;
|
||||
use rmcp::model::ClientRequest;
|
||||
use rmcp::model::CreateElicitationRequestParams;
|
||||
@@ -83,6 +84,15 @@ const JSON_MIME_TYPE: &str = "application/json";
|
||||
const HEADER_LAST_EVENT_ID: &str = "Last-Event-Id";
|
||||
const HEADER_SESSION_ID: &str = "Mcp-Session-Id";
|
||||
const NON_JSON_RESPONSE_BODY_PREVIEW_BYTES: usize = 8_192;
|
||||
|
||||
fn message_uses_request_scoped_headers(message: &ClientJsonRpcMessage) -> bool {
|
||||
matches!(
|
||||
message,
|
||||
ClientJsonRpcMessage::Request(request)
|
||||
if request.request.method() == "tools/call"
|
||||
)
|
||||
}
|
||||
|
||||
fn apply_request_scoped_headers(
|
||||
mut request: reqwest::RequestBuilder,
|
||||
request_headers_state: &Arc<StdMutex<Option<HeaderMap>>>,
|
||||
@@ -156,7 +166,9 @@ impl StreamableHttpClient for StreamableHttpResponseClient {
|
||||
if let Some(session_id_value) = session_id.as_ref() {
|
||||
request = request.header(HEADER_SESSION_ID, session_id_value.as_ref());
|
||||
}
|
||||
request = apply_request_scoped_headers(request, &self.request_headers_state);
|
||||
if message_uses_request_scoped_headers(&message) {
|
||||
request = apply_request_scoped_headers(request, &self.request_headers_state);
|
||||
}
|
||||
|
||||
let response = request
|
||||
.json(&message)
|
||||
@@ -252,8 +264,6 @@ impl StreamableHttpClient for StreamableHttpResponseClient {
|
||||
if let Some(auth_header) = auth_token {
|
||||
request_builder = request_builder.bearer_auth(auth_header);
|
||||
}
|
||||
request_builder =
|
||||
apply_request_scoped_headers(request_builder, &self.request_headers_state);
|
||||
let response = request_builder
|
||||
.header(HEADER_SESSION_ID, session.as_ref())
|
||||
.send()
|
||||
@@ -291,8 +301,6 @@ impl StreamableHttpClient for StreamableHttpResponseClient {
|
||||
if let Some(auth_header) = auth_token {
|
||||
request_builder = request_builder.bearer_auth(auth_header);
|
||||
}
|
||||
request_builder =
|
||||
apply_request_scoped_headers(request_builder, &self.request_headers_state);
|
||||
|
||||
let response = request_builder
|
||||
.send()
|
||||
@@ -503,38 +511,6 @@ pub struct RmcpClient {
|
||||
request_headers: Option<Arc<StdMutex<Option<HeaderMap>>>>,
|
||||
}
|
||||
|
||||
struct RequestHeadersGuard {
|
||||
state: Option<Arc<StdMutex<Option<HeaderMap>>>>,
|
||||
previous: Option<HeaderMap>,
|
||||
}
|
||||
|
||||
impl RequestHeadersGuard {
|
||||
fn set(state: Option<Arc<StdMutex<Option<HeaderMap>>>>, headers: Option<HeaderMap>) -> Self {
|
||||
let previous = if let Some(state_ref) = state.as_ref() {
|
||||
let mut guard = state_ref
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
let previous = guard.clone();
|
||||
*guard = headers;
|
||||
previous
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Self { state, previous }
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for RequestHeadersGuard {
|
||||
fn drop(&mut self) {
|
||||
if let Some(state) = self.state.as_ref() {
|
||||
let mut guard = state
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
*guard = self.previous.clone();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RmcpClient {
|
||||
pub async fn new_stdio_client(
|
||||
program: OsString,
|
||||
@@ -574,6 +550,7 @@ impl RmcpClient {
|
||||
http_headers: Option<HashMap<String, String>>,
|
||||
env_http_headers: Option<HashMap<String, String>>,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
request_headers: Arc<StdMutex<Option<HeaderMap>>>,
|
||||
) -> Result<Self> {
|
||||
let transport_recipe = TransportRecipe::StreamableHttp {
|
||||
server_name: server_name.to_string(),
|
||||
@@ -583,9 +560,9 @@ impl RmcpClient {
|
||||
env_http_headers,
|
||||
store_mode,
|
||||
};
|
||||
let request_headers = Some(Arc::new(StdMutex::new(None)));
|
||||
let transport =
|
||||
Self::create_pending_transport(&transport_recipe, request_headers.clone()).await?;
|
||||
Self::create_pending_transport(&transport_recipe, Some(Arc::clone(&request_headers)))
|
||||
.await?;
|
||||
Ok(Self {
|
||||
state: Mutex::new(ClientState::Connecting {
|
||||
transport: Some(transport),
|
||||
@@ -593,7 +570,7 @@ impl RmcpClient {
|
||||
transport_recipe,
|
||||
initialize_context: Mutex::new(None),
|
||||
session_recovery_lock: Mutex::new(()),
|
||||
request_headers,
|
||||
request_headers: Some(request_headers),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -604,7 +581,6 @@ impl RmcpClient {
|
||||
params: InitializeRequestParams,
|
||||
timeout: Option<Duration>,
|
||||
send_elicitation: SendElicitation,
|
||||
request_headers: Option<HeaderMap>,
|
||||
) -> Result<InitializeResult> {
|
||||
let client_handler = LoggingClientHandler::new(params.clone(), send_elicitation);
|
||||
let pending_transport = {
|
||||
@@ -618,8 +594,6 @@ impl RmcpClient {
|
||||
}
|
||||
};
|
||||
|
||||
let _request_headers_guard =
|
||||
RequestHeadersGuard::set(self.request_headers.clone(), request_headers);
|
||||
let (service, oauth_persistor, process_group_guard) =
|
||||
Self::connect_pending_transport(pending_transport, client_handler.clone(), timeout)
|
||||
.await?;
|
||||
@@ -676,11 +650,8 @@ impl RmcpClient {
|
||||
&self,
|
||||
params: Option<PaginatedRequestParams>,
|
||||
timeout: Option<Duration>,
|
||||
request_headers: Option<HeaderMap>,
|
||||
) -> Result<ListToolsWithConnectorIdResult> {
|
||||
self.refresh_oauth_if_needed().await;
|
||||
let _request_headers_guard =
|
||||
RequestHeadersGuard::set(self.request_headers.clone(), request_headers);
|
||||
let result = self
|
||||
.run_service_operation("tools/list", timeout, move |service| {
|
||||
let params = params.clone();
|
||||
@@ -774,7 +745,6 @@ impl RmcpClient {
|
||||
arguments: Option<serde_json::Value>,
|
||||
meta: Option<serde_json::Value>,
|
||||
timeout: Option<Duration>,
|
||||
request_headers: Option<HeaderMap>,
|
||||
) -> Result<CallToolResult> {
|
||||
self.refresh_oauth_if_needed().await;
|
||||
let arguments = match arguments {
|
||||
@@ -801,8 +771,6 @@ impl RmcpClient {
|
||||
arguments,
|
||||
task: None,
|
||||
};
|
||||
let _request_headers_guard =
|
||||
RequestHeadersGuard::set(self.request_headers.clone(), request_headers);
|
||||
let result = self
|
||||
.run_service_operation("tools/call", timeout, move |service| {
|
||||
let rmcp_params = rmcp_params.clone();
|
||||
|
||||
@@ -78,7 +78,6 @@ async fn rmcp_client_can_list_and_read_resources() -> anyhow::Result<()> {
|
||||
}
|
||||
.boxed()
|
||||
}),
|
||||
/*request_headers*/ None,
|
||||
)
|
||||
.await?;
|
||||
|
||||
|
||||
@@ -94,7 +94,6 @@ async fn create_client(base_url: &str) -> anyhow::Result<RmcpClient> {
|
||||
}
|
||||
.boxed()
|
||||
}),
|
||||
/*request_headers*/ None,
|
||||
)
|
||||
.await?;
|
||||
|
||||
@@ -108,7 +107,6 @@ async fn call_echo_tool(client: &RmcpClient, message: &str) -> anyhow::Result<Ca
|
||||
Some(json!({ "message": message })),
|
||||
None,
|
||||
Some(Duration::from_secs(5)),
|
||||
/*request_headers*/ None,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user