mirror of
https://github.com/openai/codex.git
synced 2026-04-24 22:54:54 +00:00
Merge branch 'main' into chores-rollout
This commit is contained in:
2
.github/workflows/rust-release.yml
vendored
2
.github/workflows/rust-release.yml
vendored
@@ -93,7 +93,7 @@ jobs:
|
||||
sudo apt install -y musl-tools pkg-config
|
||||
|
||||
- name: Cargo build
|
||||
run: cargo build --target ${{ matrix.target }} --release --all-targets --all-features
|
||||
run: cargo build --target ${{ matrix.target }} --release --bin codex --bin codex-exec --bin codex-linux-sandbox
|
||||
|
||||
- name: Stage artifacts
|
||||
shell: bash
|
||||
|
||||
8
codex-rs/Cargo.lock
generated
8
codex-rs/Cargo.lock
generated
@@ -762,7 +762,9 @@ version = "0.0.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"clap",
|
||||
"codex-common",
|
||||
"codex-core",
|
||||
"dotenvy",
|
||||
"landlock",
|
||||
"libc",
|
||||
"seccompiler",
|
||||
@@ -1278,6 +1280,12 @@ version = "0.3.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fea41bba32d969b513997752735605054bc0dfa92b4c56bf1189f2e174be7a10"
|
||||
|
||||
[[package]]
|
||||
name = "dotenvy"
|
||||
version = "0.15.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1aaf95b3e5c8f23aa320147307562d361db0ae0d51242340f558153b4eb2439b"
|
||||
|
||||
[[package]]
|
||||
name = "dupe"
|
||||
version = "0.9.1"
|
||||
|
||||
@@ -41,7 +41,7 @@ pub(crate) async fn stream_chat_completions(
|
||||
|
||||
for item in &prompt.input {
|
||||
match item {
|
||||
ResponseItem::Message { role, content } => {
|
||||
ResponseItem::Message { role, content, .. } => {
|
||||
let mut text = String::new();
|
||||
for c in content {
|
||||
match c {
|
||||
@@ -58,6 +58,7 @@ pub(crate) async fn stream_chat_completions(
|
||||
name,
|
||||
arguments,
|
||||
call_id,
|
||||
..
|
||||
} => {
|
||||
messages.push(json!({
|
||||
"role": "assistant",
|
||||
@@ -259,6 +260,7 @@ async fn process_chat_sse<S>(
|
||||
content: vec![ContentItem::OutputText {
|
||||
text: content.to_string(),
|
||||
}],
|
||||
id: None,
|
||||
};
|
||||
|
||||
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
|
||||
@@ -300,6 +302,7 @@ async fn process_chat_sse<S>(
|
||||
"tool_calls" if fn_call_state.active => {
|
||||
// Build the FunctionCall response item.
|
||||
let item = ResponseItem::FunctionCall {
|
||||
id: None,
|
||||
name: fn_call_state.name.clone().unwrap_or_else(|| "".to_string()),
|
||||
arguments: fn_call_state.arguments.clone(),
|
||||
call_id: fn_call_state.call_id.clone().unwrap_or_else(String::new),
|
||||
@@ -402,6 +405,7 @@ where
|
||||
}))) => {
|
||||
if !this.cumulative.is_empty() {
|
||||
let aggregated_item = crate::models::ResponseItem::Message {
|
||||
id: None,
|
||||
role: "assistant".to_string(),
|
||||
content: vec![crate::models::ContentItem::OutputText {
|
||||
text: std::mem::take(&mut this.cumulative),
|
||||
|
||||
@@ -117,6 +117,15 @@ impl ModelClient {
|
||||
let full_instructions = prompt.get_full_instructions(&self.config.model);
|
||||
let tools_json = create_tools_json_for_responses_api(prompt, &self.config.model)?;
|
||||
let reasoning = create_reasoning_param_for_request(&self.config, self.effort, self.summary);
|
||||
|
||||
// Request encrypted COT if we are not storing responses,
|
||||
// otherwise reasoning items will be referenced by ID
|
||||
let include = if !prompt.store && reasoning.is_some() {
|
||||
vec!["reasoning.encrypted_content".to_string()]
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
|
||||
let payload = ResponsesApiRequest {
|
||||
model: &self.config.model,
|
||||
instructions: &full_instructions,
|
||||
@@ -125,10 +134,10 @@ impl ModelClient {
|
||||
tool_choice: "auto",
|
||||
parallel_tool_calls: false,
|
||||
reasoning,
|
||||
previous_response_id: prompt.prev_id.clone(),
|
||||
store: prompt.store,
|
||||
// TODO: make this configurable
|
||||
stream: true,
|
||||
include,
|
||||
};
|
||||
|
||||
trace!(
|
||||
|
||||
@@ -22,8 +22,6 @@ const BASE_INSTRUCTIONS: &str = include_str!("../prompt.md");
|
||||
pub struct Prompt {
|
||||
/// Conversation context input items.
|
||||
pub input: Vec<ResponseItem>,
|
||||
/// Optional previous response ID (when storage is enabled).
|
||||
pub prev_id: Option<String>,
|
||||
/// Optional instructions from the user to amend to the built-in agent
|
||||
/// instructions.
|
||||
pub user_instructions: Option<String>,
|
||||
@@ -133,11 +131,10 @@ pub(crate) struct ResponsesApiRequest<'a> {
|
||||
pub(crate) tool_choice: &'static str,
|
||||
pub(crate) parallel_tool_calls: bool,
|
||||
pub(crate) reasoning: Option<Reasoning>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub(crate) previous_response_id: Option<String>,
|
||||
/// true when using the Responses API.
|
||||
pub(crate) store: bool,
|
||||
pub(crate) stream: bool,
|
||||
pub(crate) include: Vec<String>,
|
||||
}
|
||||
|
||||
use crate::config::Config;
|
||||
|
||||
@@ -34,7 +34,6 @@ use tracing::trace;
|
||||
use tracing::warn;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::WireApi;
|
||||
use crate::client::ModelClient;
|
||||
use crate::client_common::Prompt;
|
||||
use crate::client_common::ResponseEvent;
|
||||
@@ -192,6 +191,7 @@ pub(crate) struct Session {
|
||||
sandbox_policy: SandboxPolicy,
|
||||
shell_environment_policy: ShellEnvironmentPolicy,
|
||||
writable_roots: Mutex<Vec<PathBuf>>,
|
||||
disable_response_storage: bool,
|
||||
|
||||
/// Manager for external MCP servers/tools.
|
||||
mcp_connection_manager: McpConnectionManager,
|
||||
@@ -220,13 +220,9 @@ impl Session {
|
||||
struct State {
|
||||
approved_commands: HashSet<Vec<String>>,
|
||||
current_task: Option<AgentTask>,
|
||||
/// Call IDs that have been sent from the Responses API but have not been sent back yet.
|
||||
/// You CANNOT send a Responses API follow-up message unless you have sent back the output for all pending calls or else it will 400.
|
||||
pending_call_ids: HashSet<String>,
|
||||
previous_response_id: Option<String>,
|
||||
pending_approvals: HashMap<String, oneshot::Sender<ReviewDecision>>,
|
||||
pending_input: Vec<ResponseInputItem>,
|
||||
zdr_transcript: Option<ConversationHistory>,
|
||||
history: ConversationHistory,
|
||||
}
|
||||
|
||||
impl Session {
|
||||
@@ -258,6 +254,7 @@ impl Session {
|
||||
pub async fn request_command_approval(
|
||||
&self,
|
||||
sub_id: String,
|
||||
call_id: String,
|
||||
command: Vec<String>,
|
||||
cwd: PathBuf,
|
||||
reason: Option<String>,
|
||||
@@ -266,6 +263,7 @@ impl Session {
|
||||
let event = Event {
|
||||
id: sub_id.clone(),
|
||||
msg: EventMsg::ExecApprovalRequest(ExecApprovalRequestEvent {
|
||||
call_id,
|
||||
command,
|
||||
cwd,
|
||||
reason,
|
||||
@@ -282,6 +280,7 @@ impl Session {
|
||||
pub async fn request_patch_approval(
|
||||
&self,
|
||||
sub_id: String,
|
||||
call_id: String,
|
||||
action: &ApplyPatchAction,
|
||||
reason: Option<String>,
|
||||
grant_root: Option<PathBuf>,
|
||||
@@ -290,6 +289,7 @@ impl Session {
|
||||
let event = Event {
|
||||
id: sub_id.clone(),
|
||||
msg: EventMsg::ApplyPatchApprovalRequest(ApplyPatchApprovalRequestEvent {
|
||||
call_id,
|
||||
changes: convert_apply_patch_to_protocol(action),
|
||||
reason,
|
||||
grant_root,
|
||||
@@ -321,18 +321,11 @@ impl Session {
|
||||
debug!("Recording items for conversation: {items:?}");
|
||||
self.record_state_snapshot(items).await;
|
||||
|
||||
if let Some(transcript) = self.state.lock().unwrap().zdr_transcript.as_mut() {
|
||||
transcript.record_items(items);
|
||||
}
|
||||
self.state.lock().unwrap().history.record_items(items);
|
||||
}
|
||||
|
||||
async fn record_state_snapshot(&self, items: &[ResponseItem]) {
|
||||
let snapshot = {
|
||||
let state = self.state.lock().unwrap();
|
||||
crate::rollout::SessionStateSnapshot {
|
||||
previous_response_id: state.previous_response_id.clone(),
|
||||
}
|
||||
};
|
||||
let snapshot = { crate::rollout::SessionStateSnapshot {} };
|
||||
|
||||
let recorder = {
|
||||
let guard = self.rollout.lock().unwrap();
|
||||
@@ -434,8 +427,6 @@ impl Session {
|
||||
pub fn abort(&self) {
|
||||
info!("Aborting existing session");
|
||||
let mut state = self.state.lock().unwrap();
|
||||
// Don't clear pending_call_ids because we need to keep track of them to ensure we don't 400 on the next turn.
|
||||
// We will generate a synthetic aborted response for each pending call id.
|
||||
state.pending_approvals.clear();
|
||||
state.pending_input.clear();
|
||||
if let Some(task) = state.current_task.take() {
|
||||
@@ -480,15 +471,10 @@ impl Drop for Session {
|
||||
}
|
||||
|
||||
impl State {
|
||||
pub fn partial_clone(&self, retain_zdr_transcript: bool) -> Self {
|
||||
pub fn partial_clone(&self) -> Self {
|
||||
Self {
|
||||
approved_commands: self.approved_commands.clone(),
|
||||
previous_response_id: self.previous_response_id.clone(),
|
||||
zdr_transcript: if retain_zdr_transcript {
|
||||
self.zdr_transcript.clone()
|
||||
} else {
|
||||
None
|
||||
},
|
||||
history: self.history.clone(),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
@@ -628,22 +614,13 @@ async fn submission_loop(
|
||||
);
|
||||
|
||||
// abort any current running session and clone its state
|
||||
let retain_zdr_transcript =
|
||||
record_conversation_history(disable_response_storage, provider.wire_api);
|
||||
let state = match sess.take() {
|
||||
Some(sess) => {
|
||||
sess.abort();
|
||||
sess.state
|
||||
.lock()
|
||||
.unwrap()
|
||||
.partial_clone(retain_zdr_transcript)
|
||||
sess.state.lock().unwrap().partial_clone()
|
||||
}
|
||||
None => State {
|
||||
zdr_transcript: if retain_zdr_transcript {
|
||||
Some(ConversationHistory::new())
|
||||
} else {
|
||||
None
|
||||
},
|
||||
history: ConversationHistory::new(),
|
||||
..Default::default()
|
||||
},
|
||||
};
|
||||
@@ -694,18 +671,14 @@ async fn submission_loop(
|
||||
state: Mutex::new(state),
|
||||
rollout: Mutex::new(rollout_recorder),
|
||||
codex_linux_sandbox_exe: config.codex_linux_sandbox_exe.clone(),
|
||||
disable_response_storage,
|
||||
}));
|
||||
|
||||
// Patch restored state into the newly created session.
|
||||
if let Some(sess_arc) = &sess {
|
||||
if restored_prev_id.is_some() || restored_items.is_some() {
|
||||
if restored_items.is_some() {
|
||||
let mut st = sess_arc.state.lock().unwrap();
|
||||
st.previous_response_id = restored_prev_id;
|
||||
if let (Some(hist), Some(items)) =
|
||||
(st.zdr_transcript.as_mut(), restored_items.as_ref())
|
||||
{
|
||||
hist.record_items(items.iter());
|
||||
}
|
||||
st.history.record_items(restored_items.unwrap().iter());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -852,14 +825,8 @@ async fn run_task(sess: Arc<Session>, sub_id: String, input: Vec<InputItem>) {
|
||||
sess.record_conversation_items(&[initial_input_for_turn.clone().into()])
|
||||
.await;
|
||||
|
||||
let mut input_for_next_turn: Vec<ResponseInputItem> = vec![initial_input_for_turn];
|
||||
let last_agent_message: Option<String>;
|
||||
loop {
|
||||
let mut net_new_turn_input = input_for_next_turn
|
||||
.drain(..)
|
||||
.map(ResponseItem::from)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// Note that pending_input would be something like a message the user
|
||||
// submitted through the UI while the model was running. Though the UI
|
||||
// may support this, the model might not.
|
||||
@@ -876,29 +843,7 @@ async fn run_task(sess: Arc<Session>, sub_id: String, input: Vec<InputItem>) {
|
||||
// only record the new items that originated in this turn so that it
|
||||
// represents an append-only log without duplicates.
|
||||
let turn_input: Vec<ResponseItem> =
|
||||
if let Some(transcript) = sess.state.lock().unwrap().zdr_transcript.as_mut() {
|
||||
// If we are using Chat/ZDR, we need to send the transcript with
|
||||
// every turn. By induction, `transcript` already contains:
|
||||
// - The `input` that kicked off this task.
|
||||
// - Each `ResponseItem` that was recorded in the previous turn.
|
||||
// - Each response to a `ResponseItem` (in practice, the only
|
||||
// response type we seem to have is `FunctionCallOutput`).
|
||||
//
|
||||
// The only thing the `transcript` does not contain is the
|
||||
// `pending_input` that was injected while the model was
|
||||
// running. We need to add that to the conversation history
|
||||
// so that the model can see it in the next turn.
|
||||
[transcript.contents(), pending_input].concat()
|
||||
} else {
|
||||
// In practice, net_new_turn_input should contain only:
|
||||
// - User messages
|
||||
// - Outputs for function calls requested by the model
|
||||
net_new_turn_input.extend(pending_input);
|
||||
|
||||
// Responses API path – we can just send the new items and
|
||||
// record the same.
|
||||
net_new_turn_input
|
||||
};
|
||||
[sess.state.lock().unwrap().history.contents(), pending_input].concat();
|
||||
|
||||
let turn_input_messages: Vec<String> = turn_input
|
||||
.iter()
|
||||
@@ -974,8 +919,19 @@ async fn run_task(sess: Arc<Session>, sub_id: String, input: Vec<InputItem>) {
|
||||
},
|
||||
);
|
||||
}
|
||||
(ResponseItem::Reasoning { .. }, None) => {
|
||||
// Omit from conversation history.
|
||||
(
|
||||
ResponseItem::Reasoning {
|
||||
id,
|
||||
summary,
|
||||
encrypted_content,
|
||||
},
|
||||
None,
|
||||
) => {
|
||||
items_to_record_in_conversation_history.push(ResponseItem::Reasoning {
|
||||
id: id.clone(),
|
||||
summary: summary.clone(),
|
||||
encrypted_content: encrypted_content.clone(),
|
||||
});
|
||||
}
|
||||
_ => {
|
||||
warn!("Unexpected response item: {item:?} with response: {response:?}");
|
||||
@@ -1004,8 +960,6 @@ async fn run_task(sess: Arc<Session>, sub_id: String, input: Vec<InputItem>) {
|
||||
});
|
||||
break;
|
||||
}
|
||||
|
||||
input_for_next_turn = responses;
|
||||
}
|
||||
Err(e) => {
|
||||
info!("Turn error: {e:#}");
|
||||
@@ -1033,26 +987,11 @@ async fn run_turn(
|
||||
sub_id: String,
|
||||
input: Vec<ResponseItem>,
|
||||
) -> CodexResult<Vec<ProcessedResponseItem>> {
|
||||
// Decide whether to use server-side storage (previous_response_id) or disable it
|
||||
let (prev_id, store) = {
|
||||
let state = sess.state.lock().unwrap();
|
||||
let store = state.zdr_transcript.is_none();
|
||||
let prev_id = if store {
|
||||
state.previous_response_id.clone()
|
||||
} else {
|
||||
// When using ZDR, the Responses API may send previous_response_id
|
||||
// back, but trying to use it results in a 400.
|
||||
None
|
||||
};
|
||||
(prev_id, store)
|
||||
};
|
||||
|
||||
let extra_tools = sess.mcp_connection_manager.list_all_tools();
|
||||
let prompt = Prompt {
|
||||
input,
|
||||
prev_id,
|
||||
user_instructions: sess.user_instructions.clone(),
|
||||
store,
|
||||
store: !sess.disable_response_storage,
|
||||
extra_tools,
|
||||
base_instructions_override: sess.base_instructions.clone(),
|
||||
};
|
||||
@@ -1126,11 +1065,17 @@ async fn try_run_turn(
|
||||
// This usually happens because the user interrupted the model before we responded to one of its tool calls
|
||||
// and then the user sent a follow-up message.
|
||||
let missing_calls = {
|
||||
sess.state
|
||||
.lock()
|
||||
.unwrap()
|
||||
.pending_call_ids
|
||||
prompt
|
||||
.input
|
||||
.iter()
|
||||
.filter_map(|ri| match ri {
|
||||
ResponseItem::FunctionCall { call_id, .. } => Some(call_id),
|
||||
ResponseItem::LocalShellCall {
|
||||
call_id: Some(call_id),
|
||||
..
|
||||
} => Some(call_id),
|
||||
_ => None,
|
||||
})
|
||||
.filter_map(|call_id| {
|
||||
if completed_call_ids.contains(&call_id) {
|
||||
None
|
||||
@@ -1184,31 +1129,14 @@ async fn try_run_turn(
|
||||
};
|
||||
|
||||
match event {
|
||||
ResponseEvent::Created => {
|
||||
let mut state = sess.state.lock().unwrap();
|
||||
// We successfully created a new response and ensured that all pending calls were included so we can clear the pending call ids.
|
||||
state.pending_call_ids.clear();
|
||||
}
|
||||
ResponseEvent::Created => {}
|
||||
ResponseEvent::OutputItemDone(item) => {
|
||||
let call_id = match &item {
|
||||
ResponseItem::LocalShellCall {
|
||||
call_id: Some(call_id),
|
||||
..
|
||||
} => Some(call_id),
|
||||
ResponseItem::FunctionCall { call_id, .. } => Some(call_id),
|
||||
_ => None,
|
||||
};
|
||||
if let Some(call_id) = call_id {
|
||||
// We just got a new call id so we need to make sure to respond to it in the next turn.
|
||||
let mut state = sess.state.lock().unwrap();
|
||||
state.pending_call_ids.insert(call_id.clone());
|
||||
}
|
||||
let response = handle_response_item(sess, sub_id, item.clone()).await?;
|
||||
|
||||
output.push(ProcessedResponseItem { item, response });
|
||||
}
|
||||
ResponseEvent::Completed {
|
||||
response_id,
|
||||
response_id: _,
|
||||
token_usage,
|
||||
} => {
|
||||
if let Some(token_usage) = token_usage {
|
||||
@@ -1221,8 +1149,6 @@ async fn try_run_turn(
|
||||
.ok();
|
||||
}
|
||||
|
||||
let mut state = sess.state.lock().unwrap();
|
||||
state.previous_response_id = Some(response_id);
|
||||
return Ok(output);
|
||||
}
|
||||
ResponseEvent::OutputTextDelta(delta) => {
|
||||
@@ -1262,7 +1188,7 @@ async fn handle_response_item(
|
||||
}
|
||||
None
|
||||
}
|
||||
ResponseItem::Reasoning { id: _, summary } => {
|
||||
ResponseItem::Reasoning { summary, .. } => {
|
||||
for item in summary {
|
||||
let text = match item {
|
||||
ReasoningItemReasoningSummary::SummaryText { text } => text,
|
||||
@@ -1279,6 +1205,7 @@ async fn handle_response_item(
|
||||
name,
|
||||
arguments,
|
||||
call_id,
|
||||
..
|
||||
} => {
|
||||
info!("FunctionCall: {arguments}");
|
||||
Some(handle_function_call(sess, sub_id.to_string(), name, arguments, call_id).await)
|
||||
@@ -1449,6 +1376,7 @@ async fn handle_container_exec_with_params(
|
||||
let rx_approve = sess
|
||||
.request_command_approval(
|
||||
sub_id.clone(),
|
||||
call_id.clone(),
|
||||
params.command.clone(),
|
||||
params.cwd.clone(),
|
||||
None,
|
||||
@@ -1576,6 +1504,7 @@ async fn handle_sandbox_error(
|
||||
let rx_approve = sess
|
||||
.request_command_approval(
|
||||
sub_id.clone(),
|
||||
call_id.clone(),
|
||||
params.command.clone(),
|
||||
params.cwd.clone(),
|
||||
Some("command failed; retry without sandbox?".to_string()),
|
||||
@@ -1593,9 +1522,7 @@ async fn handle_sandbox_error(
|
||||
sess.notify_background_event(&sub_id, "retrying command without sandbox")
|
||||
.await;
|
||||
|
||||
// Emit a fresh Begin event so progress bars reset.
|
||||
let retry_call_id = format!("{call_id}-retry");
|
||||
sess.notify_exec_command_begin(&sub_id, &retry_call_id, ¶ms)
|
||||
sess.notify_exec_command_begin(&sub_id, &call_id, ¶ms)
|
||||
.await;
|
||||
|
||||
// This is an escalated retry; the policy will not be
|
||||
@@ -1618,14 +1545,8 @@ async fn handle_sandbox_error(
|
||||
duration,
|
||||
} = retry_output;
|
||||
|
||||
sess.notify_exec_command_end(
|
||||
&sub_id,
|
||||
&retry_call_id,
|
||||
&stdout,
|
||||
&stderr,
|
||||
exit_code,
|
||||
)
|
||||
.await;
|
||||
sess.notify_exec_command_end(&sub_id, &call_id, &stdout, &stderr, exit_code)
|
||||
.await;
|
||||
|
||||
let is_success = exit_code == 0;
|
||||
let content = format_exec_output(
|
||||
@@ -1689,7 +1610,7 @@ async fn apply_patch(
|
||||
// Compute a readable summary of path changes to include in the
|
||||
// approval request so the user can make an informed decision.
|
||||
let rx_approve = sess
|
||||
.request_patch_approval(sub_id.clone(), &action, None, None)
|
||||
.request_patch_approval(sub_id.clone(), call_id.clone(), &action, None, None)
|
||||
.await;
|
||||
match rx_approve.await.unwrap_or_default() {
|
||||
ReviewDecision::Approved | ReviewDecision::ApprovedForSession => false,
|
||||
@@ -1727,7 +1648,13 @@ async fn apply_patch(
|
||||
));
|
||||
|
||||
let rx = sess
|
||||
.request_patch_approval(sub_id.clone(), &action, reason.clone(), Some(root.clone()))
|
||||
.request_patch_approval(
|
||||
sub_id.clone(),
|
||||
call_id.clone(),
|
||||
&action,
|
||||
reason.clone(),
|
||||
Some(root.clone()),
|
||||
)
|
||||
.await;
|
||||
|
||||
if !matches!(
|
||||
@@ -1811,6 +1738,7 @@ async fn apply_patch(
|
||||
let rx = sess
|
||||
.request_patch_approval(
|
||||
sub_id.clone(),
|
||||
call_id.clone(),
|
||||
&action,
|
||||
reason.clone(),
|
||||
Some(root.clone()),
|
||||
@@ -2069,7 +1997,7 @@ fn format_exec_output(output: &str, exit_code: i32, duration: Duration) -> Strin
|
||||
|
||||
fn get_last_assistant_message_from_turn(responses: &[ResponseItem]) -> Option<String> {
|
||||
responses.iter().rev().find_map(|item| {
|
||||
if let ResponseItem::Message { role, content } = item {
|
||||
if let ResponseItem::Message { role, content, .. } = item {
|
||||
if role == "assistant" {
|
||||
content.iter().rev().find_map(|ci| {
|
||||
if let ContentItem::OutputText { text } = ci {
|
||||
@@ -2086,15 +2014,3 @@ fn get_last_assistant_message_from_turn(responses: &[ResponseItem]) -> Option<St
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// See [`ConversationHistory`] for details.
|
||||
fn record_conversation_history(disable_response_storage: bool, wire_api: WireApi) -> bool {
|
||||
if disable_response_storage {
|
||||
return true;
|
||||
}
|
||||
|
||||
match wire_api {
|
||||
WireApi::Responses => false,
|
||||
WireApi::Chat => true,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -561,7 +561,7 @@ fn default_model() -> String {
|
||||
/// function will Err if the path does not exist.
|
||||
/// - If `CODEX_HOME` is not set, this function does not verify that the
|
||||
/// directory exists.
|
||||
fn find_codex_home() -> std::io::Result<PathBuf> {
|
||||
pub fn find_codex_home() -> std::io::Result<PathBuf> {
|
||||
// Honor the `CODEX_HOME` environment variable when it is set to allow users
|
||||
// (and tests) to override the default location.
|
||||
if let Ok(val) = std::env::var("CODEX_HOME") {
|
||||
|
||||
@@ -1,12 +1,7 @@
|
||||
use crate::models::ResponseItem;
|
||||
|
||||
/// Transcript of conversation history that is needed:
|
||||
/// - for ZDR clients for which previous_response_id is not available, so we
|
||||
/// must include the transcript with every API call. This must include each
|
||||
/// `function_call` and its corresponding `function_call_output`.
|
||||
/// - for clients using the "chat completions" API as opposed to the
|
||||
/// "responses" API.
|
||||
#[derive(Debug, Clone)]
|
||||
/// Transcript of conversation history
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub(crate) struct ConversationHistory {
|
||||
/// The oldest items are at the beginning of the vector.
|
||||
items: Vec<ResponseItem>,
|
||||
@@ -44,7 +39,8 @@ fn is_api_message(message: &ResponseItem) -> bool {
|
||||
ResponseItem::Message { role, .. } => role.as_str() != "system",
|
||||
ResponseItem::FunctionCallOutput { .. }
|
||||
| ResponseItem::FunctionCall { .. }
|
||||
| ResponseItem::LocalShellCall { .. } => true,
|
||||
ResponseItem::Reasoning { .. } | ResponseItem::Other => false,
|
||||
| ResponseItem::LocalShellCall { .. }
|
||||
| ResponseItem::Reasoning { .. } => true,
|
||||
ResponseItem::Other => false,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ use std::collections::HashMap;
|
||||
use base64::Engine;
|
||||
use mcp_types::CallToolResult;
|
||||
use serde::Deserialize;
|
||||
use serde::Deserializer;
|
||||
use serde::Serialize;
|
||||
use serde::ser::Serializer;
|
||||
|
||||
@@ -37,12 +38,14 @@ pub enum ContentItem {
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum ResponseItem {
|
||||
Message {
|
||||
id: Option<String>,
|
||||
role: String,
|
||||
content: Vec<ContentItem>,
|
||||
},
|
||||
Reasoning {
|
||||
id: String,
|
||||
summary: Vec<ReasoningItemReasoningSummary>,
|
||||
encrypted_content: Option<String>,
|
||||
},
|
||||
LocalShellCall {
|
||||
/// Set when using the chat completions API.
|
||||
@@ -53,6 +56,7 @@ pub enum ResponseItem {
|
||||
action: LocalShellAction,
|
||||
},
|
||||
FunctionCall {
|
||||
id: Option<String>,
|
||||
name: String,
|
||||
// The Responses API returns the function call arguments as a *string* that contains
|
||||
// JSON, not as an already‑parsed object. We keep it as a raw string here and let
|
||||
@@ -78,7 +82,11 @@ pub enum ResponseItem {
|
||||
impl From<ResponseInputItem> for ResponseItem {
|
||||
fn from(item: ResponseInputItem) -> Self {
|
||||
match item {
|
||||
ResponseInputItem::Message { role, content } => Self::Message { role, content },
|
||||
ResponseInputItem::Message { role, content } => Self::Message {
|
||||
role,
|
||||
content,
|
||||
id: None,
|
||||
},
|
||||
ResponseInputItem::FunctionCallOutput { call_id, output } => {
|
||||
Self::FunctionCallOutput { call_id, output }
|
||||
}
|
||||
@@ -177,7 +185,7 @@ pub struct ShellToolCallParams {
|
||||
pub timeout_ms: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug, Clone)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct FunctionCallOutputPayload {
|
||||
pub content: String,
|
||||
#[expect(dead_code)]
|
||||
@@ -205,6 +213,19 @@ impl Serialize for FunctionCallOutputPayload {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for FunctionCallOutputPayload {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let s = String::deserialize(deserializer)?;
|
||||
Ok(FunctionCallOutputPayload {
|
||||
content: s,
|
||||
success: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Implement Display so callers can treat the payload like a plain string when logging or doing
|
||||
// trivial substring checks in tests (existing tests call `.contains()` on the output). Display
|
||||
// returns the raw `content` field.
|
||||
|
||||
@@ -422,6 +422,8 @@ pub struct ExecCommandEndEvent {
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ExecApprovalRequestEvent {
|
||||
/// Identifier for the associated exec call, if available.
|
||||
pub call_id: String,
|
||||
/// The command to be executed.
|
||||
pub command: Vec<String>,
|
||||
/// The command's working directory.
|
||||
@@ -433,6 +435,8 @@ pub struct ExecApprovalRequestEvent {
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ApplyPatchApprovalRequestEvent {
|
||||
/// Responses API call id for the associated patch apply call, if available.
|
||||
pub call_id: String,
|
||||
pub changes: HashMap<PathBuf, FileChange>,
|
||||
/// Optional explanatory reason (e.g. request for extra write access).
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
|
||||
@@ -31,9 +31,7 @@ pub struct SessionMeta {
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Default, Clone)]
|
||||
pub struct SessionStateSnapshot {
|
||||
pub previous_response_id: Option<String>,
|
||||
}
|
||||
pub struct SessionStateSnapshot {}
|
||||
|
||||
#[derive(Serialize, Deserialize, Default, Clone)]
|
||||
pub struct SavedSession {
|
||||
@@ -126,8 +124,9 @@ impl RolloutRecorder {
|
||||
ResponseItem::Message { .. }
|
||||
| ResponseItem::LocalShellCall { .. }
|
||||
| ResponseItem::FunctionCall { .. }
|
||||
| ResponseItem::FunctionCallOutput { .. } => filtered.push(item.clone()),
|
||||
ResponseItem::Reasoning { .. } | ResponseItem::Other => {
|
||||
| ResponseItem::FunctionCallOutput { .. }
|
||||
| ResponseItem::Reasoning { .. } => filtered.push(item.clone()),
|
||||
ResponseItem::Other => {
|
||||
// These should never be serialized.
|
||||
continue;
|
||||
}
|
||||
@@ -179,13 +178,17 @@ impl RolloutRecorder {
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if let Ok(item) = serde_json::from_value::<ResponseItem>(v.clone()) {
|
||||
match item {
|
||||
match serde_json::from_value::<ResponseItem>(v.clone()) {
|
||||
Ok(item) => match item {
|
||||
ResponseItem::Message { .. }
|
||||
| ResponseItem::LocalShellCall { .. }
|
||||
| ResponseItem::FunctionCall { .. }
|
||||
| ResponseItem::FunctionCallOutput { .. } => items.push(item),
|
||||
ResponseItem::Reasoning { .. } | ResponseItem::Other => {}
|
||||
| ResponseItem::FunctionCallOutput { .. }
|
||||
| ResponseItem::Reasoning { .. } => items.push(item),
|
||||
ResponseItem::Other => {}
|
||||
},
|
||||
Err(e) => {
|
||||
warn!("failed to parse item: {v:?}, error: {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -288,12 +291,14 @@ async fn rollout_writer(
|
||||
ResponseItem::Message { .. }
|
||||
| ResponseItem::LocalShellCall { .. }
|
||||
| ResponseItem::FunctionCall { .. }
|
||||
| ResponseItem::FunctionCallOutput { .. } => {
|
||||
if let Err(e) = write_json_line(&mut file, &item).await {
|
||||
warn!("Failed to write item: {e}");
|
||||
| ResponseItem::FunctionCallOutput { .. }
|
||||
| ResponseItem::Reasoning { .. } => {
|
||||
if let Ok(json) = serde_json::to_string(&item) {
|
||||
let _ = file.write_all(json.as_bytes()).await;
|
||||
let _ = file.write_all(b"\n").await;
|
||||
}
|
||||
}
|
||||
ResponseItem::Reasoning { .. } | ResponseItem::Other => {}
|
||||
ResponseItem::Other => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,165 +0,0 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use codex_core::Codex;
|
||||
use codex_core::ModelProviderInfo;
|
||||
use codex_core::exec::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR;
|
||||
use codex_core::protocol::ErrorEvent;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
mod test_support;
|
||||
use serde_json::Value;
|
||||
use tempfile::TempDir;
|
||||
use test_support::load_default_config_for_test;
|
||||
use test_support::load_sse_fixture_with_id;
|
||||
use tokio::time::timeout;
|
||||
use wiremock::Match;
|
||||
use wiremock::Mock;
|
||||
use wiremock::MockServer;
|
||||
use wiremock::Request;
|
||||
use wiremock::ResponseTemplate;
|
||||
use wiremock::matchers::method;
|
||||
use wiremock::matchers::path;
|
||||
|
||||
/// Matcher asserting that JSON body has NO `previous_response_id` field.
|
||||
struct NoPrevId;
|
||||
|
||||
impl Match for NoPrevId {
|
||||
fn matches(&self, req: &Request) -> bool {
|
||||
serde_json::from_slice::<Value>(&req.body)
|
||||
.map(|v| v.get("previous_response_id").is_none())
|
||||
.unwrap_or(false)
|
||||
}
|
||||
}
|
||||
|
||||
/// Matcher asserting that JSON body HAS a `previous_response_id` field.
|
||||
struct HasPrevId;
|
||||
|
||||
impl Match for HasPrevId {
|
||||
fn matches(&self, req: &Request) -> bool {
|
||||
serde_json::from_slice::<Value>(&req.body)
|
||||
.map(|v| v.get("previous_response_id").is_some())
|
||||
.unwrap_or(false)
|
||||
}
|
||||
}
|
||||
|
||||
/// Build minimal SSE stream with completed marker using the JSON fixture.
|
||||
fn sse_completed(id: &str) -> String {
|
||||
load_sse_fixture_with_id("tests/fixtures/completed_template.json", id)
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn keeps_previous_response_id_between_tasks() {
|
||||
#![allow(clippy::unwrap_used)]
|
||||
|
||||
if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
||||
println!(
|
||||
"Skipping test because it cannot execute when network is disabled in a Codex sandbox."
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
// Mock server
|
||||
let server = MockServer::start().await;
|
||||
|
||||
// First request – must NOT include `previous_response_id`.
|
||||
let first = ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "text/event-stream")
|
||||
.set_body_raw(sse_completed("resp1"), "text/event-stream");
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.and(NoPrevId)
|
||||
.respond_with(first)
|
||||
.expect(1)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
// Second request – MUST include `previous_response_id`.
|
||||
let second = ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "text/event-stream")
|
||||
.set_body_raw(sse_completed("resp2"), "text/event-stream");
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.and(HasPrevId)
|
||||
.respond_with(second)
|
||||
.expect(1)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
// Configure retry behavior explicitly to avoid mutating process-wide
|
||||
// environment variables.
|
||||
let model_provider = ModelProviderInfo {
|
||||
name: "openai".into(),
|
||||
base_url: format!("{}/v1", server.uri()),
|
||||
// Environment variable that should exist in the test environment.
|
||||
// ModelClient will return an error if the environment variable for the
|
||||
// provider is not set.
|
||||
env_key: Some("PATH".into()),
|
||||
env_key_instructions: None,
|
||||
wire_api: codex_core::WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
// disable retries so we don't get duplicate calls in this test
|
||||
request_max_retries: Some(0),
|
||||
stream_max_retries: Some(0),
|
||||
stream_idle_timeout_ms: None,
|
||||
};
|
||||
|
||||
// Init session
|
||||
let codex_home = TempDir::new().unwrap();
|
||||
let mut config = load_default_config_for_test(&codex_home);
|
||||
config.model_provider = model_provider;
|
||||
let ctrl_c = std::sync::Arc::new(tokio::sync::Notify::new());
|
||||
let (codex, _init_id, _session_id) = Codex::spawn(config, ctrl_c.clone()).await.unwrap();
|
||||
|
||||
// Task 1 – triggers first request (no previous_response_id)
|
||||
codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![InputItem::Text {
|
||||
text: "hello".into(),
|
||||
}],
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Wait for TaskComplete
|
||||
loop {
|
||||
let ev = timeout(Duration::from_secs(1), codex.next_event())
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
if matches!(ev.msg, EventMsg::TaskComplete(_)) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Task 2 – should include `previous_response_id` (triggers second request)
|
||||
codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![InputItem::Text {
|
||||
text: "again".into(),
|
||||
}],
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Wait for TaskComplete or error
|
||||
loop {
|
||||
let ev = timeout(Duration::from_secs(1), codex.next_event())
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
match ev.msg {
|
||||
EventMsg::TaskComplete(_) => break,
|
||||
EventMsg::Error(ErrorEvent { message }) => {
|
||||
panic!("unexpected error: {message}")
|
||||
}
|
||||
_ => {
|
||||
// Ignore other events.
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -17,7 +17,9 @@ workspace = true
|
||||
[dependencies]
|
||||
anyhow = "1"
|
||||
clap = { version = "4", features = ["derive"] }
|
||||
codex-common = { path = "../common", features = ["cli"] }
|
||||
codex-core = { path = "../core" }
|
||||
dotenvy = "0.15.7"
|
||||
tokio = { version = "1", features = ["rt-multi-thread"] }
|
||||
|
||||
[dev-dependencies]
|
||||
|
||||
@@ -43,6 +43,10 @@ where
|
||||
crate::run_main();
|
||||
}
|
||||
|
||||
// This modifies the environment, which is not thread-safe, so do this
|
||||
// before creating any threads/the Tokio runtime.
|
||||
load_dotenv();
|
||||
|
||||
// Regular invocation – create a Tokio runtime and execute the provided
|
||||
// async entry-point.
|
||||
let runtime = tokio::runtime::Runtime::new()?;
|
||||
@@ -61,3 +65,11 @@ where
|
||||
pub fn run_main() -> ! {
|
||||
panic!("codex-linux-sandbox is only supported on Linux");
|
||||
}
|
||||
|
||||
/// Load env vars from ~/.codex/.env and `$(pwd)/.env`.
|
||||
fn load_dotenv() {
|
||||
if let Ok(codex_home) = codex_core::config::find_codex_home() {
|
||||
dotenvy::from_path(codex_home.join(".env")).ok();
|
||||
}
|
||||
dotenvy::dotenv().ok();
|
||||
}
|
||||
|
||||
@@ -168,7 +168,7 @@ impl CodexToolCallParam {
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub(crate) struct CodexToolCallReplyParam {
|
||||
pub struct CodexToolCallReplyParam {
|
||||
/// The *session id* for this conversation.
|
||||
pub session_id: String,
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ use mcp_types::CallToolResult;
|
||||
use mcp_types::ContentBlock;
|
||||
use mcp_types::RequestId;
|
||||
use mcp_types::TextContent;
|
||||
use serde_json::json;
|
||||
use tokio::sync::Mutex;
|
||||
use uuid::Uuid;
|
||||
|
||||
@@ -39,6 +40,7 @@ pub async fn run_codex_tool_session(
|
||||
config: CodexConfig,
|
||||
outgoing: Arc<OutgoingMessageSender>,
|
||||
session_map: Arc<Mutex<HashMap<Uuid, Arc<Codex>>>>,
|
||||
running_requests_id_to_codex_uuid: Arc<Mutex<HashMap<RequestId, Uuid>>>,
|
||||
) {
|
||||
let (codex, first_event, _ctrl_c, session_id) = match init_codex(config).await {
|
||||
Ok(res) => res,
|
||||
@@ -73,7 +75,10 @@ pub async fn run_codex_tool_session(
|
||||
RequestId::String(s) => s.clone(),
|
||||
RequestId::Integer(n) => n.to_string(),
|
||||
};
|
||||
|
||||
running_requests_id_to_codex_uuid
|
||||
.lock()
|
||||
.await
|
||||
.insert(id.clone(), session_id);
|
||||
let submission = Submission {
|
||||
id: sub_id.clone(),
|
||||
op: Op::UserInput {
|
||||
@@ -85,9 +90,12 @@ pub async fn run_codex_tool_session(
|
||||
|
||||
if let Err(e) = codex.submit_with_id(submission).await {
|
||||
tracing::error!("Failed to submit initial prompt: {e}");
|
||||
// unregister the id so we don't keep it in the map
|
||||
running_requests_id_to_codex_uuid.lock().await.remove(&id);
|
||||
return;
|
||||
}
|
||||
|
||||
run_codex_tool_session_inner(codex, outgoing, id).await;
|
||||
run_codex_tool_session_inner(codex, outgoing, id, running_requests_id_to_codex_uuid).await;
|
||||
}
|
||||
|
||||
pub async fn run_codex_tool_session_reply(
|
||||
@@ -95,7 +103,13 @@ pub async fn run_codex_tool_session_reply(
|
||||
outgoing: Arc<OutgoingMessageSender>,
|
||||
request_id: RequestId,
|
||||
prompt: String,
|
||||
running_requests_id_to_codex_uuid: Arc<Mutex<HashMap<RequestId, Uuid>>>,
|
||||
session_id: Uuid,
|
||||
) {
|
||||
running_requests_id_to_codex_uuid
|
||||
.lock()
|
||||
.await
|
||||
.insert(request_id.clone(), session_id);
|
||||
if let Err(e) = codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![InputItem::Text { text: prompt }],
|
||||
@@ -103,15 +117,28 @@ pub async fn run_codex_tool_session_reply(
|
||||
.await
|
||||
{
|
||||
tracing::error!("Failed to submit user input: {e}");
|
||||
// unregister the id so we don't keep it in the map
|
||||
running_requests_id_to_codex_uuid
|
||||
.lock()
|
||||
.await
|
||||
.remove(&request_id);
|
||||
return;
|
||||
}
|
||||
|
||||
run_codex_tool_session_inner(codex, outgoing, request_id).await;
|
||||
run_codex_tool_session_inner(
|
||||
codex,
|
||||
outgoing,
|
||||
request_id,
|
||||
running_requests_id_to_codex_uuid,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
async fn run_codex_tool_session_inner(
|
||||
codex: Arc<Codex>,
|
||||
outgoing: Arc<OutgoingMessageSender>,
|
||||
request_id: RequestId,
|
||||
running_requests_id_to_codex_uuid: Arc<Mutex<HashMap<RequestId, Uuid>>>,
|
||||
) {
|
||||
let request_id_str = match &request_id {
|
||||
RequestId::String(s) => s.clone(),
|
||||
@@ -129,6 +156,7 @@ async fn run_codex_tool_session_inner(
|
||||
EventMsg::ExecApprovalRequest(ExecApprovalRequestEvent {
|
||||
command,
|
||||
cwd,
|
||||
call_id,
|
||||
reason: _,
|
||||
}) => {
|
||||
handle_exec_approval_request(
|
||||
@@ -139,16 +167,27 @@ async fn run_codex_tool_session_inner(
|
||||
request_id.clone(),
|
||||
request_id_str.clone(),
|
||||
event.id.clone(),
|
||||
call_id,
|
||||
)
|
||||
.await;
|
||||
continue;
|
||||
}
|
||||
EventMsg::Error(err_event) => {
|
||||
// Return a response to conclude the tool call when the Codex session reports an error (e.g., interruption).
|
||||
let result = json!({
|
||||
"error": err_event.message,
|
||||
});
|
||||
outgoing.send_response(request_id.clone(), result).await;
|
||||
break;
|
||||
}
|
||||
EventMsg::ApplyPatchApprovalRequest(ApplyPatchApprovalRequestEvent {
|
||||
call_id,
|
||||
reason,
|
||||
grant_root,
|
||||
changes,
|
||||
}) => {
|
||||
handle_patch_approval_request(
|
||||
call_id,
|
||||
reason,
|
||||
grant_root,
|
||||
changes,
|
||||
@@ -178,6 +217,11 @@ async fn run_codex_tool_session_inner(
|
||||
outgoing
|
||||
.send_response(request_id.clone(), result.into())
|
||||
.await;
|
||||
// unregister the id so we don't keep it in the map
|
||||
running_requests_id_to_codex_uuid
|
||||
.lock()
|
||||
.await
|
||||
.remove(&request_id);
|
||||
break;
|
||||
}
|
||||
EventMsg::SessionConfigured(_) => {
|
||||
@@ -192,8 +236,7 @@ async fn run_codex_tool_session_inner(
|
||||
EventMsg::AgentMessage(AgentMessageEvent { .. }) => {
|
||||
// TODO: think how we want to support this in the MCP
|
||||
}
|
||||
EventMsg::Error(_)
|
||||
| EventMsg::TaskStarted
|
||||
EventMsg::TaskStarted
|
||||
| EventMsg::TokenCount(_)
|
||||
| EventMsg::AgentReasoning(_)
|
||||
| EventMsg::McpToolCallBegin(_)
|
||||
|
||||
@@ -32,6 +32,7 @@ pub struct ExecApprovalElicitRequestParams {
|
||||
pub codex_elicitation: String,
|
||||
pub codex_mcp_tool_call_id: String,
|
||||
pub codex_event_id: String,
|
||||
pub codex_call_id: String,
|
||||
pub codex_command: Vec<String>,
|
||||
pub codex_cwd: PathBuf,
|
||||
}
|
||||
@@ -45,6 +46,7 @@ pub struct ExecApprovalResponse {
|
||||
pub decision: ReviewDecision,
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) async fn handle_exec_approval_request(
|
||||
command: Vec<String>,
|
||||
cwd: PathBuf,
|
||||
@@ -53,6 +55,7 @@ pub(crate) async fn handle_exec_approval_request(
|
||||
request_id: RequestId,
|
||||
tool_call_id: String,
|
||||
event_id: String,
|
||||
call_id: String,
|
||||
) {
|
||||
let escaped_command =
|
||||
shlex::try_join(command.iter().map(|s| s.as_str())).unwrap_or_else(|_| command.join(" "));
|
||||
@@ -71,6 +74,7 @@ pub(crate) async fn handle_exec_approval_request(
|
||||
codex_elicitation: "exec-approval".to_string(),
|
||||
codex_mcp_tool_call_id: tool_call_id.clone(),
|
||||
codex_event_id: event_id.clone(),
|
||||
codex_call_id: call_id,
|
||||
codex_command: command,
|
||||
codex_cwd: cwd,
|
||||
};
|
||||
|
||||
@@ -27,6 +27,7 @@ use crate::outgoing_message::OutgoingMessage;
|
||||
use crate::outgoing_message::OutgoingMessageSender;
|
||||
|
||||
pub use crate::codex_tool_config::CodexToolCallParam;
|
||||
pub use crate::codex_tool_config::CodexToolCallReplyParam;
|
||||
pub use crate::exec_approval::ExecApprovalElicitRequestParams;
|
||||
pub use crate::exec_approval::ExecApprovalResponse;
|
||||
pub use crate::patch_approval::PatchApprovalElicitRequestParams;
|
||||
@@ -81,7 +82,7 @@ pub async fn run_main(codex_linux_sandbox_exe: Option<PathBuf>) -> IoResult<()>
|
||||
match msg {
|
||||
JSONRPCMessage::Request(r) => processor.process_request(r).await,
|
||||
JSONRPCMessage::Response(r) => processor.process_response(r).await,
|
||||
JSONRPCMessage::Notification(n) => processor.process_notification(n),
|
||||
JSONRPCMessage::Notification(n) => processor.process_notification(n).await,
|
||||
JSONRPCMessage::Error(e) => processor.process_error(e),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ use crate::outgoing_message::OutgoingMessageSender;
|
||||
|
||||
use codex_core::Codex;
|
||||
use codex_core::config::Config as CodexConfig;
|
||||
use codex_core::protocol::Submission;
|
||||
use mcp_types::CallToolRequestParams;
|
||||
use mcp_types::CallToolResult;
|
||||
use mcp_types::ClientRequest;
|
||||
@@ -35,6 +36,7 @@ pub(crate) struct MessageProcessor {
|
||||
initialized: bool,
|
||||
codex_linux_sandbox_exe: Option<PathBuf>,
|
||||
session_map: Arc<Mutex<HashMap<Uuid, Arc<Codex>>>>,
|
||||
running_requests_id_to_codex_uuid: Arc<Mutex<HashMap<RequestId, Uuid>>>,
|
||||
}
|
||||
|
||||
impl MessageProcessor {
|
||||
@@ -49,6 +51,7 @@ impl MessageProcessor {
|
||||
initialized: false,
|
||||
codex_linux_sandbox_exe,
|
||||
session_map: Arc::new(Mutex::new(HashMap::new())),
|
||||
running_requests_id_to_codex_uuid: Arc::new(Mutex::new(HashMap::new())),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -116,7 +119,7 @@ impl MessageProcessor {
|
||||
}
|
||||
|
||||
/// Handle a fire-and-forget JSON-RPC notification.
|
||||
pub(crate) fn process_notification(&mut self, notification: JSONRPCNotification) {
|
||||
pub(crate) async fn process_notification(&mut self, notification: JSONRPCNotification) {
|
||||
let server_notification = match ServerNotification::try_from(notification) {
|
||||
Ok(n) => n,
|
||||
Err(e) => {
|
||||
@@ -129,7 +132,7 @@ impl MessageProcessor {
|
||||
// handler so additional logic can be implemented incrementally.
|
||||
match server_notification {
|
||||
ServerNotification::CancelledNotification(params) => {
|
||||
self.handle_cancelled_notification(params);
|
||||
self.handle_cancelled_notification(params).await;
|
||||
}
|
||||
ServerNotification::ProgressNotification(params) => {
|
||||
self.handle_progress_notification(params);
|
||||
@@ -379,6 +382,7 @@ impl MessageProcessor {
|
||||
// Clone outgoing and session map to move into async task.
|
||||
let outgoing = self.outgoing.clone();
|
||||
let session_map = self.session_map.clone();
|
||||
let running_requests_id_to_codex_uuid = self.running_requests_id_to_codex_uuid.clone();
|
||||
|
||||
// Spawn an async task to handle the Codex session so that we do not
|
||||
// block the synchronous message-processing loop.
|
||||
@@ -390,6 +394,7 @@ impl MessageProcessor {
|
||||
config,
|
||||
outgoing,
|
||||
session_map,
|
||||
running_requests_id_to_codex_uuid,
|
||||
)
|
||||
.await;
|
||||
});
|
||||
@@ -464,13 +469,12 @@ impl MessageProcessor {
|
||||
|
||||
// Clone outgoing and session map to move into async task.
|
||||
let outgoing = self.outgoing.clone();
|
||||
let running_requests_id_to_codex_uuid = self.running_requests_id_to_codex_uuid.clone();
|
||||
|
||||
// Spawn an async task to handle the Codex session so that we do not
|
||||
// block the synchronous message-processing loop.
|
||||
task::spawn(async move {
|
||||
let codex = {
|
||||
let session_map = session_map_mutex.lock().await;
|
||||
let codex = match session_map.get(&session_id) {
|
||||
Some(codex) => codex,
|
||||
match session_map.get(&session_id).cloned() {
|
||||
Some(c) => c,
|
||||
None => {
|
||||
tracing::warn!("Session not found for session_id: {session_id}");
|
||||
let result = CallToolResult {
|
||||
@@ -482,21 +486,32 @@ impl MessageProcessor {
|
||||
is_error: Some(true),
|
||||
structured_content: None,
|
||||
};
|
||||
// unwrap_or_default is fine here because we know the result is valid JSON
|
||||
outgoing
|
||||
.send_response(request_id, serde_json::to_value(result).unwrap_or_default())
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
crate::codex_tool_runner::run_codex_tool_session_reply(
|
||||
codex.clone(),
|
||||
outgoing,
|
||||
request_id,
|
||||
prompt.clone(),
|
||||
)
|
||||
.await;
|
||||
// Spawn the long-running reply handler.
|
||||
tokio::spawn({
|
||||
let codex = codex.clone();
|
||||
let outgoing = outgoing.clone();
|
||||
let prompt = prompt.clone();
|
||||
let running_requests_id_to_codex_uuid = running_requests_id_to_codex_uuid.clone();
|
||||
|
||||
async move {
|
||||
crate::codex_tool_runner::run_codex_tool_session_reply(
|
||||
codex,
|
||||
outgoing,
|
||||
request_id,
|
||||
prompt,
|
||||
running_requests_id_to_codex_uuid,
|
||||
session_id,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@@ -518,11 +533,58 @@ impl MessageProcessor {
|
||||
// Notification handlers
|
||||
// ---------------------------------------------------------------------
|
||||
|
||||
fn handle_cancelled_notification(
|
||||
async fn handle_cancelled_notification(
|
||||
&self,
|
||||
params: <mcp_types::CancelledNotification as mcp_types::ModelContextProtocolNotification>::Params,
|
||||
) {
|
||||
tracing::info!("notifications/cancelled -> params: {:?}", params);
|
||||
let request_id = params.request_id;
|
||||
// Create a stable string form early for logging and submission id.
|
||||
let request_id_string = match &request_id {
|
||||
RequestId::String(s) => s.clone(),
|
||||
RequestId::Integer(i) => i.to_string(),
|
||||
};
|
||||
|
||||
// Obtain the session_id while holding the first lock, then release.
|
||||
let session_id = {
|
||||
let map_guard = self.running_requests_id_to_codex_uuid.lock().await;
|
||||
match map_guard.get(&request_id) {
|
||||
Some(id) => *id, // Uuid is Copy
|
||||
None => {
|
||||
tracing::warn!("Session not found for request_id: {}", request_id_string);
|
||||
return;
|
||||
}
|
||||
}
|
||||
};
|
||||
tracing::info!("session_id: {session_id}");
|
||||
|
||||
// Obtain the Codex Arc while holding the session_map lock, then release.
|
||||
let codex_arc = {
|
||||
let sessions_guard = self.session_map.lock().await;
|
||||
match sessions_guard.get(&session_id) {
|
||||
Some(codex) => Arc::clone(codex),
|
||||
None => {
|
||||
tracing::warn!("Session not found for session_id: {session_id}");
|
||||
return;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Submit interrupt to Codex.
|
||||
let err = codex_arc
|
||||
.submit_with_id(Submission {
|
||||
id: request_id_string,
|
||||
op: codex_core::protocol::Op::Interrupt,
|
||||
})
|
||||
.await;
|
||||
if let Err(e) = err {
|
||||
tracing::error!("Failed to submit interrupt to Codex: {e}");
|
||||
return;
|
||||
}
|
||||
// unregister the id so we don't keep it in the map
|
||||
self.running_requests_id_to_codex_uuid
|
||||
.lock()
|
||||
.await
|
||||
.remove(&request_id);
|
||||
}
|
||||
|
||||
fn handle_progress_notification(
|
||||
|
||||
@@ -27,6 +27,7 @@ pub struct PatchApprovalElicitRequestParams {
|
||||
pub codex_elicitation: String,
|
||||
pub codex_mcp_tool_call_id: String,
|
||||
pub codex_event_id: String,
|
||||
pub codex_call_id: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub codex_reason: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
@@ -41,6 +42,7 @@ pub struct PatchApprovalResponse {
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) async fn handle_patch_approval_request(
|
||||
call_id: String,
|
||||
reason: Option<String>,
|
||||
grant_root: Option<PathBuf>,
|
||||
changes: HashMap<PathBuf, FileChange>,
|
||||
@@ -66,6 +68,7 @@ pub(crate) async fn handle_patch_approval_request(
|
||||
codex_elicitation: "patch-approval".to_string(),
|
||||
codex_mcp_tool_call_id: tool_call_id.clone(),
|
||||
codex_event_id: event_id.clone(),
|
||||
codex_call_id: call_id,
|
||||
codex_reason: reason,
|
||||
codex_grant_root: grant_root,
|
||||
codex_changes: changes,
|
||||
|
||||
@@ -171,6 +171,7 @@ fn create_expected_elicitation_request(
|
||||
codex_event_id,
|
||||
codex_command: command,
|
||||
codex_cwd: workdir.to_path_buf(),
|
||||
codex_call_id: "call1234".to_string(),
|
||||
})?),
|
||||
})
|
||||
}
|
||||
@@ -384,6 +385,7 @@ fn create_expected_patch_approval_elicitation_request(
|
||||
codex_reason: reason,
|
||||
codex_grant_root: grant_root,
|
||||
codex_changes: changes,
|
||||
codex_call_id: "call1234".to_string(),
|
||||
})?),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ use tokio::process::ChildStdout;
|
||||
use anyhow::Context;
|
||||
use assert_cmd::prelude::*;
|
||||
use codex_mcp_server::CodexToolCallParam;
|
||||
use codex_mcp_server::CodexToolCallReplyParam;
|
||||
use mcp_types::CallToolRequestParams;
|
||||
use mcp_types::ClientCapabilities;
|
||||
use mcp_types::Implementation;
|
||||
@@ -154,6 +155,25 @@ impl McpProcess {
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn send_codex_reply_tool_call(
|
||||
&mut self,
|
||||
session_id: &str,
|
||||
prompt: &str,
|
||||
) -> anyhow::Result<i64> {
|
||||
let codex_tool_call_params = CallToolRequestParams {
|
||||
name: "codex-reply".to_string(),
|
||||
arguments: Some(serde_json::to_value(CodexToolCallReplyParam {
|
||||
prompt: prompt.to_string(),
|
||||
session_id: session_id.to_string(),
|
||||
})?),
|
||||
};
|
||||
self.send_request(
|
||||
mcp_types::CallToolRequest::METHOD,
|
||||
Some(serde_json::to_value(codex_tool_call_params)?),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn send_request(
|
||||
&mut self,
|
||||
method: &str,
|
||||
@@ -171,6 +191,8 @@ impl McpProcess {
|
||||
Ok(request_id)
|
||||
}
|
||||
|
||||
// allow dead code
|
||||
#[allow(dead_code)]
|
||||
pub async fn send_response(
|
||||
&mut self,
|
||||
id: RequestId,
|
||||
@@ -198,7 +220,8 @@ impl McpProcess {
|
||||
let message = serde_json::from_str::<JSONRPCMessage>(&line)?;
|
||||
Ok(message)
|
||||
}
|
||||
|
||||
// allow dead code
|
||||
#[allow(dead_code)]
|
||||
pub async fn read_stream_until_request_message(&mut self) -> anyhow::Result<JSONRPCRequest> {
|
||||
loop {
|
||||
let message = self.read_jsonrpc_message().await?;
|
||||
@@ -221,6 +244,8 @@ impl McpProcess {
|
||||
}
|
||||
}
|
||||
|
||||
// allow dead code
|
||||
#[allow(dead_code)]
|
||||
pub async fn read_stream_until_response_message(
|
||||
&mut self,
|
||||
request_id: RequestId,
|
||||
@@ -247,4 +272,58 @@ impl McpProcess {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn read_stream_until_configured_response_message(
|
||||
&mut self,
|
||||
) -> anyhow::Result<String> {
|
||||
loop {
|
||||
let message = self.read_jsonrpc_message().await?;
|
||||
eprint!("message: {message:?}");
|
||||
|
||||
match message {
|
||||
JSONRPCMessage::Notification(notification) => {
|
||||
if notification.method == "codex/event" {
|
||||
if let Some(params) = notification.params {
|
||||
if let Some(msg) = params.get("msg") {
|
||||
if let Some(msg_type) = msg.get("type") {
|
||||
if msg_type == "session_configured" {
|
||||
if let Some(session_id) = msg.get("session_id") {
|
||||
return Ok(session_id
|
||||
.to_string()
|
||||
.trim_matches('"')
|
||||
.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
JSONRPCMessage::Request(_) => {
|
||||
anyhow::bail!("unexpected JSONRPCMessage::Request: {message:?}");
|
||||
}
|
||||
JSONRPCMessage::Error(_) => {
|
||||
anyhow::bail!("unexpected JSONRPCMessage::Error: {message:?}");
|
||||
}
|
||||
JSONRPCMessage::Response(_) => {
|
||||
anyhow::bail!("unexpected JSONRPCMessage::Response: {message:?}");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// allow dead code
|
||||
#[allow(dead_code)]
|
||||
pub async fn send_notification(
|
||||
&mut self,
|
||||
method: &str,
|
||||
params: Option<serde_json::Value>,
|
||||
) -> anyhow::Result<()> {
|
||||
self.send_jsonrpc_message(JSONRPCMessage::Notification(JSONRPCNotification {
|
||||
jsonrpc: JSONRPC_VERSION.into(),
|
||||
method: method.to_string(),
|
||||
params,
|
||||
}))
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,8 @@ mod responses;
|
||||
|
||||
pub use mcp_process::McpProcess;
|
||||
pub use mock_model_server::create_mock_chat_completions_server;
|
||||
#[allow(unused_imports)]
|
||||
pub use responses::create_apply_patch_sse_response;
|
||||
#[allow(unused_imports)]
|
||||
pub use responses::create_final_assistant_message_sse_response;
|
||||
pub use responses::create_shell_sse_response;
|
||||
|
||||
@@ -39,6 +39,8 @@ pub fn create_shell_sse_response(
|
||||
Ok(sse)
|
||||
}
|
||||
|
||||
// allow dead code
|
||||
#[allow(dead_code)]
|
||||
pub fn create_final_assistant_message_sse_response(message: &str) -> anyhow::Result<String> {
|
||||
let assistant_message = json!({
|
||||
"choices": [
|
||||
@@ -58,6 +60,8 @@ pub fn create_final_assistant_message_sse_response(message: &str) -> anyhow::Res
|
||||
Ok(sse)
|
||||
}
|
||||
|
||||
// allow dead code
|
||||
#[allow(dead_code)]
|
||||
pub fn create_apply_patch_sse_response(
|
||||
patch_content: &str,
|
||||
call_id: &str,
|
||||
|
||||
176
codex-rs/mcp-server/tests/interrupt.rs
Normal file
176
codex-rs/mcp-server/tests/interrupt.rs
Normal file
@@ -0,0 +1,176 @@
|
||||
#![cfg(unix)]
|
||||
mod common;
|
||||
|
||||
use std::path::Path;
|
||||
|
||||
use codex_core::exec::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR;
|
||||
use codex_mcp_server::CodexToolCallParam;
|
||||
use mcp_types::JSONRPCResponse;
|
||||
use mcp_types::RequestId;
|
||||
use serde_json::json;
|
||||
use tempfile::TempDir;
|
||||
use tokio::time::timeout;
|
||||
|
||||
use crate::common::McpProcess;
|
||||
use crate::common::create_mock_chat_completions_server;
|
||||
use crate::common::create_shell_sse_response;
|
||||
|
||||
const DEFAULT_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn test_shell_command_interruption() {
|
||||
if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
||||
println!(
|
||||
"Skipping test because it cannot execute when network is disabled in a Codex sandbox."
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
if let Err(err) = shell_command_interruption().await {
|
||||
panic!("failure: {err}");
|
||||
}
|
||||
}
|
||||
|
||||
async fn shell_command_interruption() -> anyhow::Result<()> {
|
||||
// Use a cross-platform blocking command. On Windows plain `sleep` is not guaranteed to exist
|
||||
// (MSYS/GNU coreutils may be absent) and the failure causes the tool call to finish immediately,
|
||||
// which triggers a second model request before the test sends the explicit follow-up. That
|
||||
// prematurely consumes the second mocked SSE response and leads to a third POST (panic: no response for 2).
|
||||
// Powershell Start-Sleep is always available on Windows runners. On Unix we keep using `sleep`.
|
||||
#[cfg(target_os = "windows")]
|
||||
let shell_command = vec![
|
||||
"powershell".to_string(),
|
||||
"-Command".to_string(),
|
||||
"Start-Sleep -Seconds 60".to_string(),
|
||||
];
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
let shell_command = vec!["sleep".to_string(), "60".to_string()];
|
||||
let workdir_for_shell_function_call = TempDir::new()?;
|
||||
|
||||
// Create mock server with a single SSE response: the long sleep command
|
||||
let server = create_mock_chat_completions_server(vec![
|
||||
create_shell_sse_response(
|
||||
shell_command.clone(),
|
||||
Some(workdir_for_shell_function_call.path()),
|
||||
Some(60_000), // 60 seconds timeout in ms
|
||||
"call_sleep",
|
||||
)?,
|
||||
create_shell_sse_response(
|
||||
shell_command.clone(),
|
||||
Some(workdir_for_shell_function_call.path()),
|
||||
Some(60_000), // 60 seconds timeout in ms
|
||||
"call_sleep",
|
||||
)?,
|
||||
])
|
||||
.await;
|
||||
|
||||
// Create Codex configuration
|
||||
let codex_home = TempDir::new()?;
|
||||
create_config_toml(codex_home.path(), server.uri())?;
|
||||
let mut mcp_process = McpProcess::new(codex_home.path()).await?;
|
||||
timeout(DEFAULT_READ_TIMEOUT, mcp_process.initialize()).await??;
|
||||
|
||||
// Send codex tool call that triggers "sleep 60"
|
||||
let codex_request_id = mcp_process
|
||||
.send_codex_tool_call(CodexToolCallParam {
|
||||
cwd: None,
|
||||
prompt: "First Run: run `sleep 60`".to_string(),
|
||||
model: None,
|
||||
profile: None,
|
||||
approval_policy: None,
|
||||
sandbox: None,
|
||||
config: None,
|
||||
base_instructions: None,
|
||||
})
|
||||
.await?;
|
||||
|
||||
let session_id = mcp_process
|
||||
.read_stream_until_configured_response_message()
|
||||
.await?;
|
||||
|
||||
// Give the command a moment to start
|
||||
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
|
||||
|
||||
// Send interrupt notification
|
||||
mcp_process
|
||||
.send_notification(
|
||||
"notifications/cancelled",
|
||||
Some(json!({ "requestId": codex_request_id })),
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Expect Codex to return an error or interruption response
|
||||
let codex_response: JSONRPCResponse = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp_process.read_stream_until_response_message(RequestId::Integer(codex_request_id)),
|
||||
)
|
||||
.await??;
|
||||
|
||||
assert!(
|
||||
codex_response
|
||||
.result
|
||||
.as_object()
|
||||
.map(|o| o.contains_key("error"))
|
||||
.unwrap_or(false),
|
||||
"Expected an interruption or error result, got: {codex_response:?}"
|
||||
);
|
||||
|
||||
let codex_reply_request_id = mcp_process
|
||||
.send_codex_reply_tool_call(&session_id, "Second Run: run `sleep 60`")
|
||||
.await?;
|
||||
|
||||
// Give the command a moment to start
|
||||
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
|
||||
|
||||
// Send interrupt notification
|
||||
mcp_process
|
||||
.send_notification(
|
||||
"notifications/cancelled",
|
||||
Some(json!({ "requestId": codex_reply_request_id })),
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Expect Codex to return an error or interruption response
|
||||
let codex_response: JSONRPCResponse = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp_process.read_stream_until_response_message(RequestId::Integer(codex_reply_request_id)),
|
||||
)
|
||||
.await??;
|
||||
|
||||
assert!(
|
||||
codex_response
|
||||
.result
|
||||
.as_object()
|
||||
.map(|o| o.contains_key("error"))
|
||||
.unwrap_or(false),
|
||||
"Expected an interruption or error result, got: {codex_response:?}"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
fn create_config_toml(codex_home: &Path, server_uri: String) -> std::io::Result<()> {
|
||||
let config_toml = codex_home.join("config.toml");
|
||||
std::fs::write(
|
||||
config_toml,
|
||||
format!(
|
||||
r#"
|
||||
model = "mock-model"
|
||||
approval_policy = "never"
|
||||
sandbox_mode = "danger-full-access"
|
||||
|
||||
model_provider = "mock_provider"
|
||||
|
||||
[model_providers.mock_provider]
|
||||
name = "Mock provider for test"
|
||||
base_url = "{server_uri}/v1"
|
||||
wire_api = "chat"
|
||||
request_max_retries = 0
|
||||
stream_max_retries = 0
|
||||
"#
|
||||
),
|
||||
)
|
||||
}
|
||||
@@ -314,6 +314,7 @@ impl ChatWidget<'_> {
|
||||
self.bottom_pane.set_task_running(false);
|
||||
}
|
||||
EventMsg::ExecApprovalRequest(ExecApprovalRequestEvent {
|
||||
call_id: _,
|
||||
command,
|
||||
cwd,
|
||||
reason,
|
||||
@@ -327,6 +328,7 @@ impl ChatWidget<'_> {
|
||||
self.bottom_pane.push_approval_request(request);
|
||||
}
|
||||
EventMsg::ApplyPatchApprovalRequest(ApplyPatchApprovalRequestEvent {
|
||||
call_id: _,
|
||||
changes,
|
||||
reason,
|
||||
grant_root,
|
||||
@@ -362,7 +364,7 @@ impl ChatWidget<'_> {
|
||||
cwd: _,
|
||||
}) => {
|
||||
self.conversation_history
|
||||
.add_active_exec_command(call_id, command);
|
||||
.reset_or_add_active_exec_command(call_id, command);
|
||||
self.request_redraw();
|
||||
}
|
||||
EventMsg::PatchApplyBegin(PatchApplyBeginEvent {
|
||||
|
||||
@@ -235,6 +235,30 @@ impl ConversationHistoryWidget {
|
||||
self.add_to_history(HistoryCell::new_active_exec_command(call_id, command));
|
||||
}
|
||||
|
||||
/// If an ActiveExecCommand with the same call_id already exists, replace
|
||||
/// it with a fresh one (resetting start time and view). Otherwise, add a new entry.
|
||||
pub fn reset_or_add_active_exec_command(&mut self, call_id: String, command: Vec<String>) {
|
||||
// Find the most recent matching ActiveExecCommand.
|
||||
let maybe_idx = self.entries.iter().rposition(|entry| {
|
||||
if let HistoryCell::ActiveExecCommand { call_id: id, .. } = &entry.cell {
|
||||
id == &call_id
|
||||
} else {
|
||||
false
|
||||
}
|
||||
});
|
||||
|
||||
if let Some(idx) = maybe_idx {
|
||||
let width = self.cached_width.get();
|
||||
self.entries[idx].cell = HistoryCell::new_active_exec_command(call_id.clone(), command);
|
||||
if width > 0 {
|
||||
let height = self.entries[idx].cell.height(width);
|
||||
self.entries[idx].line_count.set(height);
|
||||
}
|
||||
} else {
|
||||
self.add_active_exec_command(call_id, command);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add_active_mcp_tool_call(
|
||||
&mut self,
|
||||
call_id: String,
|
||||
|
||||
Reference in New Issue
Block a user