mirror of
https://github.com/openai/codex.git
synced 2026-04-24 06:35:50 +00:00
stream init
This commit is contained in:
@@ -1,6 +1,11 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::exec_approval::handle_exec_approval_request;
|
||||
use crate::mcp_protocol::CodexEventNotificationParams;
|
||||
use crate::mcp_protocol::ConversationId;
|
||||
use crate::mcp_protocol::InitialStateNotificationParams;
|
||||
use crate::mcp_protocol::InitialStatePayload;
|
||||
use crate::mcp_protocol::NotificationMeta;
|
||||
use crate::outgoing_message::OutgoingMessageSender;
|
||||
use crate::outgoing_message::OutgoingNotificationMeta;
|
||||
use crate::patch_approval::handle_patch_approval_request;
|
||||
@@ -10,48 +15,63 @@ use codex_core::protocol::ApplyPatchApprovalRequestEvent;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::ExecApprovalRequestEvent;
|
||||
use mcp_types::RequestId;
|
||||
use tokio::sync::watch::Receiver as WatchReceiver;
|
||||
use tracing::error;
|
||||
use uuid::Uuid;
|
||||
|
||||
pub async fn run_conversation_loop(
|
||||
codex: Arc<Codex>,
|
||||
outgoing: Arc<OutgoingMessageSender>,
|
||||
request_id: RequestId,
|
||||
mut stream_rx: WatchReceiver<bool>,
|
||||
session_id: Uuid,
|
||||
) {
|
||||
let request_id_str = match &request_id {
|
||||
RequestId::String(s) => s.clone(),
|
||||
RequestId::Integer(n) => n.to_string(),
|
||||
};
|
||||
|
||||
// Stream events until the task needs to pause for user interaction or
|
||||
// completes.
|
||||
loop {
|
||||
match codex.next_event().await {
|
||||
Ok(event) => {
|
||||
outgoing
|
||||
.send_event_as_notification(
|
||||
&event,
|
||||
Some(OutgoingNotificationMeta::new(Some(request_id.clone()))),
|
||||
)
|
||||
.await;
|
||||
// Buffer all events for InitialState
|
||||
let mut buffered_events: Vec<CodexEventNotificationParams> = Vec::new();
|
||||
let mut streaming_enabled = *stream_rx.borrow();
|
||||
|
||||
match event.msg {
|
||||
loop {
|
||||
tokio::select! {
|
||||
res = codex.next_event() => {
|
||||
match res {
|
||||
Ok(event) => {
|
||||
// Always buffer the event
|
||||
buffered_events.push(CodexEventNotificationParams { meta: None, msg: event.msg.clone() });
|
||||
|
||||
if streaming_enabled {
|
||||
outgoing
|
||||
.send_event_as_notification(
|
||||
&event,
|
||||
Some(OutgoingNotificationMeta::new(Some(request_id.clone()))),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
match event.msg {
|
||||
EventMsg::ExecApprovalRequest(ExecApprovalRequestEvent {
|
||||
command,
|
||||
cwd,
|
||||
call_id,
|
||||
reason: _,
|
||||
}) => {
|
||||
handle_exec_approval_request(
|
||||
command,
|
||||
cwd,
|
||||
outgoing.clone(),
|
||||
codex.clone(),
|
||||
request_id.clone(),
|
||||
request_id_str.clone(),
|
||||
event.id.clone(),
|
||||
call_id,
|
||||
)
|
||||
.await;
|
||||
if streaming_enabled {
|
||||
handle_exec_approval_request(
|
||||
command,
|
||||
cwd,
|
||||
outgoing.clone(),
|
||||
codex.clone(),
|
||||
request_id.clone(),
|
||||
request_id_str.clone(),
|
||||
event.id.clone(),
|
||||
call_id,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
EventMsg::Error(_) => {
|
||||
@@ -63,18 +83,20 @@ pub async fn run_conversation_loop(
|
||||
grant_root,
|
||||
changes,
|
||||
}) => {
|
||||
handle_patch_approval_request(
|
||||
call_id,
|
||||
reason,
|
||||
grant_root,
|
||||
changes,
|
||||
outgoing.clone(),
|
||||
codex.clone(),
|
||||
request_id.clone(),
|
||||
request_id_str.clone(),
|
||||
event.id.clone(),
|
||||
)
|
||||
.await;
|
||||
if streaming_enabled {
|
||||
handle_patch_approval_request(
|
||||
call_id,
|
||||
reason,
|
||||
grant_root,
|
||||
changes,
|
||||
outgoing.clone(),
|
||||
codex.clone(),
|
||||
request_id.clone(),
|
||||
request_id_str.clone(),
|
||||
event.id.clone(),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
EventMsg::TaskComplete(_) => {}
|
||||
@@ -111,10 +133,35 @@ pub async fn run_conversation_loop(
|
||||
// though we may want to do give different treatment to
|
||||
// individual events in the future.
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Codex runtime error: {e}");
|
||||
}
|
||||
}
|
||||
},
|
||||
changed = stream_rx.changed() => {
|
||||
if changed.is_ok() {
|
||||
let now = *stream_rx.borrow();
|
||||
if now && !streaming_enabled {
|
||||
streaming_enabled = true;
|
||||
// Emit InitialState with all buffered events
|
||||
let params = InitialStateNotificationParams {
|
||||
meta: Some(NotificationMeta { conversation_id: Some(ConversationId(session_id)), request_id: Some(request_id.clone()) }),
|
||||
initial_state: InitialStatePayload { events: buffered_events.clone() },
|
||||
};
|
||||
if let Ok(params_val) = serde_json::to_value(¶ms) {
|
||||
outgoing
|
||||
.send_custom_notification("notifications/initial_state", params_val)
|
||||
.await;
|
||||
} else {
|
||||
error!("Failed to serialize InitialState params");
|
||||
}
|
||||
} else if !now && streaming_enabled {
|
||||
// streaming disabled
|
||||
streaming_enabled = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Codex runtime error: {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ use crate::mcp_protocol::ToolCallResponseResult;
|
||||
use crate::outgoing_message::OutgoingMessageSender;
|
||||
use crate::tool_handlers::create_conversation::handle_create_conversation;
|
||||
use crate::tool_handlers::send_message::handle_send_message;
|
||||
use crate::tool_handlers::stream_conversation;
|
||||
|
||||
use codex_core::Codex;
|
||||
use codex_core::config::Config as CodexConfig;
|
||||
@@ -35,6 +36,7 @@ use mcp_types::ServerNotification;
|
||||
use mcp_types::TextContent;
|
||||
use serde_json::json;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::watch;
|
||||
use tokio::task;
|
||||
use uuid::Uuid;
|
||||
|
||||
@@ -45,6 +47,10 @@ pub(crate) struct MessageProcessor {
|
||||
session_map: Arc<Mutex<HashMap<Uuid, Arc<Codex>>>>,
|
||||
running_requests_id_to_codex_uuid: Arc<Mutex<HashMap<RequestId, Uuid>>>,
|
||||
running_session_ids: Arc<Mutex<HashSet<Uuid>>>,
|
||||
// Per-session streaming state signal (true when client connected via ConversationStream)
|
||||
streaming_session_senders: Arc<Mutex<HashMap<Uuid, watch::Sender<bool>>>>,
|
||||
// Track request IDs to the original ToolCallRequestParams for cancellation handling
|
||||
tool_request_map: Arc<Mutex<HashMap<RequestId, ToolCallRequestParams>>>,
|
||||
}
|
||||
|
||||
impl MessageProcessor {
|
||||
@@ -61,6 +67,8 @@ impl MessageProcessor {
|
||||
session_map: Arc::new(Mutex::new(HashMap::new())),
|
||||
running_requests_id_to_codex_uuid: Arc::new(Mutex::new(HashMap::new())),
|
||||
running_session_ids: Arc::new(Mutex::new(HashSet::new())),
|
||||
streaming_session_senders: Arc::new(Mutex::new(HashMap::new())),
|
||||
tool_request_map: Arc::new(Mutex::new(HashMap::new())),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -76,6 +84,12 @@ impl MessageProcessor {
|
||||
self.running_session_ids.clone()
|
||||
}
|
||||
|
||||
pub(crate) fn streaming_session_senders(
|
||||
&self,
|
||||
) -> Arc<Mutex<HashMap<Uuid, watch::Sender<bool>>>> {
|
||||
self.streaming_session_senders.clone()
|
||||
}
|
||||
|
||||
pub(crate) async fn process_request(&mut self, request: JSONRPCRequest) {
|
||||
// Hold on to the ID so we can respond.
|
||||
let request_id = request.id.clone();
|
||||
@@ -353,6 +367,11 @@ impl MessageProcessor {
|
||||
}
|
||||
}
|
||||
async fn handle_new_tool_calls(&self, request_id: RequestId, params: ToolCallRequestParams) {
|
||||
// Track the request to allow graceful cancellation routing later.
|
||||
{
|
||||
let mut guard = self.tool_request_map.lock().await;
|
||||
guard.insert(request_id.clone(), params);
|
||||
}
|
||||
match params {
|
||||
ToolCallRequestParams::ConversationCreate(args) => {
|
||||
handle_create_conversation(self, request_id, args).await;
|
||||
@@ -360,6 +379,12 @@ impl MessageProcessor {
|
||||
ToolCallRequestParams::ConversationSendMessage(args) => {
|
||||
handle_send_message(self, request_id, args).await;
|
||||
}
|
||||
ToolCallRequestParams::ConversationStream(args) => {
|
||||
crate::tool_handlers::stream_conversation::handle_stream_conversation(
|
||||
self, request_id, args,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
_ => {
|
||||
let result = CallToolResult {
|
||||
content: vec![ContentBlock::TextContent(TextContent {
|
||||
@@ -590,6 +615,21 @@ impl MessageProcessor {
|
||||
params: <mcp_types::CancelledNotification as mcp_types::ModelContextProtocolNotification>::Params,
|
||||
) {
|
||||
let request_id = params.request_id;
|
||||
// First, route cancellation for tracked tool calls (e.g., ConversationStream)
|
||||
if let Some(orig) = {
|
||||
let mut guard = self.tool_request_map.lock().await;
|
||||
guard.remove(&request_id)
|
||||
} {
|
||||
match orig {
|
||||
ToolCallRequestParams::ConversationStream(args) => {
|
||||
stream_conversation::handle_cancel(self, &args).await;
|
||||
return;
|
||||
}
|
||||
_ => {
|
||||
// TODO: Implement later. Things like interrupt.
|
||||
}
|
||||
}
|
||||
}
|
||||
// Create a stable string form early for logging and submission id.
|
||||
let request_id_string = match &request_id {
|
||||
RequestId::String(s) => s.clone(),
|
||||
|
||||
@@ -124,6 +124,15 @@ impl OutgoingMessageSender {
|
||||
let outgoing_message = OutgoingMessage::Error(OutgoingError { id, error });
|
||||
let _ = self.sender.send(outgoing_message).await;
|
||||
}
|
||||
|
||||
/// Send a custom notification with an explicit method name and params object.
|
||||
pub(crate) async fn send_custom_notification(&self, method: &str, params: serde_json::Value) {
|
||||
let outgoing_message = OutgoingMessage::Notification(OutgoingNotification {
|
||||
method: method.to_string(),
|
||||
params: Some(params),
|
||||
});
|
||||
let _ = self.sender.send(outgoing_message).await;
|
||||
}
|
||||
}
|
||||
|
||||
/// Outgoing message from the server to the client.
|
||||
|
||||
@@ -10,6 +10,7 @@ use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::SessionConfiguredEvent;
|
||||
use mcp_types::RequestId;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::watch;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::conversation_loop::run_conversation_loop;
|
||||
@@ -128,11 +129,19 @@ pub(crate) async fn handle_create_conversation(
|
||||
message_processor.session_map(),
|
||||
)
|
||||
.await;
|
||||
|
||||
// Create per-session streaming control channel (initially disabled)
|
||||
let (stream_tx, stream_rx) = watch::channel(false);
|
||||
{
|
||||
let senders = message_processor.streaming_session_senders();
|
||||
let mut guard = senders.lock().await;
|
||||
guard.insert(session_id, stream_tx);
|
||||
}
|
||||
// Run the conversation loop in the background so this request can return immediately.
|
||||
let outgoing = message_processor.outgoing();
|
||||
let spawn_id = id.clone();
|
||||
tokio::spawn(async move {
|
||||
run_conversation_loop(codex_arc.clone(), outgoing, spawn_id).await;
|
||||
run_conversation_loop(codex_arc.clone(), outgoing, spawn_id, stream_rx, session_id).await;
|
||||
});
|
||||
|
||||
// Reply with the new conversation id and effective model
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
pub(crate) mod create_conversation;
|
||||
pub(crate) mod send_message;
|
||||
pub(crate) mod stream_conversation;
|
||||
|
||||
79
codex-rs/mcp-server/src/tool_handlers/stream_conversation.rs
Normal file
79
codex-rs/mcp-server/src/tool_handlers/stream_conversation.rs
Normal file
@@ -0,0 +1,79 @@
|
||||
use mcp_types::RequestId;
|
||||
|
||||
use crate::mcp_protocol::ConversationStreamArgs;
|
||||
use crate::mcp_protocol::ConversationStreamResult;
|
||||
use crate::mcp_protocol::ToolCallResponseResult;
|
||||
use crate::message_processor::MessageProcessor;
|
||||
use crate::tool_handlers::send_message::get_session;
|
||||
use uuid::Uuid;
|
||||
|
||||
pub(crate) async fn handle_stream_conversation(
|
||||
message_processor: &MessageProcessor,
|
||||
id: RequestId,
|
||||
arguments: ConversationStreamArgs,
|
||||
) {
|
||||
let ConversationStreamArgs { conversation_id } = arguments;
|
||||
|
||||
let session_id = conversation_id.0;
|
||||
|
||||
// Ensure the session exists
|
||||
let session_exists = get_session(session_id, message_processor.session_map())
|
||||
.await
|
||||
.is_some();
|
||||
|
||||
if !session_exists {
|
||||
// Return an error with no result payload per MCP error pattern
|
||||
message_processor
|
||||
.send_response_with_optional_error(id, None, Some(true))
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
|
||||
// Toggle streaming to enabled via the per-session watch channel
|
||||
let senders_map = message_processor.streaming_session_senders();
|
||||
let tx = {
|
||||
let guard = senders_map.lock().await;
|
||||
guard.get(&session_id).cloned()
|
||||
};
|
||||
match tx {
|
||||
Some(tx) => {
|
||||
let _ = tx.send(true);
|
||||
}
|
||||
None => {
|
||||
// No channel found for the session; treat as error
|
||||
message_processor
|
||||
.send_response_with_optional_error(id, None, Some(true))
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Acknowledge the stream request
|
||||
message_processor
|
||||
.send_response_with_optional_error(
|
||||
id,
|
||||
Some(ToolCallResponseResult::ConversationStream(
|
||||
ConversationStreamResult {},
|
||||
)),
|
||||
Some(false),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
pub(crate) async fn handle_cancel(
|
||||
message_processor: &MessageProcessor,
|
||||
args: &ConversationStreamArgs,
|
||||
) {
|
||||
disable_stream_for_session(message_processor, args.conversation_id.0).await;
|
||||
}
|
||||
|
||||
async fn disable_stream_for_session(message_processor: &MessageProcessor, session_id: Uuid) {
|
||||
let sender_opt: Option<tokio::sync::watch::Sender<bool>> = {
|
||||
let senders = message_processor.streaming_session_senders();
|
||||
let guard = senders.lock().await;
|
||||
guard.get(&session_id).cloned()
|
||||
};
|
||||
if let Some(tx) = sender_opt {
|
||||
let _ = tx.send(false);
|
||||
}
|
||||
}
|
||||
@@ -17,6 +17,7 @@ use codex_mcp_server::CodexToolCallReplyParam;
|
||||
use codex_mcp_server::mcp_protocol::ConversationCreateArgs;
|
||||
use codex_mcp_server::mcp_protocol::ConversationId;
|
||||
use codex_mcp_server::mcp_protocol::ConversationSendMessageArgs;
|
||||
use codex_mcp_server::mcp_protocol::ConversationStreamArgs;
|
||||
use codex_mcp_server::mcp_protocol::ToolCallRequestParams;
|
||||
|
||||
use mcp_types::CallToolRequestParams;
|
||||
@@ -201,6 +202,20 @@ impl McpProcess {
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn send_conversation_stream_tool_call(
|
||||
&mut self,
|
||||
session_id: &str,
|
||||
) -> anyhow::Result<i64> {
|
||||
let params = ToolCallRequestParams::ConversationStream(ConversationStreamArgs {
|
||||
conversation_id: ConversationId(Uuid::parse_str(session_id)?),
|
||||
});
|
||||
self.send_request(
|
||||
mcp_types::CallToolRequest::METHOD,
|
||||
Some(serde_json::to_value(params)?),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn send_conversation_create_tool_call(
|
||||
&mut self,
|
||||
prompt: &str,
|
||||
@@ -236,6 +251,99 @@ impl McpProcess {
|
||||
.await
|
||||
}
|
||||
|
||||
/// Create a conversation and return its conversation_id as a string.
|
||||
pub async fn create_conversation_and_get_id(
|
||||
&mut self,
|
||||
prompt: &str,
|
||||
model: &str,
|
||||
cwd: &str,
|
||||
) -> anyhow::Result<String> {
|
||||
let req_id = self
|
||||
.send_conversation_create_tool_call(prompt, model, cwd)
|
||||
.await?;
|
||||
let resp = self
|
||||
.read_stream_until_response_message(RequestId::Integer(req_id))
|
||||
.await?;
|
||||
let conv_id = resp.result["structuredContent"]["conversation_id"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::format_err!("missing conversation_id"))?
|
||||
.to_string();
|
||||
Ok(conv_id)
|
||||
}
|
||||
|
||||
/// Connect stream for a conversation and wait for the initial_state notification.
|
||||
/// Returns the params of the initial_state notification for further inspection.
|
||||
pub async fn connect_stream_and_expect_initial_state(
|
||||
&mut self,
|
||||
session_id: &str,
|
||||
) -> anyhow::Result<serde_json::Value> {
|
||||
let req_id = self.send_conversation_stream_tool_call(session_id).await?;
|
||||
// Wait for stream() tool-call response first
|
||||
let _ = self
|
||||
.read_stream_until_response_message(RequestId::Integer(req_id))
|
||||
.await?;
|
||||
// Then the initial_state notification
|
||||
let note = self
|
||||
.read_stream_until_notification_method("notifications/initial_state")
|
||||
.await?;
|
||||
note.params
|
||||
.ok_or_else(|| anyhow::format_err!("initial_state must have params"))
|
||||
}
|
||||
|
||||
/// Connect stream and also return the request id for later cancellation.
|
||||
pub async fn connect_stream_get_req_and_initial_state(
|
||||
&mut self,
|
||||
session_id: &str,
|
||||
) -> anyhow::Result<(i64, serde_json::Value)> {
|
||||
let req_id = self.send_conversation_stream_tool_call(session_id).await?;
|
||||
let _ = self
|
||||
.read_stream_until_response_message(RequestId::Integer(req_id))
|
||||
.await?;
|
||||
let note = self
|
||||
.read_stream_until_notification_method("notifications/initial_state")
|
||||
.await?;
|
||||
let params = note
|
||||
.params
|
||||
.ok_or_else(|| anyhow::format_err!("initial_state must have params"))?;
|
||||
Ok((req_id, params))
|
||||
}
|
||||
|
||||
/// Wait for an agent_message with a bounded timeout. Returns Some(params) if received, None on timeout.
|
||||
pub async fn maybe_wait_for_agent_message(
|
||||
&mut self,
|
||||
dur: std::time::Duration,
|
||||
) -> anyhow::Result<Option<serde_json::Value>> {
|
||||
match tokio::time::timeout(dur, self.wait_for_agent_message()).await {
|
||||
Ok(Ok(v)) => Ok(Some(v)),
|
||||
Ok(Err(e)) => Err(e),
|
||||
Err(_elapsed) => Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
/// Send a user message to a conversation and wait for the OK tool-call response.
|
||||
pub async fn send_user_message_and_wait_ok(
|
||||
&mut self,
|
||||
message: &str,
|
||||
session_id: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
let req_id = self
|
||||
.send_user_message_tool_call(message, session_id)
|
||||
.await?;
|
||||
let _ = self
|
||||
.read_stream_until_response_message(RequestId::Integer(req_id))
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Wait until an agent_message notification arrives; returns its params.
|
||||
pub async fn wait_for_agent_message(&mut self) -> anyhow::Result<serde_json::Value> {
|
||||
let note = self
|
||||
.read_stream_until_notification_method("agent_message")
|
||||
.await?;
|
||||
note.params
|
||||
.ok_or_else(|| anyhow::format_err!("agent_message missing params"))
|
||||
}
|
||||
|
||||
async fn send_request(
|
||||
&mut self,
|
||||
method: &str,
|
||||
@@ -329,6 +437,31 @@ impl McpProcess {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn read_stream_until_notification_method(
|
||||
&mut self,
|
||||
method: &str,
|
||||
) -> anyhow::Result<JSONRPCNotification> {
|
||||
loop {
|
||||
let message = self.read_jsonrpc_message().await?;
|
||||
match message {
|
||||
JSONRPCMessage::Notification(n) => {
|
||||
if n.method == method {
|
||||
return Ok(n);
|
||||
}
|
||||
}
|
||||
JSONRPCMessage::Request(_) => {
|
||||
// ignore
|
||||
}
|
||||
JSONRPCMessage::Error(_) => {
|
||||
anyhow::bail!("unexpected JSONRPCMessage::Error: {message:?}");
|
||||
}
|
||||
JSONRPCMessage::Response(_) => {
|
||||
// ignore
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn read_stream_until_configured_response_message(
|
||||
&mut self,
|
||||
) -> anyhow::Result<String> {
|
||||
|
||||
228
codex-rs/mcp-server/tests/stream_conversation.rs
Normal file
228
codex-rs/mcp-server/tests/stream_conversation.rs
Normal file
@@ -0,0 +1,228 @@
|
||||
#![allow(clippy::expect_used, clippy::unwrap_used)]
|
||||
|
||||
use std::path::Path;
|
||||
|
||||
use mcp_test_support::McpProcess;
|
||||
use mcp_test_support::create_final_assistant_message_sse_response;
|
||||
use mcp_test_support::create_mock_chat_completions_server;
|
||||
use mcp_types::JSONRPCNotification;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
use tempfile::TempDir;
|
||||
use tokio::time::timeout;
|
||||
|
||||
const DEFAULT_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn test_connect_then_send_receives_initial_state_and_notifications() {
|
||||
let responses = vec![
|
||||
create_final_assistant_message_sse_response("Done").expect("build mock assistant message"),
|
||||
];
|
||||
let server = create_mock_chat_completions_server(responses).await;
|
||||
|
||||
let codex_home = TempDir::new().expect("create temp dir");
|
||||
create_config_toml(codex_home.path(), &server.uri()).expect("write config.toml");
|
||||
|
||||
let mut mcp = McpProcess::new(codex_home.path())
|
||||
.await
|
||||
.expect("spawn mcp process");
|
||||
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize())
|
||||
.await
|
||||
.expect("init timeout")
|
||||
.expect("init failed");
|
||||
|
||||
// Create conversation
|
||||
let conv_id = mcp
|
||||
.create_conversation_and_get_id("", "o3", "/repo")
|
||||
.await
|
||||
.expect("create conversation");
|
||||
|
||||
// Connect the stream
|
||||
let params = mcp
|
||||
.connect_stream_and_expect_initial_state(&conv_id)
|
||||
.await
|
||||
.expect("initial_state params");
|
||||
assert_eq!(
|
||||
params["_meta"]["conversationId"].as_str(),
|
||||
Some(conv_id.as_str())
|
||||
);
|
||||
assert!(params["initial_state"]["events"].is_array());
|
||||
assert!(
|
||||
params["initial_state"]["events"]
|
||||
.as_array()
|
||||
.unwrap()
|
||||
.is_empty()
|
||||
);
|
||||
|
||||
// Send a message and expect a subsequent notification (non-initial_state)
|
||||
mcp.send_user_message_and_wait_ok("Hello there", &conv_id)
|
||||
.await
|
||||
.expect("send message ok");
|
||||
|
||||
// Read until we see an event notification (new schema example: agent_message)
|
||||
let params = mcp.wait_for_agent_message().await.expect("agent message");
|
||||
assert_eq!(params["msg"]["type"].as_str(), Some("agent_message"));
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn test_send_then_connect_receives_initial_state_with_message() {
|
||||
let responses = vec![
|
||||
create_final_assistant_message_sse_response("Done").expect("build mock assistant message"),
|
||||
];
|
||||
let server = create_mock_chat_completions_server(responses).await;
|
||||
|
||||
let codex_home = TempDir::new().expect("create temp dir");
|
||||
create_config_toml(codex_home.path(), &server.uri()).expect("write config.toml");
|
||||
|
||||
let mut mcp = McpProcess::new(codex_home.path())
|
||||
.await
|
||||
.expect("spawn mcp process");
|
||||
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize())
|
||||
.await
|
||||
.expect("init timeout")
|
||||
.expect("init failed");
|
||||
|
||||
// Create conversation
|
||||
let conv_id = mcp
|
||||
.create_conversation_and_get_id("", "o3", "/repo")
|
||||
.await
|
||||
.expect("create conversation");
|
||||
|
||||
// Send a message BEFORE connecting stream
|
||||
mcp.send_user_message_and_wait_ok("Hello world", &conv_id)
|
||||
.await
|
||||
.expect("send message ok");
|
||||
|
||||
// Now connect stream and expect InitialState with the prior message included
|
||||
let params = mcp
|
||||
.connect_stream_and_expect_initial_state(&conv_id)
|
||||
.await
|
||||
.expect("initial_state params");
|
||||
let events = params["initial_state"]["events"]
|
||||
.as_array()
|
||||
.expect("events array");
|
||||
if !events.iter().any(|ev| {
|
||||
ev.get("msg")
|
||||
.and_then(|m| m.get("type"))
|
||||
.and_then(|t| t.as_str())
|
||||
== Some("agent_message")
|
||||
&& ev
|
||||
.get("msg")
|
||||
.and_then(|m| m.get("message"))
|
||||
.and_then(|t| t.as_str())
|
||||
== Some("Done")
|
||||
}) {
|
||||
// Fallback to live notification if not present in initial state
|
||||
let note: JSONRPCNotification = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_notification_method("agent_message"),
|
||||
)
|
||||
.await
|
||||
.expect("event note timeout")
|
||||
.expect("event note err");
|
||||
let params = note.params.expect("params");
|
||||
assert_eq!(params["msg"]["type"].as_str(), Some("agent_message"));
|
||||
assert_eq!(params["msg"]["message"].as_str(), Some("Done"));
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn test_cancel_stream_then_reconnect_catches_up_initial_state() {
|
||||
// One response is sufficient for the assertions in this test
|
||||
let responses = vec![
|
||||
create_final_assistant_message_sse_response("Done").expect("build mock assistant message"),
|
||||
];
|
||||
let server = create_mock_chat_completions_server(responses).await;
|
||||
|
||||
let codex_home = TempDir::new().expect("create temp dir");
|
||||
create_config_toml(codex_home.path(), &server.uri()).expect("write config.toml");
|
||||
|
||||
let mut mcp = McpProcess::new(codex_home.path())
|
||||
.await
|
||||
.expect("spawn mcp process");
|
||||
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize())
|
||||
.await
|
||||
.expect("init timeout")
|
||||
.expect("init failed");
|
||||
|
||||
// Create and connect stream A
|
||||
let conv_id = mcp
|
||||
.create_conversation_and_get_id("", "o3", "/repo")
|
||||
.await
|
||||
.expect("create");
|
||||
let (stream_a_id, _params) = mcp
|
||||
.connect_stream_get_req_and_initial_state(&conv_id)
|
||||
.await
|
||||
.expect("stream A initial_state");
|
||||
|
||||
// Send M1 and ensure we get live agent_message
|
||||
mcp.send_user_message_and_wait_ok("Hello M1", &conv_id)
|
||||
.await
|
||||
.expect("send M1");
|
||||
let _params = mcp.wait_for_agent_message().await.expect("agent M1");
|
||||
|
||||
// Cancel stream A
|
||||
mcp.send_notification(
|
||||
"notifications/cancelled",
|
||||
Some(json!({ "requestId": stream_a_id })),
|
||||
)
|
||||
.await
|
||||
.expect("send cancelled");
|
||||
|
||||
// Send M2 while stream is cancelled; we should NOT get agent_message live
|
||||
mcp.send_user_message_and_wait_ok("Hello M2", &conv_id)
|
||||
.await
|
||||
.expect("send M2");
|
||||
let maybe = mcp
|
||||
.maybe_wait_for_agent_message(std::time::Duration::from_millis(300))
|
||||
.await
|
||||
.expect("maybe wait");
|
||||
assert!(
|
||||
maybe.is_none(),
|
||||
"should not get live agent_message after cancel"
|
||||
);
|
||||
|
||||
// Connect stream B and expect initial_state that includes the response
|
||||
let params = mcp
|
||||
.connect_stream_and_expect_initial_state(&conv_id)
|
||||
.await
|
||||
.expect("stream B initial_state");
|
||||
let events = params["initial_state"]["events"]
|
||||
.as_array()
|
||||
.expect("events array");
|
||||
assert!(events.iter().any(|ev| {
|
||||
ev.get("msg")
|
||||
.and_then(|m| m.get("type"))
|
||||
.and_then(|t| t.as_str())
|
||||
== Some("agent_message")
|
||||
&& ev
|
||||
.get("msg")
|
||||
.and_then(|m| m.get("message"))
|
||||
.and_then(|t| t.as_str())
|
||||
== Some("Done")
|
||||
}));
|
||||
}
|
||||
|
||||
// Helper to create a config.toml pointing at the mock model server.
|
||||
fn create_config_toml(codex_home: &Path, server_uri: &str) -> std::io::Result<()> {
|
||||
let config_toml = codex_home.join("config.toml");
|
||||
std::fs::write(
|
||||
config_toml,
|
||||
format!(
|
||||
r#"
|
||||
model = "mock-model"
|
||||
approval_policy = "never"
|
||||
sandbox_mode = "danger-full-access"
|
||||
|
||||
model_provider = "mock_provider"
|
||||
|
||||
[model_providers.mock_provider]
|
||||
name = "Mock provider for test"
|
||||
base_url = "{server_uri}/v1"
|
||||
wire_api = "chat"
|
||||
request_max_retries = 0
|
||||
stream_max_retries = 0
|
||||
"#
|
||||
),
|
||||
)
|
||||
}
|
||||
Reference in New Issue
Block a user