mirror of
https://github.com/openai/codex.git
synced 2026-05-02 10:26:45 +00:00
Compare commits
10 Commits
codex-fix/
...
ychhabria/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
aedd076b8c | ||
|
|
55be8ca802 | ||
|
|
1a2536d12b | ||
|
|
732d7ac81f | ||
|
|
1f3ec4172c | ||
|
|
52fdfbcfb8 | ||
|
|
2321532dec | ||
|
|
de98643403 | ||
|
|
fb27c20581 | ||
|
|
bccce0f2d8 |
@@ -125,6 +125,8 @@ use futures::future::BoxFuture;
|
||||
use futures::future::Shared;
|
||||
use futures::prelude::*;
|
||||
use futures::stream::FuturesOrdered;
|
||||
use reqwest::header::HeaderMap;
|
||||
use reqwest::header::HeaderValue;
|
||||
use rmcp::model::ListResourceTemplatesResult;
|
||||
use rmcp::model::ListResourcesResult;
|
||||
use rmcp::model::PaginatedRequestParams;
|
||||
@@ -3943,6 +3945,12 @@ impl Session {
|
||||
arguments: Option<serde_json::Value>,
|
||||
meta: Option<serde_json::Value>,
|
||||
) -> anyhow::Result<CallToolResult> {
|
||||
if server == CODEX_APPS_MCP_SERVER_NAME
|
||||
&& let Some((turn_context, _)) = self.active_turn_context_and_cancellation_token().await
|
||||
{
|
||||
self.sync_mcp_request_headers_for_turn(turn_context.as_ref())
|
||||
.await;
|
||||
}
|
||||
self.services
|
||||
.mcp_connection_manager
|
||||
.read()
|
||||
@@ -3951,6 +3959,45 @@ impl Session {
|
||||
.await
|
||||
}
|
||||
|
||||
pub(crate) async fn sync_mcp_request_headers_for_turn(&self, turn_context: &TurnContext) {
|
||||
let mut request_headers = HeaderMap::new();
|
||||
let session_id = self.conversation_id.to_string();
|
||||
if let Ok(value) = HeaderValue::from_str(&session_id) {
|
||||
request_headers.insert("session_id", value.clone());
|
||||
request_headers.insert("x-client-request-id", value);
|
||||
}
|
||||
if let Some(turn_metadata) = turn_context.turn_metadata_state.current_header_value()
|
||||
&& let Ok(value) = HeaderValue::from_str(&turn_metadata)
|
||||
{
|
||||
request_headers.insert(crate::X_CODEX_TURN_METADATA_HEADER, value);
|
||||
}
|
||||
|
||||
let request_headers = if request_headers.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(request_headers)
|
||||
};
|
||||
self.services
|
||||
.mcp_connection_manager
|
||||
.read()
|
||||
.await
|
||||
.set_request_headers_for_server(
|
||||
crate::mcp::CODEX_APPS_MCP_SERVER_NAME,
|
||||
request_headers,
|
||||
);
|
||||
}
|
||||
|
||||
pub(crate) async fn clear_mcp_request_headers(&self) {
|
||||
self.services
|
||||
.mcp_connection_manager
|
||||
.read()
|
||||
.await
|
||||
.set_request_headers_for_server(
|
||||
crate::mcp::CODEX_APPS_MCP_SERVER_NAME,
|
||||
/*request_headers*/ None,
|
||||
);
|
||||
}
|
||||
|
||||
pub(crate) async fn parse_mcp_tool_name(
|
||||
&self,
|
||||
name: &str,
|
||||
|
||||
@@ -26,6 +26,8 @@ use codex_protocol::protocol::ReadOnlyAccess;
|
||||
use codex_protocol::protocol::SandboxPolicy;
|
||||
use codex_protocol::request_permissions::PermissionGrantScope;
|
||||
use codex_protocol::request_permissions::RequestPermissionProfile;
|
||||
use codex_protocol::user_input::UserInput;
|
||||
use reqwest::header::HeaderValue;
|
||||
use tracing::Span;
|
||||
|
||||
use crate::protocol::CompactedItem;
|
||||
@@ -56,6 +58,7 @@ use crate::tools::handlers::UnifiedExecHandler;
|
||||
use crate::tools::registry::ToolHandler;
|
||||
use crate::tools::router::ToolCallSource;
|
||||
use crate::turn_diff_tracker::TurnDiffTracker;
|
||||
use async_trait::async_trait;
|
||||
use codex_app_server_protocol::AppInfo;
|
||||
use codex_execpolicy::Decision;
|
||||
use codex_execpolicy::NetworkRuleProtocol;
|
||||
@@ -78,7 +81,10 @@ use opentelemetry::trace::TracerProvider as _;
|
||||
use opentelemetry_sdk::trace::SdkTracerProvider;
|
||||
use std::path::Path;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::Notify;
|
||||
use tokio::time::sleep;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tokio_util::task::AbortOnDropHandle;
|
||||
use tracing_opentelemetry::OpenTelemetrySpanExt;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
@@ -2716,6 +2722,99 @@ async fn request_permissions_is_auto_denied_when_granular_policy_blocks_tool_req
|
||||
);
|
||||
}
|
||||
|
||||
struct NoopTask;
|
||||
|
||||
#[async_trait]
|
||||
impl SessionTask for NoopTask {
|
||||
fn kind(&self) -> TaskKind {
|
||||
TaskKind::Regular
|
||||
}
|
||||
|
||||
fn span_name(&self) -> &'static str {
|
||||
"noop"
|
||||
}
|
||||
|
||||
async fn run(
|
||||
self: Arc<Self>,
|
||||
_ctx: Arc<SessionTaskContext>,
|
||||
_turn_context: Arc<TurnContext>,
|
||||
_input: Vec<UserInput>,
|
||||
_cancellation_token: CancellationToken,
|
||||
) -> Option<String> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn call_tool_refreshes_codex_apps_request_headers_from_active_turn() {
|
||||
let (session, turn_context) = make_session_and_context().await;
|
||||
let session = Arc::new(session);
|
||||
let turn_context = Arc::new(turn_context);
|
||||
|
||||
session
|
||||
.services
|
||||
.mcp_connection_manager
|
||||
.write()
|
||||
.await
|
||||
.register_test_server_for_request_headers(CODEX_APPS_MCP_SERVER_NAME);
|
||||
|
||||
session
|
||||
.sync_mcp_request_headers_for_turn(turn_context.as_ref())
|
||||
.await;
|
||||
let base_header = turn_context
|
||||
.turn_metadata_state
|
||||
.current_header_value()
|
||||
.expect("base turn metadata header");
|
||||
|
||||
let updated_header = serde_json::json!({
|
||||
"turn_id": turn_context.sub_id,
|
||||
"sandbox": "test",
|
||||
"workspaces": {
|
||||
"/tmp/repo": {
|
||||
"has_changes": true
|
||||
}
|
||||
}
|
||||
})
|
||||
.to_string();
|
||||
turn_context
|
||||
.turn_metadata_state
|
||||
.set_enriched_header_for_tests(Some(updated_header.clone()));
|
||||
|
||||
let mut active_turn = ActiveTurn::default();
|
||||
let handle = tokio::spawn(async {});
|
||||
active_turn.add_task(crate::state::RunningTask {
|
||||
done: Arc::new(Notify::new()),
|
||||
kind: TaskKind::Regular,
|
||||
task: Arc::new(NoopTask),
|
||||
cancellation_token: CancellationToken::new(),
|
||||
handle: Arc::new(AbortOnDropHandle::new(handle)),
|
||||
turn_context: Arc::clone(&turn_context),
|
||||
_timer: None,
|
||||
});
|
||||
*session.active_turn.lock().await = Some(active_turn);
|
||||
|
||||
let _err = session
|
||||
.call_tool(CODEX_APPS_MCP_SERVER_NAME, "echo", None, None)
|
||||
.await
|
||||
.expect_err("test server is not initialized");
|
||||
|
||||
let headers = session
|
||||
.services
|
||||
.mcp_connection_manager
|
||||
.read()
|
||||
.await
|
||||
.request_headers_for_server(CODEX_APPS_MCP_SERVER_NAME)
|
||||
.expect("request headers should be tracked for codex apps");
|
||||
assert_eq!(
|
||||
headers.get(crate::X_CODEX_TURN_METADATA_HEADER),
|
||||
Some(&HeaderValue::from_str(&updated_header).expect("valid enriched header")),
|
||||
);
|
||||
assert_ne!(
|
||||
headers.get(crate::X_CODEX_TURN_METADATA_HEADER),
|
||||
Some(&HeaderValue::from_str(&base_header).expect("valid base header")),
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn submit_with_id_captures_current_span_trace_context() {
|
||||
let (session, _turn_context) = make_session_and_context().await;
|
||||
|
||||
@@ -423,6 +423,7 @@ impl ManagedClient {
|
||||
#[derive(Clone)]
|
||||
struct AsyncManagedClient {
|
||||
client: Shared<BoxFuture<'static, Result<ManagedClient, StartupOutcomeError>>>,
|
||||
request_headers: Arc<StdMutex<Option<reqwest::header::HeaderMap>>>,
|
||||
startup_snapshot: Option<Vec<ToolInfo>>,
|
||||
startup_complete: Arc<AtomicBool>,
|
||||
tool_plugin_provenance: Arc<ToolPluginProvenance>,
|
||||
@@ -448,17 +449,26 @@ impl AsyncManagedClient {
|
||||
codex_apps_tools_cache_context.as_ref(),
|
||||
)
|
||||
.map(|tools| filter_tools(tools, &tool_filter));
|
||||
let request_headers = Arc::new(StdMutex::new(None));
|
||||
let startup_tool_filter = tool_filter;
|
||||
let startup_complete = Arc::new(AtomicBool::new(false));
|
||||
let startup_complete_for_fut = Arc::clone(&startup_complete);
|
||||
let request_headers_for_client = Arc::clone(&request_headers);
|
||||
let fut = async move {
|
||||
let outcome = async {
|
||||
if let Err(error) = validate_mcp_server_name(&server_name) {
|
||||
return Err(error.into());
|
||||
}
|
||||
|
||||
let client =
|
||||
Arc::new(make_rmcp_client(&server_name, config.transport, store_mode).await?);
|
||||
let client = Arc::new(
|
||||
make_rmcp_client(
|
||||
&server_name,
|
||||
config.transport,
|
||||
store_mode,
|
||||
request_headers_for_client,
|
||||
)
|
||||
.await?,
|
||||
);
|
||||
match start_server_task(
|
||||
server_name,
|
||||
client,
|
||||
@@ -495,6 +505,7 @@ impl AsyncManagedClient {
|
||||
|
||||
Self {
|
||||
client,
|
||||
request_headers,
|
||||
startup_snapshot,
|
||||
startup_complete,
|
||||
tool_plugin_provenance,
|
||||
@@ -576,6 +587,14 @@ impl AsyncManagedClient {
|
||||
let managed = self.client().await?;
|
||||
managed.notify_sandbox_state_change(sandbox_state).await
|
||||
}
|
||||
|
||||
fn set_request_headers(&self, request_headers: Option<reqwest::header::HeaderMap>) {
|
||||
let mut guard = self
|
||||
.request_headers
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
*guard = request_headers;
|
||||
}
|
||||
}
|
||||
|
||||
pub const MCP_SANDBOX_STATE_CAPABILITY: &str = "codex/sandbox-state";
|
||||
@@ -617,6 +636,40 @@ impl McpConnectionManager {
|
||||
Self::new_uninitialized(approval_policy)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn register_test_server_for_request_headers(&mut self, server_name: &str) {
|
||||
let failed_client = futures::future::ready::<Result<ManagedClient, StartupOutcomeError>>(
|
||||
Err(StartupOutcomeError::Failed {
|
||||
error: "test request headers stub".to_string(),
|
||||
}),
|
||||
)
|
||||
.boxed()
|
||||
.shared();
|
||||
self.clients.insert(
|
||||
server_name.to_string(),
|
||||
AsyncManagedClient {
|
||||
client: failed_client,
|
||||
request_headers: Arc::new(StdMutex::new(None)),
|
||||
startup_snapshot: None,
|
||||
startup_complete: Arc::new(AtomicBool::new(true)),
|
||||
tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn request_headers_for_server(
|
||||
&self,
|
||||
server_name: &str,
|
||||
) -> Option<reqwest::header::HeaderMap> {
|
||||
let client = self.clients.get(server_name)?;
|
||||
client
|
||||
.request_headers
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
||||
.clone()
|
||||
}
|
||||
|
||||
pub(crate) fn has_servers(&self) -> bool {
|
||||
!self.clients.is_empty()
|
||||
}
|
||||
@@ -1046,6 +1099,16 @@ impl McpConnectionManager {
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn set_request_headers_for_server(
|
||||
&self,
|
||||
server_name: &str,
|
||||
request_headers: Option<reqwest::header::HeaderMap>,
|
||||
) {
|
||||
if let Some(client) = self.clients.get(server_name) {
|
||||
client.set_request_headers(request_headers);
|
||||
}
|
||||
}
|
||||
|
||||
/// List resources from the specified server.
|
||||
pub async fn list_resources(
|
||||
&self,
|
||||
@@ -1429,6 +1492,7 @@ async fn make_rmcp_client(
|
||||
server_name: &str,
|
||||
transport: McpServerTransportConfig,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
request_headers: Arc<StdMutex<Option<reqwest::header::HeaderMap>>>,
|
||||
) -> Result<RmcpClient, StartupOutcomeError> {
|
||||
match transport {
|
||||
McpServerTransportConfig::Stdio {
|
||||
@@ -1462,6 +1526,7 @@ async fn make_rmcp_client(
|
||||
http_headers,
|
||||
env_http_headers,
|
||||
store_mode,
|
||||
request_headers,
|
||||
)
|
||||
.await
|
||||
.map_err(StartupOutcomeError::from)
|
||||
|
||||
@@ -4,6 +4,7 @@ use codex_protocol::protocol::McpAuthStatus;
|
||||
use rmcp::model::JsonObject;
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex as StdMutex;
|
||||
use tempfile::tempdir;
|
||||
|
||||
fn create_test_tool(server_name: &str, tool_name: &str) -> ToolInfo {
|
||||
@@ -413,6 +414,7 @@ async fn list_all_tools_uses_startup_snapshot_while_client_is_pending() {
|
||||
CODEX_APPS_MCP_SERVER_NAME.to_string(),
|
||||
AsyncManagedClient {
|
||||
client: pending_client,
|
||||
request_headers: Arc::new(StdMutex::new(None)),
|
||||
startup_snapshot: Some(startup_tools),
|
||||
startup_complete: Arc::new(std::sync::atomic::AtomicBool::new(false)),
|
||||
tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()),
|
||||
@@ -438,6 +440,7 @@ async fn list_all_tools_blocks_while_client_is_pending_without_startup_snapshot(
|
||||
CODEX_APPS_MCP_SERVER_NAME.to_string(),
|
||||
AsyncManagedClient {
|
||||
client: pending_client,
|
||||
request_headers: Arc::new(StdMutex::new(None)),
|
||||
startup_snapshot: None,
|
||||
startup_complete: Arc::new(std::sync::atomic::AtomicBool::new(false)),
|
||||
tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()),
|
||||
@@ -460,6 +463,7 @@ async fn list_all_tools_does_not_block_when_startup_snapshot_cache_hit_is_empty(
|
||||
CODEX_APPS_MCP_SERVER_NAME.to_string(),
|
||||
AsyncManagedClient {
|
||||
client: pending_client,
|
||||
request_headers: Arc::new(StdMutex::new(None)),
|
||||
startup_snapshot: Some(Vec::new()),
|
||||
startup_complete: Arc::new(std::sync::atomic::AtomicBool::new(false)),
|
||||
tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()),
|
||||
@@ -492,6 +496,7 @@ async fn list_all_tools_uses_startup_snapshot_when_client_startup_fails() {
|
||||
CODEX_APPS_MCP_SERVER_NAME.to_string(),
|
||||
AsyncManagedClient {
|
||||
client: failed_client,
|
||||
request_headers: Arc::new(StdMutex::new(None)),
|
||||
startup_snapshot: Some(startup_tools),
|
||||
startup_complete,
|
||||
tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()),
|
||||
|
||||
@@ -153,6 +153,8 @@ impl Session {
|
||||
) {
|
||||
self.abort_all_tasks(TurnAbortReason::Replaced).await;
|
||||
self.clear_connector_selection().await;
|
||||
self.sync_mcp_request_headers_for_turn(turn_context.as_ref())
|
||||
.await;
|
||||
|
||||
let task: Arc<dyn SessionTask> = Arc::new(task);
|
||||
let task_kind = task.kind();
|
||||
@@ -233,6 +235,7 @@ impl Session {
|
||||
// in-flight approval wait can surface as a model-visible rejection before TurnAborted.
|
||||
active_turn.clear_pending().await;
|
||||
}
|
||||
self.clear_mcp_request_headers().await;
|
||||
}
|
||||
|
||||
pub async fn on_task_finished(
|
||||
@@ -262,6 +265,9 @@ impl Session {
|
||||
*active = None;
|
||||
}
|
||||
drop(active);
|
||||
if should_clear_active_turn {
|
||||
self.clear_mcp_request_headers().await;
|
||||
}
|
||||
if !pending_input.is_empty() {
|
||||
for pending_input_item in pending_input {
|
||||
match inspect_pending_input(self, &turn_context, pending_input_item).await {
|
||||
|
||||
@@ -217,6 +217,14 @@ impl TurnMetadataState {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn set_enriched_header_for_tests(&self, header: Option<String>) {
|
||||
*self
|
||||
.enriched_header
|
||||
.write()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner) = header;
|
||||
}
|
||||
|
||||
async fn fetch_workspace_git_metadata(&self) -> WorkspaceGitMetadata {
|
||||
let (latest_git_commit_hash, associated_remote_urls, has_changes) = tokio::join!(
|
||||
get_head_commit_hash(&self.cwd),
|
||||
|
||||
@@ -30,6 +30,8 @@ use core_test_support::wait_for_event;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::Value;
|
||||
use serde_json::json;
|
||||
use std::fs;
|
||||
use std::process::Command;
|
||||
|
||||
const SEARCH_TOOL_DESCRIPTION_SNIPPETS: [&str; 2] = [
|
||||
"You have access to all the tools of the following apps/connectors",
|
||||
@@ -86,6 +88,15 @@ fn tool_search_output_tools(request: &ResponsesRequest, call_id: &str) -> Vec<Va
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
fn json_rpc_method(request: &wiremock::Request) -> Option<String> {
|
||||
request
|
||||
.body_json::<Value>()
|
||||
.ok()?
|
||||
.get("method")
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string)
|
||||
}
|
||||
|
||||
fn configure_apps(config: &mut Config, apps_base_url: &str) {
|
||||
config
|
||||
.features
|
||||
@@ -499,5 +510,195 @@ async fn tool_search_returns_deferred_tools_without_follow_up_tool_injection() -
|
||||
"post-tool follow-up should still rely on tool_search_output history, not tool injection: {third_request_tools:?}"
|
||||
);
|
||||
|
||||
let mcp_requests = server
|
||||
.received_requests()
|
||||
.await
|
||||
.expect("failed to fetch recorded requests");
|
||||
let tools_list_request = mcp_requests
|
||||
.iter()
|
||||
.find(|request| json_rpc_method(request).as_deref() == Some("tools/list"))
|
||||
.expect("tools/list MCP request");
|
||||
assert!(
|
||||
tools_list_request
|
||||
.headers
|
||||
.get("x-codex-turn-metadata")
|
||||
.is_none(),
|
||||
"tools/list should not include per-turn MCP headers"
|
||||
);
|
||||
|
||||
let tools_call_request = mcp_requests
|
||||
.iter()
|
||||
.find(|request| json_rpc_method(request).as_deref() == Some("tools/call"))
|
||||
.expect("tools/call MCP request");
|
||||
let session_id_header = tools_call_request
|
||||
.headers
|
||||
.get("session_id")
|
||||
.expect("tools/call session_id header");
|
||||
let request_id_header = tools_call_request
|
||||
.headers
|
||||
.get("x-client-request-id")
|
||||
.expect("tools/call x-client-request-id header");
|
||||
let turn_metadata_header = tools_call_request
|
||||
.headers
|
||||
.get("x-codex-turn-metadata")
|
||||
.expect("tools/call turn metadata header");
|
||||
assert_eq!(
|
||||
session_id_header
|
||||
.to_str()
|
||||
.expect("session_id header to be utf8"),
|
||||
request_id_header
|
||||
.to_str()
|
||||
.expect("x-client-request-id header to be utf8")
|
||||
);
|
||||
assert!(
|
||||
turn_metadata_header
|
||||
.to_str()
|
||||
.expect("turn metadata header to be utf8")
|
||||
.contains("\"turn_id\""),
|
||||
"expected turn metadata header to contain serialized turn metadata"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn apps_mcp_tool_call_uses_enriched_turn_metadata_header() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let apps_server = AppsTestServer::mount_searchable(&server).await?;
|
||||
let call_id = "tool-search-git-metadata";
|
||||
let mock = mount_sse_sequence(
|
||||
&server,
|
||||
vec![
|
||||
sse(vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_tool_search_call(
|
||||
call_id,
|
||||
&json!({
|
||||
"query": "create calendar event",
|
||||
"limit": 1,
|
||||
}),
|
||||
),
|
||||
ev_completed("resp-1"),
|
||||
]),
|
||||
sse(vec![
|
||||
ev_response_created("resp-2"),
|
||||
json!({
|
||||
"type": "response.output_item.done",
|
||||
"item": {
|
||||
"type": "function_call",
|
||||
"call_id": "calendar-call-git-metadata",
|
||||
"name": SEARCH_CALENDAR_CREATE_TOOL,
|
||||
"namespace": SEARCH_CALENDAR_NAMESPACE,
|
||||
"arguments": serde_json::to_string(&json!({
|
||||
"title": "Lunch",
|
||||
"starts_at": "2026-03-10T12:00:00Z"
|
||||
})).expect("serialize calendar args")
|
||||
}
|
||||
}),
|
||||
ev_completed("resp-2"),
|
||||
]),
|
||||
sse(vec![
|
||||
ev_response_created("resp-3"),
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-3"),
|
||||
]),
|
||||
],
|
||||
)
|
||||
.await;
|
||||
|
||||
let mut builder = configured_builder(apps_server.chatgpt_base_url.clone());
|
||||
let test = builder.build(&server).await?;
|
||||
|
||||
let cwd = test.cwd_path().to_path_buf();
|
||||
assert!(
|
||||
Command::new("git")
|
||||
.arg("init")
|
||||
.current_dir(&cwd)
|
||||
.status()?
|
||||
.success()
|
||||
);
|
||||
assert!(
|
||||
Command::new("git")
|
||||
.args(["config", "user.name", "Codex Test"])
|
||||
.current_dir(&cwd)
|
||||
.status()?
|
||||
.success()
|
||||
);
|
||||
assert!(
|
||||
Command::new("git")
|
||||
.args(["config", "user.email", "codex@example.com"])
|
||||
.current_dir(&cwd)
|
||||
.status()?
|
||||
.success()
|
||||
);
|
||||
assert!(
|
||||
Command::new("git")
|
||||
.args(["remote", "add", "origin", "https://example.test/repo.git"])
|
||||
.current_dir(&cwd)
|
||||
.status()?
|
||||
.success()
|
||||
);
|
||||
for idx in 0..400 {
|
||||
fs::write(
|
||||
cwd.join(format!("file-{idx:04}.txt")),
|
||||
format!("fixture file {idx}\n"),
|
||||
)?;
|
||||
}
|
||||
assert!(
|
||||
Command::new("git")
|
||||
.args(["add", "."])
|
||||
.current_dir(&cwd)
|
||||
.status()?
|
||||
.success()
|
||||
);
|
||||
assert!(
|
||||
Command::new("git")
|
||||
.args(["commit", "-m", "init"])
|
||||
.current_dir(&cwd)
|
||||
.status()?
|
||||
.success()
|
||||
);
|
||||
|
||||
test.codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![UserInput::Text {
|
||||
text: "Find the calendar create tool".to_string(),
|
||||
text_elements: Vec::new(),
|
||||
}],
|
||||
final_output_json_schema: None,
|
||||
})
|
||||
.await?;
|
||||
|
||||
wait_for_event(&test.codex, |event| {
|
||||
matches!(event, EventMsg::TurnComplete(_))
|
||||
})
|
||||
.await;
|
||||
|
||||
let _requests = mock.requests();
|
||||
let mcp_requests = server
|
||||
.received_requests()
|
||||
.await
|
||||
.expect("failed to fetch recorded requests");
|
||||
let tools_call_request = mcp_requests
|
||||
.iter()
|
||||
.find(|request| json_rpc_method(request).as_deref() == Some("tools/call"))
|
||||
.expect("tools/call MCP request");
|
||||
let turn_metadata_header = tools_call_request
|
||||
.headers
|
||||
.get("x-codex-turn-metadata")
|
||||
.expect("tools/call turn metadata header")
|
||||
.to_str()
|
||||
.expect("turn metadata header to be utf8");
|
||||
let parsed: Value = serde_json::from_str(turn_metadata_header)?;
|
||||
assert!(
|
||||
parsed
|
||||
.get("workspaces")
|
||||
.and_then(Value::as_object)
|
||||
.is_some_and(|workspaces| !workspaces.is_empty()),
|
||||
"expected enriched MCP turn metadata header with workspace git metadata, got {parsed:#?}"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ use std::io;
|
||||
use std::path::PathBuf;
|
||||
use std::process::Stdio;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex as StdMutex;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Result;
|
||||
@@ -22,6 +23,7 @@ use reqwest::header::HeaderMap;
|
||||
use reqwest::header::WWW_AUTHENTICATE;
|
||||
use rmcp::model::CallToolRequestParams;
|
||||
use rmcp::model::CallToolResult;
|
||||
use rmcp::model::ClientJsonRpcMessage;
|
||||
use rmcp::model::ClientNotification;
|
||||
use rmcp::model::ClientRequest;
|
||||
use rmcp::model::CreateElicitationRequestParams;
|
||||
@@ -83,14 +85,45 @@ const HEADER_LAST_EVENT_ID: &str = "Last-Event-Id";
|
||||
const HEADER_SESSION_ID: &str = "Mcp-Session-Id";
|
||||
const NON_JSON_RESPONSE_BODY_PREVIEW_BYTES: usize = 8_192;
|
||||
|
||||
fn message_uses_request_scoped_headers(message: &ClientJsonRpcMessage) -> bool {
|
||||
matches!(
|
||||
message,
|
||||
ClientJsonRpcMessage::Request(request)
|
||||
if request.request.method() == "tools/call"
|
||||
)
|
||||
}
|
||||
|
||||
fn apply_request_scoped_headers(
|
||||
mut request: reqwest::RequestBuilder,
|
||||
request_headers_state: &Arc<StdMutex<Option<HeaderMap>>>,
|
||||
) -> reqwest::RequestBuilder {
|
||||
let extra_headers = request_headers_state
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
||||
.clone();
|
||||
if let Some(extra_headers) = extra_headers {
|
||||
for (name, value) in &extra_headers {
|
||||
request = request.header(name, value.clone());
|
||||
}
|
||||
}
|
||||
request
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct StreamableHttpResponseClient {
|
||||
inner: reqwest::Client,
|
||||
request_headers_state: Arc<StdMutex<Option<HeaderMap>>>,
|
||||
}
|
||||
|
||||
impl StreamableHttpResponseClient {
|
||||
fn new(inner: reqwest::Client) -> Self {
|
||||
Self { inner }
|
||||
fn new(
|
||||
inner: reqwest::Client,
|
||||
request_headers_state: Arc<StdMutex<Option<HeaderMap>>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
inner,
|
||||
request_headers_state,
|
||||
}
|
||||
}
|
||||
|
||||
fn reqwest_error(
|
||||
@@ -133,6 +166,9 @@ impl StreamableHttpClient for StreamableHttpResponseClient {
|
||||
if let Some(session_id_value) = session_id.as_ref() {
|
||||
request = request.header(HEADER_SESSION_ID, session_id_value.as_ref());
|
||||
}
|
||||
if message_uses_request_scoped_headers(&message) {
|
||||
request = apply_request_scoped_headers(request, &self.request_headers_state);
|
||||
}
|
||||
|
||||
let response = request
|
||||
.json(&message)
|
||||
@@ -472,6 +508,7 @@ pub struct RmcpClient {
|
||||
transport_recipe: TransportRecipe,
|
||||
initialize_context: Mutex<Option<InitializeContext>>,
|
||||
session_recovery_lock: Mutex<()>,
|
||||
request_headers: Option<Arc<StdMutex<Option<HeaderMap>>>>,
|
||||
}
|
||||
|
||||
impl RmcpClient {
|
||||
@@ -489,9 +526,10 @@ impl RmcpClient {
|
||||
env_vars: env_vars.to_vec(),
|
||||
cwd,
|
||||
};
|
||||
let transport = Self::create_pending_transport(&transport_recipe)
|
||||
.await
|
||||
.map_err(io::Error::other)?;
|
||||
let transport =
|
||||
Self::create_pending_transport(&transport_recipe, /*request_headers*/ None)
|
||||
.await
|
||||
.map_err(io::Error::other)?;
|
||||
|
||||
Ok(Self {
|
||||
state: Mutex::new(ClientState::Connecting {
|
||||
@@ -500,6 +538,7 @@ impl RmcpClient {
|
||||
transport_recipe,
|
||||
initialize_context: Mutex::new(None),
|
||||
session_recovery_lock: Mutex::new(()),
|
||||
request_headers: None,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -511,6 +550,7 @@ impl RmcpClient {
|
||||
http_headers: Option<HashMap<String, String>>,
|
||||
env_http_headers: Option<HashMap<String, String>>,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
request_headers: Arc<StdMutex<Option<HeaderMap>>>,
|
||||
) -> Result<Self> {
|
||||
let transport_recipe = TransportRecipe::StreamableHttp {
|
||||
server_name: server_name.to_string(),
|
||||
@@ -520,7 +560,9 @@ impl RmcpClient {
|
||||
env_http_headers,
|
||||
store_mode,
|
||||
};
|
||||
let transport = Self::create_pending_transport(&transport_recipe).await?;
|
||||
let transport =
|
||||
Self::create_pending_transport(&transport_recipe, Some(Arc::clone(&request_headers)))
|
||||
.await?;
|
||||
Ok(Self {
|
||||
state: Mutex::new(ClientState::Connecting {
|
||||
transport: Some(transport),
|
||||
@@ -528,6 +570,7 @@ impl RmcpClient {
|
||||
transport_recipe,
|
||||
initialize_context: Mutex::new(None),
|
||||
session_recovery_lock: Mutex::new(()),
|
||||
request_headers: Some(request_headers),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -830,6 +873,7 @@ impl RmcpClient {
|
||||
|
||||
async fn create_pending_transport(
|
||||
transport_recipe: &TransportRecipe,
|
||||
request_headers: Option<Arc<StdMutex<Option<HeaderMap>>>>,
|
||||
) -> Result<PendingTransport> {
|
||||
match transport_recipe {
|
||||
TransportRecipe::Stdio {
|
||||
@@ -946,7 +990,12 @@ impl RmcpClient {
|
||||
.auth_header(access_token);
|
||||
let http_client = build_http_client(&default_headers)?;
|
||||
let transport = StreamableHttpClientTransport::with_client(
|
||||
StreamableHttpResponseClient::new(http_client),
|
||||
StreamableHttpResponseClient::new(
|
||||
http_client,
|
||||
request_headers
|
||||
.clone()
|
||||
.unwrap_or_else(|| Arc::new(StdMutex::new(None))),
|
||||
),
|
||||
http_config,
|
||||
);
|
||||
Ok(PendingTransport::StreamableHttp { transport })
|
||||
@@ -963,7 +1012,12 @@ impl RmcpClient {
|
||||
let http_client = build_http_client(&default_headers)?;
|
||||
|
||||
let transport = StreamableHttpClientTransport::with_client(
|
||||
StreamableHttpResponseClient::new(http_client),
|
||||
StreamableHttpResponseClient::new(
|
||||
http_client,
|
||||
request_headers
|
||||
.clone()
|
||||
.unwrap_or_else(|| Arc::new(StdMutex::new(None))),
|
||||
),
|
||||
http_config,
|
||||
);
|
||||
Ok(PendingTransport::StreamableHttp { transport })
|
||||
@@ -1111,7 +1165,9 @@ impl RmcpClient {
|
||||
.await
|
||||
.clone()
|
||||
.ok_or_else(|| anyhow!("MCP client cannot recover before initialize succeeds"))?;
|
||||
let pending_transport = Self::create_pending_transport(&self.transport_recipe).await?;
|
||||
let pending_transport =
|
||||
Self::create_pending_transport(&self.transport_recipe, self.request_headers.clone())
|
||||
.await?;
|
||||
let (service, oauth_persistor, process_group_guard) = Self::connect_pending_transport(
|
||||
pending_transport,
|
||||
initialize_context.handler,
|
||||
@@ -1166,7 +1222,10 @@ async fn create_oauth_transport_and_runtime(
|
||||
}
|
||||
};
|
||||
|
||||
let auth_client = AuthClient::new(StreamableHttpResponseClient::new(http_client), manager);
|
||||
let auth_client = AuthClient::new(
|
||||
StreamableHttpResponseClient::new(http_client, Arc::new(StdMutex::new(None))),
|
||||
manager,
|
||||
);
|
||||
let auth_manager = auth_client.auth_manager.clone();
|
||||
|
||||
let transport = StreamableHttpClientTransport::with_client(
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
use std::net::TcpListener;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex as StdMutex;
|
||||
use std::time::Duration;
|
||||
use std::time::Instant;
|
||||
|
||||
@@ -77,6 +79,7 @@ async fn create_client(base_url: &str) -> anyhow::Result<RmcpClient> {
|
||||
None,
|
||||
None,
|
||||
OAuthCredentialsStoreMode::File,
|
||||
Arc::new(StdMutex::new(None)),
|
||||
)
|
||||
.await?;
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
//! on every keystroke, and drops the session when the query becomes empty.
|
||||
|
||||
use codex_file_search as file_search;
|
||||
use std::num::NonZero;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
@@ -13,6 +14,8 @@ use std::sync::Mutex;
|
||||
use crate::app_event::AppEvent;
|
||||
use crate::app_event_sender::AppEventSender;
|
||||
|
||||
const FILE_SEARCH_LIMIT: usize = 50;
|
||||
|
||||
pub(crate) struct FileSearchManager {
|
||||
state: Arc<Mutex<SearchState>>,
|
||||
search_dir: PathBuf,
|
||||
@@ -83,6 +86,8 @@ impl FileSearchManager {
|
||||
let session = file_search::create_session(
|
||||
vec![self.search_dir.clone()],
|
||||
file_search::FileSearchOptions {
|
||||
#[expect(clippy::unwrap_used)]
|
||||
limit: NonZero::new(FILE_SEARCH_LIMIT).unwrap(),
|
||||
compute_indices: true,
|
||||
..Default::default()
|
||||
},
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
//! on every keystroke, and drops the session when the query becomes empty.
|
||||
|
||||
use codex_file_search as file_search;
|
||||
use std::num::NonZero;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
@@ -13,6 +14,8 @@ use std::sync::Mutex;
|
||||
use crate::app_event::AppEvent;
|
||||
use crate::app_event_sender::AppEventSender;
|
||||
|
||||
const FILE_SEARCH_LIMIT: usize = 50;
|
||||
|
||||
pub(crate) struct FileSearchManager {
|
||||
state: Arc<Mutex<SearchState>>,
|
||||
search_dir: PathBuf,
|
||||
@@ -83,6 +86,8 @@ impl FileSearchManager {
|
||||
let session = file_search::create_session(
|
||||
vec![self.search_dir.clone()],
|
||||
file_search::FileSearchOptions {
|
||||
#[expect(clippy::unwrap_used)]
|
||||
limit: NonZero::new(FILE_SEARCH_LIMIT).unwrap(),
|
||||
compute_indices: true,
|
||||
..Default::default()
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user