This commit is contained in:
Ahmed Ibrahim
2025-08-02 19:09:13 -07:00
parent 9805ad1fbc
commit 324926e240
8 changed files with 83 additions and 125 deletions

View File

@@ -83,6 +83,7 @@ pub async fn run_conversation_loop(
};
// Buffer all events to include in InitialState when streaming is enabled
// TODO: this should be expanded to load sessions from the disk.
let mut buffered_events: Vec<CodexEventNotificationParams> = Vec::new();
let mut streaming_enabled = *stream_rx.borrow();

View File

@@ -14,6 +14,7 @@ use crate::outgoing_message::OutgoingMessageSender;
use crate::tool_handlers::create_conversation::handle_create_conversation;
use crate::tool_handlers::send_message::handle_send_message;
use crate::tool_handlers::stream_conversation;
use crate::tool_handlers::stream_conversation::handle_stream_conversation;
use codex_core::Codex;
use codex_core::config::Config as CodexConfig;
@@ -380,10 +381,7 @@ impl MessageProcessor {
handle_send_message(self, request_id, args).await;
}
ToolCallRequestParams::ConversationStream(args) => {
crate::tool_handlers::stream_conversation::handle_stream_conversation(
self, request_id, args,
)
.await;
handle_stream_conversation(self, request_id, args).await;
}
_ => {
let result = CallToolResult {
@@ -615,32 +613,80 @@ impl MessageProcessor {
params: <mcp_types::CancelledNotification as mcp_types::ModelContextProtocolNotification>::Params,
) {
let request_id = params.request_id;
// First, route cancellation for tracked tool calls (e.g., ConversationStream)
if let Some(orig) = {
let mut guard = self.tool_request_map.lock().await;
guard.remove(&request_id)
} {
match orig {
ToolCallRequestParams::ConversationStream(args) => {
stream_conversation::handle_cancel(self, &args).await;
return;
}
_ => {
// TODO: Implement later. Things like interrupt.
self.handle_mcp_protocol_cancelled_notification(request_id, orig)
.await;
} else {
self.handle_legacy_cancelled_notification(request_id).await;
}
}
async fn handle_mcp_protocol_cancelled_notification(
&self,
request_id: RequestId,
orig: ToolCallRequestParams,
) {
match orig {
ToolCallRequestParams::ConversationStream(args) => {
stream_conversation::handle_cancel(self, &args).await;
}
ToolCallRequestParams::ConversationSendMessage(args) => {
// Cancel in-flight user input for this conversation by interrupting
// the submission with the same request id we used when sending.
let request_id_string = match &request_id {
RequestId::String(s) => s.clone(),
RequestId::Integer(i) => i.to_string(),
};
let session_id = args.conversation_id.0;
let codex_arc = {
let sessions_guard = self.session_map.lock().await;
match sessions_guard.get(&session_id) {
Some(codex) => Arc::clone(codex),
None => {
tracing::warn!(
"Cancel send_message: session not found for session_id: {session_id}"
);
return;
}
}
};
if let Err(e) = codex_arc
.submit_with_id(Submission {
id: request_id_string,
op: codex_core::protocol::Op::Interrupt,
})
.await
{
tracing::error!("Failed to submit interrupt for send_message cancel: {e}");
}
}
ToolCallRequestParams::ConversationCreate(_)
| ToolCallRequestParams::ConversationsList(_) => {
// Likely fast/non-streaming; nothing to cancel currently.
tracing::debug!(
"Cancel conversationsList received for request_id: {:?} (no-op)",
request_id
);
}
}
// Create a stable string form early for logging and submission id.
}
async fn handle_legacy_cancelled_notification(&self, request_id: RequestId) {
let request_id_string = match &request_id {
RequestId::String(s) => s.clone(),
RequestId::Integer(i) => i.to_string(),
};
// Obtain the session_id while holding the first lock, then release.
let session_id = {
let map_guard = self.running_requests_id_to_codex_uuid.lock().await;
match map_guard.get(&request_id) {
Some(id) => *id, // Uuid is Copy
Some(id) => *id,
None => {
tracing::warn!("Session not found for request_id: {}", request_id_string);
return;
@@ -649,7 +695,6 @@ impl MessageProcessor {
};
tracing::info!("session_id: {session_id}");
// Obtain the Codex Arc while holding the session_map lock, then release.
let codex_arc = {
let sessions_guard = self.session_map.lock().await;
match sessions_guard.get(&session_id) {
@@ -661,18 +706,17 @@ impl MessageProcessor {
}
};
// Submit interrupt to Codex.
let err = codex_arc
if let Err(e) = codex_arc
.submit_with_id(Submission {
id: request_id_string,
op: codex_core::protocol::Op::Interrupt,
})
.await;
if let Err(e) = err {
.await
{
tracing::error!("Failed to submit interrupt to Codex: {e}");
return;
}
// unregister the id so we don't keep it in the map
self.running_requests_id_to_codex_uuid
.lock()
.await

View File

@@ -1,7 +1,9 @@
mod config;
mod mcp_process;
mod mock_model_server;
mod responses;
pub use config::create_config_toml;
pub use mcp_process::McpProcess;
pub use mock_model_server::create_mock_chat_completions_server;
pub use responses::create_apply_patch_sse_response;

View File

@@ -2,6 +2,7 @@ use std::path::Path;
use std::process::Stdio;
use std::sync::atomic::AtomicI64;
use std::sync::atomic::Ordering;
use std::time::Duration;
use tokio::io::AsyncBufReadExt;
use tokio::io::AsyncWriteExt;
use tokio::io::BufReader;
@@ -295,7 +296,7 @@ impl McpProcess {
/// Wait for an agent_message with a bounded timeout. Returns Some(params) if received, None on timeout.
pub async fn maybe_wait_for_agent_message(
&mut self,
dur: std::time::Duration,
dur: Duration,
) -> anyhow::Result<Option<serde_json::Value>> {
match tokio::time::timeout(dur, self.wait_for_agent_message()).await {
Ok(Ok(v)) => Ok(Some(v)),

View File

@@ -1,8 +1,9 @@
#![allow(clippy::expect_used, clippy::unwrap_used)]
use std::path::Path;
use mcp_test_support::McpProcess;
use mcp_test_support::create_config_toml;
use mcp_test_support::create_final_assistant_message_sse_response;
use mcp_test_support::create_mock_chat_completions_server;
use mcp_types::JSONRPCResponse;
@@ -103,26 +104,4 @@ async fn test_conversation_create_and_send_message_ok() {
drop(server);
}
// Helper to create a config.toml pointing at the mock model server.
fn create_config_toml(codex_home: &Path, server_uri: &str) -> std::io::Result<()> {
let config_toml = codex_home.join("config.toml");
std::fs::write(
config_toml,
format!(
r#"
model = "mock-model"
approval_policy = "never"
sandbox_mode = "danger-full-access"
model_provider = "mock_provider"
[model_providers.mock_provider]
name = "Mock provider for test"
base_url = "{server_uri}/v1"
wire_api = "chat"
request_max_retries = 0
stream_max_retries = 0
"#
),
)
}
// create_config_toml is provided by tests/common

View File

@@ -1,7 +1,7 @@
#![cfg(unix)]
// Support code lives in the `mcp_test_support` crate under tests/common.
use std::path::Path;
use codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR;
use codex_mcp_server::CodexToolCallParam;
@@ -12,6 +12,7 @@ use tempfile::TempDir;
use tokio::time::timeout;
use mcp_test_support::McpProcess;
use mcp_test_support::create_config_toml;
use mcp_test_support::create_mock_chat_completions_server;
use mcp_test_support::create_shell_sse_response;
@@ -66,7 +67,7 @@ async fn shell_command_interruption() -> anyhow::Result<()> {
// Create Codex configuration
let codex_home = TempDir::new()?;
create_config_toml(codex_home.path(), server.uri())?;
create_config_toml(codex_home.path(), &server.uri())?;
let mut mcp_process = McpProcess::new(codex_home.path()).await?;
timeout(DEFAULT_READ_TIMEOUT, mcp_process.initialize()).await??;
@@ -149,29 +150,4 @@ async fn shell_command_interruption() -> anyhow::Result<()> {
Ok(())
}
// ---------------------------------------------------------------------------
// Helpers
// ---------------------------------------------------------------------------
fn create_config_toml(codex_home: &Path, server_uri: String) -> std::io::Result<()> {
let config_toml = codex_home.join("config.toml");
std::fs::write(
config_toml,
format!(
r#"
model = "mock-model"
approval_policy = "never"
sandbox_mode = "danger-full-access"
model_provider = "mock_provider"
[model_providers.mock_provider]
name = "Mock provider for test"
base_url = "{server_uri}/v1"
wire_api = "chat"
request_max_retries = 0
stream_max_retries = 0
"#
),
)
}
// Helpers are provided by tests/common

View File

@@ -1,11 +1,12 @@
#![allow(clippy::expect_used)]
use std::path::Path;
use std::thread::sleep;
use std::time::Duration;
use codex_mcp_server::CodexToolCallParam;
use mcp_test_support::McpProcess;
use mcp_test_support::create_config_toml;
use mcp_test_support::create_final_assistant_message_sse_response;
use mcp_test_support::create_mock_chat_completions_server;
use mcp_types::JSONRPC_VERSION;
@@ -135,29 +136,4 @@ async fn test_send_message_session_not_found() {
assert_eq!(result["isError"], json!(true));
}
// ---------------------------------------------------------------------------
// Helpers
// ---------------------------------------------------------------------------
fn create_config_toml(codex_home: &Path, server_uri: &str) -> std::io::Result<()> {
let config_toml = codex_home.join("config.toml");
std::fs::write(
config_toml,
format!(
r#"
model = "mock-model"
approval_policy = "never"
sandbox_mode = "danger-full-access"
model_provider = "mock_provider"
[model_providers.mock_provider]
name = "Mock provider for test"
base_url = "{server_uri}/v1"
wire_api = "chat"
request_max_retries = 0
stream_max_retries = 0
"#
),
)
}
// Helpers are provided by tests/common

View File

@@ -1,8 +1,9 @@
#![allow(clippy::expect_used, clippy::unwrap_used)]
use std::path::Path;
use mcp_test_support::McpProcess;
use mcp_test_support::create_config_toml;
use mcp_test_support::create_final_assistant_message_sse_response;
use mcp_test_support::create_mock_chat_completions_server;
use mcp_types::JSONRPCNotification;
@@ -249,26 +250,4 @@ async fn test_cancel_stream_then_reconnect_catches_up_initial_state() {
drop(server);
}
// Helper to create a config.toml pointing at the mock model server.
fn create_config_toml(codex_home: &Path, server_uri: &str) -> std::io::Result<()> {
let config_toml = codex_home.join("config.toml");
std::fs::write(
config_toml,
format!(
r#"
model = "mock-model"
approval_policy = "never"
sandbox_mode = "danger-full-access"
model_provider = "mock_provider"
[model_providers.mock_provider]
name = "Mock provider for test"
base_url = "{server_uri}/v1"
wire_api = "chat"
request_max_retries = 0
stream_max_retries = 0
"#
),
)
}
// create_config_toml is provided by tests/common