mirror of
https://github.com/openai/codex.git
synced 2026-04-24 14:45:27 +00:00
Merge branch 'main' into pap/model-selection
This commit is contained in:
21
SUMMARY.md
Normal file
21
SUMMARY.md
Normal file
@@ -0,0 +1,21 @@
|
||||
You are a summarization assistant. A conversation follows between a user and a coding-focused AI (Codex). Your task is to generate a clear summary capturing:
|
||||
|
||||
• High-level objective or problem being solved
|
||||
• Key instructions or design decisions given by the user
|
||||
• Main code actions or behaviors from the AI
|
||||
• Important variables, functions, modules, or outputs discussed
|
||||
• Any unresolved questions or next steps
|
||||
|
||||
Produce the summary in a structured format like:
|
||||
|
||||
**Objective:** …
|
||||
|
||||
**User instructions:** … (bulleted)
|
||||
|
||||
**AI actions / code behavior:** … (bulleted)
|
||||
|
||||
**Important entities:** … (e.g. function names, variables, files)
|
||||
|
||||
**Open issues / next steps:** … (if any)
|
||||
|
||||
**Summary (concise):** (one or two sentences)
|
||||
2
codex-rs/Cargo.lock
generated
2
codex-rs/Cargo.lock
generated
@@ -2644,6 +2644,7 @@ version = "0.0.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"assert_cmd",
|
||||
"codex-core",
|
||||
"codex-mcp-server",
|
||||
"mcp-types",
|
||||
"pretty_assertions",
|
||||
@@ -2651,6 +2652,7 @@ dependencies = [
|
||||
"shlex",
|
||||
"tempfile",
|
||||
"tokio",
|
||||
"uuid",
|
||||
"wiremock",
|
||||
]
|
||||
|
||||
|
||||
@@ -207,7 +207,14 @@ impl ModelClient {
|
||||
}
|
||||
}
|
||||
|
||||
let req_builder = self.provider.apply_http_headers(req_builder);
|
||||
req_builder = self.provider.apply_http_headers(req_builder);
|
||||
|
||||
let originator = self
|
||||
.config
|
||||
.internal_originator
|
||||
.as_deref()
|
||||
.unwrap_or("codex_cli_rs");
|
||||
req_builder = req_builder.header("originator", originator);
|
||||
|
||||
let res = req_builder.send().await;
|
||||
if let Ok(resp) = &res {
|
||||
|
||||
@@ -430,6 +430,12 @@ impl Session {
|
||||
let _ = self.tx_event.send(event).await;
|
||||
}
|
||||
|
||||
/// Build the full turn input by concatenating the current conversation
|
||||
/// history with additional items for this turn.
|
||||
pub fn turn_input_with_history(&self, extra: Vec<ResponseItem>) -> Vec<ResponseItem> {
|
||||
[self.state.lock().unwrap().history.contents(), extra].concat()
|
||||
}
|
||||
|
||||
/// Returns the input if there was no task running to inject into
|
||||
pub fn inject_input(&self, input: Vec<InputItem>) -> Result<(), Vec<InputItem>> {
|
||||
let mut state = self.state.lock().unwrap();
|
||||
@@ -552,6 +558,25 @@ impl AgentTask {
|
||||
handle,
|
||||
}
|
||||
}
|
||||
fn compact(
|
||||
sess: Arc<Session>,
|
||||
sub_id: String,
|
||||
input: Vec<InputItem>,
|
||||
compact_instructions: String,
|
||||
) -> Self {
|
||||
let handle = tokio::spawn(run_compact_task(
|
||||
Arc::clone(&sess),
|
||||
sub_id.clone(),
|
||||
input,
|
||||
compact_instructions,
|
||||
))
|
||||
.abort_handle();
|
||||
Self {
|
||||
sess,
|
||||
sub_id,
|
||||
handle,
|
||||
}
|
||||
}
|
||||
|
||||
fn abort(self) {
|
||||
if !self.handle.is_finished() {
|
||||
@@ -878,6 +903,31 @@ async fn submission_loop(
|
||||
}
|
||||
});
|
||||
}
|
||||
Op::Compact => {
|
||||
let sess = match sess.as_ref() {
|
||||
Some(sess) => sess,
|
||||
None => {
|
||||
send_no_session_event(sub.id).await;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// Create a summarization request as user input
|
||||
const SUMMARIZATION_PROMPT: &str = include_str!("../../../SUMMARY.md");
|
||||
|
||||
// Attempt to inject input into current task
|
||||
if let Err(items) = sess.inject_input(vec![InputItem::Text {
|
||||
text: "Start Summarization".to_string(),
|
||||
}]) {
|
||||
let task = AgentTask::compact(
|
||||
sess.clone(),
|
||||
sub.id,
|
||||
items,
|
||||
SUMMARIZATION_PROMPT.to_string(),
|
||||
);
|
||||
sess.set_task(task);
|
||||
}
|
||||
}
|
||||
Op::Shutdown => {
|
||||
info!("Shutting down Codex instance");
|
||||
|
||||
@@ -939,7 +989,7 @@ async fn run_task(sess: Arc<Session>, sub_id: String, input: Vec<InputItem>) {
|
||||
return;
|
||||
}
|
||||
|
||||
let initial_input_for_turn = ResponseInputItem::from(input);
|
||||
let initial_input_for_turn: ResponseInputItem = ResponseInputItem::from(input);
|
||||
sess.record_conversation_items(&[initial_input_for_turn.clone().into()])
|
||||
.await;
|
||||
|
||||
@@ -960,8 +1010,7 @@ async fn run_task(sess: Arc<Session>, sub_id: String, input: Vec<InputItem>) {
|
||||
// conversation history on each turn. The rollout file, however, should
|
||||
// only record the new items that originated in this turn so that it
|
||||
// represents an append-only log without duplicates.
|
||||
let turn_input: Vec<ResponseItem> =
|
||||
[sess.state.lock().unwrap().history.contents(), pending_input].concat();
|
||||
let turn_input: Vec<ResponseItem> = sess.turn_input_with_history(pending_input);
|
||||
|
||||
let turn_input_messages: Vec<String> = turn_input
|
||||
.iter()
|
||||
@@ -1287,6 +1336,88 @@ async fn try_run_turn(
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_compact_task(
|
||||
sess: Arc<Session>,
|
||||
sub_id: String,
|
||||
input: Vec<InputItem>,
|
||||
compact_instructions: String,
|
||||
) {
|
||||
let start_event = Event {
|
||||
id: sub_id.clone(),
|
||||
msg: EventMsg::TaskStarted,
|
||||
};
|
||||
if sess.tx_event.send(start_event).await.is_err() {
|
||||
return;
|
||||
}
|
||||
|
||||
let initial_input_for_turn: ResponseInputItem = ResponseInputItem::from(input);
|
||||
let turn_input: Vec<ResponseItem> =
|
||||
sess.turn_input_with_history(vec![initial_input_for_turn.clone().into()]);
|
||||
|
||||
let prompt = Prompt {
|
||||
input: turn_input,
|
||||
user_instructions: None,
|
||||
store: !sess.disable_response_storage,
|
||||
extra_tools: HashMap::new(),
|
||||
base_instructions_override: Some(compact_instructions.clone()),
|
||||
};
|
||||
|
||||
let max_retries = sess.client.get_provider().stream_max_retries();
|
||||
let mut retries = 0;
|
||||
|
||||
loop {
|
||||
let attempt_result = drain_to_completed(&sess, &prompt).await;
|
||||
|
||||
match attempt_result {
|
||||
Ok(()) => break,
|
||||
Err(CodexErr::Interrupted) => return,
|
||||
Err(e) => {
|
||||
if retries < max_retries {
|
||||
retries += 1;
|
||||
let delay = backoff(retries);
|
||||
sess.notify_background_event(
|
||||
&sub_id,
|
||||
format!(
|
||||
"stream error: {e}; retrying {retries}/{max_retries} in {delay:?}…"
|
||||
),
|
||||
)
|
||||
.await;
|
||||
tokio::time::sleep(delay).await;
|
||||
continue;
|
||||
} else {
|
||||
let event = Event {
|
||||
id: sub_id.clone(),
|
||||
msg: EventMsg::Error(ErrorEvent {
|
||||
message: e.to_string(),
|
||||
}),
|
||||
};
|
||||
sess.send_event(event).await;
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sess.remove_task(&sub_id);
|
||||
let event = Event {
|
||||
id: sub_id.clone(),
|
||||
msg: EventMsg::AgentMessage(AgentMessageEvent {
|
||||
message: "Compact task completed".to_string(),
|
||||
}),
|
||||
};
|
||||
sess.send_event(event).await;
|
||||
let event = Event {
|
||||
id: sub_id.clone(),
|
||||
msg: EventMsg::TaskComplete(TaskCompleteEvent {
|
||||
last_agent_message: None,
|
||||
}),
|
||||
};
|
||||
sess.send_event(event).await;
|
||||
|
||||
let mut state = sess.state.lock().unwrap();
|
||||
state.history.keep_last_messages(1);
|
||||
}
|
||||
|
||||
async fn handle_response_item(
|
||||
sess: &Session,
|
||||
sub_id: &str,
|
||||
@@ -1852,3 +1983,20 @@ fn get_last_assistant_message_from_turn(responses: &[ResponseItem]) -> Option<St
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
async fn drain_to_completed(sess: &Session, prompt: &Prompt) -> CodexResult<()> {
|
||||
let mut stream = sess.client.clone().stream(prompt).await?;
|
||||
loop {
|
||||
let maybe_event = stream.next().await;
|
||||
let Some(event) = maybe_event else {
|
||||
return Err(CodexErr::Stream(
|
||||
"stream closed before response.completed".into(),
|
||||
));
|
||||
};
|
||||
match event {
|
||||
Ok(ResponseEvent::Completed { .. }) => return Ok(()),
|
||||
Ok(_) => continue,
|
||||
Err(e) => return Err(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -147,6 +147,9 @@ pub struct Config {
|
||||
|
||||
/// Include an experimental plan tool that the model can use to update its current plan and status of each step.
|
||||
pub include_plan_tool: bool,
|
||||
|
||||
/// The value for the `originator` header included with Responses API requests.
|
||||
pub internal_originator: Option<String>,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
@@ -363,6 +366,9 @@ pub struct ConfigToml {
|
||||
|
||||
/// Experimental path to a file whose contents replace the built-in BASE_INSTRUCTIONS.
|
||||
pub experimental_instructions_file: Option<PathBuf>,
|
||||
|
||||
/// The value for the `originator` header included with Responses API requests.
|
||||
pub internal_originator: Option<String>,
|
||||
}
|
||||
|
||||
impl ConfigToml {
|
||||
@@ -556,6 +562,7 @@ impl Config {
|
||||
|
||||
experimental_resume,
|
||||
include_plan_tool: include_plan_tool.unwrap_or(false),
|
||||
internal_originator: cfg.internal_originator,
|
||||
};
|
||||
Ok(config)
|
||||
}
|
||||
@@ -914,6 +921,7 @@ disable_response_storage = true
|
||||
experimental_resume: None,
|
||||
base_instructions: None,
|
||||
include_plan_tool: false,
|
||||
internal_originator: None,
|
||||
},
|
||||
o3_profile_config
|
||||
);
|
||||
@@ -963,6 +971,7 @@ disable_response_storage = true
|
||||
experimental_resume: None,
|
||||
base_instructions: None,
|
||||
include_plan_tool: false,
|
||||
internal_originator: None,
|
||||
};
|
||||
|
||||
assert_eq!(expected_gpt3_profile_config, gpt3_profile_config);
|
||||
@@ -1027,6 +1036,7 @@ disable_response_storage = true
|
||||
experimental_resume: None,
|
||||
base_instructions: None,
|
||||
include_plan_tool: false,
|
||||
internal_originator: None,
|
||||
};
|
||||
|
||||
assert_eq!(expected_zdr_profile_config, zdr_profile_config);
|
||||
|
||||
@@ -30,6 +30,34 @@ impl ConversationHistory {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn keep_last_messages(&mut self, n: usize) {
|
||||
if n == 0 {
|
||||
self.items.clear();
|
||||
return;
|
||||
}
|
||||
|
||||
// Collect the last N message items (assistant/user), newest to oldest.
|
||||
let mut kept: Vec<ResponseItem> = Vec::with_capacity(n);
|
||||
for item in self.items.iter().rev() {
|
||||
if let ResponseItem::Message { role, content, .. } = item {
|
||||
kept.push(ResponseItem::Message {
|
||||
// we need to remove the id or the model will complain that messages are sent without
|
||||
// their reasonings
|
||||
id: None,
|
||||
role: role.clone(),
|
||||
content: content.clone(),
|
||||
});
|
||||
if kept.len() == n {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Preserve chronological order (oldest to newest) within the kept slice.
|
||||
kept.reverse();
|
||||
self.items = kept;
|
||||
}
|
||||
}
|
||||
|
||||
/// Anything that is not a system message or "reasoning" message is considered
|
||||
|
||||
@@ -12,10 +12,6 @@ use std::env::VarError;
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::error::EnvVarError;
|
||||
|
||||
/// Value for the `OpenAI-Originator` header that is sent with requests to
|
||||
/// OpenAI.
|
||||
const OPENAI_ORIGINATOR_HEADER: &str = "codex_cli_rs";
|
||||
const DEFAULT_STREAM_IDLE_TIMEOUT_MS: u64 = 300_000;
|
||||
const DEFAULT_STREAM_MAX_RETRIES: u64 = 10;
|
||||
const DEFAULT_REQUEST_MAX_RETRIES: u64 = 4;
|
||||
@@ -229,15 +225,9 @@ pub fn built_in_model_providers() -> HashMap<String, ModelProviderInfo> {
|
||||
wire_api: WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: Some(
|
||||
[
|
||||
(
|
||||
"originator".to_string(),
|
||||
OPENAI_ORIGINATOR_HEADER.to_string(),
|
||||
),
|
||||
("version".to_string(), env!("CARGO_PKG_VERSION").to_string()),
|
||||
]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
[("version".to_string(), env!("CARGO_PKG_VERSION").to_string())]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
),
|
||||
env_http_headers: Some(
|
||||
[
|
||||
|
||||
@@ -121,6 +121,10 @@ pub enum Op {
|
||||
/// Request a single history entry identified by `log_id` + `offset`.
|
||||
GetHistoryEntryRequest { offset: usize, log_id: u64 },
|
||||
|
||||
/// Request the agent to summarize the current conversation context.
|
||||
/// The agent will use its existing context (either conversation history or previous response id)
|
||||
/// to generate a summary which will be returned as an AgentMessage event.
|
||||
Compact,
|
||||
/// Request to shut down codex instance.
|
||||
Shutdown,
|
||||
}
|
||||
|
||||
@@ -95,8 +95,8 @@ async fn includes_session_id_and_model_headers_in_request() {
|
||||
// get request from the server
|
||||
let request = &server.received_requests().await.unwrap()[0];
|
||||
let request_session_id = request.headers.get("session_id").unwrap();
|
||||
let request_originator = request.headers.get("originator").unwrap();
|
||||
let request_authorization = request.headers.get("authorization").unwrap();
|
||||
let request_originator = request.headers.get("originator").unwrap();
|
||||
|
||||
assert!(current_session_id.is_some());
|
||||
assert_eq!(
|
||||
@@ -170,6 +170,59 @@ async fn includes_base_instructions_override_in_request() {
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn originator_config_override_is_used() {
|
||||
#![allow(clippy::unwrap_used)]
|
||||
|
||||
// Mock server
|
||||
let server = MockServer::start().await;
|
||||
|
||||
let first = ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "text/event-stream")
|
||||
.set_body_raw(sse_completed("resp1"), "text/event-stream");
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.respond_with(first)
|
||||
.expect(1)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let model_provider = ModelProviderInfo {
|
||||
base_url: Some(format!("{}/v1", server.uri())),
|
||||
..built_in_model_providers()["openai"].clone()
|
||||
};
|
||||
|
||||
let codex_home = TempDir::new().unwrap();
|
||||
let mut config = load_default_config_for_test(&codex_home);
|
||||
config.model_provider = model_provider;
|
||||
config.internal_originator = Some("my_override".to_string());
|
||||
|
||||
let ctrl_c = std::sync::Arc::new(tokio::sync::Notify::new());
|
||||
let CodexSpawnOk { codex, .. } = Codex::spawn(
|
||||
config,
|
||||
Some(CodexAuth::from_api_key("Test API Key".to_string())),
|
||||
ctrl_c.clone(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![InputItem::Text {
|
||||
text: "hello".into(),
|
||||
}],
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
let request = &server.received_requests().await.unwrap()[0];
|
||||
let request_originator = request.headers.get("originator").unwrap();
|
||||
assert_eq!(request_originator.to_str().unwrap(), "my_override");
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn chatgpt_auth_sends_correct_request() {
|
||||
#![allow(clippy::unwrap_used)]
|
||||
@@ -235,8 +288,8 @@ async fn chatgpt_auth_sends_correct_request() {
|
||||
// get request from the server
|
||||
let request = &server.received_requests().await.unwrap()[0];
|
||||
let request_session_id = request.headers.get("session_id").unwrap();
|
||||
let request_originator = request.headers.get("originator").unwrap();
|
||||
let request_authorization = request.headers.get("authorization").unwrap();
|
||||
let request_originator = request.headers.get("originator").unwrap();
|
||||
let request_chatgpt_account_id = request.headers.get("chatgpt-account-id").unwrap();
|
||||
let request_body = request.body_json::<serde_json::Value>().unwrap();
|
||||
|
||||
|
||||
254
codex-rs/core/tests/compact.rs
Normal file
254
codex-rs/core/tests/compact.rs
Normal file
@@ -0,0 +1,254 @@
|
||||
#![expect(clippy::unwrap_used)]
|
||||
|
||||
use codex_core::Codex;
|
||||
use codex_core::CodexSpawnOk;
|
||||
use codex_core::ModelProviderInfo;
|
||||
use codex_core::built_in_model_providers;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR;
|
||||
use codex_login::CodexAuth;
|
||||
use core_test_support::load_default_config_for_test;
|
||||
use core_test_support::wait_for_event;
|
||||
use serde_json::Value;
|
||||
use tempfile::TempDir;
|
||||
use wiremock::Mock;
|
||||
use wiremock::MockServer;
|
||||
use wiremock::ResponseTemplate;
|
||||
use wiremock::matchers::method;
|
||||
use wiremock::matchers::path;
|
||||
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
// --- Test helpers -----------------------------------------------------------
|
||||
|
||||
/// Build an SSE stream body from a list of JSON events.
|
||||
fn sse(events: Vec<Value>) -> String {
|
||||
use std::fmt::Write as _;
|
||||
let mut out = String::new();
|
||||
for ev in events {
|
||||
let kind = ev.get("type").and_then(|v| v.as_str()).unwrap();
|
||||
writeln!(&mut out, "event: {kind}").unwrap();
|
||||
if !ev.as_object().map(|o| o.len() == 1).unwrap_or(false) {
|
||||
write!(&mut out, "data: {ev}\n\n").unwrap();
|
||||
} else {
|
||||
out.push('\n');
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
/// Convenience: SSE event for a completed response with a specific id.
|
||||
fn ev_completed(id: &str) -> Value {
|
||||
serde_json::json!({
|
||||
"type": "response.completed",
|
||||
"response": {
|
||||
"id": id,
|
||||
"usage": {"input_tokens":0,"input_tokens_details":null,"output_tokens":0,"output_tokens_details":null,"total_tokens":0}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Convenience: SSE event for a single assistant message output item.
|
||||
fn ev_assistant_message(id: &str, text: &str) -> Value {
|
||||
serde_json::json!({
|
||||
"type": "response.output_item.done",
|
||||
"item": {
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"id": id,
|
||||
"content": [{"type": "output_text", "text": text}]
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn sse_response(body: String) -> ResponseTemplate {
|
||||
ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "text/event-stream")
|
||||
.set_body_raw(body, "text/event-stream")
|
||||
}
|
||||
|
||||
async fn mount_sse_once<M>(server: &MockServer, matcher: M, body: String)
|
||||
where
|
||||
M: wiremock::Match + Send + Sync + 'static,
|
||||
{
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.and(matcher)
|
||||
.respond_with(sse_response(body))
|
||||
.expect(1)
|
||||
.mount(server)
|
||||
.await;
|
||||
}
|
||||
|
||||
const FIRST_REPLY: &str = "FIRST_REPLY";
|
||||
const SUMMARY_TEXT: &str = "SUMMARY_ONLY_CONTEXT";
|
||||
const SUMMARIZE_TRIGGER: &str = "Start Summarization";
|
||||
const THIRD_USER_MSG: &str = "next turn";
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn summarize_context_three_requests_and_instructions() {
|
||||
if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
||||
println!(
|
||||
"Skipping test because it cannot execute when network is disabled in a Codex sandbox."
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
// Set up a mock server that we can inspect after the run.
|
||||
let server = MockServer::start().await;
|
||||
|
||||
// SSE 1: assistant replies normally so it is recorded in history.
|
||||
let sse1 = sse(vec![
|
||||
ev_assistant_message("m1", FIRST_REPLY),
|
||||
ev_completed("r1"),
|
||||
]);
|
||||
|
||||
// SSE 2: summarizer returns a summary message.
|
||||
let sse2 = sse(vec![
|
||||
ev_assistant_message("m2", SUMMARY_TEXT),
|
||||
ev_completed("r2"),
|
||||
]);
|
||||
|
||||
// SSE 3: minimal completed; we only need to capture the request body.
|
||||
let sse3 = sse(vec![ev_completed("r3")]);
|
||||
|
||||
// Mount three expectations, one per request, matched by body content.
|
||||
let first_matcher = |req: &wiremock::Request| {
|
||||
let body = std::str::from_utf8(&req.body).unwrap_or("");
|
||||
body.contains("\"text\":\"hello world\"")
|
||||
&& !body.contains(&format!("\"text\":\"{SUMMARIZE_TRIGGER}\""))
|
||||
};
|
||||
mount_sse_once(&server, first_matcher, sse1).await;
|
||||
|
||||
let second_matcher = |req: &wiremock::Request| {
|
||||
let body = std::str::from_utf8(&req.body).unwrap_or("");
|
||||
body.contains(&format!("\"text\":\"{SUMMARIZE_TRIGGER}\""))
|
||||
};
|
||||
mount_sse_once(&server, second_matcher, sse2).await;
|
||||
|
||||
let third_matcher = |req: &wiremock::Request| {
|
||||
let body = std::str::from_utf8(&req.body).unwrap_or("");
|
||||
body.contains(&format!("\"text\":\"{THIRD_USER_MSG}\""))
|
||||
};
|
||||
mount_sse_once(&server, third_matcher, sse3).await;
|
||||
|
||||
// Build config pointing to the mock server and spawn Codex.
|
||||
let model_provider = ModelProviderInfo {
|
||||
base_url: Some(format!("{}/v1", server.uri())),
|
||||
..built_in_model_providers()["openai"].clone()
|
||||
};
|
||||
let home = TempDir::new().unwrap();
|
||||
let mut config = load_default_config_for_test(&home);
|
||||
config.model_provider = model_provider;
|
||||
let ctrl_c = std::sync::Arc::new(tokio::sync::Notify::new());
|
||||
let CodexSpawnOk { codex, .. } = Codex::spawn(
|
||||
config,
|
||||
Some(CodexAuth::from_api_key("dummy".to_string())),
|
||||
ctrl_c.clone(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// 1) Normal user input – should hit server once.
|
||||
codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![InputItem::Text {
|
||||
text: "hello world".into(),
|
||||
}],
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
// 2) Summarize – second hit with summarization instructions.
|
||||
codex.submit(Op::Compact).await.unwrap();
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
// 3) Next user input – third hit; history should include only the summary.
|
||||
codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![InputItem::Text {
|
||||
text: THIRD_USER_MSG.into(),
|
||||
}],
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
// Inspect the three captured requests.
|
||||
let requests = server.received_requests().await.unwrap();
|
||||
assert_eq!(requests.len(), 3, "expected exactly three requests");
|
||||
|
||||
let req1 = &requests[0];
|
||||
let req2 = &requests[1];
|
||||
let req3 = &requests[2];
|
||||
|
||||
let body1 = req1.body_json::<serde_json::Value>().unwrap();
|
||||
let body2 = req2.body_json::<serde_json::Value>().unwrap();
|
||||
let body3 = req3.body_json::<serde_json::Value>().unwrap();
|
||||
|
||||
// System instructions should change for the summarization turn.
|
||||
let instr1 = body1.get("instructions").and_then(|v| v.as_str()).unwrap();
|
||||
let instr2 = body2.get("instructions").and_then(|v| v.as_str()).unwrap();
|
||||
assert_ne!(
|
||||
instr1, instr2,
|
||||
"summarization should override base instructions"
|
||||
);
|
||||
assert!(
|
||||
instr2.contains("You are a summarization assistant"),
|
||||
"summarization instructions not applied"
|
||||
);
|
||||
|
||||
// The summarization request should include the injected user input marker.
|
||||
let input2 = body2.get("input").and_then(|v| v.as_array()).unwrap();
|
||||
// The last item is the user message created from the injected input.
|
||||
let last2 = input2.last().unwrap();
|
||||
assert_eq!(last2.get("type").unwrap().as_str().unwrap(), "message");
|
||||
assert_eq!(last2.get("role").unwrap().as_str().unwrap(), "user");
|
||||
let text2 = last2["content"][0]["text"].as_str().unwrap();
|
||||
assert!(text2.contains(SUMMARIZE_TRIGGER));
|
||||
|
||||
// Third request must contain only the summary from step 2 as prior history plus new user msg.
|
||||
let input3 = body3.get("input").and_then(|v| v.as_array()).unwrap();
|
||||
println!("third request body: {body3}");
|
||||
assert!(
|
||||
input3.len() >= 2,
|
||||
"expected summary + new user message in third request"
|
||||
);
|
||||
|
||||
// Collect all (role, text) message tuples.
|
||||
let mut messages: Vec<(String, String)> = Vec::new();
|
||||
for item in input3 {
|
||||
if item["type"].as_str() == Some("message") {
|
||||
let role = item["role"].as_str().unwrap_or_default().to_string();
|
||||
let text = item["content"][0]["text"]
|
||||
.as_str()
|
||||
.unwrap_or_default()
|
||||
.to_string();
|
||||
messages.push((role, text));
|
||||
}
|
||||
}
|
||||
|
||||
// Exactly one assistant message should remain after compaction and the new user message is present.
|
||||
let assistant_count = messages.iter().filter(|(r, _)| r == "assistant").count();
|
||||
assert_eq!(
|
||||
assistant_count, 1,
|
||||
"exactly one assistant message should remain after compaction"
|
||||
);
|
||||
assert!(
|
||||
messages
|
||||
.iter()
|
||||
.any(|(r, t)| r == "user" && t == THIRD_USER_MSG),
|
||||
"third request should include the new user message"
|
||||
);
|
||||
assert!(
|
||||
!messages.iter().any(|(_, t)| t.contains("hello world")),
|
||||
"third request should not include the original user input"
|
||||
);
|
||||
assert!(
|
||||
!messages.iter().any(|(_, t)| t.contains(SUMMARIZE_TRIGGER)),
|
||||
"third request should not include the summarize trigger"
|
||||
);
|
||||
}
|
||||
@@ -19,10 +19,11 @@ mod codex_tool_config;
|
||||
mod codex_tool_runner;
|
||||
mod exec_approval;
|
||||
mod json_to_toml;
|
||||
mod mcp_protocol;
|
||||
mod message_processor;
|
||||
pub mod mcp_protocol;
|
||||
pub(crate) mod message_processor;
|
||||
mod outgoing_message;
|
||||
mod patch_approval;
|
||||
pub(crate) mod tool_handlers;
|
||||
|
||||
use crate::message_processor::MessageProcessor;
|
||||
use crate::outgoing_message::OutgoingMessage;
|
||||
|
||||
@@ -7,7 +7,10 @@ use serde::Serialize;
|
||||
use strum_macros::Display;
|
||||
use uuid::Uuid;
|
||||
|
||||
use mcp_types::CallToolResult;
|
||||
use mcp_types::ContentBlock;
|
||||
use mcp_types::RequestId;
|
||||
use mcp_types::TextContent;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(transparent)]
|
||||
@@ -118,10 +121,47 @@ pub struct ToolCallResponse {
|
||||
pub request_id: RequestId,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub is_error: Option<bool>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
#[serde(default, skip_serializing_if = "Option::is_none", flatten)]
|
||||
pub result: Option<ToolCallResponseResult>,
|
||||
}
|
||||
|
||||
impl From<ToolCallResponse> for CallToolResult {
|
||||
fn from(val: ToolCallResponse) -> Self {
|
||||
let ToolCallResponse {
|
||||
request_id: _request_id,
|
||||
is_error,
|
||||
result,
|
||||
} = val;
|
||||
match result {
|
||||
Some(res) => match serde_json::to_value(&res) {
|
||||
Ok(v) => CallToolResult {
|
||||
content: vec![ContentBlock::TextContent(TextContent {
|
||||
r#type: "text".to_string(),
|
||||
text: v.to_string(),
|
||||
annotations: None,
|
||||
})],
|
||||
is_error,
|
||||
structured_content: Some(v),
|
||||
},
|
||||
Err(e) => CallToolResult {
|
||||
content: vec![ContentBlock::TextContent(TextContent {
|
||||
r#type: "text".to_string(),
|
||||
text: format!("Failed to serialize tool result: {e}"),
|
||||
annotations: None,
|
||||
})],
|
||||
is_error: Some(true),
|
||||
structured_content: None,
|
||||
},
|
||||
},
|
||||
None => CallToolResult {
|
||||
content: vec![],
|
||||
is_error,
|
||||
structured_content: None,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum ToolCallResponseResult {
|
||||
@@ -141,8 +181,10 @@ pub struct ConversationCreateResult {
|
||||
pub struct ConversationStreamResult {}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct ConversationSendMessageResult {
|
||||
pub success: bool,
|
||||
#[serde(tag = "status", rename_all = "camelCase")]
|
||||
pub enum ConversationSendMessageResult {
|
||||
Ok,
|
||||
Error { message: String },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
@@ -455,10 +497,13 @@ mod tests {
|
||||
},
|
||||
)),
|
||||
};
|
||||
let observed = to_val(&env);
|
||||
let req_id = env.request_id.clone();
|
||||
let observed = to_val(&CallToolResult::from(env));
|
||||
let expected = json!({
|
||||
"requestId": 1,
|
||||
"result": {
|
||||
"content": [
|
||||
{ "type": "text", "text": "{\"conversation_id\":\"d0f6ecbe-84a2-41c1-b23d-b20473b25eab\",\"model\":\"o3\"}" }
|
||||
],
|
||||
"structuredContent": {
|
||||
"conversation_id": "d0f6ecbe-84a2-41c1-b23d-b20473b25eab",
|
||||
"model": "o3"
|
||||
}
|
||||
@@ -467,6 +512,7 @@ mod tests {
|
||||
observed, expected,
|
||||
"response (ConversationCreate) must match"
|
||||
);
|
||||
assert_eq!(req_id, RequestId::Integer(1));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -478,15 +524,17 @@ mod tests {
|
||||
ConversationStreamResult {},
|
||||
)),
|
||||
};
|
||||
let observed = to_val(&env);
|
||||
let req_id = env.request_id.clone();
|
||||
let observed = to_val(&CallToolResult::from(env));
|
||||
let expected = json!({
|
||||
"requestId": 2,
|
||||
"result": {}
|
||||
"content": [ { "type": "text", "text": "{}" } ],
|
||||
"structuredContent": {}
|
||||
});
|
||||
assert_eq!(
|
||||
observed, expected,
|
||||
"response (ConversationStream) must have empty object result"
|
||||
);
|
||||
assert_eq!(req_id, RequestId::Integer(2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -495,18 +543,20 @@ mod tests {
|
||||
request_id: RequestId::Integer(3),
|
||||
is_error: None,
|
||||
result: Some(ToolCallResponseResult::ConversationSendMessage(
|
||||
ConversationSendMessageResult { success: true },
|
||||
ConversationSendMessageResult::Ok,
|
||||
)),
|
||||
};
|
||||
let observed = to_val(&env);
|
||||
let req_id = env.request_id.clone();
|
||||
let observed = to_val(&CallToolResult::from(env));
|
||||
let expected = json!({
|
||||
"requestId": 3,
|
||||
"result": { "success": true }
|
||||
"content": [ { "type": "text", "text": "{\"status\":\"ok\"}" } ],
|
||||
"structuredContent": { "status": "ok" }
|
||||
});
|
||||
assert_eq!(
|
||||
observed, expected,
|
||||
"response (ConversationSendMessageAccepted) must match"
|
||||
);
|
||||
assert_eq!(req_id, RequestId::Integer(3));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -526,10 +576,13 @@ mod tests {
|
||||
},
|
||||
)),
|
||||
};
|
||||
let observed = to_val(&env);
|
||||
let req_id = env.request_id.clone();
|
||||
let observed = to_val(&CallToolResult::from(env));
|
||||
let expected = json!({
|
||||
"requestId": 4,
|
||||
"result": {
|
||||
"content": [
|
||||
{ "type": "text", "text": "{\"conversations\":[{\"conversation_id\":\"67e55044-10b1-426f-9247-bb680e5fe0c8\",\"title\":\"Refactor config loader\"}],\"next_cursor\":\"next123\"}" }
|
||||
],
|
||||
"structuredContent": {
|
||||
"conversations": [
|
||||
{
|
||||
"conversation_id": "67e55044-10b1-426f-9247-bb680e5fe0c8",
|
||||
@@ -543,6 +596,7 @@ mod tests {
|
||||
observed, expected,
|
||||
"response (ConversationsList with cursor) must match"
|
||||
);
|
||||
assert_eq!(req_id, RequestId::Integer(4));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -552,15 +606,17 @@ mod tests {
|
||||
is_error: Some(true),
|
||||
result: None,
|
||||
};
|
||||
let observed = to_val(&env);
|
||||
let req_id = env.request_id.clone();
|
||||
let observed = to_val(&CallToolResult::from(env));
|
||||
let expected = json!({
|
||||
"requestId": 4,
|
||||
"content": [],
|
||||
"isError": true
|
||||
});
|
||||
assert_eq!(
|
||||
observed, expected,
|
||||
"error response must omit `result` and include `isError`"
|
||||
);
|
||||
assert_eq!(req_id, RequestId::Integer(4));
|
||||
}
|
||||
|
||||
// ----- Notifications -----
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
@@ -6,11 +7,16 @@ use crate::codex_tool_config::CodexToolCallParam;
|
||||
use crate::codex_tool_config::CodexToolCallReplyParam;
|
||||
use crate::codex_tool_config::create_tool_for_codex_tool_call_param;
|
||||
use crate::codex_tool_config::create_tool_for_codex_tool_call_reply_param;
|
||||
use crate::mcp_protocol::ToolCallRequestParams;
|
||||
use crate::mcp_protocol::ToolCallResponse;
|
||||
use crate::mcp_protocol::ToolCallResponseResult;
|
||||
use crate::outgoing_message::OutgoingMessageSender;
|
||||
use crate::tool_handlers::send_message::handle_send_message;
|
||||
|
||||
use codex_core::Codex;
|
||||
use codex_core::config::Config as CodexConfig;
|
||||
use codex_core::protocol::Submission;
|
||||
use mcp_types::CallToolRequest;
|
||||
use mcp_types::CallToolRequestParams;
|
||||
use mcp_types::CallToolResult;
|
||||
use mcp_types::ClientRequest;
|
||||
@@ -37,6 +43,7 @@ pub(crate) struct MessageProcessor {
|
||||
codex_linux_sandbox_exe: Option<PathBuf>,
|
||||
session_map: Arc<Mutex<HashMap<Uuid, Arc<Codex>>>>,
|
||||
running_requests_id_to_codex_uuid: Arc<Mutex<HashMap<RequestId, Uuid>>>,
|
||||
running_session_ids: Arc<Mutex<HashSet<Uuid>>>,
|
||||
}
|
||||
|
||||
impl MessageProcessor {
|
||||
@@ -52,9 +59,18 @@ impl MessageProcessor {
|
||||
codex_linux_sandbox_exe,
|
||||
session_map: Arc::new(Mutex::new(HashMap::new())),
|
||||
running_requests_id_to_codex_uuid: Arc::new(Mutex::new(HashMap::new())),
|
||||
running_session_ids: Arc::new(Mutex::new(HashSet::new())),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn session_map(&self) -> Arc<Mutex<HashMap<Uuid, Arc<Codex>>>> {
|
||||
self.session_map.clone()
|
||||
}
|
||||
|
||||
pub(crate) fn running_session_ids(&self) -> Arc<Mutex<HashSet<Uuid>>> {
|
||||
self.running_session_ids.clone()
|
||||
}
|
||||
|
||||
pub(crate) async fn process_request(&mut self, request: JSONRPCRequest) {
|
||||
// Hold on to the ID so we can respond.
|
||||
let request_id = request.id.clone();
|
||||
@@ -300,6 +316,14 @@ impl MessageProcessor {
|
||||
params: <mcp_types::CallToolRequest as mcp_types::ModelContextProtocolRequest>::Params,
|
||||
) {
|
||||
tracing::info!("tools/call -> params: {:?}", params);
|
||||
// Serialize params into JSON and try to parse as new type
|
||||
if let Ok(new_params) =
|
||||
serde_json::to_value(¶ms).and_then(serde_json::from_value::<ToolCallRequestParams>)
|
||||
{
|
||||
// New tool call matched → forward
|
||||
self.handle_new_tool_calls(id, new_params).await;
|
||||
return;
|
||||
}
|
||||
let CallToolRequestParams { name, arguments } = params;
|
||||
|
||||
match name.as_str() {
|
||||
@@ -323,6 +347,26 @@ impl MessageProcessor {
|
||||
}
|
||||
}
|
||||
}
|
||||
async fn handle_new_tool_calls(&self, request_id: RequestId, params: ToolCallRequestParams) {
|
||||
match params {
|
||||
ToolCallRequestParams::ConversationSendMessage(args) => {
|
||||
handle_send_message(self, request_id, args).await;
|
||||
}
|
||||
_ => {
|
||||
let result = CallToolResult {
|
||||
content: vec![ContentBlock::TextContent(TextContent {
|
||||
r#type: "text".to_string(),
|
||||
text: "Unknown tool".to_string(),
|
||||
annotations: None,
|
||||
})],
|
||||
is_error: Some(true),
|
||||
structured_content: None,
|
||||
};
|
||||
self.send_response::<CallToolRequest>(request_id, result)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_tool_call_codex(&self, id: RequestId, arguments: Option<serde_json::Value>) {
|
||||
let (initial_prompt, config): (String, CodexConfig) = match arguments {
|
||||
@@ -631,4 +675,20 @@ impl MessageProcessor {
|
||||
) {
|
||||
tracing::info!("notifications/message -> params: {:?}", params);
|
||||
}
|
||||
|
||||
pub(crate) async fn send_response_with_optional_error(
|
||||
&self,
|
||||
id: RequestId,
|
||||
message: Option<ToolCallResponseResult>,
|
||||
error: Option<bool>,
|
||||
) {
|
||||
let response = ToolCallResponse {
|
||||
request_id: id.clone(),
|
||||
is_error: error,
|
||||
result: message,
|
||||
};
|
||||
let result: CallToolResult = response.into();
|
||||
self.send_response::<mcp_types::CallToolRequest>(id.clone(), result)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
1
codex-rs/mcp-server/src/tool_handlers/mod.rs
Normal file
1
codex-rs/mcp-server/src/tool_handlers/mod.rs
Normal file
@@ -0,0 +1 @@
|
||||
pub(crate) mod send_message;
|
||||
124
codex-rs/mcp-server/src/tool_handlers/send_message.rs
Normal file
124
codex-rs/mcp-server/src/tool_handlers/send_message.rs
Normal file
@@ -0,0 +1,124 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use codex_core::Codex;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_core::protocol::Submission;
|
||||
use mcp_types::RequestId;
|
||||
use tokio::sync::Mutex;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::mcp_protocol::ConversationSendMessageArgs;
|
||||
use crate::mcp_protocol::ConversationSendMessageResult;
|
||||
use crate::mcp_protocol::ToolCallResponseResult;
|
||||
use crate::message_processor::MessageProcessor;
|
||||
|
||||
pub(crate) async fn handle_send_message(
|
||||
message_processor: &MessageProcessor,
|
||||
id: RequestId,
|
||||
arguments: ConversationSendMessageArgs,
|
||||
) {
|
||||
let ConversationSendMessageArgs {
|
||||
conversation_id,
|
||||
content: items,
|
||||
parent_message_id: _,
|
||||
conversation_overrides: _,
|
||||
} = arguments;
|
||||
|
||||
if items.is_empty() {
|
||||
message_processor
|
||||
.send_response_with_optional_error(
|
||||
id,
|
||||
Some(ToolCallResponseResult::ConversationSendMessage(
|
||||
ConversationSendMessageResult::Error {
|
||||
message: "No content items provided".to_string(),
|
||||
},
|
||||
)),
|
||||
Some(true),
|
||||
)
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
|
||||
let session_id = conversation_id.0;
|
||||
let Some(codex) = get_session(session_id, message_processor.session_map()).await else {
|
||||
message_processor
|
||||
.send_response_with_optional_error(
|
||||
id,
|
||||
Some(ToolCallResponseResult::ConversationSendMessage(
|
||||
ConversationSendMessageResult::Error {
|
||||
message: "Session does not exist".to_string(),
|
||||
},
|
||||
)),
|
||||
Some(true),
|
||||
)
|
||||
.await;
|
||||
return;
|
||||
};
|
||||
|
||||
let running = {
|
||||
let running_sessions = message_processor.running_session_ids();
|
||||
let mut running_sessions = running_sessions.lock().await;
|
||||
!running_sessions.insert(session_id)
|
||||
};
|
||||
|
||||
if running {
|
||||
message_processor
|
||||
.send_response_with_optional_error(
|
||||
id,
|
||||
Some(ToolCallResponseResult::ConversationSendMessage(
|
||||
ConversationSendMessageResult::Error {
|
||||
message: "Session is already running".to_string(),
|
||||
},
|
||||
)),
|
||||
Some(true),
|
||||
)
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
|
||||
let request_id_string = match &id {
|
||||
RequestId::String(s) => s.clone(),
|
||||
RequestId::Integer(i) => i.to_string(),
|
||||
};
|
||||
|
||||
let submit_res = codex
|
||||
.submit_with_id(Submission {
|
||||
id: request_id_string,
|
||||
op: Op::UserInput { items },
|
||||
})
|
||||
.await;
|
||||
|
||||
if let Err(e) = submit_res {
|
||||
message_processor
|
||||
.send_response_with_optional_error(
|
||||
id,
|
||||
Some(ToolCallResponseResult::ConversationSendMessage(
|
||||
ConversationSendMessageResult::Error {
|
||||
message: format!("Failed to submit user input: {e}"),
|
||||
},
|
||||
)),
|
||||
Some(true),
|
||||
)
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
|
||||
message_processor
|
||||
.send_response_with_optional_error(
|
||||
id,
|
||||
Some(ToolCallResponseResult::ConversationSendMessage(
|
||||
ConversationSendMessageResult::Ok,
|
||||
)),
|
||||
Some(false),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
pub(crate) async fn get_session(
|
||||
session_id: Uuid,
|
||||
session_map: Arc<Mutex<HashMap<Uuid, Arc<Codex>>>>,
|
||||
) -> Option<Arc<Codex>> {
|
||||
let guard = session_map.lock().await;
|
||||
guard.get(&session_id).cloned()
|
||||
}
|
||||
@@ -10,6 +10,7 @@ path = "lib.rs"
|
||||
anyhow = "1"
|
||||
assert_cmd = "2"
|
||||
codex-mcp-server = { path = "../.." }
|
||||
codex-core = { path = "../../../core" }
|
||||
mcp-types = { path = "../../../mcp-types" }
|
||||
pretty_assertions = "1.4.1"
|
||||
serde_json = "1"
|
||||
@@ -22,3 +23,4 @@ tokio = { version = "1", features = [
|
||||
"rt-multi-thread",
|
||||
] }
|
||||
wiremock = "0.6"
|
||||
uuid = { version = "1", features = ["serde", "v4"] }
|
||||
@@ -11,8 +11,13 @@ use tokio::process::ChildStdout;
|
||||
|
||||
use anyhow::Context;
|
||||
use assert_cmd::prelude::*;
|
||||
use codex_core::protocol::InputItem;
|
||||
use codex_mcp_server::CodexToolCallParam;
|
||||
use codex_mcp_server::CodexToolCallReplyParam;
|
||||
use codex_mcp_server::mcp_protocol::ConversationId;
|
||||
use codex_mcp_server::mcp_protocol::ConversationSendMessageArgs;
|
||||
use codex_mcp_server::mcp_protocol::ToolCallRequestParams;
|
||||
|
||||
use mcp_types::CallToolRequestParams;
|
||||
use mcp_types::ClientCapabilities;
|
||||
use mcp_types::Implementation;
|
||||
@@ -29,6 +34,7 @@ use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
use std::process::Command as StdCommand;
|
||||
use tokio::process::Command;
|
||||
use uuid::Uuid;
|
||||
|
||||
pub struct McpProcess {
|
||||
next_request_id: AtomicI64,
|
||||
@@ -174,6 +180,26 @@ impl McpProcess {
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn send_user_message_tool_call(
|
||||
&mut self,
|
||||
message: &str,
|
||||
session_id: &str,
|
||||
) -> anyhow::Result<i64> {
|
||||
let params = ToolCallRequestParams::ConversationSendMessage(ConversationSendMessageArgs {
|
||||
conversation_id: ConversationId(Uuid::parse_str(session_id)?),
|
||||
content: vec![InputItem::Text {
|
||||
text: message.to_string(),
|
||||
}],
|
||||
parent_message_id: None,
|
||||
conversation_overrides: None,
|
||||
});
|
||||
self.send_request(
|
||||
mcp_types::CallToolRequest::METHOD,
|
||||
Some(serde_json::to_value(params)?),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn send_request(
|
||||
&mut self,
|
||||
method: &str,
|
||||
|
||||
163
codex-rs/mcp-server/tests/send_message.rs
Normal file
163
codex-rs/mcp-server/tests/send_message.rs
Normal file
@@ -0,0 +1,163 @@
|
||||
#![allow(clippy::expect_used)]
|
||||
|
||||
use std::path::Path;
|
||||
use std::thread::sleep;
|
||||
use std::time::Duration;
|
||||
|
||||
use codex_mcp_server::CodexToolCallParam;
|
||||
use mcp_test_support::McpProcess;
|
||||
use mcp_test_support::create_final_assistant_message_sse_response;
|
||||
use mcp_test_support::create_mock_chat_completions_server;
|
||||
use mcp_types::JSONRPC_VERSION;
|
||||
use mcp_types::JSONRPCResponse;
|
||||
use mcp_types::RequestId;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
use tempfile::TempDir;
|
||||
use tokio::time::timeout;
|
||||
|
||||
const DEFAULT_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn test_send_message_success() {
|
||||
// Spin up a mock completions server that immediately ends the Codex turn.
|
||||
// Two Codex turns hit the mock model (session start + send-user-message). Provide two SSE responses.
|
||||
let responses = vec![
|
||||
create_final_assistant_message_sse_response("Done").expect("build mock assistant message"),
|
||||
create_final_assistant_message_sse_response("Done").expect("build mock assistant message"),
|
||||
];
|
||||
let server = create_mock_chat_completions_server(responses).await;
|
||||
|
||||
// Create a temporary Codex home with config pointing at the mock server.
|
||||
let codex_home = TempDir::new().expect("create temp dir");
|
||||
create_config_toml(codex_home.path(), &server.uri()).expect("write config.toml");
|
||||
|
||||
// Start MCP server process and initialize.
|
||||
let mut mcp_process = McpProcess::new(codex_home.path())
|
||||
.await
|
||||
.expect("spawn mcp process");
|
||||
timeout(DEFAULT_READ_TIMEOUT, mcp_process.initialize())
|
||||
.await
|
||||
.expect("init timed out")
|
||||
.expect("init failed");
|
||||
|
||||
// Kick off a Codex session so we have a valid session_id.
|
||||
let codex_request_id = mcp_process
|
||||
.send_codex_tool_call(CodexToolCallParam {
|
||||
prompt: "Start a session".to_string(),
|
||||
..Default::default()
|
||||
})
|
||||
.await
|
||||
.expect("send codex tool call");
|
||||
|
||||
// Wait for the session_configured event to get the session_id.
|
||||
let session_id = mcp_process
|
||||
.read_stream_until_configured_response_message()
|
||||
.await
|
||||
.expect("read session_configured");
|
||||
|
||||
// The original codex call will finish quickly given our mock; consume its response.
|
||||
timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp_process.read_stream_until_response_message(RequestId::Integer(codex_request_id)),
|
||||
)
|
||||
.await
|
||||
.expect("codex response timeout")
|
||||
.expect("codex response error");
|
||||
|
||||
// Now exercise the send-user-message tool.
|
||||
let send_msg_request_id = mcp_process
|
||||
.send_user_message_tool_call("Hello again", &session_id)
|
||||
.await
|
||||
.expect("send send-message tool call");
|
||||
|
||||
let response: JSONRPCResponse = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp_process.read_stream_until_response_message(RequestId::Integer(send_msg_request_id)),
|
||||
)
|
||||
.await
|
||||
.expect("send-user-message response timeout")
|
||||
.expect("send-user-message response error");
|
||||
|
||||
assert_eq!(
|
||||
JSONRPCResponse {
|
||||
jsonrpc: JSONRPC_VERSION.into(),
|
||||
id: RequestId::Integer(send_msg_request_id),
|
||||
result: json!({
|
||||
"content": [
|
||||
{
|
||||
"text": "{\"status\":\"ok\"}",
|
||||
"type": "text",
|
||||
}
|
||||
],
|
||||
"isError": false,
|
||||
"structuredContent": {
|
||||
"status": "ok"
|
||||
}
|
||||
}),
|
||||
},
|
||||
response
|
||||
);
|
||||
// wait for the server to hear the user message
|
||||
sleep(Duration::from_secs(1));
|
||||
|
||||
// Ensure the server and tempdir live until end of test
|
||||
drop(server);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn test_send_message_session_not_found() {
|
||||
// Start MCP without creating a Codex session
|
||||
let codex_home = TempDir::new().expect("tempdir");
|
||||
let mut mcp = McpProcess::new(codex_home.path()).await.expect("spawn");
|
||||
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize())
|
||||
.await
|
||||
.expect("timeout")
|
||||
.expect("init");
|
||||
|
||||
let unknown = uuid::Uuid::new_v4().to_string();
|
||||
let req_id = mcp
|
||||
.send_user_message_tool_call("ping", &unknown)
|
||||
.await
|
||||
.expect("send tool");
|
||||
|
||||
let resp: JSONRPCResponse = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(req_id)),
|
||||
)
|
||||
.await
|
||||
.expect("timeout")
|
||||
.expect("resp");
|
||||
|
||||
let result = resp.result.clone();
|
||||
let content = result["content"][0]["text"].as_str().unwrap_or("");
|
||||
assert!(content.contains("Session does not exist"));
|
||||
assert_eq!(result["isError"], json!(true));
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
fn create_config_toml(codex_home: &Path, server_uri: &str) -> std::io::Result<()> {
|
||||
let config_toml = codex_home.join("config.toml");
|
||||
std::fs::write(
|
||||
config_toml,
|
||||
format!(
|
||||
r#"
|
||||
model = "mock-model"
|
||||
approval_policy = "never"
|
||||
sandbox_mode = "danger-full-access"
|
||||
|
||||
model_provider = "mock_provider"
|
||||
|
||||
[model_providers.mock_provider]
|
||||
name = "Mock provider for test"
|
||||
base_url = "{server_uri}/v1"
|
||||
wire_api = "chat"
|
||||
request_max_retries = 0
|
||||
stream_max_retries = 0
|
||||
"#
|
||||
),
|
||||
)
|
||||
}
|
||||
@@ -10,6 +10,7 @@ use crate::tui;
|
||||
use codex_core::config::Config;
|
||||
use codex_core::protocol::Event;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::Op;
|
||||
use color_eyre::eyre::Result;
|
||||
use crossterm::SynchronizedUpdate;
|
||||
use crossterm::event::KeyCode;
|
||||
@@ -330,6 +331,12 @@ impl App<'_> {
|
||||
self.app_state = AppState::Chat { widget: new_widget };
|
||||
self.app_event_tx.send(AppEvent::RequestRedraw);
|
||||
}
|
||||
SlashCommand::Compact => {
|
||||
if let AppState::Chat { widget } = &mut self.app_state {
|
||||
widget.clear_token_usage();
|
||||
self.app_event_tx.send(AppEvent::CodexOp(Op::Compact));
|
||||
}
|
||||
}
|
||||
SlashCommand::Quit => {
|
||||
break;
|
||||
}
|
||||
@@ -416,7 +423,10 @@ impl App<'_> {
|
||||
// Ignore args; forward to the existing no-args handler
|
||||
self.app_event_tx.send(AppEvent::DispatchCommand(command));
|
||||
}
|
||||
SlashCommand::New | SlashCommand::Quit | SlashCommand::Diff => {
|
||||
SlashCommand::New
|
||||
| SlashCommand::Quit
|
||||
| SlashCommand::Diff
|
||||
| SlashCommand::Compact => {
|
||||
// For other commands, fall back to existing handling.
|
||||
// We can ignore args for now.
|
||||
self.app_event_tx.send(AppEvent::DispatchCommand(command));
|
||||
|
||||
@@ -566,6 +566,12 @@ impl ChatWidget<'_> {
|
||||
self.submit_op(op);
|
||||
self.request_redraw();
|
||||
}
|
||||
|
||||
pub(crate) fn clear_token_usage(&mut self) {
|
||||
self.token_usage = TokenUsage::default();
|
||||
self.bottom_pane
|
||||
.set_token_usage(self.token_usage.clone(), self.config.model_context_window);
|
||||
}
|
||||
}
|
||||
|
||||
impl WidgetRef for &ChatWidget<'_> {
|
||||
|
||||
@@ -13,6 +13,7 @@ pub enum SlashCommand {
|
||||
// DO NOT ALPHA-SORT! Enum order is presentation order in the popup, so
|
||||
// more frequently used commands should be listed first.
|
||||
New,
|
||||
Compact,
|
||||
Diff,
|
||||
Model,
|
||||
Quit,
|
||||
@@ -25,6 +26,7 @@ impl SlashCommand {
|
||||
pub fn description(self) -> &'static str {
|
||||
match self {
|
||||
SlashCommand::New => "Start a new chat.",
|
||||
SlashCommand::Compact => "Compact the chat history.",
|
||||
SlashCommand::Quit => "Exit the application.",
|
||||
SlashCommand::Model => "Select the model to use.",
|
||||
SlashCommand::Diff => {
|
||||
|
||||
Reference in New Issue
Block a user