Compare commits

...

11 Commits

Author SHA1 Message Date
jimmyfraiture
8a7f75eeef Just fix 2025-09-29 13:08:17 +01:00
jimmyfraiture
6283dc42f8 Rename 2025-09-29 12:58:54 +01:00
jimmyfraiture
0d340b1bec P5 2025-09-29 12:03:57 +01:00
jimmyfraiture
c9f6b5dffc P4 2025-09-29 11:06:44 +01:00
jimmyfraiture
2efe961ac1 P3 2025-09-29 10:49:19 +01:00
jimmyfraiture
491ba05f71 P2 2025-09-29 10:30:24 +01:00
jimmyfraiture
cd7e37c6b0 P1 2025-09-29 09:48:56 +01:00
jimmyfraiture
3cdf35e198 Merge remote-tracking branch 'origin/main' into jif/sandbox-1 2025-09-26 15:50:40 +02:00
jimmyfraiture
caab5a19ee Move some stuff around 2025-09-26 14:46:07 +02:00
jimmyfraiture
a29380cdff Isolate apply patch adapter 2025-09-26 14:02:38 +02:00
jimmyfraiture
805de19381 V1 2025-09-26 13:42:58 +02:00
86 changed files with 4479 additions and 3837 deletions

35
codex-rs/Cargo.lock generated
View File

