mirror of
https://github.com/openai/codex.git
synced 2026-05-09 13:52:41 +00:00
Compare commits
2 Commits
eric/codex
...
dev/friel/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d1ec95ec0e | ||
|
|
c2db9386aa |
@@ -50,7 +50,8 @@ The legacy `notify` setting is deprecated and will be removed in a future releas
|
||||
|
||||
### `codex exec` to run Codex programmatically/non-interactively
|
||||
|
||||
To run Codex non-interactively, run `codex exec PROMPT` (you can also pass the prompt via `stdin`) and Codex will work on your task until it decides that it is done and exits. If you provide both a prompt argument and piped stdin, Codex appends stdin as a `<stdin>` block after the prompt so patterns like `echo "my output" | codex exec "Summarize this concisely"` work naturally. Output is printed to the terminal directly. You can set the `RUST_LOG` environment variable to see more about what's going on.
|
||||
To run Codex non-interactively, run `codex exec PROMPT` (you can also pass the prompt via `stdin`) and Codex will work on your task until it decides that it is done and exits. Output is printed to the terminal directly. You can set the `RUST_LOG` environment variable to see more about what's going on.
|
||||
Use `codex exec --fork <SESSION_ID> PROMPT` to fork an existing session without launching the interactive picker/UI.
|
||||
Use `codex exec --ephemeral ...` to run without persisting session rollout files to disk.
|
||||
|
||||
### Experimenting with the Codex Sandbox
|
||||
|
||||
@@ -764,12 +764,16 @@ async fn cli_main(arg0_paths: Arg0DispatchPaths) -> anyhow::Result<()> {
|
||||
.await?;
|
||||
handle_app_exit(exit_info)?;
|
||||
}
|
||||
Some(Subcommand::Exec(mut exec_cli)) => {
|
||||
Some(Subcommand::Exec(exec_cli)) => {
|
||||
reject_remote_mode_for_subcommand(
|
||||
root_remote.as_deref(),
|
||||
root_remote_auth_token_env.as_deref(),
|
||||
"exec",
|
||||
)?;
|
||||
let mut exec_cli = match exec_cli.validate() {
|
||||
Ok(exec_cli) => exec_cli,
|
||||
Err(err) => err.exit(),
|
||||
};
|
||||
exec_cli
|
||||
.shared
|
||||
.inherit_exec_root_options(&interactive.shared);
|
||||
@@ -1790,6 +1794,40 @@ mod tests {
|
||||
assert_eq!(args.session_id.as_deref(), Some("session-123"));
|
||||
assert_eq!(args.prompt.as_deref(), Some("re-review"));
|
||||
}
|
||||
#[test]
|
||||
fn exec_fork_accepts_prompt_positional() {
|
||||
let cli = MultitoolCli::try_parse_from([
|
||||
"codex",
|
||||
"exec",
|
||||
"--json",
|
||||
"--fork",
|
||||
"session-123",
|
||||
"2+2",
|
||||
])
|
||||
.expect("parse should succeed");
|
||||
|
||||
let Some(Subcommand::Exec(exec)) = cli.subcommand else {
|
||||
panic!("expected exec subcommand");
|
||||
};
|
||||
|
||||
assert_eq!(exec.fork_session_id.as_deref(), Some("session-123"));
|
||||
assert!(exec.command.is_none());
|
||||
assert_eq!(exec.prompt.as_deref(), Some("2+2"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn exec_fork_conflicts_with_resume_subcommand() {
|
||||
let cli =
|
||||
MultitoolCli::try_parse_from(["codex", "exec", "--fork", "session-123", "resume"])
|
||||
.expect("parse should succeed");
|
||||
|
||||
let Some(Subcommand::Exec(exec)) = cli.subcommand else {
|
||||
panic!("expected exec subcommand");
|
||||
};
|
||||
|
||||
let validate_result = exec.validate();
|
||||
assert!(validate_result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dangerous_bypass_conflicts_with_approval_policy() {
|
||||
@@ -2028,17 +2066,14 @@ mod tests {
|
||||
update_action: None,
|
||||
exit_reason: ExitReason::UserRequested,
|
||||
};
|
||||
let lines = format_exit_messages(exit_info, /*color_enabled*/ false);
|
||||
let lines = format_exit_messages(exit_info, false);
|
||||
assert!(lines.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn format_exit_messages_includes_resume_hint_without_color() {
|
||||
let exit_info = sample_exit_info(
|
||||
Some("123e4567-e89b-12d3-a456-426614174000"),
|
||||
/*thread_name*/ None,
|
||||
);
|
||||
let lines = format_exit_messages(exit_info, /*color_enabled*/ false);
|
||||
let exit_info = sample_exit_info(Some("123e4567-e89b-12d3-a456-426614174000"), None);
|
||||
let lines = format_exit_messages(exit_info, false);
|
||||
assert_eq!(
|
||||
lines,
|
||||
vec![
|
||||
@@ -2051,11 +2086,8 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn format_exit_messages_applies_color_when_enabled() {
|
||||
let exit_info = sample_exit_info(
|
||||
Some("123e4567-e89b-12d3-a456-426614174000"),
|
||||
/*thread_name*/ None,
|
||||
);
|
||||
let lines = format_exit_messages(exit_info, /*color_enabled*/ true);
|
||||
let exit_info = sample_exit_info(Some("123e4567-e89b-12d3-a456-426614174000"), None);
|
||||
let lines = format_exit_messages(exit_info, true);
|
||||
assert_eq!(lines.len(), 2);
|
||||
assert!(lines[1].contains("\u{1b}[36m"));
|
||||
}
|
||||
@@ -2066,7 +2098,7 @@ mod tests {
|
||||
Some("123e4567-e89b-12d3-a456-426614174000"),
|
||||
Some("my-thread"),
|
||||
);
|
||||
let lines = format_exit_messages(exit_info, /*color_enabled*/ false);
|
||||
let lines = format_exit_messages(exit_info, false);
|
||||
assert_eq!(
|
||||
lines,
|
||||
vec![
|
||||
@@ -2291,12 +2323,8 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn reject_remote_mode_for_non_interactive_subcommands() {
|
||||
let err = reject_remote_mode_for_subcommand(
|
||||
Some("127.0.0.1:4500"),
|
||||
/*remote_auth_token_env*/ None,
|
||||
"exec",
|
||||
)
|
||||
.expect_err("non-interactive subcommands should reject --remote");
|
||||
let err = reject_remote_mode_for_subcommand(Some("127.0.0.1:4500"), None, "exec")
|
||||
.expect_err("non-interactive subcommands should reject --remote");
|
||||
assert!(
|
||||
err.to_string()
|
||||
.contains("only supported for interactive TUI commands")
|
||||
@@ -2305,12 +2333,8 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn reject_remote_auth_token_env_for_non_interactive_subcommands() {
|
||||
let err = reject_remote_mode_for_subcommand(
|
||||
/*remote*/ None,
|
||||
Some("CODEX_REMOTE_AUTH_TOKEN"),
|
||||
"exec",
|
||||
)
|
||||
.expect_err("non-interactive subcommands should reject --remote-auth-token-env");
|
||||
let err = reject_remote_mode_for_subcommand(None, Some("CODEX_REMOTE_AUTH_TOKEN"), "exec")
|
||||
.expect_err("non-interactive subcommands should reject --remote-auth-token-env");
|
||||
assert!(
|
||||
err.to_string()
|
||||
.contains("only supported for interactive TUI commands")
|
||||
@@ -2324,7 +2348,7 @@ mod tests {
|
||||
out_dir: PathBuf::from("/tmp/out"),
|
||||
});
|
||||
let err = reject_remote_mode_for_app_server_subcommand(
|
||||
/*remote*/ None,
|
||||
None,
|
||||
Some("CODEX_REMOTE_AUTH_TOKEN"),
|
||||
Some(&subcommand),
|
||||
)
|
||||
|
||||
@@ -31,6 +31,7 @@ pub struct ResponsesClient<T: HttpTransport> {
|
||||
#[derive(Default)]
|
||||
pub struct ResponsesOptions {
|
||||
pub conversation_id: Option<String>,
|
||||
pub prompt_cache_key: Option<String>,
|
||||
pub session_source: Option<SessionSource>,
|
||||
pub extra_headers: HeaderMap,
|
||||
pub compression: Compression,
|
||||
@@ -73,6 +74,7 @@ impl<T: HttpTransport> ResponsesClient<T> {
|
||||
) -> Result<ResponseStream, ApiError> {
|
||||
let ResponsesOptions {
|
||||
conversation_id,
|
||||
prompt_cache_key,
|
||||
session_source,
|
||||
extra_headers,
|
||||
compression,
|
||||
@@ -89,7 +91,8 @@ impl<T: HttpTransport> ResponsesClient<T> {
|
||||
if let Some(ref conv_id) = conversation_id {
|
||||
insert_header(&mut headers, "x-client-request-id", conv_id);
|
||||
}
|
||||
headers.extend(build_conversation_headers(conversation_id));
|
||||
let session_id = prompt_cache_key.or_else(|| conversation_id.clone());
|
||||
headers.extend(build_conversation_headers(session_id));
|
||||
if let Some(subagent) = subagent_header(&session_source) {
|
||||
insert_header(&mut headers, "x-openai-subagent", &subagent);
|
||||
}
|
||||
|
||||
@@ -445,6 +445,7 @@ async fn azure_default_store_attaches_ids_and_headers() -> Result<()> {
|
||||
request,
|
||||
ResponsesOptions {
|
||||
conversation_id: Some("sess_123".into()),
|
||||
prompt_cache_key: Some("prompt_cache_123".into()),
|
||||
session_source: Some(SessionSource::SubAgent(SubAgentSource::Review)),
|
||||
extra_headers,
|
||||
compression: Compression::None,
|
||||
@@ -459,6 +460,12 @@ async fn azure_default_store_attaches_ids_and_headers() -> Result<()> {
|
||||
|
||||
assert_eq!(
|
||||
req.headers.get("session_id").and_then(|v| v.to_str().ok()),
|
||||
Some("prompt_cache_123")
|
||||
);
|
||||
assert_eq!(
|
||||
req.headers
|
||||
.get("x-client-request-id")
|
||||
.and_then(|v| v.to_str().ok()),
|
||||
Some("sess_123")
|
||||
);
|
||||
assert_eq!(
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
//! `codex-core`.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::future::Future;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
@@ -315,6 +316,34 @@ impl McpConnectionManager {
|
||||
failures
|
||||
}
|
||||
|
||||
pub fn required_startup_failures_future(
|
||||
&self,
|
||||
required_servers: Vec<String>,
|
||||
) -> impl Future<Output = Vec<McpStartupFailure>> + Send + 'static {
|
||||
let clients = self.clients.clone();
|
||||
async move {
|
||||
let mut failures = Vec::new();
|
||||
for server_name in required_servers {
|
||||
let Some(async_managed_client) = clients.get(&server_name).cloned() else {
|
||||
failures.push(McpStartupFailure {
|
||||
server: server_name.clone(),
|
||||
error: format!("required MCP server `{server_name}` was not initialized"),
|
||||
});
|
||||
continue;
|
||||
};
|
||||
|
||||
match async_managed_client.client().await {
|
||||
Ok(_) => {}
|
||||
Err(error) => failures.push(McpStartupFailure {
|
||||
server: server_name.clone(),
|
||||
error: startup_outcome_error_message(error),
|
||||
}),
|
||||
}
|
||||
}
|
||||
failures
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a single map that contains all tools. Each key is the
|
||||
/// fully-qualified name for the tool.
|
||||
#[instrument(level = "trace", skip_all)]
|
||||
@@ -329,6 +358,22 @@ impl McpConnectionManager {
|
||||
qualify_tools(tools)
|
||||
}
|
||||
|
||||
pub fn list_all_tools_future(
|
||||
&self,
|
||||
) -> impl Future<Output = HashMap<String, ToolInfo>> + Send + 'static {
|
||||
let clients = self.clients.values().cloned().collect::<Vec<_>>();
|
||||
async move {
|
||||
let mut tools = Vec::new();
|
||||
for managed_client in clients {
|
||||
let Some(server_tools) = managed_client.listed_tools().await else {
|
||||
continue;
|
||||
};
|
||||
tools.extend(server_tools);
|
||||
}
|
||||
qualify_tools(tools)
|
||||
}
|
||||
}
|
||||
|
||||
/// Force-refresh codex apps tools by bypassing the in-process cache.
|
||||
///
|
||||
/// On success, the refreshed tools replace the cache contents and the
|
||||
|
||||
@@ -5,10 +5,12 @@ use crate::agent::role::DEFAULT_ROLE_NAME;
|
||||
use crate::agent::role::resolve_role_config;
|
||||
use crate::agent::status::is_final;
|
||||
use crate::codex_thread::ThreadConfigSnapshot;
|
||||
use crate::inherited_thread_state::InheritedThreadState;
|
||||
use crate::session::emit_subagent_session_started;
|
||||
use crate::session_prefix::format_subagent_context_line;
|
||||
use crate::session_prefix::format_subagent_notification_message;
|
||||
use crate::shell_snapshot::ShellSnapshot;
|
||||
use crate::state::McpToolSnapshot;
|
||||
use crate::thread_manager::ResumeThreadWithHistoryOptions;
|
||||
use crate::thread_manager::ThreadManagerState;
|
||||
use crate::thread_rollout_truncation::truncate_rollout_to_last_n_fork_turns;
|
||||
@@ -42,6 +44,8 @@ use tracing::warn;
|
||||
|
||||
const AGENT_NAMES: &str = include_str!("agent_names.txt");
|
||||
const ROOT_LAST_TASK_MESSAGE: &str = "Main thread";
|
||||
const CODEX_EXPERIMENTAL_FORK_PREVIOUS_RESPONSE_ID_ENV: &str =
|
||||
"CODEX_EXPERIMENTAL_FORK_PREVIOUS_RESPONSE_ID";
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub(crate) enum SpawnAgentForkMode {
|
||||
@@ -219,6 +223,18 @@ impl AgentControl {
|
||||
// The same `AgentControl` is sent to spawn the thread.
|
||||
let new_thread = match (session_source, options.fork_mode.as_ref()) {
|
||||
(Some(session_source), Some(_)) => {
|
||||
let inherited_thread_state = InheritedThreadState::builder()
|
||||
.prompt_cache_key(
|
||||
parent_prompt_cache_key_for_source(&state, Some(&session_source)).await,
|
||||
)
|
||||
.response_continuation(
|
||||
parent_response_continuation_for_source(&state, Some(&session_source))
|
||||
.await,
|
||||
)
|
||||
.mcp_tool_snapshot(
|
||||
parent_mcp_tool_snapshot_for_source(&state, Some(&session_source)).await,
|
||||
)
|
||||
.build();
|
||||
self.spawn_forked_thread(
|
||||
&state,
|
||||
config,
|
||||
@@ -226,6 +242,7 @@ impl AgentControl {
|
||||
&options,
|
||||
inherited_shell_snapshot,
|
||||
inherited_exec_policy,
|
||||
inherited_thread_state,
|
||||
)
|
||||
.await?
|
||||
}
|
||||
@@ -240,6 +257,7 @@ impl AgentControl {
|
||||
inherited_shell_snapshot,
|
||||
inherited_exec_policy,
|
||||
options.environments.clone(),
|
||||
Default::default(),
|
||||
)
|
||||
.await?
|
||||
}
|
||||
@@ -325,6 +343,7 @@ impl AgentControl {
|
||||
})
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn spawn_forked_thread(
|
||||
&self,
|
||||
state: &Arc<ThreadManagerState>,
|
||||
@@ -333,6 +352,7 @@ impl AgentControl {
|
||||
options: &SpawnAgentOptions,
|
||||
inherited_shell_snapshot: Option<Arc<ShellSnapshot>>,
|
||||
inherited_exec_policy: Option<Arc<crate::exec_policy::ExecPolicyManager>>,
|
||||
inherited_thread_state: InheritedThreadState,
|
||||
) -> CodexResult<crate::thread_manager::NewThread> {
|
||||
if options.fork_parent_spawn_call_id.is_none() {
|
||||
return Err(CodexErr::Fatal(
|
||||
@@ -380,45 +400,59 @@ impl AgentControl {
|
||||
))
|
||||
})?;
|
||||
|
||||
let mut forked_rollout_items = parent_history.items;
|
||||
if let SpawnAgentForkMode::LastNTurns(last_n_turns) = fork_mode {
|
||||
forked_rollout_items =
|
||||
truncate_rollout_to_last_n_fork_turns(&forked_rollout_items, *last_n_turns);
|
||||
}
|
||||
// MultiAgentV2 root/subagent usage hints are injected as standalone developer
|
||||
// messages at thread start. When forking history, drop hints from the parent
|
||||
// so the child gets a fresh hint that matches its own session source/config.
|
||||
let multi_agent_v2_usage_hint_texts_to_filter: Vec<String> =
|
||||
if let Some(parent_thread) = parent_thread.as_ref() {
|
||||
parent_thread
|
||||
.codex
|
||||
.session
|
||||
.configured_multi_agent_v2_usage_hint_texts()
|
||||
.await
|
||||
} else if config.features.enabled(Feature::MultiAgentV2) {
|
||||
[
|
||||
config.multi_agent_v2.root_agent_usage_hint_text.clone(),
|
||||
config.multi_agent_v2.subagent_usage_hint_text.clone(),
|
||||
]
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.collect()
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
forked_rollout_items.retain(|item| {
|
||||
if let RolloutItem::ResponseItem(ResponseItem::Message { role, content, .. }) = item
|
||||
&& role == "developer"
|
||||
&& let [ContentItem::InputText { text }] = content.as_slice()
|
||||
&& multi_agent_v2_usage_hint_texts_to_filter
|
||||
.iter()
|
||||
.any(|usage_hint_text| usage_hint_text == text)
|
||||
{
|
||||
return false;
|
||||
let response_continuation = inherited_thread_state.response_continuation();
|
||||
let use_response_continuation_baseline =
|
||||
response_continuation.is_some() && matches!(fork_mode, SpawnAgentForkMode::FullHistory);
|
||||
let mut forked_rollout_items = if let (Some(response_continuation), true) =
|
||||
(&response_continuation, use_response_continuation_baseline)
|
||||
{
|
||||
previous_response_fork_rollout_items(
|
||||
parent_history.items,
|
||||
response_continuation.fork_baseline_input(),
|
||||
)
|
||||
} else {
|
||||
let mut items = parent_history.items;
|
||||
if let SpawnAgentForkMode::LastNTurns(last_n_turns) = fork_mode {
|
||||
items = truncate_rollout_to_last_n_fork_turns(&items, *last_n_turns);
|
||||
}
|
||||
items
|
||||
};
|
||||
if !use_response_continuation_baseline {
|
||||
// MultiAgentV2 root/subagent usage hints are injected as standalone developer
|
||||
// messages at thread start. When forking history, drop hints from the parent
|
||||
// so the child gets a fresh hint that matches its own session source/config.
|
||||
let multi_agent_v2_usage_hint_texts_to_filter: Vec<String> =
|
||||
if let Some(parent_thread) = parent_thread.as_ref() {
|
||||
parent_thread
|
||||
.codex
|
||||
.session
|
||||
.configured_multi_agent_v2_usage_hint_texts()
|
||||
.await
|
||||
} else if config.features.enabled(Feature::MultiAgentV2) {
|
||||
[
|
||||
config.multi_agent_v2.root_agent_usage_hint_text.clone(),
|
||||
config.multi_agent_v2.subagent_usage_hint_text.clone(),
|
||||
]
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.collect()
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
forked_rollout_items.retain(|item| {
|
||||
if let RolloutItem::ResponseItem(ResponseItem::Message { role, content, .. }) = item
|
||||
&& role == "developer"
|
||||
&& let [ContentItem::InputText { text }] = content.as_slice()
|
||||
&& multi_agent_v2_usage_hint_texts_to_filter
|
||||
.iter()
|
||||
.any(|usage_hint_text| usage_hint_text == text)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
keep_forked_rollout_item(item)
|
||||
});
|
||||
keep_forked_rollout_item(item)
|
||||
});
|
||||
}
|
||||
|
||||
state
|
||||
.fork_thread_with_source(
|
||||
@@ -430,6 +464,7 @@ impl AgentControl {
|
||||
inherited_shell_snapshot,
|
||||
inherited_exec_policy,
|
||||
options.environments.clone(),
|
||||
inherited_thread_state,
|
||||
)
|
||||
.await
|
||||
}
|
||||
@@ -586,6 +621,7 @@ impl AgentControl {
|
||||
session_source,
|
||||
inherited_shell_snapshot,
|
||||
inherited_exec_policy,
|
||||
inherited_thread_state: Default::default(),
|
||||
})
|
||||
.await?;
|
||||
let mut agent_metadata = agent_metadata;
|
||||
@@ -1190,6 +1226,114 @@ impl AgentControl {
|
||||
}
|
||||
}
|
||||
|
||||
async fn parent_prompt_cache_key_for_source(
|
||||
state: &Arc<ThreadManagerState>,
|
||||
session_source: Option<&SessionSource>,
|
||||
) -> Option<ThreadId> {
|
||||
let Some(SessionSource::SubAgent(SubAgentSource::ThreadSpawn {
|
||||
parent_thread_id, ..
|
||||
})) = session_source
|
||||
else {
|
||||
return None;
|
||||
};
|
||||
|
||||
state
|
||||
.get_thread(*parent_thread_id)
|
||||
.await
|
||||
.ok()
|
||||
.map(|parent_thread| parent_thread.codex.session.prompt_cache_key())
|
||||
}
|
||||
|
||||
fn previous_response_fork_rollout_items(
|
||||
source_items: Vec<RolloutItem>,
|
||||
baseline_input: Vec<ResponseItem>,
|
||||
) -> Vec<RolloutItem> {
|
||||
let source_session_meta = source_items.iter().find_map(|item| match item {
|
||||
RolloutItem::SessionMeta(meta) => Some(meta.clone()),
|
||||
RolloutItem::ResponseItem(_)
|
||||
| RolloutItem::Compacted(_)
|
||||
| RolloutItem::TurnContext(_)
|
||||
| RolloutItem::EventMsg(_) => None,
|
||||
});
|
||||
let latest_turn_context = source_items.iter().rev().find_map(|item| match item {
|
||||
RolloutItem::TurnContext(turn_context) => Some(turn_context.clone()),
|
||||
RolloutItem::ResponseItem(_)
|
||||
| RolloutItem::Compacted(_)
|
||||
| RolloutItem::SessionMeta(_)
|
||||
| RolloutItem::EventMsg(_) => None,
|
||||
});
|
||||
|
||||
source_session_meta
|
||||
.into_iter()
|
||||
.map(RolloutItem::SessionMeta)
|
||||
.chain(baseline_input.into_iter().map(RolloutItem::ResponseItem))
|
||||
.chain(
|
||||
latest_turn_context
|
||||
.into_iter()
|
||||
.map(RolloutItem::TurnContext),
|
||||
)
|
||||
.collect()
|
||||
}
|
||||
|
||||
async fn parent_mcp_tool_snapshot_for_source(
|
||||
state: &Arc<ThreadManagerState>,
|
||||
session_source: Option<&SessionSource>,
|
||||
) -> Option<McpToolSnapshot> {
|
||||
let Some(SessionSource::SubAgent(SubAgentSource::ThreadSpawn {
|
||||
parent_thread_id, ..
|
||||
})) = session_source
|
||||
else {
|
||||
return None;
|
||||
};
|
||||
|
||||
let parent_thread = state.get_thread(*parent_thread_id).await.ok()?;
|
||||
let list_all_tools = {
|
||||
let mcp_connection_manager = parent_thread
|
||||
.codex
|
||||
.session
|
||||
.services
|
||||
.mcp_connection_manager
|
||||
.read()
|
||||
.await;
|
||||
mcp_connection_manager.list_all_tools_future()
|
||||
};
|
||||
let tools = list_all_tools.await;
|
||||
Some(McpToolSnapshot { tools })
|
||||
}
|
||||
|
||||
fn fork_previous_response_id_enabled() -> bool {
|
||||
std::env::var(CODEX_EXPERIMENTAL_FORK_PREVIOUS_RESPONSE_ID_ENV)
|
||||
.is_ok_and(|value| fork_previous_response_id_value_enabled(&value))
|
||||
}
|
||||
|
||||
fn fork_previous_response_id_value_enabled(value: &str) -> bool {
|
||||
matches!(
|
||||
value.to_ascii_lowercase().as_str(),
|
||||
"1" | "true" | "yes" | "on"
|
||||
)
|
||||
}
|
||||
|
||||
async fn parent_response_continuation_for_source(
|
||||
state: &Arc<ThreadManagerState>,
|
||||
session_source: Option<&SessionSource>,
|
||||
) -> Option<crate::client::ResponseContinuation> {
|
||||
if !fork_previous_response_id_enabled() {
|
||||
return None;
|
||||
}
|
||||
let Some(SessionSource::SubAgent(SubAgentSource::ThreadSpawn {
|
||||
parent_thread_id, ..
|
||||
})) = session_source
|
||||
else {
|
||||
return None;
|
||||
};
|
||||
|
||||
state
|
||||
.get_thread(*parent_thread_id)
|
||||
.await
|
||||
.ok()
|
||||
.and_then(|parent_thread| parent_thread.codex.session.response_continuation_for_fork())
|
||||
}
|
||||
|
||||
fn thread_spawn_parent_thread_id(session_source: &SessionSource) -> Option<ThreadId> {
|
||||
match session_source {
|
||||
SessionSource::SubAgent(SubAgentSource::ThreadSpawn {
|
||||
|
||||
@@ -8,6 +8,8 @@ use crate::config::ConfigBuilder;
|
||||
use crate::context::ContextualUserFragment;
|
||||
use crate::context::SubagentNotification;
|
||||
use assert_matches::assert_matches;
|
||||
use codex_config::types::McpServerConfig;
|
||||
use codex_config::types::McpServerTransportConfig;
|
||||
use codex_features::Feature;
|
||||
use codex_login::CodexAuth;
|
||||
use codex_protocol::AgentPath;
|
||||
@@ -28,6 +30,12 @@ use codex_thread_store::ArchiveThreadParams;
|
||||
use codex_thread_store::LocalThreadStore;
|
||||
use codex_thread_store::LocalThreadStoreConfig;
|
||||
use codex_thread_store::ThreadStore;
|
||||
use core_test_support::responses::ev_completed;
|
||||
use core_test_support::responses::ev_response_created;
|
||||
use core_test_support::responses::mount_sse_once;
|
||||
use core_test_support::responses::namespace_child_tool;
|
||||
use core_test_support::responses::sse;
|
||||
use core_test_support::responses::start_mock_server;
|
||||
use pretty_assertions::assert_eq;
|
||||
use tempfile::TempDir;
|
||||
use tokio::time::Duration;
|
||||
@@ -71,6 +79,57 @@ fn assistant_message(text: &str, phase: Option<MessagePhase>) -> ResponseItem {
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fork_previous_response_id_env_value_parses_truthy_values() {
|
||||
for value in ["1", "true", "TRUE", "yes", "on"] {
|
||||
assert!(
|
||||
fork_previous_response_id_value_enabled(value),
|
||||
"{value} should enable previous response forking"
|
||||
);
|
||||
}
|
||||
|
||||
for value in ["", "0", "false", "off", "no", "enabled"] {
|
||||
assert!(
|
||||
!fork_previous_response_id_value_enabled(value),
|
||||
"{value} should not enable previous response forking"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn previous_response_fork_rollout_items_preserve_latest_turn_context() {
|
||||
let harness = AgentControlHarness::new().await;
|
||||
let (_thread_id, owner_thread) = harness.start_thread().await;
|
||||
let owner_turn = owner_thread.codex.session.new_default_turn().await;
|
||||
let mut first_turn_context = owner_turn.to_turn_context_item();
|
||||
first_turn_context.model = "first-model".to_string();
|
||||
let mut latest_turn_context = first_turn_context.clone();
|
||||
latest_turn_context.model = "latest-model".to_string();
|
||||
|
||||
let baseline_item = assistant_message(
|
||||
"parent final from previous response",
|
||||
Some(MessagePhase::FinalAnswer),
|
||||
);
|
||||
let items = previous_response_fork_rollout_items(
|
||||
vec![
|
||||
RolloutItem::TurnContext(first_turn_context),
|
||||
RolloutItem::ResponseItem(assistant_message(
|
||||
"parent rollout item should not be copied",
|
||||
Some(MessagePhase::FinalAnswer),
|
||||
)),
|
||||
RolloutItem::TurnContext(latest_turn_context.clone()),
|
||||
],
|
||||
vec![baseline_item.clone()],
|
||||
);
|
||||
|
||||
assert_eq!(items.len(), 2);
|
||||
assert_matches!(&items[0], RolloutItem::ResponseItem(item) if *item == baseline_item);
|
||||
assert_matches!(
|
||||
&items[1],
|
||||
RolloutItem::TurnContext(turn_context) if turn_context.model == latest_turn_context.model
|
||||
);
|
||||
}
|
||||
|
||||
fn spawn_agent_call(call_id: &str) -> ResponseItem {
|
||||
ResponseItem::FunctionCall {
|
||||
id: None,
|
||||
@@ -703,6 +762,39 @@ async fn spawn_agent_can_fork_parent_thread_history_with_sanitized_items() {
|
||||
.await
|
||||
.expect("child thread should be registered");
|
||||
assert_ne!(child_thread_id, parent_thread_id);
|
||||
assert_eq!(
|
||||
child_thread.codex.session.prompt_cache_key(),
|
||||
parent_thread.codex.session.prompt_cache_key(),
|
||||
);
|
||||
assert!(!Arc::ptr_eq(
|
||||
&child_thread.codex.session.services.mcp_connection_manager,
|
||||
&parent_thread.codex.session.services.mcp_connection_manager,
|
||||
));
|
||||
let mcp_tool_snapshot = child_thread
|
||||
.codex
|
||||
.session
|
||||
.services
|
||||
.mcp_tool_snapshot
|
||||
.lock()
|
||||
.await
|
||||
.clone()
|
||||
.expect("forked child should inherit an MCP tool snapshot");
|
||||
let list_all_tools = {
|
||||
let mcp_connection_manager = parent_thread
|
||||
.codex
|
||||
.session
|
||||
.services
|
||||
.mcp_connection_manager
|
||||
.read()
|
||||
.await;
|
||||
mcp_connection_manager.list_all_tools_future()
|
||||
};
|
||||
let parent_mcp_tools = list_all_tools.await;
|
||||
let mut snapshot_tool_names = mcp_tool_snapshot.tools.keys().cloned().collect::<Vec<_>>();
|
||||
snapshot_tool_names.sort();
|
||||
let mut parent_tool_names = parent_mcp_tools.keys().cloned().collect::<Vec<_>>();
|
||||
parent_tool_names.sort();
|
||||
assert_eq!(snapshot_tool_names, parent_tool_names);
|
||||
let history = child_thread.codex.session.clone_history().await;
|
||||
let expected_history = [
|
||||
ResponseItem::Message {
|
||||
@@ -751,6 +843,199 @@ async fn spawn_agent_can_fork_parent_thread_history_with_sanitized_items() {
|
||||
.expect("parent shutdown should submit");
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn forked_spawn_first_request_uses_parent_cache_key_and_mcp_snapshot() -> anyhow::Result<()> {
|
||||
let server = start_mock_server().await;
|
||||
let child_response_mock = mount_sse_once(
|
||||
&server,
|
||||
sse(vec![ev_response_created("resp-1"), ev_completed("resp-1")]),
|
||||
)
|
||||
.await;
|
||||
let (_home, mut config) = test_config().await;
|
||||
config.model_provider.base_url = Some(format!("{}/v1", server.uri()));
|
||||
config.model_provider.supports_websockets = false;
|
||||
let mcp_server_path = config.codex_home.join("fake_mcp_server.py");
|
||||
std::fs::write(
|
||||
&mcp_server_path,
|
||||
r#"import json
|
||||
import sys
|
||||
|
||||
def read_message():
|
||||
line = sys.stdin.buffer.readline()
|
||||
if not line:
|
||||
return None
|
||||
return json.loads(line)
|
||||
|
||||
def write_message(message):
|
||||
body = json.dumps(message).encode("utf-8")
|
||||
sys.stdout.buffer.write(body)
|
||||
sys.stdout.buffer.write(b"\n")
|
||||
sys.stdout.buffer.flush()
|
||||
|
||||
while True:
|
||||
message = read_message()
|
||||
if message is None:
|
||||
break
|
||||
method = message.get("method")
|
||||
request_id = message.get("id")
|
||||
if request_id is None:
|
||||
continue
|
||||
if method == "initialize":
|
||||
write_message({
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"result": {
|
||||
"protocolVersion": "2025-06-18",
|
||||
"capabilities": {"tools": {"listChanged": False}},
|
||||
"serverInfo": {"name": "fake-mcp", "version": "1.0.0"},
|
||||
},
|
||||
})
|
||||
elif method == "tools/list":
|
||||
write_message({
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"result": {
|
||||
"tools": [{
|
||||
"name": "echo",
|
||||
"description": "Echo from fake MCP",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"additionalProperties": False,
|
||||
},
|
||||
}],
|
||||
},
|
||||
})
|
||||
else:
|
||||
write_message({
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"error": {"code": -32601, "message": "method not found"},
|
||||
})
|
||||
"#,
|
||||
)?;
|
||||
config
|
||||
.mcp_servers
|
||||
.set(std::collections::HashMap::from([(
|
||||
"rmcp".to_string(),
|
||||
McpServerConfig {
|
||||
transport: McpServerTransportConfig::Stdio {
|
||||
command: "python3".to_string(),
|
||||
args: vec![mcp_server_path.to_string_lossy().to_string()],
|
||||
env: None,
|
||||
env_vars: Vec::new(),
|
||||
cwd: None,
|
||||
},
|
||||
experimental_environment: None,
|
||||
enabled: true,
|
||||
required: false,
|
||||
supports_parallel_tool_calls: false,
|
||||
disabled_reason: None,
|
||||
startup_timeout_sec: Some(Duration::from_secs(5)),
|
||||
tool_timeout_sec: None,
|
||||
default_tools_approval_mode: None,
|
||||
enabled_tools: None,
|
||||
disabled_tools: None,
|
||||
scopes: None,
|
||||
oauth_resource: None,
|
||||
tools: std::collections::HashMap::new(),
|
||||
},
|
||||
)]))
|
||||
.expect("test config should allow MCP servers");
|
||||
|
||||
let manager = ThreadManager::with_models_provider_and_home_for_tests(
|
||||
CodexAuth::from_api_key("dummy"),
|
||||
config.model_provider.clone(),
|
||||
config.codex_home.to_path_buf(),
|
||||
std::sync::Arc::new(codex_exec_server::EnvironmentManager::default_for_tests()),
|
||||
);
|
||||
let control = manager.agent_control();
|
||||
let parent = manager.start_thread(config.clone()).await?;
|
||||
let parent_thread_id = parent.thread_id;
|
||||
let parent_prompt_cache_key = parent.thread.codex.session.prompt_cache_key();
|
||||
let (list_all_tools, required_startup_failures) = {
|
||||
let mcp_connection_manager = parent
|
||||
.thread
|
||||
.codex
|
||||
.session
|
||||
.services
|
||||
.mcp_connection_manager
|
||||
.read()
|
||||
.await;
|
||||
(
|
||||
mcp_connection_manager.list_all_tools_future(),
|
||||
mcp_connection_manager.required_startup_failures_future(vec!["rmcp".to_string()]),
|
||||
)
|
||||
};
|
||||
let parent_mcp_tools = list_all_tools.await;
|
||||
let startup_failures = required_startup_failures.await;
|
||||
assert!(
|
||||
parent_mcp_tools.contains_key("mcp__rmcp__echo"),
|
||||
"parent MCP manager should expose live MCP tools before forking: tools={parent_mcp_tools:#?}; failures={startup_failures:#?}"
|
||||
);
|
||||
parent
|
||||
.thread
|
||||
.inject_user_message_without_turn("parent seed".to_string())
|
||||
.await;
|
||||
parent
|
||||
.thread
|
||||
.codex
|
||||
.session
|
||||
.ensure_rollout_materialized()
|
||||
.await;
|
||||
parent.thread.codex.session.flush_rollout().await?;
|
||||
|
||||
let child_thread_id = control
|
||||
.spawn_agent_with_metadata(
|
||||
config,
|
||||
text_input("child request boundary"),
|
||||
Some(SessionSource::SubAgent(SubAgentSource::ThreadSpawn {
|
||||
parent_thread_id,
|
||||
depth: 1,
|
||||
agent_path: None,
|
||||
agent_nickname: Some("worker".to_string()),
|
||||
agent_role: None,
|
||||
})),
|
||||
SpawnAgentOptions {
|
||||
fork_parent_spawn_call_id: Some("spawn-call-request-boundary".to_string()),
|
||||
fork_mode: Some(SpawnAgentForkMode::FullHistory),
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.await?
|
||||
.thread_id;
|
||||
let child_thread = manager
|
||||
.get_thread(child_thread_id)
|
||||
.await
|
||||
.expect("child thread should be registered");
|
||||
|
||||
timeout(Duration::from_secs(5), async {
|
||||
loop {
|
||||
let event = child_thread
|
||||
.next_event()
|
||||
.await
|
||||
.expect("child event channel should stay open");
|
||||
if matches!(event.msg, EventMsg::TurnComplete(_)) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
})
|
||||
.await
|
||||
.expect("child turn should complete");
|
||||
let body = child_response_mock.single_request().body_json();
|
||||
let expected_prompt_cache_key = parent_prompt_cache_key.to_string();
|
||||
assert_eq!(
|
||||
body["prompt_cache_key"].as_str(),
|
||||
Some(expected_prompt_cache_key.as_str())
|
||||
);
|
||||
assert!(
|
||||
namespace_child_tool(&body, "mcp__rmcp__", "echo").is_some(),
|
||||
"first forked child request should expose parent MCP snapshot tools: {body:#}"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn spawn_agent_fork_flushes_parent_rollout_before_loading_history() {
|
||||
let harness = AgentControlHarness::new().await;
|
||||
@@ -1550,7 +1835,7 @@ async fn resume_thread_subagent_restores_stored_nickname_and_role() {
|
||||
manager,
|
||||
control,
|
||||
};
|
||||
let (parent_thread_id, _parent_thread) = harness.start_thread().await;
|
||||
let (parent_thread_id, parent_thread) = harness.start_thread().await;
|
||||
let agent_path = AgentPath::from_string("/root/explorer".to_string())
|
||||
.expect("test agent path should be valid");
|
||||
|
||||
@@ -1640,13 +1925,22 @@ async fn resume_thread_subagent_restores_stored_nickname_and_role() {
|
||||
.expect("resume should succeed");
|
||||
assert_eq!(resumed_thread_id, child_thread_id);
|
||||
|
||||
let resumed_snapshot = harness
|
||||
let resumed_thread = harness
|
||||
.manager
|
||||
.get_thread(resumed_thread_id)
|
||||
.await
|
||||
.expect("resumed child thread should exist")
|
||||
.config_snapshot()
|
||||
.await;
|
||||
.expect("resumed child thread should exist");
|
||||
assert_eq!(
|
||||
resumed_thread.codex.session.prompt_cache_key(),
|
||||
resumed_thread_id,
|
||||
"resume should keep the resumed thread's own cache key"
|
||||
);
|
||||
assert_ne!(
|
||||
resumed_thread.codex.session.prompt_cache_key(),
|
||||
parent_thread.codex.session.prompt_cache_key(),
|
||||
"resume must not opportunistically inherit cache state from a live parent"
|
||||
);
|
||||
let resumed_snapshot = resumed_thread.config_snapshot().await;
|
||||
let SessionSource::SubAgent(SubAgentSource::ThreadSpawn {
|
||||
parent_thread_id: resumed_parent_thread_id,
|
||||
depth: resumed_depth,
|
||||
|
||||
@@ -154,6 +154,7 @@ struct ModelClientState {
|
||||
conversation_id: ThreadId,
|
||||
window_generation: AtomicU64,
|
||||
installation_id: String,
|
||||
prompt_cache_key_override: Option<ThreadId>,
|
||||
provider: SharedModelProvider,
|
||||
auth_env_telemetry: AuthEnvTelemetry,
|
||||
session_source: SessionSource,
|
||||
@@ -163,6 +164,7 @@ struct ModelClientState {
|
||||
beta_features_header: Option<String>,
|
||||
disable_websockets: AtomicBool,
|
||||
cached_websocket_session: StdMutex<WebsocketSession>,
|
||||
latest_response_continuation: StdMutex<Option<ResponseContinuation>>,
|
||||
}
|
||||
|
||||
/// Resolved API client setup for a single request attempt.
|
||||
@@ -237,15 +239,32 @@ struct LastResponse {
|
||||
items_added: Vec<ResponseItem>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct ResponseContinuation {
|
||||
request: ResponsesApiRequest,
|
||||
last_response: LastResponse,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
struct WebsocketSession {
|
||||
connection: Option<ApiWebSocketConnection>,
|
||||
last_request: Option<ResponsesApiRequest>,
|
||||
last_response_rx: Option<oneshot::Receiver<LastResponse>>,
|
||||
last_response: Option<LastResponse>,
|
||||
connection_reused: StdMutex<bool>,
|
||||
}
|
||||
|
||||
impl WebsocketSession {
|
||||
fn from_response_continuation(continuation: ResponseContinuation) -> Self {
|
||||
Self {
|
||||
connection: None,
|
||||
last_request: Some(continuation.request),
|
||||
last_response_rx: None,
|
||||
last_response: Some(continuation.last_response),
|
||||
connection_reused: StdMutex::new(false),
|
||||
}
|
||||
}
|
||||
|
||||
fn set_connection_reused(&self, connection_reused: bool) {
|
||||
*self
|
||||
.connection_reused
|
||||
@@ -261,6 +280,24 @@ impl WebsocketSession {
|
||||
}
|
||||
}
|
||||
|
||||
impl ResponseContinuation {
|
||||
pub(crate) fn for_fork(mut self) -> Self {
|
||||
self.request
|
||||
.input
|
||||
.retain(|item| !matches!(item, ResponseItem::Reasoning { .. }));
|
||||
self
|
||||
}
|
||||
|
||||
pub(crate) fn fork_baseline_input(&self) -> Vec<ResponseItem> {
|
||||
self.request
|
||||
.input
|
||||
.iter()
|
||||
.chain(self.last_response.items_added.iter())
|
||||
.cloned()
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
enum WebsocketStreamOutcome {
|
||||
Stream(ResponseStream),
|
||||
FallbackToHttp,
|
||||
@@ -298,12 +335,42 @@ impl ModelClient {
|
||||
auth_manager: Option<Arc<AuthManager>>,
|
||||
conversation_id: ThreadId,
|
||||
installation_id: String,
|
||||
prompt_cache_key_override: Option<ThreadId>,
|
||||
provider_info: ModelProviderInfo,
|
||||
session_source: SessionSource,
|
||||
model_verbosity: Option<VerbosityConfig>,
|
||||
enable_request_compression: bool,
|
||||
include_timing_metrics: bool,
|
||||
beta_features_header: Option<String>,
|
||||
) -> Self {
|
||||
Self::new_with_response_continuation(
|
||||
auth_manager,
|
||||
conversation_id,
|
||||
installation_id,
|
||||
prompt_cache_key_override,
|
||||
provider_info,
|
||||
session_source,
|
||||
model_verbosity,
|
||||
enable_request_compression,
|
||||
include_timing_metrics,
|
||||
beta_features_header,
|
||||
/*response_continuation*/ None,
|
||||
)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) fn new_with_response_continuation(
|
||||
auth_manager: Option<Arc<AuthManager>>,
|
||||
conversation_id: ThreadId,
|
||||
installation_id: String,
|
||||
prompt_cache_key_override: Option<ThreadId>,
|
||||
provider_info: ModelProviderInfo,
|
||||
session_source: SessionSource,
|
||||
model_verbosity: Option<VerbosityConfig>,
|
||||
enable_request_compression: bool,
|
||||
include_timing_metrics: bool,
|
||||
beta_features_header: Option<String>,
|
||||
response_continuation: Option<ResponseContinuation>,
|
||||
) -> Self {
|
||||
let model_provider = create_model_provider(provider_info, auth_manager);
|
||||
let codex_api_key_env_enabled = model_provider
|
||||
@@ -312,11 +379,16 @@ impl ModelClient {
|
||||
.is_some_and(|manager| manager.codex_api_key_env_enabled());
|
||||
let auth_env_telemetry =
|
||||
collect_auth_env_telemetry(model_provider.info(), codex_api_key_env_enabled);
|
||||
let cached_websocket_session = response_continuation
|
||||
.clone()
|
||||
.map(WebsocketSession::from_response_continuation)
|
||||
.unwrap_or_default();
|
||||
Self {
|
||||
state: Arc::new(ModelClientState {
|
||||
conversation_id,
|
||||
window_generation: AtomicU64::new(0),
|
||||
installation_id,
|
||||
prompt_cache_key_override,
|
||||
provider: model_provider,
|
||||
auth_env_telemetry,
|
||||
session_source,
|
||||
@@ -325,7 +397,8 @@ impl ModelClient {
|
||||
include_timing_metrics,
|
||||
beta_features_header,
|
||||
disable_websockets: AtomicBool::new(false),
|
||||
cached_websocket_session: StdMutex::new(WebsocketSession::default()),
|
||||
cached_websocket_session: StdMutex::new(cached_websocket_session),
|
||||
latest_response_continuation: StdMutex::new(response_continuation),
|
||||
}),
|
||||
}
|
||||
}
|
||||
@@ -364,6 +437,24 @@ impl ModelClient {
|
||||
format!("{conversation_id}:{window_generation}")
|
||||
}
|
||||
|
||||
pub(crate) fn prompt_cache_key(&self) -> ThreadId {
|
||||
self.state
|
||||
.prompt_cache_key_override
|
||||
.unwrap_or(self.state.conversation_id)
|
||||
}
|
||||
|
||||
pub(crate) fn response_continuation_for_fork(&self) -> Option<ResponseContinuation> {
|
||||
if !self.responses_websocket_enabled() {
|
||||
return None;
|
||||
}
|
||||
self.state
|
||||
.latest_response_continuation
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
||||
.clone()
|
||||
.map(ResponseContinuation::for_fork)
|
||||
}
|
||||
|
||||
fn take_cached_websocket_session(&self) -> WebsocketSession {
|
||||
let mut cached_websocket_session = self
|
||||
.state
|
||||
@@ -787,6 +878,7 @@ impl ModelClient {
|
||||
) -> ApiHeaderMap {
|
||||
let turn_metadata_header = parse_turn_metadata_header(turn_metadata_header);
|
||||
let conversation_id = self.state.conversation_id.to_string();
|
||||
let prompt_cache_key = self.prompt_cache_key().to_string();
|
||||
let mut headers = build_responses_headers(
|
||||
self.state.beta_features_header.as_deref(),
|
||||
turn_state,
|
||||
@@ -795,7 +887,10 @@ impl ModelClient {
|
||||
if let Ok(header_value) = HeaderValue::from_str(&conversation_id) {
|
||||
headers.insert("x-client-request-id", header_value);
|
||||
}
|
||||
headers.extend(build_conversation_headers(Some(conversation_id)));
|
||||
// The Responses websocket backend uses the `session_id` handshake header as the prompt
|
||||
// cache id. Keep `x-client-request-id` on the conversation id so forked agents keep a
|
||||
// unique request identity while sharing their parent's prompt cache id.
|
||||
headers.extend(build_conversation_headers(Some(prompt_cache_key)));
|
||||
headers.extend(self.build_responses_identity_headers());
|
||||
headers.insert(
|
||||
OPENAI_BETA_HEADER,
|
||||
@@ -824,6 +919,7 @@ impl ModelClientSession {
|
||||
self.websocket_session.connection = None;
|
||||
self.websocket_session.last_request = None;
|
||||
self.websocket_session.last_response_rx = None;
|
||||
self.websocket_session.last_response = None;
|
||||
self.websocket_session
|
||||
.set_connection_reused(/*connection_reused*/ false);
|
||||
}
|
||||
@@ -877,7 +973,7 @@ impl ModelClientSession {
|
||||
&prompt.output_schema,
|
||||
prompt.output_schema_strict,
|
||||
);
|
||||
let prompt_cache_key = Some(self.client.state.conversation_id.to_string());
|
||||
let prompt_cache_key = Some(self.client.prompt_cache_key().to_string());
|
||||
let request = ResponsesApiRequest {
|
||||
model: model_info.slug.clone(),
|
||||
instructions: instructions.clone(),
|
||||
@@ -916,8 +1012,10 @@ impl ModelClientSession {
|
||||
) -> ApiResponsesOptions {
|
||||
let turn_metadata_header = parse_turn_metadata_header(turn_metadata_header);
|
||||
let conversation_id = self.client.state.conversation_id.to_string();
|
||||
let prompt_cache_key = self.client.prompt_cache_key().to_string();
|
||||
ApiResponsesOptions {
|
||||
conversation_id: Some(conversation_id),
|
||||
prompt_cache_key: Some(prompt_cache_key),
|
||||
session_source: Some(self.client.state.session_source.clone()),
|
||||
extra_headers: {
|
||||
let mut headers = build_responses_headers(
|
||||
@@ -973,13 +1071,16 @@ impl ModelClientSession {
|
||||
}
|
||||
|
||||
fn get_last_response(&mut self) -> Option<LastResponse> {
|
||||
self.websocket_session
|
||||
.last_response_rx
|
||||
.take()
|
||||
.and_then(|mut receiver| match receiver.try_recv() {
|
||||
Ok(last_response) => Some(last_response),
|
||||
Err(TryRecvError::Closed) | Err(TryRecvError::Empty) => None,
|
||||
})
|
||||
if let Some(mut receiver) = self.websocket_session.last_response_rx.take() {
|
||||
match receiver.try_recv() {
|
||||
Ok(last_response) => {
|
||||
self.websocket_session.last_response = Some(last_response.clone());
|
||||
return Some(last_response);
|
||||
}
|
||||
Err(TryRecvError::Closed) | Err(TryRecvError::Empty) => {}
|
||||
}
|
||||
}
|
||||
self.websocket_session.last_response.clone()
|
||||
}
|
||||
|
||||
fn prepare_websocket_request(
|
||||
@@ -1084,7 +1185,9 @@ impl ModelClientSession {
|
||||
};
|
||||
|
||||
if needs_new {
|
||||
self.websocket_session.last_request = None;
|
||||
if self.websocket_session.last_response.is_none() {
|
||||
self.websocket_session.last_request = None;
|
||||
}
|
||||
self.websocket_session.last_response_rx = None;
|
||||
let turn_state = options
|
||||
.turn_state
|
||||
@@ -1178,6 +1281,8 @@ impl ModelClientSession {
|
||||
stream,
|
||||
session_telemetry.clone(),
|
||||
InferenceTraceAttempt::disabled(),
|
||||
/*client_state*/ None,
|
||||
/*request*/ None,
|
||||
);
|
||||
return Ok(stream);
|
||||
}
|
||||
@@ -1228,6 +1333,8 @@ impl ModelClientSession {
|
||||
stream,
|
||||
session_telemetry.clone(),
|
||||
inference_trace_attempt,
|
||||
/*client_state*/ None,
|
||||
/*request*/ None,
|
||||
);
|
||||
return Ok(stream);
|
||||
}
|
||||
@@ -1367,6 +1474,7 @@ impl ModelClientSession {
|
||||
|
||||
let ws_request = self.prepare_websocket_request(ws_payload, &request);
|
||||
self.websocket_session.last_request = Some(request);
|
||||
self.websocket_session.last_response = None;
|
||||
let inference_trace_attempt = if warmup {
|
||||
// Prewarm sends `generate=false`; it is connection setup, not a
|
||||
// model inference attempt that should appear in rollout traces.
|
||||
@@ -1399,6 +1507,8 @@ impl ModelClientSession {
|
||||
stream_result,
|
||||
session_telemetry.clone(),
|
||||
inference_trace_attempt,
|
||||
Some(Arc::clone(&self.client.state)),
|
||||
self.websocket_session.last_request.clone(),
|
||||
);
|
||||
self.websocket_session.last_response_rx = Some(last_request_rx);
|
||||
return Ok(WebsocketStreamOutcome::Stream(stream));
|
||||
@@ -1657,6 +1767,8 @@ fn map_response_stream(
|
||||
api_stream: codex_api::ResponseStream,
|
||||
session_telemetry: SessionTelemetry,
|
||||
inference_trace_attempt: InferenceTraceAttempt,
|
||||
client_state: Option<Arc<ModelClientState>>,
|
||||
request: Option<ResponsesApiRequest>,
|
||||
) -> (ResponseStream, oneshot::Receiver<LastResponse>) {
|
||||
let codex_api::ResponseStream {
|
||||
rx_event,
|
||||
@@ -1671,6 +1783,8 @@ fn map_response_stream(
|
||||
api_stream,
|
||||
session_telemetry,
|
||||
inference_trace_attempt,
|
||||
client_state,
|
||||
request,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -1679,6 +1793,8 @@ fn map_response_events<S>(
|
||||
api_stream: S,
|
||||
session_telemetry: SessionTelemetry,
|
||||
inference_trace_attempt: InferenceTraceAttempt,
|
||||
client_state: Option<Arc<ModelClientState>>,
|
||||
request: Option<ResponsesApiRequest>,
|
||||
) -> (ResponseStream, oneshot::Receiver<LastResponse>)
|
||||
where
|
||||
S: futures::Stream<Item = std::result::Result<ResponseEvent, ApiError>>
|
||||
@@ -1749,11 +1865,24 @@ where
|
||||
&token_usage,
|
||||
&items_added,
|
||||
);
|
||||
let last_response = LastResponse {
|
||||
response_id: response_id.clone(),
|
||||
items_added: std::mem::take(&mut items_added),
|
||||
};
|
||||
if let (Some(client_state), Some(request)) = (&client_state, &request)
|
||||
&& !last_response.response_id.is_empty()
|
||||
{
|
||||
*client_state
|
||||
.latest_response_continuation
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner) =
|
||||
Some(ResponseContinuation {
|
||||
request: request.clone(),
|
||||
last_response: last_response.clone(),
|
||||
});
|
||||
}
|
||||
if let Some(sender) = tx_last_response.take() {
|
||||
let _ = sender.send(LastResponse {
|
||||
response_id: response_id.clone(),
|
||||
items_added: std::mem::take(&mut items_added),
|
||||
});
|
||||
let _ = sender.send(last_response);
|
||||
}
|
||||
if tx_event
|
||||
.send(Ok(ResponseEvent::Completed {
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
use super::AuthRequestTelemetryContext;
|
||||
use super::LastResponse;
|
||||
use super::ModelClient;
|
||||
use super::PendingUnauthorizedRetry;
|
||||
use super::ResponseContinuation;
|
||||
use super::ResponsesApiRequest;
|
||||
use super::UnauthorizedRecoveryExecution;
|
||||
use super::X_CODEX_INSTALLATION_ID_HEADER;
|
||||
use super::X_CODEX_PARENT_THREAD_ID_HEADER;
|
||||
@@ -16,6 +19,8 @@ use codex_model_provider_info::create_oss_provider_with_base_url;
|
||||
use codex_otel::SessionTelemetry;
|
||||
use codex_protocol::ThreadId;
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::ReasoningItemContent;
|
||||
use codex_protocol::models::ReasoningItemReasoningSummary;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::openai_models::ModelInfo;
|
||||
use codex_protocol::protocol::InternalSessionSource;
|
||||
@@ -41,11 +46,13 @@ use tempfile::TempDir;
|
||||
use tokio::sync::Notify;
|
||||
|
||||
fn test_model_client(session_source: SessionSource) -> ModelClient {
|
||||
let conversation_id = ThreadId::new();
|
||||
let provider = create_oss_provider_with_base_url("https://example.com/v1", WireApi::Responses);
|
||||
ModelClient::new(
|
||||
/*auth_manager*/ None,
|
||||
ThreadId::new(),
|
||||
conversation_id,
|
||||
/*installation_id*/ "11111111-1111-4111-8111-111111111111".to_string(),
|
||||
/*prompt_cache_key_override*/ None,
|
||||
provider,
|
||||
session_source,
|
||||
/*model_verbosity*/ None,
|
||||
@@ -147,6 +154,66 @@ fn output_message(id: &str, text: &str) -> ResponseItem {
|
||||
}
|
||||
}
|
||||
|
||||
fn user_message_item(text: &str) -> ResponseItem {
|
||||
ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentItem::InputText {
|
||||
text: text.to_string(),
|
||||
}],
|
||||
phase: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn reasoning_item(id: &str, text: &str) -> ResponseItem {
|
||||
ResponseItem::Reasoning {
|
||||
id: id.to_string(),
|
||||
summary: vec![ReasoningItemReasoningSummary::SummaryText {
|
||||
text: "summary".to_string(),
|
||||
}],
|
||||
content: Some(vec![ReasoningItemContent::ReasoningText {
|
||||
text: text.to_string(),
|
||||
}]),
|
||||
encrypted_content: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_continuation_for_fork_drops_historical_reasoning_but_keeps_latest() {
|
||||
let user_message = user_message_item("hello");
|
||||
let old_reasoning = reasoning_item("rs-old", "old analysis");
|
||||
let latest_reasoning = reasoning_item("rs-latest", "latest analysis");
|
||||
let latest_message = output_message("msg-latest", "assistant output");
|
||||
let response_continuation = ResponseContinuation {
|
||||
request: ResponsesApiRequest {
|
||||
model: "gpt-test".to_string(),
|
||||
instructions: "base instructions".to_string(),
|
||||
input: vec![user_message.clone(), old_reasoning],
|
||||
tools: Vec::new(),
|
||||
tool_choice: "auto".to_string(),
|
||||
parallel_tool_calls: false,
|
||||
reasoning: None,
|
||||
store: false,
|
||||
stream: true,
|
||||
include: Vec::new(),
|
||||
service_tier: None,
|
||||
prompt_cache_key: Some(ThreadId::new().to_string()),
|
||||
text: None,
|
||||
client_metadata: None,
|
||||
},
|
||||
last_response: LastResponse {
|
||||
response_id: "parent-resp".to_string(),
|
||||
items_added: vec![latest_reasoning.clone(), latest_message.clone()],
|
||||
},
|
||||
}
|
||||
.for_fork();
|
||||
|
||||
assert_eq!(
|
||||
response_continuation.fork_baseline_input(),
|
||||
vec![user_message, latest_reasoning, latest_message]
|
||||
);
|
||||
}
|
||||
|
||||
async fn replay_until_cancelled(temp: &TempDir) -> anyhow::Result<RolloutTrace> {
|
||||
let mut rollout = replay_bundle(temp.path())?;
|
||||
for _ in 0..50 {
|
||||
@@ -287,6 +354,8 @@ async fn dropped_response_stream_traces_cancelled_partial_output() -> anyhow::Re
|
||||
api_stream,
|
||||
test_session_telemetry(),
|
||||
attempt,
|
||||
/*client_state*/ None,
|
||||
/*request*/ None,
|
||||
);
|
||||
|
||||
let observed = stream
|
||||
@@ -342,6 +411,8 @@ async fn dropped_backpressured_response_stream_traces_cancelled_partial_output()
|
||||
api_stream,
|
||||
test_session_telemetry(),
|
||||
attempt,
|
||||
/*client_state*/ None,
|
||||
/*request*/ None,
|
||||
);
|
||||
|
||||
// Fill the mapper channel with non-terminal events, then yield one output
|
||||
|
||||
@@ -93,6 +93,7 @@ pub(crate) async fn run_codex_thread_interactive(
|
||||
user_shell_override: None,
|
||||
inherited_exec_policy: Some(Arc::clone(&parent_session.services.exec_policy)),
|
||||
parent_rollout_thread_trace: codex_rollout_trace::ThreadTraceContext::disabled(),
|
||||
inherited_thread_state: Default::default(),
|
||||
parent_trace: None,
|
||||
environment_selections: ResolvedTurnEnvironments {
|
||||
turn_environments: parent_ctx.environments.clone(),
|
||||
|
||||
@@ -42,6 +42,8 @@ use codex_model_provider_info::ModelProviderInfo;
|
||||
pub const SUMMARIZATION_PROMPT: &str = include_str!("../templates/compact/prompt.md");
|
||||
pub const SUMMARY_PREFIX: &str = include_str!("../templates/compact/summary_prefix.md");
|
||||
const COMPACT_USER_MESSAGE_MAX_TOKENS: usize = 20_000;
|
||||
pub(crate) const UNIFIED_EXEC_PROCESS_WARNING_PREFIX: &str =
|
||||
"Warning: The maximum number of unified exec process";
|
||||
|
||||
/// Controls whether compaction replacement history must include initial context.
|
||||
///
|
||||
@@ -368,19 +370,51 @@ pub fn content_items_to_text(content: &[ContentItem]) -> Option<String> {
|
||||
}
|
||||
|
||||
pub(crate) fn collect_user_messages(items: &[ResponseItem]) -> Vec<String> {
|
||||
items
|
||||
.iter()
|
||||
.filter_map(|item| match crate::event_mapping::parse_turn_item(item) {
|
||||
let mut messages = Vec::new();
|
||||
let mut previous_message: Option<String> = Some(String::new());
|
||||
for item in items {
|
||||
let message = match crate::event_mapping::parse_turn_item(item) {
|
||||
Some(TurnItem::UserMessage(user)) => {
|
||||
if is_summary_message(&user.message()) {
|
||||
if is_summary_message(&user.message())
|
||||
|| is_compaction_filtered_user_message(&user.message())
|
||||
{
|
||||
None
|
||||
} else {
|
||||
Some(user.message())
|
||||
}
|
||||
}
|
||||
_ => None,
|
||||
})
|
||||
.collect()
|
||||
};
|
||||
let Some(message) = message else {
|
||||
previous_message = None;
|
||||
continue;
|
||||
};
|
||||
if message.is_empty() {
|
||||
continue;
|
||||
}
|
||||
if previous_message.as_deref() == Some(message.as_str()) {
|
||||
continue;
|
||||
}
|
||||
previous_message = Some(message.clone());
|
||||
messages.push(message);
|
||||
}
|
||||
messages
|
||||
}
|
||||
|
||||
pub(crate) fn is_compaction_filtered_user_message(message: &str) -> bool {
|
||||
message.starts_with(UNIFIED_EXEC_PROCESS_WARNING_PREFIX)
|
||||
}
|
||||
|
||||
pub(crate) fn is_compaction_filtered_history_item(item: &ResponseItem) -> bool {
|
||||
let ResponseItem::Message { role, content, .. } = item else {
|
||||
return false;
|
||||
};
|
||||
if role != "user" {
|
||||
return false;
|
||||
}
|
||||
content_items_to_text(content)
|
||||
.as_deref()
|
||||
.is_some_and(is_compaction_filtered_user_message)
|
||||
}
|
||||
|
||||
pub(crate) fn is_summary_message(message: &str) -> bool {
|
||||
|
||||
@@ -6,6 +6,7 @@ use crate::compact::CompactionAnalyticsAttempt;
|
||||
use crate::compact::InitialContextInjection;
|
||||
use crate::compact::compaction_status_from_result;
|
||||
use crate::compact::insert_initial_context_before_last_real_user_or_summary;
|
||||
use crate::compact::is_compaction_filtered_history_item;
|
||||
use crate::context_manager::ContextManager;
|
||||
use crate::context_manager::TotalTokenUsageBreakdown;
|
||||
use crate::context_manager::estimate_response_item_model_visible_bytes;
|
||||
@@ -144,8 +145,17 @@ async fn run_remote_compact_task_inner_impl(
|
||||
// This is the history selected for remote compaction, after any trimming required to fit the
|
||||
// compact endpoint. The checkpoint below records it separately from the next sampling request,
|
||||
// whose prompt will repeat current developer/context prefix items.
|
||||
let trace_input_history = history.raw_items().to_vec();
|
||||
let prompt_input = history.for_prompt(&turn_context.model_info.input_modalities);
|
||||
let trace_input_history = history
|
||||
.raw_items()
|
||||
.iter()
|
||||
.filter(|item| !is_compaction_filtered_history_item(item))
|
||||
.cloned()
|
||||
.collect::<Vec<_>>();
|
||||
let prompt_input = history
|
||||
.for_prompt(&turn_context.model_info.input_modalities)
|
||||
.into_iter()
|
||||
.filter(|item| !is_compaction_filtered_history_item(item))
|
||||
.collect::<Vec<_>>();
|
||||
let tool_router = built_tools(
|
||||
sess.as_ref(),
|
||||
turn_context.as_ref(),
|
||||
@@ -250,17 +260,21 @@ pub(crate) async fn process_compacted_history(
|
||||
/// We drop:
|
||||
/// - `developer` messages because remote output can include stale/duplicated
|
||||
/// instruction content.
|
||||
/// - non-user-content `user` messages (session prefix/instruction wrappers),
|
||||
/// while preserving real user messages and persisted hook prompts.
|
||||
/// - non-user-content `user` messages (session prefix/instruction wrappers).
|
||||
/// - user warnings that are known to be local runtime noise and should not
|
||||
/// survive compaction.
|
||||
///
|
||||
/// This intentionally keeps:
|
||||
/// - `assistant` messages (future remote compaction models may emit them)
|
||||
/// - `user`-role warnings and compaction-generated summary messages because
|
||||
/// they parse as `TurnItem::UserMessage`.
|
||||
/// - compaction-generated summary messages because they parse as
|
||||
/// `TurnItem::UserMessage`.
|
||||
fn should_keep_compacted_history_item(item: &ResponseItem) -> bool {
|
||||
match item {
|
||||
ResponseItem::Message { role, .. } if role == "developer" => false,
|
||||
ResponseItem::Message { role, .. } if role == "user" => {
|
||||
if is_compaction_filtered_history_item(item) {
|
||||
return false;
|
||||
}
|
||||
matches!(
|
||||
crate::event_mapping::parse_turn_item(item),
|
||||
Some(TurnItem::UserMessage(_) | TurnItem::HookPrompt(_))
|
||||
|
||||
@@ -120,6 +120,102 @@ do things
|
||||
assert_eq!(vec!["real user message".to_string()], collected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn collect_user_messages_filters_unified_exec_process_warnings() {
|
||||
let items = vec![
|
||||
ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentItem::InputText {
|
||||
text: "before warning".to_string(),
|
||||
}],
|
||||
phase: None,
|
||||
},
|
||||
ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentItem::InputText {
|
||||
text: "Warning: The maximum number of unified exec processes you can keep open is 5 and you currently have 5 processes open. Reuse older processes or close them to prevent automatic pruning of old processes".to_string(),
|
||||
}],
|
||||
phase: None,
|
||||
},
|
||||
ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentItem::InputText {
|
||||
text: "after warning".to_string(),
|
||||
}],
|
||||
phase: None,
|
||||
},
|
||||
];
|
||||
|
||||
let collected = collect_user_messages(&items);
|
||||
|
||||
assert_eq!(
|
||||
vec!["before warning".to_string(), "after warning".to_string()],
|
||||
collected
|
||||
);
|
||||
assert!(is_compaction_filtered_history_item(&items[1]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn collect_user_messages_drops_contiguous_duplicates_and_empty_messages() {
|
||||
let items = vec![
|
||||
ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentItem::InputText {
|
||||
text: String::new(),
|
||||
}],
|
||||
phase: None,
|
||||
},
|
||||
ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentItem::InputText {
|
||||
text: "repeat".to_string(),
|
||||
}],
|
||||
phase: None,
|
||||
},
|
||||
ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentItem::InputText {
|
||||
text: "repeat".to_string(),
|
||||
}],
|
||||
phase: None,
|
||||
},
|
||||
ResponseItem::Message {
|
||||
id: None,
|
||||
role: "assistant".to_string(),
|
||||
content: vec![ContentItem::OutputText {
|
||||
text: "keeps the next user message non-contiguous".to_string(),
|
||||
}],
|
||||
phase: None,
|
||||
},
|
||||
ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentItem::InputText {
|
||||
text: "repeat".to_string(),
|
||||
}],
|
||||
phase: None,
|
||||
},
|
||||
ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentItem::InputText {
|
||||
text: String::new(),
|
||||
}],
|
||||
phase: None,
|
||||
},
|
||||
];
|
||||
|
||||
let collected = collect_user_messages(&items);
|
||||
|
||||
assert_eq!(vec!["repeat".to_string(), "repeat".to_string()], collected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_token_limited_compacted_history_truncates_overlong_user_messages() {
|
||||
// Use a small truncation limit so the test remains fast while still validating
|
||||
|
||||
64
codex-rs/core/src/inherited_thread_state.rs
Normal file
64
codex-rs/core/src/inherited_thread_state.rs
Normal file
@@ -0,0 +1,64 @@
|
||||
use codex_protocol::ThreadId;
|
||||
|
||||
use crate::client::ResponseContinuation;
|
||||
use crate::state::McpToolSnapshot;
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
pub(crate) struct InheritedThreadState {
|
||||
prompt_cache_key: Option<ThreadId>,
|
||||
response_continuation: Option<ResponseContinuation>,
|
||||
mcp_tool_snapshot: Option<McpToolSnapshot>,
|
||||
}
|
||||
|
||||
impl InheritedThreadState {
|
||||
pub(crate) fn builder() -> InheritedThreadStateBuilder {
|
||||
InheritedThreadStateBuilder::default()
|
||||
}
|
||||
|
||||
pub(crate) fn prompt_cache_key(&self) -> Option<ThreadId> {
|
||||
self.prompt_cache_key
|
||||
}
|
||||
|
||||
pub(crate) fn response_continuation(&self) -> Option<ResponseContinuation> {
|
||||
self.response_continuation.clone()
|
||||
}
|
||||
|
||||
pub(crate) fn mcp_tool_snapshot(&self) -> Option<McpToolSnapshot> {
|
||||
self.mcp_tool_snapshot.clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub(crate) struct InheritedThreadStateBuilder {
|
||||
prompt_cache_key: Option<ThreadId>,
|
||||
response_continuation: Option<ResponseContinuation>,
|
||||
mcp_tool_snapshot: Option<McpToolSnapshot>,
|
||||
}
|
||||
|
||||
impl InheritedThreadStateBuilder {
|
||||
pub(crate) fn prompt_cache_key(mut self, prompt_cache_key: Option<ThreadId>) -> Self {
|
||||
self.prompt_cache_key = prompt_cache_key;
|
||||
self
|
||||
}
|
||||
|
||||
pub(crate) fn response_continuation(
|
||||
mut self,
|
||||
response_continuation: Option<ResponseContinuation>,
|
||||
) -> Self {
|
||||
self.response_continuation = response_continuation;
|
||||
self
|
||||
}
|
||||
|
||||
pub(crate) fn mcp_tool_snapshot(mut self, mcp_tool_snapshot: Option<McpToolSnapshot>) -> Self {
|
||||
self.mcp_tool_snapshot = mcp_tool_snapshot;
|
||||
self
|
||||
}
|
||||
|
||||
pub(crate) fn build(self) -> InheritedThreadState {
|
||||
InheritedThreadState {
|
||||
prompt_cache_key: self.prompt_cache_key,
|
||||
response_continuation: self.response_continuation,
|
||||
mcp_tool_snapshot: self.mcp_tool_snapshot,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -40,6 +40,7 @@ mod git_info_tests;
|
||||
mod goals;
|
||||
mod guardian;
|
||||
mod hook_runtime;
|
||||
mod inherited_thread_state;
|
||||
mod installation_id;
|
||||
pub(crate) mod landlock;
|
||||
pub use landlock::spawn_command_under_linux_sandbox;
|
||||
|
||||
@@ -267,6 +267,8 @@ impl Session {
|
||||
std::mem::replace(&mut *manager, refreshed_manager)
|
||||
};
|
||||
old_manager.shutdown().await;
|
||||
let mut snapshot = self.services.mcp_tool_snapshot.lock().await;
|
||||
*snapshot = None;
|
||||
}
|
||||
|
||||
pub(crate) async fn refresh_mcp_servers_if_requested(&self, turn_context: &TurnContext) {
|
||||
|
||||
@@ -32,6 +32,7 @@ use crate::context::PersonalitySpecInstructions;
|
||||
use crate::default_skill_metadata_budget;
|
||||
use crate::environment_selection::ResolvedTurnEnvironments;
|
||||
use crate::exec_policy::ExecPolicyManager;
|
||||
use crate::inherited_thread_state::InheritedThreadState;
|
||||
use crate::installation_id::resolve_installation_id;
|
||||
use crate::parse_turn_item;
|
||||
use crate::path_utils::normalize_for_native_workdir;
|
||||
@@ -406,6 +407,7 @@ pub(crate) struct CodexSpawnArgs {
|
||||
/// Root sessions and non-thread-spawn subagents pass a disabled context;
|
||||
/// `Session::new` creates the root trace itself when rollout tracing is enabled.
|
||||
pub(crate) parent_rollout_thread_trace: ThreadTraceContext,
|
||||
pub(crate) inherited_thread_state: InheritedThreadState,
|
||||
pub(crate) user_shell_override: Option<shell::Shell>,
|
||||
pub(crate) parent_trace: Option<W3cTraceContext>,
|
||||
pub(crate) environment_selections: ResolvedTurnEnvironments,
|
||||
@@ -464,6 +466,7 @@ impl Codex {
|
||||
user_shell_override,
|
||||
inherited_exec_policy,
|
||||
parent_rollout_thread_trace,
|
||||
inherited_thread_state,
|
||||
parent_trace: _,
|
||||
environment_selections,
|
||||
analytics_events_client,
|
||||
@@ -640,6 +643,7 @@ impl Codex {
|
||||
skills_watcher,
|
||||
agent_control,
|
||||
environment_manager,
|
||||
inherited_thread_state,
|
||||
analytics_events_client,
|
||||
thread_store,
|
||||
parent_rollout_thread_trace,
|
||||
@@ -1032,6 +1036,16 @@ impl Session {
|
||||
self.services.live_thread.as_ref()
|
||||
}
|
||||
|
||||
pub(crate) fn prompt_cache_key(&self) -> ThreadId {
|
||||
self.services.model_client.prompt_cache_key()
|
||||
}
|
||||
|
||||
pub(crate) fn response_continuation_for_fork(
|
||||
&self,
|
||||
) -> Option<crate::client::ResponseContinuation> {
|
||||
self.services.model_client.response_continuation_for_fork()
|
||||
}
|
||||
|
||||
/// Flush rollout writes and return the final durability-barrier result.
|
||||
pub(crate) async fn flush_rollout(&self) -> std::io::Result<()> {
|
||||
if let Some(live_thread) = self.live_thread() {
|
||||
|
||||
@@ -343,6 +343,7 @@ impl Session {
|
||||
skills_watcher: Arc<SkillsWatcher>,
|
||||
agent_control: AgentControl,
|
||||
environment_manager: Arc<EnvironmentManager>,
|
||||
inherited_thread_state: InheritedThreadState,
|
||||
analytics_events_client: Option<AnalyticsEventsClient>,
|
||||
thread_store: Arc<dyn ThreadStore>,
|
||||
parent_rollout_thread_trace: ThreadTraceContext,
|
||||
@@ -822,6 +823,8 @@ impl Session {
|
||||
config.analytics_enabled,
|
||||
)
|
||||
});
|
||||
let prompt_cache_key_override = inherited_thread_state.prompt_cache_key();
|
||||
let mcp_tool_snapshot = inherited_thread_state.mcp_tool_snapshot();
|
||||
let services = SessionServices {
|
||||
// Initialize the MCP connection manager with an uninitialized
|
||||
// instance. It will be replaced with one created via
|
||||
@@ -864,16 +867,19 @@ impl Session {
|
||||
state_db: state_db_ctx.clone(),
|
||||
live_thread: live_thread_init.as_ref().cloned(),
|
||||
thread_store: Arc::clone(&thread_store),
|
||||
model_client: ModelClient::new(
|
||||
mcp_tool_snapshot: Mutex::new(mcp_tool_snapshot),
|
||||
model_client: ModelClient::new_with_response_continuation(
|
||||
Some(Arc::clone(&auth_manager)),
|
||||
conversation_id,
|
||||
installation_id,
|
||||
prompt_cache_key_override,
|
||||
session_configuration.provider.clone(),
|
||||
session_configuration.session_source.clone(),
|
||||
config.model_verbosity,
|
||||
config.features.enabled(Feature::EnableRequestCompression),
|
||||
config.features.enabled(Feature::RuntimeMetrics),
|
||||
Self::build_model_client_beta_features_header(config.as_ref()),
|
||||
inherited_thread_state.response_continuation(),
|
||||
),
|
||||
code_mode_service: crate::tools::code_mode::CodeModeService::new(),
|
||||
environment_manager,
|
||||
|
||||
@@ -133,6 +133,7 @@ use core_test_support::responses::ev_function_call;
|
||||
use core_test_support::responses::ev_response_created;
|
||||
use core_test_support::responses::mount_sse_once;
|
||||
use core_test_support::responses::mount_sse_sequence;
|
||||
use core_test_support::responses::namespace_child_tool;
|
||||
use core_test_support::responses::sse;
|
||||
use core_test_support::responses::start_mock_server;
|
||||
use core_test_support::test_codex::test_codex;
|
||||
@@ -350,6 +351,7 @@ fn test_model_client_session() -> crate::client::ModelClientSession {
|
||||
ThreadId::try_from("00000000-0000-4000-8000-000000000001")
|
||||
.expect("test thread id should be valid"),
|
||||
/*installation_id*/ "11111111-1111-4111-8111-111111111111".to_string(),
|
||||
/*prompt_cache_key_override*/ None,
|
||||
ModelProviderInfo::create_openai_provider(/* base_url */ /*base_url*/ None),
|
||||
codex_protocol::protocol::SessionSource::Exec,
|
||||
/*model_verbosity*/ None,
|
||||
@@ -1752,6 +1754,126 @@ async fn fork_startup_context_then_first_turn_diff_snapshot() -> anyhow::Result<
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn inherited_thread_state_shapes_first_responses_request() -> anyhow::Result<()> {
|
||||
let server = start_mock_server().await;
|
||||
let response_mock = mount_sse_once(
|
||||
&server,
|
||||
sse(vec![ev_response_created("resp-1"), ev_completed("resp-1")]),
|
||||
)
|
||||
.await;
|
||||
let inherited_prompt_cache_key = ThreadId::try_from("00000000-0000-4000-8000-000000000002")
|
||||
.expect("test thread id should be valid");
|
||||
let inherited_tool = codex_mcp::ToolInfo {
|
||||
server_name: "snapshot".to_string(),
|
||||
callable_name: "echo".to_string(),
|
||||
callable_namespace: "mcp__snapshot__".to_string(),
|
||||
server_instructions: None,
|
||||
tool: rmcp::model::Tool {
|
||||
name: "echo".to_string().into(),
|
||||
title: None,
|
||||
description: Some("Echo from the inherited MCP snapshot".to_string().into()),
|
||||
input_schema: std::sync::Arc::new(rmcp::model::object(json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"message": { "type": "string" }
|
||||
},
|
||||
"required": ["message"],
|
||||
"additionalProperties": false
|
||||
}))),
|
||||
output_schema: None,
|
||||
annotations: None,
|
||||
execution: None,
|
||||
icons: None,
|
||||
meta: None,
|
||||
},
|
||||
connector_id: None,
|
||||
connector_name: None,
|
||||
plugin_display_names: Vec::new(),
|
||||
connector_description: None,
|
||||
};
|
||||
let inherited_thread_state = crate::inherited_thread_state::InheritedThreadState::builder()
|
||||
.prompt_cache_key(Some(inherited_prompt_cache_key))
|
||||
.mcp_tool_snapshot(Some(crate::state::McpToolSnapshot {
|
||||
tools: std::collections::HashMap::from([(
|
||||
"mcp__snapshot__echo".to_string(),
|
||||
inherited_tool,
|
||||
)]),
|
||||
}))
|
||||
.build();
|
||||
let (session, rx_event) = make_session_with_config_inherited_and_rx(
|
||||
|config| {
|
||||
config.model_provider.base_url = Some(format!("{}/v1", server.uri()));
|
||||
config.model_provider.supports_websockets = false;
|
||||
config
|
||||
.mcp_servers
|
||||
.set(std::collections::HashMap::new())
|
||||
.expect("empty mcp server config should be valid");
|
||||
},
|
||||
inherited_thread_state,
|
||||
)
|
||||
.await?;
|
||||
let turn_context = session.new_default_turn().await;
|
||||
|
||||
session
|
||||
.spawn_task(
|
||||
Arc::clone(&turn_context),
|
||||
vec![UserInput::Text {
|
||||
text: "use inherited state".to_string(),
|
||||
text_elements: Vec::new(),
|
||||
}],
|
||||
crate::tasks::RegularTask::new(),
|
||||
)
|
||||
.await;
|
||||
|
||||
let mut observed_events = Vec::new();
|
||||
let wait_result = timeout(Duration::from_secs(5), async {
|
||||
while let Ok(event) = rx_event.recv().await {
|
||||
observed_events.push(match &event.msg {
|
||||
EventMsg::SessionConfigured(_) => "SessionConfigured",
|
||||
EventMsg::McpStartupComplete(_) => "McpStartupComplete",
|
||||
EventMsg::TurnStarted(_) => "TurnStarted",
|
||||
EventMsg::RawResponseItem(_) => "RawResponseItem",
|
||||
EventMsg::ItemStarted(_) => "ItemStarted",
|
||||
EventMsg::ItemCompleted(_) => "ItemCompleted",
|
||||
EventMsg::UserMessage(_) => "UserMessage",
|
||||
EventMsg::StreamError(_) => "StreamError",
|
||||
EventMsg::Error(_) => "Error",
|
||||
EventMsg::TurnAborted(_) => "TurnAborted",
|
||||
EventMsg::TurnComplete(_) => "TurnComplete",
|
||||
_ => "Other",
|
||||
});
|
||||
match event.msg {
|
||||
EventMsg::TurnComplete(_) => return,
|
||||
EventMsg::Error(error) => panic!("turn errored: {}", error.message),
|
||||
EventMsg::TurnAborted(aborted) => panic!("turn aborted: {:?}", aborted.reason),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
})
|
||||
.await;
|
||||
if let Err(err) = wait_result {
|
||||
panic!(
|
||||
"turn should complete: {err:?}; captured requests: {}; observed events: {observed_events:#?}",
|
||||
response_mock.requests().len(),
|
||||
);
|
||||
}
|
||||
|
||||
let request = response_mock.single_request();
|
||||
let body = request.body_json();
|
||||
let expected_prompt_cache_key = inherited_prompt_cache_key.to_string();
|
||||
assert_eq!(
|
||||
body["prompt_cache_key"].as_str(),
|
||||
Some(expected_prompt_cache_key.as_str())
|
||||
);
|
||||
assert!(
|
||||
namespace_child_tool(&body, "mcp__snapshot__", "echo").is_some(),
|
||||
"first request should expose inherited MCP snapshot tools: {body:#}"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn record_initial_history_forked_hydrates_previous_turn_settings() {
|
||||
let (session, turn_context) = make_session_and_context().await;
|
||||
@@ -3487,6 +3609,7 @@ async fn session_new_fails_when_zsh_fork_enabled_without_zsh_path() {
|
||||
Arc::new(SkillsWatcher::noop()),
|
||||
AgentControl::default(),
|
||||
Arc::new(codex_exec_server::EnvironmentManager::default_for_tests()),
|
||||
Default::default(),
|
||||
/*analytics_events_client*/ None,
|
||||
Arc::new(codex_thread_store::LocalThreadStore::new(
|
||||
codex_thread_store::LocalThreadStoreConfig::from_config(config.as_ref()),
|
||||
@@ -3599,6 +3722,7 @@ pub(crate) async fn make_session_and_context() -> (Session, TurnContext) {
|
||||
&config.permissions.approval_policy,
|
||||
&config.permissions.permission_profile,
|
||||
))),
|
||||
mcp_tool_snapshot: Mutex::new(None),
|
||||
mcp_startup_cancellation_token: Mutex::new(CancellationToken::new()),
|
||||
unified_exec_manager: UnifiedExecProcessManager::new(
|
||||
config.background_terminal_max_timeout,
|
||||
@@ -3642,6 +3766,7 @@ pub(crate) async fn make_session_and_context() -> (Session, TurnContext) {
|
||||
Some(auth_manager.clone()),
|
||||
conversation_id,
|
||||
/*installation_id*/ "11111111-1111-4111-8111-111111111111".to_string(),
|
||||
/*prompt_cache_key_override*/ None,
|
||||
session_configuration.provider.clone(),
|
||||
session_configuration.session_source.clone(),
|
||||
config.model_verbosity,
|
||||
@@ -3721,6 +3846,13 @@ async fn make_session_with_config(
|
||||
|
||||
async fn make_session_with_config_and_rx(
|
||||
mutator: impl FnOnce(&mut Config),
|
||||
) -> anyhow::Result<(Arc<Session>, async_channel::Receiver<Event>)> {
|
||||
make_session_with_config_inherited_and_rx(mutator, Default::default()).await
|
||||
}
|
||||
|
||||
async fn make_session_with_config_inherited_and_rx(
|
||||
mutator: impl FnOnce(&mut Config),
|
||||
inherited_thread_state: crate::inherited_thread_state::InheritedThreadState,
|
||||
) -> anyhow::Result<(Arc<Session>, async_channel::Receiver<Event>)> {
|
||||
let codex_home = tempfile::tempdir().expect("create temp dir");
|
||||
let mut config = build_test_config(codex_home.path()).await;
|
||||
@@ -3805,6 +3937,7 @@ async fn make_session_with_config_and_rx(
|
||||
Arc::new(SkillsWatcher::noop()),
|
||||
AgentControl::default(),
|
||||
Arc::new(codex_exec_server::EnvironmentManager::default_for_tests()),
|
||||
inherited_thread_state,
|
||||
/*analytics_events_client*/ None,
|
||||
Arc::new(codex_thread_store::LocalThreadStore::new(
|
||||
codex_thread_store::LocalThreadStoreConfig::from_config(config.as_ref()),
|
||||
@@ -5081,6 +5214,7 @@ where
|
||||
&config.permissions.approval_policy,
|
||||
&config.permissions.permission_profile,
|
||||
))),
|
||||
mcp_tool_snapshot: Mutex::new(None),
|
||||
mcp_startup_cancellation_token: Mutex::new(CancellationToken::new()),
|
||||
unified_exec_manager: UnifiedExecProcessManager::new(
|
||||
config.background_terminal_max_timeout,
|
||||
@@ -5124,6 +5258,7 @@ where
|
||||
Some(Arc::clone(&auth_manager)),
|
||||
conversation_id,
|
||||
/*installation_id*/ "11111111-1111-4111-8111-111111111111".to_string(),
|
||||
/*prompt_cache_key_override*/ None,
|
||||
session_configuration.provider.clone(),
|
||||
session_configuration.session_source.clone(),
|
||||
config.model_verbosity,
|
||||
|
||||
@@ -753,6 +753,7 @@ async fn guardian_subagent_does_not_inherit_parent_exec_policy_rules() {
|
||||
inherited_shell_snapshot: None,
|
||||
inherited_exec_policy: Some(Arc::new(parent_exec_policy)),
|
||||
parent_rollout_thread_trace: codex_rollout_trace::ThreadTraceContext::disabled(),
|
||||
inherited_thread_state: Default::default(),
|
||||
user_shell_override: None,
|
||||
parent_trace: None,
|
||||
environment_selections: ResolvedTurnEnvironments {
|
||||
|
||||
@@ -1122,13 +1122,18 @@ pub(crate) async fn built_tools(
|
||||
skills_outcome: Option<&SkillLoadOutcome>,
|
||||
cancellation_token: &CancellationToken,
|
||||
) -> CodexResult<Arc<ToolRouter>> {
|
||||
let mcp_connection_manager = sess.services.mcp_connection_manager.read().await;
|
||||
let has_mcp_servers = mcp_connection_manager.has_servers();
|
||||
let all_mcp_tools = mcp_connection_manager
|
||||
.list_all_tools()
|
||||
.or_cancel(cancellation_token)
|
||||
.await?;
|
||||
drop(mcp_connection_manager);
|
||||
let inherited_mcp_tools = sess.services.mcp_tool_snapshot.lock().await.clone();
|
||||
let (has_mcp_servers, all_mcp_tools) = if let Some(snapshot) = inherited_mcp_tools {
|
||||
(!snapshot.tools.is_empty(), snapshot.tools)
|
||||
} else {
|
||||
let mcp_connection_manager = sess.services.mcp_connection_manager.read().await;
|
||||
let has_mcp_servers = mcp_connection_manager.has_servers();
|
||||
let all_mcp_tools = mcp_connection_manager
|
||||
.list_all_tools()
|
||||
.or_cancel(cancellation_token)
|
||||
.await?;
|
||||
(has_mcp_servers, all_mcp_tools)
|
||||
};
|
||||
let loaded_plugins = sess
|
||||
.services
|
||||
.plugins_manager
|
||||
|
||||
@@ -2,6 +2,7 @@ mod service;
|
||||
mod session;
|
||||
mod turn;
|
||||
|
||||
pub(crate) use service::McpToolSnapshot;
|
||||
pub(crate) use service::SessionServices;
|
||||
pub(crate) use session::SessionState;
|
||||
pub(crate) use turn::ActiveTurn;
|
||||
|
||||
@@ -21,6 +21,7 @@ use codex_exec_server::EnvironmentManager;
|
||||
use codex_hooks::Hooks;
|
||||
use codex_login::AuthManager;
|
||||
use codex_mcp::McpConnectionManager;
|
||||
use codex_mcp::ToolInfo as McpToolInfo;
|
||||
use codex_models_manager::manager::SharedModelsManager;
|
||||
use codex_otel::SessionTelemetry;
|
||||
use codex_rollout::state_db::StateDbHandle;
|
||||
@@ -34,8 +35,14 @@ use tokio::sync::RwLock;
|
||||
use tokio::sync::watch;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
pub(crate) struct McpToolSnapshot {
|
||||
pub(crate) tools: HashMap<String, McpToolInfo>,
|
||||
}
|
||||
|
||||
pub(crate) struct SessionServices {
|
||||
pub(crate) mcp_connection_manager: Arc<RwLock<McpConnectionManager>>,
|
||||
pub(crate) mcp_tool_snapshot: Mutex<Option<McpToolSnapshot>>,
|
||||
pub(crate) mcp_startup_cancellation_token: Mutex<CancellationToken>,
|
||||
pub(crate) unified_exec_manager: UnifiedExecProcessManager,
|
||||
#[cfg_attr(not(unix), allow(dead_code))]
|
||||
|
||||
@@ -6,6 +6,7 @@ use crate::config::ThreadStoreConfig;
|
||||
use crate::environment_selection::default_thread_environment_selections;
|
||||
use crate::environment_selection::resolve_environment_selections;
|
||||
use crate::file_watcher::FileWatcher;
|
||||
use crate::inherited_thread_state::InheritedThreadState;
|
||||
use crate::mcp::McpManager;
|
||||
use crate::rollout::RolloutRecorder;
|
||||
use crate::rollout::truncation;
|
||||
@@ -230,6 +231,7 @@ pub(crate) struct ResumeThreadWithHistoryOptions {
|
||||
pub(crate) session_source: SessionSource,
|
||||
pub(crate) inherited_shell_snapshot: Option<Arc<ShellSnapshot>>,
|
||||
pub(crate) inherited_exec_policy: Option<Arc<crate::exec_policy::ExecPolicyManager>>,
|
||||
pub(crate) inherited_thread_state: InheritedThreadState,
|
||||
}
|
||||
|
||||
/// Shared, `Arc`-owned state for [`ThreadManager`]. This `Arc` is required to have a single
|
||||
@@ -583,6 +585,7 @@ impl ThreadManager {
|
||||
options.metrics_service_name,
|
||||
/*inherited_shell_snapshot*/ None,
|
||||
/*inherited_exec_policy*/ None,
|
||||
Default::default(),
|
||||
options.parent_trace,
|
||||
options.environments,
|
||||
/*user_shell_override*/ None,
|
||||
@@ -924,6 +927,7 @@ impl ThreadManagerState {
|
||||
/*inherited_shell_snapshot*/ None,
|
||||
/*inherited_exec_policy*/ None,
|
||||
/*environments*/ None,
|
||||
Default::default(),
|
||||
))
|
||||
.await
|
||||
}
|
||||
@@ -939,6 +943,7 @@ impl ThreadManagerState {
|
||||
inherited_shell_snapshot: Option<Arc<ShellSnapshot>>,
|
||||
inherited_exec_policy: Option<Arc<crate::exec_policy::ExecPolicyManager>>,
|
||||
environments: Option<Vec<TurnEnvironmentSelection>>,
|
||||
inherited_thread_state: InheritedThreadState,
|
||||
) -> CodexResult<NewThread> {
|
||||
let environments = environments.unwrap_or_else(|| {
|
||||
default_thread_environment_selections(self.environment_manager.as_ref(), &config.cwd)
|
||||
@@ -954,6 +959,7 @@ impl ThreadManagerState {
|
||||
metrics_service_name,
|
||||
inherited_shell_snapshot,
|
||||
inherited_exec_policy,
|
||||
inherited_thread_state,
|
||||
/*parent_trace*/ None,
|
||||
environments,
|
||||
/*user_shell_override*/ None,
|
||||
@@ -972,6 +978,7 @@ impl ThreadManagerState {
|
||||
session_source,
|
||||
inherited_shell_snapshot,
|
||||
inherited_exec_policy,
|
||||
inherited_thread_state,
|
||||
} = options;
|
||||
let environments =
|
||||
default_thread_environment_selections(self.environment_manager.as_ref(), &config.cwd);
|
||||
@@ -986,6 +993,7 @@ impl ThreadManagerState {
|
||||
/*metrics_service_name*/ None,
|
||||
inherited_shell_snapshot,
|
||||
inherited_exec_policy,
|
||||
inherited_thread_state,
|
||||
/*parent_trace*/ None,
|
||||
environments,
|
||||
/*user_shell_override*/ None,
|
||||
@@ -1004,6 +1012,7 @@ impl ThreadManagerState {
|
||||
inherited_shell_snapshot: Option<Arc<ShellSnapshot>>,
|
||||
inherited_exec_policy: Option<Arc<crate::exec_policy::ExecPolicyManager>>,
|
||||
environments: Option<Vec<TurnEnvironmentSelection>>,
|
||||
inherited_thread_state: InheritedThreadState,
|
||||
) -> CodexResult<NewThread> {
|
||||
let environments = environments.unwrap_or_else(|| {
|
||||
default_thread_environment_selections(self.environment_manager.as_ref(), &config.cwd)
|
||||
@@ -1019,6 +1028,7 @@ impl ThreadManagerState {
|
||||
/*metrics_service_name*/ None,
|
||||
inherited_shell_snapshot,
|
||||
inherited_exec_policy,
|
||||
inherited_thread_state,
|
||||
/*parent_trace*/ None,
|
||||
environments,
|
||||
/*user_shell_override*/ None,
|
||||
@@ -1052,6 +1062,7 @@ impl ThreadManagerState {
|
||||
metrics_service_name,
|
||||
/*inherited_shell_snapshot*/ None,
|
||||
/*inherited_exec_policy*/ None,
|
||||
Default::default(),
|
||||
parent_trace,
|
||||
environments,
|
||||
user_shell_override,
|
||||
@@ -1072,6 +1083,7 @@ impl ThreadManagerState {
|
||||
metrics_service_name: Option<String>,
|
||||
inherited_shell_snapshot: Option<Arc<ShellSnapshot>>,
|
||||
inherited_exec_policy: Option<Arc<crate::exec_policy::ExecPolicyManager>>,
|
||||
inherited_thread_state: InheritedThreadState,
|
||||
parent_trace: Option<W3cTraceContext>,
|
||||
environments: Vec<TurnEnvironmentSelection>,
|
||||
user_shell_override: Option<crate::shell::Shell>,
|
||||
@@ -1137,6 +1149,7 @@ impl ThreadManagerState {
|
||||
inherited_shell_snapshot,
|
||||
inherited_exec_policy,
|
||||
parent_rollout_thread_trace,
|
||||
inherited_thread_state,
|
||||
user_shell_override,
|
||||
parent_trace,
|
||||
environment_selections,
|
||||
|
||||
@@ -102,6 +102,7 @@ async fn responses_stream_includes_subagent_header_on_review() {
|
||||
/*auth_manager*/ None,
|
||||
conversation_id,
|
||||
/*installation_id*/ TEST_INSTALLATION_ID.to_string(),
|
||||
/*prompt_cache_key_override*/ None,
|
||||
provider.clone(),
|
||||
session_source,
|
||||
config.model_verbosity,
|
||||
@@ -228,6 +229,7 @@ async fn responses_stream_includes_subagent_header_on_other() {
|
||||
/*auth_manager*/ None,
|
||||
conversation_id,
|
||||
/*installation_id*/ TEST_INSTALLATION_ID.to_string(),
|
||||
/*prompt_cache_key_override*/ None,
|
||||
provider.clone(),
|
||||
session_source,
|
||||
config.model_verbosity,
|
||||
@@ -343,6 +345,7 @@ async fn responses_respects_model_info_overrides_from_config() {
|
||||
/*auth_manager*/ None,
|
||||
conversation_id,
|
||||
/*installation_id*/ TEST_INSTALLATION_ID.to_string(),
|
||||
/*prompt_cache_key_override*/ None,
|
||||
provider.clone(),
|
||||
session_source,
|
||||
config.model_verbosity,
|
||||
|
||||
@@ -884,6 +884,7 @@ async fn send_provider_auth_request(server: &MockServer, auth: ModelProviderAuth
|
||||
))),
|
||||
conversation_id,
|
||||
/*installation_id*/ "11111111-1111-4111-8111-111111111111".to_string(),
|
||||
/*prompt_cache_key_override*/ None,
|
||||
provider,
|
||||
SessionSource::Exec,
|
||||
config.model_verbosity,
|
||||
@@ -2288,6 +2289,7 @@ async fn azure_responses_request_includes_store_and_reasoning_ids() {
|
||||
/*auth_manager*/ None,
|
||||
conversation_id,
|
||||
/*installation_id*/ "11111111-1111-4111-8111-111111111111".to_string(),
|
||||
/*prompt_cache_key_override*/ None,
|
||||
provider.clone(),
|
||||
SessionSource::Exec,
|
||||
config.model_verbosity,
|
||||
|
||||
@@ -1855,6 +1855,7 @@ async fn websocket_harness_with_provider_options(
|
||||
/*auth_manager*/ None,
|
||||
conversation_id,
|
||||
/*installation_id*/ TEST_INSTALLATION_ID.to_string(),
|
||||
/*prompt_cache_key_override*/ None,
|
||||
provider.clone(),
|
||||
SessionSource::Exec,
|
||||
config.model_verbosity,
|
||||
|
||||
@@ -16,6 +16,12 @@ pub struct Cli {
|
||||
#[command(subcommand)]
|
||||
pub command: Option<Command>,
|
||||
|
||||
/// Fork from an existing session id (or thread name) before sending the prompt.
|
||||
///
|
||||
/// This creates a new session with copied history, similar to `codex fork`.
|
||||
#[arg(long = "fork", value_name = "SESSION_ID")]
|
||||
pub fork_session_id: Option<String>,
|
||||
|
||||
#[clap(flatten)]
|
||||
pub shared: ExecSharedCliOptions,
|
||||
|
||||
@@ -81,6 +87,19 @@ pub struct Cli {
|
||||
pub prompt: Option<String>,
|
||||
}
|
||||
|
||||
impl Cli {
|
||||
pub fn validate(self) -> Result<Self, clap::Error> {
|
||||
if self.fork_session_id.is_some() && self.command.is_some() {
|
||||
return Err(clap::Error::raw(
|
||||
clap::error::ErrorKind::ArgumentConflict,
|
||||
"--fork cannot be used with subcommands",
|
||||
));
|
||||
}
|
||||
|
||||
Ok(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Deref for Cli {
|
||||
type Target = SharedCliOptions;
|
||||
|
||||
@@ -156,7 +175,6 @@ fn mark_exec_global_args(cmd: clap::Command) -> clap::Command {
|
||||
arg.global(true)
|
||||
})
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::Subcommand)]
|
||||
pub enum Command {
|
||||
/// Resume a previous session by id or pick the most recent with --last.
|
||||
|
||||
@@ -80,3 +80,22 @@ fn removed_full_auto_flag_reports_migration_path() {
|
||||
Some("warning: `--full-auto` is deprecated; use `--sandbox workspace-write` instead.")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fork_option_parses_prompt() {
|
||||
const PROMPT: &str = "echo fork-non-interactive";
|
||||
let cli = Cli::parse_from(["codex-exec", "--fork", "session-123", "--json", PROMPT]);
|
||||
|
||||
assert_eq!(cli.fork_session_id.as_deref(), Some("session-123"));
|
||||
assert_eq!(cli.prompt.as_deref(), Some(PROMPT));
|
||||
assert!(cli.command.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fork_option_conflicts_with_subcommands() {
|
||||
let err = Cli::try_parse_from(["codex-exec", "--fork", "session-123", "resume"])
|
||||
.and_then(Cli::validate)
|
||||
.expect_err("fork should conflict with subcommands");
|
||||
|
||||
assert_eq!(err.kind(), clap::error::ErrorKind::ArgumentConflict);
|
||||
}
|
||||
|
||||
@@ -34,6 +34,8 @@ use codex_app_server_protocol::ReviewTarget as ApiReviewTarget;
|
||||
use codex_app_server_protocol::ServerNotification;
|
||||
use codex_app_server_protocol::ServerRequest;
|
||||
use codex_app_server_protocol::Thread as AppServerThread;
|
||||
use codex_app_server_protocol::ThreadForkParams;
|
||||
use codex_app_server_protocol::ThreadForkResponse;
|
||||
use codex_app_server_protocol::ThreadItem as AppServerThreadItem;
|
||||
use codex_app_server_protocol::ThreadListParams;
|
||||
use codex_app_server_protocol::ThreadListResponse;
|
||||
@@ -198,6 +200,7 @@ struct ExecRunArgs {
|
||||
config: Config,
|
||||
dangerously_bypass_approvals_and_sandbox: bool,
|
||||
exec_span: tracing::Span,
|
||||
fork_session_id: Option<String>,
|
||||
images: Vec<PathBuf>,
|
||||
json_mode: bool,
|
||||
last_message_file: Option<PathBuf>,
|
||||
@@ -230,6 +233,7 @@ pub async fn run_main(cli: Cli, arg0_paths: Arg0DispatchPaths) -> anyhow::Result
|
||||
|
||||
let Cli {
|
||||
command,
|
||||
fork_session_id,
|
||||
shared,
|
||||
skip_git_repo_check,
|
||||
ephemeral,
|
||||
@@ -529,6 +533,7 @@ pub async fn run_main(cli: Cli, arg0_paths: Arg0DispatchPaths) -> anyhow::Result
|
||||
config,
|
||||
dangerously_bypass_approvals_and_sandbox,
|
||||
exec_span: exec_span.clone(),
|
||||
fork_session_id,
|
||||
images,
|
||||
json_mode,
|
||||
last_message_file,
|
||||
@@ -550,6 +555,7 @@ async fn run_exec_session(args: ExecRunArgs) -> anyhow::Result<()> {
|
||||
config,
|
||||
dangerously_bypass_approvals_and_sandbox,
|
||||
exec_span,
|
||||
fork_session_id,
|
||||
images,
|
||||
json_mode,
|
||||
last_message_file,
|
||||
@@ -667,58 +673,78 @@ async fn run_exec_session(args: ExecRunArgs) -> anyhow::Result<()> {
|
||||
anyhow::anyhow!("failed to initialize in-process app-server client: {err}")
|
||||
})?;
|
||||
|
||||
// Handle resume subcommand through existing `thread/list` + `thread/resume`
|
||||
// APIs so exec no longer reaches into rollout storage directly.
|
||||
let (primary_thread_id, fallback_session_configured) = if let Some(ExecCommand::Resume(args)) =
|
||||
command.as_ref()
|
||||
{
|
||||
if let Some(thread_id) = resolve_resume_thread_id(&client, &config, args).await? {
|
||||
let response: ThreadResumeResponse = send_request_with_response(
|
||||
&client,
|
||||
ClientRequest::ThreadResume {
|
||||
request_id: request_ids.next(),
|
||||
params: thread_resume_params_from_config(&config, thread_id),
|
||||
},
|
||||
"thread/resume",
|
||||
)
|
||||
.await
|
||||
.map_err(anyhow::Error::msg)?;
|
||||
let session_configured =
|
||||
session_configured_from_thread_resume_response(&response, &config)
|
||||
.map_err(anyhow::Error::msg)?;
|
||||
(session_configured.session_id, session_configured)
|
||||
} else {
|
||||
let response: ThreadStartResponse = send_request_with_response(
|
||||
&client,
|
||||
ClientRequest::ThreadStart {
|
||||
request_id: request_ids.next(),
|
||||
params: thread_start_params_from_config(&config),
|
||||
},
|
||||
"thread/start",
|
||||
)
|
||||
.await
|
||||
.map_err(anyhow::Error::msg)?;
|
||||
let session_configured =
|
||||
session_configured_from_thread_start_response(&response, &config)
|
||||
.map_err(anyhow::Error::msg)?;
|
||||
(session_configured.session_id, session_configured)
|
||||
// Handle resume/fork/start through app-server APIs so exec no longer reaches into
|
||||
// rollout storage directly for normal bootstrap.
|
||||
let (primary_thread_id, fallback_session_configured) = match command.as_ref() {
|
||||
Some(ExecCommand::Resume(args)) => {
|
||||
if let Some(thread_id) = resolve_resume_thread_id(&client, &config, args).await? {
|
||||
let response: ThreadResumeResponse = send_request_with_response(
|
||||
&client,
|
||||
ClientRequest::ThreadResume {
|
||||
request_id: request_ids.next(),
|
||||
params: thread_resume_params_from_config(&config, thread_id),
|
||||
},
|
||||
"thread/resume",
|
||||
)
|
||||
.await
|
||||
.map_err(anyhow::Error::msg)?;
|
||||
let session_configured =
|
||||
session_configured_from_thread_resume_response(&response, &config)
|
||||
.map_err(anyhow::Error::msg)?;
|
||||
(session_configured.session_id, session_configured)
|
||||
} else {
|
||||
let response: ThreadStartResponse = send_request_with_response(
|
||||
&client,
|
||||
ClientRequest::ThreadStart {
|
||||
request_id: request_ids.next(),
|
||||
params: thread_start_params_from_config(&config),
|
||||
},
|
||||
"thread/start",
|
||||
)
|
||||
.await
|
||||
.map_err(anyhow::Error::msg)?;
|
||||
let session_configured =
|
||||
session_configured_from_thread_start_response(&response, &config)
|
||||
.map_err(anyhow::Error::msg)?;
|
||||
(session_configured.session_id, session_configured)
|
||||
}
|
||||
}
|
||||
Some(ExecCommand::Review(_)) | None => {
|
||||
if let Some(session_id) = fork_session_id.as_deref() {
|
||||
let response: ThreadForkResponse = send_request_with_response(
|
||||
&client,
|
||||
ClientRequest::ThreadFork {
|
||||
request_id: request_ids.next(),
|
||||
params: thread_fork_params_from_config(
|
||||
&config, session_id, /*path*/ None,
|
||||
),
|
||||
},
|
||||
"thread/fork",
|
||||
)
|
||||
.await
|
||||
.map_err(anyhow::Error::msg)?;
|
||||
let session_configured =
|
||||
session_configured_from_thread_fork_response(&response, &config)
|
||||
.map_err(anyhow::Error::msg)?;
|
||||
(session_configured.session_id, session_configured)
|
||||
} else {
|
||||
let response: ThreadStartResponse = send_request_with_response(
|
||||
&client,
|
||||
ClientRequest::ThreadStart {
|
||||
request_id: request_ids.next(),
|
||||
params: thread_start_params_from_config(&config),
|
||||
},
|
||||
"thread/start",
|
||||
)
|
||||
.await
|
||||
.map_err(anyhow::Error::msg)?;
|
||||
let session_configured =
|
||||
session_configured_from_thread_start_response(&response, &config)
|
||||
.map_err(anyhow::Error::msg)?;
|
||||
(session_configured.session_id, session_configured)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let response: ThreadStartResponse = send_request_with_response(
|
||||
&client,
|
||||
ClientRequest::ThreadStart {
|
||||
request_id: request_ids.next(),
|
||||
params: thread_start_params_from_config(&config),
|
||||
},
|
||||
"thread/start",
|
||||
)
|
||||
.await
|
||||
.map_err(anyhow::Error::msg)?;
|
||||
let session_configured = session_configured_from_thread_start_response(&response, &config)
|
||||
.map_err(anyhow::Error::msg)?;
|
||||
(session_configured.session_id, session_configured)
|
||||
};
|
||||
|
||||
let primary_thread_id_for_span = primary_thread_id.to_string();
|
||||
// Use the start/resume response as the authoritative bootstrap payload.
|
||||
// Waiting for a later streamed `SessionConfigured` event adds up to 10s of
|
||||
@@ -1028,6 +1054,33 @@ fn approvals_reviewer_override_from_config(
|
||||
Some(config.approvals_reviewer.into())
|
||||
}
|
||||
|
||||
fn thread_fork_params_from_config(
|
||||
config: &Config,
|
||||
thread_id: &str,
|
||||
path: Option<PathBuf>,
|
||||
) -> ThreadForkParams {
|
||||
let permissions = permissions_selection_from_config(config);
|
||||
let sandbox = permissions.is_none().then(|| {
|
||||
sandbox_mode_from_permission_profile(
|
||||
&config.permissions.permission_profile(),
|
||||
config.cwd.as_path(),
|
||||
)
|
||||
});
|
||||
ThreadForkParams {
|
||||
thread_id: thread_id.to_string(),
|
||||
path,
|
||||
model: config.model.clone(),
|
||||
model_provider: Some(config.model_provider_id.clone()),
|
||||
cwd: Some(config.cwd.to_string_lossy().to_string()),
|
||||
approval_policy: Some(config.permissions.approval_policy.value().into()),
|
||||
approvals_reviewer: approvals_reviewer_override_from_config(config),
|
||||
sandbox: sandbox.flatten(),
|
||||
permissions,
|
||||
config: config_request_overrides_from_config(config),
|
||||
..ThreadForkParams::default()
|
||||
}
|
||||
}
|
||||
|
||||
async fn send_request_with_response<T>(
|
||||
client: &InProcessAppServerClient,
|
||||
request: ClientRequest,
|
||||
@@ -1093,6 +1146,30 @@ fn session_configured_from_thread_resume_response(
|
||||
)
|
||||
}
|
||||
|
||||
fn session_configured_from_thread_fork_response(
|
||||
response: &ThreadForkResponse,
|
||||
config: &Config,
|
||||
) -> Result<SessionConfiguredEvent, String> {
|
||||
session_configured_from_thread_response(
|
||||
&response.thread.id,
|
||||
response.thread.name.clone(),
|
||||
response.thread.path.clone(),
|
||||
response.model.clone(),
|
||||
response.model_provider.clone(),
|
||||
response.service_tier,
|
||||
response.approval_policy.to_core(),
|
||||
response.approvals_reviewer.to_core(),
|
||||
response
|
||||
.permission_profile
|
||||
.clone()
|
||||
.map(Into::into)
|
||||
.unwrap_or_else(|| config.permissions.permission_profile()),
|
||||
response.active_permission_profile.clone().map(Into::into),
|
||||
response.cwd.clone(),
|
||||
response.reasoning_effort,
|
||||
)
|
||||
}
|
||||
|
||||
fn review_target_to_api(target: ReviewTarget) -> ApiReviewTarget {
|
||||
match target {
|
||||
ReviewTarget::UncommittedChanges => ApiReviewTarget::UncommittedChanges,
|
||||
|
||||
@@ -408,6 +408,11 @@ async fn thread_lifecycle_params_include_legacy_sandbox_when_no_active_profile()
|
||||
|
||||
let start_params = thread_start_params_from_config(&config);
|
||||
let resume_params = thread_resume_params_from_config(&config, "thread-id".to_string());
|
||||
let fork_params = thread_fork_params_from_config(
|
||||
&config,
|
||||
"67e55044-10b1-426f-9247-bb680e5fe0c8",
|
||||
/*path*/ None,
|
||||
);
|
||||
|
||||
assert_eq!(config.permissions.active_permission_profile(), None);
|
||||
assert_eq!(
|
||||
@@ -420,6 +425,43 @@ async fn thread_lifecycle_params_include_legacy_sandbox_when_no_active_profile()
|
||||
Some(codex_app_server_protocol::SandboxMode::DangerFullAccess)
|
||||
);
|
||||
assert_eq!(resume_params.permissions, None);
|
||||
assert_eq!(
|
||||
fork_params.sandbox,
|
||||
Some(codex_app_server_protocol::SandboxMode::DangerFullAccess)
|
||||
);
|
||||
assert_eq!(fork_params.permissions, None);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn thread_fork_params_include_review_policy_when_auto_review_is_enabled() {
|
||||
let codex_home = tempdir().expect("create temp codex home");
|
||||
let cwd = tempdir().expect("create temp cwd");
|
||||
let config = ConfigBuilder::default()
|
||||
.codex_home(codex_home.path().to_path_buf())
|
||||
.harness_overrides(ConfigOverrides {
|
||||
approvals_reviewer: Some(ApprovalsReviewer::AutoReview),
|
||||
..Default::default()
|
||||
})
|
||||
.fallback_cwd(Some(cwd.path().to_path_buf()))
|
||||
.build()
|
||||
.await
|
||||
.expect("build config for fork params");
|
||||
|
||||
let params = thread_fork_params_from_config(
|
||||
&config,
|
||||
"67e55044-10b1-426f-9247-bb680e5fe0c8",
|
||||
/*path*/ None,
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
params.approvals_reviewer,
|
||||
Some(codex_app_server_protocol::ApprovalsReviewer::AutoReview)
|
||||
);
|
||||
assert_eq!(params.sandbox, None);
|
||||
assert_eq!(
|
||||
params.permissions,
|
||||
permissions_selection_from_config(&config)
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -498,3 +540,58 @@ fn sample_thread_start_response() -> ThreadStartResponse {
|
||||
reasoning_effort: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn session_configured_from_thread_fork_response_preserves_permission_profile() {
|
||||
let codex_home = tempdir().expect("create temp codex home");
|
||||
let cwd = tempdir().expect("create temp cwd");
|
||||
let config = ConfigBuilder::default()
|
||||
.codex_home(codex_home.path().to_path_buf())
|
||||
.fallback_cwd(Some(cwd.path().to_path_buf()))
|
||||
.build()
|
||||
.await
|
||||
.expect("build config");
|
||||
let permission_profile = PermissionProfile::Disabled;
|
||||
let response = ThreadForkResponse {
|
||||
thread: codex_app_server_protocol::Thread {
|
||||
id: "67e55044-10b1-426f-9247-bb680e5fe0c8".to_string(),
|
||||
forked_from_id: Some("f6f10963-370f-4f42-8f3b-bb680e5fe0c8".to_string()),
|
||||
preview: String::new(),
|
||||
ephemeral: false,
|
||||
model_provider: "openai".to_string(),
|
||||
created_at: 0,
|
||||
updated_at: 0,
|
||||
status: codex_app_server_protocol::ThreadStatus::Idle,
|
||||
path: Some(PathBuf::from("/tmp/fork-rollout.jsonl")),
|
||||
cwd: test_path_buf("/tmp").abs(),
|
||||
cli_version: "0.0.0".to_string(),
|
||||
source: codex_app_server_protocol::SessionSource::Cli,
|
||||
agent_nickname: None,
|
||||
agent_role: None,
|
||||
git_info: None,
|
||||
name: Some("forked-thread".to_string()),
|
||||
turns: vec![],
|
||||
},
|
||||
model: "gpt-5.4".to_string(),
|
||||
model_provider: "openai".to_string(),
|
||||
service_tier: None,
|
||||
cwd: test_path_buf("/tmp").abs(),
|
||||
instruction_sources: Vec::new(),
|
||||
approval_policy: codex_app_server_protocol::AskForApproval::OnRequest,
|
||||
approvals_reviewer: codex_app_server_protocol::ApprovalsReviewer::AutoReview,
|
||||
sandbox: codex_app_server_protocol::SandboxPolicy::WorkspaceWrite {
|
||||
writable_roots: vec![],
|
||||
network_access: false,
|
||||
exclude_tmpdir_env_var: false,
|
||||
exclude_slash_tmp: false,
|
||||
},
|
||||
permission_profile: Some(permission_profile.clone().into()),
|
||||
active_permission_profile: None,
|
||||
reasoning_effort: None,
|
||||
};
|
||||
|
||||
let event = session_configured_from_thread_fork_response(&response, &config)
|
||||
.expect("build fork session configured event");
|
||||
|
||||
assert_eq!(event.permission_profile, permission_profile);
|
||||
}
|
||||
|
||||
@@ -29,7 +29,10 @@ fn main() -> anyhow::Result<()> {
|
||||
arg0_dispatch_or_else(|arg0_paths: Arg0DispatchPaths| async move {
|
||||
let top_cli = TopCli::parse();
|
||||
// Merge root-level overrides into inner CLI struct so downstream logic remains unchanged.
|
||||
let mut inner = top_cli.inner;
|
||||
let mut inner = match top_cli.inner.validate() {
|
||||
Ok(inner) => inner,
|
||||
Err(err) => err.exit(),
|
||||
};
|
||||
inner
|
||||
.config_overrides
|
||||
.raw_overrides
|
||||
|
||||
@@ -35,3 +35,24 @@ fn top_cli_parses_resume_prompt_after_config_flag() {
|
||||
"reasoning_level=xhigh"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn top_cli_parses_fork_option_with_root_config() {
|
||||
let cli = TopCli::parse_from([
|
||||
"codex-exec",
|
||||
"--config",
|
||||
"reasoning_level=xhigh",
|
||||
"--fork",
|
||||
"session-123",
|
||||
"echo fork",
|
||||
]);
|
||||
|
||||
assert_eq!(cli.inner.fork_session_id.as_deref(), Some("session-123"));
|
||||
assert!(cli.inner.command.is_none());
|
||||
assert_eq!(cli.inner.prompt.as_deref(), Some("echo fork"));
|
||||
assert_eq!(cli.config_overrides.raw_overrides.len(), 1);
|
||||
assert_eq!(
|
||||
cli.config_overrides.raw_overrides[0],
|
||||
"reasoning_level=xhigh"
|
||||
);
|
||||
}
|
||||
|
||||
149
codex-rs/exec/tests/suite/fork.rs
Normal file
149
codex-rs/exec/tests/suite/fork.rs
Normal file
@@ -0,0 +1,149 @@
|
||||
#![allow(clippy::unwrap_used, clippy::expect_used)]
|
||||
|
||||
use anyhow::Context;
|
||||
use codex_utils_cargo_bin::find_resource;
|
||||
use core_test_support::test_codex_exec::test_codex_exec;
|
||||
use serde_json::Value;
|
||||
use std::string::ToString;
|
||||
use uuid::Uuid;
|
||||
use walkdir::WalkDir;
|
||||
|
||||
/// Utility: scan the sessions dir for a rollout file that contains `marker`
|
||||
/// in any response_item.message.content entry. Returns the absolute path.
|
||||
fn find_session_file_containing_marker(
|
||||
sessions_dir: &std::path::Path,
|
||||
marker: &str,
|
||||
) -> Option<std::path::PathBuf> {
|
||||
for entry in WalkDir::new(sessions_dir) {
|
||||
let entry = match entry {
|
||||
Ok(e) => e,
|
||||
Err(_) => continue,
|
||||
};
|
||||
if !entry.file_type().is_file() {
|
||||
continue;
|
||||
}
|
||||
if !entry.file_name().to_string_lossy().ends_with(".jsonl") {
|
||||
continue;
|
||||
}
|
||||
let path = entry.path();
|
||||
let Ok(content) = std::fs::read_to_string(path) else {
|
||||
continue;
|
||||
};
|
||||
// Skip the first meta line and scan remaining JSONL entries.
|
||||
let mut lines = content.lines();
|
||||
if lines.next().is_none() {
|
||||
continue;
|
||||
}
|
||||
for line in lines {
|
||||
if line.trim().is_empty() {
|
||||
continue;
|
||||
}
|
||||
let Ok(item): Result<Value, _> = serde_json::from_str(line) else {
|
||||
continue;
|
||||
};
|
||||
if item.get("type").and_then(|t| t.as_str()) == Some("response_item")
|
||||
&& let Some(payload) = item.get("payload")
|
||||
&& payload.get("type").and_then(|t| t.as_str()) == Some("message")
|
||||
&& payload
|
||||
.get("content")
|
||||
.map(ToString::to_string)
|
||||
.unwrap_or_default()
|
||||
.contains(marker)
|
||||
{
|
||||
return Some(path.to_path_buf());
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Extract the conversation UUID from the first SessionMeta line in the rollout file.
|
||||
fn extract_conversation_id(path: &std::path::Path) -> String {
|
||||
let content = std::fs::read_to_string(path).unwrap();
|
||||
let mut lines = content.lines();
|
||||
let meta_line = lines.next().expect("missing meta line");
|
||||
let meta: Value = serde_json::from_str(meta_line).expect("invalid meta json");
|
||||
meta.get("payload")
|
||||
.and_then(|p| p.get("id"))
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or_default()
|
||||
.to_string()
|
||||
}
|
||||
|
||||
fn extract_forked_from_id(path: &std::path::Path) -> Option<String> {
|
||||
let content = std::fs::read_to_string(path).unwrap();
|
||||
let mut lines = content.lines();
|
||||
let meta_line = lines.next().expect("missing meta line");
|
||||
let meta: Value = serde_json::from_str(meta_line).expect("invalid meta json");
|
||||
meta.get("payload")
|
||||
.and_then(|payload| payload.get("forked_from_id"))
|
||||
.and_then(Value::as_str)
|
||||
.map(ToString::to_string)
|
||||
}
|
||||
|
||||
fn exec_fixture() -> anyhow::Result<std::path::PathBuf> {
|
||||
Ok(find_resource!("tests/fixtures/cli_responses_fixture.sse")?)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn exec_fork_by_id_creates_new_session_with_copied_history() -> anyhow::Result<()> {
|
||||
let test = test_codex_exec();
|
||||
let fixture = exec_fixture()?;
|
||||
|
||||
let marker = format!("fork-base-{}", Uuid::new_v4());
|
||||
let prompt = format!("echo {marker}");
|
||||
|
||||
test.cmd()
|
||||
.env("CODEX_RS_SSE_FIXTURE", &fixture)
|
||||
.env("OPENAI_BASE_URL", "http://unused.local")
|
||||
.arg("--skip-git-repo-check")
|
||||
.arg(&prompt)
|
||||
.assert()
|
||||
.success();
|
||||
|
||||
let sessions_dir = test.home_path().join("sessions");
|
||||
let original_path = find_session_file_containing_marker(&sessions_dir, &marker)
|
||||
.context("no session file found after first run")?;
|
||||
let session_id = extract_conversation_id(&original_path);
|
||||
|
||||
let marker2 = format!("fork-follow-up-{}", Uuid::new_v4());
|
||||
let prompt2 = format!("echo {marker2}");
|
||||
|
||||
test.cmd()
|
||||
.env("CODEX_RS_SSE_FIXTURE", &fixture)
|
||||
.env("OPENAI_BASE_URL", "http://unused.local")
|
||||
.arg("--skip-git-repo-check")
|
||||
.arg("--fork")
|
||||
.arg(&session_id)
|
||||
.arg(&prompt2)
|
||||
.assert()
|
||||
.success();
|
||||
|
||||
let forked_path = find_session_file_containing_marker(&sessions_dir, &marker2)
|
||||
.context("no forked session file found for second marker")?;
|
||||
|
||||
assert_ne!(
|
||||
forked_path, original_path,
|
||||
"fork should create a new session file"
|
||||
);
|
||||
|
||||
let forked_content = std::fs::read_to_string(&forked_path)?;
|
||||
assert_eq!(
|
||||
extract_forked_from_id(&forked_path).as_deref(),
|
||||
Some(session_id.as_str())
|
||||
);
|
||||
assert!(
|
||||
forked_content.contains(&marker),
|
||||
"forked session should copy ancestor rollout history"
|
||||
);
|
||||
assert!(forked_content.contains(&marker2));
|
||||
|
||||
let original_content = std::fs::read_to_string(&original_path)?;
|
||||
assert!(original_content.contains(&marker));
|
||||
assert!(
|
||||
!original_content.contains(&marker2),
|
||||
"original session should not receive the forked prompt"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -3,6 +3,7 @@ mod add_dir;
|
||||
mod apply_patch;
|
||||
mod auth_env;
|
||||
mod ephemeral;
|
||||
mod fork;
|
||||
mod mcp_required_exit;
|
||||
mod originator;
|
||||
mod output_schema;
|
||||
|
||||
@@ -175,6 +175,7 @@ impl MemoryStartupContext {
|
||||
Some(Arc::clone(&self.auth_manager)),
|
||||
self.thread_id,
|
||||
installation_id,
|
||||
/*prompt_cache_key_override*/ None,
|
||||
config.model_provider.clone(),
|
||||
session_source,
|
||||
config.model_verbosity,
|
||||
|
||||
Reference in New Issue
Block a user