From c2db9386aa7ac5b11c2a0d52e7fddb8f96a9a2dd Mon Sep 17 00:00:00 2001 From: Friel Date: Thu, 9 Apr 2026 19:00:41 +0000 Subject: [PATCH] Inherit forked agent prompt cache keys (cherry picked from commit 61807cc6cf4c870ce643394739f99b8076b69894) --- codex-rs/README.md | 3 +- codex-rs/cli/src/main.rs | 76 +++-- codex-rs/codex-api/src/endpoint/responses.rs | 5 +- codex-rs/codex-api/tests/clients.rs | 7 + codex-rs/core/src/agent/control.rs | 220 ++++++++++--- codex-rs/core/src/agent/control_tests.rs | 306 +++++++++++++++++- codex-rs/core/src/client.rs | 155 ++++++++- codex-rs/core/src/client_tests.rs | 73 ++++- codex-rs/core/src/codex_delegate.rs | 1 + codex-rs/core/src/compact.rs | 46 ++- codex-rs/core/src/compact_remote.rs | 26 +- codex-rs/core/src/compact_tests.rs | 96 ++++++ codex-rs/core/src/inherited_thread_state.rs | 64 ++++ codex-rs/core/src/lib.rs | 1 + codex-rs/core/src/session/mcp.rs | 2 + codex-rs/core/src/session/mod.rs | 14 + codex-rs/core/src/session/session.rs | 8 +- codex-rs/core/src/session/tests.rs | 135 ++++++++ .../core/src/session/tests/guardian_tests.rs | 1 + codex-rs/core/src/session/turn.rs | 19 +- codex-rs/core/src/state/mod.rs | 1 + codex-rs/core/src/state/service.rs | 7 + codex-rs/core/src/thread_manager.rs | 13 + codex-rs/core/tests/responses_headers.rs | 3 + codex-rs/core/tests/suite/client.rs | 2 + .../core/tests/suite/client_websockets.rs | 1 + codex-rs/exec/src/cli.rs | 20 +- codex-rs/exec/src/cli_tests.rs | 19 ++ codex-rs/exec/src/lib.rs | 177 +++++++--- codex-rs/exec/src/lib_tests.rs | 97 ++++++ codex-rs/exec/src/main.rs | 5 +- codex-rs/exec/src/main_tests.rs | 21 ++ codex-rs/exec/tests/suite/fork.rs | 173 ++++++++++ codex-rs/exec/tests/suite/mod.rs | 1 + 34 files changed, 1640 insertions(+), 158 deletions(-) create mode 100644 codex-rs/core/src/inherited_thread_state.rs create mode 100644 codex-rs/exec/tests/suite/fork.rs diff --git a/codex-rs/README.md b/codex-rs/README.md index 2cc3a6b8f1..d348ae0701 100644 --- a/codex-rs/README.md +++ b/codex-rs/README.md @@ -50,7 +50,8 @@ The legacy `notify` setting is deprecated and will be removed in a future releas ### `codex exec` to run Codex programmatically/non-interactively -To run Codex non-interactively, run `codex exec PROMPT` (you can also pass the prompt via `stdin`) and Codex will work on your task until it decides that it is done and exits. If you provide both a prompt argument and piped stdin, Codex appends stdin as a `` block after the prompt so patterns like `echo "my output" | codex exec "Summarize this concisely"` work naturally. Output is printed to the terminal directly. You can set the `RUST_LOG` environment variable to see more about what's going on. +To run Codex non-interactively, run `codex exec PROMPT` (you can also pass the prompt via `stdin`) and Codex will work on your task until it decides that it is done and exits. Output is printed to the terminal directly. You can set the `RUST_LOG` environment variable to see more about what's going on. +Use `codex exec --fork PROMPT` to fork an existing session without launching the interactive picker/UI. Use `codex exec --ephemeral ...` to run without persisting session rollout files to disk. ### Experimenting with the Codex Sandbox diff --git a/codex-rs/cli/src/main.rs b/codex-rs/cli/src/main.rs index 7b6e7448d4..af19d1081c 100644 --- a/codex-rs/cli/src/main.rs +++ b/codex-rs/cli/src/main.rs @@ -764,12 +764,16 @@ async fn cli_main(arg0_paths: Arg0DispatchPaths) -> anyhow::Result<()> { .await?; handle_app_exit(exit_info)?; } - Some(Subcommand::Exec(mut exec_cli)) => { + Some(Subcommand::Exec(exec_cli)) => { reject_remote_mode_for_subcommand( root_remote.as_deref(), root_remote_auth_token_env.as_deref(), "exec", )?; + let mut exec_cli = match exec_cli.validate() { + Ok(exec_cli) => exec_cli, + Err(err) => err.exit(), + }; exec_cli .shared .inherit_exec_root_options(&interactive.shared); @@ -1790,6 +1794,40 @@ mod tests { assert_eq!(args.session_id.as_deref(), Some("session-123")); assert_eq!(args.prompt.as_deref(), Some("re-review")); } + #[test] + fn exec_fork_accepts_prompt_positional() { + let cli = MultitoolCli::try_parse_from([ + "codex", + "exec", + "--json", + "--fork", + "session-123", + "2+2", + ]) + .expect("parse should succeed"); + + let Some(Subcommand::Exec(exec)) = cli.subcommand else { + panic!("expected exec subcommand"); + }; + + assert_eq!(exec.fork_session_id.as_deref(), Some("session-123")); + assert!(exec.command.is_none()); + assert_eq!(exec.prompt.as_deref(), Some("2+2")); + } + + #[test] + fn exec_fork_conflicts_with_resume_subcommand() { + let cli = + MultitoolCli::try_parse_from(["codex", "exec", "--fork", "session-123", "resume"]) + .expect("parse should succeed"); + + let Some(Subcommand::Exec(exec)) = cli.subcommand else { + panic!("expected exec subcommand"); + }; + + let validate_result = exec.validate(); + assert!(validate_result.is_err()); + } #[test] fn dangerous_bypass_conflicts_with_approval_policy() { @@ -2028,17 +2066,14 @@ mod tests { update_action: None, exit_reason: ExitReason::UserRequested, }; - let lines = format_exit_messages(exit_info, /*color_enabled*/ false); + let lines = format_exit_messages(exit_info, false); assert!(lines.is_empty()); } #[test] fn format_exit_messages_includes_resume_hint_without_color() { - let exit_info = sample_exit_info( - Some("123e4567-e89b-12d3-a456-426614174000"), - /*thread_name*/ None, - ); - let lines = format_exit_messages(exit_info, /*color_enabled*/ false); + let exit_info = sample_exit_info(Some("123e4567-e89b-12d3-a456-426614174000"), None); + let lines = format_exit_messages(exit_info, false); assert_eq!( lines, vec![ @@ -2051,11 +2086,8 @@ mod tests { #[test] fn format_exit_messages_applies_color_when_enabled() { - let exit_info = sample_exit_info( - Some("123e4567-e89b-12d3-a456-426614174000"), - /*thread_name*/ None, - ); - let lines = format_exit_messages(exit_info, /*color_enabled*/ true); + let exit_info = sample_exit_info(Some("123e4567-e89b-12d3-a456-426614174000"), None); + let lines = format_exit_messages(exit_info, true); assert_eq!(lines.len(), 2); assert!(lines[1].contains("\u{1b}[36m")); } @@ -2066,7 +2098,7 @@ mod tests { Some("123e4567-e89b-12d3-a456-426614174000"), Some("my-thread"), ); - let lines = format_exit_messages(exit_info, /*color_enabled*/ false); + let lines = format_exit_messages(exit_info, false); assert_eq!( lines, vec![ @@ -2291,12 +2323,8 @@ mod tests { #[test] fn reject_remote_mode_for_non_interactive_subcommands() { - let err = reject_remote_mode_for_subcommand( - Some("127.0.0.1:4500"), - /*remote_auth_token_env*/ None, - "exec", - ) - .expect_err("non-interactive subcommands should reject --remote"); + let err = reject_remote_mode_for_subcommand(Some("127.0.0.1:4500"), None, "exec") + .expect_err("non-interactive subcommands should reject --remote"); assert!( err.to_string() .contains("only supported for interactive TUI commands") @@ -2305,12 +2333,8 @@ mod tests { #[test] fn reject_remote_auth_token_env_for_non_interactive_subcommands() { - let err = reject_remote_mode_for_subcommand( - /*remote*/ None, - Some("CODEX_REMOTE_AUTH_TOKEN"), - "exec", - ) - .expect_err("non-interactive subcommands should reject --remote-auth-token-env"); + let err = reject_remote_mode_for_subcommand(None, Some("CODEX_REMOTE_AUTH_TOKEN"), "exec") + .expect_err("non-interactive subcommands should reject --remote-auth-token-env"); assert!( err.to_string() .contains("only supported for interactive TUI commands") @@ -2324,7 +2348,7 @@ mod tests { out_dir: PathBuf::from("/tmp/out"), }); let err = reject_remote_mode_for_app_server_subcommand( - /*remote*/ None, + None, Some("CODEX_REMOTE_AUTH_TOKEN"), Some(&subcommand), ) diff --git a/codex-rs/codex-api/src/endpoint/responses.rs b/codex-rs/codex-api/src/endpoint/responses.rs index 17b478d1fd..2763b0e2e5 100644 --- a/codex-rs/codex-api/src/endpoint/responses.rs +++ b/codex-rs/codex-api/src/endpoint/responses.rs @@ -31,6 +31,7 @@ pub struct ResponsesClient { #[derive(Default)] pub struct ResponsesOptions { pub conversation_id: Option, + pub prompt_cache_key: Option, pub session_source: Option, pub extra_headers: HeaderMap, pub compression: Compression, @@ -73,6 +74,7 @@ impl ResponsesClient { ) -> Result { let ResponsesOptions { conversation_id, + prompt_cache_key, session_source, extra_headers, compression, @@ -89,7 +91,8 @@ impl ResponsesClient { if let Some(ref conv_id) = conversation_id { insert_header(&mut headers, "x-client-request-id", conv_id); } - headers.extend(build_conversation_headers(conversation_id)); + let session_id = prompt_cache_key.or_else(|| conversation_id.clone()); + headers.extend(build_conversation_headers(session_id)); if let Some(subagent) = subagent_header(&session_source) { insert_header(&mut headers, "x-openai-subagent", &subagent); } diff --git a/codex-rs/codex-api/tests/clients.rs b/codex-rs/codex-api/tests/clients.rs index 218a99f9b2..a73797e54f 100644 --- a/codex-rs/codex-api/tests/clients.rs +++ b/codex-rs/codex-api/tests/clients.rs @@ -445,6 +445,7 @@ async fn azure_default_store_attaches_ids_and_headers() -> Result<()> { request, ResponsesOptions { conversation_id: Some("sess_123".into()), + prompt_cache_key: Some("prompt_cache_123".into()), session_source: Some(SessionSource::SubAgent(SubAgentSource::Review)), extra_headers, compression: Compression::None, @@ -459,6 +460,12 @@ async fn azure_default_store_attaches_ids_and_headers() -> Result<()> { assert_eq!( req.headers.get("session_id").and_then(|v| v.to_str().ok()), + Some("prompt_cache_123") + ); + assert_eq!( + req.headers + .get("x-client-request-id") + .and_then(|v| v.to_str().ok()), Some("sess_123") ); assert_eq!( diff --git a/codex-rs/core/src/agent/control.rs b/codex-rs/core/src/agent/control.rs index 705d2d168f..e6299f225b 100644 --- a/codex-rs/core/src/agent/control.rs +++ b/codex-rs/core/src/agent/control.rs @@ -5,10 +5,15 @@ use crate::agent::role::DEFAULT_ROLE_NAME; use crate::agent::role::resolve_role_config; use crate::agent::status::is_final; use crate::codex_thread::ThreadConfigSnapshot; +use crate::find_archived_thread_path_by_id_str; +use crate::find_thread_path_by_id_str; +use crate::inherited_thread_state::InheritedThreadState; +use crate::rollout::RolloutRecorder; use crate::session::emit_subagent_session_started; use crate::session_prefix::format_subagent_context_line; use crate::session_prefix::format_subagent_notification_message; use crate::shell_snapshot::ShellSnapshot; +use crate::state::McpToolSnapshot; use crate::thread_manager::ResumeThreadWithHistoryOptions; use crate::thread_manager::ThreadManagerState; use crate::thread_rollout_truncation::truncate_rollout_to_last_n_fork_turns; @@ -42,6 +47,8 @@ use tracing::warn; const AGENT_NAMES: &str = include_str!("agent_names.txt"); const ROOT_LAST_TASK_MESSAGE: &str = "Main thread"; +const CODEX_EXPERIMENTAL_FORK_PREVIOUS_RESPONSE_ID_ENV: &str = + "CODEX_EXPERIMENTAL_FORK_PREVIOUS_RESPONSE_ID"; #[derive(Clone, Debug, PartialEq, Eq)] pub(crate) enum SpawnAgentForkMode { @@ -219,6 +226,18 @@ impl AgentControl { // The same `AgentControl` is sent to spawn the thread. let new_thread = match (session_source, options.fork_mode.as_ref()) { (Some(session_source), Some(_)) => { + let inherited_thread_state = InheritedThreadState::builder() + .prompt_cache_key( + parent_prompt_cache_key_for_source(&state, Some(&session_source)).await, + ) + .response_continuation( + parent_response_continuation_for_source(&state, Some(&session_source)) + .await, + ) + .mcp_tool_snapshot( + parent_mcp_tool_snapshot_for_source(&state, Some(&session_source)).await, + ) + .build(); self.spawn_forked_thread( &state, config, @@ -226,6 +245,7 @@ impl AgentControl { &options, inherited_shell_snapshot, inherited_exec_policy, + inherited_thread_state, ) .await? } @@ -240,6 +260,7 @@ impl AgentControl { inherited_shell_snapshot, inherited_exec_policy, options.environments.clone(), + Default::default(), ) .await? } @@ -325,6 +346,7 @@ impl AgentControl { }) } + #[allow(clippy::too_many_arguments)] async fn spawn_forked_thread( &self, state: &Arc, @@ -333,6 +355,7 @@ impl AgentControl { options: &SpawnAgentOptions, inherited_shell_snapshot: Option>, inherited_exec_policy: Option>, + inherited_thread_state: InheritedThreadState, ) -> CodexResult { if options.fork_parent_spawn_call_id.is_none() { return Err(CodexErr::Fatal( @@ -380,45 +403,60 @@ impl AgentControl { )) })?; - let mut forked_rollout_items = parent_history.items; - if let SpawnAgentForkMode::LastNTurns(last_n_turns) = fork_mode { - forked_rollout_items = - truncate_rollout_to_last_n_fork_turns(&forked_rollout_items, *last_n_turns); - } - // MultiAgentV2 root/subagent usage hints are injected as standalone developer - // messages at thread start. When forking history, drop hints from the parent - // so the child gets a fresh hint that matches its own session source/config. - let multi_agent_v2_usage_hint_texts_to_filter: Vec = - if let Some(parent_thread) = parent_thread.as_ref() { - parent_thread - .codex - .session - .configured_multi_agent_v2_usage_hint_texts() - .await - } else if config.features.enabled(Feature::MultiAgentV2) { - [ - config.multi_agent_v2.root_agent_usage_hint_text.clone(), - config.multi_agent_v2.subagent_usage_hint_text.clone(), - ] - .into_iter() - .flatten() - .collect() - } else { - Vec::new() - }; - forked_rollout_items.retain(|item| { - if let RolloutItem::ResponseItem(ResponseItem::Message { role, content, .. }) = item - && role == "developer" - && let [ContentItem::InputText { text }] = content.as_slice() - && multi_agent_v2_usage_hint_texts_to_filter - .iter() - .any(|usage_hint_text| usage_hint_text == text) - { - return false; +let response_continuation = inherited_thread_state.response_continuation(); + let use_response_continuation_baseline = + response_continuation.is_some() && matches!(fork_mode, SpawnAgentForkMode::FullHistory); + let mut forked_rollout_items = if let (Some(response_continuation), true) = + (&response_continuation, use_response_continuation_baseline) + { + previous_response_fork_rollout_items( + parent_history.items, + response_continuation.fork_baseline_input(), + ) + } else { + let mut items = parent_history.items; + if let SpawnAgentForkMode::LastNTurns(last_n_turns) = fork_mode { + items = truncate_rollout_to_last_n_fork_turns(&items, *last_n_turns); } + items + }; + if !use_response_continuation_baseline { + // MultiAgentV2 root/subagent usage hints are injected as standalone developer + // messages at thread start. When forking history, drop hints from the parent + // so the child gets a fresh hint that matches its own session source/config. + let multi_agent_v2_usage_hint_texts_to_filter: Vec = + if let Some(parent_thread) = parent_thread.as_ref() { + parent_thread + .codex + .session + .configured_multi_agent_v2_usage_hint_texts() + .await + } else if config.features.enabled(Feature::MultiAgentV2) { + [ + config.multi_agent_v2.root_agent_usage_hint_text.clone(), + config.multi_agent_v2.subagent_usage_hint_text.clone(), + ] + .into_iter() + .flatten() + .collect() + } else { + Vec::new() + }; + forked_rollout_items.retain(|item| { + if let RolloutItem::ResponseItem(ResponseItem::Message { role, content, .. }) = + item + && role == "developer" + && let [ContentItem::InputText { text }] = content.as_slice() + && multi_agent_v2_usage_hint_texts_to_filter + .iter() + .any(|usage_hint_text| usage_hint_text == text) + { + return false; + } - keep_forked_rollout_item(item) - }); + keep_forked_rollout_item(item) + }); + } state .fork_thread_with_source( @@ -430,6 +468,7 @@ impl AgentControl { inherited_shell_snapshot, inherited_exec_policy, options.environments.clone(), + inherited_thread_state, ) .await } @@ -586,6 +625,7 @@ impl AgentControl { session_source, inherited_shell_snapshot, inherited_exec_policy, + inherited_thread_state: Default::default(), }) .await?; let mut agent_metadata = agent_metadata; @@ -1190,6 +1230,112 @@ impl AgentControl { } } +async fn parent_prompt_cache_key_for_source( + state: &Arc, + session_source: Option<&SessionSource>, +) -> Option { + let Some(SessionSource::SubAgent(SubAgentSource::ThreadSpawn { + parent_thread_id, .. + })) = session_source + else { + return None; + }; + + state + .get_thread(*parent_thread_id) + .await + .ok() + .map(|parent_thread| parent_thread.codex.session.prompt_cache_key()) +} + +fn previous_response_fork_rollout_items( + source_items: Vec, + baseline_input: Vec, +) -> Vec { + let source_session_meta = source_items.iter().find_map(|item| match item { + RolloutItem::SessionMeta(meta) => Some(meta.clone()), + RolloutItem::ResponseItem(_) + | RolloutItem::Compacted(_) + | RolloutItem::TurnContext(_) + | RolloutItem::EventMsg(_) => None, + }); + let latest_turn_context = source_items.iter().rev().find_map(|item| match item { + RolloutItem::TurnContext(turn_context) => Some(turn_context.clone()), + RolloutItem::ResponseItem(_) + | RolloutItem::Compacted(_) + | RolloutItem::SessionMeta(_) + | RolloutItem::EventMsg(_) => None, + }); + + source_session_meta + .into_iter() + .map(RolloutItem::SessionMeta) + .chain(baseline_input.into_iter().map(RolloutItem::ResponseItem)) + .chain( + latest_turn_context + .into_iter() + .map(RolloutItem::TurnContext), + ) + .collect() +} + +async fn parent_mcp_tool_snapshot_for_source( + state: &Arc, + session_source: Option<&SessionSource>, +) -> Option { + let Some(SessionSource::SubAgent(SubAgentSource::ThreadSpawn { + parent_thread_id, .. + })) = session_source + else { + return None; + }; + + let parent_thread = state.get_thread(*parent_thread_id).await.ok()?; + let tools = parent_thread + .codex + .session + .services + .mcp_connection_manager + .read() + .await + .list_all_tools() + .await; + Some(McpToolSnapshot { tools }) +} + +fn fork_previous_response_id_enabled() -> bool { + std::env::var(CODEX_EXPERIMENTAL_FORK_PREVIOUS_RESPONSE_ID_ENV) + .is_ok_and(|value| fork_previous_response_id_value_enabled(&value)) +} + +fn fork_previous_response_id_value_enabled(value: &str) -> bool { + matches!( + value.to_ascii_lowercase().as_str(), + "1" | "true" | "yes" | "on" + ) +} + +async fn parent_response_continuation_for_source( + state: &Arc, + session_source: Option<&SessionSource>, +) -> Option { + if !fork_previous_response_id_enabled() { + return None; + } + let Some(SessionSource::SubAgent(SubAgentSource::ThreadSpawn { + parent_thread_id, .. + })) = session_source + else { + return None; + }; + + state + .get_thread(*parent_thread_id) + .await + .ok() + .and_then(|parent_thread| parent_thread.codex.session.response_continuation_for_fork()) +} + fn thread_spawn_parent_thread_id(session_source: &SessionSource) -> Option { match session_source { SessionSource::SubAgent(SubAgentSource::ThreadSpawn { diff --git a/codex-rs/core/src/agent/control_tests.rs b/codex-rs/core/src/agent/control_tests.rs index 7ef2120d5c..922c5b7723 100644 --- a/codex-rs/core/src/agent/control_tests.rs +++ b/codex-rs/core/src/agent/control_tests.rs @@ -8,6 +8,8 @@ use crate::config::ConfigBuilder; use crate::context::ContextualUserFragment; use crate::context::SubagentNotification; use assert_matches::assert_matches; +use codex_config::types::McpServerConfig; +use codex_config::types::McpServerTransportConfig; use codex_features::Feature; use codex_login::CodexAuth; use codex_protocol::AgentPath; @@ -28,6 +30,12 @@ use codex_thread_store::ArchiveThreadParams; use codex_thread_store::LocalThreadStore; use codex_thread_store::LocalThreadStoreConfig; use codex_thread_store::ThreadStore; +use core_test_support::responses::ev_completed; +use core_test_support::responses::ev_response_created; +use core_test_support::responses::mount_sse_once; +use core_test_support::responses::namespace_child_tool; +use core_test_support::responses::sse; +use core_test_support::responses::start_mock_server; use pretty_assertions::assert_eq; use tempfile::TempDir; use tokio::time::Duration; @@ -71,6 +79,57 @@ fn assistant_message(text: &str, phase: Option) -> ResponseItem { } } +#[test] +fn fork_previous_response_id_env_value_parses_truthy_values() { + for value in ["1", "true", "TRUE", "yes", "on"] { + assert!( + fork_previous_response_id_value_enabled(value), + "{value} should enable previous response forking" + ); + } + + for value in ["", "0", "false", "off", "no", "enabled"] { + assert!( + !fork_previous_response_id_value_enabled(value), + "{value} should not enable previous response forking" + ); + } +} + +#[tokio::test] +async fn previous_response_fork_rollout_items_preserve_latest_turn_context() { + let harness = AgentControlHarness::new().await; + let (_thread_id, owner_thread) = harness.start_thread().await; + let owner_turn = owner_thread.codex.session.new_default_turn().await; + let mut first_turn_context = owner_turn.to_turn_context_item(); + first_turn_context.model = "first-model".to_string(); + let mut latest_turn_context = first_turn_context.clone(); + latest_turn_context.model = "latest-model".to_string(); + + let baseline_item = assistant_message( + "parent final from previous response", + Some(MessagePhase::FinalAnswer), + ); + let items = previous_response_fork_rollout_items( + vec![ + RolloutItem::TurnContext(first_turn_context), + RolloutItem::ResponseItem(assistant_message( + "parent rollout item should not be copied", + Some(MessagePhase::FinalAnswer), + )), + RolloutItem::TurnContext(latest_turn_context.clone()), + ], + vec![baseline_item.clone()], + ); + + assert_eq!(items.len(), 2); + assert_matches!(&items[0], RolloutItem::ResponseItem(item) if *item == baseline_item); + assert_matches!( + &items[1], + RolloutItem::TurnContext(turn_context) if turn_context.model == latest_turn_context.model + ); +} + fn spawn_agent_call(call_id: &str) -> ResponseItem { ResponseItem::FunctionCall { id: None, @@ -703,6 +762,37 @@ async fn spawn_agent_can_fork_parent_thread_history_with_sanitized_items() { .await .expect("child thread should be registered"); assert_ne!(child_thread_id, parent_thread_id); + assert_eq!( + child_thread.codex.session.prompt_cache_key(), + parent_thread.codex.session.prompt_cache_key(), + ); + assert!(!Arc::ptr_eq( + &child_thread.codex.session.services.mcp_connection_manager, + &parent_thread.codex.session.services.mcp_connection_manager, + )); + let mcp_tool_snapshot = child_thread + .codex + .session + .services + .mcp_tool_snapshot + .lock() + .await + .clone() + .expect("forked child should inherit an MCP tool snapshot"); + let parent_mcp_tools = parent_thread + .codex + .session + .services + .mcp_connection_manager + .read() + .await + .list_all_tools() + .await; + let mut snapshot_tool_names = mcp_tool_snapshot.tools.keys().cloned().collect::>(); + snapshot_tool_names.sort(); + let mut parent_tool_names = parent_mcp_tools.keys().cloned().collect::>(); + parent_tool_names.sort(); + assert_eq!(snapshot_tool_names, parent_tool_names); let history = child_thread.codex.session.clone_history().await; let expected_history = [ ResponseItem::Message { @@ -751,6 +841,203 @@ async fn spawn_agent_can_fork_parent_thread_history_with_sanitized_items() { .expect("parent shutdown should submit"); } +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn forked_spawn_first_request_uses_parent_cache_key_and_mcp_snapshot() -> anyhow::Result<()> { + let server = start_mock_server().await; + let child_response_mock = mount_sse_once( + &server, + sse(vec![ev_response_created("resp-1"), ev_completed("resp-1")]), + ) + .await; + let (_home, mut config) = test_config().await; + config.model_provider.base_url = Some(format!("{}/v1", server.uri())); + config.model_provider.supports_websockets = false; + let mcp_server_path = config.codex_home.join("fake_mcp_server.py"); + std::fs::write( + &mcp_server_path, + r#"import json +import sys + +def read_message(): + line = sys.stdin.buffer.readline() + if not line: + return None + return json.loads(line) + +def write_message(message): + body = json.dumps(message).encode("utf-8") + sys.stdout.buffer.write(body) + sys.stdout.buffer.write(b"\n") + sys.stdout.buffer.flush() + +while True: + message = read_message() + if message is None: + break + method = message.get("method") + request_id = message.get("id") + if request_id is None: + continue + if method == "initialize": + write_message({ + "jsonrpc": "2.0", + "id": request_id, + "result": { + "protocolVersion": "2025-06-18", + "capabilities": {"tools": {"listChanged": False}}, + "serverInfo": {"name": "fake-mcp", "version": "1.0.0"}, + }, + }) + elif method == "tools/list": + write_message({ + "jsonrpc": "2.0", + "id": request_id, + "result": { + "tools": [{ + "name": "echo", + "description": "Echo from fake MCP", + "inputSchema": { + "type": "object", + "properties": {}, + "additionalProperties": False, + }, + }], + }, + }) + else: + write_message({ + "jsonrpc": "2.0", + "id": request_id, + "error": {"code": -32601, "message": "method not found"}, + }) +"#, + )?; + config + .mcp_servers + .set(std::collections::HashMap::from([( + "rmcp".to_string(), + McpServerConfig { + transport: McpServerTransportConfig::Stdio { + command: "python3".to_string(), + args: vec![mcp_server_path.to_string_lossy().to_string()], + env: None, + env_vars: Vec::new(), + cwd: None, + }, + experimental_environment: None, + enabled: true, + required: false, + supports_parallel_tool_calls: false, + disabled_reason: None, + startup_timeout_sec: Some(Duration::from_secs(5)), + tool_timeout_sec: None, + default_tools_approval_mode: None, + enabled_tools: None, + disabled_tools: None, + scopes: None, + oauth_resource: None, + tools: std::collections::HashMap::new(), + }, + )])) + .expect("test config should allow MCP servers"); + + let manager = ThreadManager::with_models_provider_and_home_for_tests( + CodexAuth::from_api_key("dummy"), + config.model_provider.clone(), + config.codex_home.to_path_buf(), + std::sync::Arc::new(codex_exec_server::EnvironmentManager::default_for_tests()), + ); + let control = manager.agent_control(); + let parent = manager.start_thread(config.clone()).await?; + let parent_thread_id = parent.thread_id; + let parent_prompt_cache_key = parent.thread.codex.session.prompt_cache_key(); + let parent_mcp_tools = parent + .thread + .codex + .session + .services + .mcp_connection_manager + .read() + .await + .list_all_tools() + .await; + let startup_failures = parent + .thread + .codex + .session + .services + .mcp_connection_manager + .read() + .await + .required_startup_failures(&["rmcp".to_string()]) + .await; + assert!( + parent_mcp_tools.contains_key("mcp__rmcp__echo"), + "parent MCP manager should expose live MCP tools before forking: tools={parent_mcp_tools:#?}; failures={startup_failures:#?}" + ); + parent + .thread + .inject_user_message_without_turn("parent seed".to_string()) + .await; + parent + .thread + .codex + .session + .ensure_rollout_materialized() + .await; + parent.thread.codex.session.flush_rollout().await?; + + let child_thread_id = control + .spawn_agent_with_metadata( + config, + text_input("child request boundary"), + Some(SessionSource::SubAgent(SubAgentSource::ThreadSpawn { + parent_thread_id, + depth: 1, + agent_path: None, + agent_nickname: Some("worker".to_string()), + agent_role: None, + })), + SpawnAgentOptions { + fork_parent_spawn_call_id: Some("spawn-call-request-boundary".to_string()), + fork_mode: Some(SpawnAgentForkMode::FullHistory), + ..Default::default() + }, + ) + .await? + .thread_id; + let child_thread = manager + .get_thread(child_thread_id) + .await + .expect("child thread should be registered"); + + timeout(Duration::from_secs(5), async { + loop { + let event = child_thread + .next_event() + .await + .expect("child event channel should stay open"); + if matches!(event.msg, EventMsg::TurnComplete(_)) { + break; + } + } + }) + .await + .expect("child turn should complete"); + let body = child_response_mock.single_request().body_json(); + let expected_prompt_cache_key = parent_prompt_cache_key.to_string(); + assert_eq!( + body["prompt_cache_key"].as_str(), + Some(expected_prompt_cache_key.as_str()) + ); + assert!( + namespace_child_tool(&body, "mcp__rmcp__", "echo").is_some(), + "first forked child request should expose parent MCP snapshot tools: {body:#}" + ); + + Ok(()) +} + #[tokio::test] async fn spawn_agent_fork_flushes_parent_rollout_before_loading_history() { let harness = AgentControlHarness::new().await; @@ -1550,7 +1837,7 @@ async fn resume_thread_subagent_restores_stored_nickname_and_role() { manager, control, }; - let (parent_thread_id, _parent_thread) = harness.start_thread().await; + let (parent_thread_id, parent_thread) = harness.start_thread().await; let agent_path = AgentPath::from_string("/root/explorer".to_string()) .expect("test agent path should be valid"); @@ -1640,13 +1927,22 @@ async fn resume_thread_subagent_restores_stored_nickname_and_role() { .expect("resume should succeed"); assert_eq!(resumed_thread_id, child_thread_id); - let resumed_snapshot = harness + let resumed_thread = harness .manager .get_thread(resumed_thread_id) .await - .expect("resumed child thread should exist") - .config_snapshot() - .await; + .expect("resumed child thread should exist"); + assert_eq!( + resumed_thread.codex.session.prompt_cache_key(), + resumed_thread_id, + "resume should keep the resumed thread's own cache key" + ); + assert_ne!( + resumed_thread.codex.session.prompt_cache_key(), + parent_thread.codex.session.prompt_cache_key(), + "resume must not opportunistically inherit cache state from a live parent" + ); + let resumed_snapshot = resumed_thread.config_snapshot().await; let SessionSource::SubAgent(SubAgentSource::ThreadSpawn { parent_thread_id: resumed_parent_thread_id, depth: resumed_depth, diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs index ba81b451a7..d362b3afe6 100644 --- a/codex-rs/core/src/client.rs +++ b/codex-rs/core/src/client.rs @@ -154,6 +154,7 @@ struct ModelClientState { conversation_id: ThreadId, window_generation: AtomicU64, installation_id: String, + prompt_cache_key_override: Option, provider: SharedModelProvider, auth_env_telemetry: AuthEnvTelemetry, session_source: SessionSource, @@ -163,6 +164,7 @@ struct ModelClientState { beta_features_header: Option, disable_websockets: AtomicBool, cached_websocket_session: StdMutex, + latest_response_continuation: StdMutex>, } /// Resolved API client setup for a single request attempt. @@ -237,15 +239,32 @@ struct LastResponse { items_added: Vec, } +#[derive(Debug, Clone)] +pub(crate) struct ResponseContinuation { + request: ResponsesApiRequest, + last_response: LastResponse, +} + #[derive(Debug, Default)] struct WebsocketSession { connection: Option, last_request: Option, last_response_rx: Option>, + last_response: Option, connection_reused: StdMutex, } impl WebsocketSession { + fn from_response_continuation(continuation: ResponseContinuation) -> Self { + Self { + connection: None, + last_request: Some(continuation.request), + last_response_rx: None, + last_response: Some(continuation.last_response), + connection_reused: StdMutex::new(false), + } + } + fn set_connection_reused(&self, connection_reused: bool) { *self .connection_reused @@ -261,6 +280,24 @@ impl WebsocketSession { } } +impl ResponseContinuation { + pub(crate) fn for_fork(mut self) -> Self { + self.request + .input + .retain(|item| !matches!(item, ResponseItem::Reasoning { .. })); + self + } + + pub(crate) fn fork_baseline_input(&self) -> Vec { + self.request + .input + .iter() + .chain(self.last_response.items_added.iter()) + .cloned() + .collect() + } +} + enum WebsocketStreamOutcome { Stream(ResponseStream), FallbackToHttp, @@ -298,12 +335,42 @@ impl ModelClient { auth_manager: Option>, conversation_id: ThreadId, installation_id: String, + prompt_cache_key_override: Option, provider_info: ModelProviderInfo, session_source: SessionSource, model_verbosity: Option, enable_request_compression: bool, include_timing_metrics: bool, beta_features_header: Option, + ) -> Self { + Self::new_with_response_continuation( + auth_manager, + conversation_id, + installation_id, + prompt_cache_key_override, + provider_info, + session_source, + model_verbosity, + enable_request_compression, + include_timing_metrics, + beta_features_header, + /*response_continuation*/ None, + ) + } + + #[allow(clippy::too_many_arguments)] + pub(crate) fn new_with_response_continuation( + auth_manager: Option>, + conversation_id: ThreadId, + installation_id: String, + prompt_cache_key_override: Option, + provider_info: ModelProviderInfo, + session_source: SessionSource, + model_verbosity: Option, + enable_request_compression: bool, + include_timing_metrics: bool, + beta_features_header: Option, + response_continuation: Option, ) -> Self { let model_provider = create_model_provider(provider_info, auth_manager); let codex_api_key_env_enabled = model_provider @@ -312,11 +379,16 @@ impl ModelClient { .is_some_and(|manager| manager.codex_api_key_env_enabled()); let auth_env_telemetry = collect_auth_env_telemetry(model_provider.info(), codex_api_key_env_enabled); + let cached_websocket_session = response_continuation + .clone() + .map(WebsocketSession::from_response_continuation) + .unwrap_or_default(); Self { state: Arc::new(ModelClientState { conversation_id, window_generation: AtomicU64::new(0), installation_id, + prompt_cache_key_override, provider: model_provider, auth_env_telemetry, session_source, @@ -325,7 +397,8 @@ impl ModelClient { include_timing_metrics, beta_features_header, disable_websockets: AtomicBool::new(false), - cached_websocket_session: StdMutex::new(WebsocketSession::default()), + cached_websocket_session: StdMutex::new(cached_websocket_session), + latest_response_continuation: StdMutex::new(response_continuation), }), } } @@ -364,6 +437,24 @@ impl ModelClient { format!("{conversation_id}:{window_generation}") } + pub(crate) fn prompt_cache_key(&self) -> ThreadId { + self.state + .prompt_cache_key_override + .unwrap_or(self.state.conversation_id) + } + + pub(crate) fn response_continuation_for_fork(&self) -> Option { + if !self.responses_websocket_enabled() { + return None; + } + self.state + .latest_response_continuation + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .clone() + .map(ResponseContinuation::for_fork) + } + fn take_cached_websocket_session(&self) -> WebsocketSession { let mut cached_websocket_session = self .state @@ -787,6 +878,7 @@ impl ModelClient { ) -> ApiHeaderMap { let turn_metadata_header = parse_turn_metadata_header(turn_metadata_header); let conversation_id = self.state.conversation_id.to_string(); + let prompt_cache_key = self.prompt_cache_key().to_string(); let mut headers = build_responses_headers( self.state.beta_features_header.as_deref(), turn_state, @@ -795,7 +887,10 @@ impl ModelClient { if let Ok(header_value) = HeaderValue::from_str(&conversation_id) { headers.insert("x-client-request-id", header_value); } - headers.extend(build_conversation_headers(Some(conversation_id))); + // The Responses websocket backend uses the `session_id` handshake header as the prompt + // cache id. Keep `x-client-request-id` on the conversation id so forked agents keep a + // unique request identity while sharing their parent's prompt cache id. + headers.extend(build_conversation_headers(Some(prompt_cache_key))); headers.extend(self.build_responses_identity_headers()); headers.insert( OPENAI_BETA_HEADER, @@ -824,6 +919,7 @@ impl ModelClientSession { self.websocket_session.connection = None; self.websocket_session.last_request = None; self.websocket_session.last_response_rx = None; + self.websocket_session.last_response = None; self.websocket_session .set_connection_reused(/*connection_reused*/ false); } @@ -877,7 +973,7 @@ impl ModelClientSession { &prompt.output_schema, prompt.output_schema_strict, ); - let prompt_cache_key = Some(self.client.state.conversation_id.to_string()); + let prompt_cache_key = Some(self.client.prompt_cache_key().to_string()); let request = ResponsesApiRequest { model: model_info.slug.clone(), instructions: instructions.clone(), @@ -916,8 +1012,10 @@ impl ModelClientSession { ) -> ApiResponsesOptions { let turn_metadata_header = parse_turn_metadata_header(turn_metadata_header); let conversation_id = self.client.state.conversation_id.to_string(); + let prompt_cache_key = self.client.prompt_cache_key().to_string(); ApiResponsesOptions { conversation_id: Some(conversation_id), + prompt_cache_key: Some(prompt_cache_key), session_source: Some(self.client.state.session_source.clone()), extra_headers: { let mut headers = build_responses_headers( @@ -973,13 +1071,16 @@ impl ModelClientSession { } fn get_last_response(&mut self) -> Option { - self.websocket_session - .last_response_rx - .take() - .and_then(|mut receiver| match receiver.try_recv() { - Ok(last_response) => Some(last_response), - Err(TryRecvError::Closed) | Err(TryRecvError::Empty) => None, - }) + if let Some(mut receiver) = self.websocket_session.last_response_rx.take() { + match receiver.try_recv() { + Ok(last_response) => { + self.websocket_session.last_response = Some(last_response.clone()); + return Some(last_response); + } + Err(TryRecvError::Closed) | Err(TryRecvError::Empty) => {} + } + } + self.websocket_session.last_response.clone() } fn prepare_websocket_request( @@ -1084,7 +1185,9 @@ impl ModelClientSession { }; if needs_new { - self.websocket_session.last_request = None; + if self.websocket_session.last_response.is_none() { + self.websocket_session.last_request = None; + } self.websocket_session.last_response_rx = None; let turn_state = options .turn_state @@ -1178,6 +1281,8 @@ impl ModelClientSession { stream, session_telemetry.clone(), InferenceTraceAttempt::disabled(), + /*client_state*/ None, + /*request*/ None, ); return Ok(stream); } @@ -1228,6 +1333,8 @@ impl ModelClientSession { stream, session_telemetry.clone(), inference_trace_attempt, + /*client_state*/ None, + /*request*/ None, ); return Ok(stream); } @@ -1367,6 +1474,7 @@ impl ModelClientSession { let ws_request = self.prepare_websocket_request(ws_payload, &request); self.websocket_session.last_request = Some(request); + self.websocket_session.last_response = None; let inference_trace_attempt = if warmup { // Prewarm sends `generate=false`; it is connection setup, not a // model inference attempt that should appear in rollout traces. @@ -1399,6 +1507,8 @@ impl ModelClientSession { stream_result, session_telemetry.clone(), inference_trace_attempt, + Some(Arc::clone(&self.client.state)), + self.websocket_session.last_request.clone(), ); self.websocket_session.last_response_rx = Some(last_request_rx); return Ok(WebsocketStreamOutcome::Stream(stream)); @@ -1679,6 +1789,8 @@ fn map_response_events( api_stream: S, session_telemetry: SessionTelemetry, inference_trace_attempt: InferenceTraceAttempt, + client_state: Option>, + request: Option, ) -> (ResponseStream, oneshot::Receiver) where S: futures::Stream> @@ -1749,11 +1861,24 @@ where &token_usage, &items_added, ); + let last_response = LastResponse { + response_id: response_id.clone(), + items_added: std::mem::take(&mut items_added), + }; + if let (Some(client_state), Some(request)) = (&client_state, &request) + && !last_response.response_id.is_empty() + { + *client_state + .latest_response_continuation + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) = + Some(ResponseContinuation { + request: request.clone(), + last_response: last_response.clone(), + }); + } if let Some(sender) = tx_last_response.take() { - let _ = sender.send(LastResponse { - response_id: response_id.clone(), - items_added: std::mem::take(&mut items_added), - }); + let _ = sender.send(last_response); } if tx_event .send(Ok(ResponseEvent::Completed { diff --git a/codex-rs/core/src/client_tests.rs b/codex-rs/core/src/client_tests.rs index e56500ba5f..ef6e812f8a 100644 --- a/codex-rs/core/src/client_tests.rs +++ b/codex-rs/core/src/client_tests.rs @@ -1,6 +1,9 @@ use super::AuthRequestTelemetryContext; +use super::LastResponse; use super::ModelClient; use super::PendingUnauthorizedRetry; +use super::ResponseContinuation; +use super::ResponsesApiRequest; use super::UnauthorizedRecoveryExecution; use super::X_CODEX_INSTALLATION_ID_HEADER; use super::X_CODEX_PARENT_THREAD_ID_HEADER; @@ -16,6 +19,8 @@ use codex_model_provider_info::create_oss_provider_with_base_url; use codex_otel::SessionTelemetry; use codex_protocol::ThreadId; use codex_protocol::models::ContentItem; +use codex_protocol::models::ReasoningItemContent; +use codex_protocol::models::ReasoningItemReasoningSummary; use codex_protocol::models::ResponseItem; use codex_protocol::openai_models::ModelInfo; use codex_protocol::protocol::InternalSessionSource; @@ -41,11 +46,13 @@ use tempfile::TempDir; use tokio::sync::Notify; fn test_model_client(session_source: SessionSource) -> ModelClient { + let conversation_id = ThreadId::new(); let provider = create_oss_provider_with_base_url("https://example.com/v1", WireApi::Responses); ModelClient::new( /*auth_manager*/ None, - ThreadId::new(), + conversation_id, /*installation_id*/ "11111111-1111-4111-8111-111111111111".to_string(), + /*prompt_cache_key_override*/ None, provider, session_source, /*model_verbosity*/ None, @@ -147,6 +154,66 @@ fn output_message(id: &str, text: &str) -> ResponseItem { } } +fn user_message_item(text: &str) -> ResponseItem { + ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: text.to_string(), + }], + phase: None, + } +} + +fn reasoning_item(id: &str, text: &str) -> ResponseItem { + ResponseItem::Reasoning { + id: id.to_string(), + summary: vec![ReasoningItemReasoningSummary::SummaryText { + text: "summary".to_string(), + }], + content: Some(vec![ReasoningItemContent::ReasoningText { + text: text.to_string(), + }]), + encrypted_content: None, + } +} + +#[test] +fn response_continuation_for_fork_drops_historical_reasoning_but_keeps_latest() { + let user_message = user_message_item("hello"); + let old_reasoning = reasoning_item("rs-old", "old analysis"); + let latest_reasoning = reasoning_item("rs-latest", "latest analysis"); + let latest_message = output_message("msg-latest", "assistant output"); + let response_continuation = ResponseContinuation { + request: ResponsesApiRequest { + model: "gpt-test".to_string(), + instructions: "base instructions".to_string(), + input: vec![user_message.clone(), old_reasoning], + tools: Vec::new(), + tool_choice: "auto".to_string(), + parallel_tool_calls: false, + reasoning: None, + store: false, + stream: true, + include: Vec::new(), + service_tier: None, + prompt_cache_key: Some(ThreadId::new().to_string()), + text: None, + client_metadata: None, + }, + last_response: LastResponse { + response_id: "parent-resp".to_string(), + items_added: vec![latest_reasoning.clone(), latest_message.clone()], + }, + } + .for_fork(); + + assert_eq!( + response_continuation.fork_baseline_input(), + vec![user_message, latest_reasoning, latest_message] + ); +} + async fn replay_until_cancelled(temp: &TempDir) -> anyhow::Result { let mut rollout = replay_bundle(temp.path())?; for _ in 0..50 { @@ -287,6 +354,8 @@ async fn dropped_response_stream_traces_cancelled_partial_output() -> anyhow::Re api_stream, test_session_telemetry(), attempt, + /*client_state*/ None, + /*request*/ None, ); let observed = stream @@ -342,6 +411,8 @@ async fn dropped_backpressured_response_stream_traces_cancelled_partial_output() api_stream, test_session_telemetry(), attempt, + /*client_state*/ None, + /*request*/ None, ); // Fill the mapper channel with non-terminal events, then yield one output diff --git a/codex-rs/core/src/codex_delegate.rs b/codex-rs/core/src/codex_delegate.rs index d142d33a2f..693b71ac02 100644 --- a/codex-rs/core/src/codex_delegate.rs +++ b/codex-rs/core/src/codex_delegate.rs @@ -93,6 +93,7 @@ pub(crate) async fn run_codex_thread_interactive( user_shell_override: None, inherited_exec_policy: Some(Arc::clone(&parent_session.services.exec_policy)), parent_rollout_thread_trace: codex_rollout_trace::ThreadTraceContext::disabled(), + inherited_thread_state: Default::default(), parent_trace: None, environment_selections: ResolvedTurnEnvironments { turn_environments: parent_ctx.environments.clone(), diff --git a/codex-rs/core/src/compact.rs b/codex-rs/core/src/compact.rs index 58a2610fcb..be3a53680d 100644 --- a/codex-rs/core/src/compact.rs +++ b/codex-rs/core/src/compact.rs @@ -42,6 +42,8 @@ use codex_model_provider_info::ModelProviderInfo; pub const SUMMARIZATION_PROMPT: &str = include_str!("../templates/compact/prompt.md"); pub const SUMMARY_PREFIX: &str = include_str!("../templates/compact/summary_prefix.md"); const COMPACT_USER_MESSAGE_MAX_TOKENS: usize = 20_000; +pub(crate) const UNIFIED_EXEC_PROCESS_WARNING_PREFIX: &str = + "Warning: The maximum number of unified exec process"; /// Controls whether compaction replacement history must include initial context. /// @@ -368,19 +370,51 @@ pub fn content_items_to_text(content: &[ContentItem]) -> Option { } pub(crate) fn collect_user_messages(items: &[ResponseItem]) -> Vec { - items - .iter() - .filter_map(|item| match crate::event_mapping::parse_turn_item(item) { + let mut messages = Vec::new(); + let mut previous_message: Option = Some(String::new()); + for item in items { + let message = match crate::event_mapping::parse_turn_item(item) { Some(TurnItem::UserMessage(user)) => { - if is_summary_message(&user.message()) { + if is_summary_message(&user.message()) + || is_compaction_filtered_user_message(&user.message()) + { None } else { Some(user.message()) } } _ => None, - }) - .collect() + }; + let Some(message) = message else { + previous_message = None; + continue; + }; + if message.is_empty() { + continue; + } + if previous_message.as_deref() == Some(message.as_str()) { + continue; + } + previous_message = Some(message.clone()); + messages.push(message); + } + messages +} + +pub(crate) fn is_compaction_filtered_user_message(message: &str) -> bool { + message.starts_with(UNIFIED_EXEC_PROCESS_WARNING_PREFIX) +} + +pub(crate) fn is_compaction_filtered_history_item(item: &ResponseItem) -> bool { + let ResponseItem::Message { role, content, .. } = item else { + return false; + }; + if role != "user" { + return false; + } + content_items_to_text(content) + .as_deref() + .is_some_and(is_compaction_filtered_user_message) } pub(crate) fn is_summary_message(message: &str) -> bool { diff --git a/codex-rs/core/src/compact_remote.rs b/codex-rs/core/src/compact_remote.rs index d8adb20772..4b0a80c8a5 100644 --- a/codex-rs/core/src/compact_remote.rs +++ b/codex-rs/core/src/compact_remote.rs @@ -6,6 +6,7 @@ use crate::compact::CompactionAnalyticsAttempt; use crate::compact::InitialContextInjection; use crate::compact::compaction_status_from_result; use crate::compact::insert_initial_context_before_last_real_user_or_summary; +use crate::compact::is_compaction_filtered_history_item; use crate::context_manager::ContextManager; use crate::context_manager::TotalTokenUsageBreakdown; use crate::context_manager::estimate_response_item_model_visible_bytes; @@ -144,8 +145,17 @@ async fn run_remote_compact_task_inner_impl( // This is the history selected for remote compaction, after any trimming required to fit the // compact endpoint. The checkpoint below records it separately from the next sampling request, // whose prompt will repeat current developer/context prefix items. - let trace_input_history = history.raw_items().to_vec(); - let prompt_input = history.for_prompt(&turn_context.model_info.input_modalities); + let trace_input_history = history + .raw_items() + .iter() + .filter(|item| !is_compaction_filtered_history_item(item)) + .cloned() + .collect::>(); + let prompt_input = history + .for_prompt(&turn_context.model_info.input_modalities) + .into_iter() + .filter(|item| !is_compaction_filtered_history_item(item)) + .collect::>(); let tool_router = built_tools( sess.as_ref(), turn_context.as_ref(), @@ -250,17 +260,21 @@ pub(crate) async fn process_compacted_history( /// We drop: /// - `developer` messages because remote output can include stale/duplicated /// instruction content. -/// - non-user-content `user` messages (session prefix/instruction wrappers), -/// while preserving real user messages and persisted hook prompts. +/// - non-user-content `user` messages (session prefix/instruction wrappers). +/// - user warnings that are known to be local runtime noise and should not +/// survive compaction. /// /// This intentionally keeps: /// - `assistant` messages (future remote compaction models may emit them) -/// - `user`-role warnings and compaction-generated summary messages because -/// they parse as `TurnItem::UserMessage`. +/// - compaction-generated summary messages because they parse as +/// `TurnItem::UserMessage`. fn should_keep_compacted_history_item(item: &ResponseItem) -> bool { match item { ResponseItem::Message { role, .. } if role == "developer" => false, ResponseItem::Message { role, .. } if role == "user" => { + if is_compaction_filtered_history_item(item) { + return false; + } matches!( crate::event_mapping::parse_turn_item(item), Some(TurnItem::UserMessage(_) | TurnItem::HookPrompt(_)) diff --git a/codex-rs/core/src/compact_tests.rs b/codex-rs/core/src/compact_tests.rs index 8fdb7fb4b2..c237f5b52d 100644 --- a/codex-rs/core/src/compact_tests.rs +++ b/codex-rs/core/src/compact_tests.rs @@ -120,6 +120,102 @@ do things assert_eq!(vec!["real user message".to_string()], collected); } +#[test] +fn collect_user_messages_filters_unified_exec_process_warnings() { + let items = vec![ + ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: "before warning".to_string(), + }], + phase: None, + }, + ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: "Warning: The maximum number of unified exec processes you can keep open is 5 and you currently have 5 processes open. Reuse older processes or close them to prevent automatic pruning of old processes".to_string(), + }], + phase: None, + }, + ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: "after warning".to_string(), + }], + phase: None, + }, + ]; + + let collected = collect_user_messages(&items); + + assert_eq!( + vec!["before warning".to_string(), "after warning".to_string()], + collected + ); + assert!(is_compaction_filtered_history_item(&items[1])); +} + +#[test] +fn collect_user_messages_drops_contiguous_duplicates_and_empty_messages() { + let items = vec![ + ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: String::new(), + }], + phase: None, + }, + ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: "repeat".to_string(), + }], + phase: None, + }, + ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: "repeat".to_string(), + }], + phase: None, + }, + ResponseItem::Message { + id: None, + role: "assistant".to_string(), + content: vec![ContentItem::OutputText { + text: "keeps the next user message non-contiguous".to_string(), + }], + phase: None, + }, + ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: "repeat".to_string(), + }], + phase: None, + }, + ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: String::new(), + }], + phase: None, + }, + ]; + + let collected = collect_user_messages(&items); + + assert_eq!(vec!["repeat".to_string(), "repeat".to_string()], collected); +} + #[test] fn build_token_limited_compacted_history_truncates_overlong_user_messages() { // Use a small truncation limit so the test remains fast while still validating diff --git a/codex-rs/core/src/inherited_thread_state.rs b/codex-rs/core/src/inherited_thread_state.rs new file mode 100644 index 0000000000..2eb0a91427 --- /dev/null +++ b/codex-rs/core/src/inherited_thread_state.rs @@ -0,0 +1,64 @@ +use codex_protocol::ThreadId; + +use crate::client::ResponseContinuation; +use crate::state::McpToolSnapshot; + +#[derive(Clone, Default)] +pub(crate) struct InheritedThreadState { + prompt_cache_key: Option, + response_continuation: Option, + mcp_tool_snapshot: Option, +} + +impl InheritedThreadState { + pub(crate) fn builder() -> InheritedThreadStateBuilder { + InheritedThreadStateBuilder::default() + } + + pub(crate) fn prompt_cache_key(&self) -> Option { + self.prompt_cache_key + } + + pub(crate) fn response_continuation(&self) -> Option { + self.response_continuation.clone() + } + + pub(crate) fn mcp_tool_snapshot(&self) -> Option { + self.mcp_tool_snapshot.clone() + } +} + +#[derive(Default)] +pub(crate) struct InheritedThreadStateBuilder { + prompt_cache_key: Option, + response_continuation: Option, + mcp_tool_snapshot: Option, +} + +impl InheritedThreadStateBuilder { + pub(crate) fn prompt_cache_key(mut self, prompt_cache_key: Option) -> Self { + self.prompt_cache_key = prompt_cache_key; + self + } + + pub(crate) fn response_continuation( + mut self, + response_continuation: Option, + ) -> Self { + self.response_continuation = response_continuation; + self + } + + pub(crate) fn mcp_tool_snapshot(mut self, mcp_tool_snapshot: Option) -> Self { + self.mcp_tool_snapshot = mcp_tool_snapshot; + self + } + + pub(crate) fn build(self) -> InheritedThreadState { + InheritedThreadState { + prompt_cache_key: self.prompt_cache_key, + response_continuation: self.response_continuation, + mcp_tool_snapshot: self.mcp_tool_snapshot, + } + } +} diff --git a/codex-rs/core/src/lib.rs b/codex-rs/core/src/lib.rs index a396851f98..32fe4ba1f9 100644 --- a/codex-rs/core/src/lib.rs +++ b/codex-rs/core/src/lib.rs @@ -40,6 +40,7 @@ mod git_info_tests; mod goals; mod guardian; mod hook_runtime; +mod inherited_thread_state; mod installation_id; pub(crate) mod landlock; pub use landlock::spawn_command_under_linux_sandbox; diff --git a/codex-rs/core/src/session/mcp.rs b/codex-rs/core/src/session/mcp.rs index 2aa5adee28..decb075c1c 100644 --- a/codex-rs/core/src/session/mcp.rs +++ b/codex-rs/core/src/session/mcp.rs @@ -267,6 +267,8 @@ impl Session { std::mem::replace(&mut *manager, refreshed_manager) }; old_manager.shutdown().await; + let mut snapshot = self.services.mcp_tool_snapshot.lock().await; + *snapshot = None; } pub(crate) async fn refresh_mcp_servers_if_requested(&self, turn_context: &TurnContext) { diff --git a/codex-rs/core/src/session/mod.rs b/codex-rs/core/src/session/mod.rs index c18976fde1..4666975216 100644 --- a/codex-rs/core/src/session/mod.rs +++ b/codex-rs/core/src/session/mod.rs @@ -32,6 +32,7 @@ use crate::context::PersonalitySpecInstructions; use crate::default_skill_metadata_budget; use crate::environment_selection::ResolvedTurnEnvironments; use crate::exec_policy::ExecPolicyManager; +use crate::inherited_thread_state::InheritedThreadState; use crate::installation_id::resolve_installation_id; use crate::parse_turn_item; use crate::path_utils::normalize_for_native_workdir; @@ -406,6 +407,7 @@ pub(crate) struct CodexSpawnArgs { /// Root sessions and non-thread-spawn subagents pass a disabled context; /// `Session::new` creates the root trace itself when rollout tracing is enabled. pub(crate) parent_rollout_thread_trace: ThreadTraceContext, + pub(crate) inherited_thread_state: InheritedThreadState, pub(crate) user_shell_override: Option, pub(crate) parent_trace: Option, pub(crate) environment_selections: ResolvedTurnEnvironments, @@ -464,6 +466,7 @@ impl Codex { user_shell_override, inherited_exec_policy, parent_rollout_thread_trace, + inherited_thread_state, parent_trace: _, environment_selections, analytics_events_client, @@ -640,6 +643,7 @@ impl Codex { skills_watcher, agent_control, environment_manager, + inherited_thread_state, analytics_events_client, thread_store, parent_rollout_thread_trace, @@ -1032,6 +1036,16 @@ impl Session { self.services.live_thread.as_ref() } + pub(crate) fn prompt_cache_key(&self) -> ThreadId { + self.services.model_client.prompt_cache_key() + } + + pub(crate) fn response_continuation_for_fork( + &self, + ) -> Option { + self.services.model_client.response_continuation_for_fork() + } + /// Flush rollout writes and return the final durability-barrier result. pub(crate) async fn flush_rollout(&self) -> std::io::Result<()> { if let Some(live_thread) = self.live_thread() { diff --git a/codex-rs/core/src/session/session.rs b/codex-rs/core/src/session/session.rs index 50b3345d61..8047b7a9f6 100644 --- a/codex-rs/core/src/session/session.rs +++ b/codex-rs/core/src/session/session.rs @@ -343,6 +343,7 @@ impl Session { skills_watcher: Arc, agent_control: AgentControl, environment_manager: Arc, + inherited_thread_state: InheritedThreadState, analytics_events_client: Option, thread_store: Arc, parent_rollout_thread_trace: ThreadTraceContext, @@ -822,6 +823,8 @@ impl Session { config.analytics_enabled, ) }); + let prompt_cache_key_override = inherited_thread_state.prompt_cache_key(); + let mcp_tool_snapshot = inherited_thread_state.mcp_tool_snapshot(); let services = SessionServices { // Initialize the MCP connection manager with an uninitialized // instance. It will be replaced with one created via @@ -864,16 +867,19 @@ impl Session { state_db: state_db_ctx.clone(), live_thread: live_thread_init.as_ref().cloned(), thread_store: Arc::clone(&thread_store), - model_client: ModelClient::new( + mcp_tool_snapshot: Mutex::new(mcp_tool_snapshot), + model_client: ModelClient::new_with_response_continuation( Some(Arc::clone(&auth_manager)), conversation_id, installation_id, + prompt_cache_key_override, session_configuration.provider.clone(), session_configuration.session_source.clone(), config.model_verbosity, config.features.enabled(Feature::EnableRequestCompression), config.features.enabled(Feature::RuntimeMetrics), Self::build_model_client_beta_features_header(config.as_ref()), + inherited_thread_state.response_continuation(), ), code_mode_service: crate::tools::code_mode::CodeModeService::new(), environment_manager, diff --git a/codex-rs/core/src/session/tests.rs b/codex-rs/core/src/session/tests.rs index 89633daf35..cf031891c9 100644 --- a/codex-rs/core/src/session/tests.rs +++ b/codex-rs/core/src/session/tests.rs @@ -133,6 +133,7 @@ use core_test_support::responses::ev_function_call; use core_test_support::responses::ev_response_created; use core_test_support::responses::mount_sse_once; use core_test_support::responses::mount_sse_sequence; +use core_test_support::responses::namespace_child_tool; use core_test_support::responses::sse; use core_test_support::responses::start_mock_server; use core_test_support::test_codex::test_codex; @@ -350,6 +351,7 @@ fn test_model_client_session() -> crate::client::ModelClientSession { ThreadId::try_from("00000000-0000-4000-8000-000000000001") .expect("test thread id should be valid"), /*installation_id*/ "11111111-1111-4111-8111-111111111111".to_string(), + /*prompt_cache_key_override*/ None, ModelProviderInfo::create_openai_provider(/* base_url */ /*base_url*/ None), codex_protocol::protocol::SessionSource::Exec, /*model_verbosity*/ None, @@ -1752,6 +1754,126 @@ async fn fork_startup_context_then_first_turn_diff_snapshot() -> anyhow::Result< Ok(()) } +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn inherited_thread_state_shapes_first_responses_request() -> anyhow::Result<()> { + let server = start_mock_server().await; + let response_mock = mount_sse_once( + &server, + sse(vec![ev_response_created("resp-1"), ev_completed("resp-1")]), + ) + .await; + let inherited_prompt_cache_key = ThreadId::try_from("00000000-0000-4000-8000-000000000002") + .expect("test thread id should be valid"); + let inherited_tool = codex_mcp::ToolInfo { + server_name: "snapshot".to_string(), + callable_name: "echo".to_string(), + callable_namespace: "mcp__snapshot__".to_string(), + server_instructions: None, + tool: rmcp::model::Tool { + name: "echo".to_string().into(), + title: None, + description: Some("Echo from the inherited MCP snapshot".to_string().into()), + input_schema: std::sync::Arc::new(rmcp::model::object(json!({ + "type": "object", + "properties": { + "message": { "type": "string" } + }, + "required": ["message"], + "additionalProperties": false + }))), + output_schema: None, + annotations: None, + execution: None, + icons: None, + meta: None, + }, + connector_id: None, + connector_name: None, + plugin_display_names: Vec::new(), + connector_description: None, + }; + let inherited_thread_state = crate::inherited_thread_state::InheritedThreadState::builder() + .prompt_cache_key(Some(inherited_prompt_cache_key)) + .mcp_tool_snapshot(Some(crate::state::McpToolSnapshot { + tools: std::collections::HashMap::from([( + "mcp__snapshot__echo".to_string(), + inherited_tool, + )]), + })) + .build(); + let (session, rx_event) = make_session_with_config_inherited_and_rx( + |config| { + config.model_provider.base_url = Some(format!("{}/v1", server.uri())); + config.model_provider.supports_websockets = false; + config + .mcp_servers + .set(std::collections::HashMap::new()) + .expect("empty mcp server config should be valid"); + }, + inherited_thread_state, + ) + .await?; + let turn_context = session.new_default_turn().await; + + session + .spawn_task( + Arc::clone(&turn_context), + vec![UserInput::Text { + text: "use inherited state".to_string(), + text_elements: Vec::new(), + }], + crate::tasks::RegularTask::new(), + ) + .await; + + let mut observed_events = Vec::new(); + let wait_result = timeout(Duration::from_secs(5), async { + while let Ok(event) = rx_event.recv().await { + observed_events.push(match &event.msg { + EventMsg::SessionConfigured(_) => "SessionConfigured", + EventMsg::McpStartupComplete(_) => "McpStartupComplete", + EventMsg::TurnStarted(_) => "TurnStarted", + EventMsg::RawResponseItem(_) => "RawResponseItem", + EventMsg::ItemStarted(_) => "ItemStarted", + EventMsg::ItemCompleted(_) => "ItemCompleted", + EventMsg::UserMessage(_) => "UserMessage", + EventMsg::StreamError(_) => "StreamError", + EventMsg::Error(_) => "Error", + EventMsg::TurnAborted(_) => "TurnAborted", + EventMsg::TurnComplete(_) => "TurnComplete", + _ => "Other", + }); + match event.msg { + EventMsg::TurnComplete(_) => return, + EventMsg::Error(error) => panic!("turn errored: {}", error.message), + EventMsg::TurnAborted(aborted) => panic!("turn aborted: {:?}", aborted.reason), + _ => {} + } + } + }) + .await; + if let Err(err) = wait_result { + panic!( + "turn should complete: {err:?}; captured requests: {}; observed events: {observed_events:#?}", + response_mock.requests().len(), + ); + } + + let request = response_mock.single_request(); + let body = request.body_json(); + let expected_prompt_cache_key = inherited_prompt_cache_key.to_string(); + assert_eq!( + body["prompt_cache_key"].as_str(), + Some(expected_prompt_cache_key.as_str()) + ); + assert!( + namespace_child_tool(&body, "mcp__snapshot__", "echo").is_some(), + "first request should expose inherited MCP snapshot tools: {body:#}" + ); + + Ok(()) +} + #[tokio::test] async fn record_initial_history_forked_hydrates_previous_turn_settings() { let (session, turn_context) = make_session_and_context().await; @@ -3487,6 +3609,7 @@ async fn session_new_fails_when_zsh_fork_enabled_without_zsh_path() { Arc::new(SkillsWatcher::noop()), AgentControl::default(), Arc::new(codex_exec_server::EnvironmentManager::default_for_tests()), + Default::default(), /*analytics_events_client*/ None, Arc::new(codex_thread_store::LocalThreadStore::new( codex_thread_store::LocalThreadStoreConfig::from_config(config.as_ref()), @@ -3599,6 +3722,7 @@ pub(crate) async fn make_session_and_context() -> (Session, TurnContext) { &config.permissions.approval_policy, &config.permissions.permission_profile, ))), + mcp_tool_snapshot: Mutex::new(None), mcp_startup_cancellation_token: Mutex::new(CancellationToken::new()), unified_exec_manager: UnifiedExecProcessManager::new( config.background_terminal_max_timeout, @@ -3642,6 +3766,7 @@ pub(crate) async fn make_session_and_context() -> (Session, TurnContext) { Some(auth_manager.clone()), conversation_id, /*installation_id*/ "11111111-1111-4111-8111-111111111111".to_string(), + /*prompt_cache_key_override*/ None, session_configuration.provider.clone(), session_configuration.session_source.clone(), config.model_verbosity, @@ -3721,6 +3846,13 @@ async fn make_session_with_config( async fn make_session_with_config_and_rx( mutator: impl FnOnce(&mut Config), +) -> anyhow::Result<(Arc, async_channel::Receiver)> { + make_session_with_config_inherited_and_rx(mutator, Default::default()).await +} + +async fn make_session_with_config_inherited_and_rx( + mutator: impl FnOnce(&mut Config), + inherited_thread_state: crate::inherited_thread_state::InheritedThreadState, ) -> anyhow::Result<(Arc, async_channel::Receiver)> { let codex_home = tempfile::tempdir().expect("create temp dir"); let mut config = build_test_config(codex_home.path()).await; @@ -3805,6 +3937,7 @@ async fn make_session_with_config_and_rx( Arc::new(SkillsWatcher::noop()), AgentControl::default(), Arc::new(codex_exec_server::EnvironmentManager::default_for_tests()), + inherited_thread_state, /*analytics_events_client*/ None, Arc::new(codex_thread_store::LocalThreadStore::new( codex_thread_store::LocalThreadStoreConfig::from_config(config.as_ref()), @@ -5081,6 +5214,7 @@ where &config.permissions.approval_policy, &config.permissions.permission_profile, ))), + mcp_tool_snapshot: Mutex::new(None), mcp_startup_cancellation_token: Mutex::new(CancellationToken::new()), unified_exec_manager: UnifiedExecProcessManager::new( config.background_terminal_max_timeout, @@ -5124,6 +5258,7 @@ where Some(Arc::clone(&auth_manager)), conversation_id, /*installation_id*/ "11111111-1111-4111-8111-111111111111".to_string(), + /*prompt_cache_key_override*/ None, session_configuration.provider.clone(), session_configuration.session_source.clone(), config.model_verbosity, diff --git a/codex-rs/core/src/session/tests/guardian_tests.rs b/codex-rs/core/src/session/tests/guardian_tests.rs index d6a87d466a..b2cea6af0b 100644 --- a/codex-rs/core/src/session/tests/guardian_tests.rs +++ b/codex-rs/core/src/session/tests/guardian_tests.rs @@ -753,6 +753,7 @@ async fn guardian_subagent_does_not_inherit_parent_exec_policy_rules() { inherited_shell_snapshot: None, inherited_exec_policy: Some(Arc::new(parent_exec_policy)), parent_rollout_thread_trace: codex_rollout_trace::ThreadTraceContext::disabled(), + inherited_thread_state: Default::default(), user_shell_override: None, parent_trace: None, environment_selections: ResolvedTurnEnvironments { diff --git a/codex-rs/core/src/session/turn.rs b/codex-rs/core/src/session/turn.rs index 2b37372a33..aa36273fb6 100644 --- a/codex-rs/core/src/session/turn.rs +++ b/codex-rs/core/src/session/turn.rs @@ -1122,13 +1122,18 @@ pub(crate) async fn built_tools( skills_outcome: Option<&SkillLoadOutcome>, cancellation_token: &CancellationToken, ) -> CodexResult> { - let mcp_connection_manager = sess.services.mcp_connection_manager.read().await; - let has_mcp_servers = mcp_connection_manager.has_servers(); - let all_mcp_tools = mcp_connection_manager - .list_all_tools() - .or_cancel(cancellation_token) - .await?; - drop(mcp_connection_manager); + let inherited_mcp_tools = sess.services.mcp_tool_snapshot.lock().await.clone(); + let (has_mcp_servers, all_mcp_tools) = if let Some(snapshot) = inherited_mcp_tools { + (!snapshot.tools.is_empty(), snapshot.tools) + } else { + let mcp_connection_manager = sess.services.mcp_connection_manager.read().await; + let has_mcp_servers = mcp_connection_manager.has_servers(); + let all_mcp_tools = mcp_connection_manager + .list_all_tools() + .or_cancel(cancellation_token) + .await?; + (has_mcp_servers, all_mcp_tools) + }; let loaded_plugins = sess .services .plugins_manager diff --git a/codex-rs/core/src/state/mod.rs b/codex-rs/core/src/state/mod.rs index 13f3bf6c86..b874067cc4 100644 --- a/codex-rs/core/src/state/mod.rs +++ b/codex-rs/core/src/state/mod.rs @@ -2,6 +2,7 @@ mod service; mod session; mod turn; +pub(crate) use service::McpToolSnapshot; pub(crate) use service::SessionServices; pub(crate) use session::SessionState; pub(crate) use turn::ActiveTurn; diff --git a/codex-rs/core/src/state/service.rs b/codex-rs/core/src/state/service.rs index 9cd9e97fbb..8fc60d030e 100644 --- a/codex-rs/core/src/state/service.rs +++ b/codex-rs/core/src/state/service.rs @@ -21,6 +21,7 @@ use codex_exec_server::EnvironmentManager; use codex_hooks::Hooks; use codex_login::AuthManager; use codex_mcp::McpConnectionManager; +use codex_mcp::ToolInfo as McpToolInfo; use codex_models_manager::manager::SharedModelsManager; use codex_otel::SessionTelemetry; use codex_rollout::state_db::StateDbHandle; @@ -34,8 +35,14 @@ use tokio::sync::RwLock; use tokio::sync::watch; use tokio_util::sync::CancellationToken; +#[derive(Clone, Default)] +pub(crate) struct McpToolSnapshot { + pub(crate) tools: HashMap, +} + pub(crate) struct SessionServices { pub(crate) mcp_connection_manager: Arc>, + pub(crate) mcp_tool_snapshot: Mutex>, pub(crate) mcp_startup_cancellation_token: Mutex, pub(crate) unified_exec_manager: UnifiedExecProcessManager, #[cfg_attr(not(unix), allow(dead_code))] diff --git a/codex-rs/core/src/thread_manager.rs b/codex-rs/core/src/thread_manager.rs index eb7419076d..5821bf237f 100644 --- a/codex-rs/core/src/thread_manager.rs +++ b/codex-rs/core/src/thread_manager.rs @@ -6,6 +6,7 @@ use crate::config::ThreadStoreConfig; use crate::environment_selection::default_thread_environment_selections; use crate::environment_selection::resolve_environment_selections; use crate::file_watcher::FileWatcher; +use crate::inherited_thread_state::InheritedThreadState; use crate::mcp::McpManager; use crate::rollout::RolloutRecorder; use crate::rollout::truncation; @@ -230,6 +231,7 @@ pub(crate) struct ResumeThreadWithHistoryOptions { pub(crate) session_source: SessionSource, pub(crate) inherited_shell_snapshot: Option>, pub(crate) inherited_exec_policy: Option>, + pub(crate) inherited_thread_state: InheritedThreadState, } /// Shared, `Arc`-owned state for [`ThreadManager`]. This `Arc` is required to have a single @@ -583,6 +585,7 @@ impl ThreadManager { options.metrics_service_name, /*inherited_shell_snapshot*/ None, /*inherited_exec_policy*/ None, + Default::default(), options.parent_trace, options.environments, /*user_shell_override*/ None, @@ -924,6 +927,7 @@ impl ThreadManagerState { /*inherited_shell_snapshot*/ None, /*inherited_exec_policy*/ None, /*environments*/ None, + Default::default(), )) .await } @@ -939,6 +943,7 @@ impl ThreadManagerState { inherited_shell_snapshot: Option>, inherited_exec_policy: Option>, environments: Option>, + inherited_thread_state: InheritedThreadState, ) -> CodexResult { let environments = environments.unwrap_or_else(|| { default_thread_environment_selections(self.environment_manager.as_ref(), &config.cwd) @@ -954,6 +959,7 @@ impl ThreadManagerState { metrics_service_name, inherited_shell_snapshot, inherited_exec_policy, + inherited_thread_state, /*parent_trace*/ None, environments, /*user_shell_override*/ None, @@ -972,6 +978,7 @@ impl ThreadManagerState { session_source, inherited_shell_snapshot, inherited_exec_policy, + inherited_thread_state, } = options; let environments = default_thread_environment_selections(self.environment_manager.as_ref(), &config.cwd); @@ -986,6 +993,7 @@ impl ThreadManagerState { /*metrics_service_name*/ None, inherited_shell_snapshot, inherited_exec_policy, + inherited_thread_state, /*parent_trace*/ None, environments, /*user_shell_override*/ None, @@ -1004,6 +1012,7 @@ impl ThreadManagerState { inherited_shell_snapshot: Option>, inherited_exec_policy: Option>, environments: Option>, + inherited_thread_state: InheritedThreadState, ) -> CodexResult { let environments = environments.unwrap_or_else(|| { default_thread_environment_selections(self.environment_manager.as_ref(), &config.cwd) @@ -1019,6 +1028,7 @@ impl ThreadManagerState { /*metrics_service_name*/ None, inherited_shell_snapshot, inherited_exec_policy, + inherited_thread_state, /*parent_trace*/ None, environments, /*user_shell_override*/ None, @@ -1052,6 +1062,7 @@ impl ThreadManagerState { metrics_service_name, /*inherited_shell_snapshot*/ None, /*inherited_exec_policy*/ None, + Default::default(), parent_trace, environments, user_shell_override, @@ -1072,6 +1083,7 @@ impl ThreadManagerState { metrics_service_name: Option, inherited_shell_snapshot: Option>, inherited_exec_policy: Option>, + inherited_thread_state: InheritedThreadState, parent_trace: Option, environments: Vec, user_shell_override: Option, @@ -1137,6 +1149,7 @@ impl ThreadManagerState { inherited_shell_snapshot, inherited_exec_policy, parent_rollout_thread_trace, + inherited_thread_state, user_shell_override, parent_trace, environment_selections, diff --git a/codex-rs/core/tests/responses_headers.rs b/codex-rs/core/tests/responses_headers.rs index 56e9893116..e0c1e8ea79 100644 --- a/codex-rs/core/tests/responses_headers.rs +++ b/codex-rs/core/tests/responses_headers.rs @@ -102,6 +102,7 @@ async fn responses_stream_includes_subagent_header_on_review() { /*auth_manager*/ None, conversation_id, /*installation_id*/ TEST_INSTALLATION_ID.to_string(), + /*prompt_cache_key_override*/ None, provider.clone(), session_source, config.model_verbosity, @@ -228,6 +229,7 @@ async fn responses_stream_includes_subagent_header_on_other() { /*auth_manager*/ None, conversation_id, /*installation_id*/ TEST_INSTALLATION_ID.to_string(), + /*prompt_cache_key_override*/ None, provider.clone(), session_source, config.model_verbosity, @@ -343,6 +345,7 @@ async fn responses_respects_model_info_overrides_from_config() { /*auth_manager*/ None, conversation_id, /*installation_id*/ TEST_INSTALLATION_ID.to_string(), + /*prompt_cache_key_override*/ None, provider.clone(), session_source, config.model_verbosity, diff --git a/codex-rs/core/tests/suite/client.rs b/codex-rs/core/tests/suite/client.rs index f4960af550..6e8a2567fa 100644 --- a/codex-rs/core/tests/suite/client.rs +++ b/codex-rs/core/tests/suite/client.rs @@ -884,6 +884,7 @@ async fn send_provider_auth_request(server: &MockServer, auth: ModelProviderAuth ))), conversation_id, /*installation_id*/ "11111111-1111-4111-8111-111111111111".to_string(), + /*prompt_cache_key_override*/ None, provider, SessionSource::Exec, config.model_verbosity, @@ -2288,6 +2289,7 @@ async fn azure_responses_request_includes_store_and_reasoning_ids() { /*auth_manager*/ None, conversation_id, /*installation_id*/ "11111111-1111-4111-8111-111111111111".to_string(), + /*prompt_cache_key_override*/ None, provider.clone(), SessionSource::Exec, config.model_verbosity, diff --git a/codex-rs/core/tests/suite/client_websockets.rs b/codex-rs/core/tests/suite/client_websockets.rs index cdbb65aabd..1691a6081a 100755 --- a/codex-rs/core/tests/suite/client_websockets.rs +++ b/codex-rs/core/tests/suite/client_websockets.rs @@ -1855,6 +1855,7 @@ async fn websocket_harness_with_provider_options( /*auth_manager*/ None, conversation_id, /*installation_id*/ TEST_INSTALLATION_ID.to_string(), + /*prompt_cache_key_override*/ None, provider.clone(), SessionSource::Exec, config.model_verbosity, diff --git a/codex-rs/exec/src/cli.rs b/codex-rs/exec/src/cli.rs index 2b12898c3c..d0169e1eae 100644 --- a/codex-rs/exec/src/cli.rs +++ b/codex-rs/exec/src/cli.rs @@ -16,6 +16,12 @@ pub struct Cli { #[command(subcommand)] pub command: Option, + /// Fork from an existing session id (or thread name) before sending the prompt. + /// + /// This creates a new session with copied history, similar to `codex fork`. + #[arg(long = "fork", value_name = "SESSION_ID")] + pub fork_session_id: Option, + #[clap(flatten)] pub shared: ExecSharedCliOptions, @@ -81,6 +87,19 @@ pub struct Cli { pub prompt: Option, } +impl Cli { + pub fn validate(self) -> Result { + if self.fork_session_id.is_some() && self.command.is_some() { + return Err(clap::Error::raw( + clap::error::ErrorKind::ArgumentConflict, + "--fork cannot be used with subcommands", + )); + } + + Ok(self) + } +} + impl std::ops::Deref for Cli { type Target = SharedCliOptions; @@ -156,7 +175,6 @@ fn mark_exec_global_args(cmd: clap::Command) -> clap::Command { arg.global(true) }) } - #[derive(Debug, clap::Subcommand)] pub enum Command { /// Resume a previous session by id or pick the most recent with --last. diff --git a/codex-rs/exec/src/cli_tests.rs b/codex-rs/exec/src/cli_tests.rs index 45f2aa330d..9d785ddc28 100644 --- a/codex-rs/exec/src/cli_tests.rs +++ b/codex-rs/exec/src/cli_tests.rs @@ -80,3 +80,22 @@ fn removed_full_auto_flag_reports_migration_path() { Some("warning: `--full-auto` is deprecated; use `--sandbox workspace-write` instead.") ); } + +#[test] +fn fork_option_parses_prompt() { + const PROMPT: &str = "echo fork-non-interactive"; + let cli = Cli::parse_from(["codex-exec", "--fork", "session-123", "--json", PROMPT]); + + assert_eq!(cli.fork_session_id.as_deref(), Some("session-123")); + assert_eq!(cli.prompt.as_deref(), Some(PROMPT)); + assert!(cli.command.is_none()); +} + +#[test] +fn fork_option_conflicts_with_subcommands() { + let err = Cli::try_parse_from(["codex-exec", "--fork", "session-123", "resume"]) + .and_then(Cli::validate) + .expect_err("fork should conflict with subcommands"); + + assert_eq!(err.kind(), clap::error::ErrorKind::ArgumentConflict); +} diff --git a/codex-rs/exec/src/lib.rs b/codex-rs/exec/src/lib.rs index f2f0ed030b..ce0c66cd65 100644 --- a/codex-rs/exec/src/lib.rs +++ b/codex-rs/exec/src/lib.rs @@ -34,6 +34,8 @@ use codex_app_server_protocol::ReviewTarget as ApiReviewTarget; use codex_app_server_protocol::ServerNotification; use codex_app_server_protocol::ServerRequest; use codex_app_server_protocol::Thread as AppServerThread; +use codex_app_server_protocol::ThreadForkParams; +use codex_app_server_protocol::ThreadForkResponse; use codex_app_server_protocol::ThreadItem as AppServerThreadItem; use codex_app_server_protocol::ThreadListParams; use codex_app_server_protocol::ThreadListResponse; @@ -198,6 +200,7 @@ struct ExecRunArgs { config: Config, dangerously_bypass_approvals_and_sandbox: bool, exec_span: tracing::Span, + fork_session_id: Option, images: Vec, json_mode: bool, last_message_file: Option, @@ -230,6 +233,7 @@ pub async fn run_main(cli: Cli, arg0_paths: Arg0DispatchPaths) -> anyhow::Result let Cli { command, + fork_session_id, shared, skip_git_repo_check, ephemeral, @@ -529,6 +533,7 @@ pub async fn run_main(cli: Cli, arg0_paths: Arg0DispatchPaths) -> anyhow::Result config, dangerously_bypass_approvals_and_sandbox, exec_span: exec_span.clone(), + fork_session_id, images, json_mode, last_message_file, @@ -550,6 +555,7 @@ async fn run_exec_session(args: ExecRunArgs) -> anyhow::Result<()> { config, dangerously_bypass_approvals_and_sandbox, exec_span, + fork_session_id, images, json_mode, last_message_file, @@ -667,58 +673,78 @@ async fn run_exec_session(args: ExecRunArgs) -> anyhow::Result<()> { anyhow::anyhow!("failed to initialize in-process app-server client: {err}") })?; - // Handle resume subcommand through existing `thread/list` + `thread/resume` - // APIs so exec no longer reaches into rollout storage directly. - let (primary_thread_id, fallback_session_configured) = if let Some(ExecCommand::Resume(args)) = - command.as_ref() - { - if let Some(thread_id) = resolve_resume_thread_id(&client, &config, args).await? { - let response: ThreadResumeResponse = send_request_with_response( - &client, - ClientRequest::ThreadResume { - request_id: request_ids.next(), - params: thread_resume_params_from_config(&config, thread_id), - }, - "thread/resume", - ) - .await - .map_err(anyhow::Error::msg)?; - let session_configured = - session_configured_from_thread_resume_response(&response, &config) - .map_err(anyhow::Error::msg)?; - (session_configured.session_id, session_configured) - } else { - let response: ThreadStartResponse = send_request_with_response( - &client, - ClientRequest::ThreadStart { - request_id: request_ids.next(), - params: thread_start_params_from_config(&config), - }, - "thread/start", - ) - .await - .map_err(anyhow::Error::msg)?; - let session_configured = - session_configured_from_thread_start_response(&response, &config) - .map_err(anyhow::Error::msg)?; - (session_configured.session_id, session_configured) + // Handle resume/fork/start through app-server APIs so exec no longer reaches into + // rollout storage directly for normal bootstrap. + let (primary_thread_id, fallback_session_configured) = match command.as_ref() { + Some(ExecCommand::Resume(args)) => { + if let Some(thread_id) = resolve_resume_thread_id(&client, &config, args).await? { + let response: ThreadResumeResponse = send_request_with_response( + &client, + ClientRequest::ThreadResume { + request_id: request_ids.next(), + params: thread_resume_params_from_config(&config, thread_id), + }, + "thread/resume", + ) + .await + .map_err(anyhow::Error::msg)?; + let session_configured = + session_configured_from_thread_resume_response(&response, &config) + .map_err(anyhow::Error::msg)?; + (session_configured.session_id, session_configured) + } else { + let response: ThreadStartResponse = send_request_with_response( + &client, + ClientRequest::ThreadStart { + request_id: request_ids.next(), + params: thread_start_params_from_config(&config), + }, + "thread/start", + ) + .await + .map_err(anyhow::Error::msg)?; + let session_configured = + session_configured_from_thread_start_response(&response, &config) + .map_err(anyhow::Error::msg)?; + (session_configured.session_id, session_configured) + } + } + Some(ExecCommand::Review(_)) | None => { + if let Some(session_id) = fork_session_id.as_deref() { + let response: ThreadForkResponse = send_request_with_response( + &client, + ClientRequest::ThreadFork { + request_id: request_ids.next(), + params: thread_fork_params_from_config( + &config, session_id, /*path*/ None, + ), + }, + "thread/fork", + ) + .await + .map_err(anyhow::Error::msg)?; + let session_configured = + session_configured_from_thread_fork_response(&response, &config) + .map_err(anyhow::Error::msg)?; + (session_configured.session_id, session_configured) + } else { + let response: ThreadStartResponse = send_request_with_response( + &client, + ClientRequest::ThreadStart { + request_id: request_ids.next(), + params: thread_start_params_from_config(&config), + }, + "thread/start", + ) + .await + .map_err(anyhow::Error::msg)?; + let session_configured = + session_configured_from_thread_start_response(&response, &config) + .map_err(anyhow::Error::msg)?; + (session_configured.session_id, session_configured) + } } - } else { - let response: ThreadStartResponse = send_request_with_response( - &client, - ClientRequest::ThreadStart { - request_id: request_ids.next(), - params: thread_start_params_from_config(&config), - }, - "thread/start", - ) - .await - .map_err(anyhow::Error::msg)?; - let session_configured = session_configured_from_thread_start_response(&response, &config) - .map_err(anyhow::Error::msg)?; - (session_configured.session_id, session_configured) }; - let primary_thread_id_for_span = primary_thread_id.to_string(); // Use the start/resume response as the authoritative bootstrap payload. // Waiting for a later streamed `SessionConfigured` event adds up to 10s of @@ -1028,6 +1054,33 @@ fn approvals_reviewer_override_from_config( Some(config.approvals_reviewer.into()) } +fn thread_fork_params_from_config( + config: &Config, + thread_id: &str, + path: Option, +) -> ThreadForkParams { + let permissions = permissions_selection_from_config(config); + let sandbox = permissions.is_none().then(|| { + sandbox_mode_from_permission_profile( + &config.permissions.permission_profile(), + config.cwd.as_path(), + ) + }); + ThreadForkParams { + thread_id: thread_id.to_string(), + path, + model: config.model.clone(), + model_provider: Some(config.model_provider_id.clone()), + cwd: Some(config.cwd.to_string_lossy().to_string()), + approval_policy: Some(config.permissions.approval_policy.value().into()), + approvals_reviewer: approvals_reviewer_override_from_config(config), + sandbox: sandbox.flatten(), + permissions, + config: config_request_overrides_from_config(config), + ..ThreadForkParams::default() + } +} + async fn send_request_with_response( client: &InProcessAppServerClient, request: ClientRequest, @@ -1093,6 +1146,30 @@ fn session_configured_from_thread_resume_response( ) } +fn session_configured_from_thread_fork_response( + response: &ThreadForkResponse, + config: &Config, +) -> Result { + session_configured_from_thread_response( + &response.thread.id, + response.thread.name.clone(), + response.thread.path.clone(), + response.model.clone(), + response.model_provider.clone(), + response.service_tier, + response.approval_policy.to_core(), + response.approvals_reviewer.to_core(), + response + .permission_profile + .clone() + .map(Into::into) + .unwrap_or_else(|| config.permissions.permission_profile()), + response.active_permission_profile.clone().map(Into::into), + response.cwd.clone(), + response.reasoning_effort, + ) +} + fn review_target_to_api(target: ReviewTarget) -> ApiReviewTarget { match target { ReviewTarget::UncommittedChanges => ApiReviewTarget::UncommittedChanges, diff --git a/codex-rs/exec/src/lib_tests.rs b/codex-rs/exec/src/lib_tests.rs index 648d512689..9d72d7d5a7 100644 --- a/codex-rs/exec/src/lib_tests.rs +++ b/codex-rs/exec/src/lib_tests.rs @@ -408,6 +408,11 @@ async fn thread_lifecycle_params_include_legacy_sandbox_when_no_active_profile() let start_params = thread_start_params_from_config(&config); let resume_params = thread_resume_params_from_config(&config, "thread-id".to_string()); + let fork_params = thread_fork_params_from_config( + &config, + "67e55044-10b1-426f-9247-bb680e5fe0c8", + /*path*/ None, + ); assert_eq!(config.permissions.active_permission_profile(), None); assert_eq!( @@ -420,6 +425,43 @@ async fn thread_lifecycle_params_include_legacy_sandbox_when_no_active_profile() Some(codex_app_server_protocol::SandboxMode::DangerFullAccess) ); assert_eq!(resume_params.permissions, None); + assert_eq!( + fork_params.sandbox, + Some(codex_app_server_protocol::SandboxMode::DangerFullAccess) + ); + assert_eq!(fork_params.permissions, None); +} + +#[tokio::test] +async fn thread_fork_params_include_review_policy_when_auto_review_is_enabled() { + let codex_home = tempdir().expect("create temp codex home"); + let cwd = tempdir().expect("create temp cwd"); + let config = ConfigBuilder::default() + .codex_home(codex_home.path().to_path_buf()) + .harness_overrides(ConfigOverrides { + approvals_reviewer: Some(ApprovalsReviewer::AutoReview), + ..Default::default() + }) + .fallback_cwd(Some(cwd.path().to_path_buf())) + .build() + .await + .expect("build config for fork params"); + + let params = thread_fork_params_from_config( + &config, + "67e55044-10b1-426f-9247-bb680e5fe0c8", + /*path*/ None, + ); + + assert_eq!( + params.approvals_reviewer, + Some(codex_app_server_protocol::ApprovalsReviewer::AutoReview) + ); + assert_eq!(params.sandbox, None); + assert_eq!( + params.permissions, + permissions_selection_from_config(&config) + ); } #[tokio::test] @@ -498,3 +540,58 @@ fn sample_thread_start_response() -> ThreadStartResponse { reasoning_effort: None, } } + +#[tokio::test] +async fn session_configured_from_thread_fork_response_preserves_permission_profile() { + let codex_home = tempdir().expect("create temp codex home"); + let cwd = tempdir().expect("create temp cwd"); + let config = ConfigBuilder::default() + .codex_home(codex_home.path().to_path_buf()) + .fallback_cwd(Some(cwd.path().to_path_buf())) + .build() + .await + .expect("build config"); + let permission_profile = PermissionProfile::Disabled; + let response = ThreadForkResponse { + thread: codex_app_server_protocol::Thread { + id: "67e55044-10b1-426f-9247-bb680e5fe0c8".to_string(), + forked_from_id: Some("f6f10963-370f-4f42-8f3b-bb680e5fe0c8".to_string()), + preview: String::new(), + ephemeral: false, + model_provider: "openai".to_string(), + created_at: 0, + updated_at: 0, + status: codex_app_server_protocol::ThreadStatus::Idle, + path: Some(PathBuf::from("/tmp/fork-rollout.jsonl")), + cwd: test_path_buf("/tmp").abs(), + cli_version: "0.0.0".to_string(), + source: codex_app_server_protocol::SessionSource::Cli, + agent_nickname: None, + agent_role: None, + git_info: None, + name: Some("forked-thread".to_string()), + turns: vec![], + }, + model: "gpt-5.4".to_string(), + model_provider: "openai".to_string(), + service_tier: None, + cwd: test_path_buf("/tmp").abs(), + instruction_sources: Vec::new(), + approval_policy: codex_app_server_protocol::AskForApproval::OnRequest, + approvals_reviewer: codex_app_server_protocol::ApprovalsReviewer::AutoReview, + sandbox: codex_app_server_protocol::SandboxPolicy::WorkspaceWrite { + writable_roots: vec![], + network_access: false, + exclude_tmpdir_env_var: false, + exclude_slash_tmp: false, + }, + permission_profile: Some(permission_profile.clone().into()), + active_permission_profile: None, + reasoning_effort: None, + }; + + let event = session_configured_from_thread_fork_response(&response, &config) + .expect("build fork session configured event"); + + assert_eq!(event.permission_profile, permission_profile); +} diff --git a/codex-rs/exec/src/main.rs b/codex-rs/exec/src/main.rs index 79a681b146..f820319ce7 100644 --- a/codex-rs/exec/src/main.rs +++ b/codex-rs/exec/src/main.rs @@ -29,7 +29,10 @@ fn main() -> anyhow::Result<()> { arg0_dispatch_or_else(|arg0_paths: Arg0DispatchPaths| async move { let top_cli = TopCli::parse(); // Merge root-level overrides into inner CLI struct so downstream logic remains unchanged. - let mut inner = top_cli.inner; + let mut inner = match top_cli.inner.validate() { + Ok(inner) => inner, + Err(err) => err.exit(), + }; inner .config_overrides .raw_overrides diff --git a/codex-rs/exec/src/main_tests.rs b/codex-rs/exec/src/main_tests.rs index a9cb0ec633..f99b7ed251 100644 --- a/codex-rs/exec/src/main_tests.rs +++ b/codex-rs/exec/src/main_tests.rs @@ -35,3 +35,24 @@ fn top_cli_parses_resume_prompt_after_config_flag() { "reasoning_level=xhigh" ); } + +#[test] +fn top_cli_parses_fork_option_with_root_config() { + let cli = TopCli::parse_from([ + "codex-exec", + "--config", + "reasoning_level=xhigh", + "--fork", + "session-123", + "echo fork", + ]); + + assert_eq!(cli.inner.fork_session_id.as_deref(), Some("session-123")); + assert!(cli.inner.command.is_none()); + assert_eq!(cli.inner.prompt.as_deref(), Some("echo fork")); + assert_eq!(cli.config_overrides.raw_overrides.len(), 1); + assert_eq!( + cli.config_overrides.raw_overrides[0], + "reasoning_level=xhigh" + ); +} diff --git a/codex-rs/exec/tests/suite/fork.rs b/codex-rs/exec/tests/suite/fork.rs new file mode 100644 index 0000000000..e706b2b558 --- /dev/null +++ b/codex-rs/exec/tests/suite/fork.rs @@ -0,0 +1,173 @@ +#![allow(clippy::unwrap_used, clippy::expect_used)] + +use anyhow::Context; +use codex_utils_cargo_bin::find_resource; +use core_test_support::test_codex_exec::test_codex_exec; +use serde_json::Value; +use std::string::ToString; +use uuid::Uuid; +use walkdir::WalkDir; + +/// Utility: scan the sessions dir for a rollout file that contains `marker` +/// in any response_item.message.content entry. Returns the absolute path. +fn find_session_file_containing_marker( + sessions_dir: &std::path::Path, + marker: &str, +) -> Option { + for entry in WalkDir::new(sessions_dir) { + let entry = match entry { + Ok(e) => e, + Err(_) => continue, + }; + if !entry.file_type().is_file() { + continue; + } + if !entry.file_name().to_string_lossy().ends_with(".jsonl") { + continue; + } + let path = entry.path(); + let Ok(content) = std::fs::read_to_string(path) else { + continue; + }; + // Skip the first meta line and scan remaining JSONL entries. + let mut lines = content.lines(); + if lines.next().is_none() { + continue; + } + for line in lines { + if line.trim().is_empty() { + continue; + } + let Ok(item): Result = serde_json::from_str(line) else { + continue; + }; + if item.get("type").and_then(|t| t.as_str()) == Some("response_item") + && let Some(payload) = item.get("payload") + && payload.get("type").and_then(|t| t.as_str()) == Some("message") + && payload + .get("content") + .map(ToString::to_string) + .unwrap_or_default() + .contains(marker) + { + return Some(path.to_path_buf()); + } + } + } + None +} + +/// Extract the conversation UUID from the first SessionMeta line in the rollout file. +fn extract_conversation_id(path: &std::path::Path) -> String { + let content = std::fs::read_to_string(path).unwrap(); + let mut lines = content.lines(); + let meta_line = lines.next().expect("missing meta line"); + let meta: Value = serde_json::from_str(meta_line).expect("invalid meta json"); + meta.get("payload") + .and_then(|p| p.get("id")) + .and_then(|v| v.as_str()) + .unwrap_or_default() + .to_string() +} + +fn extract_forked_from_id(path: &std::path::Path) -> Option { + let content = std::fs::read_to_string(path).unwrap(); + let mut lines = content.lines(); + let meta_line = lines.next().expect("missing meta line"); + let meta: Value = serde_json::from_str(meta_line).expect("invalid meta json"); + meta.get("payload") + .and_then(|payload| payload.get("forked_from_id")) + .and_then(Value::as_str) + .map(ToString::to_string) +} + +fn rollout_contains_fork_reference(path: &std::path::Path) -> bool { + extract_fork_reference(path).is_some() +} + +fn extract_fork_reference(path: &std::path::Path) -> Option<(String, usize)> { + let Ok(content) = std::fs::read_to_string(path) else { + return None; + }; + content.lines().skip(1).find_map(|line| { + let item = serde_json::from_str::(line).ok()?; + if item.get("type").and_then(Value::as_str) != Some("fork_reference") { + return None; + } + let payload = item.get("payload")?; + let rollout_path = payload.get("rollout_path")?.as_str()?.to_string(); + let nth_user_message = payload + .get("nth_user_message")? + .as_u64() + .and_then(|value| usize::try_from(value).ok())?; + Some((rollout_path, nth_user_message)) + }) +} + +fn exec_fixture() -> anyhow::Result { + Ok(find_resource!("tests/fixtures/cli_responses_fixture.sse")?) +} + +#[test] +fn exec_fork_by_id_creates_new_session_with_copied_history() -> anyhow::Result<()> { + let test = test_codex_exec(); + let fixture = exec_fixture()?; + + let marker = format!("fork-base-{}", Uuid::new_v4()); + let prompt = format!("echo {marker}"); + + test.cmd() + .env("CODEX_RS_SSE_FIXTURE", &fixture) + .env("OPENAI_BASE_URL", "http://unused.local") + .arg("--skip-git-repo-check") + .arg(&prompt) + .assert() + .success(); + + let sessions_dir = test.home_path().join("sessions"); + let original_path = find_session_file_containing_marker(&sessions_dir, &marker) + .context("no session file found after first run")?; + let session_id = extract_conversation_id(&original_path); + + let marker2 = format!("fork-follow-up-{}", Uuid::new_v4()); + let prompt2 = format!("echo {marker2}"); + + test.cmd() + .env("CODEX_RS_SSE_FIXTURE", &fixture) + .env("OPENAI_BASE_URL", "http://unused.local") + .arg("--skip-git-repo-check") + .arg("--fork") + .arg(&session_id) + .arg(&prompt2) + .assert() + .success(); + + let forked_path = find_session_file_containing_marker(&sessions_dir, &marker2) + .context("no forked session file found for second marker")?; + + assert_ne!( + forked_path, original_path, + "fork should create a new session file" + ); + + let forked_content = std::fs::read_to_string(&forked_path)?; + assert_eq!( + extract_forked_from_id(&forked_path).as_deref(), + Some(session_id.as_str()) + ); + let fork_reference = + extract_fork_reference(&forked_path).context("forked rollout should record a reference")?; + assert_eq!(fork_reference.0, original_path.to_string_lossy().as_ref()); + assert_eq!(fork_reference.1, usize::MAX); + assert!(rollout_contains_fork_reference(&forked_path)); + assert!(forked_content.contains(&marker2)); + + let original_content = std::fs::read_to_string(&original_path)?; + assert!(original_content.contains(&marker)); + assert!( + !original_content.contains(&marker2), + "original session should not receive the forked prompt" + ); + + Ok(()) +} diff --git a/codex-rs/exec/tests/suite/mod.rs b/codex-rs/exec/tests/suite/mod.rs index c6fa0f9fde..5513badc37 100644 --- a/codex-rs/exec/tests/suite/mod.rs +++ b/codex-rs/exec/tests/suite/mod.rs @@ -3,6 +3,7 @@ mod add_dir; mod apply_patch; mod auth_env; mod ephemeral; +mod fork; mod mcp_required_exit; mod originator; mod output_schema;