@@ -573,6 +573,38 @@ version = "0.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e9b18233253483ce2f65329a24072ec414db782531bdbb7d0bbc4bd2ce6b7e21"
[[package]]
name = "codex-agent"
version = "0.0.0"
dependencies = [
"anyhow",
"async-trait",
"base64",
"codex-apply-patch",
"codex-file-search",
"codex-protocol",
"core_test_support",
"libc",
"mcp-types",
"portable-pty",
"pretty_assertions",
"serde",
"serde_json",
"sha1",
"shlex",
"similar",
"tempfile",
"thiserror 2.0.16",
"time",
"tokio",
"tracing",
"tree-sitter",
"tree-sitter-bash",
"uuid",
"which",
"wildmatch",
]
[[package]]
name = "codex-ansi-escape"
version = "0.0.0"
@@ -677,6 +709,7 @@ dependencies = [
"base64",
"bytes",
"chrono",
"codex-agent",
"codex-apply-patch",
"codex-file-search",
"codex-mcp-client",
@@ -715,8 +748,6 @@ dependencies = [
"toml",
"toml_edit",
"tracing",
"tree-sitter",
"tree-sitter-bash",
"uuid",
"walkdir",
"which",

View File

@@ -1,5 +1,6 @@
[workspace]
members = [
"agent",
"ansi-escape",
"apply-patch",
"arg0",
@@ -33,6 +34,7 @@ edition = "2024"
[workspace.dependencies]
# Internal
codex-agent = { path = "agent" }
codex-ansi-escape = { path = "ansi-escape" }
codex-apply-patch = { path = "apply-patch" }
codex-arg0 = { path = "arg0" }

View File

@@ -97,6 +97,7 @@ The same setting can be persisted in `~/.codex/config.toml` via the top-level `s
This folder is the root of a Cargo workspace. It contains quite a bit of experimental code, but here are the key crates:
- [`core/`](./core) contains the business logic for Codex. Ultimately, we hope this to be a library crate that is generally useful for building other Rust/native applications that use Codex.
- [`docs/agent_runtime_baseline.md`](./docs/agent_runtime_baseline.md) documents the current agent runtime interfaces (`Codex`, `Session`, `SessionTask`) and links to the ongoing refactor plan in `agent_refactor.md`.
- [`exec/`](./exec) "headless" CLI for use in automation.
- [`tui/`](./tui) CLI that launches a fullscreen TUI built with [Ratatui](https://ratatui.rs/).
- [`cli/`](./cli) CLI multitool that provides the aforementioned CLIs via subcommands.

37
codex-rs/agent/Cargo.toml Normal file
View File

@@ -0,0 +1,37 @@
[package]
name = "codex-agent"
version.workspace = true
edition.workspace = true
[dependencies]
anyhow = { workspace = true }
async-trait = { workspace = true }
codex-protocol = { workspace = true }
codex-apply-patch = { workspace = true }
mcp-types = { workspace = true }
base64 = { workspace = true }
serde_json = { workspace = true }
libc = { workspace = true }
portable-pty = { workspace = true }
serde = { workspace = true, features = ["derive"] }
sha1 = { workspace = true }
shlex = { workspace = true }
similar = { workspace = true }
thiserror = { workspace = true }
tokio = { workspace = true, features = ["macros", "process", "rt-multi-thread", "sync", "time"] }
uuid = { workspace = true, features = ["serde", "v4"] }
which = { workspace = true }
wildmatch = { workspace = true }
codex-file-search = { workspace = true }
time = { workspace = true, features = ["formatting", "parsing", "local-offset", "macros"] }
tracing = { workspace = true }
tree-sitter = { workspace = true }
tree-sitter-bash = { workspace = true }
[dev-dependencies]
core_test_support = { workspace = true }
tempfile = { workspace = true }
pretty_assertions = { workspace = true }
[lints]
workspace = true

View File

@@ -1,18 +1,22 @@
use crate::codex::Session;
use crate::codex::TurnContext;
use crate::function_tool::FunctionCallError;
use crate::protocol::FileChange;
use crate::protocol::ReviewDecision;
use crate::safety::SafetyCheck;
use crate::safety::assess_patch_safety;
use std::collections::HashMap;
use std::path::Path;
use std::path::PathBuf;
use codex_apply_patch::ApplyPatchAction;
use codex_apply_patch::ApplyPatchFileChange;
use std::collections::HashMap;
use std::path::PathBuf;
use codex_protocol::protocol::AskForApproval;
use codex_protocol::protocol::FileChange;
use codex_protocol::protocol::ReviewDecision;
use codex_protocol::protocol::SandboxPolicy;
use crate::function_tool::FunctionCallError;
use crate::safety::SafetyCheck;
use crate::safety::assess_patch_safety;
use crate::services::ApprovalCoordinator;
pub const CODEX_APPLY_PATCH_ARG1: &str = "--codex-run-as-apply-patch";
pub(crate) enum InternalApplyPatchInvocation {
pub enum InternalApplyPatchInvocation {
/// The `apply_patch` call was handled programmatically, without any sort
/// of sandbox, because the user explicitly approved it. This is the
/// result to use with the `shell` function call that contained `apply_patch`.
@@ -27,23 +31,30 @@ pub(crate) enum InternalApplyPatchInvocation {
DelegateToExec(ApplyPatchExec),
}
pub(crate) struct ApplyPatchExec {
pub(crate) action: ApplyPatchAction,
pub(crate) user_explicitly_approved_this_action: bool,
#[derive(Debug)]
pub struct ApplyPatchExec {
pub action: ApplyPatchAction,
pub user_explicitly_approved_this_action: bool,
}
pub(crate) async fn apply_patch(
sess: &Session,
turn_context: &TurnContext,
pub struct ApplyPatchContext<'a> {
pub approval_policy: AskForApproval,
pub sandbox_policy: &'a SandboxPolicy,
pub cwd: &'a Path,
}
pub async fn apply_patch(
approvals: &dyn ApprovalCoordinator,
context: ApplyPatchContext<'_>,
sub_id: &str,
call_id: &str,
action: ApplyPatchAction,
) -> InternalApplyPatchInvocation {
match assess_patch_safety(
&action,
turn_context.approval_policy,
&turn_context.sandbox_policy,
&turn_context.cwd,
context.approval_policy,
context.sandbox_policy,
context.cwd,
) {
SafetyCheck::AutoApprove { .. } => {
InternalApplyPatchInvocation::DelegateToExec(ApplyPatchExec {
@@ -52,17 +63,11 @@ pub(crate) async fn apply_patch(
})
}
SafetyCheck::AskUser => {
// Compute a readable summary of path changes to include in the
// approval request so the user can make an informed decision.
//
// Note that it might be worth expanding this approval request to
// give the user the option to expand the set of writable roots so
// that similar patches can be auto-approved in the future during
// this session.
let rx_approve = sess
let approval = approvals
.request_patch_approval(sub_id.to_owned(), call_id.to_owned(), &action, None, None)
.await;
match rx_approve.await.unwrap_or_default() {
match approval {
ReviewDecision::Approved | ReviewDecision::ApprovedForSession => {
InternalApplyPatchInvocation::DelegateToExec(ApplyPatchExec {
action,
@@ -82,9 +87,7 @@ pub(crate) async fn apply_patch(
}
}
pub(crate) fn convert_apply_patch_to_protocol(
action: &ApplyPatchAction,
) -> HashMap<PathBuf, FileChange> {
pub fn convert_apply_patch_to_protocol(action: &ApplyPatchAction) -> HashMap<PathBuf, FileChange> {
let changes = action.changes();
let mut result = HashMap::with_capacity(changes.len());
for (path, change) in changes {

View File

@@ -0,0 +1,305 @@
//! Shared configuration data structures for Codex runtime and hosts.
//
// This module intentionally focuses on simple data containers without
// business logic so they can be reused across crates.
use std::collections::HashMap;
use std::path::PathBuf;
use std::time::Duration;
use wildmatch::WildMatchPattern;
use serde::Deserialize;
use serde::Deserializer;
use serde::Serialize;
use serde::de::Error as SerdeError;
#[derive(Serialize, Debug, Clone, PartialEq)]
pub struct McpServerConfig {
pub command: String,
#[serde(default)]
pub args: Vec<String>,
#[serde(default)]
pub env: Option<HashMap<String, String>>,
/// Startup timeout in seconds for initializing MCP server & initially listing tools.
#[serde(
default,
with = "option_duration_secs",
skip_serializing_if = "Option::is_none"
)]
pub startup_timeout_sec: Option<Duration>,
/// Default timeout for MCP tool calls initiated via this server.
#[serde(default, with = "option_duration_secs")]
pub tool_timeout_sec: Option<Duration>,
}
impl<'de> Deserialize<'de> for McpServerConfig {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
#[derive(Deserialize)]
struct RawMcpServerConfig {
command: String,
#[serde(default)]
args: Vec<String>,
#[serde(default)]
env: Option<HashMap<String, String>>,
#[serde(default)]
startup_timeout_sec: Option<f64>,
#[serde(default)]
startup_timeout_ms: Option<u64>,
#[serde(default, with = "option_duration_secs")]
tool_timeout_sec: Option<Duration>,
}
let raw = RawMcpServerConfig::deserialize(deserializer)?;
let startup_timeout_sec = match (raw.startup_timeout_sec, raw.startup_timeout_ms) {
(Some(sec), _) => {
let duration = Duration::try_from_secs_f64(sec).map_err(SerdeError::custom)?;
Some(duration)
}
(None, Some(ms)) => Some(Duration::from_millis(ms)),
(None, None) => None,
};
Ok(Self {
command: raw.command,
args: raw.args,
env: raw.env,
startup_timeout_sec,
tool_timeout_sec: raw.tool_timeout_sec,
})
}
}
mod option_duration_secs {
use serde::Deserialize;
use serde::Deserializer;
use serde::Serializer;
use std::time::Duration;
pub fn serialize<S>(value: &Option<Duration>, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
match value {
Some(duration) => serializer.serialize_some(&duration.as_secs_f64()),
None => serializer.serialize_none(),
}
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<Duration>, D::Error>
where
D: Deserializer<'de>,
{
let secs = Option::<f64>::deserialize(deserializer)?;
secs.map(|secs| Duration::try_from_secs_f64(secs).map_err(serde::de::Error::custom))
.transpose()
}
}
#[derive(Deserialize, Debug, Copy, Clone, PartialEq)]
pub enum UriBasedFileOpener {
#[serde(rename = "vscode")]
VsCode,
#[serde(rename = "vscode-insiders")]
VsCodeInsiders,
#[serde(rename = "windsurf")]
Windsurf,
#[serde(rename = "cursor")]
Cursor,
/// Option to disable the URI-based file opener.
#[serde(rename = "none")]
None,
}
impl UriBasedFileOpener {
pub fn get_scheme(&self) -> Option<&str> {
match self {
UriBasedFileOpener::VsCode => Some("vscode"),
UriBasedFileOpener::VsCodeInsiders => Some("vscode-insiders"),
UriBasedFileOpener::Windsurf => Some("windsurf"),
UriBasedFileOpener::Cursor => Some("cursor"),
UriBasedFileOpener::None => None,
}
}
}
/// Settings that govern if and what will be written to `~/.codex/history.jsonl`.
#[derive(Deserialize, Debug, Clone, PartialEq, Default)]
pub struct History {
/// If true, history entries will not be written to disk.
pub persistence: HistoryPersistence,
/// If set, the maximum size of the history file in bytes.
/// TODO(mbolin): Not currently honored.
pub max_bytes: Option<usize>,
}
#[derive(Deserialize, Debug, Copy, Clone, PartialEq, Default)]
#[serde(rename_all = "kebab-case")]
pub enum HistoryPersistence {
/// Save all history entries to disk.
#[default]
SaveAll,
/// Do not write history to disk.
None,
}
#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
#[serde(untagged)]
pub enum Notifications {
Enabled(bool),
Custom(Vec<String>),
}
impl Default for Notifications {
fn default() -> Self {
Self::Enabled(false)
}
}
/// Collection of settings that are specific to the TUI.
#[derive(Deserialize, Debug, Clone, PartialEq, Default)]
pub struct Tui {
/// Enable desktop notifications from the TUI when the terminal is unfocused.
/// Defaults to `false`.
#[serde(default)]
pub notifications: Notifications,
}
#[derive(Deserialize, Debug, Clone, PartialEq, Default)]
pub struct SandboxWorkspaceWrite {
#[serde(default)]
pub writable_roots: Vec<PathBuf>,
#[serde(default)]
pub network_access: bool,
#[serde(default)]
pub exclude_tmpdir_env_var: bool,
#[serde(default)]
pub exclude_slash_tmp: bool,
}
impl From<SandboxWorkspaceWrite> for codex_protocol::mcp_protocol::SandboxSettings {
fn from(sandbox_workspace_write: SandboxWorkspaceWrite) -> Self {
Self {
writable_roots: sandbox_workspace_write.writable_roots,
network_access: Some(sandbox_workspace_write.network_access),
exclude_tmpdir_env_var: Some(sandbox_workspace_write.exclude_tmpdir_env_var),
exclude_slash_tmp: Some(sandbox_workspace_write.exclude_slash_tmp),
}
}
}
#[derive(Deserialize, Debug, Clone, PartialEq, Default)]
#[serde(rename_all = "kebab-case")]
pub enum ShellEnvironmentPolicyInherit {
/// "Core" environment variables for the platform. On UNIX, this would
/// include HOME, LOGNAME, PATH, SHELL, and USER, among others.
Core,
/// Inherits the full environment from the parent process.
#[default]
All,
/// Do not inherit any environment variables from the parent process.
None,
}
/// Policy for building the `env` when spawning a process via either the
/// `shell` or `local_shell` tool.
#[derive(Deserialize, Debug, Clone, PartialEq, Default)]
pub struct ShellEnvironmentPolicyToml {
pub inherit: Option<ShellEnvironmentPolicyInherit>,
pub ignore_default_excludes: Option<bool>,
/// List of regular expressions.
pub exclude: Option<Vec<String>>,
pub r#set: Option<HashMap<String, String>>,
/// List of regular expressions.
pub include_only: Option<Vec<String>>,
pub experimental_use_profile: Option<bool>,
}
pub type EnvironmentVariablePattern = WildMatchPattern<'*', '?'>;
/// Deriving the `env` based on this policy works as follows:
/// 1. Create an initial map based on the `inherit` policy.
/// 2. If `ignore_default_excludes` is false, filter the map using the default
/// exclude pattern(s), which are: `"*KEY*"` and `"*TOKEN*"`.
/// 3. If `exclude` is not empty, filter the map using the provided patterns.
/// 4. Insert any entries from `r#set` into the map.
/// 5. If non-empty, filter the map using the `include_only` patterns.
#[derive(Debug, Clone, PartialEq, Default)]
pub struct ShellEnvironmentPolicy {
/// Starting point when building the environment.
pub inherit: ShellEnvironmentPolicyInherit,
/// True to skip the check to exclude default environment variables that
/// contain "KEY" or "TOKEN" in their name.
pub ignore_default_excludes: bool,
/// Environment variable names to exclude from the environment.
pub exclude: Vec<EnvironmentVariablePattern>,
/// (key, value) pairs to insert in the environment.
pub r#set: HashMap<String, String>,
/// Environment variable names to retain in the environment.
pub include_only: Vec<EnvironmentVariablePattern>,
/// If true, the shell profile will be used to run the command.
pub use_profile: bool,
}
impl From<ShellEnvironmentPolicyToml> for ShellEnvironmentPolicy {
fn from(toml: ShellEnvironmentPolicyToml) -> Self {
// Default to inheriting the full environment when not specified.
let inherit = toml.inherit.unwrap_or(ShellEnvironmentPolicyInherit::All);
let ignore_default_excludes = toml.ignore_default_excludes.unwrap_or(false);
let exclude = toml
.exclude
.unwrap_or_default()
.into_iter()
.map(|s| EnvironmentVariablePattern::new_case_insensitive(&s))
.collect();
let r#set = toml.r#set.unwrap_or_default();
let include_only = toml
.include_only
.unwrap_or_default()
.into_iter()
.map(|s| EnvironmentVariablePattern::new_case_insensitive(&s))
.collect();
let use_profile = toml.experimental_use_profile.unwrap_or(false);
Self {
inherit,
ignore_default_excludes,
exclude,
r#set,
include_only,
use_profile,
}
}
}
#[derive(Deserialize, Debug, Clone, PartialEq, Eq, Default, Hash)]
#[serde(rename_all = "kebab-case")]
pub enum ReasoningSummaryFormat {
#[default]
None,
Experimental,
}

View File

@@ -0,0 +1,117 @@
use codex_protocol::models::ResponseItem;
/// Transcript of conversation history shared across agent hosts.
#[derive(Debug, Clone, Default)]
pub struct ConversationHistory {
/// Oldest items appear at the start of the vector.
items: Vec<ResponseItem>,
}
impl ConversationHistory {
pub fn new() -> Self {
Self { items: Vec::new() }
}
/// Returns a clone of the stored transcript.
pub fn contents(&self) -> Vec<ResponseItem> {
self.items.clone()
}
/// Records additional response items, filtering out non-API messages.
pub fn record_items<I>(&mut self, items: I)
where
I: IntoIterator,
I::Item: std::ops::Deref<Target = ResponseItem>,
{
for item in items {
if !is_api_message(&item) {
continue;
}
self.items.push(item.clone());
}
}
pub fn replace(&mut self, items: Vec<ResponseItem>) {
self.items = items;
}
}
/// Detects whether the given message should be persisted to history.
fn is_api_message(message: &ResponseItem) -> bool {
match message {
ResponseItem::Message { role, .. } => role.as_str() != "system",
ResponseItem::FunctionCallOutput { .. }
| ResponseItem::FunctionCall { .. }
| ResponseItem::CustomToolCall { .. }
| ResponseItem::CustomToolCallOutput { .. }
| ResponseItem::LocalShellCall { .. }
| ResponseItem::Reasoning { .. }
| ResponseItem::WebSearchCall { .. } => true,
ResponseItem::Other => false,
}
}
#[cfg(test)]
mod tests {
use super::*;
use codex_protocol::models::ContentItem;
fn assistant_msg(text: &str) -> ResponseItem {
ResponseItem::Message {
id: None,
role: "assistant".to_string(),
content: vec![ContentItem::OutputText {
text: text.to_string(),
}],
}
}
fn user_msg(text: &str) -> ResponseItem {
ResponseItem::Message {
id: None,
role: "user".to_string(),
content: vec![ContentItem::OutputText {
text: text.to_string(),
}],
}
}
#[test]
fn filters_non_api_messages() {
let mut h = ConversationHistory::default();
let system = ResponseItem::Message {
id: None,
role: "system".to_string(),
content: vec![ContentItem::OutputText {
text: "ignored".to_string(),
}],
};
h.record_items([&system, &ResponseItem::Other]);
let u = user_msg("hi");
let a = assistant_msg("hello");
h.record_items([&u, &a]);
let items = h.contents();
assert_eq!(
items,
vec![
ResponseItem::Message {
id: None,
role: "user".to_string(),
content: vec![ContentItem::OutputText {
text: "hi".to_string()
}]
},
ResponseItem::Message {
id: None,
role: "assistant".to_string(),
content: vec![ContentItem::OutputText {
text: "hello".to_string()
}]
}
]
);
}
}

View File

@@ -5,7 +5,8 @@ use tokio::sync::mpsc;
use tokio::task::JoinHandle;
#[derive(Debug)]
pub(crate) struct ExecCommandSession {
#[allow(dead_code)]
pub struct ExecCommandSession {
/// Queue for writing bytes to the process stdin (PTY master write side).
writer_tx: mpsc::Sender<Vec<u8>>,
/// Broadcast stream of output chunks read from the PTY. New subscribers
@@ -29,8 +30,9 @@ pub(crate) struct ExecCommandSession {
exit_status: std::sync::Arc<std::sync::atomic::AtomicBool>,
}
#[allow(dead_code)]
impl ExecCommandSession {
pub(crate) fn new(
pub fn new(
writer_tx: mpsc::Sender<Vec<u8>>,
output_tx: broadcast::Sender<Vec<u8>>,
killer: Box<dyn portable_pty::ChildKiller + Send + Sync>,
@@ -54,7 +56,7 @@ impl ExecCommandSession {
)
}
pub(crate) fn writer_sender(&self) -> mpsc::Sender<Vec<u8>> {
pub fn writer_sender(&self) -> mpsc::Sender<Vec<u8>> {
self.writer_tx.clone()
}
@@ -62,7 +64,7 @@ impl ExecCommandSession {
self.output_tx.subscribe()
}
pub(crate) fn has_exited(&self) -> bool {
pub fn has_exited(&self) -> bool {
self.exit_status.load(std::sync::atomic::Ordering::SeqCst)
}
}

View File

@@ -0,0 +1,11 @@
mod exec_command_params;
mod exec_command_session;
mod session_id;
mod session_manager;
pub use exec_command_params::ExecCommandParams;
pub use exec_command_params::WriteStdinParams;
pub use exec_command_session::ExecCommandSession;
pub use session_id::SessionId;
pub use session_manager::ExecCommandOutput;
pub use session_manager::SessionManager as ExecSessionManager;

View File

@@ -2,4 +2,4 @@ use serde::Deserialize;
use serde::Serialize;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub(crate) struct SessionId(pub u32);
pub struct SessionId(pub u32);

View File

@@ -5,6 +5,7 @@ use std::sync::Arc;
use std::sync::Mutex as StdMutex;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::AtomicU32;
use std::vec::Vec;
use portable_pty::CommandBuilder;
use portable_pty::PtySize;
@@ -28,6 +29,7 @@ pub struct SessionManager {
sessions: Mutex<HashMap<SessionId, ExecCommandSession>>,
}
#[allow(dead_code)]
#[derive(Debug)]
pub struct ExecCommandOutput {
wall_time: Duration,
@@ -37,7 +39,7 @@ pub struct ExecCommandOutput {
}
impl ExecCommandOutput {
pub(crate) fn to_text_output(&self) -> String {
pub fn to_text_output(&self) -> String {
let wall_time_secs = self.wall_time.as_secs_f32();
let termination_status = match self.exit_status {
ExitStatus::Exited(code) => format!("Process exited with code {code}"),
@@ -61,6 +63,7 @@ Output:
}
}
#[allow(dead_code)]
#[derive(Debug)]
pub enum ExitStatus {
Exited(i32),
@@ -79,7 +82,11 @@ impl SessionManager {
.fetch_add(1, std::sync::atomic::Ordering::SeqCst),
);
let (session, mut output_rx, mut exit_rx) = create_exec_command_session(params.clone())
let (session, mut output_rx, mut exit_rx): (
ExecCommandSession,
tokio::sync::broadcast::Receiver<Vec<u8>>,
tokio::sync::oneshot::Receiver<i32>,
) = create_exec_command_session(params.clone())
.await
.map_err(|err| {
format!(

View File

@@ -0,0 +1,7 @@
use thiserror::Error;
#[derive(Debug, Error, PartialEq)]
pub enum FunctionCallError {
#[error("{0}")]
RespondToModel(String),
}

48
codex-rs/agent/src/lib.rs Normal file
View File

@@ -0,0 +1,48 @@
pub mod apply_patch;
pub mod bash;
pub mod command_safety;
pub mod config_types;
pub mod conversation_history;
pub mod exec_command;
pub mod function_tool;
pub mod model_family;
pub mod model_provider;
pub mod notifications;
pub mod rollout;
pub mod runtime;
pub mod runtime_config;
pub mod safety;
pub mod sandbox;
pub mod services;
pub mod session_services;
pub mod session_state;
pub mod shell;
pub mod token_data;
pub mod tooling;
pub mod truncate;
pub mod turn_diff_tracker;
pub mod unified_exec;
pub use apply_patch::*;
pub use bash::*;
pub use command_safety::*;
pub use config_types::*;
pub use conversation_history::*;
pub use function_tool::*;
pub use model_family::*;
pub use model_provider::*;
pub use notifications::*;
pub use rollout::*;
pub use runtime::*;
pub use runtime_config::*;
pub use safety::*;
pub use sandbox::*;
pub use services::*;
pub use session_services::*;
pub use session_state::*;
pub use shell::*;
pub use token_data::*;
pub use tooling::*;
pub use truncate::*;
pub use turn_diff_tracker::*;
pub use unified_exec::*;

View File

@@ -0,0 +1,15 @@
use crate::config_types::ReasoningSummaryFormat;
use crate::tooling::ApplyPatchToolType;
/// Metadata describing consistent behaviour across a family of models.
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct ModelFamily {
pub slug: String,
pub family: String,
pub needs_special_apply_patch_instructions: bool,
pub supports_reasoning_summaries: bool,
pub reasoning_summary_format: ReasoningSummaryFormat,
pub uses_local_shell_tool: bool,
pub apply_patch_tool_type: Option<ApplyPatchToolType>,
pub base_instructions: String,
}

View File

@@ -0,0 +1,54 @@
use std::collections::HashMap;
use codex_protocol::mcp_protocol::AuthMode;
use serde::Deserialize;
use serde::Serialize;
/// Wire protocol variants supported by model providers.
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum WireApi {
Responses,
#[default]
Chat,
}
/// Serializable representation of a provider definition shared across hosts.
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
pub struct ModelProviderInfo {
pub name: String,
pub base_url: Option<String>,
pub env_key: Option<String>,
pub env_key_instructions: Option<String>,
#[serde(default)]
pub wire_api: WireApi,
pub query_params: Option<HashMap<String, String>>,
pub http_headers: Option<HashMap<String, String>>,
pub env_http_headers: Option<HashMap<String, String>>,
pub request_max_retries: Option<u64>,
pub stream_max_retries: Option<u64>,
pub stream_idle_timeout_ms: Option<u64>,
#[serde(default)]
pub requires_openai_auth: bool,
}
impl ModelProviderInfo {
pub fn wire_api(&self) -> WireApi {
self.wire_api
}
pub fn requires_auth(&self) -> bool {
self.requires_openai_auth
}
pub fn base_url(&self, auth_mode: AuthMode) -> String {
let fallback = if auth_mode == AuthMode::ChatGPT {
"https://chatgpt.com/backend-api/codex"
} else {
"https://api.openai.com/v1"
};
self.base_url
.clone()
.unwrap_or_else(|| fallback.to_string())
}
}

View File

@@ -0,0 +1,15 @@
use serde::Serialize;
/// Cross-host notification payloads emitted by the agent runtime.
#[derive(Debug, Clone, PartialEq, Serialize)]
#[serde(tag = "type", rename_all = "kebab-case")]
pub enum UserNotification {
#[serde(rename_all = "kebab-case")]
AgentTurnComplete {
turn_id: String,
/// Messages submitted by the user to start the turn.
input_messages: Vec<String>,
/// Final assistant message emitted at turn completion.
last_assistant_message: Option<String>,
},
}

View File

@@ -1,9 +1,13 @@
use std::cmp::Reverse;
use std::io::{self};
use std::io;
use std::path::Path;
use std::path::PathBuf;
use codex_file_search as file_search;
use codex_protocol::protocol::EventMsg;
use codex_protocol::protocol::RolloutItem;
use codex_protocol::protocol::RolloutLine;
use serde_json::Value;
use std::num::NonZero;
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
@@ -11,40 +15,29 @@ use time::OffsetDateTime;
use time::PrimitiveDateTime;
use time::format_description::FormatItem;
use time::macros::format_description;
use tokio::fs;
use tokio::io::AsyncBufReadExt;
use uuid::Uuid;
use super::SESSIONS_SUBDIR;
use crate::protocol::EventMsg;
use codex_protocol::protocol::RolloutItem;
use codex_protocol::protocol::RolloutLine;
/// Returned page of conversation summaries.
#[derive(Debug, Default, PartialEq)]
pub struct ConversationsPage {
/// Conversation summaries ordered newest first.
pub items: Vec<ConversationItem>,
/// Opaque pagination token to resume after the last item, or `None` if end.
pub next_cursor: Option<Cursor>,
/// Total number of files touched while scanning this request.
pub num_scanned_files: usize,
/// True if a hard scan cap was hit; consider resuming with `next_cursor`.
pub reached_scan_cap: bool,
}
/// Summary information for a conversation rollout file.
#[derive(Debug, PartialEq)]
pub struct ConversationItem {
/// Absolute path to the rollout file.
pub path: PathBuf,
/// First up to 5 JSONL records parsed as JSON (includes meta line).
pub head: Vec<serde_json::Value>,
pub head: Vec<Value>,
}
/// Hard cap to bound worstcase work per request.
const MAX_SCAN_FILES: usize = 100;
const HEAD_RECORD_LIMIT: usize = 10;
/// Pagination cursor identifying a file by timestamp and UUID.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Cursor {
ts: OffsetDateTime,
@@ -82,10 +75,7 @@ impl<'de> serde::Deserialize<'de> for Cursor {
}
}
/// Retrieve recorded conversation file paths with token pagination. The returned `next_cursor`
/// can be supplied on the next call to resume after the last returned item, resilient to
/// concurrent new sessions being appended. Ordering is stable by timestamp desc, then UUID desc.
pub(crate) async fn get_conversations(
pub async fn get_conversations(
codex_home: &Path,
page_size: usize,
cursor: Option<&Cursor>,
@@ -94,31 +84,57 @@ pub(crate) async fn get_conversations(
root.push(SESSIONS_SUBDIR);
if !root.exists() {
return Ok(ConversationsPage {
items: Vec::new(),
next_cursor: None,
num_scanned_files: 0,
reached_scan_cap: false,
});
return Ok(ConversationsPage::default());
}
let anchor = cursor.cloned();
let result = traverse_directories_for_paths(root.clone(), page_size, anchor).await?;
Ok(result)
traverse_directories_for_paths(root, page_size, anchor).await
}
/// Load the full contents of a single conversation session file at `path`.
/// Returns the entire file contents as a String.
#[allow(dead_code)]
pub(crate) async fn get_conversation(path: &Path) -> io::Result<String> {
tokio::fs::read_to_string(path).await
pub async fn get_conversation(path: &Path) -> io::Result<String> {
fs::read_to_string(path).await
}
pub async fn find_conversation_path_by_id_str(
codex_home: &Path,
id_str: &str,
) -> io::Result<Option<PathBuf>> {
if Uuid::parse_str(id_str).is_err() {
return Ok(None);
}
let mut root = codex_home.to_path_buf();
root.push(SESSIONS_SUBDIR);
if !root.exists() {
return Ok(None);
}
let limit = NonZero::new(1).ok_or_else(|| io::Error::other("search limit must be non-zero"))?;
let threads =
NonZero::new(2).ok_or_else(|| io::Error::other("thread pool size must be non-zero"))?;
let cancel = Arc::new(AtomicBool::new(false));
let exclude: Vec<String> = Vec::new();
let compute_indices = false;
let results = file_search::run(
id_str,
limit,
&root,
exclude,
threads,
cancel,
compute_indices,
)
.map_err(|e| io::Error::other(format!("file search failed: {e}")))?;
Ok(results
.matches
.into_iter()
.next()
.map(|m| root.join(m.path)))
}
/// Load conversation file paths from disk using directory traversal.
///
/// Directory layout: `~/.codex/sessions/YYYY/MM/DD/rollout-YYYY-MM-DDThh-mm-ss-<uuid>.jsonl`
/// Returned newest (latest) first.
async fn traverse_directories_for_paths(
root: PathBuf,
page_size: usize,
@@ -157,8 +173,7 @@ async fn traverse_directories_for_paths(
.map(|(ts, id)| (ts, id, name_str.to_string(), path.to_path_buf()))
})
.await?;
// Stable ordering within the same second: (timestamp desc, uuid desc)
day_files.sort_by_key(|(ts, sid, _name_str, _path)| (Reverse(*ts), Reverse(*sid)));
day_files.sort_by_key(|(ts, sid, _, _)| (Reverse(*ts), Reverse(*sid)));
for (ts, sid, _name_str, path) in day_files.into_iter() {
scanned_files += 1;
if scanned_files >= MAX_SCAN_FILES && items.len() >= page_size {
@@ -174,13 +189,10 @@ async fn traverse_directories_for_paths(
if items.len() == page_size {
break 'outer;
}
// Read head and simultaneously detect message events within the same
// first N JSONL records to avoid a second file read.
let (head, saw_session_meta, saw_user_event) =
read_head_and_flags(&path, HEAD_RECORD_LIMIT)
.await
.unwrap_or((Vec::new(), false, false));
// Apply filters: must have session meta and at least one user message event
if saw_session_meta && saw_user_event {
items.push(ConversationItem { path, head });
}
@@ -198,23 +210,6 @@ async fn traverse_directories_for_paths(
})
}
/// Pagination cursor token format: "<file_ts>|<uuid>" where `file_ts` matches the
/// filename timestamp portion (YYYY-MM-DDThh-mm-ss) used in rollout filenames.
/// The cursor orders files by timestamp desc, then UUID desc.
fn parse_cursor(token: &str) -> Option<Cursor> {
let (file_ts, uuid_str) = token.split_once('|')?;
let Ok(uuid) = Uuid::parse_str(uuid_str) else {
return None;
};
let format: &[FormatItem] =
format_description!("[year]-[month]-[day]T[hour]-[minute]-[second]");
let ts = PrimitiveDateTime::parse(file_ts, format).ok()?.assume_utc();
Some(Cursor::new(ts, uuid))
}
fn build_next_cursor(items: &[ConversationItem]) -> Option<Cursor> {
let last = items.last()?;
let file_name = last.path.file_name()?.to_string_lossy();
@@ -222,14 +217,12 @@ fn build_next_cursor(items: &[ConversationItem]) -> Option<Cursor> {
Some(Cursor::new(ts, id))
}
/// Collects immediate subdirectories of `parent`, parses their (string) names with `parse`,
/// and returns them sorted descending by the parsed key.
async fn collect_dirs_desc<T, F>(parent: &Path, parse: F) -> io::Result<Vec<(T, PathBuf)>>
where
T: Ord + Copy,
F: Fn(&str) -> Option<T>,
{
let mut dir = tokio::fs::read_dir(parent).await?;
let mut dir = fs::read_dir(parent).await?;
let mut vec: Vec<(T, PathBuf)> = Vec::new();
while let Some(entry) = dir.next_entry().await? {
if entry
@@ -247,12 +240,11 @@ where
Ok(vec)
}
/// Collects files in a directory and parses them with `parse`.
async fn collect_files<T, F>(parent: &Path, parse: F) -> io::Result<Vec<T>>
where
F: Fn(&str, &Path) -> Option<T>,
{
let mut dir = tokio::fs::read_dir(parent).await?;
let mut dir = fs::read_dir(parent).await?;
let mut collected: Vec<T> = Vec::new();
while let Some(entry) = dir.next_entry().await? {
if entry
@@ -270,15 +262,11 @@ where
}
fn parse_timestamp_uuid_from_filename(name: &str) -> Option<(OffsetDateTime, Uuid)> {
// Expected: rollout-YYYY-MM-DDThh-mm-ss-<uuid>.jsonl
let core = name.strip_prefix("rollout-")?.strip_suffix(".jsonl")?;
// Scan from the right for a '-' such that the suffix parses as a UUID.
let (sep_idx, uuid) = core
.match_indices('-')
.rev()
.find_map(|(i, _)| Uuid::parse_str(&core[i + 1..]).ok().map(|u| (i, u)))?;
let ts_str = &core[..sep_idx];
let format: &[FormatItem] =
format_description!("[year]-[month]-[day]T[hour]-[minute]-[second]");
@@ -286,16 +274,23 @@ fn parse_timestamp_uuid_from_filename(name: &str) -> Option<(OffsetDateTime, Uui
Some((ts, uuid))
}
fn parse_cursor(token: &str) -> Option<Cursor> {
let (file_ts, uuid_str) = token.split_once('|')?;
let uuid = Uuid::parse_str(uuid_str).ok()?;
let format: &[FormatItem] =
format_description!("[year]-[month]-[day]T[hour]-[minute]-[second]");
let ts = PrimitiveDateTime::parse(file_ts, format).ok()?.assume_utc();
Some(Cursor::new(ts, uuid))
}
async fn read_head_and_flags(
path: &Path,
max_records: usize,
) -> io::Result<(Vec<serde_json::Value>, bool, bool)> {
use tokio::io::AsyncBufReadExt;
) -> io::Result<(Vec<Value>, bool, bool)> {
let file = tokio::fs::File::open(path).await?;
let reader = tokio::io::BufReader::new(file);
let mut lines = reader.lines();
let mut head: Vec<serde_json::Value> = Vec::new();
let mut head: Vec<Value> = Vec::new();
let mut saw_session_meta = false;
let mut saw_user_event = false;
@@ -322,12 +317,7 @@ async fn read_head_and_flags(
head.push(val);
}
}
RolloutItem::TurnContext(_) => {
// Not included in `head`; skip.
}
RolloutItem::Compacted(_) => {
// Not included in `head`; skip.
}
RolloutItem::TurnContext(_) | RolloutItem::Compacted(_) => {}
RolloutItem::EventMsg(ev) => {
if matches!(ev, EventMsg::UserMessage(_)) {
saw_user_event = true;
@@ -338,48 +328,3 @@ async fn read_head_and_flags(
Ok((head, saw_session_meta, saw_user_event))
}
/// Locate a recorded conversation rollout file by its UUID string using the existing
/// paginated listing implementation. Returns `Ok(Some(path))` if found, `Ok(None)` if not present
/// or the id is invalid.
pub async fn find_conversation_path_by_id_str(
codex_home: &Path,
id_str: &str,
) -> io::Result<Option<PathBuf>> {
// Validate UUID format early.
if Uuid::parse_str(id_str).is_err() {
return Ok(None);
}
let mut root = codex_home.to_path_buf();
root.push(SESSIONS_SUBDIR);
if !root.exists() {
return Ok(None);
}
// This is safe because we know the values are valid.
#[allow(clippy::unwrap_used)]
let limit = NonZero::new(1).unwrap();
// This is safe because we know the values are valid.
#[allow(clippy::unwrap_used)]
let threads = NonZero::new(2).unwrap();
let cancel = Arc::new(AtomicBool::new(false));
let exclude: Vec<String> = Vec::new();
let compute_indices = false;
let results = file_search::run(
id_str,
limit,
&root,
exclude,
threads,
cancel,
compute_indices,
)
.map_err(|e| io::Error::other(format!("file search failed: {e}")))?;
Ok(results
.matches
.into_iter()
.next()
.map(|m| root.join(m.path)))
}

View File

@@ -0,0 +1,11 @@
pub const SESSIONS_SUBDIR: &str = "sessions";
pub const ARCHIVED_SESSIONS_SUBDIR: &str = "archived_sessions";
pub mod list;
pub mod policy;
pub mod recorder;
pub use recorder::GitInfoCollector;
pub use recorder::RolloutConfig;
pub use recorder::RolloutRecorder;
pub use recorder::RolloutRecorderParams;

View File

@@ -1,14 +1,13 @@
use crate::protocol::EventMsg;
use crate::protocol::RolloutItem;
use codex_protocol::models::ResponseItem;
use codex_protocol::protocol::EventMsg;
use codex_protocol::protocol::RolloutItem;
/// Whether a rollout `item` should be persisted in rollout files.
#[inline]
pub(crate) fn is_persisted_response_item(item: &RolloutItem) -> bool {
pub fn is_persisted_response_item(item: &RolloutItem) -> bool {
match item {
RolloutItem::ResponseItem(item) => should_persist_response_item(item),
RolloutItem::EventMsg(ev) => should_persist_event_msg(ev),
// Persist Codex executive markers so we can analyze flows (e.g., compaction, API turns).
RolloutItem::Compacted(_) | RolloutItem::TurnContext(_) | RolloutItem::SessionMeta(_) => {
true
}
@@ -17,7 +16,7 @@ pub(crate) fn is_persisted_response_item(item: &RolloutItem) -> bool {
/// Whether a `ResponseItem` should be persisted in rollout files.
#[inline]
pub(crate) fn should_persist_response_item(item: &ResponseItem) -> bool {
pub fn should_persist_response_item(item: &ResponseItem) -> bool {
match item {
ResponseItem::Message { .. }
| ResponseItem::Reasoning { .. }
@@ -33,7 +32,7 @@ pub(crate) fn should_persist_response_item(item: &ResponseItem) -> bool {
/// Whether an `EventMsg` should be persisted in rollout files.
#[inline]
pub(crate) fn should_persist_event_msg(ev: &EventMsg) -> bool {
pub fn should_persist_event_msg(ev: &EventMsg) -> bool {
match ev {
EventMsg::UserMessage(_)
| EventMsg::AgentMessage(_)

View File

@@ -1,19 +1,26 @@
//! Persist Codex session rollouts (.jsonl) so sessions can be replayed or inspected later.
use std::fs;
use std::fs::File;
use std::fs::{self};
use std::io::Error as IoError;
use std::path::Path;
use std::path::PathBuf;
use std::sync::Arc;
use async_trait::async_trait;
use codex_protocol::mcp_protocol::ConversationId;
use codex_protocol::protocol::GitInfo;
use codex_protocol::protocol::InitialHistory;
use codex_protocol::protocol::ResumedHistory;
use codex_protocol::protocol::RolloutItem;
use codex_protocol::protocol::RolloutLine;
use codex_protocol::protocol::SessionMeta;
use codex_protocol::protocol::SessionMetaLine;
use serde_json::Value;
use time::OffsetDateTime;
use time::format_description::FormatItem;
use time::macros::format_description;
use tokio::io::AsyncWriteExt;
use tokio::sync::mpsc;
use tokio::sync::mpsc::Sender;
use tokio::sync::mpsc::{self};
use tokio::sync::oneshot;
use tracing::info;
use tracing::warn;
@@ -23,35 +30,31 @@ use super::list::ConversationsPage;
use super::list::Cursor;
use super::list::get_conversations;
use super::policy::is_persisted_response_item;
use crate::config::Config;
use crate::default_client::ORIGINATOR;
use crate::git_info::collect_git_info;
use codex_protocol::protocol::InitialHistory;
use codex_protocol::protocol::ResumedHistory;
use codex_protocol::protocol::RolloutItem;
use codex_protocol::protocol::RolloutLine;
use codex_protocol::protocol::SessionMeta;
use codex_protocol::protocol::SessionMetaLine;
/// Records all [`ResponseItem`]s for a session and flushes them to disk after
/// every update.
///
/// Rollouts are recorded as JSONL and can be inspected with tools such as:
///
/// ```ignore
/// $ jq -C . ~/.codex/sessions/rollout-2025-05-07T17-24-21-5973b6c0-94b8-487b-a530-2aeb6098ae0e.jsonl
/// $ fx ~/.codex/sessions/rollout-2025-05-07T17-24-21-5973b6c0-94b8-487b-a530-2aeb6098ae0e.jsonl
/// ```
#[async_trait]
pub trait GitInfoCollector: Send + Sync {
async fn collect(&self, cwd: &Path) -> Option<GitInfo>;
}
#[derive(Clone)]
pub struct RolloutConfig {
pub codex_home: PathBuf,
pub originator: String,
pub cli_version: String,
pub git_info_collector: Option<Arc<dyn GitInfoCollector>>,
}
#[derive(Clone)]
pub struct RolloutRecorder {
tx: Sender<RolloutCmd>,
pub(crate) rollout_path: PathBuf,
rollout_path: PathBuf,
}
#[derive(Clone)]
pub enum RolloutRecorderParams {
Create {
conversation_id: ConversationId,
cwd: PathBuf,
instructions: Option<String>,
},
Resume {
@@ -61,19 +64,19 @@ pub enum RolloutRecorderParams {
enum RolloutCmd {
AddItems(Vec<RolloutItem>),
/// Ensure all prior writes are processed; respond when flushed.
Flush {
ack: oneshot::Sender<()>,
},
Shutdown {
ack: oneshot::Sender<()>,
},
Flush { ack: oneshot::Sender<()> },
Shutdown { ack: oneshot::Sender<()> },
}
impl RolloutRecorderParams {
pub fn new(conversation_id: ConversationId, instructions: Option<String>) -> Self {
pub fn new(
conversation_id: ConversationId,
cwd: PathBuf,
instructions: Option<String>,
) -> Self {
Self::Create {
conversation_id,
cwd,
instructions,
}
}
@@ -84,7 +87,6 @@ impl RolloutRecorderParams {
}
impl RolloutRecorder {
/// List conversations (rollout files) under the provided Codex home directory.
pub async fn list_conversations(
codex_home: &Path,
page_size: usize,
@@ -93,13 +95,14 @@ impl RolloutRecorder {
get_conversations(codex_home, page_size, cursor).await
}
/// Attempt to create a new [`RolloutRecorder`]. If the sessions directory
/// cannot be created or the rollout file cannot be opened we return the
/// error so the caller can decide whether to disable persistence.
pub async fn new(config: &Config, params: RolloutRecorderParams) -> std::io::Result<Self> {
let (file, rollout_path, meta) = match params {
pub async fn new(
config: &RolloutConfig,
params: RolloutRecorderParams,
) -> std::io::Result<Self> {
let (file, rollout_path, meta, cwd) = match params {
RolloutRecorderParams::Create {
conversation_id,
cwd,
instructions,
} => {
let LogFileInfo {
@@ -107,7 +110,7 @@ impl RolloutRecorder {
path,
conversation_id: session_id,
timestamp,
} = create_log_file(config, conversation_id)?;
} = create_log_file(&config.codex_home, conversation_id)?;
let timestamp_format: &[FormatItem] = format_description!(
"[year]-[month]-[day]T[hour]:[minute]:[second].[subsecond digits:3]Z"
@@ -117,18 +120,16 @@ impl RolloutRecorder {
.format(timestamp_format)
.map_err(|e| IoError::other(format!("failed to format timestamp: {e}")))?;
(
tokio::fs::File::from_std(file),
path,
Some(SessionMeta {
id: session_id,
timestamp,
cwd: config.cwd.clone(),
originator: ORIGINATOR.value.clone(),
cli_version: env!("CARGO_PKG_VERSION").to_string(),
instructions,
}),
)
let meta = SessionMeta {
id: session_id,
timestamp,
cwd: cwd.clone(),
originator: config.originator.clone(),
cli_version: config.cli_version.clone(),
instructions,
};
(tokio::fs::File::from_std(file), path, Some(meta), Some(cwd))
}
RolloutRecorderParams::Resume { path } => (
tokio::fs::OpenOptions::new()
@@ -137,31 +138,21 @@ impl RolloutRecorder {
.await?,
path,
None,
None,
),
};
// Clone the cwd for the spawned task to collect git info asynchronously
let cwd = config.cwd.clone();
// A reasonably-sized bounded channel. If the buffer fills up the send
// future will yield, which is fine we only need to ensure we do not
// perform *blocking* I/O on the caller's thread.
let (tx, rx) = mpsc::channel::<RolloutCmd>(256);
let collector = config.git_info_collector.clone();
// Spawn a Tokio task that owns the file handle and performs async
// writes. Using `tokio::fs::File` keeps everything on the async I/O
// driver instead of blocking the runtime.
tokio::task::spawn(rollout_writer(file, rx, meta, cwd));
tokio::task::spawn(rollout_writer(file, rx, meta, cwd, collector));
Ok(Self { tx, rollout_path })
}
pub(crate) async fn record_items(&self, items: &[RolloutItem]) -> std::io::Result<()> {
pub async fn record_items(&self, items: &[RolloutItem]) -> std::io::Result<()> {
let mut filtered = Vec::new();
for item in items {
// Note that function calls may look a bit strange if they are
// "fully qualified MCP tool calls," so we could consider
// reformatting them in that case.
if is_persisted_response_item(item) {
filtered.push(item.clone());
}
@@ -175,7 +166,6 @@ impl RolloutRecorder {
.map_err(|e| IoError::other(format!("failed to queue rollout items: {e}")))
}
/// Flush all queued writes and wait until they are committed by the writer task.
pub async fn flush(&self) -> std::io::Result<()> {
let (tx, rx) = oneshot::channel();
self.tx
@@ -186,7 +176,26 @@ impl RolloutRecorder {
.map_err(|e| IoError::other(format!("failed waiting for rollout flush: {e}")))
}
pub(crate) async fn get_rollout_history(path: &Path) -> std::io::Result<InitialHistory> {
pub async fn shutdown(&self) -> std::io::Result<()> {
let (tx_done, rx_done) = oneshot::channel();
match self.tx.send(RolloutCmd::Shutdown { ack: tx_done }).await {
Ok(_) => rx_done
.await
.map_err(|e| IoError::other(format!("failed waiting for rollout shutdown: {e}"))),
Err(e) => {
warn!("failed to send rollout shutdown command: {e}");
Err(IoError::other(format!(
"failed to send rollout shutdown command: {e}"
)))
}
}
}
pub fn get_rollout_path(&self) -> PathBuf {
self.rollout_path.clone()
}
pub async fn get_rollout_history(path: &Path) -> std::io::Result<InitialHistory> {
info!("Resuming rollout from {path:?}");
let text = tokio::fs::read_to_string(path).await?;
if text.trim().is_empty() {
@@ -207,33 +216,17 @@ impl RolloutRecorder {
}
};
// Parse the rollout line structure
match serde_json::from_value::<RolloutLine>(v.clone()) {
Ok(rollout_line) => match rollout_line.item {
RolloutItem::SessionMeta(session_meta_line) => {
// Use the FIRST SessionMeta encountered in the file as the canonical
// conversation id and main session information. Keep all items intact.
if conversation_id.is_none() {
conversation_id = Some(session_meta_line.meta.id);
}
items.push(RolloutItem::SessionMeta(session_meta_line));
}
RolloutItem::ResponseItem(item) => {
items.push(RolloutItem::ResponseItem(item));
}
RolloutItem::Compacted(item) => {
items.push(RolloutItem::Compacted(item));
}
RolloutItem::TurnContext(item) => {
items.push(RolloutItem::TurnContext(item));
}
RolloutItem::EventMsg(_ev) => {
items.push(RolloutItem::EventMsg(_ev));
}
other => items.push(other),
},
Err(e) => {
warn!("failed to parse rollout line: {v:?}, error: {e}");
}
Err(e) => warn!("failed to parse rollout line: {v:?}, error: {e}"),
}
}
@@ -256,57 +249,28 @@ impl RolloutRecorder {
rollout_path: path.to_path_buf(),
}))
}
pub(crate) fn get_rollout_path(&self) -> PathBuf {
self.rollout_path.clone()
}
pub async fn shutdown(&self) -> std::io::Result<()> {
let (tx_done, rx_done) = oneshot::channel();
match self.tx.send(RolloutCmd::Shutdown { ack: tx_done }).await {
Ok(_) => rx_done
.await
.map_err(|e| IoError::other(format!("failed waiting for rollout shutdown: {e}"))),
Err(e) => {
warn!("failed to send rollout shutdown command: {e}");
Err(IoError::other(format!(
"failed to send rollout shutdown command: {e}"
)))
}
}
}
}
struct LogFileInfo {
/// Opened file handle to the rollout file.
file: File,
/// Full path to the rollout file.
path: PathBuf,
/// Session ID (also embedded in filename).
conversation_id: ConversationId,
/// Timestamp for the start of the session.
timestamp: OffsetDateTime,
}
fn create_log_file(
config: &Config,
codex_home: &Path,
conversation_id: ConversationId,
) -> std::io::Result<LogFileInfo> {
// Resolve ~/.codex/sessions/YYYY/MM/DD and create it if missing.
let timestamp = OffsetDateTime::now_local()
.map_err(|e| IoError::other(format!("failed to get local time: {e}")))?;
let mut dir = config.codex_home.clone();
let mut dir = codex_home.to_path_buf();
dir.push(SESSIONS_SUBDIR);
dir.push(timestamp.year().to_string());
dir.push(format!("{:02}", u8::from(timestamp.month())));
dir.push(format!("{:02}", timestamp.day()));
fs::create_dir_all(&dir)?;
// Custom format for YYYY-MM-DDThh-mm-ss. Use `-` instead of `:` for
// compatibility with filesystems that do not allow colons in filenames.
let format: &[FormatItem] =
format_description!("[year]-[month]-[day]T[hour]-[minute]-[second]");
let date_str = timestamp
@@ -314,7 +278,6 @@ fn create_log_file(
.map_err(|e| IoError::other(format!("failed to format timestamp: {e}")))?;
let filename = format!("rollout-{date_str}-{conversation_id}.jsonl");
let path = dir.join(filename);
let file = std::fs::OpenOptions::new()
.append(true)
@@ -333,25 +296,27 @@ async fn rollout_writer(
file: tokio::fs::File,
mut rx: mpsc::Receiver<RolloutCmd>,
mut meta: Option<SessionMeta>,
cwd: std::path::PathBuf,
cwd: Option<PathBuf>,
git_info_collector: Option<Arc<dyn GitInfoCollector>>,
) -> std::io::Result<()> {
let mut writer = JsonlWriter { file };
// If we have a meta, collect git info asynchronously and write meta first
if let Some(session_meta) = meta.take() {
let git_info = collect_git_info(&cwd).await;
let git_info =
if let (Some(provider), Some(cwd)) = (git_info_collector.as_ref(), cwd.as_ref()) {
provider.collect(cwd.as_path()).await
} else {
None
};
let session_meta_line = SessionMetaLine {
meta: session_meta,
git: git_info,
};
// Write the SessionMeta as the first item in the file, wrapped in a rollout line
writer
.write_rollout_item(RolloutItem::SessionMeta(session_meta_line))
.await?;
}
// Process rollout commands
while let Some(cmd) = rx.recv().await {
match cmd {
RolloutCmd::AddItems(items) => {
@@ -362,7 +327,6 @@ async fn rollout_writer(
}
}
RolloutCmd::Flush { ack } => {
// Ensure underlying file is flushed and then ack.
if let Err(e) = writer.file.flush().await {
let _ = ack.send(());
return Err(e);
@@ -397,11 +361,14 @@ impl JsonlWriter {
};
self.write_line(&line).await
}
async fn write_line(&mut self, item: &impl serde::Serialize) -> std::io::Result<()> {
let mut json = serde_json::to_string(item)?;
json.push('\n');
self.file.write_all(json.as_bytes()).await?;
self.file.flush().await?;
Ok(())
let mut buf = serde_json::to_vec(item)
.map_err(|e| IoError::other(format!("failed to serialise rollout line: {e}")))?;
buf.push(b'\n');
self.file
.write_all(&buf)
.await
.map_err(|e| IoError::other(format!("failed to write rollout line: {e}")))
}
}

View File

@@ -0,0 +1,16 @@
use async_trait::async_trait;
use codex_protocol::protocol::Event;
use codex_protocol::protocol::Op;
use codex_protocol::protocol::Submission;
/// Minimal async interface for interacting with an agent runtime.
#[async_trait]
pub trait AgentRuntime: Send + Sync {
type Error: std::error::Error + Send + Sync + 'static;
async fn submit(&self, op: Op) -> Result<String, Self::Error>;
async fn submit_with_id(&self, submission: Submission) -> Result<(), Self::Error>;
async fn next_event(&self) -> Result<Event, Self::Error>;
}

View File

@@ -0,0 +1,46 @@
use std::collections::HashMap;
use std::path::PathBuf;
use crate::config_types::History;
use crate::config_types::McpServerConfig;
use crate::config_types::ShellEnvironmentPolicy;
use crate::model_family::ModelFamily;
use crate::model_provider::ModelProviderInfo;
use codex_protocol::config_types::ReasoningEffort;
use codex_protocol::config_types::ReasoningSummary;
use codex_protocol::config_types::Verbosity;
use codex_protocol::protocol::AskForApproval;
use codex_protocol::protocol::SandboxPolicy;
/// Configuration surface consumed by the agent runtime regardless of host.
#[derive(Debug, Clone, PartialEq)]
pub struct AgentConfig {
pub model: String,
pub review_model: String,
pub model_family: ModelFamily,
pub model_context_window: Option<u64>,
pub model_auto_compact_token_limit: Option<i64>,
pub model_reasoning_effort: Option<ReasoningEffort>,
pub model_reasoning_summary: ReasoningSummary,
pub model_verbosity: Option<Verbosity>,
pub model_provider: ModelProviderInfo,
pub approval_policy: AskForApproval,
pub sandbox_policy: SandboxPolicy,
pub shell_environment_policy: ShellEnvironmentPolicy,
pub user_instructions: Option<String>,
pub base_instructions: Option<String>,
pub notify: Option<Vec<String>>,
pub cwd: PathBuf,
pub codex_home: PathBuf,
pub history: History,
pub mcp_servers: HashMap<String, McpServerConfig>,
pub include_plan_tool: bool,
pub include_apply_patch_tool: bool,
pub include_view_image_tool: bool,
pub tools_web_search_request: bool,
pub use_experimental_streamable_shell_tool: bool,
pub use_experimental_unified_exec_tool: bool,
pub show_raw_agent_reasoning: bool,
pub codex_linux_sandbox_exe: Option<PathBuf>,
pub project_doc_max_bytes: usize,
}

View File

@@ -1,17 +1,14 @@
use std::collections::HashSet;
use std::path::Component;
use std::path::Path;
use std::path::PathBuf;
use codex_apply_patch::ApplyPatchAction;
use codex_apply_patch::ApplyPatchFileChange;
use crate::exec::SandboxType;
use codex_protocol::protocol::AskForApproval;
use codex_protocol::protocol::SandboxPolicy;
use crate::command_safety::is_dangerous_command::command_might_be_dangerous;
use crate::command_safety::is_safe_command::is_known_safe_command;
use crate::protocol::AskForApproval;
use crate::protocol::SandboxPolicy;
use crate::sandbox::SandboxType;
#[derive(Debug, PartialEq)]
pub enum SafetyCheck {
@@ -199,81 +196,196 @@ fn is_write_patch_constrained_to_writable_paths(
SandboxPolicy::DangerFullAccess => {
return true;
}
SandboxPolicy::WorkspaceWrite { .. } => sandbox_policy.get_writable_roots_with_cwd(cwd),
SandboxPolicy::WorkspaceWrite {
writable_roots,
exclude_slash_tmp: _exclude_slash_tmp,
exclude_tmpdir_env_var: _exclude_tmpdir,
network_access: _network_access,
} => writable_roots,
};
// Normalize a path by removing `.` and resolving `..` without touching the
// filesystem (works even if the file does not exist).
fn normalize(path: &Path) -> Option<PathBuf> {
let mut out = PathBuf::new();
for comp in path.components() {
match comp {
Component::ParentDir => {
out.pop();
}
Component::CurDir => { /* skip */ }
other => out.push(other.as_os_str()),
}
}
Some(out)
// If the policy allows writes outside the workspace (DangerFullAccess),
// we've already returned true above. At this point we only have
// `WorkspaceWrite`, which includes the cwd implicitly, so first check if
// the patch fully lives within the cwd. If it does then we're fine.
let workspace_root = cwd.canonicalize().unwrap_or_else(|_| cwd.to_path_buf());
if all_changes_within_root(action, &workspace_root) {
return true;
}
// Determine whether `path` is inside **any** writable root. Both `path`
// and roots are converted to absolute, normalized forms before the
// prefix check.
let is_path_writable = |p: &PathBuf| {
let abs = if p.is_absolute() {
p.clone()
} else {
cwd.join(p)
};
let abs = match normalize(&abs) {
Some(v) => v,
None => return false,
};
if writable_roots.is_empty() {
return false;
}
writable_roots
.iter()
.any(|writable_root| writable_root.is_path_writable(&abs))
};
// When `/tmp` is excluded, filter it out of writable roots. Some patch commands write
// temporary files there even for workspace-only updates.
let mut writable_roots: Vec<&PathBuf> = writable_roots.iter().collect();
if matches!(
sandbox_policy,
SandboxPolicy::WorkspaceWrite {
exclude_slash_tmp: true,
..
}
) {
writable_roots.retain(|path| !path.as_path().starts_with("/tmp"));
}
for (path, change) in action.changes() {
match change {
ApplyPatchFileChange::Add { .. } | ApplyPatchFileChange::Delete { .. } => {
if !is_path_writable(path) {
return false;
let mut all_within_declared_root = true;
for change in action.changes() {
match change.0.strip_prefix(&workspace_root) {
Ok(relative_path) => {
if !is_within_any_root(relative_path, &writable_roots) {
all_within_declared_root = false;
break;
}
}
ApplyPatchFileChange::Update { move_path, .. } => {
if !is_path_writable(path) {
return false;
}
if let Some(dest) = move_path
&& !is_path_writable(dest)
{
return false;
}
Err(_) => {
all_within_declared_root = false;
break;
}
}
}
true
all_within_declared_root
}
#[cfg(test)]
fn all_changes_within_root(action: &ApplyPatchAction, root: &Path) -> bool {
action
.changes()
.iter()
.all(|(path, _)| path.starts_with(root))
}
fn is_within_any_root(path: &Path, roots: &[&PathBuf]) -> bool {
roots.iter().any(|root| path.starts_with(root.as_path()))
}
#[cfg(any())]
mod tests {
use super::*;
use tempfile::TempDir;
#[test]
fn test_writable_roots_constraint() {
// Use a temporary directory as our workspace to avoid touching
// the real current working directory.
let tmp = TempDir::new().unwrap();
let cwd = tmp.path().to_path_buf();
fn reject_empty_patch() {
let action = ApplyPatchAction::new_for_test(vec![]);
let sandbox_policy = SandboxPolicy::ReadOnly;
let cwd = Path::new(".");
assert_eq!(
assess_patch_safety(&action, AskForApproval::OnRequest, &sandbox_policy, cwd),
SafetyCheck::Reject {
reason: "empty patch".to_string(),
}
);
}
#[test]
fn auto_allow_patch_in_workspace_write_sandbox() {
let patch_action = ApplyPatchAction::new_for_test(vec![ApplyPatchFileChange::new_update(
PathBuf::from("src/main.rs"),
"diff --git a/src/main.rs b/src/main.rs\n".to_string(),
None,
"".to_string(),
)]);
let sandbox_policy = SandboxPolicy::WorkspaceWrite {
writable_roots: vec![],
network_access: false,
exclude_tmpdir_env_var: false,
exclude_slash_tmp: false,
};
assert_eq!(
assess_patch_safety(
&patch_action,
AskForApproval::OnRequest,
&sandbox_policy,
Path::new("."),
),
SafetyCheck::AutoApprove {
sandbox_type: get_platform_sandbox().unwrap_or(SandboxType::None),
}
);
}
#[test]
fn reject_patch_if_policy_is_never_and_writes_outside_of_workspace() {
let patch_action = ApplyPatchAction::new_for_test(vec![ApplyPatchFileChange::new_update(
PathBuf::from("../outside_file.txt"),
"diff --git a/../outside_file.txt b/../outside_file.txt\n".to_string(),
None,
"".to_string(),
)]);
let sandbox_policy = SandboxPolicy::WorkspaceWrite {
writable_roots: vec![],
network_access: false,
exclude_tmpdir_env_var: false,
exclude_slash_tmp: false,
};
assert_eq!(
assess_patch_safety(
&patch_action,
AskForApproval::Never,
&sandbox_policy,
Path::new("."),
),
SafetyCheck::Reject {
reason: "writing outside of the project; rejected by user approval settings"
.to_string(),
}
);
}
#[test]
fn assess_command_safety_known_safe_command() {
let command = vec!["ls".to_string()];
let approval_policy = AskForApproval::OnRequest;
let sandbox_policy = SandboxPolicy::ReadOnly;
let approved = HashSet::new();
let request_escalated_privileges = false;
let safety_check = assess_command_safety(
&command,
approval_policy,
&sandbox_policy,
&approved,
request_escalated_privileges,
);
assert_eq!(
safety_check,
SafetyCheck::AutoApprove {
sandbox_type: SandboxType::None
}
);
}
#[test]
fn assess_command_safety_dangerous_command_to_reject() {
let command = vec!["rm".to_string(), "-rf".to_string(), "/".to_string()];
let approval_policy = AskForApproval::OnRequest;
let sandbox_policy = SandboxPolicy::ReadOnly;
let approved = HashSet::new();
let request_escalated_privileges = false;
let safety_check = assess_command_safety(
&command,
approval_policy,
&sandbox_policy,
&approved,
request_escalated_privileges,
);
assert_eq!(safety_check, SafetyCheck::AskUser);
}
#[test]
fn patch_within_declared_root() {
let tempdir = tempfile::tempdir().unwrap();
let cwd = tempdir.path().to_path_buf();
let parent = cwd.parent().unwrap().to_path_buf();
// Helper to build a singleentry patch that adds a file at `p`.
let make_add_change = |p: PathBuf| ApplyPatchAction::new_add_for_test(&p, "".to_string());
let add_inside = make_add_change(cwd.join("inner.txt"));

View File

@@ -0,0 +1,3 @@
pub mod types;
pub use types::SandboxType;

View File

@@ -0,0 +1,10 @@
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum SandboxType {
None,
/// Only available on macOS.
MacosSeatbelt,
/// Only available on Linux.
LinuxSeccomp,
}

View File

@@ -0,0 +1,138 @@
use std::collections::HashMap;
use std::path::PathBuf;
use async_trait::async_trait;
use codex_apply_patch::ApplyPatchAction;
use codex_protocol::mcp_protocol::AuthMode;
use codex_protocol::protocol::ReviewDecision;
use codex_protocol::protocol::RolloutItem;
use mcp_types::Tool;
use serde_json::Value;
use crate::exec_command::ExecCommandOutput;
use crate::exec_command::ExecCommandParams;
use crate::exec_command::WriteStdinParams;
use crate::notifications::UserNotification;
use crate::rollout::RolloutRecorder;
use crate::token_data::PlanType;
use crate::unified_exec::UnifiedExecError;
use crate::unified_exec::UnifiedExecRequest;
use crate::unified_exec::UnifiedExecResult;
/// Authentication context made available to the provider layer.
#[async_trait]
pub trait ProviderAuth: Send + Sync {
fn mode(&self) -> AuthMode;
async fn access_token(&self) -> std::io::Result<String>;
fn account_id(&self) -> Option<String>;
fn plan_type(&self) -> Option<PlanType>;
}
/// Provides access to credentials required when talking to model providers.
#[async_trait]
pub trait CredentialsProvider: Send + Sync {
fn auth(&self) -> Option<std::sync::Arc<dyn ProviderAuth>>;
async fn refresh_token(&self) -> std::io::Result<Option<String>>;
}
/// Emits user-facing notifications for turn completion or other events.
pub trait Notifier: Send + Sync {
fn notify(&self, notification: &UserNotification);
}
/// Runtime callbacks for user approval workflows.
#[async_trait]
pub trait ApprovalCoordinator: Send + Sync {
async fn request_patch_approval(
&self,
sub_id: String,
call_id: String,
action: &ApplyPatchAction,
reason: Option<String>,
grant_root: Option<PathBuf>,
) -> ReviewDecision;
async fn request_command_approval(
&self,
sub_id: String,
call_id: String,
command: Vec<String>,
cwd: PathBuf,
reason: Option<String>,
) -> ReviewDecision;
async fn add_approved_command(&self, command: Vec<String>);
}
/// Aggregates and dispatches MCP tool calls across configured servers.
#[async_trait]
pub trait McpInterface: Send + Sync {
fn list_all_tools(&self) -> HashMap<String, Tool>;
fn parse_tool_name(&self, tool_name: &str) -> Option<(String, String)>;
async fn call_tool(
&self,
server: &str,
tool: &str,
arguments: Option<Value>,
) -> anyhow::Result<mcp_types::CallToolResult>;
}
/// Persists rollout events for later inspection or replay.
#[async_trait]
pub trait RolloutSink: Send + Sync {
async fn record_items(&self, items: &[RolloutItem]) -> std::io::Result<()>;
async fn flush(&self) -> std::io::Result<()>;
async fn shutdown(&self) -> std::io::Result<()>;
fn get_rollout_path(&self) -> PathBuf;
}
#[async_trait]
impl RolloutSink for RolloutRecorder {
async fn record_items(&self, items: &[RolloutItem]) -> std::io::Result<()> {
RolloutRecorder::record_items(self, items).await
}
async fn flush(&self) -> std::io::Result<()> {
RolloutRecorder::flush(self).await
}
async fn shutdown(&self) -> std::io::Result<()> {
RolloutRecorder::shutdown(self).await
}
fn get_rollout_path(&self) -> PathBuf {
RolloutRecorder::get_rollout_path(self)
}
}
/// Handles sandboxed exec orchestration, including long-running sessions.
#[async_trait]
pub trait SandboxManager: Send + Sync {
async fn handle_exec_command_request(
&self,
params: ExecCommandParams,
) -> Result<ExecCommandOutput, String>;
async fn handle_write_stdin_request(
&self,
params: WriteStdinParams,
) -> Result<ExecCommandOutput, String>;
async fn handle_unified_exec_request(
&self,
request: UnifiedExecRequest<'_>,
) -> Result<UnifiedExecResult, UnifiedExecError>;
fn codex_linux_sandbox_exe(&self) -> &Option<PathBuf>;
fn user_shell(&self) -> &crate::shell::Shell;
}

View File

@@ -0,0 +1,18 @@
use std::sync::Arc;
use tokio::sync::Mutex;
use crate::services::McpInterface;
use crate::services::Notifier;
use crate::services::RolloutSink;
use crate::services::SandboxManager;
/// Aggregated services that back a running agent session. Hosts provide
/// implementations for these traits and hand them to the runtime at spawn.
pub struct SessionServices {
pub mcp: Arc<dyn McpInterface>,
pub notifier: Arc<dyn Notifier>,
pub sandbox: Arc<dyn SandboxManager>,
pub rollout: Mutex<Option<Arc<dyn RolloutSink>>>,
pub show_raw_agent_reasoning: bool,
}

View File

@@ -1,26 +1,24 @@
//! Session-wide mutable state.
use std::collections::HashSet;
use codex_protocol::models::ResponseItem;
use codex_protocol::protocol::RateLimitSnapshot;
use codex_protocol::protocol::TokenUsage;
use codex_protocol::protocol::TokenUsageInfo;
use crate::conversation_history::ConversationHistory;
use crate::protocol::RateLimitSnapshot;
use crate::protocol::TokenUsage;
use crate::protocol::TokenUsageInfo;
/// Persistent, session-scoped state previously stored directly on `Session`.
#[derive(Default)]
pub(crate) struct SessionState {
pub(crate) approved_commands: HashSet<Vec<String>>,
pub(crate) history: ConversationHistory,
pub(crate) token_info: Option<TokenUsageInfo>,
pub(crate) latest_rate_limits: Option<RateLimitSnapshot>,
pub struct SessionState {
approved_commands: HashSet<Vec<String>>,
history: ConversationHistory,
token_info: Option<TokenUsageInfo>,
latest_rate_limits: Option<RateLimitSnapshot>,
}
impl SessionState {
/// Create a new session state mirroring previous `State::default()` semantics.
pub(crate) fn new() -> Self {
pub fn new() -> Self {
Self {
history: ConversationHistory::new(),
..Default::default()
@@ -28,7 +26,7 @@ impl SessionState {
}
// History helpers
pub(crate) fn record_items<I>(&mut self, items: I)
pub fn record_items<I>(&mut self, items: I)
where
I: IntoIterator,
I::Item: std::ops::Deref<Target = ResponseItem>,
@@ -36,25 +34,25 @@ impl SessionState {
self.history.record_items(items)
}
pub(crate) fn history_snapshot(&self) -> Vec<ResponseItem> {
pub fn history_snapshot(&self) -> Vec<ResponseItem> {
self.history.contents()
}
pub(crate) fn replace_history(&mut self, items: Vec<ResponseItem>) {
pub fn replace_history(&mut self, items: Vec<ResponseItem>) {
self.history.replace(items);
}
// Approved command helpers
pub(crate) fn add_approved_command(&mut self, cmd: Vec<String>) {
pub fn add_approved_command(&mut self, cmd: Vec<String>) {
self.approved_commands.insert(cmd);
}
pub(crate) fn approved_commands_ref(&self) -> &HashSet<Vec<String>> {
pub fn approved_commands_ref(&self) -> &HashSet<Vec<String>> {
&self.approved_commands
}
// Token/rate limit helpers
pub(crate) fn update_token_info_from_usage(
pub fn update_token_info_from_usage(
&mut self,
usage: &TokenUsage,
model_context_window: Option<u64>,
@@ -66,15 +64,13 @@ impl SessionState {
);
}
pub(crate) fn set_rate_limits(&mut self, snapshot: RateLimitSnapshot) {
pub fn set_rate_limits(&mut self, snapshot: RateLimitSnapshot) {
self.latest_rate_limits = Some(snapshot);
}
pub(crate) fn token_info_and_rate_limits(
pub fn token_info_and_rate_limits(
&self,
) -> (Option<TokenUsageInfo>, Option<RateLimitSnapshot>) {
(self.token_info.clone(), self.latest_rate_limits.clone())
}
// Pending input/approval moved to TurnState.
}

271
codex-rs/agent/src/shell.rs Normal file
View File

@@ -0,0 +1,271 @@
use serde::Deserialize;
use serde::Serialize;
use shlex;
use std::path::PathBuf;
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
pub struct ZshShell {
pub(crate) shell_path: String,
pub(crate) zshrc_path: String,
}
impl ZshShell {
pub fn new(shell_path: impl Into<String>, zshrc_path: impl Into<String>) -> Self {
Self {
shell_path: shell_path.into(),
zshrc_path: zshrc_path.into(),
}
}
pub fn shell_path(&self) -> &str {
&self.shell_path
}
pub fn zshrc_path(&self) -> &str {
&self.zshrc_path
}
}
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
pub struct BashShell {
pub(crate) shell_path: String,
pub(crate) bashrc_path: String,
}
impl BashShell {
pub fn new(shell_path: impl Into<String>, bashrc_path: impl Into<String>) -> Self {
Self {
shell_path: shell_path.into(),
bashrc_path: bashrc_path.into(),
}
}
pub fn shell_path(&self) -> &str {
&self.shell_path
}
pub fn bashrc_path(&self) -> &str {
&self.bashrc_path
}
}
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
pub struct PowerShellConfig {
pub(crate) exe: String, // Executable name or path, e.g. "pwsh" or "powershell.exe".
pub(crate) bash_exe_fallback: Option<PathBuf>, // In case the model generates a bash command.
}
impl PowerShellConfig {
pub fn new(exe: impl Into<String>, bash_exe_fallback: Option<PathBuf>) -> Self {
Self {
exe: exe.into(),
bash_exe_fallback,
}
}
pub fn exe(&self) -> &str {
&self.exe
}
pub fn bash_exe_fallback(&self) -> Option<&PathBuf> {
self.bash_exe_fallback.as_ref()
}
}
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
pub enum Shell {
Zsh(ZshShell),
Bash(BashShell),
PowerShell(PowerShellConfig),
Unknown,
}
impl Shell {
pub fn format_default_shell_invocation(&self, command: Vec<String>) -> Option<Vec<String>> {
match self {
Shell::Zsh(zsh) => format_shell_invocation_with_rc(
command.as_slice(),
&zsh.shell_path,
&zsh.zshrc_path,
),
Shell::Bash(bash) => format_shell_invocation_with_rc(
command.as_slice(),
&bash.shell_path,
&bash.bashrc_path,
),
Shell::PowerShell(ps) => {
// If model generated a bash command, prefer a detected bash fallback
if let Some(script) = strip_bash_lc(command.as_slice()) {
return match &ps.bash_exe_fallback {
Some(bash) => Some(vec![
bash.to_string_lossy().to_string(),
"-lc".to_string(),
script,
]),
// No bash fallback → run the script under PowerShell.
// It will likely fail (except for some simple commands), but the error
// should give a clue to the model to fix upon retry that it's running under PowerShell.
None => Some(vec![
ps.exe.clone(),
"-NoProfile".to_string(),
"-Command".to_string(),
script,
]),
};
}
// Not a bash command. If model did not generate a PowerShell command,
// turn it into a PowerShell command.
let first = command.first().map(String::as_str);
if first != Some(ps.exe.as_str()) {
// TODO (CODEX_2900): Handle escaping newlines.
if command.iter().any(|a| a.contains('\n') || a.contains('\r')) {
return Some(command);
}
let joined = shlex::try_join(command.iter().map(String::as_str)).ok();
return joined.map(|arg| {
vec![
ps.exe.clone(),
"-NoProfile".to_string(),
"-Command".to_string(),
arg,
]
});
}
// Model generated a PowerShell command. Run it.
Some(command)
}
Shell::Unknown => None,
}
}
pub fn name(&self) -> Option<String> {
match self {
Shell::Zsh(zsh) => std::path::Path::new(&zsh.shell_path)
.file_name()
.map(|s| s.to_string_lossy().to_string()),
Shell::Bash(bash) => std::path::Path::new(&bash.shell_path)
.file_name()
.map(|s| s.to_string_lossy().to_string()),
Shell::PowerShell(ps) => Some(ps.exe.clone()),
Shell::Unknown => None,
}
}
}
fn format_shell_invocation_with_rc(
command: &[String],
shell_path: &str,
rc_path: &str,
) -> Option<Vec<String>> {
let joined = strip_bash_lc(command)
.or_else(|| shlex::try_join(command.iter().map(String::as_str)).ok())?;
let rc_command = if std::path::Path::new(rc_path).exists() {
format!("source {rc_path} && ({joined})")
} else {
joined
};
Some(vec![shell_path.to_string(), "-lc".to_string(), rc_command])
}
fn strip_bash_lc(command: &[String]) -> Option<String> {
match command {
// exactly three items
[first, second, third]
// first two must be "bash", "-lc"
if first == "bash" && second == "-lc" =>
{
Some(third.clone())
}
_ => None,
}
}
#[cfg(unix)]
fn detect_default_user_shell() -> Shell {
use libc::getpwuid;
use libc::getuid;
use std::ffi::CStr;
unsafe {
let uid = getuid();
let pw = getpwuid(uid);
if !pw.is_null() {
let shell_path = CStr::from_ptr((*pw).pw_shell)
.to_string_lossy()
.into_owned();
let home_path = CStr::from_ptr((*pw).pw_dir).to_string_lossy().into_owned();
if shell_path.ends_with("/zsh") {
return Shell::Zsh(ZshShell {
shell_path,
zshrc_path: format!("{home_path}/.zshrc"),
});
}
if shell_path.ends_with("/bash") {
return Shell::Bash(BashShell {
shell_path,
bashrc_path: format!("{home_path}/.bashrc"),
});
}
}
}
Shell::Unknown
}
#[cfg(unix)]
pub async fn default_user_shell() -> Shell {
detect_default_user_shell()
}
#[cfg(target_os = "windows")]
pub async fn default_user_shell() -> Shell {
use tokio::process::Command;
// Prefer PowerShell 7+ (`pwsh`) if available, otherwise fall back to Windows PowerShell.
let has_pwsh = Command::new("pwsh")
.arg("-NoLogo")
.arg("-NoProfile")
.arg("-Command")
.arg("$PSVersionTable.PSVersion.Major")
.output()
.await
.map(|o| o.status.success())
.unwrap_or(false);
let bash_exe = if Command::new("bash.exe")
.arg("--version")
.output()
.await
.ok()
.map(|o| o.status.success())
.unwrap_or(false)
{
which::which("bash.exe").ok()
} else {
None
};
if has_pwsh {
Shell::PowerShell(PowerShellConfig {
exe: "pwsh.exe".to_string(),
bash_exe_fallback: bash_exe,
})
} else {
Shell::PowerShell(PowerShellConfig {
exe: "powershell.exe".to_string(),
bash_exe_fallback: bash_exe,
})
}
}
#[cfg(all(not(target_os = "windows"), not(unix)))]
pub async fn default_user_shell() -> Shell {
Shell::Unknown
}

View File

@@ -0,0 +1,182 @@
use base64::Engine;
use serde::Deserialize;
use serde::Serialize;
use thiserror::Error;
#[derive(Deserialize, Serialize, Clone, Debug, PartialEq, Default)]
pub struct TokenData {
/// Flat info parsed from the JWT in auth.json.
#[serde(
deserialize_with = "deserialize_id_token",
serialize_with = "serialize_id_token"
)]
pub id_token: IdTokenInfo,
/// This is a JWT.
pub access_token: String,
pub refresh_token: String,
pub account_id: Option<String>,
}
/// Flat subset of useful claims in id_token from auth.json.
#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
pub struct IdTokenInfo {
pub email: Option<String>,
/// The ChatGPT subscription plan type
/// (e.g., "free", "plus", "pro", "business", "enterprise", "edu").
/// (Note: values may vary by backend.)
pub chatgpt_plan_type: Option<PlanType>,
pub raw_jwt: String,
}
impl IdTokenInfo {
pub fn get_chatgpt_plan_type(&self) -> Option<String> {
self.chatgpt_plan_type.as_ref().map(|t| match t {
PlanType::Known(plan) => format!("{plan:?}"),
PlanType::Unknown(s) => s.clone(),
})
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(untagged)]
pub enum PlanType {
Known(KnownPlan),
Unknown(String),
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum KnownPlan {
Free,
Plus,
Pro,
Team,
Business,
Enterprise,
Edu,
}
#[derive(Deserialize)]
struct IdClaims {
#[serde(default)]
email: Option<String>,
#[serde(rename = "https://api.openai.com/auth", default)]
auth: Option<AuthClaims>,
}
#[derive(Deserialize)]
struct AuthClaims {
#[serde(default)]
chatgpt_plan_type: Option<PlanType>,
}
#[derive(Debug, Error)]
pub enum IdTokenInfoError {
#[error("invalid ID token format")]
InvalidFormat,
#[error(transparent)]
Base64(#[from] base64::DecodeError),
#[error(transparent)]
Json(#[from] serde_json::Error),
}
pub fn parse_id_token(id_token: &str) -> Result<IdTokenInfo, IdTokenInfoError> {
// JWT format: header.payload.signature
let mut parts = id_token.split('.');
let (_header_b64, payload_b64, _sig_b64) = match (parts.next(), parts.next(), parts.next()) {
(Some(h), Some(p), Some(s)) if !h.is_empty() && !p.is_empty() && !s.is_empty() => (h, p, s),
_ => return Err(IdTokenInfoError::InvalidFormat),
};
let payload_bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(payload_b64)?;
let claims: IdClaims = serde_json::from_slice(&payload_bytes)?;
Ok(IdTokenInfo {
email: claims.email,
chatgpt_plan_type: claims.auth.and_then(|a| a.chatgpt_plan_type),
raw_jwt: id_token.to_string(),
})
}
fn deserialize_id_token<'de, D>(deserializer: D) -> Result<IdTokenInfo, D::Error>
where
D: serde::Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
parse_id_token(&s).map_err(serde::de::Error::custom)
}
fn serialize_id_token<S>(id_token: &IdTokenInfo, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(&id_token.raw_jwt)
}
#[cfg(test)]
mod tests {
use super::*;
use serde::Serialize;
#[test]
fn id_token_info_parses_email_and_plan() {
#[derive(Serialize)]
struct Header {
alg: &'static str,
typ: &'static str,
}
let header = Header {
alg: "none",
typ: "JWT",
};
let payload = serde_json::json!({
"email": "user@example.com",
"https://api.openai.com/auth": {
"chatgpt_plan_type": "pro"
}
});
fn b64url_no_pad(bytes: &[u8]) -> String {
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes)
}
let header_b64 = b64url_no_pad(&serde_json::to_vec(&header).unwrap());
let payload_b64 = b64url_no_pad(&serde_json::to_vec(&payload).unwrap());
let signature_b64 = b64url_no_pad(b"sig");
let fake_jwt = format!("{header_b64}.{payload_b64}.{signature_b64}");
let info = parse_id_token(&fake_jwt).expect("should parse");
assert_eq!(info.email.as_deref(), Some("user@example.com"));
assert_eq!(info.get_chatgpt_plan_type().as_deref(), Some("Pro"));
}
#[test]
fn id_token_info_handles_missing_fields() {
#[derive(Serialize)]
struct Header {
alg: &'static str,
typ: &'static str,
}
let header = Header {
alg: "none",
typ: "JWT",
};
let payload = serde_json::json!({ "sub": "123" });
fn b64url_no_pad(bytes: &[u8]) -> String {
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes)
}
let header_b64 = b64url_no_pad(&serde_json::to_vec(&header).unwrap());
let payload_b64 = b64url_no_pad(&serde_json::to_vec(&payload).unwrap());
let signature_b64 = b64url_no_pad(b"sig");
let fake_jwt = format!("{header_b64}.{payload_b64}.{signature_b64}");
let info = parse_id_token(&fake_jwt).expect("should parse");
assert!(info.email.is_none());
assert!(info.get_chatgpt_plan_type().is_none());
}
}

View File

@@ -0,0 +1,10 @@
use serde::Deserialize;
use serde::Serialize;
/// Represents which apply_patch tool variant a model expects.
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
#[serde(rename_all = "snake_case")]
pub enum ApplyPatchToolType {
Freeform,
Function,
}

View File

@@ -0,0 +1,180 @@
//! Utilities for truncating large chunks of output while preserving a prefix
//! and suffix on UTF-8 boundaries.
/// Truncate the middle of a UTF-8 string to at most `max_bytes` bytes,
/// preserving the beginning and the end. Returns the possibly truncated
/// string and `Some(original_token_count)` (estimated at 4 bytes/token)
/// if truncation occurred; otherwise returns the original string and `None`.
pub fn truncate_middle(s: &str, max_bytes: usize) -> (String, Option<u64>) {
if s.len() <= max_bytes {
return (s.to_string(), None);
}
let est_tokens = (s.len() as u64).div_ceil(4);
if max_bytes == 0 {
return (format!("{est_tokens} tokens truncated…"), Some(est_tokens));
}
fn truncate_on_boundary(input: &str, max_len: usize) -> &str {
if input.len() <= max_len {
return input;
}
let mut end = max_len;
while end > 0 && !input.is_char_boundary(end) {
end -= 1;
}
&input[..end]
}
fn pick_prefix_end(s: &str, left_budget: usize) -> usize {
if let Some(head) = s.get(..left_budget)
&& let Some(i) = head.rfind('\n')
{
return i + 1;
}
truncate_on_boundary(s, left_budget).len()
}
fn pick_suffix_start(s: &str, right_budget: usize) -> usize {
let start_tail = s.len().saturating_sub(right_budget);
if let Some(tail) = s.get(start_tail..)
&& let Some(i) = tail.find('\n')
{
return start_tail + i + 1;
}
let mut idx = start_tail.min(s.len());
while idx < s.len() && !s.is_char_boundary(idx) {
idx += 1;
}
idx
}
let mut guess_tokens = est_tokens;
for _ in 0..4 {
let marker = format!("{guess_tokens} tokens truncated…");
let marker_len = marker.len();
let keep_budget = max_bytes.saturating_sub(marker_len);
if keep_budget == 0 {
return (format!("{est_tokens} tokens truncated…"), Some(est_tokens));
}
let left_budget = keep_budget / 2;
let right_budget = keep_budget - left_budget;
let prefix_end = pick_prefix_end(s, left_budget);
let mut suffix_start = pick_suffix_start(s, right_budget);
if suffix_start < prefix_end {
suffix_start = prefix_end;
}
let kept_content_bytes = prefix_end + (s.len() - suffix_start);
let truncated_content_bytes = s.len().saturating_sub(kept_content_bytes);
let new_tokens = (truncated_content_bytes as u64).div_ceil(4);
if new_tokens == guess_tokens {
let mut out = String::with_capacity(marker_len + kept_content_bytes + 1);
out.push_str(&s[..prefix_end]);
out.push_str(&marker);
out.push('\n');
out.push_str(&s[suffix_start..]);
return (out, Some(est_tokens));
}
guess_tokens = new_tokens;
}
let marker = format!("{guess_tokens} tokens truncated…");
let marker_len = marker.len();
let keep_budget = max_bytes.saturating_sub(marker_len);
if keep_budget == 0 {
return (format!("{est_tokens} tokens truncated…"), Some(est_tokens));
}
let left_budget = keep_budget / 2;
let right_budget = keep_budget - left_budget;
let prefix_end = pick_prefix_end(s, left_budget);
let suffix_start = pick_suffix_start(s, right_budget);
let mut out = String::with_capacity(marker_len + prefix_end + (s.len() - suffix_start) + 1);
out.push_str(&s[..prefix_end]);
out.push_str(&marker);
out.push('\n');
out.push_str(&s[suffix_start..]);
(out, Some(est_tokens))
}
#[cfg(test)]
mod tests {
use super::truncate_middle;
#[test]
fn truncate_middle_no_newlines_fallback() {
let s = "abcdefghijklmnopqrstuvwxyz0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ*";
let max_bytes = 32;
let (out, original) = truncate_middle(s, max_bytes);
assert!(out.starts_with("abc"));
assert!(out.contains("tokens truncated"));
assert!(out.ends_with("XYZ*"));
assert_eq!(original, Some((s.len() as u64).div_ceil(4)));
}
#[test]
fn truncate_middle_prefers_newline_boundaries() {
let mut s = String::new();
for i in 1..=20 {
s.push_str(&format!("{i:03}\n"));
}
assert_eq!(s.len(), 80);
let max_bytes = 64;
let (out, tokens) = truncate_middle(&s, max_bytes);
assert!(out.starts_with("001\n002\n003\n004\n"));
assert!(out.contains("tokens truncated"));
assert!(out.ends_with("017\n018\n019\n020\n"));
assert_eq!(tokens, Some(20));
}
#[test]
fn truncate_middle_handles_utf8_content() {
let s = "😀😀😀😀😀😀😀😀😀😀\nsecond line with ascii text\n";
let max_bytes = 32;
let (out, tokens) = truncate_middle(s, max_bytes);
assert!(out.contains("tokens truncated"));
assert!(!out.contains('\u{fffd}'));
assert_eq!(tokens, Some((s.len() as u64).div_ceil(4)));
}
#[test]
fn truncate_middle_prefers_newline_boundaries_2() {
// Build a multi-line string of 20 numbered lines (each "NNN\n").
let mut s = String::new();
for i in 1..=20 {
s.push_str(&format!("{i:03}\n"));
}
// Total length: 20 lines * 4 bytes per line = 80 bytes.
assert_eq!(s.len(), 80);
// Choose a cap that forces truncation while leaving room for
// a few lines on each side after accounting for the marker.
let max_bytes = 64;
// Expect exact output: first 4 lines, marker, last 4 lines, and correct token estimate (80/4 = 20).
assert_eq!(
truncate_middle(&s, max_bytes),
(
r#"001
002
003
004
…12 tokens truncated…
017
018
019
020
"#
.to_string(),
Some(20)
)
);
}
}

View File

@@ -0,0 +1,896 @@
use std::collections::HashMap;
use std::fs;
use std::path::Path;
use std::path::PathBuf;
use std::process::Command;
use anyhow::Context;
use anyhow::Result;
use anyhow::anyhow;
use sha1::digest::Output;
use uuid::Uuid;
use codex_protocol::protocol::FileChange;
const ZERO_OID: &str = "0000000000000000000000000000000000000000";
const DEV_NULL: &str = "/dev/null";
struct BaselineFileInfo {
path: PathBuf,
content: Vec<u8>,
mode: FileMode,
oid: String,
}
/// Tracks sets of changes to files and exposes the overall unified diff.
/// Internally, the way this works is now:
/// 1. Maintain an in-memory baseline snapshot of files when they are first seen.
/// For new additions, do not create a baseline so that diffs are shown as proper additions (using /dev/null).
/// 2. Keep a stable internal filename (uuid) per external path for rename tracking.
/// 3. To compute the aggregated unified diff, compare each baseline snapshot to the current file on disk entirely in-memory
/// using the `similar` crate and emit unified diffs with rewritten external paths.
#[derive(Default)]
pub struct TurnDiffTracker {
/// Map external path -> internal filename (uuid).
external_to_temp_name: HashMap<PathBuf, String>,
/// Internal filename -> baseline file info.
baseline_file_info: HashMap<String, BaselineFileInfo>,
/// Internal filename -> external path as of current accumulated state (after applying all changes).
/// This is where renames are tracked.
temp_name_to_current_path: HashMap<String, PathBuf>,
/// Cache of known git worktree roots to avoid repeated filesystem walks.
git_root_cache: Vec<PathBuf>,
}
impl TurnDiffTracker {
pub fn new() -> Self {
Self::default()
}
/// Front-run apply patch calls to track the starting contents of any modified files.
/// - Creates an in-memory baseline snapshot for files that already exist on disk when first seen.
/// - For additions, we intentionally do not create a baseline snapshot so that diffs are proper additions.
/// - Also updates internal mappings for move/rename events.
pub fn on_patch_begin(&mut self, changes: &HashMap<PathBuf, FileChange>) {
for (path, change) in changes.iter() {
// Ensure a stable internal filename exists for this external path.
if !self.external_to_temp_name.contains_key(path.as_path()) {
let internal = Uuid::new_v4().to_string();
self.external_to_temp_name
.insert(path.clone(), internal.clone());
self.temp_name_to_current_path
.insert(internal.clone(), path.clone());
// If the file exists on disk now, snapshot as baseline; else leave missing to represent /dev/null.
let baseline_file_info = if path.exists() {
let mode = file_mode_for_path(path);
let mode_val = mode.unwrap_or(FileMode::Regular);
let content = blob_bytes(path, mode_val).unwrap_or_default();
let oid = if mode == Some(FileMode::Symlink) {
format!("{:x}", git_blob_sha1_hex_bytes(&content))
} else {
self.git_blob_oid_for_path(path)
.unwrap_or_else(|| format!("{:x}", git_blob_sha1_hex_bytes(&content)))
};
Some(BaselineFileInfo {
path: path.clone(),
content,
mode: mode_val,
oid,
})
} else {
Some(BaselineFileInfo {
path: path.clone(),
content: vec![],
mode: FileMode::Regular,
oid: ZERO_OID.to_string(),
})
};
if let Some(baseline_file_info) = baseline_file_info {
self.baseline_file_info
.insert(internal.clone(), baseline_file_info);
}
}
// Track rename/move in current mapping if provided in an Update.
if let FileChange::Update {
move_path: Some(dest),
..
} = change
{
let uuid_filename = match self.external_to_temp_name.get(path.as_path()) {
Some(i) => i.clone(),
None => {
// This should be rare, but if we haven't mapped the source, create it with no baseline.
let i = Uuid::new_v4().to_string();
self.baseline_file_info.insert(
i.clone(),
BaselineFileInfo {
path: path.clone(),
content: vec![],
mode: FileMode::Regular,
oid: ZERO_OID.to_string(),
},
);
i
}
};
// Update current external mapping for temp file name.
self.temp_name_to_current_path
.insert(uuid_filename.clone(), dest.clone());
// Update forward file_mapping: external current -> internal name.
self.external_to_temp_name.remove(path);
self.external_to_temp_name
.insert(dest.clone(), uuid_filename);
};
}
}
fn get_path_for_internal(&self, internal: &str) -> Option<PathBuf> {
self.temp_name_to_current_path
.get(internal)
.cloned()
.or_else(|| {
self.baseline_file_info
.get(internal)
.map(|info| info.path.clone())
})
}
/// Find the git worktree root for a file/directory by walking up to the first ancestor containing a `.git` entry.
/// Uses a simple cache of known roots and avoids negative-result caching for simplicity.
fn find_git_root_cached(&mut self, start: &Path) -> Option<PathBuf> {
let dir = if start.is_dir() {
start
} else {
start.parent()?
};
// Fast path: if any cached root is an ancestor of this path, use it.
if let Some(root) = self
.git_root_cache
.iter()
.find(|r| dir.starts_with(r))
.cloned()
{
return Some(root);
}
// Walk up to find a `.git` marker.
let mut cur = dir.to_path_buf();
loop {
let git_marker = cur.join(".git");
if git_marker.is_dir() || git_marker.is_file() {
if !self.git_root_cache.iter().any(|r| r == &cur) {
self.git_root_cache.push(cur.clone());
}
return Some(cur);
}
// On Windows, avoid walking above the drive or UNC share root.
#[cfg(windows)]
{
if is_windows_drive_or_unc_root(&cur) {
return None;
}
}
if let Some(parent) = cur.parent() {
cur = parent.to_path_buf();
} else {
return None;
}
}
}
/// Return a display string for `path` relative to its git root if found, else absolute.
fn relative_to_git_root_str(&mut self, path: &Path) -> String {
let s = if let Some(root) = self.find_git_root_cached(path) {
if let Ok(rel) = path.strip_prefix(&root) {
rel.display().to_string()
} else {
path.display().to_string()
}
} else {
path.display().to_string()
};
s.replace('\\', "/")
}
/// Ask git to compute the blob SHA-1 for the file at `path` within its repository.
/// Returns None if no repository is found or git invocation fails.
fn git_blob_oid_for_path(&mut self, path: &Path) -> Option<String> {
let root = self.find_git_root_cached(path)?;
// Compute a path relative to the repo root for better portability across platforms.
let rel = path.strip_prefix(&root).unwrap_or(path);
let output = Command::new("git")
.arg("-C")
.arg(&root)
.arg("hash-object")
.arg("--")
.arg(rel)
.output()
.ok()?;
if !output.status.success() {
return None;
}
let s = String::from_utf8_lossy(&output.stdout).trim().to_string();
if s.len() == 40 { Some(s) } else { None }
}
/// Recompute the aggregated unified diff by comparing all of the in-memory snapshots that were
/// collected before the first time they were touched by apply_patch during this turn with
/// the current repo state.
pub fn get_unified_diff(&mut self) -> Result<Option<String>> {
let mut aggregated = String::new();
// Compute diffs per tracked internal file in a stable order by external path.
let mut baseline_file_names: Vec<String> =
self.baseline_file_info.keys().cloned().collect();
// Sort lexicographically by full repo-relative path to match git behavior.
baseline_file_names.sort_by_key(|internal| {
self.get_path_for_internal(internal)
.map(|p| self.relative_to_git_root_str(&p))
.unwrap_or_default()
});
for internal in baseline_file_names {
aggregated.push_str(self.get_file_diff(&internal).as_str());
if !aggregated.ends_with('\n') {
aggregated.push('\n');
}
}
if aggregated.trim().is_empty() {
Ok(None)
} else {
Ok(Some(aggregated))
}
}
fn get_file_diff(&mut self, internal_file_name: &str) -> String {
let mut aggregated = String::new();
// Snapshot lightweight fields only.
let (baseline_external_path, baseline_mode, left_oid) = {
if let Some(info) = self.baseline_file_info.get(internal_file_name) {
(info.path.clone(), info.mode, info.oid.clone())
} else {
(PathBuf::new(), FileMode::Regular, ZERO_OID.to_string())
}
};
let current_external_path = match self.get_path_for_internal(internal_file_name) {
Some(p) => p,
None => return aggregated,
};
let current_mode = file_mode_for_path(&current_external_path).unwrap_or(FileMode::Regular);
let right_bytes = blob_bytes(&current_external_path, current_mode);
// Compute displays with &mut self before borrowing any baseline content.
let left_display = self.relative_to_git_root_str(&baseline_external_path);
let right_display = self.relative_to_git_root_str(&current_external_path);
// Compute right oid before borrowing baseline content.
let right_oid = if let Some(b) = right_bytes.as_ref() {
if current_mode == FileMode::Symlink {
format!("{:x}", git_blob_sha1_hex_bytes(b))
} else {
self.git_blob_oid_for_path(&current_external_path)
.unwrap_or_else(|| format!("{:x}", git_blob_sha1_hex_bytes(b)))
}
} else {
ZERO_OID.to_string()
};
// Borrow baseline content only after all &mut self uses are done.
let left_present = left_oid.as_str() != ZERO_OID;
let left_bytes: Option<&[u8]> = if left_present {
self.baseline_file_info
.get(internal_file_name)
.map(|i| i.content.as_slice())
} else {
None
};
// Fast path: identical bytes or both missing.
if left_bytes == right_bytes.as_deref() {
return aggregated;
}
aggregated.push_str(&format!("diff --git a/{left_display} b/{right_display}\n"));
let is_add = !left_present && right_bytes.is_some();
let is_delete = left_present && right_bytes.is_none();
if is_add {
aggregated.push_str(&format!("new file mode {current_mode}\n"));
} else if is_delete {
aggregated.push_str(&format!("deleted file mode {baseline_mode}\n"));
} else if baseline_mode != current_mode {
aggregated.push_str(&format!("old mode {baseline_mode}\n"));
aggregated.push_str(&format!("new mode {current_mode}\n"));
}
let left_text = left_bytes.and_then(|b| std::str::from_utf8(b).ok());
let right_text = right_bytes
.as_deref()
.and_then(|b| std::str::from_utf8(b).ok());
let can_text_diff = matches!(
(left_text, right_text, is_add, is_delete),
(Some(_), Some(_), _, _) | (_, Some(_), true, _) | (Some(_), _, _, true)
);
if can_text_diff {
let l = left_text.unwrap_or("");
let r = right_text.unwrap_or("");
aggregated.push_str(&format!("index {left_oid}..{right_oid}\n"));
let old_header = if left_present {
format!("a/{left_display}")
} else {
DEV_NULL.to_string()
};
let new_header = if right_bytes.is_some() {
format!("b/{right_display}")
} else {
DEV_NULL.to_string()
};
let diff = similar::TextDiff::from_lines(l, r);
let unified = diff
.unified_diff()
.context_radius(3)
.header(&old_header, &new_header)
.to_string();
aggregated.push_str(&unified);
} else {
aggregated.push_str(&format!("index {left_oid}..{right_oid}\n"));
let old_header = if left_present {
format!("a/{left_display}")
} else {
DEV_NULL.to_string()
};
let new_header = if right_bytes.is_some() {
format!("b/{right_display}")
} else {
DEV_NULL.to_string()
};
aggregated.push_str(&format!("--- {old_header}\n"));
aggregated.push_str(&format!("+++ {new_header}\n"));
aggregated.push_str("Binary files differ\n");
}
aggregated
}
}
/// Compute the Git SHA-1 blob object ID for the given content (bytes).
fn git_blob_sha1_hex_bytes(data: &[u8]) -> Output<sha1::Sha1> {
// Git blob hash is sha1 of: "blob <len>\0<data>"
let header = format!("blob {}\0", data.len());
use sha1::Digest;
let mut hasher = sha1::Sha1::new();
hasher.update(header.as_bytes());
hasher.update(data);
hasher.finalize()
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum FileMode {
Regular,
#[cfg(unix)]
Executable,
Symlink,
}
impl FileMode {
fn as_str(self) -> &'static str {
match self {
FileMode::Regular => "100644",
#[cfg(unix)]
FileMode::Executable => "100755",
FileMode::Symlink => "120000",
}
}
}
impl std::fmt::Display for FileMode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
}
}
#[cfg(unix)]
fn file_mode_for_path(path: &Path) -> Option<FileMode> {
use std::os::unix::fs::PermissionsExt;
let meta = fs::symlink_metadata(path).ok()?;
let ft = meta.file_type();
if ft.is_symlink() {
return Some(FileMode::Symlink);
}
let mode = meta.permissions().mode();
let is_exec = (mode & 0o111) != 0;
Some(if is_exec {
FileMode::Executable
} else {
FileMode::Regular
})
}
#[cfg(not(unix))]
fn file_mode_for_path(_path: &Path) -> Option<FileMode> {
// Default to non-executable on non-unix.
Some(FileMode::Regular)
}
fn blob_bytes(path: &Path, mode: FileMode) -> Option<Vec<u8>> {
if path.exists() {
let contents = if mode == FileMode::Symlink {
symlink_blob_bytes(path)
.ok_or_else(|| anyhow!("failed to read symlink target for {}", path.display()))
} else {
fs::read(path)
.with_context(|| format!("failed to read current file for diff {}", path.display()))
};
contents.ok()
} else {
None
}
}
#[cfg(unix)]
fn symlink_blob_bytes(path: &Path) -> Option<Vec<u8>> {
use std::os::unix::ffi::OsStrExt;
let target = std::fs::read_link(path).ok()?;
Some(target.as_os_str().as_bytes().to_vec())
}
#[cfg(not(unix))]
fn symlink_blob_bytes(_path: &Path) -> Option<Vec<u8>> {
None
}
#[cfg(windows)]
fn is_windows_drive_or_unc_root(p: &std::path::Path) -> bool {
use std::path::Component;
let mut comps = p.components();
matches!(
(comps.next(), comps.next(), comps.next()),
(Some(Component::Prefix(_)), Some(Component::RootDir), None)
)
}
#[cfg(test)]
mod tests {
use super::*;
use pretty_assertions::assert_eq;
use tempfile::tempdir;
/// Compute the Git SHA-1 blob object ID for the given content (string).
/// This delegates to the bytes version to avoid UTF-8 lossy conversions here.
fn git_blob_sha1_hex(data: &str) -> String {
format!("{:x}", git_blob_sha1_hex_bytes(data.as_bytes()))
}
fn normalize_diff_for_test(input: &str, root: &Path) -> String {
let root_str = root.display().to_string().replace('\\', "/");
let replaced = input.replace(&root_str, "<TMP>");
// Split into blocks on lines starting with "diff --git ", sort blocks for determinism, and rejoin
let mut blocks: Vec<String> = Vec::new();
let mut current = String::new();
for line in replaced.lines() {
if line.starts_with("diff --git ") && !current.is_empty() {
blocks.push(current);
current = String::new();
}
if !current.is_empty() {
current.push('\n');
}
current.push_str(line);
}
if !current.is_empty() {
blocks.push(current);
}
blocks.sort();
let mut out = blocks.join("\n");
if !out.ends_with('\n') {
out.push('\n');
}
out
}
#[test]
fn accumulates_add_and_update() {
let mut acc = TurnDiffTracker::new();
let dir = tempdir().unwrap();
let file = dir.path().join("a.txt");
// First patch: add file (baseline should be /dev/null).
let add_changes = HashMap::from([(
file.clone(),
FileChange::Add {
content: "foo\n".to_string(),
},
)]);
acc.on_patch_begin(&add_changes);
// Simulate apply: create the file on disk.
fs::write(&file, "foo\n").unwrap();
let first = acc.get_unified_diff().unwrap().unwrap();
let first = normalize_diff_for_test(&first, dir.path());
let expected_first = {
let mode = file_mode_for_path(&file).unwrap_or(FileMode::Regular);
let right_oid = git_blob_sha1_hex("foo\n");
format!(
r#"diff --git a/<TMP>/a.txt b/<TMP>/a.txt
new file mode {mode}
index {ZERO_OID}..{right_oid}
--- {DEV_NULL}
+++ b/<TMP>/a.txt
@@ -0,0 +1 @@
+foo
"#,
)
};
assert_eq!(first, expected_first);
// Second patch: update the file on disk.
let update_changes = HashMap::from([(
file.clone(),
FileChange::Update {
unified_diff: "".to_owned(),
move_path: None,
},
)]);
acc.on_patch_begin(&update_changes);
// Simulate apply: append a new line.
fs::write(&file, "foo\nbar\n").unwrap();
let combined = acc.get_unified_diff().unwrap().unwrap();
let combined = normalize_diff_for_test(&combined, dir.path());
let expected_combined = {
let mode = file_mode_for_path(&file).unwrap_or(FileMode::Regular);
let right_oid = git_blob_sha1_hex("foo\nbar\n");
format!(
r#"diff --git a/<TMP>/a.txt b/<TMP>/a.txt
new file mode {mode}
index {ZERO_OID}..{right_oid}
--- {DEV_NULL}
+++ b/<TMP>/a.txt
@@ -0,0 +1,2 @@
+foo
+bar
"#,
)
};
assert_eq!(combined, expected_combined);
}
#[test]
fn accumulates_delete() {
let dir = tempdir().unwrap();
let file = dir.path().join("b.txt");
fs::write(&file, "x\n").unwrap();
let mut acc = TurnDiffTracker::new();
let del_changes = HashMap::from([(
file.clone(),
FileChange::Delete {
content: "x\n".to_string(),
},
)]);
acc.on_patch_begin(&del_changes);
// Simulate apply: delete the file from disk.
let baseline_mode = file_mode_for_path(&file).unwrap_or(FileMode::Regular);
fs::remove_file(&file).unwrap();
let diff = acc.get_unified_diff().unwrap().unwrap();
let diff = normalize_diff_for_test(&diff, dir.path());
let expected = {
let left_oid = git_blob_sha1_hex("x\n");
format!(
r#"diff --git a/<TMP>/b.txt b/<TMP>/b.txt
deleted file mode {baseline_mode}
index {left_oid}..{ZERO_OID}
--- a/<TMP>/b.txt
+++ {DEV_NULL}
@@ -1 +0,0 @@
-x
"#,
)
};
assert_eq!(diff, expected);
}
#[test]
fn accumulates_move_and_update() {
let dir = tempdir().unwrap();
let src = dir.path().join("src.txt");
let dest = dir.path().join("dst.txt");
fs::write(&src, "line\n").unwrap();
let mut acc = TurnDiffTracker::new();
let mv_changes = HashMap::from([(
src.clone(),
FileChange::Update {
unified_diff: "".to_owned(),
move_path: Some(dest.clone()),
},
)]);
acc.on_patch_begin(&mv_changes);
// Simulate apply: move and update content.
fs::rename(&src, &dest).unwrap();
fs::write(&dest, "line2\n").unwrap();
let out = acc.get_unified_diff().unwrap().unwrap();
let out = normalize_diff_for_test(&out, dir.path());
let expected = {
let left_oid = git_blob_sha1_hex("line\n");
let right_oid = git_blob_sha1_hex("line2\n");
format!(
r#"diff --git a/<TMP>/src.txt b/<TMP>/dst.txt
index {left_oid}..{right_oid}
--- a/<TMP>/src.txt
+++ b/<TMP>/dst.txt
@@ -1 +1 @@
-line
+line2
"#
)
};
assert_eq!(out, expected);
}
#[test]
fn move_without_1change_yields_no_diff() {
let dir = tempdir().unwrap();
let src = dir.path().join("moved.txt");
let dest = dir.path().join("renamed.txt");
fs::write(&src, "same\n").unwrap();
let mut acc = TurnDiffTracker::new();
let mv_changes = HashMap::from([(
src.clone(),
FileChange::Update {
unified_diff: "".to_owned(),
move_path: Some(dest.clone()),
},
)]);
acc.on_patch_begin(&mv_changes);
// Simulate apply: move only, no content change.
fs::rename(&src, &dest).unwrap();
let diff = acc.get_unified_diff().unwrap();
assert_eq!(diff, None);
}
#[test]
fn move_declared_but_file_only_appears_at_dest_is_add() {
let dir = tempdir().unwrap();
let src = dir.path().join("src.txt");
let dest = dir.path().join("dest.txt");
let mut acc = TurnDiffTracker::new();
let mv = HashMap::from([(
src,
FileChange::Update {
unified_diff: "".into(),
move_path: Some(dest.clone()),
},
)]);
acc.on_patch_begin(&mv);
// No file existed initially; create only dest
fs::write(&dest, "hello\n").unwrap();
let diff = acc.get_unified_diff().unwrap().unwrap();
let diff = normalize_diff_for_test(&diff, dir.path());
let expected = {
let mode = file_mode_for_path(&dest).unwrap_or(FileMode::Regular);
let right_oid = git_blob_sha1_hex("hello\n");
format!(
r#"diff --git a/<TMP>/src.txt b/<TMP>/dest.txt
new file mode {mode}
index {ZERO_OID}..{right_oid}
--- {DEV_NULL}
+++ b/<TMP>/dest.txt
@@ -0,0 +1 @@
+hello
"#,
)
};
assert_eq!(diff, expected);
}
#[test]
fn update_persists_across_new_baseline_for_new_file() {
let dir = tempdir().unwrap();
let a = dir.path().join("a.txt");
let b = dir.path().join("b.txt");
fs::write(&a, "foo\n").unwrap();
fs::write(&b, "z\n").unwrap();
let mut acc = TurnDiffTracker::new();
// First: update existing a.txt (baseline snapshot is created for a).
let update_a = HashMap::from([(
a.clone(),
FileChange::Update {
unified_diff: "".to_owned(),
move_path: None,
},
)]);
acc.on_patch_begin(&update_a);
// Simulate apply: modify a.txt on disk.
fs::write(&a, "foo\nbar\n").unwrap();
let first = acc.get_unified_diff().unwrap().unwrap();
let first = normalize_diff_for_test(&first, dir.path());
let expected_first = {
let left_oid = git_blob_sha1_hex("foo\n");
let right_oid = git_blob_sha1_hex("foo\nbar\n");
format!(
r#"diff --git a/<TMP>/a.txt b/<TMP>/a.txt
index {left_oid}..{right_oid}
--- a/<TMP>/a.txt
+++ b/<TMP>/a.txt
@@ -1 +1,2 @@
foo
+bar
"#
)
};
assert_eq!(first, expected_first);
// Next: introduce a brand-new path b.txt into baseline snapshots via a delete change.
let del_b = HashMap::from([(
b.clone(),
FileChange::Delete {
content: "z\n".to_string(),
},
)]);
acc.on_patch_begin(&del_b);
// Simulate apply: delete b.txt.
let baseline_mode = file_mode_for_path(&b).unwrap_or(FileMode::Regular);
fs::remove_file(&b).unwrap();
let combined = acc.get_unified_diff().unwrap().unwrap();
let combined = normalize_diff_for_test(&combined, dir.path());
let expected = {
let left_oid_a = git_blob_sha1_hex("foo\n");
let right_oid_a = git_blob_sha1_hex("foo\nbar\n");
let left_oid_b = git_blob_sha1_hex("z\n");
format!(
r#"diff --git a/<TMP>/a.txt b/<TMP>/a.txt
index {left_oid_a}..{right_oid_a}
--- a/<TMP>/a.txt
+++ b/<TMP>/a.txt
@@ -1 +1,2 @@
foo
+bar
diff --git a/<TMP>/b.txt b/<TMP>/b.txt
deleted file mode {baseline_mode}
index {left_oid_b}..{ZERO_OID}
--- a/<TMP>/b.txt
+++ {DEV_NULL}
@@ -1 +0,0 @@
-z
"#,
)
};
assert_eq!(combined, expected);
}
#[test]
fn binary_files_differ_update() {
let dir = tempdir().unwrap();
let file = dir.path().join("bin.dat");
// Initial non-UTF8 bytes
let left_bytes: Vec<u8> = vec![0xff, 0xfe, 0xfd, 0x00];
// Updated non-UTF8 bytes
let right_bytes: Vec<u8> = vec![0x01, 0x02, 0x03, 0x00];
fs::write(&file, &left_bytes).unwrap();
let mut acc = TurnDiffTracker::new();
let update_changes = HashMap::from([(
file.clone(),
FileChange::Update {
unified_diff: "".to_owned(),
move_path: None,
},
)]);
acc.on_patch_begin(&update_changes);
// Apply update on disk
fs::write(&file, &right_bytes).unwrap();
let diff = acc.get_unified_diff().unwrap().unwrap();
let diff = normalize_diff_for_test(&diff, dir.path());
let expected = {
let left_oid = format!("{:x}", git_blob_sha1_hex_bytes(&left_bytes));
let right_oid = format!("{:x}", git_blob_sha1_hex_bytes(&right_bytes));
format!(
r#"diff --git a/<TMP>/bin.dat b/<TMP>/bin.dat
index {left_oid}..{right_oid}
--- a/<TMP>/bin.dat
+++ b/<TMP>/bin.dat
Binary files differ
"#
)
};
assert_eq!(diff, expected);
}
#[test]
fn filenames_with_spaces_add_and_update() {
let mut acc = TurnDiffTracker::new();
let dir = tempdir().unwrap();
let file = dir.path().join("name with spaces.txt");
// First patch: add file (baseline should be /dev/null).
let add_changes = HashMap::from([(
file.clone(),
FileChange::Add {
content: "foo\n".to_string(),
},
)]);
acc.on_patch_begin(&add_changes);
// Simulate apply: create the file on disk.
fs::write(&file, "foo\n").unwrap();
let first = acc.get_unified_diff().unwrap().unwrap();
let first = normalize_diff_for_test(&first, dir.path());
let expected_first = {
let mode = file_mode_for_path(&file).unwrap_or(FileMode::Regular);
let right_oid = git_blob_sha1_hex("foo\n");
format!(
r#"diff --git a/<TMP>/name with spaces.txt b/<TMP>/name with spaces.txt
new file mode {mode}
index {ZERO_OID}..{right_oid}
--- {DEV_NULL}
+++ b/<TMP>/name with spaces.txt
@@ -0,0 +1 @@
+foo
"#,
)
};
assert_eq!(first, expected_first);
// Second patch: update the file on disk.
let update_changes = HashMap::from([(
file.clone(),
FileChange::Update {
unified_diff: "".to_owned(),
move_path: None,
},
)]);
acc.on_patch_begin(&update_changes);
// Simulate apply: append a new line with a space.
fs::write(&file, "foo\nbar baz\n").unwrap();
let combined = acc.get_unified_diff().unwrap().unwrap();
let combined = normalize_diff_for_test(&combined, dir.path());
let expected_combined = {
let mode = file_mode_for_path(&file).unwrap_or(FileMode::Regular);
let right_oid = git_blob_sha1_hex("foo\nbar baz\n");
format!(
r#"diff --git a/<TMP>/name with spaces.txt b/<TMP>/name with spaces.txt
new file mode {mode}
index {ZERO_OID}..{right_oid}
--- {DEV_NULL}
+++ b/<TMP>/name with spaces.txt
@@ -0,0 +1,2 @@
+foo
+bar baz
"#,
)
};
assert_eq!(combined, expected_combined);
}
}

View File

@@ -1,7 +1,7 @@
use thiserror::Error;
#[derive(Debug, Error)]
pub(crate) enum UnifiedExecError {
pub enum UnifiedExecError {
#[error("Failed to create unified exec session: {pty_error}")]
CreateSession {
#[source]

View File

@@ -22,27 +22,27 @@ use crate::truncate::truncate_middle;
mod errors;
pub(crate) use errors::UnifiedExecError;
pub use errors::UnifiedExecError;
const DEFAULT_TIMEOUT_MS: u64 = 1_000;
const MAX_TIMEOUT_MS: u64 = 60_000;
const UNIFIED_EXEC_OUTPUT_MAX_BYTES: usize = 128 * 1024; // 128 KiB
#[derive(Debug)]
pub(crate) struct UnifiedExecRequest<'a> {
pub struct UnifiedExecRequest<'a> {
pub session_id: Option<i32>,
pub input_chunks: &'a [String],
pub timeout_ms: Option<u64>,
}
#[derive(Debug, Clone, PartialEq)]
pub(crate) struct UnifiedExecResult {
pub struct UnifiedExecResult {
pub session_id: Option<i32>,
pub output: String,
}
#[derive(Debug, Default)]
pub(crate) struct UnifiedExecSessionManager {
pub struct UnifiedExecSessionManager {
next_session_id: AtomicI32,
sessions: Mutex<HashMap<i32, ManagedUnifiedExecSession>>,
}

112
codex-rs/agent_refactor.md Normal file
View File

@@ -0,0 +1,112 @@
# Agent Runtime Refactor
## Goals
- Decouple the Codex agent loop from CLI-specific wiring so it can run as a reusable library or standalone binary.
- Preserve the current behaviour of `codex-core` (tooling, approvals, sandboxing, MCP integration) while providing a cleaner embedding surface.
- Enable specialised hosts—CLI, training harnesses, response API bridges—to share the same runtime with minimal glue code.
## Proposed Architecture
### 1. `codex-agent` crate (new)
- Owns the session runtime: `AgentRuntime`, `AgentHandle`, the states, and the task runners now under `core/src/tasks`.
- Exposes a queue-like API: `AgentHandle::submit(Op/Submission)` and `AgentHandle::next_event()` mirroring todays behaviour.
- Re-exports protocol types from `codex-protocol` so consumers do not depend on the entire `codex-core` tree.
- Houses the agent loop (`run_task`, `run_turn`, exec/safety plumbing) together with the sandbox planner (`ExecPlan`, `PreparedExec`, etc.).
### 2. Shared configuration surface
- Introduce `AgentConfig` as the minimal runtime configuration (model, provider, approvals, sandbox defaults, cwd, user/base instructions, feature flags relevant to the loop).
- Provide `From<&Config>` for CLI compatibility; training/other hosts construct `AgentConfig` directly.
- CLI-only concerns (logging, auth prompts, workspace presets) stay inside `codex-core` and are translated before spawning the runtime.
### 3. Service abstraction layer
- Define traits that the runtime depends on instead of concrete CLI structs:
- `CredentialsProvider` (wraps `AuthManager`).
- `Notifier` (reuses `UserNotifier` contract).
- `McpInterface` (start/list tools, dispatch tool calls).
- `SandboxManager` (wraps `BackendRegistry`/`prepare_exec_invocation` wiring).
- `RolloutSink` (write/flush rollout items; default no-op).
- Provide default implementations in `codex-core` that simply wrap the existing services (`SessionServices`).
### 4. Task subsystem consolidation
- Keep the new `SessionTask` trait and concrete tasks (`RegularTask`, `ReviewTask`, `CompactTask`) inside `codex-agent` so custom hosts can opt into additional tasks without touching CLI crates.
- Ensure task lifecycle management (`spawn_task`, `abort_all_tasks`, `ActiveTurn`) stays encapsulated in the runtime and surfaces only high-level signals (events, cancellation APIs).
### 5. Sandbox execution layer
- Move the recently created `core/src/sandbox` module into `codex-agent` (or re-export) so runtime owns exec planning.
- Runtime exposes an injectable `SandboxRuntimeConfig` (paths, seatbelt binary, stdout streaming choice) and calls into `SandboxManager` to execute plans.
- Respect existing environment variables and approval policies; no semantic changes to seatbelt handling.
### 6. Host integrations
- CLI crate: replaces direct usage of `Codex::spawn` with `AgentRuntime::spawn`, adapting CLI config/auth providers to runtime traits. Behaviour remains identical.
- Training binary (`codex-agent-bin`): thin crate that parses CLI flags (Response API URL, auth token, optional instructions) and bridges remote Ops/Events to the runtime via chosen transport (MCP channel, HTTP/WebSocket bridge).
- Additional hosts can embed the runtime by implementing the service traits and providing transport glue.
### 7. Transport adapters
- Internally keep `async_channel` for runtime queues.
- Provide helper adapters (`AgentTransport` trait) so callers can hook streams (local channel, TCP bridge, etc.) while keeping backpressure and graceful shutdown semantics consistent.
## Guidelines
- **Config boundary**: new code must depend on `AgentConfig`; only CLI/front-ends may use the broader `Config` struct. Avoid adding CLI-specific fields to the runtime config.
- **Trait-based services**: any runtime dependency that could vary across hosts (MCP, rollout persistence, sandbox execution, notifications) should be expressed as a trait with a default implementation living in `codex-core`.
- **Task authoring**: additional tasks must implement `SessionTask`; tasks are responsible for calling `run_task`/`exit_review_mode` helpers and returning final assistant output for `TaskComplete` events.
- **Sandbox safety**: all exec/patch calls must flow through `plan_exec`/`plan_apply_patch` (now under `codex-agent::sandbox`) to preserve approval semantics. Never bypass `SandboxManager`.
- **MCP usage**: runtime talks only through `McpInterface`; hosts provide concrete connectors (existing CLI manager, lightweight training stub, etc.).
- **Rollout handling**: default `RolloutSink` should no-op; hosts that require persistence (CLI, evaluation harness) supply an implementation that wraps existing recorder.
- **Transport/backpressure**: treat the runtime queue as bounded and handle cancellations; adapters must propagate `Op::Shutdown` promptly.
- **Observability**: keep tracing instrumentation intact; new modules should use existing `tracing` spans for start/end of tasks, exec calls, and MCP interactions.
- **Code quality**: write minimalist idiomatic code. Leverage the capacity of Rust
## Current Scope Snapshot
- `codex-agent` owns the execution/runtime surface: conversation history, rollout recording, function tool plumbing, sandbox planning, command/apply_patch safety, and the new `ApprovalCoordinator` trait that abstracts user approvals. Host-agnostic helpers such as shell formatting, bash parsing, and command safety now live here.
- `codex-core` focuses on CLI integration: loading user configuration, wiring concrete services (auth, MCP, sandbox manager), translating CLI policies into runtime configs, and exposing the embedded runtime to front-ends. It re-exports runtime modules needed by existing callers but should avoid hosting new agent logic.
- Session bootstrap now flows through a host-provided `prepare_session_bootstrap` helper: the CLI constructs rollout/MCP/sandbox services, builds the new `codex_agent::SessionServices` + `SessionState`, pre-builds the initial `TurnContext` (model client + tool config), and hands them to `Session::new` instead of constructing them inline.
## Implementation Plan
1. **Baseline & documentation**
- Capture current interfaces (`Codex`, `Session`, `SessionTask`) and update developer docs to reference this refactor plan.
- Add smoke tests covering multi-task scenarios (regular + review + compact) to guard against regressions during extraction.
2. **Introduce `AgentConfig`**
- Define struct + conversion helpers inside `codex-core`.
- Refactor internal `Session::new` / `TurnContext` builders to accept `AgentConfig` without changing external behaviour.
3. **Service trait extraction**
- Carve out trait definitions (`CredentialsProvider`, `McpInterface`, `SandboxManager`, `RolloutSink`, `Notifier`).
- Provide adapters backed by existing `SessionServices`.
- Update `Session` and helper modules to depend on traits rather than concrete structs.
4. **Create `codex-agent` crate**
- Scaffold crate, move runtime modules (`codex.rs`, `state`, `tasks`, `sandbox`) while keeping module paths stable via `pub use` re-exports.
- Resolve module imports to reference trait abstractions / helper crates (e.g., `codex_protocol`, `codex-apply-patch`).
- Ensure crate exposes `AgentRuntime`, `AgentHandle`, and service traits.
5. **Adapt `codex-core`**
- Replace `Codex::spawn` with thin wrapper that constructs `AgentConfig`, runtime service adapters, and delegates to `codex-agent`.
- Update public API to re-export runtime types if downstream crates expect them.
- Confirm unit tests continue to pass.
6. **Update front-ends**
- CLI crate: switch to new runtime API; verify login/auth flows, approvals, and sandbox invocations.
- Other binaries (`chatgpt`, etc.) migrate similarly, adjusting imports/config conversions.
7. **Add training binary**
- Implement new `codex-agent-bin` crate providing CLI for Response API URL + auth.
- Reuse existing MCP client logic where possible; otherwise, provide minimal HTTP bridge translating Ops/Events.
- Add integration tests using mocked Response API.
8. **Refine transport adapters**
- Add optional helper module offering channel/TCP/WebSocket adapters along with graceful shutdown behaviour.
- Document how hosts select or implement transports.
9. **Finalize rollout persistence strategy**
- Implement `RolloutSink` adapters (file-based, in-memory, disabled).
- Ensure CLI wires existing recorder; training binary can opt in/out via flags.
10. **Docs & polish**
- Update repository documentation (`README`, architecture docs) to reference the new crates and APIs.
- Record migration notes for downstream consumers.
- Run `just fmt`, scoped `just fix -p`, and targeted tests for touched crates before merging.
11. **Validation**
- Execute `cargo test -p codex-agent`, `cargo test -p codex-core`, and full suite (`cargo test --all-features`) once shared crates change.
- Perform manual verification: CLI session, review task, training binary against mock Response API, ensuring approvals and sandboxing behave identically.

View File

@@ -20,6 +20,7 @@ base64 = { workspace = true }
bytes = { workspace = true }
chrono = { workspace = true, features = ["serde"] }
codex-apply-patch = { workspace = true }
codex-agent = { workspace = true }
codex-file-search = { workspace = true }
codex-mcp-client = { workspace = true }
codex-protocol = { workspace = true }
@@ -60,8 +61,6 @@ tokio-util = { workspace = true }
toml = { workspace = true }
toml_edit = { workspace = true }
tracing = { workspace = true, features = ["log"] }
tree-sitter = { workspace = true }
tree-sitter-bash = { workspace = true }
uuid = { workspace = true, features = ["serde", "v4"] }
which = { workspace = true }
wildmatch = { workspace = true }

View File

@@ -0,0 +1,38 @@
pub use codex_agent::AgentConfig;
use crate::config::Config;
impl From<&Config> for AgentConfig {
fn from(config: &Config) -> Self {
Self {
model: config.model.clone(),
review_model: config.review_model.clone(),
model_family: config.model_family.clone(),
model_context_window: config.model_context_window,
model_auto_compact_token_limit: config.model_auto_compact_token_limit,
model_reasoning_effort: config.model_reasoning_effort,
model_reasoning_summary: config.model_reasoning_summary,
model_verbosity: config.model_verbosity,
model_provider: config.model_provider.clone(),
approval_policy: config.approval_policy,
sandbox_policy: config.sandbox_policy.clone(),
shell_environment_policy: config.shell_environment_policy.clone(),
user_instructions: config.user_instructions.clone(),
base_instructions: config.base_instructions.clone(),
notify: config.notify.clone(),
cwd: config.cwd.clone(),
codex_home: config.codex_home.clone(),
history: config.history.clone(),
mcp_servers: config.mcp_servers.clone(),
include_plan_tool: config.include_plan_tool,
include_apply_patch_tool: config.include_apply_patch_tool,
include_view_image_tool: config.include_view_image_tool,
tools_web_search_request: config.tools_web_search_request,
use_experimental_streamable_shell_tool: config.use_experimental_streamable_shell_tool,
use_experimental_unified_exec_tool: config.use_experimental_unified_exec_tool,
show_raw_agent_reasoning: config.show_raw_agent_reasoning,
codex_linux_sandbox_exe: config.codex_linux_sandbox_exe.clone(),
project_doc_max_bytes: config.project_doc_max_bytes,
}
}
}

View File

@@ -0,0 +1,148 @@
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use anyhow::Result;
use async_trait::async_trait;
use codex_agent::notifications::UserNotification;
use codex_agent::services::CredentialsProvider;
use codex_agent::services::McpInterface;
use codex_agent::services::Notifier;
use codex_agent::services::ProviderAuth;
use codex_agent::services::SandboxManager;
use codex_agent::token_data::PlanType;
use codex_protocol::mcp_protocol::AuthMode;
use mcp_types::CallToolResult;
use mcp_types::Tool;
use serde_json::Value;
use crate::auth::AuthManager;
use crate::auth::CodexAuth;
use crate::exec_command::ExecCommandOutput;
use crate::exec_command::ExecCommandParams;
use crate::exec_command::ExecSessionManager;
use crate::exec_command::WriteStdinParams;
use crate::mcp_connection_manager::McpConnectionManager;
use crate::unified_exec::UnifiedExecError;
use crate::unified_exec::UnifiedExecRequest;
use crate::unified_exec::UnifiedExecResult;
use crate::unified_exec::UnifiedExecSessionManager;
use crate::user_notification::UserNotifier;
#[async_trait]
impl ProviderAuth for CodexAuth {
fn mode(&self) -> AuthMode {
self.mode
}
async fn access_token(&self) -> std::io::Result<String> {
self.get_token().await
}
fn account_id(&self) -> Option<String> {
self.get_account_id()
}
fn plan_type(&self) -> Option<PlanType> {
self.get_plan_type()
}
}
#[async_trait]
impl CredentialsProvider for AuthManager {
fn auth(&self) -> Option<Arc<dyn ProviderAuth>> {
AuthManager::auth(self).map(|auth| Arc::new(auth) as Arc<dyn ProviderAuth>)
}
async fn refresh_token(&self) -> std::io::Result<Option<String>> {
AuthManager::refresh_token(self).await
}
}
impl Notifier for UserNotifier {
fn notify(&self, notification: &UserNotification) {
UserNotifier::notify(self, notification);
}
}
#[async_trait]
impl McpInterface for McpConnectionManager {
fn list_all_tools(&self) -> HashMap<String, Tool> {
McpConnectionManager::list_all_tools(self)
}
fn parse_tool_name(&self, tool_name: &str) -> Option<(String, String)> {
McpConnectionManager::parse_tool_name(self, tool_name)
}
async fn call_tool(
&self,
server: &str,
tool: &str,
arguments: Option<Value>,
) -> Result<CallToolResult> {
McpConnectionManager::call_tool(self, server, tool, arguments).await
}
}
/// Default [`SandboxManager`] used by the CLI runtime. Wraps the existing exec
/// session managers and exposes their functionality via the trait-based
/// interface so other hosts can substitute different implementations.
pub struct DefaultSandboxManager {
exec_session_manager: ExecSessionManager,
unified_exec_manager: UnifiedExecSessionManager,
codex_linux_sandbox_exe: Option<PathBuf>,
user_shell: crate::shell::Shell,
}
impl DefaultSandboxManager {
pub fn new(
exec_session_manager: ExecSessionManager,
unified_exec_manager: UnifiedExecSessionManager,
codex_linux_sandbox_exe: Option<PathBuf>,
user_shell: crate::shell::Shell,
) -> Self {
Self {
exec_session_manager,
unified_exec_manager,
codex_linux_sandbox_exe,
user_shell,
}
}
}
#[async_trait]
impl SandboxManager for DefaultSandboxManager {
async fn handle_exec_command_request(
&self,
params: ExecCommandParams,
) -> Result<ExecCommandOutput, String> {
self.exec_session_manager
.handle_exec_command_request(params)
.await
}
async fn handle_write_stdin_request(
&self,
params: WriteStdinParams,
) -> Result<ExecCommandOutput, String> {
self.exec_session_manager
.handle_write_stdin_request(params)
.await
}
async fn handle_unified_exec_request(
&self,
request: UnifiedExecRequest<'_>,
) -> Result<UnifiedExecResult, UnifiedExecError> {
self.unified_exec_manager.handle_request(request).await
}
fn codex_linux_sandbox_exe(&self) -> &Option<PathBuf> {
&self.codex_linux_sandbox_exe
}
fn user_shell(&self) -> &crate::shell::Shell {
&self.user_shell
}
}

View File

@@ -135,7 +135,7 @@ impl CodexAuth {
self.get_current_token_data().and_then(|t| t.account_id)
}
pub(crate) fn get_plan_type(&self) -> Option<PlanType> {
pub fn get_plan_type(&self) -> Option<PlanType> {
self.get_current_token_data()
.and_then(|t| t.id_token.chatgpt_plan_type)
}

View File

@@ -22,6 +22,7 @@ use crate::client_common::ResponseStream;
use crate::error::CodexErr;
use crate::error::Result;
use crate::model_family::ModelFamily;
use crate::model_provider_info::ModelProviderExt;
use crate::openai_tools::create_tools_json_for_chat_completions_api;
use crate::util::backoff;
use codex_protocol::models::ContentItem;

View File

@@ -1,10 +1,10 @@
use std::fmt;
use std::io::BufRead;
use std::path::Path;
use std::sync::OnceLock;
use std::time::Duration;
use crate::AuthManager;
use crate::auth::CodexAuth;
use crate::CredentialsProvider;
use bytes::Bytes;
use codex_protocol::mcp_protocol::AuthMode;
use codex_protocol::mcp_protocol::ConversationId;
@@ -23,6 +23,9 @@ use tracing::debug;
use tracing::trace;
use tracing::warn;
use crate::ModelProviderInfo;
use crate::WireApi;
use crate::agent_config::AgentConfig;
use crate::chat_completions::AggregateStreamExt;
use crate::chat_completions::stream_chat_completions;
use crate::client_common::Prompt;
@@ -31,15 +34,13 @@ use crate::client_common::ResponseStream;
use crate::client_common::ResponsesApiRequest;
use crate::client_common::create_reasoning_param_for_request;
use crate::client_common::create_text_param_for_request;
use crate::config::Config;
use crate::default_client::create_client;
use crate::error::CodexErr;
use crate::error::Result;
use crate::error::UsageLimitReachedError;
use crate::flags::CODEX_RS_SSE_FIXTURE;
use crate::model_family::ModelFamily;
use crate::model_provider_info::ModelProviderInfo;
use crate::model_provider_info::WireApi;
use crate::model_provider_info::ModelProviderExt;
use crate::openai_model_info::get_model_info;
use crate::openai_tools::create_tools_json_for_responses_api;
use crate::protocol::RateLimitSnapshot;
@@ -69,10 +70,10 @@ struct Error {
resets_in_seconds: Option<u64>,
}
#[derive(Debug, Clone)]
#[derive(Clone)]
pub struct ModelClient {
config: Arc<Config>,
auth_manager: Option<Arc<AuthManager>>,
config: Arc<AgentConfig>,
auth_manager: Option<Arc<dyn CredentialsProvider>>,
client: reqwest::Client,
provider: ModelProviderInfo,
conversation_id: ConversationId,
@@ -80,10 +81,22 @@ pub struct ModelClient {
summary: ReasoningSummaryConfig,
}
impl fmt::Debug for ModelClient {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ModelClient")
.field("config", &self.config)
.field("provider", &self.provider)
.field("conversation_id", &self.conversation_id)
.field("effort", &self.effort)
.field("summary", &self.summary)
.finish()
}
}
impl ModelClient {
pub fn new(
config: Arc<Config>,
auth_manager: Option<Arc<AuthManager>>,
config: Arc<AgentConfig>,
auth_manager: Option<Arc<dyn CredentialsProvider>>,
provider: ModelProviderInfo,
effort: Option<ReasoningEffortConfig>,
summary: ReasoningSummaryConfig,
@@ -259,7 +272,7 @@ impl ModelClient {
async fn attempt_stream_responses(
&self,
payload_json: &Value,
auth_manager: &Option<Arc<AuthManager>>,
auth_manager: &Option<Arc<dyn CredentialsProvider>>,
) -> std::result::Result<ResponseStream, StreamAttemptError> {
// Always fetch the latest auth in case a prior attempt refreshed the token.
let auth = auth_manager.as_ref().and_then(|m| m.auth());
@@ -285,8 +298,8 @@ impl ModelClient {
.json(payload_json);
if let Some(auth) = auth.as_ref()
&& auth.mode == AuthMode::ChatGPT
&& let Some(account_id) = auth.get_account_id()
&& auth.mode() == AuthMode::ChatGPT
&& let Some(account_id) = auth.account_id()
{
req_builder = req_builder.header("chatgpt-account-id", account_id);
}
@@ -372,7 +385,7 @@ impl ModelClient {
// token.
let plan_type = error
.plan_type
.or_else(|| auth.as_ref().and_then(CodexAuth::get_plan_type));
.or_else(|| auth.as_ref().and_then(|a| a.plan_type()));
let resets_in_seconds = error.resets_in_seconds;
let codex_err = CodexErr::UsageLimitReached(UsageLimitReachedError {
plan_type,
@@ -419,7 +432,7 @@ impl ModelClient {
self.summary
}
pub fn get_auth_manager(&self) -> Option<Arc<AuthManager>> {
pub fn get_auth_manager(&self) -> Option<Arc<dyn CredentialsProvider>> {
self.auth_manager.clone()
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -7,6 +7,7 @@ use crate::Prompt;
use crate::client_common::ResponseEvent;
use crate::error::CodexErr;
use crate::error::Result as CodexResult;
use crate::model_provider_info::ModelProviderExt;
use crate::protocol::AgentMessageEvent;
use crate::protocol::CompactedItem;
use crate::protocol::ErrorEvent;

View File

@@ -1,3 +1,4 @@
use crate::ModelProviderInfo;
use crate::config_profile::ConfigProfile;
use crate::config_types::History;
use crate::config_types::McpServerConfig;
@@ -12,7 +13,6 @@ use crate::git_info::resolve_root_git_project_for_trust;
use crate::model_family::ModelFamily;
use crate::model_family::derive_default_model_family;
use crate::model_family::find_family_for_model;
use crate::model_provider_info::ModelProviderInfo;
use crate::model_provider_info::built_in_model_providers;
use crate::openai_model_info::get_model_info;
use crate::protocol::AskForApproval;

View File

@@ -1,305 +1,3 @@
//! Types used to define the fields of [`crate::config::Config`].
//! Re-exported configuration data structures now defined in `codex-agent`.
// Note this file should generally be restricted to simple struct/enum
// definitions that do not contain business logic.
use std::collections::HashMap;
use std::path::PathBuf;
use std::time::Duration;
use wildmatch::WildMatchPattern;
use serde::Deserialize;
use serde::Deserializer;
use serde::Serialize;
use serde::de::Error as SerdeError;
#[derive(Serialize, Debug, Clone, PartialEq)]
pub struct McpServerConfig {
pub command: String,
#[serde(default)]
pub args: Vec<String>,
#[serde(default)]
pub env: Option<HashMap<String, String>>,
/// Startup timeout in seconds for initializing MCP server & initially listing tools.
#[serde(
default,
with = "option_duration_secs",
skip_serializing_if = "Option::is_none"
)]
pub startup_timeout_sec: Option<Duration>,
/// Default timeout for MCP tool calls initiated via this server.
#[serde(default, with = "option_duration_secs")]
pub tool_timeout_sec: Option<Duration>,
}
impl<'de> Deserialize<'de> for McpServerConfig {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
#[derive(Deserialize)]
struct RawMcpServerConfig {
command: String,
#[serde(default)]
args: Vec<String>,
#[serde(default)]
env: Option<HashMap<String, String>>,
#[serde(default)]
startup_timeout_sec: Option<f64>,
#[serde(default)]
startup_timeout_ms: Option<u64>,
#[serde(default, with = "option_duration_secs")]
tool_timeout_sec: Option<Duration>,
}
let raw = RawMcpServerConfig::deserialize(deserializer)?;
let startup_timeout_sec = match (raw.startup_timeout_sec, raw.startup_timeout_ms) {
(Some(sec), _) => {
let duration = Duration::try_from_secs_f64(sec).map_err(SerdeError::custom)?;
Some(duration)
}
(None, Some(ms)) => Some(Duration::from_millis(ms)),
(None, None) => None,
};
Ok(Self {
command: raw.command,
args: raw.args,
env: raw.env,
startup_timeout_sec,
tool_timeout_sec: raw.tool_timeout_sec,
})
}
}
mod option_duration_secs {
use serde::Deserialize;
use serde::Deserializer;
use serde::Serializer;
use std::time::Duration;
pub fn serialize<S>(value: &Option<Duration>, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
match value {
Some(duration) => serializer.serialize_some(&duration.as_secs_f64()),
None => serializer.serialize_none(),
}
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<Duration>, D::Error>
where
D: Deserializer<'de>,
{
let secs = Option::<f64>::deserialize(deserializer)?;
secs.map(|secs| Duration::try_from_secs_f64(secs).map_err(serde::de::Error::custom))
.transpose()
}
}
#[derive(Deserialize, Debug, Copy, Clone, PartialEq)]
pub enum UriBasedFileOpener {
#[serde(rename = "vscode")]
VsCode,
#[serde(rename = "vscode-insiders")]
VsCodeInsiders,
#[serde(rename = "windsurf")]
Windsurf,
#[serde(rename = "cursor")]
Cursor,
/// Option to disable the URI-based file opener.
#[serde(rename = "none")]
None,
}
impl UriBasedFileOpener {
pub fn get_scheme(&self) -> Option<&str> {
match self {
UriBasedFileOpener::VsCode => Some("vscode"),
UriBasedFileOpener::VsCodeInsiders => Some("vscode-insiders"),
UriBasedFileOpener::Windsurf => Some("windsurf"),
UriBasedFileOpener::Cursor => Some("cursor"),
UriBasedFileOpener::None => None,
}
}
}
/// Settings that govern if and what will be written to `~/.codex/history.jsonl`.
#[derive(Deserialize, Debug, Clone, PartialEq, Default)]
pub struct History {
/// If true, history entries will not be written to disk.
pub persistence: HistoryPersistence,
/// If set, the maximum size of the history file in bytes.
/// TODO(mbolin): Not currently honored.
pub max_bytes: Option<usize>,
}
#[derive(Deserialize, Debug, Copy, Clone, PartialEq, Default)]
#[serde(rename_all = "kebab-case")]
pub enum HistoryPersistence {
/// Save all history entries to disk.
#[default]
SaveAll,
/// Do not write history to disk.
None,
}
#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
#[serde(untagged)]
pub enum Notifications {
Enabled(bool),
Custom(Vec<String>),
}
impl Default for Notifications {
fn default() -> Self {
Self::Enabled(false)
}
}
/// Collection of settings that are specific to the TUI.
#[derive(Deserialize, Debug, Clone, PartialEq, Default)]
pub struct Tui {
/// Enable desktop notifications from the TUI when the terminal is unfocused.
/// Defaults to `false`.
#[serde(default)]
pub notifications: Notifications,
}
#[derive(Deserialize, Debug, Clone, PartialEq, Default)]
pub struct SandboxWorkspaceWrite {
#[serde(default)]
pub writable_roots: Vec<PathBuf>,
#[serde(default)]
pub network_access: bool,
#[serde(default)]
pub exclude_tmpdir_env_var: bool,
#[serde(default)]
pub exclude_slash_tmp: bool,
}
impl From<SandboxWorkspaceWrite> for codex_protocol::mcp_protocol::SandboxSettings {
fn from(sandbox_workspace_write: SandboxWorkspaceWrite) -> Self {
Self {
writable_roots: sandbox_workspace_write.writable_roots,
network_access: Some(sandbox_workspace_write.network_access),
exclude_tmpdir_env_var: Some(sandbox_workspace_write.exclude_tmpdir_env_var),
exclude_slash_tmp: Some(sandbox_workspace_write.exclude_slash_tmp),
}
}
}
#[derive(Deserialize, Debug, Clone, PartialEq, Default)]
#[serde(rename_all = "kebab-case")]
pub enum ShellEnvironmentPolicyInherit {
/// "Core" environment variables for the platform. On UNIX, this would
/// include HOME, LOGNAME, PATH, SHELL, and USER, among others.
Core,
/// Inherits the full environment from the parent process.
#[default]
All,
/// Do not inherit any environment variables from the parent process.
None,
}
/// Policy for building the `env` when spawning a process via either the
/// `shell` or `local_shell` tool.
#[derive(Deserialize, Debug, Clone, PartialEq, Default)]
pub struct ShellEnvironmentPolicyToml {
pub inherit: Option<ShellEnvironmentPolicyInherit>,
pub ignore_default_excludes: Option<bool>,
/// List of regular expressions.
pub exclude: Option<Vec<String>>,
pub r#set: Option<HashMap<String, String>>,
/// List of regular expressions.
pub include_only: Option<Vec<String>>,
pub experimental_use_profile: Option<bool>,
}
pub type EnvironmentVariablePattern = WildMatchPattern<'*', '?'>;
/// Deriving the `env` based on this policy works as follows:
/// 1. Create an initial map based on the `inherit` policy.
/// 2. If `ignore_default_excludes` is false, filter the map using the default
/// exclude pattern(s), which are: `"*KEY*"` and `"*TOKEN*"`.
/// 3. If `exclude` is not empty, filter the map using the provided patterns.
/// 4. Insert any entries from `r#set` into the map.
/// 5. If non-empty, filter the map using the `include_only` patterns.
#[derive(Debug, Clone, PartialEq, Default)]
pub struct ShellEnvironmentPolicy {
/// Starting point when building the environment.
pub inherit: ShellEnvironmentPolicyInherit,
/// True to skip the check to exclude default environment variables that
/// contain "KEY" or "TOKEN" in their name.
pub ignore_default_excludes: bool,
/// Environment variable names to exclude from the environment.
pub exclude: Vec<EnvironmentVariablePattern>,
/// (key, value) pairs to insert in the environment.
pub r#set: HashMap<String, String>,
/// Environment variable names to retain in the environment.
pub include_only: Vec<EnvironmentVariablePattern>,
/// If true, the shell profile will be used to run the command.
pub use_profile: bool,
}
impl From<ShellEnvironmentPolicyToml> for ShellEnvironmentPolicy {
fn from(toml: ShellEnvironmentPolicyToml) -> Self {
// Default to inheriting the full environment when not specified.
let inherit = toml.inherit.unwrap_or(ShellEnvironmentPolicyInherit::All);
let ignore_default_excludes = toml.ignore_default_excludes.unwrap_or(false);
let exclude = toml
.exclude
.unwrap_or_default()
.into_iter()
.map(|s| EnvironmentVariablePattern::new_case_insensitive(&s))
.collect();
let r#set = toml.r#set.unwrap_or_default();
let include_only = toml
.include_only
.unwrap_or_default()
.into_iter()
.map(|s| EnvironmentVariablePattern::new_case_insensitive(&s))
.collect();
let use_profile = toml.experimental_use_profile.unwrap_or(false);
Self {
inherit,
ignore_default_excludes,
exclude,
r#set,
include_only,
use_profile,
}
}
}
#[derive(Deserialize, Debug, Clone, PartialEq, Eq, Default, Hash)]
#[serde(rename_all = "kebab-case")]
pub enum ReasoningSummaryFormat {
#[default]
None,
Experimental,
}
pub use codex_agent::config_types::*;

View File

@@ -1,120 +1 @@
use codex_protocol::models::ResponseItem;
/// Transcript of conversation history
#[derive(Debug, Clone, Default)]
pub(crate) struct ConversationHistory {
/// The oldest items are at the beginning of the vector.
items: Vec<ResponseItem>,
}
impl ConversationHistory {
pub(crate) fn new() -> Self {
Self { items: Vec::new() }
}
/// Returns a clone of the contents in the transcript.
pub(crate) fn contents(&self) -> Vec<ResponseItem> {
self.items.clone()
}
/// `items` is ordered from oldest to newest.
pub(crate) fn record_items<I>(&mut self, items: I)
where
I: IntoIterator,
I::Item: std::ops::Deref<Target = ResponseItem>,
{
for item in items {
if !is_api_message(&item) {
continue;
}
self.items.push(item.clone());
}
}
pub(crate) fn replace(&mut self, items: Vec<ResponseItem>) {
self.items = items;
}
}
/// Anything that is not a system message or "reasoning" message is considered
/// an API message.
fn is_api_message(message: &ResponseItem) -> bool {
match message {
ResponseItem::Message { role, .. } => role.as_str() != "system",
ResponseItem::FunctionCallOutput { .. }
| ResponseItem::FunctionCall { .. }
| ResponseItem::CustomToolCall { .. }
| ResponseItem::CustomToolCallOutput { .. }
| ResponseItem::LocalShellCall { .. }
| ResponseItem::Reasoning { .. }
| ResponseItem::WebSearchCall { .. } => true,
ResponseItem::Other => false,
}
}
#[cfg(test)]
mod tests {
use super::*;
use codex_protocol::models::ContentItem;
fn assistant_msg(text: &str) -> ResponseItem {
ResponseItem::Message {
id: None,
role: "assistant".to_string(),
content: vec![ContentItem::OutputText {
text: text.to_string(),
}],
}
}
fn user_msg(text: &str) -> ResponseItem {
ResponseItem::Message {
id: None,
role: "user".to_string(),
content: vec![ContentItem::OutputText {
text: text.to_string(),
}],
}
}
#[test]
fn filters_non_api_messages() {
let mut h = ConversationHistory::default();
// System message is not an API message; Other is ignored.
let system = ResponseItem::Message {
id: None,
role: "system".to_string(),
content: vec![ContentItem::OutputText {
text: "ignored".to_string(),
}],
};
h.record_items([&system, &ResponseItem::Other]);
// User and assistant should be retained.
let u = user_msg("hi");
let a = assistant_msg("hello");
h.record_items([&u, &a]);
let items = h.contents();
assert_eq!(
items,
vec![
ResponseItem::Message {
id: None,
role: "user".to_string(),
content: vec![ContentItem::OutputText {
text: "hi".to_string()
}]
},
ResponseItem::Message {
id: None,
role: "assistant".to_string(),
content: vec![ContentItem::OutputText {
text: "hello".to_string()
}]
}
]
);
}
}
pub use codex_agent::ConversationHistory;

View File

@@ -308,19 +308,16 @@ mod tests {
Some(PathBuf::from("/repo")),
Some(AskForApproval::OnRequest),
Some(workspace_write_policy(vec!["/repo"], false)),
Some(Shell::Bash(BashShell {
shell_path: "/bin/bash".into(),
bashrc_path: "/home/user/.bashrc".into(),
})),
Some(Shell::Bash(BashShell::new(
"/bin/bash",
"/home/user/.bashrc",
))),
);
let context2 = EnvironmentContext::new(
Some(PathBuf::from("/repo")),
Some(AskForApproval::OnRequest),
Some(workspace_write_policy(vec!["/repo"], false)),
Some(Shell::Zsh(ZshShell {
shell_path: "/bin/zsh".into(),
zshrc_path: "/home/user/.zshrc".into(),
})),
Some(Shell::Zsh(ZshShell::new("/bin/zsh", "/home/user/.zshrc"))),
);
assert!(context1.equals_except_shell(&context2));

View File

@@ -27,6 +27,7 @@ use crate::protocol::SandboxPolicy;
use crate::seatbelt::spawn_command_under_seatbelt;
use crate::spawn::StdioPolicy;
use crate::spawn::spawn_child_async;
pub use codex_agent::sandbox::SandboxType;
const DEFAULT_TIMEOUT_MS: u64 = 10_000;
@@ -61,17 +62,6 @@ impl ExecParams {
}
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum SandboxType {
None,
/// Only available on macOS.
MacosSeatbelt,
/// Only available on Linux.
LinuxSeccomp,
}
#[derive(Clone)]
pub struct StdoutStream {
pub sub_id: String,

View File

@@ -3,6 +3,11 @@ use std::collections::BTreeMap;
use crate::openai_tools::JsonSchema;
use crate::openai_tools::ResponsesApiTool;
pub use codex_agent::exec_command::ExecCommandOutput;
pub use codex_agent::exec_command::ExecCommandParams;
pub use codex_agent::exec_command::ExecSessionManager;
pub use codex_agent::exec_command::WriteStdinParams;
pub const EXEC_COMMAND_TOOL_NAME: &str = "exec_command";
pub const WRITE_STDIN_TOOL_NAME: &str = "write_stdin";

View File

@@ -1,14 +0,0 @@
mod exec_command_params;
mod exec_command_session;
mod responses_api;
mod session_id;
mod session_manager;
pub use exec_command_params::ExecCommandParams;
pub use exec_command_params::WriteStdinParams;
pub(crate) use exec_command_session::ExecCommandSession;
pub use responses_api::EXEC_COMMAND_TOOL_NAME;
pub use responses_api::WRITE_STDIN_TOOL_NAME;
pub use responses_api::create_exec_command_tool_for_responses_api;
pub use responses_api::create_write_stdin_tool_for_responses_api;
pub use session_manager::SessionManager as ExecSessionManager;

View File

@@ -1,7 +1 @@
use thiserror::Error;
#[derive(Debug, Error, PartialEq)]
pub enum FunctionCallError {
#[error("{0}")]
RespondToModel(String),
}
pub use codex_agent::function_tool::*;

View File

@@ -5,9 +5,13 @@
// the TUI or the tracing stack).
#![deny(clippy::print_stdout, clippy::print_stderr)]
mod apply_patch;
pub mod agent_config;
pub mod agent_services;
pub mod auth;
pub mod bash;
// `bash` helpers now live in `codex-agent`; re-export the module for
// backwards compatibility with existing imports.
pub use codex_agent::bash;
pub use codex_agent::command_safety;
mod chat_completions;
mod client;
mod client_common;
@@ -15,7 +19,6 @@ pub mod codex;
mod codex_conversation;
pub mod token_data;
pub use codex_conversation::CodexConversation;
mod command_safety;
pub mod config;
pub mod config_edit;
pub mod config_profile;
@@ -40,8 +43,6 @@ mod truncate;
mod unified_exec;
mod user_instructions;
pub use model_provider_info::BUILT_IN_OSS_MODEL_PROVIDER_ID;
pub use model_provider_info::ModelProviderInfo;
pub use model_provider_info::WireApi;
pub use model_provider_info::built_in_model_providers;
pub use model_provider_info::create_oss_provider_with_base_url;
mod conversation_manager;
@@ -60,30 +61,30 @@ mod openai_tools;
pub mod plan_tool;
pub mod project_doc;
mod rollout;
pub(crate) mod safety;
pub mod sandbox;
pub mod seatbelt;
pub mod shell;
pub mod spawn;
pub mod terminal;
mod tool_apply_patch;
pub mod turn_diff_tracker;
pub use codex_protocol::protocol::SessionMeta;
pub use rollout::ARCHIVED_SESSIONS_SUBDIR;
pub use rollout::RolloutRecorder;
pub use rollout::SESSIONS_SUBDIR;
pub use rollout::SessionMeta;
pub use rollout::find_conversation_path_by_id_str;
pub use rollout::list::ConversationItem;
pub use rollout::list::ConversationsPage;
pub use rollout::list::Cursor;
pub use rollout::list::find_conversation_path_by_id_str;
mod function_tool;
mod state;
mod tasks;
mod user_notification;
pub mod util;
pub use apply_patch::CODEX_APPLY_PATCH_ARG1;
pub use command_safety::is_safe_command;
pub use safety::get_platform_sandbox;
pub use codex_agent::apply_patch::CODEX_APPLY_PATCH_ARG1;
pub use codex_agent::command_safety::is_safe_command;
pub use codex_agent::safety::get_platform_sandbox;
// Re-export the protocol types from the standalone `codex-protocol` crate so existing
// `codex_core::protocol::...` references continue to work across the workspace.
pub use codex_protocol::protocol;
@@ -91,6 +92,8 @@ pub use codex_protocol::protocol;
// as those in the protocol crate when constructing protocol messages.
pub use codex_protocol::config_types as protocol_config_types;
pub use agent_config::AgentConfig;
pub use agent_services::DefaultSandboxManager;
pub use client::ModelClient;
pub use client_common::Prompt;
pub use client_common::REVIEW_PROMPT;
@@ -98,6 +101,14 @@ pub use client_common::ResponseEvent;
pub use client_common::ResponseStream;
pub use codex::compact::content_items_to_text;
pub use codex::compact::is_session_prefix_message;
pub use codex_agent::ModelProviderInfo;
pub use codex_agent::WireApi;
pub use codex_agent::services::CredentialsProvider;
pub use codex_agent::services::McpInterface;
pub use codex_agent::services::Notifier;
pub use codex_agent::services::ProviderAuth;
pub use codex_agent::services::RolloutSink;
pub use codex_agent::services::SandboxManager;
pub use codex_protocol::models::ContentItem;
pub use codex_protocol::models::LocalShellAction;
pub use codex_protocol::models::LocalShellExecAction;

View File

@@ -27,7 +27,7 @@ use std::time::Duration;
use tokio::fs;
use tokio::io::AsyncReadExt;
use crate::config::Config;
use crate::agent_config::AgentConfig;
use crate::config_types::HistoryPersistence;
use codex_protocol::mcp_protocol::ConversationId;
@@ -49,7 +49,7 @@ pub struct HistoryEntry {
pub text: String,
}
fn history_filepath(config: &Config) -> PathBuf {
fn history_filepath(config: &AgentConfig) -> PathBuf {
let mut path = config.codex_home.clone();
path.push(HISTORY_FILENAME);
path
@@ -61,7 +61,7 @@ fn history_filepath(config: &Config) -> PathBuf {
pub(crate) async fn append_entry(
text: &str,
conversation_id: &ConversationId,
config: &Config,
config: &AgentConfig,
) -> Result<()> {
match config.history.persistence {
HistoryPersistence::SaveAll => {
@@ -140,7 +140,7 @@ pub(crate) async fn append_entry(
/// Asynchronously fetch the history file's *identifier* (inode on Unix) and
/// the current number of entries by counting newline characters.
pub(crate) async fn history_metadata(config: &Config) -> (u64, usize) {
pub(crate) async fn history_metadata(config: &AgentConfig) -> (u64, usize) {
let path = history_filepath(config);
#[cfg(unix)]
@@ -187,7 +187,7 @@ pub(crate) async fn history_metadata(config: &Config) -> (u64, usize) {
/// Note this function is not async because it uses a sync advisory file
/// locking API.
#[cfg(unix)]
pub(crate) fn lookup(log_id: u64, offset: usize, config: &Config) -> Option<HistoryEntry> {
pub(crate) fn lookup(log_id: u64, offset: usize, config: &AgentConfig) -> Option<HistoryEntry> {
use std::io::BufRead;
use std::io::BufReader;
use std::os::unix::fs::MetadataExt;

View File

@@ -1,48 +1,12 @@
use crate::config_types::ReasoningSummaryFormat;
use crate::tool_apply_patch::ApplyPatchToolType;
use codex_agent::ApplyPatchToolType;
pub use codex_agent::ModelFamily;
/// The `instructions` field in the payload sent to a model should always start
/// with this content.
const BASE_INSTRUCTIONS: &str = include_str!("../prompt.md");
const GPT_5_CODEX_INSTRUCTIONS: &str = include_str!("../gpt_5_codex_prompt.md");
/// A model family is a group of models that share certain characteristics.
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct ModelFamily {
/// The full model slug used to derive this model family, e.g.
/// "gpt-4.1-2025-04-14".
pub slug: String,
/// The model family name, e.g. "gpt-4.1". Note this should able to be used
/// with [`crate::openai_model_info::get_model_info`].
pub family: String,
/// True if the model needs additional instructions on how to use the
/// "virtual" `apply_patch` CLI.
pub needs_special_apply_patch_instructions: bool,
// Whether the `reasoning` field can be set when making a request to this
// model family. Note it has `effort` and `summary` subfields (though
// `summary` is optional).
pub supports_reasoning_summaries: bool,
// Define if we need a special handling of reasoning summary
pub reasoning_summary_format: ReasoningSummaryFormat,
// This should be set to true when the model expects a tool named
// "local_shell" to be provided. Its contract must be understood natively by
// the model such that its description can be omitted.
// See https://platform.openai.com/docs/guides/tools-local-shell
pub uses_local_shell_tool: bool,
/// Present if the model performs better when `apply_patch` is provided as
/// a tool call instead of just a bash command
pub apply_patch_tool_type: Option<ApplyPatchToolType>,
// Instructions to use for querying the model
pub base_instructions: String,
}
macro_rules! model_family {
(
$slug:expr, $family:expr $(, $key:ident : $value:expr )* $(,)?

View File

@@ -5,15 +5,19 @@
//! 2. User-defined entries inside `~/.codex/config.toml` under the `model_providers`
//! key. These override or extend the defaults at runtime.
use crate::CodexAuth;
use async_trait::async_trait;
pub use codex_agent::ModelProviderInfo;
pub use codex_agent::WireApi;
use codex_protocol::mcp_protocol::AuthMode;
use serde::Deserialize;
use serde::Serialize;
use std::collections::HashMap;
use std::env::VarError;
use std::sync::Arc;
use std::time::Duration;
use crate::CodexAuth;
use crate::ProviderAuth;
use crate::error::EnvVarError;
const DEFAULT_STREAM_IDLE_TIMEOUT_MS: u64 = 300_000;
const DEFAULT_STREAM_MAX_RETRIES: u64 = 5;
const DEFAULT_REQUEST_MAX_RETRIES: u64 = 4;
@@ -22,135 +26,58 @@ const MAX_STREAM_MAX_RETRIES: u64 = 100;
/// Hard cap for user-configured `request_max_retries`.
const MAX_REQUEST_MAX_RETRIES: u64 = 100;
/// Wire protocol that the provider speaks. Most third-party services only
/// implement the classic OpenAI Chat Completions JSON schema, whereas OpenAI
/// itself (and a handful of others) additionally expose the more modern
/// *Responses* API. The two protocols use different request/response shapes
/// and *cannot* be auto-detected at runtime, therefore each provider entry
/// must declare which one it expects.
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum WireApi {
/// The Responses API exposed by OpenAI at `/v1/responses`.
Responses,
#[async_trait]
pub trait ModelProviderExt {
async fn create_request_builder(
&self,
client: &reqwest::Client,
auth: &Option<Arc<dyn ProviderAuth>>,
) -> crate::error::Result<reqwest::RequestBuilder>;
/// Regular Chat Completions compatible with `/v1/chat/completions`.
#[default]
Chat,
fn get_full_url(&self, auth: &Option<Arc<dyn ProviderAuth>>) -> String;
fn is_azure_responses_endpoint(&self) -> bool;
fn apply_http_headers(&self, builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder;
fn api_key(&self) -> crate::error::Result<Option<String>>;
fn request_max_retries(&self) -> u64;
fn stream_max_retries(&self) -> u64;
fn stream_idle_timeout(&self) -> Duration;
}
/// Serializable representation of a provider definition.
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
pub struct ModelProviderInfo {
/// Friendly display name.
pub name: String,
/// Base URL for the provider's OpenAI-compatible API.
pub base_url: Option<String>,
/// Environment variable that stores the user's API key for this provider.
pub env_key: Option<String>,
/// Optional instructions to help the user get a valid value for the
/// variable and set it.
pub env_key_instructions: Option<String>,
/// Which wire protocol this provider expects.
#[serde(default)]
pub wire_api: WireApi,
/// Optional query parameters to append to the base URL.
pub query_params: Option<HashMap<String, String>>,
/// Additional HTTP headers to include in requests to this provider where
/// the (key, value) pairs are the header name and value.
pub http_headers: Option<HashMap<String, String>>,
/// Optional HTTP headers to include in requests to this provider where the
/// (key, value) pairs are the header name and _environment variable_ whose
/// value should be used. If the environment variable is not set, or the
/// value is empty, the header will not be included in the request.
pub env_http_headers: Option<HashMap<String, String>>,
/// Maximum number of times to retry a failed HTTP request to this provider.
pub request_max_retries: Option<u64>,
/// Number of times to retry reconnecting a dropped streaming response before failing.
pub stream_max_retries: Option<u64>,
/// Idle timeout (in milliseconds) to wait for activity on a streaming response before treating
/// the connection as lost.
pub stream_idle_timeout_ms: Option<u64>,
/// Does this provider require an OpenAI API Key or ChatGPT login token? If true,
/// user is presented with login screen on first run, and login preference and token/key
/// are stored in auth.json. If false (which is the default), login screen is skipped,
/// and API key (if needed) comes from the "env_key" environment variable.
#[serde(default)]
pub requires_openai_auth: bool,
}
impl ModelProviderInfo {
/// Construct a `POST` RequestBuilder for the given URL using the provided
/// reqwest Client applying:
/// • provider-specific headers (static + env based)
/// • Bearer auth header when an API key is available.
/// • Auth token for OAuth.
///
/// If the provider declares an `env_key` but the variable is missing/empty, returns an [`Err`] identical to the
/// one produced by [`ModelProviderInfo::api_key`].
pub async fn create_request_builder<'a>(
&'a self,
client: &'a reqwest::Client,
auth: &Option<CodexAuth>,
#[async_trait]
impl ModelProviderExt for ModelProviderInfo {
async fn create_request_builder(
&self,
client: &reqwest::Client,
auth: &Option<Arc<dyn ProviderAuth>>,
) -> crate::error::Result<reqwest::RequestBuilder> {
let effective_auth = match self.api_key() {
Ok(Some(key)) => Some(CodexAuth::from_api_key(&key)),
Ok(None) => auth.clone(),
Err(err) => {
if auth.is_some() {
auth.clone()
} else {
return Err(err);
}
}
let effective_auth: Option<Arc<dyn ProviderAuth>> = match self.api_key()? {
Some(key) => Some(Arc::new(CodexAuth::from_api_key(&key))),
None => auth.clone(),
};
let url = self.get_full_url(&effective_auth);
let mut builder = client.post(url);
if let Some(auth) = effective_auth.as_ref() {
builder = builder.bearer_auth(auth.get_token().await?);
builder = builder.bearer_auth(auth.access_token().await?);
}
Ok(self.apply_http_headers(builder))
}
fn get_query_string(&self) -> String {
self.query_params
.as_ref()
.map_or_else(String::new, |params| {
let full_params = params
.iter()
.map(|(k, v)| format!("{k}={v}"))
.collect::<Vec<_>>()
.join("&");
format!("?{full_params}")
})
}
pub(crate) fn get_full_url(&self, auth: &Option<CodexAuth>) -> String {
let default_base_url = if matches!(
auth,
Some(CodexAuth {
mode: AuthMode::ChatGPT,
..
})
) {
fn get_full_url(&self, auth: &Option<Arc<dyn ProviderAuth>>) -> String {
let default_base_url = if auth.as_ref().map(|a| a.mode()) == Some(AuthMode::ChatGPT) {
"https://chatgpt.com/backend-api/codex"
} else {
"https://api.openai.com/v1"
};
let query_string = self.get_query_string();
let query_string = get_query_string(self);
let base_url = self
.base_url
.clone()
@@ -162,7 +89,7 @@ impl ModelProviderInfo {
}
}
pub(crate) fn is_azure_responses_endpoint(&self) -> bool {
fn is_azure_responses_endpoint(&self) -> bool {
if self.wire_api != WireApi::Responses {
return false;
}
@@ -177,9 +104,6 @@ impl ModelProviderInfo {
.unwrap_or(false)
}
/// Apply provider-specific HTTP headers (both static and environment-based)
/// onto an existing `reqwest::RequestBuilder` and return the updated
/// builder.
fn apply_http_headers(&self, mut builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
if let Some(extra) = &self.http_headers {
for (k, v) in extra {
@@ -199,10 +123,7 @@ impl ModelProviderInfo {
builder
}
/// If `env_key` is Some, returns the API key for this provider if present
/// (and non-empty) in the environment. If `env_key` is required but
/// cannot be found, returns an error.
pub fn api_key(&self) -> crate::error::Result<Option<String>> {
fn api_key(&self) -> crate::error::Result<Option<String>> {
match &self.env_key {
Some(env_key) => {
let env_value = std::env::var(env_key);
@@ -225,28 +146,39 @@ impl ModelProviderInfo {
}
}
/// Effective maximum number of request retries for this provider.
pub fn request_max_retries(&self) -> u64 {
fn request_max_retries(&self) -> u64 {
self.request_max_retries
.unwrap_or(DEFAULT_REQUEST_MAX_RETRIES)
.min(MAX_REQUEST_MAX_RETRIES)
}
/// Effective maximum number of stream reconnection attempts for this provider.
pub fn stream_max_retries(&self) -> u64 {
fn stream_max_retries(&self) -> u64 {
self.stream_max_retries
.unwrap_or(DEFAULT_STREAM_MAX_RETRIES)
.min(MAX_STREAM_MAX_RETRIES)
}
/// Effective idle timeout for streaming responses.
pub fn stream_idle_timeout(&self) -> Duration {
fn stream_idle_timeout(&self) -> Duration {
self.stream_idle_timeout_ms
.map(Duration::from_millis)
.unwrap_or(Duration::from_millis(DEFAULT_STREAM_IDLE_TIMEOUT_MS))
}
}
fn get_query_string(provider: &ModelProviderInfo) -> String {
provider
.query_params
.as_ref()
.map_or_else(String::new, |params| {
let full_params = params
.iter()
.map(|(k, v)| format!("{k}={v}"))
.collect::<Vec<_>>()
.join("&");
format!("?{full_params}")
})
}
const DEFAULT_OLLAMA_PORT: u32 = 11434;
pub const BUILT_IN_OSS_MODEL_PROVIDER_ID: &str = "oss";
@@ -255,20 +187,11 @@ pub const BUILT_IN_OSS_MODEL_PROVIDER_ID: &str = "oss";
pub fn built_in_model_providers() -> HashMap<String, ModelProviderInfo> {
use ModelProviderInfo as P;
// We do not want to be in the business of adjucating which third-party
// providers are bundled with Codex CLI, so we only include the OpenAI and
// open source ("oss") providers by default. Users are encouraged to add to
// `model_providers` in config.toml to add their own providers.
[
(
"openai",
P {
name: "OpenAI".into(),
// Allow users to override the default OpenAI endpoint by
// exporting `OPENAI_BASE_URL`. This is useful when pointing
// Codex at a proxy, mock server, or Azure-style deployment
// without requiring a full TOML override for the built-in
// OpenAI provider.
base_url: std::env::var("OPENAI_BASE_URL")
.ok()
.filter(|v| !v.trim().is_empty()),
@@ -292,7 +215,6 @@ pub fn built_in_model_providers() -> HashMap<String, ModelProviderInfo> {
.into_iter()
.collect(),
),
// Use global defaults for retry/timeout unless overridden in config.toml.
request_max_retries: None,
stream_max_retries: None,
stream_idle_timeout_ms: None,
@@ -307,8 +229,6 @@ pub fn built_in_model_providers() -> HashMap<String, ModelProviderInfo> {
}
pub fn create_oss_provider() -> ModelProviderInfo {
// These CODEX_OSS_ environment variables are experimental: we may
// switch to reading values from config.toml instead.
let codex_oss_base_url = match std::env::var("CODEX_OSS_BASE_URL")
.ok()
.filter(|v| !v.trim().is_empty())
@@ -361,130 +281,11 @@ mod tests {
use super::*;
use pretty_assertions::assert_eq;
#[test]
fn test_deserialize_ollama_model_provider_toml() {
let azure_provider_toml = r#"
name = "Ollama"
base_url = "http://localhost:11434/v1"
"#;
let expected_provider = ModelProviderInfo {
name: "Ollama".into(),
base_url: Some("http://localhost:11434/v1".into()),
env_key: None,
env_key_instructions: None,
wire_api: WireApi::Chat,
query_params: None,
http_headers: None,
env_http_headers: None,
request_max_retries: None,
stream_max_retries: None,
stream_idle_timeout_ms: None,
requires_openai_auth: false,
};
let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap();
assert_eq!(expected_provider, provider);
}
#[test]
fn test_deserialize_azure_model_provider_toml() {
let azure_provider_toml = r#"
name = "Azure"
base_url = "https://xxxxx.openai.azure.com/openai"
env_key = "AZURE_OPENAI_API_KEY"
query_params = { api-version = "2025-04-01-preview" }
"#;
let expected_provider = ModelProviderInfo {
name: "Azure".into(),
base_url: Some("https://xxxxx.openai.azure.com/openai".into()),
env_key: Some("AZURE_OPENAI_API_KEY".into()),
env_key_instructions: None,
wire_api: WireApi::Chat,
query_params: Some(maplit::hashmap! {
"api-version".to_string() => "2025-04-01-preview".to_string(),
}),
http_headers: None,
env_http_headers: None,
request_max_retries: None,
stream_max_retries: None,
stream_idle_timeout_ms: None,
requires_openai_auth: false,
};
let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap();
assert_eq!(expected_provider, provider);
}
#[test]
fn test_deserialize_example_model_provider_toml() {
let azure_provider_toml = r#"
name = "Example"
base_url = "https://example.com"
env_key = "API_KEY"
http_headers = { "X-Example-Header" = "example-value" }
env_http_headers = { "X-Example-Env-Header" = "EXAMPLE_ENV_VAR" }
"#;
let expected_provider = ModelProviderInfo {
name: "Example".into(),
base_url: Some("https://example.com".into()),
env_key: Some("API_KEY".into()),
env_key_instructions: None,
wire_api: WireApi::Chat,
query_params: None,
http_headers: Some(maplit::hashmap! {
"X-Example-Header".to_string() => "example-value".to_string(),
}),
env_http_headers: Some(maplit::hashmap! {
"X-Example-Env-Header".to_string() => "EXAMPLE_ENV_VAR".to_string(),
}),
request_max_retries: None,
stream_max_retries: None,
stream_idle_timeout_ms: None,
requires_openai_auth: false,
};
let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap();
assert_eq!(expected_provider, provider);
}
#[test]
fn detects_azure_responses_base_urls() {
fn provider_for(base_url: &str) -> ModelProviderInfo {
ModelProviderInfo {
name: "test".into(),
base_url: Some(base_url.into()),
env_key: None,
env_key_instructions: None,
wire_api: WireApi::Responses,
query_params: None,
http_headers: None,
env_http_headers: None,
request_max_retries: None,
stream_max_retries: None,
stream_idle_timeout_ms: None,
requires_openai_auth: false,
}
}
let positive_cases = [
"https://foo.openai.azure.com/openai",
"https://foo.openai.azure.us/openai/deployments/bar",
"https://foo.cognitiveservices.azure.cn/openai",
"https://foo.aoai.azure.com/openai",
"https://foo.openai.azure-api.net/openai",
"https://foo.z01.azurefd.net/",
];
for base_url in positive_cases {
let provider = provider_for(base_url);
assert!(
provider.is_azure_responses_endpoint(),
"expected {base_url} to be detected as Azure"
);
}
let named_provider = ModelProviderInfo {
name: "Azure".into(),
base_url: Some("https://example.com".into()),
#[tokio::test]
async fn creates_request_builder_with_auth() {
let provider = ModelProviderInfo {
name: "openai".to_string(),
base_url: None,
env_key: None,
env_key_instructions: None,
wire_api: WireApi::Responses,
@@ -494,21 +295,33 @@ env_http_headers = { "X-Example-Env-Header" = "EXAMPLE_ENV_VAR" }
request_max_retries: None,
stream_max_retries: None,
stream_idle_timeout_ms: None,
requires_openai_auth: false,
requires_openai_auth: true,
};
assert!(named_provider.is_azure_responses_endpoint());
let client = reqwest::Client::new();
let auth =
Some(Arc::new(CodexAuth::create_dummy_chatgpt_auth_for_testing())
as Arc<dyn ProviderAuth>);
let negative_cases = [
"https://api.openai.com/v1",
"https://example.com/openai",
"https://myproxy.azurewebsites.net/openai",
];
for base_url in negative_cases {
let provider = provider_for(base_url);
assert!(
!provider.is_azure_responses_endpoint(),
"expected {base_url} not to be detected as Azure"
);
}
let builder = provider
.create_request_builder(&client, &auth)
.await
.expect("builder");
let request = builder.build().expect("request");
assert_eq!(request.method(), reqwest::Method::POST);
assert_eq!(
request.url().as_str(),
"https://chatgpt.com/backend-api/codex/responses"
);
}
#[test]
fn azure_detection() {
let mut provider = create_oss_provider();
assert!(!provider.is_azure_responses_endpoint());
provider.name = "azure".to_string();
provider.wire_api = WireApi::Responses;
assert!(provider.is_azure_responses_endpoint());
}
}

View File

@@ -12,7 +12,7 @@
//! that order.
//! 3. We do **not** walk past the Git root.
use crate::config::Config;
use crate::agent_config::AgentConfig;
use std::path::PathBuf;
use tokio::io::AsyncReadExt;
use tracing::error;
@@ -26,7 +26,7 @@ const PROJECT_DOC_SEPARATOR: &str = "\n\n--- project-doc ---\n\n";
/// Combines `Config::instructions` and `AGENTS.md` (if present) into a single
/// string of instructions.
pub(crate) async fn get_user_instructions(config: &Config) -> Option<String> {
pub(crate) async fn get_user_instructions(config: &AgentConfig) -> Option<String> {
match read_project_docs(config).await {
Ok(Some(project_doc)) => match &config.user_instructions {
Some(original_instructions) => Some(format!(
@@ -48,7 +48,7 @@ pub(crate) async fn get_user_instructions(config: &Config) -> Option<String> {
/// concatenation of all discovered docs. If no documentation file is found the
/// function returns `Ok(None)`. Unexpected I/O failures bubble up as `Err` so
/// callers can decide how to handle them.
pub async fn read_project_docs(config: &Config) -> std::io::Result<Option<String>> {
pub async fn read_project_docs(config: &AgentConfig) -> std::io::Result<Option<String>> {
let max_total = config.project_doc_max_bytes;
if max_total == 0 {
@@ -106,7 +106,7 @@ pub async fn read_project_docs(config: &Config) -> std::io::Result<Option<String
/// contents. The list is ordered from repository root to the current working
/// directory (inclusive). Symlinks are allowed. When `project_doc_max_bytes`
/// is zero, returns an empty list.
pub fn discover_project_doc_paths(config: &Config) -> std::io::Result<Vec<PathBuf>> {
pub fn discover_project_doc_paths(config: &AgentConfig) -> std::io::Result<Vec<PathBuf>> {
let mut dir = config.cwd.clone();
if let Ok(canon) = dir.canonicalize() {
dir = canon;
@@ -176,6 +176,7 @@ pub fn discover_project_doc_paths(config: &Config) -> std::io::Result<Vec<PathBu
#[cfg(test)]
mod tests {
use super::*;
use crate::config::Config;
use crate::config::ConfigOverrides;
use crate::config::ConfigToml;
use std::fs;
@@ -186,7 +187,7 @@ mod tests {
/// optionally specify a custom `instructions` string when `None` the
/// value is cleared to mimic a scenario where no system instructions have
/// been configured.
fn make_config(root: &TempDir, limit: usize, instructions: Option<&str>) -> Config {
fn make_config(root: &TempDir, limit: usize, instructions: Option<&str>) -> AgentConfig {
let codex_home = TempDir::new().unwrap();
let mut config = Config::load_from_base_config_with_overrides(
ConfigToml::default(),
@@ -199,7 +200,8 @@ mod tests {
config.project_doc_max_bytes = limit;
config.user_instructions = instructions.map(ToOwned::to_owned);
config
AgentConfig::from(&config)
}
/// AGENTS.md missing should yield `None`.

View File

@@ -1,16 +1,29 @@
//! Rollout module: persistence and discovery of session rollout files.
use std::path::Path;
pub const SESSIONS_SUBDIR: &str = "sessions";
pub const ARCHIVED_SESSIONS_SUBDIR: &str = "archived_sessions";
use async_trait::async_trait;
use codex_agent::rollout::GitInfoCollector as SharedGitInfoCollector;
use codex_protocol::protocol::GitInfo;
pub mod list;
pub(crate) mod policy;
pub mod recorder;
pub use codex_agent::rollout::ARCHIVED_SESSIONS_SUBDIR;
pub use codex_agent::rollout::RolloutConfig;
pub use codex_agent::rollout::RolloutRecorder;
pub use codex_agent::rollout::RolloutRecorderParams;
pub use codex_agent::rollout::SESSIONS_SUBDIR;
pub use codex_protocol::protocol::SessionMeta;
pub use list::find_conversation_path_by_id_str;
pub use recorder::RolloutRecorder;
pub use recorder::RolloutRecorderParams;
pub mod list {
pub use codex_agent::rollout::list::*;
}
#[cfg(test)]
pub mod tests;
use crate::git_info::collect_git_info;
pub struct CoreGitInfoCollector;
#[async_trait]
impl SharedGitInfoCollector for CoreGitInfoCollector {
async fn collect(&self, cwd: &Path) -> Option<GitInfo> {
collect_git_info(cwd).await
}
}

View File

@@ -0,0 +1,33 @@
use std::collections::HashMap;
use std::env;
use crate::exec::ExecParams;
use crate::function_tool::FunctionCallError;
use codex_agent::apply_patch::ApplyPatchExec;
use codex_agent::apply_patch::CODEX_APPLY_PATCH_ARG1;
pub(crate) fn build_exec_params_for_apply_patch(
exec: &ApplyPatchExec,
original: &ExecParams,
) -> Result<ExecParams, FunctionCallError> {
let path_to_codex = env::current_exe()
.ok()
.map(|p| p.to_string_lossy().to_string())
.ok_or_else(|| {
FunctionCallError::RespondToModel(
"failed to determine path to codex executable".to_string(),
)
})?;
let patch = exec.action.patch.clone();
Ok(ExecParams {
command: vec![path_to_codex, CODEX_APPLY_PATCH_ARG1.to_string(), patch],
cwd: exec.action.cwd.clone(),
timeout_ms: original.timeout_ms,
// Run apply_patch with a minimal environment for determinism and to
// avoid leaking host environment variables into the patch process.
env: HashMap::new(),
with_escalated_permissions: original.with_escalated_permissions,
justification: original.justification.clone(),
})
}

View File

@@ -0,0 +1,87 @@
use std::path::Path;
use std::path::PathBuf;
use async_trait::async_trait;
use crate::error::Result;
use crate::exec::ExecParams;
use crate::exec::ExecToolCallOutput;
use crate::exec::SandboxType;
use crate::exec::StdoutStream;
use crate::exec::process_exec_tool_call;
use crate::protocol::SandboxPolicy;
#[async_trait]
pub trait SpawnBackend: Send + Sync {
fn sandbox_type(&self) -> SandboxType;
async fn spawn(
&self,
params: ExecParams,
sandbox_policy: &SandboxPolicy,
sandbox_cwd: &Path,
codex_linux_sandbox_exe: &Option<PathBuf>,
stdout_stream: Option<StdoutStream>,
) -> Result<ExecToolCallOutput> {
process_exec_tool_call(
params,
self.sandbox_type(),
sandbox_policy,
sandbox_cwd,
codex_linux_sandbox_exe,
stdout_stream,
)
.await
}
}
#[derive(Clone, Copy, Debug, Default)]
pub struct DirectBackend;
#[async_trait]
impl SpawnBackend for DirectBackend {
fn sandbox_type(&self) -> SandboxType {
SandboxType::None
}
}
#[derive(Clone, Copy, Debug, Default)]
pub struct SeatbeltBackend;
#[async_trait]
impl SpawnBackend for SeatbeltBackend {
fn sandbox_type(&self) -> SandboxType {
SandboxType::MacosSeatbelt
}
}
#[derive(Clone, Copy, Debug, Default)]
pub struct LinuxBackend;
#[async_trait]
impl SpawnBackend for LinuxBackend {
fn sandbox_type(&self) -> SandboxType {
SandboxType::LinuxSeccomp
}
}
#[derive(Default)]
pub struct BackendRegistry {
direct: DirectBackend,
seatbelt: SeatbeltBackend,
linux: LinuxBackend,
}
impl BackendRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn for_type(&self, sandbox: SandboxType) -> &dyn SpawnBackend {
match sandbox {
SandboxType::None => &self.direct,
SandboxType::MacosSeatbelt => &self.seatbelt,
SandboxType::LinuxSeccomp => &self.linux,
}
}
}

View File

@@ -0,0 +1,51 @@
mod apply_patch_adapter;
mod backend;
mod planner;
pub use backend::BackendRegistry;
pub use backend::DirectBackend;
pub use backend::LinuxBackend;
pub use backend::SeatbeltBackend;
pub use backend::SpawnBackend;
pub use planner::ExecPlan;
pub use planner::ExecRequest;
pub use planner::PatchExecRequest;
pub(crate) use planner::PreparedExec;
pub use planner::plan_apply_patch;
pub use planner::plan_exec;
pub(crate) use planner::prepare_exec_invocation;
use crate::error::Result;
use crate::exec::ExecParams;
use crate::exec::ExecToolCallOutput;
use crate::exec::StdoutStream;
use crate::protocol::SandboxPolicy;
pub struct ExecRuntimeContext<'a> {
pub sandbox_policy: &'a SandboxPolicy,
pub sandbox_cwd: &'a std::path::Path,
pub codex_linux_sandbox_exe: &'a Option<std::path::PathBuf>,
pub stdout_stream: Option<StdoutStream>,
}
pub async fn run_with_plan(
params: ExecParams,
plan: &ExecPlan,
registry: &BackendRegistry,
runtime_ctx: &ExecRuntimeContext<'_>,
) -> Result<ExecToolCallOutput> {
let ExecPlan::Approved { sandbox, .. } = plan else {
unreachable!("run_with_plan called without approved plan");
};
registry
.for_type(*sandbox)
.spawn(
params,
runtime_ctx.sandbox_policy,
runtime_ctx.sandbox_cwd,
runtime_ctx.codex_linux_sandbox_exe,
runtime_ctx.stdout_stream.clone(),
)
.await
}

View File

@@ -0,0 +1,217 @@
use std::collections::HashSet;
use std::path::Path;
use codex_agent::apply_patch::ApplyPatchExec;
use codex_agent::safety::SafetyCheck;
use codex_agent::safety::assess_command_safety;
use codex_agent::safety::assess_patch_safety;
use codex_agent::sandbox::SandboxType;
use codex_agent::services::ApprovalCoordinator;
use codex_apply_patch::ApplyPatchAction;
use super::apply_patch_adapter::build_exec_params_for_apply_patch;
use crate::codex::TurnContext;
use crate::exec::ExecParams;
use crate::function_tool::FunctionCallError;
use crate::protocol::AskForApproval;
use crate::protocol::ReviewDecision;
use crate::protocol::SandboxPolicy;
#[derive(Clone, Debug)]
pub struct ExecRequest<'a> {
pub params: &'a ExecParams,
pub approval: AskForApproval,
pub policy: &'a SandboxPolicy,
pub approved_session_commands: &'a HashSet<Vec<String>>,
}
#[derive(Clone, Debug)]
pub enum ExecPlan {
Reject {
reason: String,
},
AskUser {
reason: Option<String>,
},
Approved {
sandbox: SandboxType,
on_failure_escalate: bool,
approved_by_user: bool,
},
}
impl ExecPlan {
pub fn approved(
sandbox: SandboxType,
on_failure_escalate: bool,
approved_by_user: bool,
) -> Self {
ExecPlan::Approved {
sandbox,
on_failure_escalate,
approved_by_user,
}
}
}
pub fn plan_exec(req: &ExecRequest<'_>) -> ExecPlan {
let params = req.params;
let with_escalated_permissions = params.with_escalated_permissions.unwrap_or(false);
let safety = assess_command_safety(
&params.command,
req.approval,
req.policy,
req.approved_session_commands,
with_escalated_permissions,
);
match safety {
SafetyCheck::AutoApprove { sandbox_type } => ExecPlan::Approved {
sandbox: sandbox_type,
on_failure_escalate: should_escalate_on_failure(req.approval, sandbox_type),
approved_by_user: false,
},
SafetyCheck::AskUser => ExecPlan::AskUser {
reason: params.justification.clone(),
},
SafetyCheck::Reject { reason } => ExecPlan::Reject { reason },
}
}
#[derive(Clone, Debug)]
pub struct PatchExecRequest<'a> {
pub action: &'a ApplyPatchAction,
pub approval: AskForApproval,
pub policy: &'a SandboxPolicy,
pub cwd: &'a Path,
pub user_explicitly_approved: bool,
}
pub fn plan_apply_patch(req: &PatchExecRequest<'_>) -> ExecPlan {
if req.user_explicitly_approved {
ExecPlan::Approved {
sandbox: SandboxType::None,
on_failure_escalate: false,
approved_by_user: true,
}
} else {
match assess_patch_safety(req.action, req.approval, req.policy, req.cwd) {
SafetyCheck::AutoApprove { sandbox_type } => ExecPlan::Approved {
sandbox: sandbox_type,
on_failure_escalate: should_escalate_on_failure(req.approval, sandbox_type),
approved_by_user: false,
},
SafetyCheck::AskUser => ExecPlan::AskUser { reason: None },
SafetyCheck::Reject { reason } => ExecPlan::Reject { reason },
}
}
}
#[derive(Debug)]
pub(crate) struct PreparedExec {
pub(crate) params: ExecParams,
pub(crate) plan: ExecPlan,
pub(crate) command_for_display: Vec<String>,
pub(crate) apply_patch_exec: Option<ApplyPatchExec>,
}
pub(crate) async fn prepare_exec_invocation(
approvals: &dyn ApprovalCoordinator,
turn_context: &TurnContext,
sub_id: &str,
call_id: &str,
params: ExecParams,
apply_patch_exec: Option<ApplyPatchExec>,
approved_session_commands: HashSet<Vec<String>>,
) -> Result<PreparedExec, FunctionCallError> {
let mut params = params;
let (plan, command_for_display) = if let Some(exec) = apply_patch_exec.as_ref() {
params = build_exec_params_for_apply_patch(exec, &params)?;
let command_for_display = vec!["apply_patch".to_string(), exec.action.patch.clone()];
let plan_req = PatchExecRequest {
action: &exec.action,
approval: turn_context.approval_policy,
policy: &turn_context.sandbox_policy,
cwd: &turn_context.cwd,
user_explicitly_approved: exec.user_explicitly_approved_this_action,
};
let plan = match plan_apply_patch(&plan_req) {
plan @ ExecPlan::Approved { .. } => plan,
ExecPlan::AskUser { .. } => {
return Err(FunctionCallError::RespondToModel(
"patch requires approval but none was recorded".to_string(),
));
}
ExecPlan::Reject { reason } => {
return Err(FunctionCallError::RespondToModel(format!(
"patch rejected: {reason}"
)));
}
};
(plan, command_for_display)
} else {
let command_for_display = params.command.clone();
let initial_plan = plan_exec(&ExecRequest {
params: &params,
approval: turn_context.approval_policy,
policy: &turn_context.sandbox_policy,
approved_session_commands: &approved_session_commands,
});
let plan = match initial_plan {
plan @ ExecPlan::Approved { .. } => plan,
ExecPlan::AskUser { reason } => {
let decision = approvals
.request_command_approval(
sub_id.to_string(),
call_id.to_string(),
params.command.clone(),
params.cwd.clone(),
reason,
)
.await;
match decision {
ReviewDecision::Approved => ExecPlan::approved(SandboxType::None, false, true),
ReviewDecision::ApprovedForSession => {
approvals.add_approved_command(params.command.clone()).await;
ExecPlan::approved(SandboxType::None, false, true)
}
ReviewDecision::Denied | ReviewDecision::Abort => {
return Err(FunctionCallError::RespondToModel(
"exec command rejected by user".to_string(),
));
}
}
}
ExecPlan::Reject { reason } => {
return Err(FunctionCallError::RespondToModel(format!(
"exec command rejected: {reason:?}"
)));
}
};
(plan, command_for_display)
};
Ok(PreparedExec {
params,
plan,
command_for_display,
apply_patch_exec,
})
}
fn should_escalate_on_failure(approval: AskForApproval, sandbox: SandboxType) -> bool {
matches!(
(approval, sandbox),
(
AskForApproval::UnlessTrusted | AskForApproval::OnFailure,
SandboxType::MacosSeatbelt | SandboxType::LinuxSeccomp
)
)
}

View File

@@ -1,571 +1 @@
use serde::Deserialize;
use serde::Serialize;
use shlex;
use std::path::PathBuf;
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
pub struct ZshShell {
pub(crate) shell_path: String,
pub(crate) zshrc_path: String,
}
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
pub struct BashShell {
pub(crate) shell_path: String,
pub(crate) bashrc_path: String,
}
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
pub struct PowerShellConfig {
pub(crate) exe: String, // Executable name or path, e.g. "pwsh" or "powershell.exe".
pub(crate) bash_exe_fallback: Option<PathBuf>, // In case the model generates a bash command.
}
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
pub enum Shell {
Zsh(ZshShell),
Bash(BashShell),
PowerShell(PowerShellConfig),
Unknown,
}
impl Shell {
pub fn format_default_shell_invocation(&self, command: Vec<String>) -> Option<Vec<String>> {
match self {
Shell::Zsh(zsh) => format_shell_invocation_with_rc(
command.as_slice(),
&zsh.shell_path,
&zsh.zshrc_path,
),
Shell::Bash(bash) => format_shell_invocation_with_rc(
command.as_slice(),
&bash.shell_path,
&bash.bashrc_path,
),
Shell::PowerShell(ps) => {
// If model generated a bash command, prefer a detected bash fallback
if let Some(script) = strip_bash_lc(command.as_slice()) {
return match &ps.bash_exe_fallback {
Some(bash) => Some(vec![
bash.to_string_lossy().to_string(),
"-lc".to_string(),
script,
]),
// No bash fallback → run the script under PowerShell.
// It will likely fail (except for some simple commands), but the error
// should give a clue to the model to fix upon retry that it's running under PowerShell.
None => Some(vec![
ps.exe.clone(),
"-NoProfile".to_string(),
"-Command".to_string(),
script,
]),
};
}
// Not a bash command. If model did not generate a PowerShell command,
// turn it into a PowerShell command.
let first = command.first().map(String::as_str);
if first != Some(ps.exe.as_str()) {
// TODO (CODEX_2900): Handle escaping newlines.
if command.iter().any(|a| a.contains('\n') || a.contains('\r')) {
return Some(command);
}
let joined = shlex::try_join(command.iter().map(String::as_str)).ok();
return joined.map(|arg| {
vec![
ps.exe.clone(),
"-NoProfile".to_string(),
"-Command".to_string(),
arg,
]
});
}
// Model generated a PowerShell command. Run it.
Some(command)
}
Shell::Unknown => None,
}
}
pub fn name(&self) -> Option<String> {
match self {
Shell::Zsh(zsh) => std::path::Path::new(&zsh.shell_path)
.file_name()
.map(|s| s.to_string_lossy().to_string()),
Shell::Bash(bash) => std::path::Path::new(&bash.shell_path)
.file_name()
.map(|s| s.to_string_lossy().to_string()),
Shell::PowerShell(ps) => Some(ps.exe.clone()),
Shell::Unknown => None,
}
}
}
fn format_shell_invocation_with_rc(
command: &[String],
shell_path: &str,
rc_path: &str,
) -> Option<Vec<String>> {
let joined = strip_bash_lc(command)
.or_else(|| shlex::try_join(command.iter().map(String::as_str)).ok())?;
let rc_command = if std::path::Path::new(rc_path).exists() {
format!("source {rc_path} && ({joined})")
} else {
joined
};
Some(vec![shell_path.to_string(), "-lc".to_string(), rc_command])
}
fn strip_bash_lc(command: &[String]) -> Option<String> {
match command {
// exactly three items
[first, second, third]
// first two must be "bash", "-lc"
if first == "bash" && second == "-lc" =>
{
Some(third.clone())
}
_ => None,
}
}
#[cfg(unix)]
fn detect_default_user_shell() -> Shell {
use libc::getpwuid;
use libc::getuid;
use std::ffi::CStr;
unsafe {
let uid = getuid();
let pw = getpwuid(uid);
if !pw.is_null() {
let shell_path = CStr::from_ptr((*pw).pw_shell)
.to_string_lossy()
.into_owned();
let home_path = CStr::from_ptr((*pw).pw_dir).to_string_lossy().into_owned();
if shell_path.ends_with("/zsh") {
return Shell::Zsh(ZshShell {
shell_path,
zshrc_path: format!("{home_path}/.zshrc"),
});
}
if shell_path.ends_with("/bash") {
return Shell::Bash(BashShell {
shell_path,
bashrc_path: format!("{home_path}/.bashrc"),
});
}
}
}
Shell::Unknown
}
#[cfg(unix)]
pub async fn default_user_shell() -> Shell {
detect_default_user_shell()
}
#[cfg(target_os = "windows")]
pub async fn default_user_shell() -> Shell {
use tokio::process::Command;
// Prefer PowerShell 7+ (`pwsh`) if available, otherwise fall back to Windows PowerShell.
let has_pwsh = Command::new("pwsh")
.arg("-NoLogo")
.arg("-NoProfile")
.arg("-Command")
.arg("$PSVersionTable.PSVersion.Major")
.output()
.await
.map(|o| o.status.success())
.unwrap_or(false);
let bash_exe = if Command::new("bash.exe")
.arg("--version")
.output()
.await
.ok()
.map(|o| o.status.success())
.unwrap_or(false)
{
which::which("bash.exe").ok()
} else {
None
};
if has_pwsh {
Shell::PowerShell(PowerShellConfig {
exe: "pwsh.exe".to_string(),
bash_exe_fallback: bash_exe,
})
} else {
Shell::PowerShell(PowerShellConfig {
exe: "powershell.exe".to_string(),
bash_exe_fallback: bash_exe,
})
}
}
#[cfg(all(not(target_os = "windows"), not(unix)))]
pub async fn default_user_shell() -> Shell {
Shell::Unknown
}
#[cfg(test)]
#[cfg(unix)]
mod tests {
use super::*;
use std::process::Command;
use std::string::ToString;
#[tokio::test]
async fn test_current_shell_detects_zsh() {
let shell = Command::new("sh")
.arg("-c")
.arg("echo $SHELL")
.output()
.unwrap();
let home = std::env::var("HOME").unwrap();
let shell_path = String::from_utf8_lossy(&shell.stdout).trim().to_string();
if shell_path.ends_with("/zsh") {
assert_eq!(
default_user_shell().await,
Shell::Zsh(ZshShell {
shell_path: shell_path.to_string(),
zshrc_path: format!("{home}/.zshrc",),
})
);
}
}
#[tokio::test]
async fn test_run_with_profile_zshrc_not_exists() {
let shell = Shell::Zsh(ZshShell {
shell_path: "/bin/zsh".to_string(),
zshrc_path: "/does/not/exist/.zshrc".to_string(),
});
let actual_cmd = shell.format_default_shell_invocation(vec!["myecho".to_string()]);
assert_eq!(
actual_cmd,
Some(vec![
"/bin/zsh".to_string(),
"-lc".to_string(),
"myecho".to_string()
])
);
}
#[tokio::test]
async fn test_run_with_profile_bashrc_not_exists() {
let shell = Shell::Bash(BashShell {
shell_path: "/bin/bash".to_string(),
bashrc_path: "/does/not/exist/.bashrc".to_string(),
});
let actual_cmd = shell.format_default_shell_invocation(vec!["myecho".to_string()]);
assert_eq!(
actual_cmd,
Some(vec![
"/bin/bash".to_string(),
"-lc".to_string(),
"myecho".to_string()
])
);
}
#[tokio::test]
async fn test_run_with_profile_bash_escaping_and_execution() {
let shell_path = "/bin/bash";
let cases = vec![
(
vec!["myecho"],
vec![shell_path, "-lc", "source BASHRC_PATH && (myecho)"],
Some("It works!\n"),
),
(
vec!["bash", "-lc", "echo 'single' \"double\""],
vec![
shell_path,
"-lc",
"source BASHRC_PATH && (echo 'single' \"double\")",
],
Some("single double\n"),
),
];
for (input, expected_cmd, expected_output) in cases {
use std::collections::HashMap;
use crate::exec::ExecParams;
use crate::exec::SandboxType;
use crate::exec::process_exec_tool_call;
use crate::protocol::SandboxPolicy;
let temp_home = tempfile::tempdir().unwrap();
let bashrc_path = temp_home.path().join(".bashrc");
std::fs::write(
&bashrc_path,
r#"
set -x
function myecho {
echo 'It works!'
}
"#,
)
.unwrap();
let shell = Shell::Bash(BashShell {
shell_path: shell_path.to_string(),
bashrc_path: bashrc_path.to_str().unwrap().to_string(),
});
let actual_cmd = shell
.format_default_shell_invocation(input.iter().map(ToString::to_string).collect());
let expected_cmd = expected_cmd
.iter()
.map(|s| s.replace("BASHRC_PATH", bashrc_path.to_str().unwrap()))
.collect();
assert_eq!(actual_cmd, Some(expected_cmd));
let output = process_exec_tool_call(
ExecParams {
command: actual_cmd.unwrap(),
cwd: PathBuf::from(temp_home.path()),
timeout_ms: None,
env: HashMap::from([(
"HOME".to_string(),
temp_home.path().to_str().unwrap().to_string(),
)]),
with_escalated_permissions: None,
justification: None,
},
SandboxType::None,
&SandboxPolicy::DangerFullAccess,
temp_home.path(),
&None,
None,
)
.await
.unwrap();
assert_eq!(output.exit_code, 0, "input: {input:?} output: {output:?}");
if let Some(expected) = expected_output {
assert_eq!(
output.stdout.text, expected,
"input: {input:?} output: {output:?}"
);
}
}
}
}
#[cfg(test)]
#[cfg(target_os = "macos")]
mod macos_tests {
use super::*;
use std::string::ToString;
#[tokio::test]
async fn test_run_with_profile_escaping_and_execution() {
let shell_path = "/bin/zsh";
let cases = vec![
(
vec!["myecho"],
vec![shell_path, "-lc", "source ZSHRC_PATH && (myecho)"],
Some("It works!\n"),
),
(
vec!["myecho"],
vec![shell_path, "-lc", "source ZSHRC_PATH && (myecho)"],
Some("It works!\n"),
),
(
vec!["bash", "-c", "echo 'single' \"double\""],
vec![
shell_path,
"-lc",
"source ZSHRC_PATH && (bash -c \"echo 'single' \\\"double\\\"\")",
],
Some("single double\n"),
),
(
vec!["bash", "-lc", "echo 'single' \"double\""],
vec![
shell_path,
"-lc",
"source ZSHRC_PATH && (echo 'single' \"double\")",
],
Some("single double\n"),
),
];
for (input, expected_cmd, expected_output) in cases {
use std::collections::HashMap;
use std::path::PathBuf;
use crate::exec::ExecParams;
use crate::exec::SandboxType;
use crate::exec::process_exec_tool_call;
use crate::protocol::SandboxPolicy;
// create a temp directory with a zshrc file in it
let temp_home = tempfile::tempdir().unwrap();
let zshrc_path = temp_home.path().join(".zshrc");
std::fs::write(
&zshrc_path,
r#"
set -x
function myecho {
echo 'It works!'
}
"#,
)
.unwrap();
let shell = Shell::Zsh(ZshShell {
shell_path: shell_path.to_string(),
zshrc_path: zshrc_path.to_str().unwrap().to_string(),
});
let actual_cmd = shell
.format_default_shell_invocation(input.iter().map(ToString::to_string).collect());
let expected_cmd = expected_cmd
.iter()
.map(|s| s.replace("ZSHRC_PATH", zshrc_path.to_str().unwrap()))
.collect();
assert_eq!(actual_cmd, Some(expected_cmd));
// Actually run the command and check output/exit code
let output = process_exec_tool_call(
ExecParams {
command: actual_cmd.unwrap(),
cwd: PathBuf::from(temp_home.path()),
timeout_ms: None,
env: HashMap::from([(
"HOME".to_string(),
temp_home.path().to_str().unwrap().to_string(),
)]),
with_escalated_permissions: None,
justification: None,
},
SandboxType::None,
&SandboxPolicy::DangerFullAccess,
temp_home.path(),
&None,
None,
)
.await
.unwrap();
assert_eq!(output.exit_code, 0, "input: {input:?} output: {output:?}");
if let Some(expected) = expected_output {
assert_eq!(
output.stdout.text, expected,
"input: {input:?} output: {output:?}"
);
}
}
}
}
#[cfg(test)]
#[cfg(target_os = "windows")]
mod tests_windows {
use super::*;
#[test]
fn test_format_default_shell_invocation_powershell() {
let cases = vec![
(
Shell::PowerShell(PowerShellConfig {
exe: "pwsh.exe".to_string(),
bash_exe_fallback: None,
}),
vec!["bash", "-lc", "echo hello"],
vec!["pwsh.exe", "-NoProfile", "-Command", "echo hello"],
),
(
Shell::PowerShell(PowerShellConfig {
exe: "powershell.exe".to_string(),
bash_exe_fallback: None,
}),
vec!["bash", "-lc", "echo hello"],
vec!["powershell.exe", "-NoProfile", "-Command", "echo hello"],
),
(
Shell::PowerShell(PowerShellConfig {
exe: "pwsh.exe".to_string(),
bash_exe_fallback: Some(PathBuf::from("bash.exe")),
}),
vec!["bash", "-lc", "echo hello"],
vec!["bash.exe", "-lc", "echo hello"],
),
(
Shell::PowerShell(PowerShellConfig {
exe: "pwsh.exe".to_string(),
bash_exe_fallback: Some(PathBuf::from("bash.exe")),
}),
vec![
"bash",
"-lc",
"apply_patch <<'EOF'\n*** Begin Patch\n*** Update File: destination_file.txt\n-original content\n+modified content\n*** End Patch\nEOF",
],
vec![
"bash.exe",
"-lc",
"apply_patch <<'EOF'\n*** Begin Patch\n*** Update File: destination_file.txt\n-original content\n+modified content\n*** End Patch\nEOF",
],
),
(
Shell::PowerShell(PowerShellConfig {
exe: "pwsh.exe".to_string(),
bash_exe_fallback: Some(PathBuf::from("bash.exe")),
}),
vec!["echo", "hello"],
vec!["pwsh.exe", "-NoProfile", "-Command", "echo hello"],
),
(
Shell::PowerShell(PowerShellConfig {
exe: "pwsh.exe".to_string(),
bash_exe_fallback: Some(PathBuf::from("bash.exe")),
}),
vec!["pwsh.exe", "-NoProfile", "-Command", "echo hello"],
vec!["pwsh.exe", "-NoProfile", "-Command", "echo hello"],
),
(
// TODO (CODEX_2900): Handle escaping newlines for powershell invocation.
Shell::PowerShell(PowerShellConfig {
exe: "powershell.exe".to_string(),
bash_exe_fallback: Some(PathBuf::from("bash.exe")),
}),
vec![
"codex-mcp-server.exe",
"--codex-run-as-apply-patch",
"*** Begin Patch\n*** Update File: C:\\Users\\person\\destination_file.txt\n-original content\n+modified content\n*** End Patch",
],
vec![
"codex-mcp-server.exe",
"--codex-run-as-apply-patch",
"*** Begin Patch\n*** Update File: C:\\Users\\person\\destination_file.txt\n-original content\n+modified content\n*** End Patch",
],
),
];
for (shell, input, expected_cmd) in cases {
let actual_cmd = shell
.format_default_shell_invocation(input.iter().map(|s| (*s).to_string()).collect());
assert_eq!(
actual_cmd,
Some(expected_cmd.iter().map(|s| (*s).to_string()).collect())
);
}
}
}
pub use codex_agent::shell::*;

View File

@@ -1,9 +1,6 @@
mod service;
mod session;
mod turn;
pub(crate) use service::SessionServices;
pub(crate) use session::SessionState;
pub(crate) use codex_agent::session_state::SessionState;
pub(crate) use turn::ActiveTurn;
pub(crate) use turn::RunningTask;
pub(crate) use turn::TaskKind;

View File

@@ -1,18 +0,0 @@
use crate::RolloutRecorder;
use crate::exec_command::ExecSessionManager;
use crate::mcp_connection_manager::McpConnectionManager;
use crate::unified_exec::UnifiedExecSessionManager;
use crate::user_notification::UserNotifier;
use std::path::PathBuf;
use tokio::sync::Mutex;
pub(crate) struct SessionServices {
pub(crate) mcp_connection_manager: McpConnectionManager,
pub(crate) session_manager: ExecSessionManager,
pub(crate) unified_exec_manager: UnifiedExecSessionManager,
pub(crate) notifier: UserNotifier,
pub(crate) rollout: Mutex<Option<RolloutRecorder>>,
pub(crate) codex_linux_sandbox_exe: Option<PathBuf>,
pub(crate) user_shell: crate::shell::Shell,
pub(crate) show_raw_agent_reasoning: bool,
}

View File

@@ -1,182 +1 @@
use base64::Engine;
use serde::Deserialize;
use serde::Serialize;
use thiserror::Error;
#[derive(Deserialize, Serialize, Clone, Debug, PartialEq, Default)]
pub struct TokenData {
/// Flat info parsed from the JWT in auth.json.
#[serde(
deserialize_with = "deserialize_id_token",
serialize_with = "serialize_id_token"
)]
pub id_token: IdTokenInfo,
/// This is a JWT.
pub access_token: String,
pub refresh_token: String,
pub account_id: Option<String>,
}
/// Flat subset of useful claims in id_token from auth.json.
#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
pub struct IdTokenInfo {
pub email: Option<String>,
/// The ChatGPT subscription plan type
/// (e.g., "free", "plus", "pro", "business", "enterprise", "edu").
/// (Note: values may vary by backend.)
pub(crate) chatgpt_plan_type: Option<PlanType>,
pub raw_jwt: String,
}
impl IdTokenInfo {
pub fn get_chatgpt_plan_type(&self) -> Option<String> {
self.chatgpt_plan_type.as_ref().map(|t| match t {
PlanType::Known(plan) => format!("{plan:?}"),
PlanType::Unknown(s) => s.clone(),
})
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(untagged)]
pub(crate) enum PlanType {
Known(KnownPlan),
Unknown(String),
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub(crate) enum KnownPlan {
Free,
Plus,
Pro,
Team,
Business,
Enterprise,
Edu,
}
#[derive(Deserialize)]
struct IdClaims {
#[serde(default)]
email: Option<String>,
#[serde(rename = "https://api.openai.com/auth", default)]
auth: Option<AuthClaims>,
}
#[derive(Deserialize)]
struct AuthClaims {
#[serde(default)]
chatgpt_plan_type: Option<PlanType>,
}
#[derive(Debug, Error)]
pub enum IdTokenInfoError {
#[error("invalid ID token format")]
InvalidFormat,
#[error(transparent)]
Base64(#[from] base64::DecodeError),
#[error(transparent)]
Json(#[from] serde_json::Error),
}
pub fn parse_id_token(id_token: &str) -> Result<IdTokenInfo, IdTokenInfoError> {
// JWT format: header.payload.signature
let mut parts = id_token.split('.');
let (_header_b64, payload_b64, _sig_b64) = match (parts.next(), parts.next(), parts.next()) {
(Some(h), Some(p), Some(s)) if !h.is_empty() && !p.is_empty() && !s.is_empty() => (h, p, s),
_ => return Err(IdTokenInfoError::InvalidFormat),
};
let payload_bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(payload_b64)?;
let claims: IdClaims = serde_json::from_slice(&payload_bytes)?;
Ok(IdTokenInfo {
email: claims.email,
chatgpt_plan_type: claims.auth.and_then(|a| a.chatgpt_plan_type),
raw_jwt: id_token.to_string(),
})
}
fn deserialize_id_token<'de, D>(deserializer: D) -> Result<IdTokenInfo, D::Error>
where
D: serde::Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
parse_id_token(&s).map_err(serde::de::Error::custom)
}
fn serialize_id_token<S>(id_token: &IdTokenInfo, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(&id_token.raw_jwt)
}
#[cfg(test)]
mod tests {
use super::*;
use serde::Serialize;
#[test]
fn id_token_info_parses_email_and_plan() {
#[derive(Serialize)]
struct Header {
alg: &'static str,
typ: &'static str,
}
let header = Header {
alg: "none",
typ: "JWT",
};
let payload = serde_json::json!({
"email": "user@example.com",
"https://api.openai.com/auth": {
"chatgpt_plan_type": "pro"
}
});
fn b64url_no_pad(bytes: &[u8]) -> String {
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes)
}
let header_b64 = b64url_no_pad(&serde_json::to_vec(&header).unwrap());
let payload_b64 = b64url_no_pad(&serde_json::to_vec(&payload).unwrap());
let signature_b64 = b64url_no_pad(b"sig");
let fake_jwt = format!("{header_b64}.{payload_b64}.{signature_b64}");
let info = parse_id_token(&fake_jwt).expect("should parse");
assert_eq!(info.email.as_deref(), Some("user@example.com"));
assert_eq!(info.get_chatgpt_plan_type().as_deref(), Some("Pro"));
}
#[test]
fn id_token_info_handles_missing_fields() {
#[derive(Serialize)]
struct Header {
alg: &'static str,
typ: &'static str,
}
let header = Header {
alg: "none",
typ: "JWT",
};
let payload = serde_json::json!({ "sub": "123" });
fn b64url_no_pad(bytes: &[u8]) -> String {
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes)
}
let header_b64 = b64url_no_pad(&serde_json::to_vec(&header).unwrap());
let payload_b64 = b64url_no_pad(&serde_json::to_vec(&payload).unwrap());
let signature_b64 = b64url_no_pad(b"sig");
let fake_jwt = format!("{header_b64}.{payload_b64}.{signature_b64}");
let info = parse_id_token(&fake_jwt).expect("should parse");
assert!(info.email.is_none());
assert!(info.get_chatgpt_plan_type().is_none());
}
}
pub use codex_agent::token_data::*;

View File

@@ -1,5 +1,3 @@
use serde::Deserialize;
use serde::Serialize;
use std::collections::BTreeMap;
use crate::openai_tools::FreeformTool;
@@ -7,16 +5,10 @@ use crate::openai_tools::FreeformToolFormat;
use crate::openai_tools::JsonSchema;
use crate::openai_tools::OpenAiTool;
use crate::openai_tools::ResponsesApiTool;
pub use codex_agent::ApplyPatchToolType;
const APPLY_PATCH_LARK_GRAMMAR: &str = include_str!("tool_apply_patch.lark");
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
#[serde(rename_all = "snake_case")]
pub enum ApplyPatchToolType {
Freeform,
Function,
}
/// Returns a custom tool that can be used to edit files. Well-suited for GPT-5 models
/// https://platform.openai.com/docs/guides/function-calling#custom-tools
pub(crate) fn create_apply_patch_freeform_tool() -> OpenAiTool {

View File

@@ -1,180 +1 @@
//! Utilities for truncating large chunks of output while preserving a prefix
//! and suffix on UTF-8 boundaries.
/// Truncate the middle of a UTF-8 string to at most `max_bytes` bytes,
/// preserving the beginning and the end. Returns the possibly truncated
/// string and `Some(original_token_count)` (estimated at 4 bytes/token)
/// if truncation occurred; otherwise returns the original string and `None`.
pub(crate) fn truncate_middle(s: &str, max_bytes: usize) -> (String, Option<u64>) {
if s.len() <= max_bytes {
return (s.to_string(), None);
}
let est_tokens = (s.len() as u64).div_ceil(4);
if max_bytes == 0 {
return (format!("{est_tokens} tokens truncated…"), Some(est_tokens));
}
fn truncate_on_boundary(input: &str, max_len: usize) -> &str {
if input.len() <= max_len {
return input;
}
let mut end = max_len;
while end > 0 && !input.is_char_boundary(end) {
end -= 1;
}
&input[..end]
}
fn pick_prefix_end(s: &str, left_budget: usize) -> usize {
if let Some(head) = s.get(..left_budget)
&& let Some(i) = head.rfind('\n')
{
return i + 1;
}
truncate_on_boundary(s, left_budget).len()
}
fn pick_suffix_start(s: &str, right_budget: usize) -> usize {
let start_tail = s.len().saturating_sub(right_budget);
if let Some(tail) = s.get(start_tail..)
&& let Some(i) = tail.find('\n')
{
return start_tail + i + 1;
}
let mut idx = start_tail.min(s.len());
while idx < s.len() && !s.is_char_boundary(idx) {
idx += 1;
}
idx
}
let mut guess_tokens = est_tokens;
for _ in 0..4 {
let marker = format!("{guess_tokens} tokens truncated…");
let marker_len = marker.len();
let keep_budget = max_bytes.saturating_sub(marker_len);
if keep_budget == 0 {
return (format!("{est_tokens} tokens truncated…"), Some(est_tokens));
}
let left_budget = keep_budget / 2;
let right_budget = keep_budget - left_budget;
let prefix_end = pick_prefix_end(s, left_budget);
let mut suffix_start = pick_suffix_start(s, right_budget);
if suffix_start < prefix_end {
suffix_start = prefix_end;
}
let kept_content_bytes = prefix_end + (s.len() - suffix_start);
let truncated_content_bytes = s.len().saturating_sub(kept_content_bytes);
let new_tokens = (truncated_content_bytes as u64).div_ceil(4);
if new_tokens == guess_tokens {
let mut out = String::with_capacity(marker_len + kept_content_bytes + 1);
out.push_str(&s[..prefix_end]);
out.push_str(&marker);
out.push('\n');
out.push_str(&s[suffix_start..]);
return (out, Some(est_tokens));
}
guess_tokens = new_tokens;
}
let marker = format!("{guess_tokens} tokens truncated…");
let marker_len = marker.len();
let keep_budget = max_bytes.saturating_sub(marker_len);
if keep_budget == 0 {
return (format!("{est_tokens} tokens truncated…"), Some(est_tokens));
}
let left_budget = keep_budget / 2;
let right_budget = keep_budget - left_budget;
let prefix_end = pick_prefix_end(s, left_budget);
let suffix_start = pick_suffix_start(s, right_budget);
let mut out = String::with_capacity(marker_len + prefix_end + (s.len() - suffix_start) + 1);
out.push_str(&s[..prefix_end]);
out.push_str(&marker);
out.push('\n');
out.push_str(&s[suffix_start..]);
(out, Some(est_tokens))
}
#[cfg(test)]
mod tests {
use super::truncate_middle;
#[test]
fn truncate_middle_no_newlines_fallback() {
let s = "abcdefghijklmnopqrstuvwxyz0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ*";
let max_bytes = 32;
let (out, original) = truncate_middle(s, max_bytes);
assert!(out.starts_with("abc"));
assert!(out.contains("tokens truncated"));
assert!(out.ends_with("XYZ*"));
assert_eq!(original, Some((s.len() as u64).div_ceil(4)));
}
#[test]
fn truncate_middle_prefers_newline_boundaries() {
let mut s = String::new();
for i in 1..=20 {
s.push_str(&format!("{i:03}\n"));
}
assert_eq!(s.len(), 80);
let max_bytes = 64;
let (out, tokens) = truncate_middle(&s, max_bytes);
assert!(out.starts_with("001\n002\n003\n004\n"));
assert!(out.contains("tokens truncated"));
assert!(out.ends_with("017\n018\n019\n020\n"));
assert_eq!(tokens, Some(20));
}
#[test]
fn truncate_middle_handles_utf8_content() {
let s = "😀😀😀😀😀😀😀😀😀😀\nsecond line with ascii text\n";
let max_bytes = 32;
let (out, tokens) = truncate_middle(s, max_bytes);
assert!(out.contains("tokens truncated"));
assert!(!out.contains('\u{fffd}'));
assert_eq!(tokens, Some((s.len() as u64).div_ceil(4)));
}
#[test]
fn truncate_middle_prefers_newline_boundaries_2() {
// Build a multi-line string of 20 numbered lines (each "NNN\n").
let mut s = String::new();
for i in 1..=20 {
s.push_str(&format!("{i:03}\n"));
}
// Total length: 20 lines * 4 bytes per line = 80 bytes.
assert_eq!(s.len(), 80);
// Choose a cap that forces truncation while leaving room for
// a few lines on each side after accounting for the marker.
let max_bytes = 64;
// Expect exact output: first 4 lines, marker, last 4 lines, and correct token estimate (80/4 = 20).
assert_eq!(
truncate_middle(&s, max_bytes),
(
r#"001
002
003
004
…12 tokens truncated…
017
018
019
020
"#
.to_string(),
Some(20)
)
);
}
}
pub use codex_agent::truncate::*;

View File

@@ -1,896 +1 @@
use std::collections::HashMap;
use std::fs;
use std::path::Path;
use std::path::PathBuf;
use std::process::Command;
use anyhow::Context;
use anyhow::Result;
use anyhow::anyhow;
use sha1::digest::Output;
use uuid::Uuid;
use crate::protocol::FileChange;
const ZERO_OID: &str = "0000000000000000000000000000000000000000";
const DEV_NULL: &str = "/dev/null";
struct BaselineFileInfo {
path: PathBuf,
content: Vec<u8>,
mode: FileMode,
oid: String,
}
/// Tracks sets of changes to files and exposes the overall unified diff.
/// Internally, the way this works is now:
/// 1. Maintain an in-memory baseline snapshot of files when they are first seen.
/// For new additions, do not create a baseline so that diffs are shown as proper additions (using /dev/null).
/// 2. Keep a stable internal filename (uuid) per external path for rename tracking.
/// 3. To compute the aggregated unified diff, compare each baseline snapshot to the current file on disk entirely in-memory
/// using the `similar` crate and emit unified diffs with rewritten external paths.
#[derive(Default)]
pub struct TurnDiffTracker {
/// Map external path -> internal filename (uuid).
external_to_temp_name: HashMap<PathBuf, String>,
/// Internal filename -> baseline file info.
baseline_file_info: HashMap<String, BaselineFileInfo>,
/// Internal filename -> external path as of current accumulated state (after applying all changes).
/// This is where renames are tracked.
temp_name_to_current_path: HashMap<String, PathBuf>,
/// Cache of known git worktree roots to avoid repeated filesystem walks.
git_root_cache: Vec<PathBuf>,
}
impl TurnDiffTracker {
pub fn new() -> Self {
Self::default()
}
/// Front-run apply patch calls to track the starting contents of any modified files.
/// - Creates an in-memory baseline snapshot for files that already exist on disk when first seen.
/// - For additions, we intentionally do not create a baseline snapshot so that diffs are proper additions.
/// - Also updates internal mappings for move/rename events.
pub fn on_patch_begin(&mut self, changes: &HashMap<PathBuf, FileChange>) {
for (path, change) in changes.iter() {
// Ensure a stable internal filename exists for this external path.
if !self.external_to_temp_name.contains_key(path) {
let internal = Uuid::new_v4().to_string();
self.external_to_temp_name
.insert(path.clone(), internal.clone());
self.temp_name_to_current_path
.insert(internal.clone(), path.clone());
// If the file exists on disk now, snapshot as baseline; else leave missing to represent /dev/null.
let baseline_file_info = if path.exists() {
let mode = file_mode_for_path(path);
let mode_val = mode.unwrap_or(FileMode::Regular);
let content = blob_bytes(path, mode_val).unwrap_or_default();
let oid = if mode == Some(FileMode::Symlink) {
format!("{:x}", git_blob_sha1_hex_bytes(&content))
} else {
self.git_blob_oid_for_path(path)
.unwrap_or_else(|| format!("{:x}", git_blob_sha1_hex_bytes(&content)))
};
Some(BaselineFileInfo {
path: path.clone(),
content,
mode: mode_val,
oid,
})
} else {
Some(BaselineFileInfo {
path: path.clone(),
content: vec![],
mode: FileMode::Regular,
oid: ZERO_OID.to_string(),
})
};
if let Some(baseline_file_info) = baseline_file_info {
self.baseline_file_info
.insert(internal.clone(), baseline_file_info);
}
}
// Track rename/move in current mapping if provided in an Update.
if let FileChange::Update {
move_path: Some(dest),
..
} = change
{
let uuid_filename = match self.external_to_temp_name.get(path) {
Some(i) => i.clone(),
None => {
// This should be rare, but if we haven't mapped the source, create it with no baseline.
let i = Uuid::new_v4().to_string();
self.baseline_file_info.insert(
i.clone(),
BaselineFileInfo {
path: path.clone(),
content: vec![],
mode: FileMode::Regular,
oid: ZERO_OID.to_string(),
},
);
i
}
};
// Update current external mapping for temp file name.
self.temp_name_to_current_path
.insert(uuid_filename.clone(), dest.clone());
// Update forward file_mapping: external current -> internal name.
self.external_to_temp_name.remove(path);
self.external_to_temp_name
.insert(dest.clone(), uuid_filename);
};
}
}
fn get_path_for_internal(&self, internal: &str) -> Option<PathBuf> {
self.temp_name_to_current_path
.get(internal)
.cloned()
.or_else(|| {
self.baseline_file_info
.get(internal)
.map(|info| info.path.clone())
})
}
/// Find the git worktree root for a file/directory by walking up to the first ancestor containing a `.git` entry.
/// Uses a simple cache of known roots and avoids negative-result caching for simplicity.
fn find_git_root_cached(&mut self, start: &Path) -> Option<PathBuf> {
let dir = if start.is_dir() {
start
} else {
start.parent()?
};
// Fast path: if any cached root is an ancestor of this path, use it.
if let Some(root) = self
.git_root_cache
.iter()
.find(|r| dir.starts_with(r))
.cloned()
{
return Some(root);
}
// Walk up to find a `.git` marker.
let mut cur = dir.to_path_buf();
loop {
let git_marker = cur.join(".git");
if git_marker.is_dir() || git_marker.is_file() {
if !self.git_root_cache.iter().any(|r| r == &cur) {
self.git_root_cache.push(cur.clone());
}
return Some(cur);
}
// On Windows, avoid walking above the drive or UNC share root.
#[cfg(windows)]
{
if is_windows_drive_or_unc_root(&cur) {
return None;
}
}
if let Some(parent) = cur.parent() {
cur = parent.to_path_buf();
} else {
return None;
}
}
}
/// Return a display string for `path` relative to its git root if found, else absolute.
fn relative_to_git_root_str(&mut self, path: &Path) -> String {
let s = if let Some(root) = self.find_git_root_cached(path) {
if let Ok(rel) = path.strip_prefix(&root) {
rel.display().to_string()
} else {
path.display().to_string()
}
} else {
path.display().to_string()
};
s.replace('\\', "/")
}
/// Ask git to compute the blob SHA-1 for the file at `path` within its repository.
/// Returns None if no repository is found or git invocation fails.
fn git_blob_oid_for_path(&mut self, path: &Path) -> Option<String> {
let root = self.find_git_root_cached(path)?;
// Compute a path relative to the repo root for better portability across platforms.
let rel = path.strip_prefix(&root).unwrap_or(path);
let output = Command::new("git")
.arg("-C")
.arg(&root)
.arg("hash-object")
.arg("--")
.arg(rel)
.output()
.ok()?;
if !output.status.success() {
return None;
}
let s = String::from_utf8_lossy(&output.stdout).trim().to_string();
if s.len() == 40 { Some(s) } else { None }
}
/// Recompute the aggregated unified diff by comparing all of the in-memory snapshots that were
/// collected before the first time they were touched by apply_patch during this turn with
/// the current repo state.
pub fn get_unified_diff(&mut self) -> Result<Option<String>> {
let mut aggregated = String::new();
// Compute diffs per tracked internal file in a stable order by external path.
let mut baseline_file_names: Vec<String> =
self.baseline_file_info.keys().cloned().collect();
// Sort lexicographically by full repo-relative path to match git behavior.
baseline_file_names.sort_by_key(|internal| {
self.get_path_for_internal(internal)
.map(|p| self.relative_to_git_root_str(&p))
.unwrap_or_default()
});
for internal in baseline_file_names {
aggregated.push_str(self.get_file_diff(&internal).as_str());
if !aggregated.ends_with('\n') {
aggregated.push('\n');
}
}
if aggregated.trim().is_empty() {
Ok(None)
} else {
Ok(Some(aggregated))
}
}
fn get_file_diff(&mut self, internal_file_name: &str) -> String {
let mut aggregated = String::new();
// Snapshot lightweight fields only.
let (baseline_external_path, baseline_mode, left_oid) = {
if let Some(info) = self.baseline_file_info.get(internal_file_name) {
(info.path.clone(), info.mode, info.oid.clone())
} else {
(PathBuf::new(), FileMode::Regular, ZERO_OID.to_string())
}
};
let current_external_path = match self.get_path_for_internal(internal_file_name) {
Some(p) => p,
None => return aggregated,
};
let current_mode = file_mode_for_path(&current_external_path).unwrap_or(FileMode::Regular);
let right_bytes = blob_bytes(&current_external_path, current_mode);
// Compute displays with &mut self before borrowing any baseline content.
let left_display = self.relative_to_git_root_str(&baseline_external_path);
let right_display = self.relative_to_git_root_str(&current_external_path);
// Compute right oid before borrowing baseline content.
let right_oid = if let Some(b) = right_bytes.as_ref() {
if current_mode == FileMode::Symlink {
format!("{:x}", git_blob_sha1_hex_bytes(b))
} else {
self.git_blob_oid_for_path(&current_external_path)
.unwrap_or_else(|| format!("{:x}", git_blob_sha1_hex_bytes(b)))
}
} else {
ZERO_OID.to_string()
};
// Borrow baseline content only after all &mut self uses are done.
let left_present = left_oid.as_str() != ZERO_OID;
let left_bytes: Option<&[u8]> = if left_present {
self.baseline_file_info
.get(internal_file_name)
.map(|i| i.content.as_slice())
} else {
None
};
// Fast path: identical bytes or both missing.
if left_bytes == right_bytes.as_deref() {
return aggregated;
}
aggregated.push_str(&format!("diff --git a/{left_display} b/{right_display}\n"));
let is_add = !left_present && right_bytes.is_some();
let is_delete = left_present && right_bytes.is_none();
if is_add {
aggregated.push_str(&format!("new file mode {current_mode}\n"));
} else if is_delete {
aggregated.push_str(&format!("deleted file mode {baseline_mode}\n"));
} else if baseline_mode != current_mode {
aggregated.push_str(&format!("old mode {baseline_mode}\n"));
aggregated.push_str(&format!("new mode {current_mode}\n"));
}
let left_text = left_bytes.and_then(|b| std::str::from_utf8(b).ok());
let right_text = right_bytes
.as_deref()
.and_then(|b| std::str::from_utf8(b).ok());
let can_text_diff = matches!(
(left_text, right_text, is_add, is_delete),
(Some(_), Some(_), _, _) | (_, Some(_), true, _) | (Some(_), _, _, true)
);
if can_text_diff {
let l = left_text.unwrap_or("");
let r = right_text.unwrap_or("");
aggregated.push_str(&format!("index {left_oid}..{right_oid}\n"));
let old_header = if left_present {
format!("a/{left_display}")
} else {
DEV_NULL.to_string()
};
let new_header = if right_bytes.is_some() {
format!("b/{right_display}")
} else {
DEV_NULL.to_string()
};
let diff = similar::TextDiff::from_lines(l, r);
let unified = diff
.unified_diff()
.context_radius(3)
.header(&old_header, &new_header)
.to_string();
aggregated.push_str(&unified);
} else {
aggregated.push_str(&format!("index {left_oid}..{right_oid}\n"));
let old_header = if left_present {
format!("a/{left_display}")
} else {
DEV_NULL.to_string()
};
let new_header = if right_bytes.is_some() {
format!("b/{right_display}")
} else {
DEV_NULL.to_string()
};
aggregated.push_str(&format!("--- {old_header}\n"));
aggregated.push_str(&format!("+++ {new_header}\n"));
aggregated.push_str("Binary files differ\n");
}
aggregated
}
}
/// Compute the Git SHA-1 blob object ID for the given content (bytes).
fn git_blob_sha1_hex_bytes(data: &[u8]) -> Output<sha1::Sha1> {
// Git blob hash is sha1 of: "blob <len>\0<data>"
let header = format!("blob {}\0", data.len());
use sha1::Digest;
let mut hasher = sha1::Sha1::new();
hasher.update(header.as_bytes());
hasher.update(data);
hasher.finalize()
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum FileMode {
Regular,
#[cfg(unix)]
Executable,
Symlink,
}
impl FileMode {
fn as_str(self) -> &'static str {
match self {
FileMode::Regular => "100644",
#[cfg(unix)]
FileMode::Executable => "100755",
FileMode::Symlink => "120000",
}
}
}
impl std::fmt::Display for FileMode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
}
}
#[cfg(unix)]
fn file_mode_for_path(path: &Path) -> Option<FileMode> {
use std::os::unix::fs::PermissionsExt;
let meta = fs::symlink_metadata(path).ok()?;
let ft = meta.file_type();
if ft.is_symlink() {
return Some(FileMode::Symlink);
}
let mode = meta.permissions().mode();
let is_exec = (mode & 0o111) != 0;
Some(if is_exec {
FileMode::Executable
} else {
FileMode::Regular
})
}
#[cfg(not(unix))]
fn file_mode_for_path(_path: &Path) -> Option<FileMode> {
// Default to non-executable on non-unix.
Some(FileMode::Regular)
}
fn blob_bytes(path: &Path, mode: FileMode) -> Option<Vec<u8>> {
if path.exists() {
let contents = if mode == FileMode::Symlink {
symlink_blob_bytes(path)
.ok_or_else(|| anyhow!("failed to read symlink target for {}", path.display()))
} else {
fs::read(path)
.with_context(|| format!("failed to read current file for diff {}", path.display()))
};
contents.ok()
} else {
None
}
}
#[cfg(unix)]
fn symlink_blob_bytes(path: &Path) -> Option<Vec<u8>> {
use std::os::unix::ffi::OsStrExt;
let target = std::fs::read_link(path).ok()?;
Some(target.as_os_str().as_bytes().to_vec())
}
#[cfg(not(unix))]
fn symlink_blob_bytes(_path: &Path) -> Option<Vec<u8>> {
None
}
#[cfg(windows)]
fn is_windows_drive_or_unc_root(p: &std::path::Path) -> bool {
use std::path::Component;
let mut comps = p.components();
matches!(
(comps.next(), comps.next(), comps.next()),
(Some(Component::Prefix(_)), Some(Component::RootDir), None)
)
}
#[cfg(test)]
mod tests {
use super::*;
use pretty_assertions::assert_eq;
use tempfile::tempdir;
/// Compute the Git SHA-1 blob object ID for the given content (string).
/// This delegates to the bytes version to avoid UTF-8 lossy conversions here.
fn git_blob_sha1_hex(data: &str) -> String {
format!("{:x}", git_blob_sha1_hex_bytes(data.as_bytes()))
}
fn normalize_diff_for_test(input: &str, root: &Path) -> String {
let root_str = root.display().to_string().replace('\\', "/");
let replaced = input.replace(&root_str, "<TMP>");
// Split into blocks on lines starting with "diff --git ", sort blocks for determinism, and rejoin
let mut blocks: Vec<String> = Vec::new();
let mut current = String::new();
for line in replaced.lines() {
if line.starts_with("diff --git ") && !current.is_empty() {
blocks.push(current);
current = String::new();
}
if !current.is_empty() {
current.push('\n');
}
current.push_str(line);
}
if !current.is_empty() {
blocks.push(current);
}
blocks.sort();
let mut out = blocks.join("\n");
if !out.ends_with('\n') {
out.push('\n');
}
out
}
#[test]
fn accumulates_add_and_update() {
let mut acc = TurnDiffTracker::new();
let dir = tempdir().unwrap();
let file = dir.path().join("a.txt");
// First patch: add file (baseline should be /dev/null).
let add_changes = HashMap::from([(
file.clone(),
FileChange::Add {
content: "foo\n".to_string(),
},
)]);
acc.on_patch_begin(&add_changes);
// Simulate apply: create the file on disk.
fs::write(&file, "foo\n").unwrap();
let first = acc.get_unified_diff().unwrap().unwrap();
let first = normalize_diff_for_test(&first, dir.path());
let expected_first = {
let mode = file_mode_for_path(&file).unwrap_or(FileMode::Regular);
let right_oid = git_blob_sha1_hex("foo\n");
format!(
r#"diff --git a/<TMP>/a.txt b/<TMP>/a.txt
new file mode {mode}
index {ZERO_OID}..{right_oid}
--- {DEV_NULL}
+++ b/<TMP>/a.txt
@@ -0,0 +1 @@
+foo
"#,
)
};
assert_eq!(first, expected_first);
// Second patch: update the file on disk.
let update_changes = HashMap::from([(
file.clone(),
FileChange::Update {
unified_diff: "".to_owned(),
move_path: None,
},
)]);
acc.on_patch_begin(&update_changes);
// Simulate apply: append a new line.
fs::write(&file, "foo\nbar\n").unwrap();
let combined = acc.get_unified_diff().unwrap().unwrap();
let combined = normalize_diff_for_test(&combined, dir.path());
let expected_combined = {
let mode = file_mode_for_path(&file).unwrap_or(FileMode::Regular);
let right_oid = git_blob_sha1_hex("foo\nbar\n");
format!(
r#"diff --git a/<TMP>/a.txt b/<TMP>/a.txt
new file mode {mode}
index {ZERO_OID}..{right_oid}
--- {DEV_NULL}
+++ b/<TMP>/a.txt
@@ -0,0 +1,2 @@
+foo
+bar
"#,
)
};
assert_eq!(combined, expected_combined);
}
#[test]
fn accumulates_delete() {
let dir = tempdir().unwrap();
let file = dir.path().join("b.txt");
fs::write(&file, "x\n").unwrap();
let mut acc = TurnDiffTracker::new();
let del_changes = HashMap::from([(
file.clone(),
FileChange::Delete {
content: "x\n".to_string(),
},
)]);
acc.on_patch_begin(&del_changes);
// Simulate apply: delete the file from disk.
let baseline_mode = file_mode_for_path(&file).unwrap_or(FileMode::Regular);
fs::remove_file(&file).unwrap();
let diff = acc.get_unified_diff().unwrap().unwrap();
let diff = normalize_diff_for_test(&diff, dir.path());
let expected = {
let left_oid = git_blob_sha1_hex("x\n");
format!(
r#"diff --git a/<TMP>/b.txt b/<TMP>/b.txt
deleted file mode {baseline_mode}
index {left_oid}..{ZERO_OID}
--- a/<TMP>/b.txt
+++ {DEV_NULL}
@@ -1 +0,0 @@
-x
"#,
)
};
assert_eq!(diff, expected);
}
#[test]
fn accumulates_move_and_update() {
let dir = tempdir().unwrap();
let src = dir.path().join("src.txt");
let dest = dir.path().join("dst.txt");
fs::write(&src, "line\n").unwrap();
let mut acc = TurnDiffTracker::new();
let mv_changes = HashMap::from([(
src.clone(),
FileChange::Update {
unified_diff: "".to_owned(),
move_path: Some(dest.clone()),
},
)]);
acc.on_patch_begin(&mv_changes);
// Simulate apply: move and update content.
fs::rename(&src, &dest).unwrap();
fs::write(&dest, "line2\n").unwrap();
let out = acc.get_unified_diff().unwrap().unwrap();
let out = normalize_diff_for_test(&out, dir.path());
let expected = {
let left_oid = git_blob_sha1_hex("line\n");
let right_oid = git_blob_sha1_hex("line2\n");
format!(
r#"diff --git a/<TMP>/src.txt b/<TMP>/dst.txt
index {left_oid}..{right_oid}
--- a/<TMP>/src.txt
+++ b/<TMP>/dst.txt
@@ -1 +1 @@
-line
+line2
"#
)
};
assert_eq!(out, expected);
}
#[test]
fn move_without_1change_yields_no_diff() {
let dir = tempdir().unwrap();
let src = dir.path().join("moved.txt");
let dest = dir.path().join("renamed.txt");
fs::write(&src, "same\n").unwrap();
let mut acc = TurnDiffTracker::new();
let mv_changes = HashMap::from([(
src.clone(),
FileChange::Update {
unified_diff: "".to_owned(),
move_path: Some(dest.clone()),
},
)]);
acc.on_patch_begin(&mv_changes);
// Simulate apply: move only, no content change.
fs::rename(&src, &dest).unwrap();
let diff = acc.get_unified_diff().unwrap();
assert_eq!(diff, None);
}
#[test]
fn move_declared_but_file_only_appears_at_dest_is_add() {
let dir = tempdir().unwrap();
let src = dir.path().join("src.txt");
let dest = dir.path().join("dest.txt");
let mut acc = TurnDiffTracker::new();
let mv = HashMap::from([(
src,
FileChange::Update {
unified_diff: "".into(),
move_path: Some(dest.clone()),
},
)]);
acc.on_patch_begin(&mv);
// No file existed initially; create only dest
fs::write(&dest, "hello\n").unwrap();
let diff = acc.get_unified_diff().unwrap().unwrap();
let diff = normalize_diff_for_test(&diff, dir.path());
let expected = {
let mode = file_mode_for_path(&dest).unwrap_or(FileMode::Regular);
let right_oid = git_blob_sha1_hex("hello\n");
format!(
r#"diff --git a/<TMP>/src.txt b/<TMP>/dest.txt
new file mode {mode}
index {ZERO_OID}..{right_oid}
--- {DEV_NULL}
+++ b/<TMP>/dest.txt
@@ -0,0 +1 @@
+hello
"#,
)
};
assert_eq!(diff, expected);
}
#[test]
fn update_persists_across_new_baseline_for_new_file() {
let dir = tempdir().unwrap();
let a = dir.path().join("a.txt");
let b = dir.path().join("b.txt");
fs::write(&a, "foo\n").unwrap();
fs::write(&b, "z\n").unwrap();
let mut acc = TurnDiffTracker::new();
// First: update existing a.txt (baseline snapshot is created for a).
let update_a = HashMap::from([(
a.clone(),
FileChange::Update {
unified_diff: "".to_owned(),
move_path: None,
},
)]);
acc.on_patch_begin(&update_a);
// Simulate apply: modify a.txt on disk.
fs::write(&a, "foo\nbar\n").unwrap();
let first = acc.get_unified_diff().unwrap().unwrap();
let first = normalize_diff_for_test(&first, dir.path());
let expected_first = {
let left_oid = git_blob_sha1_hex("foo\n");
let right_oid = git_blob_sha1_hex("foo\nbar\n");
format!(
r#"diff --git a/<TMP>/a.txt b/<TMP>/a.txt
index {left_oid}..{right_oid}
--- a/<TMP>/a.txt
+++ b/<TMP>/a.txt
@@ -1 +1,2 @@
foo
+bar
"#
)
};
assert_eq!(first, expected_first);
// Next: introduce a brand-new path b.txt into baseline snapshots via a delete change.
let del_b = HashMap::from([(
b.clone(),
FileChange::Delete {
content: "z\n".to_string(),
},
)]);
acc.on_patch_begin(&del_b);
// Simulate apply: delete b.txt.
let baseline_mode = file_mode_for_path(&b).unwrap_or(FileMode::Regular);
fs::remove_file(&b).unwrap();
let combined = acc.get_unified_diff().unwrap().unwrap();
let combined = normalize_diff_for_test(&combined, dir.path());
let expected = {
let left_oid_a = git_blob_sha1_hex("foo\n");
let right_oid_a = git_blob_sha1_hex("foo\nbar\n");
let left_oid_b = git_blob_sha1_hex("z\n");
format!(
r#"diff --git a/<TMP>/a.txt b/<TMP>/a.txt
index {left_oid_a}..{right_oid_a}
--- a/<TMP>/a.txt
+++ b/<TMP>/a.txt
@@ -1 +1,2 @@
foo
+bar
diff --git a/<TMP>/b.txt b/<TMP>/b.txt
deleted file mode {baseline_mode}
index {left_oid_b}..{ZERO_OID}
--- a/<TMP>/b.txt
+++ {DEV_NULL}
@@ -1 +0,0 @@
-z
"#,
)
};
assert_eq!(combined, expected);
}
#[test]
fn binary_files_differ_update() {
let dir = tempdir().unwrap();
let file = dir.path().join("bin.dat");
// Initial non-UTF8 bytes
let left_bytes: Vec<u8> = vec![0xff, 0xfe, 0xfd, 0x00];
// Updated non-UTF8 bytes
let right_bytes: Vec<u8> = vec![0x01, 0x02, 0x03, 0x00];
fs::write(&file, &left_bytes).unwrap();
let mut acc = TurnDiffTracker::new();
let update_changes = HashMap::from([(
file.clone(),
FileChange::Update {
unified_diff: "".to_owned(),
move_path: None,
},
)]);
acc.on_patch_begin(&update_changes);
// Apply update on disk
fs::write(&file, &right_bytes).unwrap();
let diff = acc.get_unified_diff().unwrap().unwrap();
let diff = normalize_diff_for_test(&diff, dir.path());
let expected = {
let left_oid = format!("{:x}", git_blob_sha1_hex_bytes(&left_bytes));
let right_oid = format!("{:x}", git_blob_sha1_hex_bytes(&right_bytes));
format!(
r#"diff --git a/<TMP>/bin.dat b/<TMP>/bin.dat
index {left_oid}..{right_oid}
--- a/<TMP>/bin.dat
+++ b/<TMP>/bin.dat
Binary files differ
"#
)
};
assert_eq!(diff, expected);
}
#[test]
fn filenames_with_spaces_add_and_update() {
let mut acc = TurnDiffTracker::new();
let dir = tempdir().unwrap();
let file = dir.path().join("name with spaces.txt");
// First patch: add file (baseline should be /dev/null).
let add_changes = HashMap::from([(
file.clone(),
FileChange::Add {
content: "foo\n".to_string(),
},
)]);
acc.on_patch_begin(&add_changes);
// Simulate apply: create the file on disk.
fs::write(&file, "foo\n").unwrap();
let first = acc.get_unified_diff().unwrap().unwrap();
let first = normalize_diff_for_test(&first, dir.path());
let expected_first = {
let mode = file_mode_for_path(&file).unwrap_or(FileMode::Regular);
let right_oid = git_blob_sha1_hex("foo\n");
format!(
r#"diff --git a/<TMP>/name with spaces.txt b/<TMP>/name with spaces.txt
new file mode {mode}
index {ZERO_OID}..{right_oid}
--- {DEV_NULL}
+++ b/<TMP>/name with spaces.txt
@@ -0,0 +1 @@
+foo
"#,
)
};
assert_eq!(first, expected_first);
// Second patch: update the file on disk.
let update_changes = HashMap::from([(
file.clone(),
FileChange::Update {
unified_diff: "".to_owned(),
move_path: None,
},
)]);
acc.on_patch_begin(&update_changes);
// Simulate apply: append a new line with a space.
fs::write(&file, "foo\nbar baz\n").unwrap();
let combined = acc.get_unified_diff().unwrap().unwrap();
let combined = normalize_diff_for_test(&combined, dir.path());
let expected_combined = {
let mode = file_mode_for_path(&file).unwrap_or(FileMode::Regular);
let right_oid = git_blob_sha1_hex("foo\nbar baz\n");
format!(
r#"diff --git a/<TMP>/name with spaces.txt b/<TMP>/name with spaces.txt
new file mode {mode}
index {ZERO_OID}..{right_oid}
--- {DEV_NULL}
+++ b/<TMP>/name with spaces.txt
@@ -0,0 +1,2 @@
+foo
+bar baz
"#,
)
};
assert_eq!(combined, expected_combined);
}
}
pub use codex_agent::turn_diff_tracker::*;

View File

@@ -0,0 +1 @@
pub use codex_agent::unified_exec::*;

View File

@@ -1,7 +1,9 @@
use serde::Serialize;
use serde_json::to_string;
use tracing::error;
use tracing::warn;
pub use codex_agent::notifications::UserNotification;
#[derive(Debug, Default)]
pub(crate) struct UserNotifier {
notify_command: Option<Vec<String>>,
@@ -17,7 +19,7 @@ impl UserNotifier {
}
fn invoke_notify(&self, notify_command: &[String], notification: &UserNotification) {
let Ok(json) = serde_json::to_string(&notification) else {
let Ok(json) = to_string(notification) else {
error!("failed to serialise notification payload");
return;
};
@@ -28,7 +30,6 @@ impl UserNotifier {
}
command.arg(json);
// Fire-and-forget we do not wait for completion.
if let Err(e) = command.spawn() {
warn!("failed to spawn notifier '{}': {e}", notify_command[0]);
}
@@ -40,44 +41,3 @@ impl UserNotifier {
}
}
}
/// User can configure a program that will receive notifications. Each
/// notification is serialized as JSON and passed as an argument to the
/// program.
#[derive(Debug, Clone, PartialEq, Serialize)]
#[serde(tag = "type", rename_all = "kebab-case")]
pub(crate) enum UserNotification {
#[serde(rename_all = "kebab-case")]
AgentTurnComplete {
turn_id: String,
/// Messages that the user sent to the agent to initiate the turn.
input_messages: Vec<String>,
/// The last message sent by the assistant in the turn.
last_assistant_message: Option<String>,
},
}
#[cfg(test)]
mod tests {
use super::*;
use anyhow::Result;
#[test]
fn test_user_notification() -> Result<()> {
let notification = UserNotification::AgentTurnComplete {
turn_id: "12345".to_string(),
input_messages: vec!["Rename `foo` to `bar` and update the callsites.".to_string()],
last_assistant_message: Some(
"Rename complete and verified `cargo build` succeeds.".to_string(),
),
};
let serialized = serde_json::to_string(&notification)?;
assert_eq!(
serialized,
r#"{"type":"agent-turn-complete","turn-id":"12345","input-messages":["Rename `foo` to `bar` and update the callsites."],"last-assistant-message":"Rename complete and verified `cargo build` succeeds."}"#
);
Ok(())
}
}

View File

@@ -1,5 +1,6 @@
use std::sync::Arc;
use codex_core::AgentConfig;
use codex_core::ContentItem;
use codex_core::LocalShellAction;
use codex_core::LocalShellExecAction;
@@ -69,9 +70,10 @@ async fn run_request(input: Vec<ResponseItem>) -> Value {
let effort = config.model_reasoning_effort;
let summary = config.model_reasoning_summary;
let config = Arc::new(config);
let agent_config = Arc::new(AgentConfig::from(config.as_ref()));
let client = ModelClient::new(
Arc::clone(&config),
Arc::clone(&agent_config),
None,
provider,
effort,

View File

@@ -1,5 +1,6 @@
use std::sync::Arc;
use codex_core::AgentConfig;
use codex_core::ContentItem;
use codex_core::ModelClient;
use codex_core::ModelProviderInfo;
@@ -62,9 +63,10 @@ async fn run_stream(sse_body: &str) -> Vec<ResponseEvent> {
let effort = config.model_reasoning_effort;
let summary = config.model_reasoning_summary;
let config = Arc::new(config);
let agent_config = Arc::new(AgentConfig::from(config.as_ref()));
let client = ModelClient::new(
Arc::clone(&config),
Arc::clone(&agent_config),
None,
provider,
effort,

View File

@@ -1,3 +1,4 @@
use codex_core::AgentConfig;
use codex_core::CodexAuth;
use codex_core::ContentItem;
use codex_core::ConversationManager;
@@ -661,9 +662,10 @@ async fn azure_responses_request_includes_store_and_reasoning_ids() {
let effort = config.model_reasoning_effort;
let summary = config.model_reasoning_summary;
let config = Arc::new(config);
let agent_config = Arc::new(AgentConfig::from(config.as_ref()));
let client = ModelClient::new(
Arc::clone(&config),
Arc::clone(&agent_config),
None,
provider,
effort,

View File

@@ -12,6 +12,7 @@ mod fork_conversation;
mod json_result;
mod live_cli;
mod model_overrides;
mod multi_task_smoke;
mod prompt_caching;
mod review;
mod rollout_list_find;

View File

@@ -0,0 +1,115 @@
use codex_core::protocol::EventMsg;
use codex_core::protocol::InputItem;
use codex_core::protocol::Op;
use codex_core::protocol::ReviewRequest;
use core_test_support::responses::ev_assistant_message;
use core_test_support::responses::ev_completed;
use core_test_support::responses::mount_sse_sequence;
use core_test_support::responses::sse;
use core_test_support::responses::start_mock_server;
use core_test_support::skip_if_no_network;
use core_test_support::test_codex::test_codex;
use core_test_support::wait_for_event;
use pretty_assertions::assert_eq;
/// Smoke test: ensure regular, review, and compact tasks can run sequentially in
/// a single session without leaving dangling state.
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn regular_review_compact_sequence() {
skip_if_no_network!();
let server = start_mock_server().await;
let regular_body = sse(vec![
ev_assistant_message("regular-msg", "First turn complete."),
ev_completed("regular-resp"),
]);
let review_json = serde_json::json!({
"findings": [],
"overall_correctness": "ok",
"overall_explanation": "Looks good overall.",
"overall_confidence_score": 0.5
})
.to_string();
let review_body = sse(vec![
ev_assistant_message("review-msg", &review_json),
ev_completed("review-resp"),
]);
let compact_body = sse(vec![
ev_assistant_message("compact-msg", "Session summary."),
ev_completed("compact-resp"),
]);
mount_sse_sequence(&server, vec![regular_body, review_body, compact_body]).await;
let codex = test_codex().build(&server).await.unwrap().codex;
// 1. Regular user input turn.
codex
.submit(Op::UserInput {
items: vec![InputItem::Text {
text: "first turn".into(),
}],
})
.await
.unwrap();
let regular_complete =
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
let regular_last = match regular_complete {
EventMsg::TaskComplete(ev) => ev.last_agent_message,
other => panic!("expected TaskComplete for regular turn, got {other:?}"),
};
assert_eq!(regular_last.as_deref(), Some("First turn complete."));
// 2. Review task turn.
codex
.submit(Op::Review {
review_request: ReviewRequest {
prompt: "please review".to_string(),
user_facing_hint: "hint".to_string(),
},
})
.await
.unwrap();
let _entered = wait_for_event(&codex, |ev| matches!(ev, EventMsg::EnteredReviewMode(_))).await;
let exited = wait_for_event(&codex, |ev| matches!(ev, EventMsg::ExitedReviewMode(_))).await;
let review_output = match exited {
EventMsg::ExitedReviewMode(ev) => ev.review_output.expect("expected review output"),
other => panic!("expected ExitedReviewMode, got {other:?}"),
};
assert_eq!(review_output.overall_correctness, "ok");
assert_eq!(review_output.overall_explanation, "Looks good overall.");
assert!(review_output.findings.is_empty());
let review_complete =
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
let review_last = match review_complete {
EventMsg::TaskComplete(ev) => ev.last_agent_message,
other => panic!("expected TaskComplete for review, got {other:?}"),
}
.expect("review task should emit last_agent_message");
assert_eq!(review_last, review_json);
// 3. Manual compact task.
codex.submit(Op::Compact).await.unwrap();
let compact_complete =
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
let compact_last = match compact_complete {
EventMsg::TaskComplete(ev) => ev.last_agent_message,
other => panic!("expected TaskComplete for compact, got {other:?}"),
};
assert!(
compact_last.is_none(),
"compact task should not emit a trailing assistant message"
);
let requests = server.received_requests().await.unwrap();
assert_eq!(
requests.len(),
3,
"expected exactly three Responses API calls"
);
}

View File

@@ -0,0 +1,77 @@
# Agent Runtime Baseline
Codex currently exposes its agent runtime from the `codex-core` crate. The runtime is organised around three cooperating interfaces located in `core/src`:
- `Codex` (`core/src/codex.rs`) is the public façade that hosts use. It owns the submission/event queues and spawns the asynchronous runtime.
- `Session` (`core/src/codex.rs`) encapsulates conversation-scoped state and orchestrates task lifecycles once a session has been configured.
- `SessionTask` (`core/src/tasks/mod.rs`) is the trait implemented by the concrete task runners (`RegularTask`, `ReviewTask`, `CompactTask`).
The refactor tracked in `../agent_refactor.md` will extract these responsibilities into a dedicated `codex-agent` crate; this document captures the current layout before the extraction begins.
## `Codex`
`Codex` is the high-level queue API that front-ends interact with:
- `spawn` initialises the runtime with a `Config`, `AuthManager`, and `InitialHistory` and returns a `CodexSpawnOk` containing the queue endpoints and generated `ConversationId`.
- `submit` wraps an `Op` in a `Submission`, generates a unique id, and pushes it onto the bounded submission channel (`SUBMISSION_CHANNEL_CAPACITY = 64`).
- `next_event` pulls the next `Event` from the unbounded event receiver, propagating `CodexErr::InternalAgentDied` if the channel is closed.
Upon spawning, `Codex` constructs a `ConfigureSession` payload that gathers CLI-derived configuration (model details, approvals, sandbox policy, notify hooks, `cwd`) and uses `prepare_session_bootstrap` to assemble rollout/MCP/sandbox services before delegating to `Session::new`. The CLI now also constructs the initial `ModelClient`/`ToolsConfig` pair and wraps them in a `TurnContext` `Arc`, hands over the host-built `SessionServices`, and kicks off the background `submission_loop` task that drives the agent.
## `Session`
`Session` bundles the pieces required to handle a configured conversation:
- Identifiers: `conversation_id` (originating from `InitialHistory`) and an internal `next_internal_sub_id` counter for action-specific ids.
- Communication: the `tx_event` sender used to emit `Event` messages back to the host.
- State holders: a `Mutex<SessionState>` for persistent session data and a `Mutex<Option<ActiveTurn>>` tracking the currently running tasks.
- Services: `SessionServices` (now defined in `codex-agent`) packages dependencies that are currently hard-wired to CLI types—`ExecSessionManager`, `UnifiedExecSessionManager`, `RolloutRecorder`, `McpConnectionManager`, etc.
`Session::new` now expects those services to be prepared by the host: `prepare_session_bootstrap` initialises rollout recording, MCP connections, default shell discovery, and history metadata in parallel before handing the assembled pieces to `Session::new`, which emits the initial `SessionConfigured` event and records any startup warnings so they can be surfaced after configuration.
Operationally, `Session` is responsible for:
- Translating incoming `Submission`s into task invocations via `run_task`, `run_turn`, and helper functions in `core/src/codex.rs` and `core/src/tasks`.
- Managing approvals and sandbox execution by calling into `sandbox::plan_*` helpers and dispatching events (`ExecApprovalRequestEvent`, `ApplyPatchApprovalRequestEvent`, etc.).
- Recording rollout items and forwarding MCP tool call updates.
- Tracking and cancelling running work when new input arrives or when approvals are rejected.
## `SessionTask`
Tasks implement the asynchronous work units that execute within a session. The trait lives in `core/src/tasks/mod.rs`:
```rust
#[async_trait]
pub(crate) trait SessionTask: Send + Sync + 'static {
fn kind(&self) -> TaskKind;
async fn run(
self: Arc<Self>,
session: Arc<SessionTaskContext>,
ctx: Arc<TurnContext>,
sub_id: String,
input: Vec<InputItem>,
) -> Option<String>;
async fn abort(&self, session: Arc<SessionTaskContext>, sub_id: &str) { ... }
}
```
`SessionTaskContext` is a thin wrapper that hands tasks a clone of the `Session`, giving them access to helpers such as `send_event`, `plan_exec`, and `run_with_plan`. `Session::spawn_task` ensures only one task runs at a time by:
1. Calling `abort_all_tasks` to cancel the current `ActiveTurn`.
2. Wrapping the concrete task in `Arc<dyn SessionTask>`.
3. Spawning a Tokio task that awaits `run` and then reports completion through `Session::on_task_finished`.
`RunningTask`, `ActiveTurn`, and `TurnAbortReason` (from `core/src/state`) coordinate cancellation semantics and surface `TurnAborted`/`TaskComplete` events consistently.
Today the concrete implementations are:
- `RegularTask` (`core/src/tasks/regular.rs`) for the standard Codex workflow.
- `ReviewTask` (`core/src/tasks/review.rs`) used during review mode.
- `CompactTask` (`core/src/tasks/compact.rs`) which emits summarised history.
Each uses shared utilities in `core/src/codex.rs` (e.g., `run_task`, `exit_review_mode`, sandbox planners) and relies on the CLI-flavoured services packaged in `SessionServices`.
## Next Steps
With this baseline documented, the next implementation steps are described in `../agent_refactor.md`. As we move work into the new `codex-agent` crate we should revisit this document to ensure the captured interfaces stay accurate and to outline any newly introduced abstractions (`AgentRuntime`, `AgentConfig`, service traits, etc.).

View File

@@ -25,9 +25,12 @@ async fn user_info_returns_email_from_auth_json() {
let codex_home = TempDir::new().expect("create tempdir");
let auth_path = get_auth_file(codex_home.path());
let mut id_token = IdTokenInfo::default();
id_token.email = Some("user@example.com".to_string());
id_token.raw_jwt = encode_id_token_with_email("user@example.com").expect("encode id token");
let raw_jwt = encode_id_token_with_email("user@example.com").expect("encode id token");
let id_token = IdTokenInfo {
email: Some("user@example.com".to_string()),
raw_jwt,
..IdTokenInfo::default()
};
let auth = AuthDotJson {
openai_api_key: None,

View File

@@ -590,8 +590,10 @@ mod tests {
AuthMode::ChatGPT => {
// Minimal valid JWT payload: header.payload.signature (all base64url, no padding)
const FAKE_JWT: &str = "eyJhbGciOiJub25lIiwidHlwIjoiSldUIn0.e30.c2ln"; // {"alg":"none","typ":"JWT"}.{}."sig"
let mut id_info = IdTokenInfo::default();
id_info.raw_jwt = FAKE_JWT.to_string();
let id_info = IdTokenInfo {
raw_jwt: FAKE_JWT.to_string(),
..IdTokenInfo::default()
};
let auth = AuthDotJson {
openai_api_key: None,
tokens: Some(TokenData {

View File

@@ -2,6 +2,7 @@ use crate::exec_command::relativize_to_home;
use crate::text_formatting;
use chrono::DateTime;
use chrono::Local;
use codex_core::AgentConfig;
use codex_core::auth::get_auth_file;
use codex_core::auth::try_read_auth_json;
use codex_core::config::Config;
@@ -32,7 +33,8 @@ pub(crate) fn compose_model_display(
}
pub(crate) fn compose_agents_summary(config: &Config) -> String {
match discover_project_doc_paths(config) {
let agent_config = AgentConfig::from(config);
match discover_project_doc_paths(&agent_config) {
Ok(paths) => {
let mut rels: Vec<String> = Vec::new();
for p in paths {