Files
codex/codex-rs/core/tests/suite/sqlite_state.rs
josiah-openai 937dd3812d Add supports_parallel_tool_calls flag to included mcps (#17667)
## Why

For more advanced MCP usage, we want the model to be able to emit
parallel MCP tool calls and have Codex execute eligible ones
concurrently, instead of forcing all MCP calls through the serial block.

The main design choice was where to thread the config. I made this
server-level because parallel safety depends on the MCP server
implementation. Codex reads the flag from `mcp_servers`, threads the
opted-in server names into `ToolRouter`, and checks the parsed
`ToolPayload::Mcp { server, .. }` at execution time. That avoids relying
on model-visible tool names, which can be incomplete in
deferred/search-tool paths or ambiguous for similarly named
servers/tools.

## What was added

Added `supports_parallel_tool_calls` for MCP servers.

Before:

```toml
[mcp_servers.docs]
command = "docs-server"
```

After:

```toml
[mcp_servers.docs]
command = "docs-server"
supports_parallel_tool_calls = true
```

MCP calls remain serial by default. Only tools from opted-in servers are
eligible to run in parallel. Docs also now warn to enable this only when
the server’s tools are safe to run concurrently, especially around
shared state or read/write races.

## Testing

Tested with a local stdio MCP server exposing real delay tools. The
model/Responses side was mocked only to deterministically emit two MCP
calls in the same turn.

Each test called `query_with_delay` and `query_with_delay_2` with `{
"seconds": 25 }`.

| Build/config | Observed | Wall time |
| --- | --- | --- |
| main with flag enabled | serial | `58.79s` |
| PR with flag enabled | parallel | `31.73s` |
| PR without flag | serial | `56.70s` |

PR with flag enabled showed both tools start before either completed;
main and PR-without-flag completed the first delay before starting the
second.

Also added an integration test.

Additional checks:

- `cargo test -p codex-tools` passed
- `cargo test -p codex-core
mcp_parallel_support_uses_exact_payload_server` passed
- `git diff --check` passed
2026-04-13 15:16:34 -07:00

511 lines
17 KiB
Rust

use anyhow::Result;
use codex_config::types::McpServerConfig;
use codex_config::types::McpServerTransportConfig;
use codex_features::Feature;
use codex_protocol::ThreadId;
use codex_protocol::dynamic_tools::DynamicToolSpec;
use codex_protocol::protocol::AskForApproval;
use codex_protocol::protocol::EventMsg;
use codex_protocol::protocol::Op;
use codex_protocol::protocol::RolloutItem;
use codex_protocol::protocol::RolloutLine;
use codex_protocol::protocol::SandboxPolicy;
use codex_protocol::protocol::SessionMeta;
use codex_protocol::protocol::SessionMetaLine;
use codex_protocol::protocol::SessionSource;
use codex_protocol::protocol::UserMessageEvent;
use codex_protocol::user_input::UserInput;
use core_test_support::responses;
use core_test_support::responses::ev_completed;
use core_test_support::responses::ev_function_call;
use core_test_support::responses::ev_response_created;
use core_test_support::responses::ev_web_search_call_done;
use core_test_support::responses::mount_sse_once;
use core_test_support::responses::mount_sse_sequence;
use core_test_support::responses::start_mock_server;
use core_test_support::skip_if_no_network;
use core_test_support::stdio_server_bin;
use core_test_support::test_codex::test_codex;
use core_test_support::wait_for_event;
use core_test_support::wait_for_event_match;
use pretty_assertions::assert_eq;
use serde_json::json;
use std::collections::HashMap;
use std::fs;
use tokio::time::Duration;
use tracing_subscriber::prelude::*;
use uuid::Uuid;
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn new_thread_is_recorded_in_state_db() -> Result<()> {
let server = start_mock_server().await;
let mut builder = test_codex().with_config(|config| {
config
.features
.enable(Feature::Sqlite)
.expect("test config should allow feature update");
});
let test = builder.build(&server).await?;
let thread_id = test.session_configured.session_id;
let rollout_path = test.codex.rollout_path().expect("rollout path");
let db_path = codex_state::state_db_path(test.config.sqlite_home.as_path());
for _ in 0..100 {
if tokio::fs::try_exists(&db_path).await.unwrap_or(false) {
break;
}
tokio::time::sleep(Duration::from_millis(25)).await;
}
let db = test.codex.state_db().expect("state db enabled");
assert!(
!rollout_path.exists(),
"fresh thread rollout should not be materialized before first user message"
);
let initial_metadata = db.get_thread(thread_id).await?;
assert!(
initial_metadata.is_none(),
"fresh thread should not be recorded in state db before first user message"
);
test.submit_turn("materialize rollout").await?;
let mut metadata = None;
for _ in 0..100 {
metadata = db.get_thread(thread_id).await?;
if metadata.is_some() {
break;
}
tokio::time::sleep(Duration::from_millis(25)).await;
}
let metadata = metadata.expect("thread should exist in state db");
assert_eq!(metadata.id, thread_id);
assert_eq!(metadata.rollout_path, rollout_path.to_path_buf());
assert!(
rollout_path.exists(),
"rollout should be materialized after first user message"
);
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn backfill_scans_existing_rollouts() -> Result<()> {
let server = start_mock_server().await;
let uuid = Uuid::now_v7();
let thread_id = ThreadId::from_string(&uuid.to_string())?;
let rollout_rel_path = format!("sessions/2026/01/27/rollout-2026-01-27T12-00-00-{uuid}.jsonl");
let rollout_rel_path_for_hook = rollout_rel_path.clone();
let dynamic_tools = vec![
DynamicToolSpec {
name: "geo_lookup".to_string(),
description: "lookup a city".to_string(),
input_schema: json!({
"type": "object",
"required": ["city"],
"properties": { "city": { "type": "string" } }
}),
defer_loading: true,
},
DynamicToolSpec {
name: "weather_lookup".to_string(),
description: "lookup weather".to_string(),
input_schema: json!({
"type": "object",
"required": ["zip"],
"properties": { "zip": { "type": "string" } }
}),
defer_loading: false,
},
];
let dynamic_tools_for_hook = dynamic_tools.clone();
let mut builder = test_codex()
.with_pre_build_hook(move |codex_home| {
let rollout_path = codex_home.join(&rollout_rel_path_for_hook);
let parent = rollout_path
.parent()
.expect("rollout path should have parent");
fs::create_dir_all(parent).expect("should create rollout directory");
let session_meta_line = SessionMetaLine {
meta: SessionMeta {
id: thread_id,
forked_from_id: None,
timestamp: "2026-01-27T12:00:00Z".to_string(),
cwd: codex_home.to_path_buf(),
originator: "test".to_string(),
cli_version: "test".to_string(),
source: SessionSource::default(),
agent_path: None,
agent_nickname: None,
agent_role: None,
model_provider: None,
base_instructions: None,
dynamic_tools: Some(dynamic_tools_for_hook),
memory_mode: None,
},
git: None,
};
let lines = [
RolloutLine {
timestamp: "2026-01-27T12:00:00Z".to_string(),
item: RolloutItem::SessionMeta(session_meta_line),
},
RolloutLine {
timestamp: "2026-01-27T12:00:01Z".to_string(),
item: RolloutItem::EventMsg(EventMsg::UserMessage(UserMessageEvent {
message: "hello from backfill".to_string(),
images: None,
local_images: Vec::new(),
text_elements: Vec::new(),
})),
},
];
let jsonl = lines
.iter()
.map(|line| serde_json::to_string(line).expect("rollout line should serialize"))
.collect::<Vec<_>>()
.join("\n");
fs::write(&rollout_path, format!("{jsonl}\n")).expect("should write rollout file");
})
.with_config(|config| {
config
.features
.enable(Feature::Sqlite)
.expect("test config should allow feature update");
});
let test = builder.build(&server).await?;
let db_path = codex_state::state_db_path(test.config.sqlite_home.as_path());
let rollout_path = test.config.codex_home.join(&rollout_rel_path);
let default_provider = test.config.model_provider_id.clone();
for _ in 0..20 {
if tokio::fs::try_exists(&db_path).await.unwrap_or(false) {
break;
}
tokio::time::sleep(Duration::from_millis(25)).await;
}
let db = test.codex.state_db().expect("state db enabled");
let mut metadata = None;
for _ in 0..40 {
metadata = db.get_thread(thread_id).await?;
if metadata.is_some() {
break;
}
tokio::time::sleep(Duration::from_millis(25)).await;
}
let metadata = metadata.expect("backfilled thread should exist in state db");
assert_eq!(metadata.id, thread_id);
assert_eq!(metadata.rollout_path, rollout_path.to_path_buf());
assert_eq!(metadata.model_provider, default_provider);
assert!(metadata.first_user_message.is_some());
let mut stored_tools = None;
for _ in 0..40 {
stored_tools = db.get_dynamic_tools(thread_id).await?;
if stored_tools.is_some() {
break;
}
tokio::time::sleep(Duration::from_millis(25)).await;
}
let stored_tools = stored_tools.expect("dynamic tools should be stored");
assert_eq!(stored_tools, dynamic_tools);
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn user_messages_persist_in_state_db() -> Result<()> {
let server = start_mock_server().await;
mount_sse_sequence(
&server,
vec![
responses::sse(vec![ev_response_created("resp-1"), ev_completed("resp-1")]),
responses::sse(vec![ev_response_created("resp-2"), ev_completed("resp-2")]),
],
)
.await;
let mut builder = test_codex().with_config(|config| {
config
.features
.enable(Feature::Sqlite)
.expect("test config should allow feature update");
});
let test = builder.build(&server).await?;
let db_path = codex_state::state_db_path(test.config.sqlite_home.as_path());
for _ in 0..100 {
if tokio::fs::try_exists(&db_path).await.unwrap_or(false) {
break;
}
tokio::time::sleep(Duration::from_millis(25)).await;
}
test.submit_turn("hello from sqlite").await?;
test.submit_turn("another message").await?;
let db = test.codex.state_db().expect("state db enabled");
let thread_id = test.session_configured.session_id;
let mut metadata = None;
for _ in 0..100 {
metadata = db.get_thread(thread_id).await?;
if metadata
.as_ref()
.map(|entry| entry.first_user_message.is_some())
.unwrap_or(false)
{
break;
}
tokio::time::sleep(Duration::from_millis(25)).await;
}
let metadata = metadata.expect("thread should exist in state db");
assert!(metadata.first_user_message.is_some());
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn web_search_marks_thread_memory_mode_polluted_when_configured() -> Result<()> {
let server = start_mock_server().await;
mount_sse_sequence(
&server,
vec![responses::sse(vec![
ev_response_created("resp-1"),
ev_web_search_call_done("ws-1", "completed", "weather seattle"),
ev_completed("resp-1"),
])],
)
.await;
let mut builder = test_codex().with_config(|config| {
config
.features
.enable(Feature::Sqlite)
.expect("test config should allow feature update");
config.memories.no_memories_if_mcp_or_web_search = true;
});
let test = builder.build(&server).await?;
let db = test.codex.state_db().expect("state db enabled");
let thread_id = test.session_configured.session_id;
test.submit_turn("search the web").await?;
let mut memory_mode = None;
for _ in 0..100 {
memory_mode = db.get_thread_memory_mode(thread_id).await?;
if memory_mode.as_deref() == Some("polluted") {
break;
}
tokio::time::sleep(Duration::from_millis(25)).await;
}
assert_eq!(memory_mode.as_deref(), Some("polluted"));
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn mcp_call_marks_thread_memory_mode_polluted_when_configured() -> Result<()> {
skip_if_no_network!(Ok(()));
let server = start_mock_server().await;
let call_id = "call-123";
let server_name = "rmcp";
let tool_name = format!("mcp__{server_name}__echo");
mount_sse_once(
&server,
responses::sse(vec![
ev_response_created("resp-1"),
ev_function_call(call_id, &tool_name, "{\"message\":\"ping\"}"),
ev_completed("resp-1"),
]),
)
.await;
mount_sse_once(
&server,
responses::sse(vec![
responses::ev_assistant_message("msg-1", "rmcp echo tool completed."),
ev_completed("resp-2"),
]),
)
.await;
let rmcp_test_server_bin = stdio_server_bin()?;
let mut builder = test_codex().with_config(move |config| {
config
.features
.enable(Feature::Sqlite)
.expect("test config should allow feature update");
config.memories.no_memories_if_mcp_or_web_search = true;
let mut servers = config.mcp_servers.get().clone();
servers.insert(
server_name.to_string(),
McpServerConfig {
transport: McpServerTransportConfig::Stdio {
command: rmcp_test_server_bin,
args: Vec::new(),
env: Some(HashMap::from([(
"MCP_TEST_VALUE".to_string(),
"propagated-env".to_string(),
)])),
env_vars: Vec::new(),
cwd: None,
},
enabled: true,
required: false,
supports_parallel_tool_calls: false,
disabled_reason: None,
startup_timeout_sec: Some(Duration::from_secs(10)),
tool_timeout_sec: None,
enabled_tools: None,
disabled_tools: None,
scopes: None,
oauth_resource: None,
tools: HashMap::new(),
},
);
config
.mcp_servers
.set(servers)
.expect("test mcp servers should accept any configuration");
});
let test = builder.build(&server).await?;
let db = test.codex.state_db().expect("state db enabled");
let thread_id = test.session_configured.session_id;
test.codex
.submit(Op::UserTurn {
items: vec![UserInput::Text {
text: "call the rmcp echo tool".to_string(),
text_elements: Vec::new(),
}],
final_output_json_schema: None,
cwd: test.cwd_path().to_path_buf(),
approval_policy: AskForApproval::Never,
approvals_reviewer: None,
sandbox_policy: SandboxPolicy::new_read_only_policy(),
model: test.session_configured.model.clone(),
effort: None,
summary: None,
service_tier: None,
collaboration_mode: None,
personality: None,
})
.await?;
wait_for_event(&test.codex, |event| {
matches!(event, EventMsg::McpToolCallEnd(_))
})
.await;
wait_for_event_match(&test.codex, |event| match event {
EventMsg::Error(err) => Some(Err(anyhow::anyhow!(err.message.clone()))),
EventMsg::TurnComplete(_) => Some(Ok(())),
_ => None,
})
.await?;
let mut memory_mode = None;
for _ in 0..100 {
memory_mode = db.get_thread_memory_mode(thread_id).await?;
if memory_mode.as_deref() == Some("polluted") {
break;
}
tokio::time::sleep(Duration::from_millis(25)).await;
}
assert_eq!(memory_mode.as_deref(), Some("polluted"));
Ok(())
}
#[tokio::test(flavor = "current_thread")]
async fn tool_call_logs_include_thread_id() -> Result<()> {
let server = start_mock_server().await;
let call_id = "call-1";
let args = json!({
"command": "echo hello",
"timeout_ms": 1_000,
"login": false,
});
let args_json = serde_json::to_string(&args)?;
mount_sse_sequence(
&server,
vec![
responses::sse(vec![
ev_response_created("resp-1"),
ev_function_call(call_id, "shell_command", &args_json),
ev_completed("resp-1"),
]),
responses::sse(vec![ev_completed("resp-2")]),
],
)
.await;
let mut builder = test_codex().with_config(|config| {
config
.features
.enable(Feature::Sqlite)
.expect("test config should allow feature update");
});
let test = builder.build(&server).await?;
let db = test.codex.state_db().expect("state db enabled");
let expected_thread_id = test.session_configured.session_id.to_string();
test.submit_turn("run a shell command").await?;
let log_db_layer = codex_state::log_db::start(db.clone());
let subscriber = tracing_subscriber::registry().with(log_db_layer.clone());
let dispatch = tracing::Dispatch::new(subscriber);
tracing::dispatcher::with_default(&dispatch, || {
let span = tracing::info_span!("test_log_span", thread_id = %expected_thread_id);
let _entered = span.enter();
tracing::info!("ToolCall: shell_command {{\"command\":\"echo hello\"}}");
});
log_db_layer.flush().await;
let mut found = None;
for _ in 0..80 {
let query = codex_state::LogQuery {
descending: true,
limit: Some(20),
..Default::default()
};
let rows = db.query_logs(&query).await?;
if let Some(row) = rows.into_iter().find(|row| {
row.message
.as_deref()
.is_some_and(|m| m.contains("ToolCall:"))
}) {
let thread_id = row.thread_id;
let message = row.message;
found = Some((thread_id, message));
break;
}
tokio::time::sleep(Duration::from_millis(25)).await;
}
let (thread_id, message) = found.expect("expected ToolCall log row");
assert_eq!(thread_id, Some(expected_thread_id));
assert!(
message
.as_deref()
.is_some_and(|text| text.contains("ToolCall:")),
"expected ToolCall message, got {message:?}"
);
Ok(())
}