mirror of
https://github.com/openai/codex.git
synced 2026-02-01 22:47:52 +00:00
Compare commits
42 Commits
jif/fix-ui
...
stream-con
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3b90ca959b | ||
|
|
c6cfdf705c | ||
|
|
622a84f4ba | ||
|
|
34704ff055 | ||
|
|
0bf33f7359 | ||
|
|
0c9d8f13e5 | ||
|
|
4259e5787f | ||
|
|
c85c6dfccd | ||
|
|
c1e9083cbd | ||
|
|
f68bf94db1 | ||
|
|
e054715bea | ||
|
|
c72fe752cc | ||
|
|
985c97985b | ||
|
|
7dec04ae4f | ||
|
|
c182126bca | ||
|
|
a362ad00ce | ||
|
|
c515d2869e | ||
|
|
bfbe523f81 | ||
|
|
95423b26d7 | ||
|
|
5bab2bd2f8 | ||
|
|
1294def888 | ||
|
|
ab70497539 | ||
|
|
2a40d07a06 | ||
|
|
2e07f4b033 | ||
|
|
324926e240 | ||
|
|
9805ad1fbc | ||
|
|
792efc990c | ||
|
|
ec6a4f9e2a | ||
|
|
c01b9d2d2a | ||
|
|
d5efc45869 | ||
|
|
dbcb9e7ca6 | ||
|
|
8d413194f3 | ||
|
|
19d3e17572 | ||
|
|
a5b3c151ac | ||
|
|
0110749efa | ||
|
|
bea4a5358a | ||
|
|
4c13829e8b | ||
|
|
5ccd02b0fe | ||
|
|
21c334ae54 | ||
|
|
66ea94f723 | ||
|
|
ae6becc58d | ||
|
|
3a456c1fbb |
1
codex-rs/Cargo.lock
generated
1
codex-rs/Cargo.lock
generated
@@ -831,6 +831,7 @@ dependencies = [
|
||||
"tempfile",
|
||||
"tokio",
|
||||
"tokio-test",
|
||||
"tokio-util",
|
||||
"toml 0.9.4",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
|
||||
@@ -31,6 +31,7 @@ tokio = { version = "1", features = [
|
||||
"rt-multi-thread",
|
||||
"signal",
|
||||
] }
|
||||
tokio-util = { version = "0.7" }
|
||||
toml = "0.9"
|
||||
tracing = { version = "0.1.41", features = ["log"] }
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter", "fmt"] }
|
||||
|
||||
@@ -15,7 +15,6 @@ use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::ExecApprovalRequestEvent;
|
||||
use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_core::protocol::Submission;
|
||||
use codex_core::protocol::TaskCompleteEvent;
|
||||
use mcp_types::CallToolResult;
|
||||
use mcp_types::ContentBlock;
|
||||
@@ -79,27 +78,18 @@ pub async fn run_codex_tool_session(
|
||||
)
|
||||
.await;
|
||||
|
||||
// Use the original MCP request ID as the `sub_id` for the Codex submission so that
|
||||
// any events emitted for this tool-call can be correlated with the
|
||||
// originating `tools/call` request.
|
||||
let sub_id = match &id {
|
||||
RequestId::String(s) => s.clone(),
|
||||
RequestId::Integer(n) => n.to_string(),
|
||||
};
|
||||
running_requests_id_to_codex_uuid
|
||||
.lock()
|
||||
.await
|
||||
.insert(id.clone(), session_id);
|
||||
let submission = Submission {
|
||||
id: sub_id.clone(),
|
||||
op: Op::UserInput {
|
||||
if let Err(e) = codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![InputItem::Text {
|
||||
text: initial_prompt.clone(),
|
||||
}],
|
||||
},
|
||||
};
|
||||
|
||||
if let Err(e) = codex.submit_with_id(submission).await {
|
||||
})
|
||||
.await
|
||||
{
|
||||
tracing::error!("Failed to submit initial prompt: {e}");
|
||||
// unregister the id so we don't keep it in the map
|
||||
running_requests_id_to_codex_uuid.lock().await.remove(&id);
|
||||
@@ -151,10 +141,7 @@ async fn run_codex_tool_session_inner(
|
||||
request_id: RequestId,
|
||||
running_requests_id_to_codex_uuid: Arc<Mutex<HashMap<RequestId, Uuid>>>,
|
||||
) {
|
||||
let request_id_str = match &request_id {
|
||||
RequestId::String(s) => s.clone(),
|
||||
RequestId::Integer(n) => n.to_string(),
|
||||
};
|
||||
let request_id_str = crate::request_id::request_id_to_string(&request_id);
|
||||
|
||||
// Stream events until the task needs to pause for user interaction or
|
||||
// completes.
|
||||
|
||||
@@ -1,124 +1,369 @@
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::exec_approval::handle_exec_approval_request;
|
||||
use crate::outgoing_message::OutgoingMessageSender;
|
||||
use crate::outgoing_message::OutgoingNotificationMeta;
|
||||
use crate::patch_approval::handle_patch_approval_request;
|
||||
use codex_core::Codex;
|
||||
use codex_core::error::Result as CodexResult;
|
||||
use codex_core::protocol::AgentMessageEvent;
|
||||
use codex_core::protocol::ApplyPatchApprovalRequestEvent;
|
||||
use codex_core::protocol::Event;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::ExecApprovalRequestEvent;
|
||||
use codex_core::protocol::FileChange;
|
||||
use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
use mcp_types::RequestId;
|
||||
use tokio::sync::Mutex;
|
||||
// no streaming watch channel; streaming is toggled via set_streaming on the struct
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::error;
|
||||
use uuid::Uuid;
|
||||
|
||||
pub async fn run_conversation_loop(
|
||||
use crate::exec_approval::handle_exec_approval_request;
|
||||
use crate::mcp_protocol::CodexEventNotificationParams;
|
||||
use crate::mcp_protocol::ConversationId;
|
||||
use crate::mcp_protocol::InitialStateNotificationParams;
|
||||
use crate::mcp_protocol::InitialStatePayload;
|
||||
use crate::mcp_protocol::NotificationMeta;
|
||||
use crate::mcp_protocol::ServerNotification;
|
||||
use crate::outgoing_message::OutgoingMessageSender;
|
||||
use crate::patch_approval::handle_patch_approval_request;
|
||||
use crate::request_id::request_id_to_string;
|
||||
|
||||
/// Conversation struct that owns the Codex session and all per-conversation state.
|
||||
pub(crate) struct Conversation {
|
||||
codex: Arc<Codex>,
|
||||
session_id: Uuid,
|
||||
outgoing: Arc<OutgoingMessageSender>,
|
||||
request_id: RequestId,
|
||||
) {
|
||||
let request_id_str = match &request_id {
|
||||
RequestId::String(s) => s.clone(),
|
||||
RequestId::Integer(n) => n.to_string(),
|
||||
};
|
||||
state: Mutex<ConversationState>,
|
||||
cancel: CancellationToken,
|
||||
}
|
||||
|
||||
// Stream events until the task needs to pause for user interaction or
|
||||
// completes.
|
||||
loop {
|
||||
match codex.next_event().await {
|
||||
Ok(event) => {
|
||||
outgoing
|
||||
.send_event_as_notification(
|
||||
&event,
|
||||
Some(OutgoingNotificationMeta::new(Some(request_id.clone()))),
|
||||
)
|
||||
struct ConversationState {
|
||||
streaming_enabled: bool,
|
||||
buffered_events: Vec<CodexEventNotificationParams>,
|
||||
pending_elicitations: Vec<PendingElicitation>,
|
||||
}
|
||||
|
||||
impl Conversation {
|
||||
pub(crate) fn new(
|
||||
codex: Arc<Codex>,
|
||||
outgoing: Arc<OutgoingMessageSender>,
|
||||
request_id: RequestId,
|
||||
session_id: Uuid,
|
||||
) -> Arc<Self> {
|
||||
let conv = Arc::new(Self {
|
||||
codex,
|
||||
session_id,
|
||||
outgoing,
|
||||
request_id,
|
||||
state: Mutex::new(ConversationState {
|
||||
streaming_enabled: false,
|
||||
buffered_events: Vec::new(),
|
||||
pending_elicitations: Vec::new(),
|
||||
}),
|
||||
cancel: CancellationToken::new(),
|
||||
});
|
||||
// Detach a background loop tied to this Conversation
|
||||
spawn_conversation_loop(conv.clone());
|
||||
conv
|
||||
}
|
||||
|
||||
pub(crate) async fn set_streaming(&self, enabled: bool) {
|
||||
if enabled {
|
||||
let (events_snapshot, pending_snapshot) = {
|
||||
let mut st = self.state.lock().await;
|
||||
st.streaming_enabled = true;
|
||||
(
|
||||
st.buffered_events.clone(),
|
||||
std::mem::take(&mut st.pending_elicitations),
|
||||
)
|
||||
};
|
||||
self.emit_initial_state_with(events_snapshot).await;
|
||||
self.drain_pending_elicitations_from(pending_snapshot).await;
|
||||
} else {
|
||||
let mut st = self.state.lock().await;
|
||||
st.streaming_enabled = false;
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn codex(&self) -> Arc<Codex> {
|
||||
self.codex.clone()
|
||||
}
|
||||
|
||||
pub(crate) async fn try_submit_user_input(
|
||||
&self,
|
||||
request_id: RequestId,
|
||||
items: Vec<InputItem>,
|
||||
) -> CodexResult<()> {
|
||||
let _ = request_id; // request_id is not used to enforce uniqueness; Codex generates ids.
|
||||
self.codex.submit(Op::UserInput { items }).await.map(|_| ())
|
||||
}
|
||||
|
||||
async fn handle_event(&self, event: Event) {
|
||||
{
|
||||
let mut st = self.state.lock().await;
|
||||
st.buffered_events.push(CodexEventNotificationParams {
|
||||
meta: None,
|
||||
msg: event.msg.clone(),
|
||||
});
|
||||
}
|
||||
self.stream_event_if_enabled(&event.msg).await;
|
||||
|
||||
match event.msg {
|
||||
EventMsg::ExecApprovalRequest(ExecApprovalRequestEvent {
|
||||
command,
|
||||
cwd,
|
||||
call_id,
|
||||
reason: _,
|
||||
}) => {
|
||||
self.process_exec_request(command, cwd, call_id, event.id.clone())
|
||||
.await;
|
||||
}
|
||||
EventMsg::Error(err) => {
|
||||
error!("Codex runtime error: {}", err.message);
|
||||
}
|
||||
EventMsg::ApplyPatchApprovalRequest(ApplyPatchApprovalRequestEvent {
|
||||
call_id,
|
||||
reason,
|
||||
grant_root,
|
||||
changes,
|
||||
}) => {
|
||||
self.start_patch_approval(PatchRequest {
|
||||
call_id,
|
||||
reason,
|
||||
grant_root,
|
||||
changes,
|
||||
event_id: event.id.clone(),
|
||||
})
|
||||
.await;
|
||||
}
|
||||
EventMsg::TaskComplete(_) => {}
|
||||
EventMsg::TaskStarted => {}
|
||||
EventMsg::SessionConfigured(ev) => {
|
||||
error!("unexpected SessionConfigured event: {:?}", ev);
|
||||
}
|
||||
EventMsg::AgentMessageDelta(_) => {}
|
||||
EventMsg::AgentReasoningDelta(_) => {}
|
||||
EventMsg::AgentMessage(AgentMessageEvent { .. }) => {}
|
||||
EventMsg::TokenCount(_)
|
||||
| EventMsg::AgentReasoning(_)
|
||||
| EventMsg::AgentReasoningRawContent(_)
|
||||
| EventMsg::AgentReasoningRawContentDelta(_)
|
||||
| EventMsg::McpToolCallBegin(_)
|
||||
| EventMsg::McpToolCallEnd(_)
|
||||
| EventMsg::ExecCommandBegin(_)
|
||||
| EventMsg::ExecCommandEnd(_)
|
||||
| EventMsg::BackgroundEvent(_)
|
||||
| EventMsg::ExecCommandOutputDelta(_)
|
||||
| EventMsg::PatchApplyBegin(_)
|
||||
| EventMsg::PatchApplyEnd(_)
|
||||
| EventMsg::GetHistoryEntryResponse(_)
|
||||
| EventMsg::PlanUpdate(_)
|
||||
| EventMsg::TurnDiff(_)
|
||||
| EventMsg::ShutdownComplete => {}
|
||||
}
|
||||
}
|
||||
|
||||
match event.msg {
|
||||
EventMsg::ExecApprovalRequest(ExecApprovalRequestEvent {
|
||||
async fn emit_initial_state_with(&self, events: Vec<CodexEventNotificationParams>) {
|
||||
let params = InitialStateNotificationParams {
|
||||
meta: Some(NotificationMeta {
|
||||
conversation_id: Some(ConversationId(self.session_id)),
|
||||
request_id: None,
|
||||
}),
|
||||
initial_state: InitialStatePayload { events },
|
||||
};
|
||||
self.outgoing
|
||||
.send_server_notification(ServerNotification::InitialState(params))
|
||||
.await;
|
||||
}
|
||||
|
||||
async fn drain_pending_elicitations_from(&self, items: Vec<PendingElicitation>) {
|
||||
for item in items {
|
||||
match item {
|
||||
PendingElicitation::ExecRequest(ExecRequest {
|
||||
command,
|
||||
cwd,
|
||||
event_id,
|
||||
call_id,
|
||||
}) => {
|
||||
handle_exec_approval_request(
|
||||
command,
|
||||
cwd,
|
||||
self.outgoing.clone(),
|
||||
self.codex.clone(),
|
||||
self.request_id.clone(),
|
||||
request_id_to_string(&self.request_id),
|
||||
event_id,
|
||||
call_id,
|
||||
reason: _,
|
||||
}) => {
|
||||
handle_exec_approval_request(
|
||||
command,
|
||||
cwd,
|
||||
outgoing.clone(),
|
||||
codex.clone(),
|
||||
request_id.clone(),
|
||||
request_id_str.clone(),
|
||||
event.id.clone(),
|
||||
call_id,
|
||||
)
|
||||
.await;
|
||||
continue;
|
||||
}
|
||||
EventMsg::Error(_) => {
|
||||
error!("Codex runtime error");
|
||||
}
|
||||
EventMsg::ApplyPatchApprovalRequest(ApplyPatchApprovalRequestEvent {
|
||||
)
|
||||
.await;
|
||||
}
|
||||
PendingElicitation::PatchRequest(PatchRequest {
|
||||
call_id,
|
||||
reason,
|
||||
grant_root,
|
||||
changes,
|
||||
event_id,
|
||||
}) => {
|
||||
handle_patch_approval_request(
|
||||
call_id,
|
||||
reason,
|
||||
grant_root,
|
||||
changes,
|
||||
}) => {
|
||||
handle_patch_approval_request(
|
||||
call_id,
|
||||
reason,
|
||||
grant_root,
|
||||
changes,
|
||||
outgoing.clone(),
|
||||
codex.clone(),
|
||||
request_id.clone(),
|
||||
request_id_str.clone(),
|
||||
event.id.clone(),
|
||||
)
|
||||
.await;
|
||||
continue;
|
||||
}
|
||||
EventMsg::TaskComplete(_) => {}
|
||||
EventMsg::SessionConfigured(_) => {
|
||||
tracing::error!("unexpected SessionConfigured event");
|
||||
}
|
||||
EventMsg::AgentMessageDelta(_) => {
|
||||
// TODO: think how we want to support this in the MCP
|
||||
}
|
||||
EventMsg::AgentReasoningDelta(_) => {
|
||||
// TODO: think how we want to support this in the MCP
|
||||
}
|
||||
EventMsg::AgentMessage(AgentMessageEvent { .. }) => {
|
||||
// TODO: think how we want to support this in the MCP
|
||||
}
|
||||
EventMsg::AgentReasoningRawContent(_)
|
||||
| EventMsg::AgentReasoningRawContentDelta(_)
|
||||
| EventMsg::TaskStarted
|
||||
| EventMsg::TokenCount(_)
|
||||
| EventMsg::AgentReasoning(_)
|
||||
| EventMsg::McpToolCallBegin(_)
|
||||
| EventMsg::McpToolCallEnd(_)
|
||||
| EventMsg::ExecCommandBegin(_)
|
||||
| EventMsg::ExecCommandEnd(_)
|
||||
| EventMsg::TurnDiff(_)
|
||||
| EventMsg::BackgroundEvent(_)
|
||||
| EventMsg::ExecCommandOutputDelta(_)
|
||||
| EventMsg::PatchApplyBegin(_)
|
||||
| EventMsg::PatchApplyEnd(_)
|
||||
| EventMsg::GetHistoryEntryResponse(_)
|
||||
| EventMsg::PlanUpdate(_)
|
||||
| EventMsg::ShutdownComplete => {
|
||||
// For now, we do not do anything extra for these
|
||||
// events. Note that
|
||||
// send(codex_event_to_notification(&event)) above has
|
||||
// already dispatched these events as notifications,
|
||||
// though we may want to do give different treatment to
|
||||
// individual events in the future.
|
||||
}
|
||||
self.outgoing.clone(),
|
||||
self.codex.clone(),
|
||||
self.request_id.clone(),
|
||||
request_id_to_string(&self.request_id),
|
||||
event_id,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Codex runtime error: {e}");
|
||||
}
|
||||
}
|
||||
|
||||
async fn process_exec_request(
|
||||
&self,
|
||||
command: Vec<String>,
|
||||
cwd: PathBuf,
|
||||
call_id: String,
|
||||
event_id: String,
|
||||
) {
|
||||
let should_stream = {
|
||||
let st = self.state.lock().await;
|
||||
st.streaming_enabled
|
||||
};
|
||||
if should_stream {
|
||||
handle_exec_approval_request(
|
||||
command,
|
||||
cwd,
|
||||
self.outgoing.clone(),
|
||||
self.codex.clone(),
|
||||
self.request_id.clone(),
|
||||
request_id_to_string(&self.request_id),
|
||||
event_id,
|
||||
call_id,
|
||||
)
|
||||
.await;
|
||||
} else {
|
||||
let mut st = self.state.lock().await;
|
||||
st.pending_elicitations
|
||||
.push(PendingElicitation::ExecRequest(ExecRequest {
|
||||
command,
|
||||
cwd,
|
||||
event_id,
|
||||
call_id,
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
async fn start_patch_approval(&self, req: PatchRequest) {
|
||||
let PatchRequest {
|
||||
call_id,
|
||||
reason,
|
||||
grant_root,
|
||||
changes,
|
||||
event_id,
|
||||
} = req;
|
||||
let should_stream = {
|
||||
let st = self.state.lock().await;
|
||||
st.streaming_enabled
|
||||
};
|
||||
if should_stream {
|
||||
handle_patch_approval_request(
|
||||
call_id,
|
||||
reason,
|
||||
grant_root,
|
||||
changes,
|
||||
self.outgoing.clone(),
|
||||
self.codex.clone(),
|
||||
self.request_id.clone(),
|
||||
request_id_to_string(&self.request_id),
|
||||
event_id,
|
||||
)
|
||||
.await;
|
||||
} else {
|
||||
let mut st = self.state.lock().await;
|
||||
st.pending_elicitations
|
||||
.push(PendingElicitation::PatchRequest(PatchRequest {
|
||||
call_id,
|
||||
reason,
|
||||
grant_root,
|
||||
changes,
|
||||
event_id,
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
async fn stream_event_if_enabled(&self, msg: &EventMsg) {
|
||||
if !{ self.state.lock().await.streaming_enabled } {
|
||||
return;
|
||||
}
|
||||
let method = msg.to_string();
|
||||
let params = CodexEventNotificationParams {
|
||||
meta: None,
|
||||
msg: msg.clone(),
|
||||
};
|
||||
match serde_json::to_value(¶ms) {
|
||||
Ok(params_val) => {
|
||||
self.outgoing
|
||||
.send_custom_notification(&method, params_val)
|
||||
.await;
|
||||
}
|
||||
Err(err) => {
|
||||
error!("Failed to serialize event params: {err:?}");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
enum PendingElicitation {
|
||||
ExecRequest(ExecRequest),
|
||||
PatchRequest(PatchRequest),
|
||||
}
|
||||
|
||||
struct PatchRequest {
|
||||
call_id: String,
|
||||
reason: Option<String>,
|
||||
grant_root: Option<PathBuf>,
|
||||
changes: HashMap<PathBuf, FileChange>,
|
||||
event_id: String,
|
||||
}
|
||||
|
||||
struct ExecRequest {
|
||||
command: Vec<String>,
|
||||
cwd: PathBuf,
|
||||
event_id: String,
|
||||
call_id: String,
|
||||
}
|
||||
|
||||
impl Drop for Conversation {
|
||||
fn drop(&mut self) {
|
||||
self.cancel.cancel();
|
||||
}
|
||||
}
|
||||
|
||||
fn spawn_conversation_loop(this: Arc<Conversation>) {
|
||||
tokio::spawn(async move {
|
||||
let codex = this.codex.clone();
|
||||
let cancel = this.cancel.clone();
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = cancel.cancelled() => {
|
||||
break;
|
||||
}
|
||||
res = codex.next_event() => {
|
||||
match res {
|
||||
Ok(event) => this.handle_event(event).await,
|
||||
Err(e) => {
|
||||
error!("Codex next_event error (session {}): {e}", this.session_id);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@@ -24,6 +24,7 @@ pub mod mcp_protocol;
|
||||
pub(crate) mod message_processor;
|
||||
mod outgoing_message;
|
||||
mod patch_approval;
|
||||
mod request_id;
|
||||
pub(crate) mod tool_handlers;
|
||||
|
||||
use crate::message_processor::MessageProcessor;
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
@@ -13,10 +12,11 @@ use crate::mcp_protocol::ToolCallResponseResult;
|
||||
use crate::outgoing_message::OutgoingMessageSender;
|
||||
use crate::tool_handlers::create_conversation::handle_create_conversation;
|
||||
use crate::tool_handlers::send_message::handle_send_message;
|
||||
use crate::tool_handlers::stream_conversation;
|
||||
use crate::tool_handlers::stream_conversation::handle_stream_conversation;
|
||||
|
||||
use codex_core::Codex;
|
||||
use codex_core::config::Config as CodexConfig;
|
||||
use codex_core::protocol::Submission;
|
||||
use mcp_types::CallToolRequest;
|
||||
use mcp_types::CallToolRequestParams;
|
||||
use mcp_types::CallToolResult;
|
||||
@@ -43,8 +43,10 @@ pub(crate) struct MessageProcessor {
|
||||
initialized: bool,
|
||||
codex_linux_sandbox_exe: Option<PathBuf>,
|
||||
session_map: Arc<Mutex<HashMap<Uuid, Arc<Codex>>>>,
|
||||
conversation_map: Arc<Mutex<HashMap<Uuid, Arc<crate::conversation_loop::Conversation>>>>,
|
||||
running_requests_id_to_codex_uuid: Arc<Mutex<HashMap<RequestId, Uuid>>>,
|
||||
running_session_ids: Arc<Mutex<HashSet<Uuid>>>,
|
||||
/// Track request IDs to the original ToolCallRequestParams for cancellation handling
|
||||
tool_request_map: Arc<Mutex<HashMap<RequestId, ToolCallRequestParams>>>,
|
||||
}
|
||||
|
||||
impl MessageProcessor {
|
||||
@@ -59,23 +61,22 @@ impl MessageProcessor {
|
||||
initialized: false,
|
||||
codex_linux_sandbox_exe,
|
||||
session_map: Arc::new(Mutex::new(HashMap::new())),
|
||||
conversation_map: Arc::new(Mutex::new(HashMap::new())),
|
||||
running_requests_id_to_codex_uuid: Arc::new(Mutex::new(HashMap::new())),
|
||||
running_session_ids: Arc::new(Mutex::new(HashSet::new())),
|
||||
tool_request_map: Arc::new(Mutex::new(HashMap::new())),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn session_map(&self) -> Arc<Mutex<HashMap<Uuid, Arc<Codex>>>> {
|
||||
self.session_map.clone()
|
||||
pub(crate) fn conversation_map(
|
||||
&self,
|
||||
) -> Arc<Mutex<HashMap<Uuid, Arc<crate::conversation_loop::Conversation>>>> {
|
||||
self.conversation_map.clone()
|
||||
}
|
||||
|
||||
pub(crate) fn outgoing(&self) -> Arc<OutgoingMessageSender> {
|
||||
self.outgoing.clone()
|
||||
}
|
||||
|
||||
pub(crate) fn running_session_ids(&self) -> Arc<Mutex<HashSet<Uuid>>> {
|
||||
self.running_session_ids.clone()
|
||||
}
|
||||
|
||||
pub(crate) async fn process_request(&mut self, request: JSONRPCRequest) {
|
||||
// Hold on to the ID so we can respond.
|
||||
let request_id = request.id.clone();
|
||||
@@ -353,6 +354,11 @@ impl MessageProcessor {
|
||||
}
|
||||
}
|
||||
async fn handle_new_tool_calls(&self, request_id: RequestId, params: ToolCallRequestParams) {
|
||||
// Track the request to allow graceful cancellation routing later.
|
||||
{
|
||||
let mut tool_request_map = self.tool_request_map.lock().await;
|
||||
tool_request_map.insert(request_id.clone(), params.clone());
|
||||
}
|
||||
match params {
|
||||
ToolCallRequestParams::ConversationCreate(args) => {
|
||||
handle_create_conversation(self, request_id, args).await;
|
||||
@@ -360,6 +366,9 @@ impl MessageProcessor {
|
||||
ToolCallRequestParams::ConversationSendMessage(args) => {
|
||||
handle_send_message(self, request_id, args).await;
|
||||
}
|
||||
ToolCallRequestParams::ConversationStream(args) => {
|
||||
handle_stream_conversation(self, request_id, args).await;
|
||||
}
|
||||
_ => {
|
||||
let result = CallToolResult {
|
||||
content: vec![ContentBlock::TextContent(TextContent {
|
||||
@@ -584,23 +593,72 @@ impl MessageProcessor {
|
||||
// ---------------------------------------------------------------------
|
||||
// Notification handlers
|
||||
// ---------------------------------------------------------------------
|
||||
|
||||
async fn handle_cancelled_notification(
|
||||
&self,
|
||||
params: <mcp_types::CancelledNotification as mcp_types::ModelContextProtocolNotification>::Params,
|
||||
) {
|
||||
let request_id = params.request_id;
|
||||
// Create a stable string form early for logging and submission id.
|
||||
let request_id_string = match &request_id {
|
||||
RequestId::String(s) => s.clone(),
|
||||
RequestId::Integer(i) => i.to_string(),
|
||||
};
|
||||
|
||||
// Obtain the session_id while holding the first lock, then release.
|
||||
if let Some(orig) = {
|
||||
let mut tool_request_map = self.tool_request_map.lock().await;
|
||||
tool_request_map.remove(&request_id)
|
||||
} {
|
||||
self.handle_mcp_protocol_cancelled_notification(request_id, orig)
|
||||
.await;
|
||||
} else {
|
||||
self.handle_legacy_cancelled_notification(request_id).await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_mcp_protocol_cancelled_notification(
|
||||
&self,
|
||||
request_id: RequestId,
|
||||
orig: ToolCallRequestParams,
|
||||
) {
|
||||
match orig {
|
||||
ToolCallRequestParams::ConversationStream(args) => {
|
||||
stream_conversation::handle_cancel(self, &args).await;
|
||||
}
|
||||
ToolCallRequestParams::ConversationSendMessage(args) => {
|
||||
// Cancel in-flight user input for this conversation by interrupting the session.
|
||||
|
||||
let session_id = args.conversation_id.0;
|
||||
let codex_arc = {
|
||||
let sessions_guard = self.conversation_map.lock().await;
|
||||
match sessions_guard.get(&session_id) {
|
||||
Some(conv) => conv.codex().clone(),
|
||||
None => {
|
||||
tracing::warn!(
|
||||
"Cancel send_message: session not found for session_id: {session_id}"
|
||||
);
|
||||
return;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if let Err(e) = codex_arc.submit(codex_core::protocol::Op::Interrupt).await {
|
||||
tracing::error!("Failed to submit interrupt for send_message cancel: {e}");
|
||||
}
|
||||
}
|
||||
ToolCallRequestParams::ConversationCreate(_)
|
||||
| ToolCallRequestParams::ConversationsList(_) => {
|
||||
// Likely fast/non-streaming; nothing to cancel currently.
|
||||
tracing::debug!(
|
||||
"Cancel conversationsList received for request_id: {:?} (no-op)",
|
||||
request_id
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_legacy_cancelled_notification(&self, request_id: RequestId) {
|
||||
use crate::request_id::request_id_to_string;
|
||||
let request_id_string = request_id_to_string(&request_id);
|
||||
|
||||
let session_id = {
|
||||
let map_guard = self.running_requests_id_to_codex_uuid.lock().await;
|
||||
match map_guard.get(&request_id) {
|
||||
Some(id) => *id, // Uuid is Copy
|
||||
Some(id) => *id,
|
||||
None => {
|
||||
tracing::warn!("Session not found for request_id: {}", request_id_string);
|
||||
return;
|
||||
@@ -609,7 +667,6 @@ impl MessageProcessor {
|
||||
};
|
||||
tracing::info!("session_id: {session_id}");
|
||||
|
||||
// Obtain the Codex Arc while holding the session_map lock, then release.
|
||||
let codex_arc = {
|
||||
let sessions_guard = self.session_map.lock().await;
|
||||
match sessions_guard.get(&session_id) {
|
||||
@@ -621,18 +678,11 @@ impl MessageProcessor {
|
||||
}
|
||||
};
|
||||
|
||||
// Submit interrupt to Codex.
|
||||
let err = codex_arc
|
||||
.submit_with_id(Submission {
|
||||
id: request_id_string,
|
||||
op: codex_core::protocol::Op::Interrupt,
|
||||
})
|
||||
.await;
|
||||
if let Err(e) = err {
|
||||
if let Err(e) = codex_arc.submit(codex_core::protocol::Op::Interrupt).await {
|
||||
tracing::error!("Failed to submit interrupt to Codex: {e}");
|
||||
return;
|
||||
}
|
||||
// unregister the id so we don't keep it in the map
|
||||
|
||||
self.running_requests_id_to_codex_uuid
|
||||
.lock()
|
||||
.await
|
||||
|
||||
@@ -109,7 +109,7 @@ impl OutgoingMessageSender {
|
||||
|
||||
// should be backwards compatible.
|
||||
// it will replace send_event_as_notification eventually.
|
||||
async fn send_event_as_notification_new_schema(
|
||||
pub(crate) async fn send_event_as_notification_new_schema(
|
||||
&self,
|
||||
event: &Event,
|
||||
params: Option<serde_json::Value>,
|
||||
@@ -124,6 +124,37 @@ impl OutgoingMessageSender {
|
||||
let outgoing_message = OutgoingMessage::Error(OutgoingError { id, error });
|
||||
let _ = self.sender.send(outgoing_message).await;
|
||||
}
|
||||
|
||||
/// Send a custom notification with an explicit method name and params object.
|
||||
pub(crate) async fn send_custom_notification(&self, method: &str, params: serde_json::Value) {
|
||||
let outgoing_message = OutgoingMessage::Notification(OutgoingNotification {
|
||||
method: method.to_string(),
|
||||
params: Some(params),
|
||||
});
|
||||
let _ = self.sender.send(outgoing_message).await;
|
||||
}
|
||||
|
||||
/// Send a typed server notification by serializing it into a method/params pair.
|
||||
pub(crate) async fn send_server_notification(
|
||||
&self,
|
||||
notification: crate::mcp_protocol::ServerNotification,
|
||||
) {
|
||||
match serde_json::to_value(notification) {
|
||||
Ok(serde_json::Value::Object(mut map)) => {
|
||||
let method = map
|
||||
.remove("method")
|
||||
.and_then(|v| v.as_str().map(|s| s.to_string()));
|
||||
let params = map.remove("params").unwrap_or(serde_json::Value::Null);
|
||||
if let Some(method) = method {
|
||||
self.send_custom_notification(&method, params).await;
|
||||
} else {
|
||||
warn!("ServerNotification missing method after serialization");
|
||||
}
|
||||
}
|
||||
Ok(_) => warn!("ServerNotification did not serialize to an object"),
|
||||
Err(err) => warn!("Failed to serialize ServerNotification: {err:?}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Outgoing message from the server to the client.
|
||||
|
||||
9
codex-rs/mcp-server/src/request_id.rs
Normal file
9
codex-rs/mcp-server/src/request_id.rs
Normal file
@@ -0,0 +1,9 @@
|
||||
use mcp_types::RequestId;
|
||||
|
||||
/// Utility to convert an MCP `RequestId` into a `String`.
|
||||
pub(crate) fn request_id_to_string(id: &RequestId) -> String {
|
||||
match id {
|
||||
RequestId::String(s) => s.clone(),
|
||||
RequestId::Integer(i) => i.to_string(),
|
||||
}
|
||||
}
|
||||
@@ -1,18 +1,14 @@
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
use codex_core::Codex;
|
||||
use codex_core::codex_wrapper::init_codex;
|
||||
use codex_core::config::Config as CodexConfig;
|
||||
use codex_core::config::ConfigOverrides;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::SessionConfiguredEvent;
|
||||
use mcp_types::RequestId;
|
||||
use tokio::sync::Mutex;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::conversation_loop::run_conversation_loop;
|
||||
use crate::conversation_loop::Conversation;
|
||||
use crate::json_to_toml::json_to_toml;
|
||||
use crate::mcp_protocol::ConversationCreateArgs;
|
||||
use crate::mcp_protocol::ConversationCreateResult;
|
||||
@@ -121,24 +117,17 @@ pub(crate) async fn handle_create_conversation(
|
||||
let session_id = codex_conversation.session_id;
|
||||
let codex_arc = Arc::new(codex_conversation.codex);
|
||||
|
||||
// Store session for future calls
|
||||
insert_session(
|
||||
session_id,
|
||||
codex_arc.clone(),
|
||||
message_processor.session_map(),
|
||||
)
|
||||
.await;
|
||||
// Run the conversation loop in the background so this request can return immediately.
|
||||
// Construct conversation and start its loop, store it, then reply with id and model
|
||||
let outgoing = message_processor.outgoing();
|
||||
let spawn_id = id.clone();
|
||||
tokio::spawn(async move {
|
||||
run_conversation_loop(codex_arc.clone(), outgoing, spawn_id).await;
|
||||
});
|
||||
|
||||
// Reply with the new conversation id and effective model
|
||||
let conversation = Conversation::new(codex_arc.clone(), outgoing, id.clone(), session_id);
|
||||
let conv_map = message_processor.conversation_map();
|
||||
{
|
||||
let mut guard = conv_map.lock().await;
|
||||
guard.insert(session_id, conversation);
|
||||
}
|
||||
message_processor
|
||||
.send_response_with_optional_error(
|
||||
id,
|
||||
id.clone(),
|
||||
Some(ToolCallResponseResult::ConversationCreate(
|
||||
ConversationCreateResult::Ok {
|
||||
conversation_id: ConversationId(session_id),
|
||||
@@ -149,12 +138,3 @@ pub(crate) async fn handle_create_conversation(
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
async fn insert_session(
|
||||
session_id: Uuid,
|
||||
codex: Arc<Codex>,
|
||||
session_map: Arc<Mutex<HashMap<Uuid, Arc<Codex>>>>,
|
||||
) {
|
||||
let mut guard = session_map.lock().await;
|
||||
guard.insert(session_id, codex);
|
||||
}
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
pub(crate) mod create_conversation;
|
||||
pub(crate) mod send_message;
|
||||
pub(crate) mod stream_conversation;
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use codex_core::Codex;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_core::protocol::Submission;
|
||||
use mcp_types::RequestId;
|
||||
use tokio::sync::Mutex;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::conversation_loop::Conversation;
|
||||
use crate::mcp_protocol::ConversationSendMessageArgs;
|
||||
use crate::mcp_protocol::ConversationSendMessageResult;
|
||||
use crate::mcp_protocol::ToolCallResponseResult;
|
||||
@@ -41,7 +39,8 @@ pub(crate) async fn handle_send_message(
|
||||
}
|
||||
|
||||
let session_id = conversation_id.0;
|
||||
let Some(codex) = get_session(session_id, message_processor.session_map()).await else {
|
||||
let Some(conversation) = get_session(session_id, message_processor.conversation_map()).await
|
||||
else {
|
||||
message_processor
|
||||
.send_response_with_optional_error(
|
||||
id,
|
||||
@@ -56,46 +55,15 @@ pub(crate) async fn handle_send_message(
|
||||
return;
|
||||
};
|
||||
|
||||
let running = {
|
||||
let running_sessions = message_processor.running_session_ids();
|
||||
let mut running_sessions = running_sessions.lock().await;
|
||||
!running_sessions.insert(session_id)
|
||||
};
|
||||
let res = conversation.try_submit_user_input(id.clone(), items).await;
|
||||
|
||||
if running {
|
||||
if let Err(e) = res {
|
||||
message_processor
|
||||
.send_response_with_optional_error(
|
||||
id,
|
||||
Some(ToolCallResponseResult::ConversationSendMessage(
|
||||
ConversationSendMessageResult::Error {
|
||||
message: "Session is already running".to_string(),
|
||||
},
|
||||
)),
|
||||
Some(true),
|
||||
)
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
|
||||
let request_id_string = match &id {
|
||||
RequestId::String(s) => s.clone(),
|
||||
RequestId::Integer(i) => i.to_string(),
|
||||
};
|
||||
|
||||
let submit_res = codex
|
||||
.submit_with_id(Submission {
|
||||
id: request_id_string,
|
||||
op: Op::UserInput { items },
|
||||
})
|
||||
.await;
|
||||
|
||||
if let Err(e) = submit_res {
|
||||
message_processor
|
||||
.send_response_with_optional_error(
|
||||
id,
|
||||
Some(ToolCallResponseResult::ConversationSendMessage(
|
||||
ConversationSendMessageResult::Error {
|
||||
message: format!("Failed to submit user input: {e}"),
|
||||
message: e.to_string(),
|
||||
},
|
||||
)),
|
||||
Some(true),
|
||||
@@ -117,8 +85,8 @@ pub(crate) async fn handle_send_message(
|
||||
|
||||
pub(crate) async fn get_session(
|
||||
session_id: Uuid,
|
||||
session_map: Arc<Mutex<HashMap<Uuid, Arc<Codex>>>>,
|
||||
) -> Option<Arc<Codex>> {
|
||||
let guard = session_map.lock().await;
|
||||
conversation_map: Arc<Mutex<HashMap<Uuid, Arc<Conversation>>>>,
|
||||
) -> Option<Arc<Conversation>> {
|
||||
let guard = conversation_map.lock().await;
|
||||
guard.get(&session_id).cloned()
|
||||
}
|
||||
|
||||
57
codex-rs/mcp-server/src/tool_handlers/stream_conversation.rs
Normal file
57
codex-rs/mcp-server/src/tool_handlers/stream_conversation.rs
Normal file
@@ -0,0 +1,57 @@
|
||||
use mcp_types::RequestId;
|
||||
|
||||
use crate::mcp_protocol::ConversationStreamArgs;
|
||||
use crate::mcp_protocol::ConversationStreamResult;
|
||||
use crate::mcp_protocol::ToolCallResponseResult;
|
||||
use crate::message_processor::MessageProcessor;
|
||||
use crate::tool_handlers::send_message::get_session;
|
||||
|
||||
/// Handles the ConversationStream tool call: verifies the session and
|
||||
/// enables streaming for the session, replying with an OK result.
|
||||
pub(crate) async fn handle_stream_conversation(
|
||||
message_processor: &MessageProcessor,
|
||||
id: RequestId,
|
||||
arguments: ConversationStreamArgs,
|
||||
) {
|
||||
let ConversationStreamArgs { conversation_id } = arguments;
|
||||
|
||||
let session_id = conversation_id.0;
|
||||
|
||||
// Ensure the session exists
|
||||
let conv = get_session(session_id, message_processor.conversation_map()).await;
|
||||
|
||||
if conv.is_none() {
|
||||
// Return an error with no result payload per MCP error pattern
|
||||
message_processor
|
||||
.send_response_with_optional_error(id, None, Some(true))
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
|
||||
message_processor
|
||||
.send_response_with_optional_error(
|
||||
id,
|
||||
Some(ToolCallResponseResult::ConversationStream(
|
||||
ConversationStreamResult {},
|
||||
)),
|
||||
Some(false),
|
||||
)
|
||||
.await;
|
||||
|
||||
if let Some(conv) = conv {
|
||||
tokio::spawn(async move {
|
||||
conv.set_streaming(true).await;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/// Handles cancellation for ConversationStream by disabling streaming for the session.
|
||||
pub(crate) async fn handle_cancel(
|
||||
message_processor: &MessageProcessor,
|
||||
args: &ConversationStreamArgs,
|
||||
) {
|
||||
let session_id = args.conversation_id.0;
|
||||
if let Some(conv) = get_session(session_id, message_processor.conversation_map()).await {
|
||||
conv.set_streaming(false).await;
|
||||
}
|
||||
}
|
||||
26
codex-rs/mcp-server/tests/common/config.rs
Normal file
26
codex-rs/mcp-server/tests/common/config.rs
Normal file
@@ -0,0 +1,26 @@
|
||||
use std::path::Path;
|
||||
|
||||
/// Write a minimal Codex config.toml pointing at the provided mock server URI.
|
||||
/// Used by tests that don't exercise approval/sandbox variations.
|
||||
pub fn create_config_toml(codex_home: &Path, server_uri: &str) -> std::io::Result<()> {
|
||||
let config_toml = codex_home.join("config.toml");
|
||||
std::fs::write(
|
||||
config_toml,
|
||||
format!(
|
||||
r#"
|
||||
model = "mock-model"
|
||||
approval_policy = "never"
|
||||
sandbox_mode = "danger-full-access"
|
||||
|
||||
model_provider = "mock_provider"
|
||||
|
||||
[model_providers.mock_provider]
|
||||
name = "Mock provider for test"
|
||||
base_url = "{server_uri}/v1"
|
||||
wire_api = "chat"
|
||||
request_max_retries = 0
|
||||
stream_max_retries = 0
|
||||
"#
|
||||
),
|
||||
)
|
||||
}
|
||||
@@ -1,7 +1,9 @@
|
||||
mod config;
|
||||
mod mcp_process;
|
||||
mod mock_model_server;
|
||||
mod responses;
|
||||
|
||||
pub use config::create_config_toml;
|
||||
pub use mcp_process::McpProcess;
|
||||
pub use mock_model_server::create_mock_chat_completions_server;
|
||||
pub use responses::create_apply_patch_sse_response;
|
||||
|
||||
@@ -2,6 +2,7 @@ use std::path::Path;
|
||||
use std::process::Stdio;
|
||||
use std::sync::atomic::AtomicI64;
|
||||
use std::sync::atomic::Ordering;
|
||||
use std::time::Duration;
|
||||
use tokio::io::AsyncBufReadExt;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::io::BufReader;
|
||||
@@ -17,6 +18,7 @@ use codex_mcp_server::CodexToolCallReplyParam;
|
||||
use codex_mcp_server::mcp_protocol::ConversationCreateArgs;
|
||||
use codex_mcp_server::mcp_protocol::ConversationId;
|
||||
use codex_mcp_server::mcp_protocol::ConversationSendMessageArgs;
|
||||
use codex_mcp_server::mcp_protocol::ConversationStreamArgs;
|
||||
use codex_mcp_server::mcp_protocol::ToolCallRequestParams;
|
||||
|
||||
use mcp_types::CallToolRequestParams;
|
||||
@@ -201,6 +203,20 @@ impl McpProcess {
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn send_conversation_stream_tool_call(
|
||||
&mut self,
|
||||
session_id: &str,
|
||||
) -> anyhow::Result<i64> {
|
||||
let params = ToolCallRequestParams::ConversationStream(ConversationStreamArgs {
|
||||
conversation_id: ConversationId(Uuid::parse_str(session_id)?),
|
||||
});
|
||||
self.send_request(
|
||||
mcp_types::CallToolRequest::METHOD,
|
||||
Some(serde_json::to_value(params)?),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn send_conversation_create_tool_call(
|
||||
&mut self,
|
||||
prompt: &str,
|
||||
@@ -236,6 +252,83 @@ impl McpProcess {
|
||||
.await
|
||||
}
|
||||
|
||||
/// Create a conversation and return its conversation_id as a string.
|
||||
pub async fn create_conversation_and_get_id(
|
||||
&mut self,
|
||||
prompt: &str,
|
||||
model: &str,
|
||||
cwd: &str,
|
||||
) -> anyhow::Result<String> {
|
||||
let req_id = self
|
||||
.send_conversation_create_tool_call(prompt, model, cwd)
|
||||
.await?;
|
||||
let resp = self
|
||||
.read_stream_until_response_message(RequestId::Integer(req_id))
|
||||
.await?;
|
||||
let conv_id = resp.result["structuredContent"]["conversation_id"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::format_err!("missing conversation_id"))?
|
||||
.to_string();
|
||||
Ok(conv_id)
|
||||
}
|
||||
|
||||
/// Connect stream for a conversation and wait for the initial_state notification.
|
||||
/// Returns (requestId, params) where params are the initial_state notification params.
|
||||
pub async fn connect_stream_and_expect_initial_state(
|
||||
&mut self,
|
||||
session_id: &str,
|
||||
) -> anyhow::Result<(i64, serde_json::Value)> {
|
||||
let req_id = self.send_conversation_stream_tool_call(session_id).await?;
|
||||
// Wait for stream() tool-call response first
|
||||
let _ = self
|
||||
.read_stream_until_response_message(RequestId::Integer(req_id))
|
||||
.await?;
|
||||
// Then the initial_state notification
|
||||
let note = self
|
||||
.read_stream_until_notification_method("notifications/initial_state")
|
||||
.await?;
|
||||
let params = note
|
||||
.params
|
||||
.ok_or_else(|| anyhow::format_err!("initial_state must have params"))?;
|
||||
Ok((req_id, params))
|
||||
}
|
||||
|
||||
/// Wait for an agent_message with a bounded timeout. Returns Some(params) if received, None on timeout.
|
||||
pub async fn maybe_wait_for_agent_message(
|
||||
&mut self,
|
||||
dur: Duration,
|
||||
) -> anyhow::Result<Option<serde_json::Value>> {
|
||||
match tokio::time::timeout(dur, self.wait_for_agent_message()).await {
|
||||
Ok(Ok(v)) => Ok(Some(v)),
|
||||
Ok(Err(e)) => Err(e),
|
||||
Err(_elapsed) => Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
/// Send a user message to a conversation and wait for the OK tool-call response.
|
||||
pub async fn send_user_message_and_wait_ok(
|
||||
&mut self,
|
||||
message: &str,
|
||||
session_id: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
let req_id = self
|
||||
.send_user_message_tool_call(message, session_id)
|
||||
.await?;
|
||||
let _ = self
|
||||
.read_stream_until_response_message(RequestId::Integer(req_id))
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Wait until an agent_message notification arrives; returns its params.
|
||||
pub async fn wait_for_agent_message(&mut self) -> anyhow::Result<serde_json::Value> {
|
||||
let note = self
|
||||
.read_stream_until_notification_method("agent_message")
|
||||
.await?;
|
||||
note.params
|
||||
.ok_or_else(|| anyhow::format_err!("agent_message missing params"))
|
||||
}
|
||||
|
||||
async fn send_request(
|
||||
&mut self,
|
||||
method: &str,
|
||||
@@ -329,53 +422,51 @@ impl McpProcess {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn read_stream_until_notification_method(
|
||||
&mut self,
|
||||
method: &str,
|
||||
) -> anyhow::Result<JSONRPCNotification> {
|
||||
loop {
|
||||
let message = self.read_jsonrpc_message().await?;
|
||||
match message {
|
||||
JSONRPCMessage::Notification(n) => {
|
||||
if n.method == method {
|
||||
return Ok(n);
|
||||
}
|
||||
}
|
||||
JSONRPCMessage::Request(_) => {
|
||||
// ignore
|
||||
}
|
||||
JSONRPCMessage::Error(_) => {
|
||||
anyhow::bail!("unexpected JSONRPCMessage::Error: {message:?}");
|
||||
}
|
||||
JSONRPCMessage::Response(_) => {
|
||||
// ignore
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn read_stream_until_configured_response_message(
|
||||
&mut self,
|
||||
) -> anyhow::Result<String> {
|
||||
let mut sid_old: Option<String> = None;
|
||||
let mut sid_new: Option<String> = None;
|
||||
loop {
|
||||
let message = self.read_jsonrpc_message().await?;
|
||||
eprint!("message: {message:?}");
|
||||
|
||||
match message {
|
||||
JSONRPCMessage::Notification(notification) => {
|
||||
if let Some(params) = notification.params {
|
||||
// Back-compat schema: method == "codex/event" and msg.type == "session_configured"
|
||||
if notification.method == "codex/event" {
|
||||
if let Some(msg) = params.get("msg") {
|
||||
if msg.get("type").and_then(|v| v.as_str())
|
||||
== Some("session_configured")
|
||||
{
|
||||
if let Some(session_id) =
|
||||
msg.get("session_id").and_then(|v| v.as_str())
|
||||
{
|
||||
sid_old = Some(session_id.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// New schema: method is the Display of EventMsg::SessionConfigured => "SessionConfigured"
|
||||
if notification.method == "session_configured" {
|
||||
if notification.method == "session_configured" {
|
||||
if let Some(params) = notification.params {
|
||||
if let Some(msg) = params.get("msg") {
|
||||
if let Some(session_id) =
|
||||
msg.get("session_id").and_then(|v| v.as_str())
|
||||
{
|
||||
sid_new = Some(session_id.to_string());
|
||||
return Ok(session_id.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if sid_old.is_some() && sid_new.is_some() {
|
||||
// Both seen, they must match
|
||||
assert_eq!(
|
||||
sid_old.as_ref().unwrap(),
|
||||
sid_new.as_ref().unwrap(),
|
||||
"session_id mismatch between old and new schema"
|
||||
);
|
||||
return Ok(sid_old.unwrap());
|
||||
}
|
||||
}
|
||||
JSONRPCMessage::Request(_) => {
|
||||
anyhow::bail!("unexpected JSONRPCMessage::Request: {message:?}");
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
#![allow(clippy::expect_used, clippy::unwrap_used)]
|
||||
|
||||
use std::path::Path;
|
||||
|
||||
use mcp_test_support::McpProcess;
|
||||
use mcp_test_support::create_config_toml;
|
||||
use mcp_test_support::create_final_assistant_message_sse_response;
|
||||
use mcp_test_support::create_mock_chat_completions_server;
|
||||
use mcp_types::JSONRPCResponse;
|
||||
@@ -103,26 +102,4 @@ async fn test_conversation_create_and_send_message_ok() {
|
||||
drop(server);
|
||||
}
|
||||
|
||||
// Helper to create a config.toml pointing at the mock model server.
|
||||
fn create_config_toml(codex_home: &Path, server_uri: &str) -> std::io::Result<()> {
|
||||
let config_toml = codex_home.join("config.toml");
|
||||
std::fs::write(
|
||||
config_toml,
|
||||
format!(
|
||||
r#"
|
||||
model = "mock-model"
|
||||
approval_policy = "never"
|
||||
sandbox_mode = "danger-full-access"
|
||||
|
||||
model_provider = "mock_provider"
|
||||
|
||||
[model_providers.mock_provider]
|
||||
name = "Mock provider for test"
|
||||
base_url = "{server_uri}/v1"
|
||||
wire_api = "chat"
|
||||
request_max_retries = 0
|
||||
stream_max_retries = 0
|
||||
"#
|
||||
),
|
||||
)
|
||||
}
|
||||
// create_config_toml is provided by tests/common
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
#![cfg(unix)]
|
||||
// Support code lives in the `mcp_test_support` crate under tests/common.
|
||||
|
||||
use std::path::Path;
|
||||
|
||||
use codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR;
|
||||
use codex_mcp_server::CodexToolCallParam;
|
||||
use mcp_types::JSONRPCResponse;
|
||||
use mcp_types::ModelContextProtocolNotification;
|
||||
use mcp_types::RequestId;
|
||||
use serde_json::json;
|
||||
use tempfile::TempDir;
|
||||
use tokio::time::timeout;
|
||||
|
||||
use mcp_test_support::McpProcess;
|
||||
use mcp_test_support::create_config_toml;
|
||||
use mcp_test_support::create_mock_chat_completions_server;
|
||||
use mcp_test_support::create_shell_sse_response;
|
||||
|
||||
@@ -66,7 +66,7 @@ async fn shell_command_interruption() -> anyhow::Result<()> {
|
||||
|
||||
// Create Codex configuration
|
||||
let codex_home = TempDir::new()?;
|
||||
create_config_toml(codex_home.path(), server.uri())?;
|
||||
create_config_toml(codex_home.path(), &server.uri())?;
|
||||
let mut mcp_process = McpProcess::new(codex_home.path()).await?;
|
||||
timeout(DEFAULT_READ_TIMEOUT, mcp_process.initialize()).await??;
|
||||
|
||||
@@ -95,7 +95,7 @@ async fn shell_command_interruption() -> anyhow::Result<()> {
|
||||
// Send interrupt notification
|
||||
mcp_process
|
||||
.send_notification(
|
||||
"notifications/cancelled",
|
||||
mcp_types::CancelledNotification::METHOD,
|
||||
Some(json!({ "requestId": codex_request_id })),
|
||||
)
|
||||
.await?;
|
||||
@@ -126,7 +126,7 @@ async fn shell_command_interruption() -> anyhow::Result<()> {
|
||||
// Send interrupt notification
|
||||
mcp_process
|
||||
.send_notification(
|
||||
"notifications/cancelled",
|
||||
mcp_types::CancelledNotification::METHOD,
|
||||
Some(json!({ "requestId": codex_reply_request_id })),
|
||||
)
|
||||
.await?;
|
||||
@@ -148,30 +148,3 @@ async fn shell_command_interruption() -> anyhow::Result<()> {
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
fn create_config_toml(codex_home: &Path, server_uri: String) -> std::io::Result<()> {
|
||||
let config_toml = codex_home.join("config.toml");
|
||||
std::fs::write(
|
||||
config_toml,
|
||||
format!(
|
||||
r#"
|
||||
model = "mock-model"
|
||||
approval_policy = "never"
|
||||
sandbox_mode = "danger-full-access"
|
||||
|
||||
model_provider = "mock_provider"
|
||||
|
||||
[model_providers.mock_provider]
|
||||
name = "Mock provider for test"
|
||||
base_url = "{server_uri}/v1"
|
||||
wire_api = "chat"
|
||||
request_max_retries = 0
|
||||
stream_max_retries = 0
|
||||
"#
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
#![allow(clippy::expect_used)]
|
||||
|
||||
use std::path::Path;
|
||||
use std::thread::sleep;
|
||||
use std::time::Duration;
|
||||
|
||||
use codex_mcp_server::CodexToolCallParam;
|
||||
use mcp_test_support::McpProcess;
|
||||
use mcp_test_support::create_config_toml;
|
||||
use mcp_test_support::create_final_assistant_message_sse_response;
|
||||
use mcp_test_support::create_mock_chat_completions_server;
|
||||
use mcp_types::JSONRPC_VERSION;
|
||||
@@ -20,11 +19,9 @@ const DEFAULT_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn test_send_message_success() {
|
||||
// Spin up a mock completions server that immediately ends the Codex turn.
|
||||
// Two Codex turns hit the mock model (session start + send-user-message). Provide two SSE responses.
|
||||
// Spin up a mock completions server that ends the Codex turn for the send-user-message call.
|
||||
let responses = vec![
|
||||
create_final_assistant_message_sse_response("Done").expect("build mock assistant message"),
|
||||
create_final_assistant_message_sse_response("Done").expect("build mock assistant message"),
|
||||
];
|
||||
let server = create_mock_chat_completions_server(responses).await;
|
||||
|
||||
@@ -41,29 +38,11 @@ async fn test_send_message_success() {
|
||||
.expect("init timed out")
|
||||
.expect("init failed");
|
||||
|
||||
// Kick off a Codex session so we have a valid session_id.
|
||||
let codex_request_id = mcp_process
|
||||
.send_codex_tool_call(CodexToolCallParam {
|
||||
prompt: "Start a session".to_string(),
|
||||
..Default::default()
|
||||
})
|
||||
.await
|
||||
.expect("send codex tool call");
|
||||
|
||||
// Wait for the session_configured event to get the session_id.
|
||||
// Create a conversation using the tool and get its conversation_id
|
||||
let session_id = mcp_process
|
||||
.read_stream_until_configured_response_message()
|
||||
.create_conversation_and_get_id("", "mock-model", "/repo")
|
||||
.await
|
||||
.expect("read session_configured");
|
||||
|
||||
// The original codex call will finish quickly given our mock; consume its response.
|
||||
timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp_process.read_stream_until_response_message(RequestId::Integer(codex_request_id)),
|
||||
)
|
||||
.await
|
||||
.expect("codex response timeout")
|
||||
.expect("codex response error");
|
||||
.expect("create conversation");
|
||||
|
||||
// Now exercise the send-user-message tool.
|
||||
let send_msg_request_id = mcp_process
|
||||
@@ -135,29 +114,4 @@ async fn test_send_message_session_not_found() {
|
||||
assert_eq!(result["isError"], json!(true));
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
fn create_config_toml(codex_home: &Path, server_uri: &str) -> std::io::Result<()> {
|
||||
let config_toml = codex_home.join("config.toml");
|
||||
std::fs::write(
|
||||
config_toml,
|
||||
format!(
|
||||
r#"
|
||||
model = "mock-model"
|
||||
approval_policy = "never"
|
||||
sandbox_mode = "danger-full-access"
|
||||
|
||||
model_provider = "mock_provider"
|
||||
|
||||
[model_providers.mock_provider]
|
||||
name = "Mock provider for test"
|
||||
base_url = "{server_uri}/v1"
|
||||
wire_api = "chat"
|
||||
request_max_retries = 0
|
||||
stream_max_retries = 0
|
||||
"#
|
||||
),
|
||||
)
|
||||
}
|
||||
// Helpers are provided by tests/common
|
||||
|
||||
251
codex-rs/mcp-server/tests/stream_conversation.rs
Normal file
251
codex-rs/mcp-server/tests/stream_conversation.rs
Normal file
@@ -0,0 +1,251 @@
|
||||
#![allow(clippy::expect_used, clippy::unwrap_used)]
|
||||
|
||||
use mcp_test_support::McpProcess;
|
||||
use mcp_test_support::create_config_toml;
|
||||
use mcp_test_support::create_final_assistant_message_sse_response;
|
||||
use mcp_test_support::create_mock_chat_completions_server;
|
||||
use mcp_types::JSONRPCNotification;
|
||||
use mcp_types::ModelContextProtocolNotification;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
use tempfile::TempDir;
|
||||
use tokio::time::timeout;
|
||||
|
||||
const DEFAULT_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(3);
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn test_connect_then_send_receives_initial_state_and_notifications() {
|
||||
let responses = vec![
|
||||
create_final_assistant_message_sse_response("Done").expect("build mock assistant message"),
|
||||
];
|
||||
let server = create_mock_chat_completions_server(responses).await;
|
||||
|
||||
let codex_home = TempDir::new().expect("create temp dir");
|
||||
create_config_toml(codex_home.path(), &server.uri()).expect("write config.toml");
|
||||
|
||||
let mut mcp = McpProcess::new(codex_home.path())
|
||||
.await
|
||||
.expect("spawn mcp process");
|
||||
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize())
|
||||
.await
|
||||
.expect("init timeout")
|
||||
.expect("init failed");
|
||||
|
||||
// Create conversation
|
||||
let conv_id = mcp
|
||||
.create_conversation_and_get_id("", "o3", "/repo")
|
||||
.await
|
||||
.expect("create conversation");
|
||||
|
||||
// Connect the stream
|
||||
let (_stream_req, params) = mcp
|
||||
.connect_stream_and_expect_initial_state(&conv_id)
|
||||
.await
|
||||
.expect("initial_state params");
|
||||
let expected_params = json!({
|
||||
"_meta": {
|
||||
"conversationId": conv_id.as_str(),
|
||||
},
|
||||
"initial_state": {
|
||||
"events": []
|
||||
}
|
||||
});
|
||||
assert_eq!(params, expected_params);
|
||||
|
||||
// Send a message and expect a subsequent notification (non-initial_state)
|
||||
mcp.send_user_message_and_wait_ok("Hello there", &conv_id)
|
||||
.await
|
||||
.expect("send message ok");
|
||||
|
||||
// Read until we see an event notification (new schema example: agent_message)
|
||||
let params = mcp.wait_for_agent_message().await.expect("agent message");
|
||||
let expected_params = json!({
|
||||
"msg": {
|
||||
"type": "agent_message",
|
||||
"message": "Done"
|
||||
}
|
||||
});
|
||||
assert_eq!(params, expected_params);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn test_send_then_connect_receives_initial_state_with_message() {
|
||||
let responses = vec![
|
||||
create_final_assistant_message_sse_response("Done").expect("build mock assistant message"),
|
||||
];
|
||||
let server = create_mock_chat_completions_server(responses).await;
|
||||
|
||||
let codex_home = TempDir::new().expect("create temp dir");
|
||||
create_config_toml(codex_home.path(), &server.uri()).expect("write config.toml");
|
||||
|
||||
let mut mcp = McpProcess::new(codex_home.path())
|
||||
.await
|
||||
.expect("spawn mcp process");
|
||||
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize())
|
||||
.await
|
||||
.expect("init timeout")
|
||||
.expect("init failed");
|
||||
|
||||
// Create conversation
|
||||
let conv_id = mcp
|
||||
.create_conversation_and_get_id("", "o3", "/repo")
|
||||
.await
|
||||
.expect("create conversation");
|
||||
|
||||
// Send a message BEFORE connecting stream
|
||||
mcp.send_user_message_and_wait_ok("Hello world", &conv_id)
|
||||
.await
|
||||
.expect("send message ok");
|
||||
|
||||
// Now connect stream and expect InitialState with the prior message included
|
||||
let (_stream_req, params) = mcp
|
||||
.connect_stream_and_expect_initial_state(&conv_id)
|
||||
.await
|
||||
.expect("initial_state params");
|
||||
let events = params["initial_state"]["events"]
|
||||
.as_array()
|
||||
.expect("events array");
|
||||
if !events.iter().any(|ev| {
|
||||
ev.get("msg")
|
||||
.and_then(|m| m.get("type"))
|
||||
.and_then(|t| t.as_str())
|
||||
== Some("agent_message")
|
||||
&& ev
|
||||
.get("msg")
|
||||
.and_then(|m| m.get("message"))
|
||||
.and_then(|t| t.as_str())
|
||||
== Some("Done")
|
||||
}) {
|
||||
// Fallback to live notification if not present in initial state
|
||||
let note: JSONRPCNotification = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_notification_method("agent_message"),
|
||||
)
|
||||
.await
|
||||
.expect("event note timeout")
|
||||
.expect("event note err");
|
||||
let params = note.params.expect("params");
|
||||
let expected_params = json!({
|
||||
"msg": {
|
||||
"type": "agent_message",
|
||||
"message": "Done"
|
||||
}
|
||||
});
|
||||
assert_eq!(params, expected_params);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn test_cancel_stream_then_reconnect_catches_up_initial_state() {
|
||||
// One response is sufficient for the assertions in this test
|
||||
let responses = vec![
|
||||
create_final_assistant_message_sse_response("Done 1")
|
||||
.expect("build mock assistant message"),
|
||||
create_final_assistant_message_sse_response("Done 2")
|
||||
.expect("build mock assistant message"),
|
||||
];
|
||||
let server = create_mock_chat_completions_server(responses).await;
|
||||
|
||||
let codex_home = TempDir::new().expect("create temp dir");
|
||||
create_config_toml(codex_home.path(), &server.uri()).expect("write config.toml");
|
||||
|
||||
let mut mcp = McpProcess::new(codex_home.path())
|
||||
.await
|
||||
.expect("spawn mcp process");
|
||||
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize())
|
||||
.await
|
||||
.expect("init timeout")
|
||||
.expect("init failed");
|
||||
|
||||
// Create and connect stream A
|
||||
let conv_id = mcp
|
||||
.create_conversation_and_get_id("", "o3", "/repo")
|
||||
.await
|
||||
.expect("create");
|
||||
let (stream_a_id, _params) = mcp
|
||||
.connect_stream_and_expect_initial_state(&conv_id)
|
||||
.await
|
||||
.expect("stream A initial_state");
|
||||
|
||||
// Send M1 and ensure we get live agent_message
|
||||
mcp.send_user_message_and_wait_ok("Hello M1", &conv_id)
|
||||
.await
|
||||
.expect("send M1");
|
||||
let _params = mcp.wait_for_agent_message().await.expect("agent M1");
|
||||
|
||||
// Ensure the first task has fully completed before cancelling the stream
|
||||
// so that the session is no longer marked as running.
|
||||
let _ = mcp
|
||||
.read_stream_until_notification_method("task_complete")
|
||||
.await
|
||||
.expect("task complete");
|
||||
|
||||
// Cancel stream A
|
||||
mcp.send_notification(
|
||||
mcp_types::CancelledNotification::METHOD,
|
||||
Some(json!({ "requestId": stream_a_id })),
|
||||
)
|
||||
.await
|
||||
.expect("send cancelled");
|
||||
|
||||
// Send M2 while stream is cancelled; we should NOT get agent_message live
|
||||
mcp.send_user_message_and_wait_ok("Hello M2", &conv_id)
|
||||
.await
|
||||
.expect("send M2");
|
||||
let maybe = mcp
|
||||
.maybe_wait_for_agent_message(std::time::Duration::from_millis(300))
|
||||
.await
|
||||
.expect("maybe wait");
|
||||
assert!(
|
||||
maybe.is_none(),
|
||||
"should not get live agent_message after cancel"
|
||||
);
|
||||
|
||||
// Connect stream B and expect initial_state that includes the response
|
||||
let (_stream_req, params) = mcp
|
||||
.connect_stream_and_expect_initial_state(&conv_id)
|
||||
.await
|
||||
.expect("stream B initial_state");
|
||||
let events = params["initial_state"]["events"]
|
||||
.as_array()
|
||||
.expect("events array");
|
||||
let expected = vec![
|
||||
json!({
|
||||
"msg": {
|
||||
"type": "task_started",
|
||||
},
|
||||
}),
|
||||
json!({
|
||||
"msg": {
|
||||
"message": "Done 1",
|
||||
"type": "agent_message",
|
||||
},
|
||||
}),
|
||||
json!({
|
||||
"msg": {
|
||||
"last_agent_message": "Done 1",
|
||||
"type": "task_complete",
|
||||
},
|
||||
}),
|
||||
json!({
|
||||
"msg": {
|
||||
"type": "task_started",
|
||||
},
|
||||
}),
|
||||
json!({
|
||||
"msg": {
|
||||
"message": "Done 2",
|
||||
"type": "agent_message",
|
||||
},
|
||||
}),
|
||||
json!({
|
||||
"msg": {
|
||||
"last_agent_message": "Done 2",
|
||||
"type": "task_complete",
|
||||
},
|
||||
}),
|
||||
];
|
||||
assert_eq!(*events, expected);
|
||||
}
|
||||
|
||||
//
|
||||
Reference in New Issue
Block a user