mirror of
https://github.com/openai/codex.git
synced 2026-02-03 07:23:39 +00:00
Compare commits
82 Commits
tokencount
...
patch-tool
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
570639cf98 | ||
|
|
1c50fbb8a7 | ||
|
|
3316d04ed4 | ||
|
|
67a8566f59 | ||
|
|
2d36621f48 | ||
|
|
0a70810fc0 | ||
|
|
b5cf9e09ff | ||
|
|
b2067c73d9 | ||
|
|
13e8771ee9 | ||
|
|
bba567cee9 | ||
|
|
6577197fa4 | ||
|
|
fd1e12f34e | ||
|
|
ba6af23cb6 | ||
|
|
f805d17930 | ||
|
|
9580603fed | ||
|
|
da38a8f56a | ||
|
|
552a438cc9 | ||
|
|
a36a273d4e | ||
|
|
6884c6ccf6 | ||
|
|
1e5a613c55 | ||
|
|
90965fbc84 | ||
|
|
4fee2ca3fd | ||
|
|
3318cf9369 | ||
|
|
5ba0bcf035 | ||
|
|
6d55ef62f9 | ||
|
|
cecf3a82a6 | ||
|
|
c172e8e997 | ||
|
|
9bbeb75361 | ||
|
|
6ccd32c601 | ||
|
|
3b5a5412bb | ||
|
|
44bb53df1e | ||
|
|
9a7266a33f | ||
|
|
2abad8fece | ||
|
|
0d4a25b981 | ||
|
|
8453915e02 | ||
|
|
44587c2443 | ||
|
|
8f7b22b652 | ||
|
|
027944c64e | ||
|
|
bec51f6c05 | ||
|
|
66967500bb | ||
|
|
167b4f0e25 | ||
|
|
167154178b | ||
|
|
674e3d3c90 | ||
|
|
114ce9ff4d | ||
|
|
e13b35ecb0 | ||
|
|
377af75730 | ||
|
|
86e0f31a7e | ||
|
|
8f837f1093 | ||
|
|
162e1235a8 | ||
|
|
c09ed74a16 | ||
|
|
65f3528cad | ||
|
|
44262d8fd8 | ||
|
|
95a9938d3a | ||
|
|
f69f07b028 | ||
|
|
8d766088e6 | ||
|
|
87654ec0b7 | ||
|
|
51d9e05de7 | ||
|
|
8068cc75f8 | ||
|
|
acb28bf914 | ||
|
|
97338de578 | ||
|
|
5200b7a95d | ||
|
|
64e6c4afbb | ||
|
|
39db113cc9 | ||
|
|
45bd5ca4b9 | ||
|
|
c13c3dadbf | ||
|
|
8636bff46d | ||
|
|
43809a454e | ||
|
|
5c48600bb3 | ||
|
|
de6559f2ab | ||
|
|
5bcc9d8b77 | ||
|
|
5eab4c7ab4 | ||
|
|
f656e192bf | ||
|
|
ee5ecae7c0 | ||
|
|
58bb2048ac | ||
|
|
ac8a3155d6 | ||
|
|
ace14e8d36 | ||
|
|
2a76a08a9e | ||
|
|
16309d6b68 | ||
|
|
62bd0e3d9d | ||
|
|
a9c68ea270 | ||
|
|
ac58749bd3 | ||
|
|
79cbd2ab1b |
35
.github/workflows/rust-ci.yml
vendored
35
.github/workflows/rust-ci.yml
vendored
@@ -62,6 +62,26 @@ jobs:
|
||||
components: rustfmt
|
||||
- name: cargo fmt
|
||||
run: cargo fmt -- --config imports_granularity=Item --check
|
||||
- name: Verify codegen for mcp-types
|
||||
run: ./mcp-types/check_lib_rs.py
|
||||
|
||||
cargo_shear:
|
||||
name: cargo shear
|
||||
runs-on: ubuntu-24.04
|
||||
needs: changed
|
||||
if: ${{ needs.changed.outputs.codex == 'true' || needs.changed.outputs.workflows == 'true' || github.event_name == 'push' }}
|
||||
defaults:
|
||||
run:
|
||||
working-directory: codex-rs
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: dtolnay/rust-toolchain@1.89
|
||||
- uses: taiki-e/install-action@0c5db7f7f897c03b771660e91d065338615679f4 # v2
|
||||
with:
|
||||
tool: cargo-shear
|
||||
version: 1.5.1
|
||||
- name: cargo shear
|
||||
run: cargo shear
|
||||
|
||||
# --- CI to validate on different os/targets --------------------------------
|
||||
lint_build_test:
|
||||
@@ -160,12 +180,17 @@ jobs:
|
||||
find . -name Cargo.toml -mindepth 2 -maxdepth 2 -print0 \
|
||||
| xargs -0 -n1 -I{} bash -c 'cd "$(dirname "{}")" && cargo check --profile ${{ matrix.profile }}'
|
||||
|
||||
- name: cargo test
|
||||
- uses: taiki-e/install-action@0c5db7f7f897c03b771660e91d065338615679f4 # v2
|
||||
with:
|
||||
tool: nextest
|
||||
version: 0.9.103
|
||||
|
||||
- name: tests
|
||||
id: test
|
||||
# `cargo test` takes too long for release builds to run them on every PR
|
||||
# Tests take too long for release builds to run them on every PR.
|
||||
if: ${{ matrix.profile != 'release' }}
|
||||
continue-on-error: true
|
||||
run: cargo test --all-features --target ${{ matrix.target }} --profile ${{ matrix.profile }}
|
||||
run: cargo nextest run --all-features --no-fail-fast --target ${{ matrix.target }}
|
||||
env:
|
||||
RUST_BACKTRACE: 1
|
||||
|
||||
@@ -182,7 +207,7 @@ jobs:
|
||||
# --- Gatherer job that you mark as the ONLY required status -----------------
|
||||
results:
|
||||
name: CI results (required)
|
||||
needs: [changed, general, lint_build_test]
|
||||
needs: [changed, general, cargo_shear, lint_build_test]
|
||||
if: always()
|
||||
runs-on: ubuntu-24.04
|
||||
steps:
|
||||
@@ -190,6 +215,7 @@ jobs:
|
||||
shell: bash
|
||||
run: |
|
||||
echo "general: ${{ needs.general.result }}"
|
||||
echo "shear : ${{ needs.cargo_shear.result }}"
|
||||
echo "matrix : ${{ needs.lint_build_test.result }}"
|
||||
|
||||
# If nothing relevant changed (PR touching only root README, etc.),
|
||||
@@ -201,4 +227,5 @@ jobs:
|
||||
|
||||
# Otherwise require the jobs to have succeeded
|
||||
[[ '${{ needs.general.result }}' == 'success' ]] || { echo 'general failed'; exit 1; }
|
||||
[[ '${{ needs.cargo_shear.result }}' == 'success' ]] || { echo 'cargo_shear failed'; exit 1; }
|
||||
[[ '${{ needs.lint_build_test.result }}' == 'success' ]] || { echo 'matrix failed'; exit 1; }
|
||||
|
||||
19
.github/workflows/rust-release.yml
vendored
19
.github/workflows/rust-release.yml
vendored
@@ -219,3 +219,22 @@ jobs:
|
||||
with:
|
||||
tag: ${{ github.ref_name }}
|
||||
config: .github/dotslash-config.json
|
||||
|
||||
update-branch:
|
||||
name: Update latest-alpha-cli branch
|
||||
permissions:
|
||||
contents: write
|
||||
needs: release
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Update latest-alpha-cli branch
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
gh api \
|
||||
repos/${GITHUB_REPOSITORY}/git/refs/heads/latest-alpha-cli \
|
||||
-X PATCH \
|
||||
-f sha="${GITHUB_SHA}" \
|
||||
-F force=true
|
||||
|
||||
@@ -2,7 +2,10 @@
|
||||
|
||||
<p align="center"><code>npm i -g @openai/codex</code><br />or <code>brew install codex</code></p>
|
||||
|
||||
<p align="center"><strong>Codex CLI</strong> is a coding agent from OpenAI that runs locally on your computer.</br>If you are looking for the <em>cloud-based agent</em> from OpenAI, <strong>Codex Web</strong>, see <a href="https://chatgpt.com/codex">chatgpt.com/codex</a>.</p>
|
||||
<p align="center"><strong>Codex CLI</strong> is a coding agent from OpenAI that runs locally on your computer.
|
||||
</br>
|
||||
</br>If you want Codex in your code editor (VS Code, Cursor, Windsurf), <a href="https://developers.openai.com/codex/ide">install in your IDE</a>
|
||||
</br>If you are looking for the <em>cloud-based agent</em> from OpenAI, <strong>Codex Web</strong>, go to <a href="https://chatgpt.com/codex">chatgpt.com/codex</a></p>
|
||||
|
||||
<p align="center">
|
||||
<img src="./.github/codex-cli-splash.png" alt="Codex CLI splash" width="80%" />
|
||||
|
||||
816
codex-rs/Cargo.lock
generated
816
codex-rs/Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -34,6 +34,7 @@ rust = {}
|
||||
|
||||
[workspace.lints.clippy]
|
||||
expect_used = "deny"
|
||||
redundant_clone = "deny"
|
||||
uninlined_format_args = "deny"
|
||||
unwrap_used = "deny"
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ npx @modelcontextprotocol/inspector codex mcp
|
||||
|
||||
You can enable notifications by configuring a script that is run whenever the agent finishes a turn. The [notify documentation](../docs/config.md#notify) includes a detailed example that explains how to get desktop notifications via [terminal-notifier](https://github.com/julienXX/terminal-notifier) on macOS.
|
||||
|
||||
### `codex exec` to run Codex programmatially/non-interactively
|
||||
### `codex exec` to run Codex programmatically/non-interactively
|
||||
|
||||
To run Codex non-interactively, run `codex exec PROMPT` (you can also pass the prompt via `stdin`) and Codex will work on your task until it decides that it is done and exits. Output is printed to the terminal directly. You can set the `RUST_LOG` environment variable to see more about what's going on.
|
||||
|
||||
|
||||
@@ -726,13 +726,15 @@ fn compute_replacements(
|
||||
line_index = start_idx + pattern.len();
|
||||
} else {
|
||||
return Err(ApplyPatchError::ComputeReplacements(format!(
|
||||
"Failed to find expected lines {:?} in {}",
|
||||
chunk.old_lines,
|
||||
path.display()
|
||||
"Failed to find expected lines in {}:\n{}",
|
||||
path.display(),
|
||||
chunk.old_lines.join("\n"),
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
replacements.sort_by(|(lhs_idx, _, _), (rhs_idx, _, _)| lhs_idx.cmp(rhs_idx));
|
||||
|
||||
Ok(replacements)
|
||||
}
|
||||
|
||||
@@ -1216,6 +1218,33 @@ PATCH"#,
|
||||
assert_eq!(contents, "a\nB\nc\nd\nE\nf\ng\n");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pure_addition_chunk_followed_by_removal() {
|
||||
let dir = tempdir().unwrap();
|
||||
let path = dir.path().join("panic.txt");
|
||||
fs::write(&path, "line1\nline2\nline3\n").unwrap();
|
||||
let patch = wrap_patch(&format!(
|
||||
r#"*** Update File: {}
|
||||
@@
|
||||
+after-context
|
||||
+second-line
|
||||
@@
|
||||
line1
|
||||
-line2
|
||||
-line3
|
||||
+line2-replacement"#,
|
||||
path.display()
|
||||
));
|
||||
let mut stdout = Vec::new();
|
||||
let mut stderr = Vec::new();
|
||||
apply_patch(&patch, &mut stdout, &mut stderr).unwrap();
|
||||
let contents = fs::read_to_string(path).unwrap();
|
||||
assert_eq!(
|
||||
contents,
|
||||
"line1\nline2-replacement\nafter-context\nsecond-line\n"
|
||||
);
|
||||
}
|
||||
|
||||
/// Ensure that patches authored with ASCII characters can update lines that
|
||||
/// contain typographic Unicode punctuation (e.g. EN DASH, NON-BREAKING
|
||||
/// HYPHEN). Historically `git apply` succeeds in such scenarios but our
|
||||
|
||||
@@ -617,7 +617,7 @@ fn test_parse_patch_lenient() {
|
||||
assert_eq!(
|
||||
parse_patch_text(&patch_text_in_double_quoted_heredoc, ParseMode::Lenient),
|
||||
Ok(ApplyPatchArgs {
|
||||
hunks: expected_patch.clone(),
|
||||
hunks: expected_patch,
|
||||
patch: patch_text.to_string(),
|
||||
workdir: None,
|
||||
})
|
||||
@@ -637,7 +637,7 @@ fn test_parse_patch_lenient() {
|
||||
"<<EOF\n*** Begin Patch\n*** Update File: file2.py\nEOF\n".to_string();
|
||||
assert_eq!(
|
||||
parse_patch_text(&patch_text_with_missing_closing_heredoc, ParseMode::Strict),
|
||||
Err(expected_error.clone())
|
||||
Err(expected_error)
|
||||
);
|
||||
assert_eq!(
|
||||
parse_patch_text(&patch_text_with_missing_closing_heredoc, ParseMode::Lenient),
|
||||
|
||||
@@ -11,8 +11,6 @@ anyhow = "1"
|
||||
clap = { version = "4", features = ["derive"] }
|
||||
codex-common = { path = "../common", features = ["cli"] }
|
||||
codex-core = { path = "../core" }
|
||||
codex-protocol = { path = "../protocol" }
|
||||
reqwest = { version = "0.12", features = ["json", "stream"] }
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
|
||||
@@ -31,7 +31,7 @@ pub async fn run_apply_command(
|
||||
ConfigOverrides::default(),
|
||||
)?;
|
||||
|
||||
init_chatgpt_token_from_auth(&config.codex_home, &config.responses_originator_header).await?;
|
||||
init_chatgpt_token_from_auth(&config.codex_home).await?;
|
||||
|
||||
let task_response = get_task(&config, apply_cli.task_id).await?;
|
||||
apply_diff_from_task(task_response, cwd).await
|
||||
|
||||
@@ -13,10 +13,10 @@ pub(crate) async fn chatgpt_get_request<T: DeserializeOwned>(
|
||||
path: String,
|
||||
) -> anyhow::Result<T> {
|
||||
let chatgpt_base_url = &config.chatgpt_base_url;
|
||||
init_chatgpt_token_from_auth(&config.codex_home, &config.responses_originator_header).await?;
|
||||
init_chatgpt_token_from_auth(&config.codex_home).await?;
|
||||
|
||||
// Make direct HTTP request to ChatGPT backend API with the token
|
||||
let client = create_client(&config.responses_originator_header);
|
||||
let client = create_client();
|
||||
let url = format!("{chatgpt_base_url}{path}");
|
||||
|
||||
let token =
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
use codex_core::CodexAuth;
|
||||
use codex_protocol::mcp_protocol::AuthMode;
|
||||
use std::path::Path;
|
||||
use std::sync::LazyLock;
|
||||
use std::sync::RwLock;
|
||||
@@ -19,11 +18,8 @@ pub fn set_chatgpt_token_data(value: TokenData) {
|
||||
}
|
||||
|
||||
/// Initialize the ChatGPT token from auth.json file
|
||||
pub async fn init_chatgpt_token_from_auth(
|
||||
codex_home: &Path,
|
||||
originator: &str,
|
||||
) -> std::io::Result<()> {
|
||||
let auth = CodexAuth::from_codex_home(codex_home, AuthMode::ChatGPT, originator)?;
|
||||
pub async fn init_chatgpt_token_from_auth(codex_home: &Path) -> std::io::Result<()> {
|
||||
let auth = CodexAuth::from_codex_home(codex_home)?;
|
||||
if let Some(auth) = auth {
|
||||
let token_data = auth.get_token_data().await?;
|
||||
set_chatgpt_token_data(token_data);
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
use codex_common::CliConfigOverrides;
|
||||
use codex_core::CodexAuth;
|
||||
use codex_core::auth::CLIENT_ID;
|
||||
use codex_core::auth::OPENAI_API_KEY_ENV_VAR;
|
||||
use codex_core::auth::login_with_api_key;
|
||||
use codex_core::auth::logout;
|
||||
use codex_core::config::Config;
|
||||
@@ -9,11 +8,10 @@ use codex_core::config::ConfigOverrides;
|
||||
use codex_login::ServerOptions;
|
||||
use codex_login::run_login_server;
|
||||
use codex_protocol::mcp_protocol::AuthMode;
|
||||
use std::env;
|
||||
use std::path::PathBuf;
|
||||
|
||||
pub async fn login_with_chatgpt(codex_home: PathBuf, originator: String) -> std::io::Result<()> {
|
||||
let opts = ServerOptions::new(codex_home, CLIENT_ID.to_string(), originator);
|
||||
pub async fn login_with_chatgpt(codex_home: PathBuf) -> std::io::Result<()> {
|
||||
let opts = ServerOptions::new(codex_home, CLIENT_ID.to_string());
|
||||
let server = run_login_server(opts)?;
|
||||
|
||||
eprintln!(
|
||||
@@ -27,12 +25,7 @@ pub async fn login_with_chatgpt(codex_home: PathBuf, originator: String) -> std:
|
||||
pub async fn run_login_with_chatgpt(cli_config_overrides: CliConfigOverrides) -> ! {
|
||||
let config = load_config_or_exit(cli_config_overrides);
|
||||
|
||||
match login_with_chatgpt(
|
||||
config.codex_home,
|
||||
config.responses_originator_header.clone(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
match login_with_chatgpt(config.codex_home).await {
|
||||
Ok(_) => {
|
||||
eprintln!("Successfully logged in");
|
||||
std::process::exit(0);
|
||||
@@ -65,23 +58,11 @@ pub async fn run_login_with_api_key(
|
||||
pub async fn run_login_status(cli_config_overrides: CliConfigOverrides) -> ! {
|
||||
let config = load_config_or_exit(cli_config_overrides);
|
||||
|
||||
match CodexAuth::from_codex_home(
|
||||
&config.codex_home,
|
||||
config.preferred_auth_method,
|
||||
&config.responses_originator_header,
|
||||
) {
|
||||
match CodexAuth::from_codex_home(&config.codex_home) {
|
||||
Ok(Some(auth)) => match auth.mode {
|
||||
AuthMode::ApiKey => match auth.get_token().await {
|
||||
Ok(api_key) => {
|
||||
eprintln!("Logged in using an API key - {}", safe_format_key(&api_key));
|
||||
|
||||
if let Ok(env_api_key) = env::var(OPENAI_API_KEY_ENV_VAR)
|
||||
&& env_api_key == api_key
|
||||
{
|
||||
eprintln!(
|
||||
" API loaded from OPENAI_API_KEY environment variable or .env file"
|
||||
);
|
||||
}
|
||||
std::process::exit(0);
|
||||
}
|
||||
Err(e) => {
|
||||
|
||||
@@ -37,11 +37,8 @@ pub async fn run_main(opts: ProtoCli) -> anyhow::Result<()> {
|
||||
|
||||
let config = Config::load_with_cli_overrides(overrides_vec, ConfigOverrides::default())?;
|
||||
// Use conversation_manager API to start a conversation
|
||||
let conversation_manager = ConversationManager::new(AuthManager::shared(
|
||||
config.codex_home.clone(),
|
||||
config.preferred_auth_method,
|
||||
config.responses_originator_header.clone(),
|
||||
));
|
||||
let conversation_manager =
|
||||
ConversationManager::new(AuthManager::shared(config.codex_home.clone()));
|
||||
let NewConversation {
|
||||
conversation_id: _,
|
||||
conversation,
|
||||
|
||||
@@ -2,7 +2,7 @@ use std::time::Duration;
|
||||
use std::time::Instant;
|
||||
|
||||
/// Returns a string representing the elapsed time since `start_time` like
|
||||
/// "1m15s" or "1.50s".
|
||||
/// "1m 15s" or "1.50s".
|
||||
pub fn format_elapsed(start_time: Instant) -> String {
|
||||
format_duration(start_time.elapsed())
|
||||
}
|
||||
@@ -12,7 +12,7 @@ pub fn format_elapsed(start_time: Instant) -> String {
|
||||
/// Formatting rules:
|
||||
/// * < 1 s -> "{milli}ms"
|
||||
/// * < 60 s -> "{sec:.2}s" (two decimal places)
|
||||
/// * >= 60 s -> "{min}m{sec:02}s"
|
||||
/// * >= 60 s -> "{min}m {sec:02}s"
|
||||
pub fn format_duration(duration: Duration) -> String {
|
||||
let millis = duration.as_millis() as i64;
|
||||
format_elapsed_millis(millis)
|
||||
@@ -26,7 +26,7 @@ fn format_elapsed_millis(millis: i64) -> String {
|
||||
} else {
|
||||
let minutes = millis / 60_000;
|
||||
let seconds = (millis % 60_000) / 1000;
|
||||
format!("{minutes}m{seconds:02}s")
|
||||
format!("{minutes}m {seconds:02}s")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -61,12 +61,18 @@ mod tests {
|
||||
fn test_format_duration_minutes() {
|
||||
// Durations ≥ 1 minute should be printed mmss.
|
||||
let dur = Duration::from_millis(75_000); // 1m15s
|
||||
assert_eq!(format_duration(dur), "1m15s");
|
||||
assert_eq!(format_duration(dur), "1m 15s");
|
||||
|
||||
let dur_exact = Duration::from_millis(60_000); // 1m0s
|
||||
assert_eq!(format_duration(dur_exact), "1m00s");
|
||||
assert_eq!(format_duration(dur_exact), "1m 00s");
|
||||
|
||||
let dur_long = Duration::from_millis(3_601_000);
|
||||
assert_eq!(format_duration(dur_long), "60m01s");
|
||||
assert_eq!(format_duration(dur_long), "60m 01s");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_duration_one_hour_has_space() {
|
||||
let dur_hour = Duration::from_millis(3_600_000);
|
||||
assert_eq!(format_duration(dur_hour), "60m 00s");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -49,6 +49,13 @@ pub fn builtin_model_presets() -> &'static [ModelPreset] {
|
||||
model: "gpt-5",
|
||||
effort: ReasoningEffort::High,
|
||||
},
|
||||
ModelPreset {
|
||||
id: "gpt-5-high-new",
|
||||
label: "gpt-5 high new",
|
||||
description: "— our latest release tuned to rely on the model's built-in reasoning defaults",
|
||||
model: "gpt-5-high-new",
|
||||
effort: ReasoningEffort::Medium,
|
||||
},
|
||||
];
|
||||
PRESETS
|
||||
}
|
||||
|
||||
@@ -26,14 +26,12 @@ eventsource-stream = "0.2.3"
|
||||
futures = "0.3"
|
||||
libc = "0.2.175"
|
||||
mcp-types = { path = "../mcp-types" }
|
||||
mime_guess = "2.0"
|
||||
os_info = "3.12.0"
|
||||
portable-pty = "0.9.0"
|
||||
rand = "0.9"
|
||||
regex-lite = "0.1.7"
|
||||
reqwest = { version = "0.12", features = ["json", "stream"] }
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_bytes = "0.11"
|
||||
serde_json = "1"
|
||||
sha1 = "0.10.6"
|
||||
shlex = "1.3.0"
|
||||
@@ -56,7 +54,7 @@ tracing = { version = "0.1.41", features = ["log"] }
|
||||
tree-sitter = "0.25.9"
|
||||
tree-sitter-bash = "0.25.0"
|
||||
uuid = { version = "1", features = ["serde", "v4"] }
|
||||
whoami = "1.6.1"
|
||||
which = "6"
|
||||
wildmatch = "2.4.0"
|
||||
|
||||
|
||||
@@ -72,9 +70,6 @@ openssl-sys = { version = "*", features = ["vendored"] }
|
||||
[target.aarch64-unknown-linux-musl.dependencies]
|
||||
openssl-sys = { version = "*", features = ["vendored"] }
|
||||
|
||||
[target.'cfg(target_os = "windows")'.dependencies]
|
||||
which = "6"
|
||||
|
||||
[dev-dependencies]
|
||||
assert_cmd = "2"
|
||||
core_test_support = { path = "tests/common" }
|
||||
@@ -85,3 +80,6 @@ tempfile = "3"
|
||||
tokio-test = "0.4"
|
||||
walkdir = "2.5.0"
|
||||
wiremock = "0.6"
|
||||
|
||||
[package.metadata.cargo-shear]
|
||||
ignored = ["openssl-sys"]
|
||||
|
||||
@@ -17,6 +17,7 @@ use std::time::Duration;
|
||||
|
||||
use codex_protocol::mcp_protocol::AuthMode;
|
||||
|
||||
use crate::token_data::PlanType;
|
||||
use crate::token_data::TokenData;
|
||||
use crate::token_data::parse_id_token;
|
||||
|
||||
@@ -70,14 +71,9 @@ impl CodexAuth {
|
||||
Ok(access)
|
||||
}
|
||||
|
||||
/// Loads the available auth information from the auth.json or
|
||||
/// OPENAI_API_KEY environment variable.
|
||||
pub fn from_codex_home(
|
||||
codex_home: &Path,
|
||||
preferred_auth_method: AuthMode,
|
||||
originator: &str,
|
||||
) -> std::io::Result<Option<CodexAuth>> {
|
||||
load_auth(codex_home, true, preferred_auth_method, originator)
|
||||
/// Loads the available auth information from the auth.json.
|
||||
pub fn from_codex_home(codex_home: &Path) -> std::io::Result<Option<CodexAuth>> {
|
||||
load_auth(codex_home)
|
||||
}
|
||||
|
||||
pub async fn get_token_data(&self) -> Result<TokenData, std::io::Error> {
|
||||
@@ -136,13 +132,12 @@ impl CodexAuth {
|
||||
}
|
||||
|
||||
pub fn get_account_id(&self) -> Option<String> {
|
||||
self.get_current_token_data()
|
||||
.and_then(|t| t.account_id.clone())
|
||||
self.get_current_token_data().and_then(|t| t.account_id)
|
||||
}
|
||||
|
||||
pub fn get_plan_type(&self) -> Option<String> {
|
||||
pub(crate) fn get_plan_type(&self) -> Option<PlanType> {
|
||||
self.get_current_token_data()
|
||||
.and_then(|t| t.id_token.chatgpt_plan_type.as_ref().map(|p| p.as_string()))
|
||||
.and_then(|t| t.id_token.chatgpt_plan_type)
|
||||
}
|
||||
|
||||
fn get_current_auth_json(&self) -> Option<AuthDotJson> {
|
||||
@@ -151,7 +146,7 @@ impl CodexAuth {
|
||||
}
|
||||
|
||||
fn get_current_token_data(&self) -> Option<TokenData> {
|
||||
self.get_current_auth_json().and_then(|t| t.tokens.clone())
|
||||
self.get_current_auth_json().and_then(|t| t.tokens)
|
||||
}
|
||||
|
||||
/// Consider this private to integration tests.
|
||||
@@ -173,7 +168,7 @@ impl CodexAuth {
|
||||
mode: AuthMode::ChatGPT,
|
||||
auth_file: PathBuf::new(),
|
||||
auth_dot_json,
|
||||
client: crate::default_client::create_client("codex_cli_rs"),
|
||||
client: crate::default_client::create_client(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -188,19 +183,17 @@ impl CodexAuth {
|
||||
}
|
||||
|
||||
pub fn from_api_key(api_key: &str) -> Self {
|
||||
Self::from_api_key_with_client(
|
||||
api_key,
|
||||
crate::default_client::create_client(crate::default_client::DEFAULT_ORIGINATOR),
|
||||
)
|
||||
Self::from_api_key_with_client(api_key, crate::default_client::create_client())
|
||||
}
|
||||
}
|
||||
|
||||
pub const OPENAI_API_KEY_ENV_VAR: &str = "OPENAI_API_KEY";
|
||||
|
||||
fn read_openai_api_key_from_env() -> Option<String> {
|
||||
pub fn read_openai_api_key_from_env() -> Option<String> {
|
||||
env::var(OPENAI_API_KEY_ENV_VAR)
|
||||
.ok()
|
||||
.filter(|s| !s.is_empty())
|
||||
.map(|value| value.trim().to_string())
|
||||
.filter(|value| !value.is_empty())
|
||||
}
|
||||
|
||||
pub fn get_auth_file(codex_home: &Path) -> PathBuf {
|
||||
@@ -218,7 +211,7 @@ pub fn logout(codex_home: &Path) -> std::io::Result<bool> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Writes an `auth.json` that contains only the API key. Intended for CLI use.
|
||||
/// Writes an `auth.json` that contains only the API key.
|
||||
pub fn login_with_api_key(codex_home: &Path, api_key: &str) -> std::io::Result<()> {
|
||||
let auth_dot_json = AuthDotJson {
|
||||
openai_api_key: Some(api_key.to_string()),
|
||||
@@ -228,29 +221,11 @@ pub fn login_with_api_key(codex_home: &Path, api_key: &str) -> std::io::Result<(
|
||||
write_auth_json(&get_auth_file(codex_home), &auth_dot_json)
|
||||
}
|
||||
|
||||
fn load_auth(
|
||||
codex_home: &Path,
|
||||
include_env_var: bool,
|
||||
preferred_auth_method: AuthMode,
|
||||
originator: &str,
|
||||
) -> std::io::Result<Option<CodexAuth>> {
|
||||
// First, check to see if there is a valid auth.json file. If not, we fall
|
||||
// back to AuthMode::ApiKey using the OPENAI_API_KEY environment variable
|
||||
// (if it is set).
|
||||
fn load_auth(codex_home: &Path) -> std::io::Result<Option<CodexAuth>> {
|
||||
let auth_file = get_auth_file(codex_home);
|
||||
let client = crate::default_client::create_client(originator);
|
||||
let client = crate::default_client::create_client();
|
||||
let auth_dot_json = match try_read_auth_json(&auth_file) {
|
||||
Ok(auth) => auth,
|
||||
// If auth.json does not exist, try to read the OPENAI_API_KEY from the
|
||||
// environment variable.
|
||||
Err(e) if e.kind() == std::io::ErrorKind::NotFound && include_env_var => {
|
||||
return match read_openai_api_key_from_env() {
|
||||
Some(api_key) => Ok(Some(CodexAuth::from_api_key_with_client(&api_key, client))),
|
||||
None => Ok(None),
|
||||
};
|
||||
}
|
||||
// Though if auth.json exists but is malformed, do not fall back to the
|
||||
// env var because the user may be expecting to use AuthMode::ChatGPT.
|
||||
Err(e) => {
|
||||
return Err(e);
|
||||
}
|
||||
@@ -262,32 +237,11 @@ fn load_auth(
|
||||
last_refresh,
|
||||
} = auth_dot_json;
|
||||
|
||||
// If the auth.json has an API key AND does not appear to be on a plan that
|
||||
// should prefer AuthMode::ChatGPT, use AuthMode::ApiKey.
|
||||
// Prefer AuthMode.ApiKey if it's set in the auth.json.
|
||||
if let Some(api_key) = &auth_json_api_key {
|
||||
// Should any of these be AuthMode::ChatGPT with the api_key set?
|
||||
// Does AuthMode::ChatGPT indicate that there is an auth.json that is
|
||||
// "refreshable" even if we are using the API key for auth?
|
||||
match &tokens {
|
||||
Some(tokens) => {
|
||||
if tokens.should_use_api_key(preferred_auth_method, tokens.is_openai_email()) {
|
||||
return Ok(Some(CodexAuth::from_api_key_with_client(api_key, client)));
|
||||
} else {
|
||||
// Ignore the API key and fall through to ChatGPT auth.
|
||||
}
|
||||
}
|
||||
None => {
|
||||
// We have an API key but no tokens in the auth.json file.
|
||||
// Perhaps the user ran `codex login --api-key <KEY>` or updated
|
||||
// auth.json by hand. Either way, let's assume they are trying
|
||||
// to use their API key.
|
||||
return Ok(Some(CodexAuth::from_api_key_with_client(api_key, client)));
|
||||
}
|
||||
}
|
||||
return Ok(Some(CodexAuth::from_api_key_with_client(api_key, client)));
|
||||
}
|
||||
|
||||
// For the AuthMode::ChatGPT variant, perhaps neither api_key nor
|
||||
// openai_api_key should exist?
|
||||
Ok(Some(CodexAuth {
|
||||
api_key: None,
|
||||
mode: AuthMode::ChatGPT,
|
||||
@@ -337,10 +291,10 @@ async fn update_tokens(
|
||||
let tokens = auth_dot_json.tokens.get_or_insert_with(TokenData::default);
|
||||
tokens.id_token = parse_id_token(&id_token).map_err(std::io::Error::other)?;
|
||||
if let Some(access_token) = access_token {
|
||||
tokens.access_token = access_token.to_string();
|
||||
tokens.access_token = access_token;
|
||||
}
|
||||
if let Some(refresh_token) = refresh_token {
|
||||
tokens.refresh_token = refresh_token.to_string();
|
||||
tokens.refresh_token = refresh_token;
|
||||
}
|
||||
auth_dot_json.last_refresh = Some(Utc::now());
|
||||
write_auth_json(auth_file, &auth_dot_json)?;
|
||||
@@ -417,7 +371,6 @@ use std::sync::RwLock;
|
||||
/// Internal cached auth state.
|
||||
#[derive(Clone, Debug)]
|
||||
struct CachedAuth {
|
||||
preferred_auth_mode: AuthMode,
|
||||
auth: Option<CodexAuth>,
|
||||
}
|
||||
|
||||
@@ -473,9 +426,7 @@ mod tests {
|
||||
auth_dot_json,
|
||||
auth_file: _,
|
||||
..
|
||||
} = super::load_auth(codex_home.path(), false, AuthMode::ChatGPT, "codex_cli_rs")
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
} = super::load_auth(codex_home.path()).unwrap().unwrap();
|
||||
assert_eq!(None, api_key);
|
||||
assert_eq!(AuthMode::ChatGPT, mode);
|
||||
|
||||
@@ -504,88 +455,6 @@ mod tests {
|
||||
)
|
||||
}
|
||||
|
||||
/// Even if the OPENAI_API_KEY is set in auth.json, if the plan is not in
|
||||
/// [`TokenData::is_plan_that_should_use_api_key`], it should use
|
||||
/// [`AuthMode::ChatGPT`].
|
||||
#[tokio::test]
|
||||
async fn pro_account_with_api_key_still_uses_chatgpt_auth() {
|
||||
let codex_home = tempdir().unwrap();
|
||||
let fake_jwt = write_auth_file(
|
||||
AuthFileParams {
|
||||
openai_api_key: Some("sk-test-key".to_string()),
|
||||
chatgpt_plan_type: "pro".to_string(),
|
||||
},
|
||||
codex_home.path(),
|
||||
)
|
||||
.expect("failed to write auth file");
|
||||
|
||||
let CodexAuth {
|
||||
api_key,
|
||||
mode,
|
||||
auth_dot_json,
|
||||
auth_file: _,
|
||||
..
|
||||
} = super::load_auth(codex_home.path(), false, AuthMode::ChatGPT, "codex_cli_rs")
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
assert_eq!(None, api_key);
|
||||
assert_eq!(AuthMode::ChatGPT, mode);
|
||||
|
||||
let guard = auth_dot_json.lock().unwrap();
|
||||
let auth_dot_json = guard.as_ref().expect("AuthDotJson should exist");
|
||||
assert_eq!(
|
||||
&AuthDotJson {
|
||||
openai_api_key: None,
|
||||
tokens: Some(TokenData {
|
||||
id_token: IdTokenInfo {
|
||||
email: Some("user@example.com".to_string()),
|
||||
chatgpt_plan_type: Some(PlanType::Known(KnownPlan::Pro)),
|
||||
raw_jwt: fake_jwt,
|
||||
},
|
||||
access_token: "test-access-token".to_string(),
|
||||
refresh_token: "test-refresh-token".to_string(),
|
||||
account_id: None,
|
||||
}),
|
||||
last_refresh: Some(
|
||||
DateTime::parse_from_rfc3339(LAST_REFRESH)
|
||||
.unwrap()
|
||||
.with_timezone(&Utc)
|
||||
),
|
||||
},
|
||||
auth_dot_json
|
||||
)
|
||||
}
|
||||
|
||||
/// If the OPENAI_API_KEY is set in auth.json and it is an enterprise
|
||||
/// account, then it should use [`AuthMode::ApiKey`].
|
||||
#[tokio::test]
|
||||
async fn enterprise_account_with_api_key_uses_apikey_auth() {
|
||||
let codex_home = tempdir().unwrap();
|
||||
write_auth_file(
|
||||
AuthFileParams {
|
||||
openai_api_key: Some("sk-test-key".to_string()),
|
||||
chatgpt_plan_type: "enterprise".to_string(),
|
||||
},
|
||||
codex_home.path(),
|
||||
)
|
||||
.expect("failed to write auth file");
|
||||
|
||||
let CodexAuth {
|
||||
api_key,
|
||||
mode,
|
||||
auth_dot_json,
|
||||
auth_file: _,
|
||||
..
|
||||
} = super::load_auth(codex_home.path(), false, AuthMode::ChatGPT, "codex_cli_rs")
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
assert_eq!(Some("sk-test-key".to_string()), api_key);
|
||||
assert_eq!(AuthMode::ApiKey, mode);
|
||||
|
||||
let guard = auth_dot_json.lock().expect("should unwrap");
|
||||
assert!(guard.is_none(), "auth_dot_json should be None");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn loads_api_key_from_auth_json() {
|
||||
let dir = tempdir().unwrap();
|
||||
@@ -596,9 +465,7 @@ mod tests {
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let auth = super::load_auth(dir.path(), false, AuthMode::ChatGPT, "codex_cli_rs")
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
let auth = super::load_auth(dir.path()).unwrap().unwrap();
|
||||
assert_eq!(auth.mode, AuthMode::ApiKey);
|
||||
assert_eq!(auth.api_key, Some("sk-test-key".to_string()));
|
||||
|
||||
@@ -680,7 +547,6 @@ mod tests {
|
||||
#[derive(Debug)]
|
||||
pub struct AuthManager {
|
||||
codex_home: PathBuf,
|
||||
originator: String,
|
||||
inner: RwLock<CachedAuth>,
|
||||
}
|
||||
|
||||
@@ -689,30 +555,19 @@ impl AuthManager {
|
||||
/// preferred auth method. Errors loading auth are swallowed; `auth()` will
|
||||
/// simply return `None` in that case so callers can treat it as an
|
||||
/// unauthenticated state.
|
||||
pub fn new(codex_home: PathBuf, preferred_auth_mode: AuthMode, originator: String) -> Self {
|
||||
let auth = CodexAuth::from_codex_home(&codex_home, preferred_auth_mode, &originator)
|
||||
.ok()
|
||||
.flatten();
|
||||
pub fn new(codex_home: PathBuf) -> Self {
|
||||
let auth = CodexAuth::from_codex_home(&codex_home).ok().flatten();
|
||||
Self {
|
||||
codex_home,
|
||||
originator,
|
||||
inner: RwLock::new(CachedAuth {
|
||||
preferred_auth_mode,
|
||||
auth,
|
||||
}),
|
||||
inner: RwLock::new(CachedAuth { auth }),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an AuthManager with a specific CodexAuth, for testing only.
|
||||
pub fn from_auth_for_testing(auth: CodexAuth) -> Arc<Self> {
|
||||
let preferred_auth_mode = auth.mode;
|
||||
let cached = CachedAuth {
|
||||
preferred_auth_mode,
|
||||
auth: Some(auth),
|
||||
};
|
||||
let cached = CachedAuth { auth: Some(auth) };
|
||||
Arc::new(Self {
|
||||
codex_home: PathBuf::new(),
|
||||
originator: "codex_cli_rs".to_string(),
|
||||
inner: RwLock::new(cached),
|
||||
})
|
||||
}
|
||||
@@ -722,21 +577,10 @@ impl AuthManager {
|
||||
self.inner.read().ok().and_then(|c| c.auth.clone())
|
||||
}
|
||||
|
||||
/// Preferred auth method used when (re)loading.
|
||||
pub fn preferred_auth_method(&self) -> AuthMode {
|
||||
self.inner
|
||||
.read()
|
||||
.map(|c| c.preferred_auth_mode)
|
||||
.unwrap_or(AuthMode::ApiKey)
|
||||
}
|
||||
|
||||
/// Force a reload using the existing preferred auth method. Returns
|
||||
/// Force a reload of the auth information from auth.json. Returns
|
||||
/// whether the auth value changed.
|
||||
pub fn reload(&self) -> bool {
|
||||
let preferred = self.preferred_auth_method();
|
||||
let new_auth = CodexAuth::from_codex_home(&self.codex_home, preferred, &self.originator)
|
||||
.ok()
|
||||
.flatten();
|
||||
let new_auth = CodexAuth::from_codex_home(&self.codex_home).ok().flatten();
|
||||
if let Ok(mut guard) = self.inner.write() {
|
||||
let changed = !AuthManager::auths_equal(&guard.auth, &new_auth);
|
||||
guard.auth = new_auth;
|
||||
@@ -755,12 +599,8 @@ impl AuthManager {
|
||||
}
|
||||
|
||||
/// Convenience constructor returning an `Arc` wrapper.
|
||||
pub fn shared(
|
||||
codex_home: PathBuf,
|
||||
preferred_auth_mode: AuthMode,
|
||||
originator: String,
|
||||
) -> Arc<Self> {
|
||||
Arc::new(Self::new(codex_home, preferred_auth_mode, originator))
|
||||
pub fn shared(codex_home: PathBuf) -> Arc<Self> {
|
||||
Arc::new(Self::new(codex_home))
|
||||
}
|
||||
|
||||
/// Attempt to refresh the current auth token (if any). On success, reload
|
||||
|
||||
@@ -41,6 +41,7 @@ use crate::model_provider_info::WireApi;
|
||||
use crate::openai_model_info::get_model_info;
|
||||
use crate::openai_tools::create_tools_json_for_responses_api;
|
||||
use crate::protocol::TokenUsage;
|
||||
use crate::token_data::PlanType;
|
||||
use crate::util::backoff;
|
||||
use codex_protocol::config_types::ReasoningEffort as ReasoningEffortConfig;
|
||||
use codex_protocol::config_types::ReasoningSummary as ReasoningSummaryConfig;
|
||||
@@ -60,7 +61,7 @@ struct Error {
|
||||
message: Option<String>,
|
||||
|
||||
// Optional fields available on "usage_limit_reached" and "usage_not_included" errors
|
||||
plan_type: Option<String>,
|
||||
plan_type: Option<PlanType>,
|
||||
resets_in_seconds: Option<u64>,
|
||||
}
|
||||
|
||||
@@ -84,7 +85,7 @@ impl ModelClient {
|
||||
summary: ReasoningSummaryConfig,
|
||||
conversation_id: ConversationId,
|
||||
) -> Self {
|
||||
let client = create_client(&config.responses_originator_header);
|
||||
let client = create_client();
|
||||
|
||||
Self {
|
||||
config,
|
||||
@@ -239,10 +240,10 @@ impl ModelClient {
|
||||
let res = req_builder.send().await;
|
||||
if let Ok(resp) = &res {
|
||||
trace!(
|
||||
"Response status: {}, request-id: {}",
|
||||
"Response status: {}, cf-ray: {}",
|
||||
resp.status(),
|
||||
resp.headers()
|
||||
.get("x-request-id")
|
||||
.get("cf-ray")
|
||||
.map(|v| v.to_str().unwrap_or_default())
|
||||
.unwrap_or_default()
|
||||
);
|
||||
@@ -304,7 +305,7 @@ impl ModelClient {
|
||||
// token.
|
||||
let plan_type = error
|
||||
.plan_type
|
||||
.or_else(|| auth.and_then(|a| a.get_plan_type()));
|
||||
.or_else(|| auth.as_ref().and_then(|a| a.get_plan_type()));
|
||||
let resets_in_seconds = error.resets_in_seconds;
|
||||
return Err(CodexErr::UsageLimitReached(UsageLimitReachedError {
|
||||
plan_type,
|
||||
@@ -1037,4 +1038,37 @@ mod tests {
|
||||
let delay = try_parse_retry_after(&err);
|
||||
assert_eq!(delay, Some(Duration::from_secs_f64(1.898)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn error_response_deserializes_old_schema_known_plan_type_and_serializes_back() {
|
||||
use crate::token_data::KnownPlan;
|
||||
use crate::token_data::PlanType;
|
||||
|
||||
let json = r#"{"error":{"type":"usage_limit_reached","plan_type":"pro","resets_in_seconds":3600}}"#;
|
||||
let resp: ErrorResponse =
|
||||
serde_json::from_str(json).expect("should deserialize old schema");
|
||||
|
||||
assert!(matches!(
|
||||
resp.error.plan_type,
|
||||
Some(PlanType::Known(KnownPlan::Pro))
|
||||
));
|
||||
|
||||
let plan_json = serde_json::to_string(&resp.error.plan_type).expect("serialize plan_type");
|
||||
assert_eq!(plan_json, "\"pro\"");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn error_response_deserializes_old_schema_unknown_plan_type_and_serializes_back() {
|
||||
use crate::token_data::PlanType;
|
||||
|
||||
let json =
|
||||
r#"{"error":{"type":"usage_limit_reached","plan_type":"vip","resets_in_seconds":60}}"#;
|
||||
let resp: ErrorResponse =
|
||||
serde_json::from_str(json).expect("should deserialize old schema");
|
||||
|
||||
assert!(matches!(resp.error.plan_type, Some(PlanType::Unknown(ref s)) if s == "vip"));
|
||||
|
||||
let plan_json = serde_json::to_string(&resp.error.plan_type).expect("serialize plan_type");
|
||||
assert_eq!(plan_json, "\"vip\"");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,20 +10,22 @@ use std::time::Duration;
|
||||
|
||||
use crate::AuthManager;
|
||||
use crate::event_mapping::map_response_item_to_event_messages;
|
||||
use crate::rollout::RolloutItem;
|
||||
use crate::rollout::recorder::RolloutItemSliceExt;
|
||||
use async_channel::Receiver;
|
||||
use async_channel::Sender;
|
||||
use codex_apply_patch::ApplyPatchAction;
|
||||
use codex_apply_patch::MaybeApplyPatchVerified;
|
||||
use codex_apply_patch::maybe_parse_apply_patch_verified;
|
||||
use codex_protocol::mcp_protocol::ConversationId;
|
||||
use codex_protocol::protocol::CompactedItem;
|
||||
use codex_protocol::protocol::ConversationPathResponseEvent;
|
||||
use codex_protocol::protocol::RolloutItem;
|
||||
use codex_protocol::protocol::TaskStartedEvent;
|
||||
use codex_protocol::protocol::TurnAbortReason;
|
||||
use codex_protocol::protocol::TurnAbortedEvent;
|
||||
use codex_protocol::protocol::TurnContextItem;
|
||||
use futures::prelude::*;
|
||||
use mcp_types::CallToolResult;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use serde_json;
|
||||
use tokio::sync::oneshot;
|
||||
@@ -46,7 +48,6 @@ use crate::client_common::ResponseEvent;
|
||||
use crate::config::Config;
|
||||
use crate::config_types::ShellEnvironmentPolicy;
|
||||
use crate::conversation_history::ConversationHistory;
|
||||
use crate::conversation_manager::InitialHistory;
|
||||
use crate::environment_context::EnvironmentContext;
|
||||
use crate::error::CodexErr;
|
||||
use crate::error::Result as CodexResult;
|
||||
@@ -105,11 +106,13 @@ use crate::protocol::TokenUsageInfo;
|
||||
use crate::protocol::TurnDiffEvent;
|
||||
use crate::protocol::WebSearchBeginEvent;
|
||||
use crate::rollout::RolloutRecorder;
|
||||
use crate::rollout::RolloutRecorderParams;
|
||||
use crate::safety::SafetyCheck;
|
||||
use crate::safety::assess_command_safety;
|
||||
use crate::safety::assess_safety_for_untrusted_command;
|
||||
use crate::shell;
|
||||
use crate::turn_diff_tracker::TurnDiffTracker;
|
||||
use crate::unified_exec::UnifiedExecSessionManager;
|
||||
use crate::user_instructions::UserInstructions;
|
||||
use crate::user_notification::UserNotification;
|
||||
use crate::util::backoff;
|
||||
@@ -122,6 +125,7 @@ use codex_protocol::models::LocalShellAction;
|
||||
use codex_protocol::models::ResponseInputItem;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::models::ShellToolCallParams;
|
||||
use codex_protocol::protocol::InitialHistory;
|
||||
|
||||
// A convenience extension trait for acquiring mutex locks where poisoning is
|
||||
// unrecoverable and should abort the program. This avoids scattered `.unwrap()`
|
||||
@@ -207,12 +211,7 @@ impl Codex {
|
||||
let conversation_id = session.conversation_id;
|
||||
|
||||
// This task will run until Op::Shutdown is received.
|
||||
tokio::spawn(submission_loop(
|
||||
session.clone(),
|
||||
turn_context,
|
||||
config,
|
||||
rx_sub,
|
||||
));
|
||||
tokio::spawn(submission_loop(session, turn_context, config, rx_sub));
|
||||
let codex = Codex {
|
||||
next_id: AtomicU64::new(0),
|
||||
tx_sub,
|
||||
@@ -277,6 +276,7 @@ pub(crate) struct Session {
|
||||
/// Manager for external MCP servers/tools.
|
||||
mcp_connection_manager: McpConnectionManager,
|
||||
session_manager: ExecSessionManager,
|
||||
unified_exec_manager: UnifiedExecSessionManager,
|
||||
|
||||
/// External notifier command (will be passed as args to exec()). When
|
||||
/// `None` this feature is disabled.
|
||||
@@ -377,9 +377,18 @@ impl Session {
|
||||
return Err(anyhow::anyhow!("cwd is not absolute: {cwd:?}"));
|
||||
}
|
||||
|
||||
let conversation_id = match &initial_history {
|
||||
InitialHistory::New | InitialHistory::Forked(_) => ConversationId::default(),
|
||||
InitialHistory::Resumed(resumed_history) => resumed_history.conversation_id,
|
||||
let (conversation_id, rollout_params) = match &initial_history {
|
||||
InitialHistory::New | InitialHistory::Forked(_) => {
|
||||
let conversation_id = ConversationId::default();
|
||||
(
|
||||
conversation_id,
|
||||
RolloutRecorderParams::new(conversation_id, user_instructions.clone()),
|
||||
)
|
||||
}
|
||||
InitialHistory::Resumed(resumed_history) => (
|
||||
resumed_history.conversation_id,
|
||||
RolloutRecorderParams::resume(resumed_history.rollout_path.clone()),
|
||||
),
|
||||
};
|
||||
|
||||
// Error messages to dispatch after SessionConfigured is sent.
|
||||
@@ -391,7 +400,7 @@ impl Session {
|
||||
// - spin up MCP connection manager
|
||||
// - perform default shell discovery
|
||||
// - load history metadata
|
||||
let rollout_fut = RolloutRecorder::new(&config, conversation_id, user_instructions.clone());
|
||||
let rollout_fut = RolloutRecorder::new(&config, rollout_params);
|
||||
|
||||
let mcp_fut = McpConnectionManager::new(config.mcp_servers.clone());
|
||||
let default_shell_fut = shell::default_user_shell();
|
||||
@@ -405,6 +414,7 @@ impl Session {
|
||||
error!("failed to initialize rollout recorder: {e:#}");
|
||||
anyhow::anyhow!("failed to initialize rollout recorder: {e:#}")
|
||||
})?;
|
||||
let rollout_path = rollout_recorder.rollout_path.clone();
|
||||
// Create the mutable state for the Session.
|
||||
let state = State {
|
||||
history: ConversationHistory::new(),
|
||||
@@ -452,12 +462,12 @@ impl Session {
|
||||
tools_config: ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &config.model_family,
|
||||
approval_policy,
|
||||
sandbox_policy: sandbox_policy.clone(),
|
||||
include_plan_tool: config.include_plan_tool,
|
||||
include_apply_patch_tool: config.include_apply_patch_tool,
|
||||
include_web_search_request: config.tools_web_search_request,
|
||||
use_streamable_shell_tool: config.use_experimental_streamable_shell_tool,
|
||||
include_view_image_tool: config.include_view_image_tool,
|
||||
experimental_unified_exec_tool: config.use_experimental_unified_exec_tool,
|
||||
}),
|
||||
user_instructions,
|
||||
base_instructions,
|
||||
@@ -471,6 +481,7 @@ impl Session {
|
||||
tx_event: tx_event.clone(),
|
||||
mcp_connection_manager,
|
||||
session_manager: ExecSessionManager::default(),
|
||||
unified_exec_manager: UnifiedExecSessionManager::default(),
|
||||
notify,
|
||||
state: Mutex::new(state),
|
||||
rollout: Mutex::new(Some(rollout_recorder)),
|
||||
@@ -481,19 +492,20 @@ impl Session {
|
||||
|
||||
// Dispatch the SessionConfiguredEvent first and then report any errors.
|
||||
// If resuming, include converted initial messages in the payload so UIs can render them immediately.
|
||||
let initial_messages = Some(
|
||||
sess.apply_initial_history(&turn_context, initial_history.clone())
|
||||
.await,
|
||||
);
|
||||
let initial_messages = initial_history.get_event_msgs();
|
||||
sess.record_initial_history(&turn_context, initial_history)
|
||||
.await;
|
||||
|
||||
let events = std::iter::once(Event {
|
||||
id: INITIAL_SUBMIT_ID.to_owned(),
|
||||
msg: EventMsg::SessionConfigured(SessionConfiguredEvent {
|
||||
session_id: conversation_id,
|
||||
model,
|
||||
reasoning_effort: model_reasoning_effort,
|
||||
history_log_id,
|
||||
history_entry_count,
|
||||
initial_messages,
|
||||
rollout_path,
|
||||
}),
|
||||
})
|
||||
.chain(post_session_configured_error_events.into_iter());
|
||||
@@ -521,82 +533,42 @@ impl Session {
|
||||
}
|
||||
}
|
||||
|
||||
async fn apply_initial_history(
|
||||
async fn record_initial_history(
|
||||
&self,
|
||||
turn_context: &TurnContext,
|
||||
conversation_history: InitialHistory,
|
||||
) -> Vec<EventMsg> {
|
||||
) {
|
||||
match conversation_history {
|
||||
InitialHistory::New => self.record_initial_history_new(turn_context).await,
|
||||
InitialHistory::Forked(items) => {
|
||||
self.record_conversation_items_internal(&items, true).await;
|
||||
items
|
||||
.into_iter()
|
||||
.flat_map(|ri| {
|
||||
map_response_item_to_event_messages(&ri, self.show_raw_agent_reasoning)
|
||||
})
|
||||
.filter(|m| matches!(m, EventMsg::UserMessage(_)))
|
||||
.collect()
|
||||
InitialHistory::New => {
|
||||
// Build and record initial items (user instructions + environment context)
|
||||
let items = self.build_initial_context(turn_context);
|
||||
self.record_conversation_items(&items).await;
|
||||
}
|
||||
InitialHistory::Resumed(resumed_history) => {
|
||||
self.record_initial_history_resumed(resumed_history.history)
|
||||
.await
|
||||
InitialHistory::Resumed(_) | InitialHistory::Forked(_) => {
|
||||
let rollout_items = conversation_history.get_rollout_items();
|
||||
let persist = matches!(conversation_history, InitialHistory::Forked(_));
|
||||
|
||||
// Always add response items to conversation history
|
||||
let response_items = conversation_history.get_response_items();
|
||||
if !response_items.is_empty() {
|
||||
self.record_into_history(&response_items);
|
||||
}
|
||||
|
||||
// If persisting, persist all rollout items as-is (recorder filters)
|
||||
if persist && !rollout_items.is_empty() {
|
||||
self.persist_rollout_items(&rollout_items).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn record_initial_history_new(&self, turn_context: &TurnContext) -> Vec<EventMsg> {
|
||||
// record the initial user instructions and environment context,
|
||||
// regardless of whether we restored items.
|
||||
// TODO: Those items shouldn't be "user messages" IMO. Maybe developer messages.
|
||||
let mut conversation_items = Vec::<ResponseItem>::with_capacity(2);
|
||||
if let Some(user_instructions) = turn_context.user_instructions.as_deref() {
|
||||
conversation_items.push(UserInstructions::new(user_instructions.to_string()).into());
|
||||
}
|
||||
conversation_items.push(ResponseItem::from(EnvironmentContext::new(
|
||||
Some(turn_context.cwd.clone()),
|
||||
Some(turn_context.approval_policy),
|
||||
Some(turn_context.sandbox_policy.clone()),
|
||||
Some(self.user_shell.clone()),
|
||||
)));
|
||||
for item in conversation_items {
|
||||
self.record_conversation_item(item).await;
|
||||
}
|
||||
vec![]
|
||||
}
|
||||
|
||||
async fn record_initial_history_from_items(&self, items: Vec<ResponseItem>) {
|
||||
self.record_conversation_items_internal(&items, false).await;
|
||||
}
|
||||
|
||||
async fn record_initial_history_resumed(&self, items: Vec<RolloutItem>) -> Vec<EventMsg> {
|
||||
// Record transcript (without persisting again)
|
||||
let responses: Vec<ResponseItem> = items.as_slice().get_response_items();
|
||||
if !responses.is_empty() {
|
||||
self.record_conversation_items_internal(&responses, true)
|
||||
.await;
|
||||
}
|
||||
|
||||
items.as_slice().get_events()
|
||||
}
|
||||
|
||||
/// Sends the given event to the client and records it to the rollout (if enabled).
|
||||
/// Any send/record errors are logged and swallowed.
|
||||
/// Persist the event to rollout and send it to clients.
|
||||
pub(crate) async fn send_event(&self, event: Event) {
|
||||
let event_to_record = event.clone();
|
||||
// Persist the event into rollout (recorder filters as needed)
|
||||
let rollout_items = vec![RolloutItem::EventMsg(event.msg.clone())];
|
||||
self.persist_rollout_items(&rollout_items).await;
|
||||
if let Err(e) = self.tx_event.send(event).await {
|
||||
error!("failed to send event: {e}");
|
||||
}
|
||||
let recorder = {
|
||||
let guard = self.rollout.lock_unchecked();
|
||||
guard.as_ref().cloned()
|
||||
};
|
||||
if let Some(rec) = recorder
|
||||
&& let Err(e) = rec
|
||||
.record_items(crate::rollout::RolloutItem::Event(event_to_record))
|
||||
.await
|
||||
{
|
||||
error!("failed to record rollout event: {e:#}");
|
||||
error!("failed to send tool call event: {e}");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -684,61 +656,70 @@ impl Session {
|
||||
state.approved_commands.insert(cmd);
|
||||
}
|
||||
|
||||
/// Records items to both the rollout and the chat completions/ZDR
|
||||
/// transcript, if enabled.
|
||||
/// Records input items: always append to conversation history and
|
||||
/// persist these response items to rollout.
|
||||
async fn record_conversation_items(&self, items: &[ResponseItem]) {
|
||||
self.record_conversation_items_internal(items, true).await;
|
||||
self.record_into_history(items);
|
||||
self.persist_rollout_response_items(items).await;
|
||||
}
|
||||
|
||||
async fn record_conversation_item(&self, item: ResponseItem) {
|
||||
let items = [item];
|
||||
self.record_conversation_items_internal(&items, true).await;
|
||||
/// Append ResponseItems to the in-memory conversation history only.
|
||||
fn record_into_history(&self, items: &[ResponseItem]) {
|
||||
self.state
|
||||
.lock_unchecked()
|
||||
.history
|
||||
.record_items(items.iter());
|
||||
}
|
||||
|
||||
async fn record_conversation_items_internal(&self, items: &[ResponseItem], persist: bool) {
|
||||
debug!("Recording items for conversation: {items:?}");
|
||||
if persist {
|
||||
// Record snapshot of these items into rollout
|
||||
for item in items {
|
||||
self.record_state_snapshot(RolloutItem::ResponseItem(item.clone()))
|
||||
.await;
|
||||
}
|
||||
async fn persist_rollout_response_items(&self, items: &[ResponseItem]) {
|
||||
let rollout_items: Vec<RolloutItem> = items
|
||||
.iter()
|
||||
.cloned()
|
||||
.map(RolloutItem::ResponseItem)
|
||||
.collect();
|
||||
self.persist_rollout_items(&rollout_items).await;
|
||||
}
|
||||
|
||||
fn build_initial_context(&self, turn_context: &TurnContext) -> Vec<ResponseItem> {
|
||||
let mut items = Vec::<ResponseItem>::with_capacity(2);
|
||||
if let Some(user_instructions) = turn_context.user_instructions.as_deref() {
|
||||
items.push(UserInstructions::new(user_instructions.to_string()).into());
|
||||
}
|
||||
|
||||
self.state.lock_unchecked().history.record_items(items);
|
||||
items
|
||||
}
|
||||
|
||||
async fn record_state_snapshot(&self, item: RolloutItem) {
|
||||
async fn persist_rollout_items(&self, items: &[RolloutItem]) {
|
||||
let recorder = {
|
||||
let guard = self.rollout.lock_unchecked();
|
||||
guard.as_ref().cloned()
|
||||
};
|
||||
|
||||
if let Some(rec) = recorder
|
||||
&& let Err(e) = rec.record_items(item).await
|
||||
&& let Err(e) = rec.record_items(items).await
|
||||
{
|
||||
error!("failed to record rollout items: {e:#}");
|
||||
}
|
||||
}
|
||||
|
||||
/// Records a user input into conversation history AND a corresponding UserMessage event in rollout.
|
||||
/// Does not send events to the UI.
|
||||
async fn record_user_input(&self, sub_id: &str, response_item: ResponseItem) {
|
||||
// Record the message/tool input in conversation history/rollout state
|
||||
self.record_conversation_item(response_item.clone()).await;
|
||||
/// Record a user input item to conversation history and also persist a
|
||||
/// corresponding UserMessage EventMsg to rollout.
|
||||
async fn record_input_and_rollout_usermsg(&self, response_input: &ResponseInputItem) {
|
||||
let response_item: ResponseItem = response_input.clone().into();
|
||||
// Add to conversation history and persist response item to rollout
|
||||
self.record_conversation_items(std::slice::from_ref(&response_item))
|
||||
.await;
|
||||
|
||||
// Derive and record a UserMessage event alongside it in the rollout
|
||||
let user_events =
|
||||
map_response_item_to_event_messages(&response_item, self.show_raw_agent_reasoning)
|
||||
.into_iter()
|
||||
.filter(|m| matches!(m, EventMsg::UserMessage(_)));
|
||||
|
||||
for msg in user_events {
|
||||
let event = Event {
|
||||
id: sub_id.to_string(),
|
||||
msg,
|
||||
};
|
||||
self.record_state_snapshot(RolloutItem::Event(event)).await;
|
||||
// Derive user message events and persist only UserMessage to rollout
|
||||
let msgs =
|
||||
map_response_item_to_event_messages(&response_item, self.show_raw_agent_reasoning);
|
||||
let user_msgs: Vec<RolloutItem> = msgs
|
||||
.into_iter()
|
||||
.filter_map(|m| match m {
|
||||
EventMsg::UserMessage(ev) => Some(RolloutItem::EventMsg(EventMsg::UserMessage(ev))),
|
||||
_ => None,
|
||||
})
|
||||
.collect();
|
||||
if !user_msgs.is_empty() {
|
||||
self.persist_rollout_items(&user_msgs).await;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1079,7 +1060,7 @@ impl AgentTask {
|
||||
id: self.sub_id,
|
||||
msg: EventMsg::TurnAborted(TurnAbortedEvent { reason }),
|
||||
};
|
||||
let sess = self.sess.clone();
|
||||
let sess = self.sess;
|
||||
tokio::spawn(async move {
|
||||
sess.send_event(event).await;
|
||||
});
|
||||
@@ -1115,10 +1096,10 @@ async fn submission_loop(
|
||||
let provider = prev.client.get_provider();
|
||||
|
||||
// Effective model + family
|
||||
let (effective_model, effective_family) = if let Some(m) = model {
|
||||
let (effective_model, effective_family) = if let Some(ref m) = model {
|
||||
let fam =
|
||||
find_family_for_model(&m).unwrap_or_else(|| config.model_family.clone());
|
||||
(m, fam)
|
||||
find_family_for_model(m).unwrap_or_else(|| config.model_family.clone());
|
||||
(m.clone(), fam)
|
||||
} else {
|
||||
(prev.client.get_model(), prev.client.get_model_family())
|
||||
};
|
||||
@@ -1155,12 +1136,12 @@ async fn submission_loop(
|
||||
let tools_config = ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &effective_family,
|
||||
approval_policy: new_approval_policy,
|
||||
sandbox_policy: new_sandbox_policy.clone(),
|
||||
include_plan_tool: config.include_plan_tool,
|
||||
include_apply_patch_tool: config.include_apply_patch_tool,
|
||||
include_web_search_request: config.tools_web_search_request,
|
||||
use_streamable_shell_tool: config.use_experimental_streamable_shell_tool,
|
||||
include_view_image_tool: config.include_view_image_tool,
|
||||
experimental_unified_exec_tool: config.use_experimental_unified_exec_tool,
|
||||
});
|
||||
|
||||
let new_turn_context = TurnContext {
|
||||
@@ -1176,25 +1157,18 @@ async fn submission_loop(
|
||||
|
||||
// Install the new persistent context for subsequent tasks/turns.
|
||||
turn_context = Arc::new(new_turn_context);
|
||||
if cwd.is_some() || approval_policy.is_some() || sandbox_policy.is_some() {
|
||||
sess.record_conversation_item(ResponseItem::from(EnvironmentContext::new(
|
||||
cwd,
|
||||
approval_policy,
|
||||
sandbox_policy,
|
||||
// Shell is not configurable from turn to turn
|
||||
None,
|
||||
)))
|
||||
.await;
|
||||
}
|
||||
}
|
||||
Op::UserInput { items } => {
|
||||
// attempt to inject input into current task
|
||||
if let Err(items) = sess.inject_input(items) {
|
||||
// no current task, spawn a new one
|
||||
let task =
|
||||
AgentTask::spawn(sess.clone(), Arc::clone(&turn_context), sub.id, items);
|
||||
sess.set_task(task);
|
||||
}
|
||||
submit_user_input(
|
||||
turn_context.cwd.clone(),
|
||||
turn_context.approval_policy,
|
||||
turn_context.sandbox_policy.clone(),
|
||||
&sess,
|
||||
&turn_context,
|
||||
sub.id.clone(),
|
||||
items,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
Op::UserTurn {
|
||||
items,
|
||||
@@ -1239,13 +1213,14 @@ async fn submission_loop(
|
||||
tools_config: ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &model_family,
|
||||
approval_policy,
|
||||
sandbox_policy: sandbox_policy.clone(),
|
||||
include_plan_tool: config.include_plan_tool,
|
||||
include_apply_patch_tool: config.include_apply_patch_tool,
|
||||
include_web_search_request: config.tools_web_search_request,
|
||||
use_streamable_shell_tool: config
|
||||
.use_experimental_streamable_shell_tool,
|
||||
include_view_image_tool: config.include_view_image_tool,
|
||||
experimental_unified_exec_tool: config
|
||||
.use_experimental_unified_exec_tool,
|
||||
}),
|
||||
user_instructions: turn_context.user_instructions.clone(),
|
||||
base_instructions: turn_context.base_instructions.clone(),
|
||||
@@ -1254,11 +1229,16 @@ async fn submission_loop(
|
||||
shell_environment_policy: turn_context.shell_environment_policy.clone(),
|
||||
cwd,
|
||||
};
|
||||
// TODO: record the new environment context in the conversation history
|
||||
// no current task, spawn a new one with the per‑turn context
|
||||
let task =
|
||||
AgentTask::spawn(sess.clone(), Arc::new(fresh_turn_context), sub.id, items);
|
||||
sess.set_task(task);
|
||||
submit_user_input(
|
||||
fresh_turn_context.cwd.clone(),
|
||||
fresh_turn_context.approval_policy,
|
||||
fresh_turn_context.sandbox_policy.clone(),
|
||||
&sess,
|
||||
&Arc::new(fresh_turn_context),
|
||||
sub.id.clone(),
|
||||
items,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
Op::ExecApproval { id, decision } => match decision {
|
||||
@@ -1286,7 +1266,7 @@ async fn submission_loop(
|
||||
|
||||
Op::GetHistoryEntryRequest { offset, log_id } => {
|
||||
let config = config.clone();
|
||||
let sess_for_spawn = sess.clone();
|
||||
let sess_clone = sess.clone();
|
||||
let sub_id = sub.id.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
@@ -1314,7 +1294,7 @@ async fn submission_loop(
|
||||
),
|
||||
};
|
||||
|
||||
sess_for_spawn.send_event(event).await;
|
||||
sess_clone.send_event(event).await;
|
||||
});
|
||||
}
|
||||
Op::ListMcpTools => {
|
||||
@@ -1392,23 +1372,29 @@ async fn submission_loop(
|
||||
sess.send_event(event).await;
|
||||
break;
|
||||
}
|
||||
Op::GetConversationPath => {
|
||||
Op::GetPath => {
|
||||
let sub_id = sub.id.clone();
|
||||
// Ensure rollout file is flushed so consumers can read it immediately.
|
||||
let rec_opt = { sess.rollout.lock_unchecked().as_ref().cloned() };
|
||||
if let Some(rec) = rec_opt {
|
||||
let _ = rec.flush().await;
|
||||
// Flush rollout writes before returning the path so readers observe a consistent file.
|
||||
let (path, rec_opt) = {
|
||||
let guard = sess.rollout.lock_unchecked();
|
||||
match guard.as_ref() {
|
||||
Some(rec) => (rec.get_rollout_path(), Some(rec.clone())),
|
||||
None => {
|
||||
error!("rollout recorder not found");
|
||||
continue;
|
||||
}
|
||||
}
|
||||
};
|
||||
if let Some(rec) = rec_opt
|
||||
&& let Err(e) = rec.flush().await
|
||||
{
|
||||
warn!("failed to flush rollout recorder before GetHistory: {e}");
|
||||
}
|
||||
let event = Event {
|
||||
id: sub_id.clone(),
|
||||
msg: EventMsg::ConversationHistory(ConversationPathResponseEvent {
|
||||
msg: EventMsg::ConversationPath(ConversationPathResponseEvent {
|
||||
conversation_id: sess.conversation_id,
|
||||
path: sess
|
||||
.rollout
|
||||
.lock_unchecked()
|
||||
.as_ref()
|
||||
.map(|r| r.path().to_path_buf())
|
||||
.unwrap_or_default(),
|
||||
path,
|
||||
}),
|
||||
};
|
||||
sess.send_event(event).await;
|
||||
@@ -1443,10 +1429,6 @@ async fn run_task(
|
||||
if input.is_empty() {
|
||||
return;
|
||||
}
|
||||
let initial_input_for_turn: ResponseInputItem = ResponseInputItem::from(input);
|
||||
// Record the user's input and corresponding event into the rollout
|
||||
let user_input_response: ResponseItem = ResponseItem::from(initial_input_for_turn.clone());
|
||||
sess.record_user_input(&sub_id, user_input_response).await;
|
||||
let event = Event {
|
||||
id: sub_id.clone(),
|
||||
msg: EventMsg::TaskStarted(TaskStartedEvent {
|
||||
@@ -1455,6 +1437,10 @@ async fn run_task(
|
||||
};
|
||||
sess.send_event(event).await;
|
||||
|
||||
let initial_input_for_turn: ResponseInputItem = ResponseInputItem::from(input);
|
||||
sess.record_input_and_rollout_usermsg(&initial_input_for_turn)
|
||||
.await;
|
||||
|
||||
let mut last_agent_message: Option<String> = None;
|
||||
// Although from the perspective of codex.rs, TurnDiffTracker has the lifecycle of a Task which contains
|
||||
// many turns, from the perspective of the user, it is a single turn.
|
||||
@@ -1469,9 +1455,7 @@ async fn run_task(
|
||||
.into_iter()
|
||||
.map(ResponseItem::from)
|
||||
.collect::<Vec<ResponseItem>>();
|
||||
for item in pending_input.iter() {
|
||||
sess.record_user_input(&sub_id, item.clone()).await;
|
||||
}
|
||||
sess.record_conversation_items(&pending_input).await;
|
||||
|
||||
// Construct the input that we will send to the model. When using the
|
||||
// Chat completions API (or ZDR clients), the model needs the full
|
||||
@@ -1598,9 +1582,8 @@ async fn run_task(
|
||||
|
||||
// Only attempt to take the lock if there is something to record.
|
||||
if !items_to_record_in_conversation_history.is_empty() {
|
||||
for item in items_to_record_in_conversation_history.iter().cloned() {
|
||||
sess.record_conversation_item(item).await;
|
||||
}
|
||||
sess.record_conversation_items(&items_to_record_in_conversation_history)
|
||||
.await;
|
||||
}
|
||||
|
||||
if responses.is_empty() {
|
||||
@@ -1754,7 +1737,7 @@ async fn try_run_turn(
|
||||
}
|
||||
})
|
||||
.map(|call_id| ResponseItem::CustomToolCallOutput {
|
||||
call_id: call_id.clone(),
|
||||
call_id,
|
||||
output: "aborted".to_string(),
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
@@ -1770,6 +1753,15 @@ async fn try_run_turn(
|
||||
})
|
||||
};
|
||||
|
||||
let rollout_item = RolloutItem::TurnContext(TurnContextItem {
|
||||
cwd: turn_context.cwd.clone(),
|
||||
approval_policy: turn_context.approval_policy,
|
||||
sandbox_policy: turn_context.sandbox_policy.clone(),
|
||||
model: turn_context.client.get_model(),
|
||||
effort: turn_context.client.get_reasoning_effort(),
|
||||
summary: turn_context.client.get_reasoning_summary(),
|
||||
});
|
||||
sess.persist_rollout_items(&[rollout_item]).await;
|
||||
let mut stream = turn_context.client.clone().stream(&prompt).await?;
|
||||
|
||||
let mut output = Vec::new();
|
||||
@@ -1811,11 +1803,13 @@ async fn try_run_turn(
|
||||
output.push(ProcessedResponseItem { item, response });
|
||||
}
|
||||
ResponseEvent::WebSearchCallBegin { call_id } => {
|
||||
sess.send_event(Event {
|
||||
id: sub_id.to_string(),
|
||||
msg: EventMsg::WebSearchBegin(WebSearchBeginEvent { call_id }),
|
||||
})
|
||||
.await;
|
||||
let _ = sess
|
||||
.tx_event
|
||||
.send(Event {
|
||||
id: sub_id.to_string(),
|
||||
msg: EventMsg::WebSearchBegin(WebSearchBeginEvent { call_id }),
|
||||
})
|
||||
.await;
|
||||
}
|
||||
ResponseEvent::Completed {
|
||||
response_id: _,
|
||||
@@ -1831,11 +1825,12 @@ async fn try_run_turn(
|
||||
st.token_info = info.clone();
|
||||
info
|
||||
};
|
||||
sess.send_event(Event {
|
||||
id: sub_id.to_string(),
|
||||
msg: EventMsg::TokenCount(crate::protocol::TokenCountEvent { info }),
|
||||
})
|
||||
.await;
|
||||
let _ = sess
|
||||
.send_event(Event {
|
||||
id: sub_id.to_string(),
|
||||
msg: EventMsg::TokenCount(crate::protocol::TokenCountEvent { info }),
|
||||
})
|
||||
.await;
|
||||
|
||||
let unified_diff = turn_diff_tracker.get_unified_diff();
|
||||
if let Ok(Some(unified_diff)) = unified_diff {
|
||||
@@ -1949,10 +1944,14 @@ async fn run_compact_task(
|
||||
|
||||
sess.remove_task(&sub_id);
|
||||
|
||||
{
|
||||
let rollout_item = {
|
||||
let mut state = sess.state.lock_unchecked();
|
||||
state.history.keep_last_messages(1);
|
||||
}
|
||||
RolloutItem::Compacted(CompactedItem {
|
||||
message: state.history.last_agent_message(),
|
||||
})
|
||||
};
|
||||
sess.persist_rollout_items(&[rollout_item]).await;
|
||||
|
||||
let event = Event {
|
||||
id: sub_id.clone(),
|
||||
@@ -2086,6 +2085,72 @@ async fn handle_response_item(
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
async fn handle_unified_exec_tool_call(
|
||||
sess: &Session,
|
||||
call_id: String,
|
||||
session_id: Option<String>,
|
||||
arguments: Vec<String>,
|
||||
timeout_ms: Option<u64>,
|
||||
) -> ResponseInputItem {
|
||||
let parsed_session_id = if let Some(session_id) = session_id {
|
||||
match session_id.parse::<i32>() {
|
||||
Ok(parsed) => Some(parsed),
|
||||
Err(output) => {
|
||||
return ResponseInputItem::FunctionCallOutput {
|
||||
call_id: call_id.to_string(),
|
||||
output: FunctionCallOutputPayload {
|
||||
content: format!("invalid session_id: {session_id} due to error {output}"),
|
||||
success: Some(false),
|
||||
},
|
||||
};
|
||||
}
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let request = crate::unified_exec::UnifiedExecRequest {
|
||||
session_id: parsed_session_id,
|
||||
input_chunks: &arguments,
|
||||
timeout_ms,
|
||||
};
|
||||
|
||||
let result = sess.unified_exec_manager.handle_request(request).await;
|
||||
|
||||
let output_payload = match result {
|
||||
Ok(value) => {
|
||||
#[derive(Serialize)]
|
||||
struct SerializedUnifiedExecResult<'a> {
|
||||
session_id: Option<String>,
|
||||
output: &'a str,
|
||||
}
|
||||
|
||||
match serde_json::to_string(&SerializedUnifiedExecResult {
|
||||
session_id: value.session_id.map(|id| id.to_string()),
|
||||
output: &value.output,
|
||||
}) {
|
||||
Ok(serialized) => FunctionCallOutputPayload {
|
||||
content: serialized,
|
||||
success: Some(true),
|
||||
},
|
||||
Err(err) => FunctionCallOutputPayload {
|
||||
content: format!("failed to serialize unified exec output: {err}"),
|
||||
success: Some(false),
|
||||
},
|
||||
}
|
||||
}
|
||||
Err(err) => FunctionCallOutputPayload {
|
||||
content: format!("unified exec failed: {err}"),
|
||||
success: Some(false),
|
||||
},
|
||||
};
|
||||
|
||||
ResponseInputItem::FunctionCallOutput {
|
||||
call_id,
|
||||
output: output_payload,
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_function_call(
|
||||
sess: &Session,
|
||||
turn_context: &TurnContext,
|
||||
@@ -2113,6 +2178,38 @@ async fn handle_function_call(
|
||||
)
|
||||
.await
|
||||
}
|
||||
"unified_exec" => {
|
||||
#[derive(Deserialize)]
|
||||
struct UnifiedExecArgs {
|
||||
input: Vec<String>,
|
||||
#[serde(default)]
|
||||
session_id: Option<String>,
|
||||
#[serde(default)]
|
||||
timeout_ms: Option<u64>,
|
||||
}
|
||||
|
||||
let args = match serde_json::from_str::<UnifiedExecArgs>(&arguments) {
|
||||
Ok(args) => args,
|
||||
Err(err) => {
|
||||
return ResponseInputItem::FunctionCallOutput {
|
||||
call_id,
|
||||
output: FunctionCallOutputPayload {
|
||||
content: format!("failed to parse function arguments: {err}"),
|
||||
success: Some(false),
|
||||
},
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
handle_unified_exec_tool_call(
|
||||
sess,
|
||||
call_id,
|
||||
args.session_id,
|
||||
args.input,
|
||||
args.timeout_ms,
|
||||
)
|
||||
.await
|
||||
}
|
||||
"view_image" => {
|
||||
#[derive(serde::Deserialize)]
|
||||
struct SeeImageArgs {
|
||||
@@ -2591,6 +2688,20 @@ async fn handle_sandbox_error(
|
||||
let sub_id = exec_command_context.sub_id.clone();
|
||||
let cwd = exec_command_context.cwd.clone();
|
||||
|
||||
// if the command timed out, we can simply return this failure to the model
|
||||
if matches!(error, SandboxErr::Timeout) {
|
||||
return ResponseInputItem::FunctionCallOutput {
|
||||
call_id,
|
||||
output: FunctionCallOutputPayload {
|
||||
content: format!(
|
||||
"command timed out after {} milliseconds",
|
||||
params.timeout_duration().as_millis()
|
||||
),
|
||||
success: Some(false),
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
// Early out if either the user never wants to be asked for approval, or
|
||||
// we're letting the model manage escalation requests. Otherwise, continue
|
||||
match turn_context.approval_policy {
|
||||
@@ -2608,20 +2719,6 @@ async fn handle_sandbox_error(
|
||||
AskForApproval::UnlessTrusted | AskForApproval::OnFailure => (),
|
||||
}
|
||||
|
||||
// similarly, if the command timed out, we can simply return this failure to the model
|
||||
if matches!(error, SandboxErr::Timeout) {
|
||||
return ResponseInputItem::FunctionCallOutput {
|
||||
call_id,
|
||||
output: FunctionCallOutputPayload {
|
||||
content: format!(
|
||||
"command timed out after {} milliseconds",
|
||||
params.timeout_duration().as_millis()
|
||||
),
|
||||
success: Some(false),
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
// Note that when `error` is `SandboxErr::Denied`, it could be a false
|
||||
// positive. That is, it may have exited with a non-zero exit code, not
|
||||
// because the sandbox denied it, but because that is its expected behavior,
|
||||
@@ -2716,6 +2813,29 @@ async fn handle_sandbox_error(
|
||||
}
|
||||
}
|
||||
|
||||
async fn submit_user_input(
|
||||
cwd: PathBuf,
|
||||
approval_policy: AskForApproval,
|
||||
sandbox_policy: SandboxPolicy,
|
||||
sess: &Arc<Session>,
|
||||
turn_context: &Arc<TurnContext>,
|
||||
sub_id: String,
|
||||
items: Vec<InputItem>,
|
||||
) {
|
||||
sess.record_conversation_items(&[ResponseItem::from(EnvironmentContext::new(
|
||||
Some(cwd),
|
||||
Some(approval_policy),
|
||||
Some(sandbox_policy),
|
||||
Some(sess.user_shell.clone()),
|
||||
))])
|
||||
.await;
|
||||
if let Err(items) = sess.inject_input(items) {
|
||||
// no current task, spawn a new one
|
||||
let task = AgentTask::spawn(Arc::clone(sess), Arc::clone(turn_context), sub_id, items);
|
||||
sess.set_task(task);
|
||||
}
|
||||
}
|
||||
|
||||
fn format_exec_output_str(exec_output: &ExecToolCallOutput) -> String {
|
||||
let ExecToolCallOutput {
|
||||
aggregated_output, ..
|
||||
@@ -2880,6 +3000,15 @@ async fn drain_to_completed(
|
||||
sub_id: &str,
|
||||
prompt: &Prompt,
|
||||
) -> CodexResult<()> {
|
||||
let rollout_item = RolloutItem::TurnContext(TurnContextItem {
|
||||
cwd: turn_context.cwd.clone(),
|
||||
approval_policy: turn_context.approval_policy,
|
||||
sandbox_policy: turn_context.sandbox_policy.clone(),
|
||||
model: turn_context.client.get_model(),
|
||||
effort: turn_context.client.get_reasoning_effort(),
|
||||
summary: turn_context.client.get_reasoning_summary(),
|
||||
});
|
||||
sess.persist_rollout_items(&[rollout_item]).await;
|
||||
let mut stream = turn_context.client.clone().stream(prompt).await?;
|
||||
loop {
|
||||
let maybe_event = stream.next().await;
|
||||
@@ -2910,11 +3039,13 @@ async fn drain_to_completed(
|
||||
info
|
||||
};
|
||||
|
||||
sess.send_event(Event {
|
||||
id: sub_id.to_string(),
|
||||
msg: EventMsg::TokenCount(crate::protocol::TokenCountEvent { info }),
|
||||
})
|
||||
.await;
|
||||
sess.tx_event
|
||||
.send(Event {
|
||||
id: sub_id.to_string(),
|
||||
msg: EventMsg::TokenCount(crate::protocol::TokenCountEvent { info }),
|
||||
})
|
||||
.await
|
||||
.ok();
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
@@ -3011,7 +3142,7 @@ mod tests {
|
||||
exit_code: 0,
|
||||
stdout: StreamOutput::new(String::new()),
|
||||
stderr: StreamOutput::new(String::new()),
|
||||
aggregated_output: StreamOutput::new(full.clone()),
|
||||
aggregated_output: StreamOutput::new(full),
|
||||
duration: StdDuration::from_secs(1),
|
||||
};
|
||||
|
||||
@@ -3045,7 +3176,7 @@ mod tests {
|
||||
fn model_truncation_respects_byte_budget() {
|
||||
// Construct a large output (about 100kB) so byte budget dominates
|
||||
let big_line = "x".repeat(100);
|
||||
let full = std::iter::repeat_n(big_line.clone(), 1000)
|
||||
let full = std::iter::repeat_n(big_line, 1000)
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
|
||||
|
||||
@@ -15,11 +15,11 @@ use crate::model_provider_info::built_in_model_providers;
|
||||
use crate::openai_model_info::get_model_info;
|
||||
use crate::protocol::AskForApproval;
|
||||
use crate::protocol::SandboxPolicy;
|
||||
use anyhow::Context;
|
||||
use codex_protocol::config_types::ReasoningEffort;
|
||||
use codex_protocol::config_types::ReasoningSummary;
|
||||
use codex_protocol::config_types::SandboxMode;
|
||||
use codex_protocol::config_types::Verbosity;
|
||||
use codex_protocol::mcp_protocol::AuthMode;
|
||||
use codex_protocol::mcp_protocol::Tools;
|
||||
use codex_protocol::mcp_protocol::UserSavedConfig;
|
||||
use dirs::home_dir;
|
||||
@@ -32,15 +32,14 @@ use toml::Value as TomlValue;
|
||||
use toml_edit::DocumentMut;
|
||||
|
||||
const OPENAI_DEFAULT_MODEL: &str = "gpt-5";
|
||||
pub const GPT5_HIGH_MODEL: &str = "gpt-5-high";
|
||||
|
||||
/// Maximum number of bytes of the documentation that will be embedded. Larger
|
||||
/// files are *silently truncated* to this size so we do not take up too much of
|
||||
/// the context window.
|
||||
pub(crate) const PROJECT_DOC_MAX_BYTES: usize = 32 * 1024; // 32 KiB
|
||||
|
||||
const CONFIG_TOML_FILE: &str = "config.toml";
|
||||
|
||||
const DEFAULT_RESPONSES_ORIGINATOR_HEADER: &str = "codex_cli_rs";
|
||||
pub(crate) const CONFIG_TOML_FILE: &str = "config.toml";
|
||||
|
||||
/// Application configuration loaded from disk and merged with overrides.
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
@@ -131,9 +130,6 @@ pub struct Config {
|
||||
/// output will be hyperlinked using the specified URI scheme.
|
||||
pub file_opener: UriBasedFileOpener,
|
||||
|
||||
/// Collection of settings that are specific to the TUI.
|
||||
pub tui: Tui,
|
||||
|
||||
/// Path to the `codex-linux-sandbox` executable. This must be set if
|
||||
/// [`crate::exec::SandboxType::LinuxSeccomp`] is used. Note that this
|
||||
/// cannot be set in the config file: it must be set in code via
|
||||
@@ -169,16 +165,17 @@ pub struct Config {
|
||||
|
||||
pub tools_web_search_request: bool,
|
||||
|
||||
/// The value for the `originator` header included with Responses API requests.
|
||||
pub responses_originator_header: String,
|
||||
|
||||
/// If set to `true`, the API key will be signed with the `originator` header.
|
||||
pub preferred_auth_method: AuthMode,
|
||||
|
||||
pub use_experimental_streamable_shell_tool: bool,
|
||||
|
||||
/// If set to `true`, used only the experimental unified exec tool.
|
||||
pub use_experimental_unified_exec_tool: bool,
|
||||
|
||||
/// Include the `view_image` tool that lets the agent attach a local image path to context.
|
||||
pub include_view_image_tool: bool,
|
||||
|
||||
/// The active profile name used to derive this `Config` (if any).
|
||||
pub active_profile: Option<String>,
|
||||
|
||||
/// When true, disables burst-paste detection for typed input entirely.
|
||||
/// All characters are inserted as they are received, and no buffering
|
||||
/// or placeholder replacement will occur for fast keypress bursts.
|
||||
@@ -262,17 +259,7 @@ pub fn load_config_as_toml(codex_home: &Path) -> std::io::Result<TomlValue> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Patch `CODEX_HOME/config.toml` project state.
|
||||
/// Use with caution.
|
||||
pub fn set_project_trusted(codex_home: &Path, project_path: &Path) -> anyhow::Result<()> {
|
||||
let config_path = codex_home.join(CONFIG_TOML_FILE);
|
||||
// Parse existing config if present; otherwise start a new document.
|
||||
let mut doc = match std::fs::read_to_string(config_path.clone()) {
|
||||
Ok(s) => s.parse::<DocumentMut>()?,
|
||||
Err(e) if e.kind() == std::io::ErrorKind::NotFound => DocumentMut::new(),
|
||||
Err(e) => return Err(e.into()),
|
||||
};
|
||||
|
||||
fn set_project_trusted_inner(doc: &mut DocumentMut, project_path: &Path) -> anyhow::Result<()> {
|
||||
// Ensure we render a human-friendly structure:
|
||||
//
|
||||
// [projects]
|
||||
@@ -288,14 +275,26 @@ pub fn set_project_trusted(codex_home: &Path, project_path: &Path) -> anyhow::Re
|
||||
// Ensure top-level `projects` exists as a non-inline, explicit table. If it
|
||||
// exists but was previously represented as a non-table (e.g., inline),
|
||||
// replace it with an explicit table.
|
||||
let mut created_projects_table = false;
|
||||
{
|
||||
let root = doc.as_table_mut();
|
||||
let needs_table = !root.contains_key("projects")
|
||||
|| root.get("projects").and_then(|i| i.as_table()).is_none();
|
||||
if needs_table {
|
||||
root.insert("projects", toml_edit::table());
|
||||
created_projects_table = true;
|
||||
// If `projects` exists but isn't a standard table (e.g., it's an inline table),
|
||||
// convert it to an explicit table while preserving existing entries.
|
||||
let existing_projects = root.get("projects").cloned();
|
||||
if existing_projects.as_ref().is_none_or(|i| !i.is_table()) {
|
||||
let mut projects_tbl = toml_edit::Table::new();
|
||||
projects_tbl.set_implicit(true);
|
||||
|
||||
// If there was an existing inline table, migrate its entries to explicit tables.
|
||||
if let Some(inline_tbl) = existing_projects.as_ref().and_then(|i| i.as_inline_table()) {
|
||||
for (k, v) in inline_tbl.iter() {
|
||||
if let Some(inner_tbl) = v.as_inline_table() {
|
||||
let new_tbl = inner_tbl.clone().into_table();
|
||||
projects_tbl.insert(k, toml_edit::Item::Table(new_tbl));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
root.insert("projects", toml_edit::Item::Table(projects_tbl));
|
||||
}
|
||||
}
|
||||
let Some(projects_tbl) = doc["projects"].as_table_mut() else {
|
||||
@@ -304,12 +303,6 @@ pub fn set_project_trusted(codex_home: &Path, project_path: &Path) -> anyhow::Re
|
||||
));
|
||||
};
|
||||
|
||||
// If we created the `projects` table ourselves, keep it implicit so we
|
||||
// don't render a standalone `[projects]` header.
|
||||
if created_projects_table {
|
||||
projects_tbl.set_implicit(true);
|
||||
}
|
||||
|
||||
// Ensure the per-project entry is its own explicit table. If it exists but
|
||||
// is not a table (e.g., an inline table), replace it with an explicit table.
|
||||
let needs_proj_table = !projects_tbl.contains_key(project_key.as_str())
|
||||
@@ -328,6 +321,21 @@ pub fn set_project_trusted(codex_home: &Path, project_path: &Path) -> anyhow::Re
|
||||
};
|
||||
proj_tbl.set_implicit(false);
|
||||
proj_tbl["trust_level"] = toml_edit::value("trusted");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Patch `CODEX_HOME/config.toml` project state.
|
||||
/// Use with caution.
|
||||
pub fn set_project_trusted(codex_home: &Path, project_path: &Path) -> anyhow::Result<()> {
|
||||
let config_path = codex_home.join(CONFIG_TOML_FILE);
|
||||
// Parse existing config if present; otherwise start a new document.
|
||||
let mut doc = match std::fs::read_to_string(config_path.clone()) {
|
||||
Ok(s) => s.parse::<DocumentMut>()?,
|
||||
Err(e) if e.kind() == std::io::ErrorKind::NotFound => DocumentMut::new(),
|
||||
Err(e) => return Err(e.into()),
|
||||
};
|
||||
|
||||
set_project_trusted_inner(&mut doc, project_path)?;
|
||||
|
||||
// ensure codex_home exists
|
||||
std::fs::create_dir_all(codex_home)?;
|
||||
@@ -342,6 +350,107 @@ pub fn set_project_trusted(codex_home: &Path, project_path: &Path) -> anyhow::Re
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn ensure_profile_table<'a>(
|
||||
doc: &'a mut DocumentMut,
|
||||
profile_name: &str,
|
||||
) -> anyhow::Result<&'a mut toml_edit::Table> {
|
||||
let mut created_profiles_table = false;
|
||||
{
|
||||
let root = doc.as_table_mut();
|
||||
let needs_table = !root.contains_key("profiles")
|
||||
|| root
|
||||
.get("profiles")
|
||||
.and_then(|item| item.as_table())
|
||||
.is_none();
|
||||
if needs_table {
|
||||
root.insert("profiles", toml_edit::table());
|
||||
created_profiles_table = true;
|
||||
}
|
||||
}
|
||||
|
||||
let Some(profiles_table) = doc["profiles"].as_table_mut() else {
|
||||
return Err(anyhow::anyhow!(
|
||||
"profiles table missing after initialization"
|
||||
));
|
||||
};
|
||||
|
||||
if created_profiles_table {
|
||||
profiles_table.set_implicit(true);
|
||||
}
|
||||
|
||||
let needs_profile_table = !profiles_table.contains_key(profile_name)
|
||||
|| profiles_table
|
||||
.get(profile_name)
|
||||
.and_then(|item| item.as_table())
|
||||
.is_none();
|
||||
if needs_profile_table {
|
||||
profiles_table.insert(profile_name, toml_edit::table());
|
||||
}
|
||||
|
||||
let Some(profile_table) = profiles_table
|
||||
.get_mut(profile_name)
|
||||
.and_then(|item| item.as_table_mut())
|
||||
else {
|
||||
return Err(anyhow::anyhow!(format!(
|
||||
"profile table missing for {profile_name}"
|
||||
)));
|
||||
};
|
||||
|
||||
profile_table.set_implicit(false);
|
||||
Ok(profile_table)
|
||||
}
|
||||
|
||||
// TODO(jif) refactor config persistence.
|
||||
pub async fn persist_model_selection(
|
||||
codex_home: &Path,
|
||||
active_profile: Option<&str>,
|
||||
model: &str,
|
||||
effort: Option<ReasoningEffort>,
|
||||
) -> anyhow::Result<()> {
|
||||
let config_path = codex_home.join(CONFIG_TOML_FILE);
|
||||
let serialized = match tokio::fs::read_to_string(&config_path).await {
|
||||
Ok(contents) => contents,
|
||||
Err(err) if err.kind() == std::io::ErrorKind::NotFound => String::new(),
|
||||
Err(err) => return Err(err.into()),
|
||||
};
|
||||
|
||||
let mut doc = if serialized.is_empty() {
|
||||
DocumentMut::new()
|
||||
} else {
|
||||
serialized.parse::<DocumentMut>()?
|
||||
};
|
||||
|
||||
if let Some(profile_name) = active_profile {
|
||||
let profile_table = ensure_profile_table(&mut doc, profile_name)?;
|
||||
profile_table["model"] = toml_edit::value(model);
|
||||
if let Some(effort) = effort {
|
||||
profile_table["model_reasoning_effort"] = toml_edit::value(effort.to_string());
|
||||
}
|
||||
} else {
|
||||
let table = doc.as_table_mut();
|
||||
table["model"] = toml_edit::value(model);
|
||||
if let Some(effort) = effort {
|
||||
table["model_reasoning_effort"] = toml_edit::value(effort.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(jif) refactor the home creation
|
||||
tokio::fs::create_dir_all(codex_home)
|
||||
.await
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"failed to create Codex home directory at {}",
|
||||
codex_home.display()
|
||||
)
|
||||
})?;
|
||||
|
||||
tokio::fs::write(&config_path, doc.to_string())
|
||||
.await
|
||||
.with_context(|| format!("failed to persist config.toml at {}", config_path.display()))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Apply a single dotted-path override onto a TOML value.
|
||||
fn apply_toml_override(root: &mut TomlValue, path: &str, value: TomlValue) {
|
||||
use toml::value::Table;
|
||||
@@ -386,7 +495,7 @@ fn apply_toml_override(root: &mut TomlValue, path: &str, value: TomlValue) {
|
||||
}
|
||||
|
||||
/// Base config deserialized from ~/.codex/config.toml.
|
||||
#[derive(Deserialize, Debug, Clone, Default)]
|
||||
#[derive(Deserialize, Debug, Clone, Default, PartialEq)]
|
||||
pub struct ConfigToml {
|
||||
/// Optional override of model selection.
|
||||
pub model: Option<String>,
|
||||
@@ -477,15 +586,10 @@ pub struct ConfigToml {
|
||||
pub experimental_instructions_file: Option<PathBuf>,
|
||||
|
||||
pub experimental_use_exec_command_tool: Option<bool>,
|
||||
|
||||
/// The value for the `originator` header included with Responses API requests.
|
||||
pub responses_originator_header_internal_override: Option<String>,
|
||||
pub experimental_use_unified_exec_tool: Option<bool>,
|
||||
|
||||
pub projects: Option<HashMap<String, ProjectConfig>>,
|
||||
|
||||
/// If set to `true`, the API key will be signed with the `originator` header.
|
||||
pub preferred_auth_method: Option<AuthMode>,
|
||||
|
||||
/// Nested tools section for feature toggles
|
||||
pub tools: Option<ToolsToml>,
|
||||
|
||||
@@ -523,7 +627,7 @@ pub struct ProjectConfig {
|
||||
pub trust_level: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug, Clone, Default)]
|
||||
#[derive(Deserialize, Debug, Clone, Default, PartialEq)]
|
||||
pub struct ToolsToml {
|
||||
#[serde(default, alias = "web_search_request")]
|
||||
pub web_search: Option<bool>,
|
||||
@@ -661,7 +765,11 @@ impl Config {
|
||||
tools_web_search_request: override_tools_web_search_request,
|
||||
} = overrides;
|
||||
|
||||
let config_profile = match config_profile_key.as_ref().or(cfg.profile.as_ref()) {
|
||||
let active_profile_name = config_profile_key
|
||||
.as_ref()
|
||||
.or(cfg.profile.as_ref())
|
||||
.cloned();
|
||||
let config_profile = match active_profile_name.as_ref() {
|
||||
Some(key) => cfg
|
||||
.profiles
|
||||
.get(key)
|
||||
@@ -773,10 +881,6 @@ impl Config {
|
||||
Self::get_base_instructions(experimental_instructions_path, &resolved_cwd)?;
|
||||
let base_instructions = base_instructions.or(file_base_instructions);
|
||||
|
||||
let responses_originator_header: String = cfg
|
||||
.responses_originator_header_internal_override
|
||||
.unwrap_or(DEFAULT_RESPONSES_ORIGINATOR_HEADER.to_owned());
|
||||
|
||||
let config = Self {
|
||||
model,
|
||||
model_family,
|
||||
@@ -800,7 +904,6 @@ impl Config {
|
||||
codex_home,
|
||||
history,
|
||||
file_opener: cfg.file_opener.unwrap_or(UriBasedFileOpener::VsCode),
|
||||
tui: cfg.tui.unwrap_or_default(),
|
||||
codex_linux_sandbox_exe,
|
||||
|
||||
hide_agent_reasoning: cfg.hide_agent_reasoning.unwrap_or(false),
|
||||
@@ -826,12 +929,14 @@ impl Config {
|
||||
include_plan_tool: include_plan_tool.unwrap_or(false),
|
||||
include_apply_patch_tool: include_apply_patch_tool.unwrap_or(false),
|
||||
tools_web_search_request,
|
||||
responses_originator_header,
|
||||
preferred_auth_method: cfg.preferred_auth_method.unwrap_or(AuthMode::ChatGPT),
|
||||
use_experimental_streamable_shell_tool: cfg
|
||||
.experimental_use_exec_command_tool
|
||||
.unwrap_or(false),
|
||||
use_experimental_unified_exec_tool: cfg
|
||||
.experimental_use_unified_exec_tool
|
||||
.unwrap_or(false),
|
||||
include_view_image_tool,
|
||||
active_profile: active_profile_name,
|
||||
disable_paste_burst: cfg.disable_paste_burst.unwrap_or(false),
|
||||
};
|
||||
Ok(config)
|
||||
@@ -942,6 +1047,7 @@ mod tests {
|
||||
|
||||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
use tempfile::TempDir;
|
||||
|
||||
#[test]
|
||||
@@ -1032,6 +1138,145 @@ exclude_slash_tmp = true
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn persist_model_selection_updates_defaults() -> anyhow::Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
|
||||
persist_model_selection(
|
||||
codex_home.path(),
|
||||
None,
|
||||
"gpt-5-high-new",
|
||||
Some(ReasoningEffort::High),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let serialized =
|
||||
tokio::fs::read_to_string(codex_home.path().join(CONFIG_TOML_FILE)).await?;
|
||||
let parsed: ConfigToml = toml::from_str(&serialized)?;
|
||||
|
||||
assert_eq!(parsed.model.as_deref(), Some("gpt-5-high-new"));
|
||||
assert_eq!(parsed.model_reasoning_effort, Some(ReasoningEffort::High));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn persist_model_selection_overwrites_existing_model() -> anyhow::Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
let config_path = codex_home.path().join(CONFIG_TOML_FILE);
|
||||
|
||||
tokio::fs::write(
|
||||
&config_path,
|
||||
r#"
|
||||
model = "gpt-5"
|
||||
model_reasoning_effort = "medium"
|
||||
|
||||
[profiles.dev]
|
||||
model = "gpt-4.1"
|
||||
"#,
|
||||
)
|
||||
.await?;
|
||||
|
||||
persist_model_selection(
|
||||
codex_home.path(),
|
||||
None,
|
||||
"o4-mini",
|
||||
Some(ReasoningEffort::High),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let serialized = tokio::fs::read_to_string(config_path).await?;
|
||||
let parsed: ConfigToml = toml::from_str(&serialized)?;
|
||||
|
||||
assert_eq!(parsed.model.as_deref(), Some("o4-mini"));
|
||||
assert_eq!(parsed.model_reasoning_effort, Some(ReasoningEffort::High));
|
||||
assert_eq!(
|
||||
parsed
|
||||
.profiles
|
||||
.get("dev")
|
||||
.and_then(|profile| profile.model.as_deref()),
|
||||
Some("gpt-4.1"),
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn persist_model_selection_updates_profile() -> anyhow::Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
|
||||
persist_model_selection(
|
||||
codex_home.path(),
|
||||
Some("dev"),
|
||||
"gpt-5-high-new",
|
||||
Some(ReasoningEffort::Low),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let serialized =
|
||||
tokio::fs::read_to_string(codex_home.path().join(CONFIG_TOML_FILE)).await?;
|
||||
let parsed: ConfigToml = toml::from_str(&serialized)?;
|
||||
let profile = parsed
|
||||
.profiles
|
||||
.get("dev")
|
||||
.expect("profile should be created");
|
||||
|
||||
assert_eq!(profile.model.as_deref(), Some("gpt-5-high-new"));
|
||||
assert_eq!(profile.model_reasoning_effort, Some(ReasoningEffort::Low));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn persist_model_selection_updates_existing_profile() -> anyhow::Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
let config_path = codex_home.path().join(CONFIG_TOML_FILE);
|
||||
|
||||
tokio::fs::write(
|
||||
&config_path,
|
||||
r#"
|
||||
[profiles.dev]
|
||||
model = "gpt-4"
|
||||
model_reasoning_effort = "medium"
|
||||
|
||||
[profiles.prod]
|
||||
model = "gpt-5"
|
||||
"#,
|
||||
)
|
||||
.await?;
|
||||
|
||||
persist_model_selection(
|
||||
codex_home.path(),
|
||||
Some("dev"),
|
||||
"o4-high",
|
||||
Some(ReasoningEffort::Medium),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let serialized = tokio::fs::read_to_string(config_path).await?;
|
||||
let parsed: ConfigToml = toml::from_str(&serialized)?;
|
||||
|
||||
let dev_profile = parsed
|
||||
.profiles
|
||||
.get("dev")
|
||||
.expect("dev profile should survive updates");
|
||||
assert_eq!(dev_profile.model.as_deref(), Some("o4-high"));
|
||||
assert_eq!(
|
||||
dev_profile.model_reasoning_effort,
|
||||
Some(ReasoningEffort::Medium)
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
parsed
|
||||
.profiles
|
||||
.get("prod")
|
||||
.and_then(|profile| profile.model.as_deref()),
|
||||
Some("gpt-5"),
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct PrecedenceTestFixture {
|
||||
cwd: TempDir,
|
||||
codex_home: TempDir,
|
||||
@@ -1190,7 +1435,6 @@ model_verbosity = "high"
|
||||
codex_home: fixture.codex_home(),
|
||||
history: History::default(),
|
||||
file_opener: UriBasedFileOpener::VsCode,
|
||||
tui: Tui::default(),
|
||||
codex_linux_sandbox_exe: None,
|
||||
hide_agent_reasoning: false,
|
||||
show_raw_agent_reasoning: false,
|
||||
@@ -1203,10 +1447,10 @@ model_verbosity = "high"
|
||||
include_plan_tool: false,
|
||||
include_apply_patch_tool: false,
|
||||
tools_web_search_request: false,
|
||||
responses_originator_header: "codex_cli_rs".to_string(),
|
||||
preferred_auth_method: AuthMode::ChatGPT,
|
||||
use_experimental_streamable_shell_tool: false,
|
||||
use_experimental_unified_exec_tool: false,
|
||||
include_view_image_tool: true,
|
||||
active_profile: Some("o3".to_string()),
|
||||
disable_paste_burst: false,
|
||||
},
|
||||
o3_profile_config
|
||||
@@ -1247,7 +1491,6 @@ model_verbosity = "high"
|
||||
codex_home: fixture.codex_home(),
|
||||
history: History::default(),
|
||||
file_opener: UriBasedFileOpener::VsCode,
|
||||
tui: Tui::default(),
|
||||
codex_linux_sandbox_exe: None,
|
||||
hide_agent_reasoning: false,
|
||||
show_raw_agent_reasoning: false,
|
||||
@@ -1260,10 +1503,10 @@ model_verbosity = "high"
|
||||
include_plan_tool: false,
|
||||
include_apply_patch_tool: false,
|
||||
tools_web_search_request: false,
|
||||
responses_originator_header: "codex_cli_rs".to_string(),
|
||||
preferred_auth_method: AuthMode::ChatGPT,
|
||||
use_experimental_streamable_shell_tool: false,
|
||||
use_experimental_unified_exec_tool: false,
|
||||
include_view_image_tool: true,
|
||||
active_profile: Some("gpt3".to_string()),
|
||||
disable_paste_burst: false,
|
||||
};
|
||||
|
||||
@@ -1319,7 +1562,6 @@ model_verbosity = "high"
|
||||
codex_home: fixture.codex_home(),
|
||||
history: History::default(),
|
||||
file_opener: UriBasedFileOpener::VsCode,
|
||||
tui: Tui::default(),
|
||||
codex_linux_sandbox_exe: None,
|
||||
hide_agent_reasoning: false,
|
||||
show_raw_agent_reasoning: false,
|
||||
@@ -1332,10 +1574,10 @@ model_verbosity = "high"
|
||||
include_plan_tool: false,
|
||||
include_apply_patch_tool: false,
|
||||
tools_web_search_request: false,
|
||||
responses_originator_header: "codex_cli_rs".to_string(),
|
||||
preferred_auth_method: AuthMode::ChatGPT,
|
||||
use_experimental_streamable_shell_tool: false,
|
||||
use_experimental_unified_exec_tool: false,
|
||||
include_view_image_tool: true,
|
||||
active_profile: Some("zdr".to_string()),
|
||||
disable_paste_burst: false,
|
||||
};
|
||||
|
||||
@@ -1377,7 +1619,6 @@ model_verbosity = "high"
|
||||
codex_home: fixture.codex_home(),
|
||||
history: History::default(),
|
||||
file_opener: UriBasedFileOpener::VsCode,
|
||||
tui: Tui::default(),
|
||||
codex_linux_sandbox_exe: None,
|
||||
hide_agent_reasoning: false,
|
||||
show_raw_agent_reasoning: false,
|
||||
@@ -1390,10 +1631,10 @@ model_verbosity = "high"
|
||||
include_plan_tool: false,
|
||||
include_apply_patch_tool: false,
|
||||
tools_web_search_request: false,
|
||||
responses_originator_header: "codex_cli_rs".to_string(),
|
||||
preferred_auth_method: AuthMode::ChatGPT,
|
||||
use_experimental_streamable_shell_tool: false,
|
||||
use_experimental_unified_exec_tool: false,
|
||||
include_view_image_tool: true,
|
||||
active_profile: Some("gpt5".to_string()),
|
||||
disable_paste_burst: false,
|
||||
};
|
||||
|
||||
@@ -1404,17 +1645,14 @@ model_verbosity = "high"
|
||||
|
||||
#[test]
|
||||
fn test_set_project_trusted_writes_explicit_tables() -> anyhow::Result<()> {
|
||||
let codex_home = TempDir::new().unwrap();
|
||||
let project_dir = TempDir::new().unwrap();
|
||||
let project_dir = Path::new("/some/path");
|
||||
let mut doc = DocumentMut::new();
|
||||
|
||||
// Call the function under test
|
||||
set_project_trusted(codex_home.path(), project_dir.path())?;
|
||||
set_project_trusted_inner(&mut doc, project_dir)?;
|
||||
|
||||
// Read back the generated config.toml and assert exact contents
|
||||
let config_path = codex_home.path().join(CONFIG_TOML_FILE);
|
||||
let contents = std::fs::read_to_string(&config_path)?;
|
||||
let contents = doc.to_string();
|
||||
|
||||
let raw_path = project_dir.path().to_string_lossy();
|
||||
let raw_path = project_dir.to_string_lossy();
|
||||
let path_str = if raw_path.contains('\\') {
|
||||
format!("'{raw_path}'")
|
||||
} else {
|
||||
@@ -1432,12 +1670,10 @@ trust_level = "trusted"
|
||||
|
||||
#[test]
|
||||
fn test_set_project_trusted_converts_inline_to_explicit() -> anyhow::Result<()> {
|
||||
let codex_home = TempDir::new().unwrap();
|
||||
let project_dir = TempDir::new().unwrap();
|
||||
let project_dir = Path::new("/some/path");
|
||||
|
||||
// Seed config.toml with an inline project entry under [projects]
|
||||
let config_path = codex_home.path().join(CONFIG_TOML_FILE);
|
||||
let raw_path = project_dir.path().to_string_lossy();
|
||||
let raw_path = project_dir.to_string_lossy();
|
||||
let path_str = if raw_path.contains('\\') {
|
||||
format!("'{raw_path}'")
|
||||
} else {
|
||||
@@ -1449,13 +1685,12 @@ trust_level = "trusted"
|
||||
{path_str} = {{ trust_level = "untrusted" }}
|
||||
"#
|
||||
);
|
||||
std::fs::create_dir_all(codex_home.path())?;
|
||||
std::fs::write(&config_path, initial)?;
|
||||
let mut doc = initial.parse::<DocumentMut>()?;
|
||||
|
||||
// Run the function; it should convert to explicit tables and set trusted
|
||||
set_project_trusted(codex_home.path(), project_dir.path())?;
|
||||
set_project_trusted_inner(&mut doc, project_dir)?;
|
||||
|
||||
let contents = std::fs::read_to_string(&config_path)?;
|
||||
let contents = doc.to_string();
|
||||
|
||||
// Assert exact output after conversion to explicit table
|
||||
let expected = format!(
|
||||
@@ -1470,5 +1705,37 @@ trust_level = "trusted"
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// No test enforcing the presence of a standalone [projects] header.
|
||||
#[test]
|
||||
fn test_set_project_trusted_migrates_top_level_inline_projects_preserving_entries()
|
||||
-> anyhow::Result<()> {
|
||||
let initial = r#"toplevel = "baz"
|
||||
projects = { "/Users/mbolin/code/codex4" = { trust_level = "trusted", foo = "bar" } , "/Users/mbolin/code/codex3" = { trust_level = "trusted" } }
|
||||
model = "foo""#;
|
||||
let mut doc = initial.parse::<DocumentMut>()?;
|
||||
|
||||
// Approve a new directory
|
||||
let new_project = Path::new("/Users/mbolin/code/codex2");
|
||||
set_project_trusted_inner(&mut doc, new_project)?;
|
||||
|
||||
let contents = doc.to_string();
|
||||
|
||||
// Since we created the [projects] table as part of migration, it is kept implicit.
|
||||
// Expect explicit per-project tables, preserving prior entries and appending the new one.
|
||||
let expected = r#"toplevel = "baz"
|
||||
model = "foo"
|
||||
|
||||
[projects."/Users/mbolin/code/codex4"]
|
||||
trust_level = "trusted"
|
||||
foo = "bar"
|
||||
|
||||
[projects."/Users/mbolin/code/codex3"]
|
||||
trust_level = "trusted"
|
||||
|
||||
[projects."/Users/mbolin/code/codex2"]
|
||||
trust_level = "trusted"
|
||||
"#;
|
||||
assert_eq!(contents, expected);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
582
codex-rs/core/src/config_edit.rs
Normal file
582
codex-rs/core/src/config_edit.rs
Normal file
@@ -0,0 +1,582 @@
|
||||
use crate::config::CONFIG_TOML_FILE;
|
||||
use anyhow::Result;
|
||||
use std::path::Path;
|
||||
use tempfile::NamedTempFile;
|
||||
use toml_edit::DocumentMut;
|
||||
|
||||
pub const CONFIG_KEY_MODEL: &str = "model";
|
||||
pub const CONFIG_KEY_EFFORT: &str = "model_reasoning_effort";
|
||||
|
||||
/// Persist overrides into `config.toml` using explicit key segments per
|
||||
/// override. This avoids ambiguity with keys that contain dots or spaces.
|
||||
pub async fn persist_overrides(
|
||||
codex_home: &Path,
|
||||
profile: Option<&str>,
|
||||
overrides: &[(&[&str], &str)],
|
||||
) -> Result<()> {
|
||||
let config_path = codex_home.join(CONFIG_TOML_FILE);
|
||||
|
||||
let mut doc = match tokio::fs::read_to_string(&config_path).await {
|
||||
Ok(s) => s.parse::<DocumentMut>()?,
|
||||
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
|
||||
tokio::fs::create_dir_all(codex_home).await?;
|
||||
DocumentMut::new()
|
||||
}
|
||||
Err(e) => return Err(e.into()),
|
||||
};
|
||||
|
||||
let effective_profile = if let Some(p) = profile {
|
||||
Some(p.to_owned())
|
||||
} else {
|
||||
doc.get("profile")
|
||||
.and_then(|i| i.as_str())
|
||||
.map(|s| s.to_string())
|
||||
};
|
||||
|
||||
for (segments, val) in overrides.iter().copied() {
|
||||
let value = toml_edit::value(val);
|
||||
if let Some(ref name) = effective_profile {
|
||||
if segments.first().copied() == Some("profiles") {
|
||||
apply_toml_edit_override_segments(&mut doc, segments, value);
|
||||
} else {
|
||||
let mut seg_buf: Vec<&str> = Vec::with_capacity(2 + segments.len());
|
||||
seg_buf.push("profiles");
|
||||
seg_buf.push(name.as_str());
|
||||
seg_buf.extend_from_slice(segments);
|
||||
apply_toml_edit_override_segments(&mut doc, &seg_buf, value);
|
||||
}
|
||||
} else {
|
||||
apply_toml_edit_override_segments(&mut doc, segments, value);
|
||||
}
|
||||
}
|
||||
|
||||
let tmp_file = NamedTempFile::new_in(codex_home)?;
|
||||
tokio::fs::write(tmp_file.path(), doc.to_string()).await?;
|
||||
tmp_file.persist(config_path)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Persist overrides where values may be optional. Any entries with `None`
|
||||
/// values are skipped. If all values are `None`, this becomes a no-op and
|
||||
/// returns `Ok(())` without touching the file.
|
||||
pub async fn persist_non_null_overrides(
|
||||
codex_home: &Path,
|
||||
profile: Option<&str>,
|
||||
overrides: &[(&[&str], Option<&str>)],
|
||||
) -> Result<()> {
|
||||
let filtered: Vec<(&[&str], &str)> = overrides
|
||||
.iter()
|
||||
.filter_map(|(k, v)| v.map(|vv| (*k, vv)))
|
||||
.collect();
|
||||
|
||||
if filtered.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
persist_overrides(codex_home, profile, &filtered).await
|
||||
}
|
||||
|
||||
/// Apply a single override onto a `toml_edit` document while preserving
|
||||
/// existing formatting/comments.
|
||||
/// The key is expressed as explicit segments to correctly handle keys that
|
||||
/// contain dots or spaces.
|
||||
fn apply_toml_edit_override_segments(
|
||||
doc: &mut DocumentMut,
|
||||
segments: &[&str],
|
||||
value: toml_edit::Item,
|
||||
) {
|
||||
use toml_edit::Item;
|
||||
|
||||
if segments.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let mut current = doc.as_table_mut();
|
||||
for seg in &segments[..segments.len() - 1] {
|
||||
if !current.contains_key(seg) {
|
||||
current[*seg] = Item::Table(toml_edit::Table::new());
|
||||
if let Some(t) = current[*seg].as_table_mut() {
|
||||
t.set_implicit(true);
|
||||
}
|
||||
}
|
||||
|
||||
let maybe_item = current.get_mut(seg);
|
||||
let Some(item) = maybe_item else { return };
|
||||
|
||||
if !item.is_table() {
|
||||
*item = Item::Table(toml_edit::Table::new());
|
||||
if let Some(t) = item.as_table_mut() {
|
||||
t.set_implicit(true);
|
||||
}
|
||||
}
|
||||
|
||||
let Some(tbl) = item.as_table_mut() else {
|
||||
return;
|
||||
};
|
||||
current = tbl;
|
||||
}
|
||||
|
||||
let last = segments[segments.len() - 1];
|
||||
current[last] = value;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
use tempfile::tempdir;
|
||||
|
||||
/// Verifies model and effort are written at top-level when no profile is set.
|
||||
#[tokio::test]
|
||||
async fn set_default_model_and_effort_top_level_when_no_profile() {
|
||||
let tmpdir = tempdir().expect("tmp");
|
||||
let codex_home = tmpdir.path();
|
||||
|
||||
persist_overrides(
|
||||
codex_home,
|
||||
None,
|
||||
&[
|
||||
(&[CONFIG_KEY_MODEL], "gpt-5"),
|
||||
(&[CONFIG_KEY_EFFORT], "high"),
|
||||
],
|
||||
)
|
||||
.await
|
||||
.expect("persist");
|
||||
|
||||
let contents = read_config(codex_home).await;
|
||||
let expected = r#"model = "gpt-5"
|
||||
model_reasoning_effort = "high"
|
||||
"#;
|
||||
assert_eq!(contents, expected);
|
||||
}
|
||||
|
||||
/// Verifies values are written under the active profile when `profile` is set.
|
||||
#[tokio::test]
|
||||
async fn set_defaults_update_profile_when_profile_set() {
|
||||
let tmpdir = tempdir().expect("tmp");
|
||||
let codex_home = tmpdir.path();
|
||||
|
||||
// Seed config with a profile selection but without profiles table
|
||||
let seed = "profile = \"o3\"\n";
|
||||
tokio::fs::write(codex_home.join(CONFIG_TOML_FILE), seed)
|
||||
.await
|
||||
.expect("seed write");
|
||||
|
||||
persist_overrides(
|
||||
codex_home,
|
||||
None,
|
||||
&[
|
||||
(&[CONFIG_KEY_MODEL], "o3"),
|
||||
(&[CONFIG_KEY_EFFORT], "minimal"),
|
||||
],
|
||||
)
|
||||
.await
|
||||
.expect("persist");
|
||||
|
||||
let contents = read_config(codex_home).await;
|
||||
let expected = r#"profile = "o3"
|
||||
|
||||
[profiles.o3]
|
||||
model = "o3"
|
||||
model_reasoning_effort = "minimal"
|
||||
"#;
|
||||
assert_eq!(contents, expected);
|
||||
}
|
||||
|
||||
/// Verifies profile names with dots/spaces are preserved via explicit segments.
|
||||
#[tokio::test]
|
||||
async fn set_defaults_update_profile_with_dot_and_space() {
|
||||
let tmpdir = tempdir().expect("tmp");
|
||||
let codex_home = tmpdir.path();
|
||||
|
||||
// Seed config with a profile name that contains a dot and a space
|
||||
let seed = "profile = \"my.team name\"\n";
|
||||
tokio::fs::write(codex_home.join(CONFIG_TOML_FILE), seed)
|
||||
.await
|
||||
.expect("seed write");
|
||||
|
||||
persist_overrides(
|
||||
codex_home,
|
||||
None,
|
||||
&[
|
||||
(&[CONFIG_KEY_MODEL], "o3"),
|
||||
(&[CONFIG_KEY_EFFORT], "minimal"),
|
||||
],
|
||||
)
|
||||
.await
|
||||
.expect("persist");
|
||||
|
||||
let contents = read_config(codex_home).await;
|
||||
let expected = r#"profile = "my.team name"
|
||||
|
||||
[profiles."my.team name"]
|
||||
model = "o3"
|
||||
model_reasoning_effort = "minimal"
|
||||
"#;
|
||||
assert_eq!(contents, expected);
|
||||
}
|
||||
|
||||
/// Verifies explicit profile override writes under that profile even without active profile.
|
||||
#[tokio::test]
|
||||
async fn set_defaults_update_when_profile_override_supplied() {
|
||||
let tmpdir = tempdir().expect("tmp");
|
||||
let codex_home = tmpdir.path();
|
||||
|
||||
// No profile key in config.toml
|
||||
tokio::fs::write(codex_home.join(CONFIG_TOML_FILE), "")
|
||||
.await
|
||||
.expect("seed write");
|
||||
|
||||
// Persist with an explicit profile override
|
||||
persist_overrides(
|
||||
codex_home,
|
||||
Some("o3"),
|
||||
&[(&[CONFIG_KEY_MODEL], "o3"), (&[CONFIG_KEY_EFFORT], "high")],
|
||||
)
|
||||
.await
|
||||
.expect("persist");
|
||||
|
||||
let contents = read_config(codex_home).await;
|
||||
let expected = r#"[profiles.o3]
|
||||
model = "o3"
|
||||
model_reasoning_effort = "high"
|
||||
"#;
|
||||
assert_eq!(contents, expected);
|
||||
}
|
||||
|
||||
/// Verifies nested tables are created as needed when applying overrides.
|
||||
#[tokio::test]
|
||||
async fn persist_overrides_creates_nested_tables() {
|
||||
let tmpdir = tempdir().expect("tmp");
|
||||
let codex_home = tmpdir.path();
|
||||
|
||||
persist_overrides(
|
||||
codex_home,
|
||||
None,
|
||||
&[
|
||||
(&["a", "b", "c"], "v"),
|
||||
(&["x"], "y"),
|
||||
(&["profiles", "p1", CONFIG_KEY_MODEL], "gpt-5"),
|
||||
],
|
||||
)
|
||||
.await
|
||||
.expect("persist");
|
||||
|
||||
let contents = read_config(codex_home).await;
|
||||
let expected = r#"x = "y"
|
||||
|
||||
[a.b]
|
||||
c = "v"
|
||||
|
||||
[profiles.p1]
|
||||
model = "gpt-5"
|
||||
"#;
|
||||
assert_eq!(contents, expected);
|
||||
}
|
||||
|
||||
/// Verifies a scalar key becomes a table when nested keys are written.
|
||||
#[tokio::test]
|
||||
async fn persist_overrides_replaces_scalar_with_table() {
|
||||
let tmpdir = tempdir().expect("tmp");
|
||||
let codex_home = tmpdir.path();
|
||||
let seed = "foo = \"bar\"\n";
|
||||
tokio::fs::write(codex_home.join(CONFIG_TOML_FILE), seed)
|
||||
.await
|
||||
.expect("seed write");
|
||||
|
||||
persist_overrides(codex_home, None, &[(&["foo", "bar", "baz"], "ok")])
|
||||
.await
|
||||
.expect("persist");
|
||||
|
||||
let contents = read_config(codex_home).await;
|
||||
let expected = r#"[foo.bar]
|
||||
baz = "ok"
|
||||
"#;
|
||||
assert_eq!(contents, expected);
|
||||
}
|
||||
|
||||
/// Verifies comments and spacing are preserved when writing under active profile.
|
||||
#[tokio::test]
|
||||
async fn set_defaults_preserve_comments() {
|
||||
let tmpdir = tempdir().expect("tmp");
|
||||
let codex_home = tmpdir.path();
|
||||
|
||||
// Seed a config with comments and spacing we expect to preserve
|
||||
let seed = r#"# Global comment
|
||||
# Another line
|
||||
|
||||
profile = "o3"
|
||||
|
||||
# Profile settings
|
||||
[profiles.o3]
|
||||
# keep me
|
||||
existing = "keep"
|
||||
"#;
|
||||
tokio::fs::write(codex_home.join(CONFIG_TOML_FILE), seed)
|
||||
.await
|
||||
.expect("seed write");
|
||||
|
||||
// Apply defaults; since profile is set, it should write under [profiles.o3]
|
||||
persist_overrides(
|
||||
codex_home,
|
||||
None,
|
||||
&[(&[CONFIG_KEY_MODEL], "o3"), (&[CONFIG_KEY_EFFORT], "high")],
|
||||
)
|
||||
.await
|
||||
.expect("persist");
|
||||
|
||||
let contents = read_config(codex_home).await;
|
||||
let expected = r#"# Global comment
|
||||
# Another line
|
||||
|
||||
profile = "o3"
|
||||
|
||||
# Profile settings
|
||||
[profiles.o3]
|
||||
# keep me
|
||||
existing = "keep"
|
||||
model = "o3"
|
||||
model_reasoning_effort = "high"
|
||||
"#;
|
||||
assert_eq!(contents, expected);
|
||||
}
|
||||
|
||||
/// Verifies comments and spacing are preserved when writing at top level.
|
||||
#[tokio::test]
|
||||
async fn set_defaults_preserve_global_comments() {
|
||||
let tmpdir = tempdir().expect("tmp");
|
||||
let codex_home = tmpdir.path();
|
||||
|
||||
// Seed a config WITHOUT a profile, containing comments and spacing
|
||||
let seed = r#"# Top-level comments
|
||||
# should be preserved
|
||||
|
||||
existing = "keep"
|
||||
"#;
|
||||
tokio::fs::write(codex_home.join(CONFIG_TOML_FILE), seed)
|
||||
.await
|
||||
.expect("seed write");
|
||||
|
||||
// Since there is no profile, the defaults should be written at top-level
|
||||
persist_overrides(
|
||||
codex_home,
|
||||
None,
|
||||
&[
|
||||
(&[CONFIG_KEY_MODEL], "gpt-5"),
|
||||
(&[CONFIG_KEY_EFFORT], "minimal"),
|
||||
],
|
||||
)
|
||||
.await
|
||||
.expect("persist");
|
||||
|
||||
let contents = read_config(codex_home).await;
|
||||
let expected = r#"# Top-level comments
|
||||
# should be preserved
|
||||
|
||||
existing = "keep"
|
||||
model = "gpt-5"
|
||||
model_reasoning_effort = "minimal"
|
||||
"#;
|
||||
assert_eq!(contents, expected);
|
||||
}
|
||||
|
||||
/// Verifies errors on invalid TOML propagate and file is not clobbered.
|
||||
#[tokio::test]
|
||||
async fn persist_overrides_errors_on_parse_failure() {
|
||||
let tmpdir = tempdir().expect("tmp");
|
||||
let codex_home = tmpdir.path();
|
||||
|
||||
// Write an intentionally invalid TOML file
|
||||
let invalid = "invalid = [unclosed";
|
||||
tokio::fs::write(codex_home.join(CONFIG_TOML_FILE), invalid)
|
||||
.await
|
||||
.expect("seed write");
|
||||
|
||||
// Attempting to persist should return an error and must not clobber the file.
|
||||
let res = persist_overrides(codex_home, None, &[(&["x"], "y")]).await;
|
||||
assert!(res.is_err(), "expected parse error to propagate");
|
||||
|
||||
// File should be unchanged
|
||||
let contents = read_config(codex_home).await;
|
||||
assert_eq!(contents, invalid);
|
||||
}
|
||||
|
||||
/// Verifies changing model only preserves existing effort at top-level.
|
||||
#[tokio::test]
|
||||
async fn changing_only_model_preserves_existing_effort_top_level() {
|
||||
let tmpdir = tempdir().expect("tmp");
|
||||
let codex_home = tmpdir.path();
|
||||
|
||||
// Seed with an effort value only
|
||||
let seed = "model_reasoning_effort = \"minimal\"\n";
|
||||
tokio::fs::write(codex_home.join(CONFIG_TOML_FILE), seed)
|
||||
.await
|
||||
.expect("seed write");
|
||||
|
||||
// Change only the model
|
||||
persist_overrides(codex_home, None, &[(&[CONFIG_KEY_MODEL], "o3")])
|
||||
.await
|
||||
.expect("persist");
|
||||
|
||||
let contents = read_config(codex_home).await;
|
||||
let expected = r#"model_reasoning_effort = "minimal"
|
||||
model = "o3"
|
||||
"#;
|
||||
assert_eq!(contents, expected);
|
||||
}
|
||||
|
||||
/// Verifies changing effort only preserves existing model at top-level.
|
||||
#[tokio::test]
|
||||
async fn changing_only_effort_preserves_existing_model_top_level() {
|
||||
let tmpdir = tempdir().expect("tmp");
|
||||
let codex_home = tmpdir.path();
|
||||
|
||||
// Seed with a model value only
|
||||
let seed = "model = \"gpt-5\"\n";
|
||||
tokio::fs::write(codex_home.join(CONFIG_TOML_FILE), seed)
|
||||
.await
|
||||
.expect("seed write");
|
||||
|
||||
// Change only the effort
|
||||
persist_overrides(codex_home, None, &[(&[CONFIG_KEY_EFFORT], "high")])
|
||||
.await
|
||||
.expect("persist");
|
||||
|
||||
let contents = read_config(codex_home).await;
|
||||
let expected = r#"model = "gpt-5"
|
||||
model_reasoning_effort = "high"
|
||||
"#;
|
||||
assert_eq!(contents, expected);
|
||||
}
|
||||
|
||||
/// Verifies changing model only preserves existing effort in active profile.
|
||||
#[tokio::test]
|
||||
async fn changing_only_model_preserves_effort_in_active_profile() {
|
||||
let tmpdir = tempdir().expect("tmp");
|
||||
let codex_home = tmpdir.path();
|
||||
|
||||
// Seed with an active profile and an existing effort under that profile
|
||||
let seed = r#"profile = "p1"
|
||||
|
||||
[profiles.p1]
|
||||
model_reasoning_effort = "low"
|
||||
"#;
|
||||
tokio::fs::write(codex_home.join(CONFIG_TOML_FILE), seed)
|
||||
.await
|
||||
.expect("seed write");
|
||||
|
||||
persist_overrides(codex_home, None, &[(&[CONFIG_KEY_MODEL], "o4-mini")])
|
||||
.await
|
||||
.expect("persist");
|
||||
|
||||
let contents = read_config(codex_home).await;
|
||||
let expected = r#"profile = "p1"
|
||||
|
||||
[profiles.p1]
|
||||
model_reasoning_effort = "low"
|
||||
model = "o4-mini"
|
||||
"#;
|
||||
assert_eq!(contents, expected);
|
||||
}
|
||||
|
||||
/// Verifies changing effort only preserves existing model in a profile override.
|
||||
#[tokio::test]
|
||||
async fn changing_only_effort_preserves_model_in_profile_override() {
|
||||
let tmpdir = tempdir().expect("tmp");
|
||||
let codex_home = tmpdir.path();
|
||||
|
||||
// No active profile key; we'll target an explicit override
|
||||
let seed = r#"[profiles.team]
|
||||
model = "gpt-5"
|
||||
"#;
|
||||
tokio::fs::write(codex_home.join(CONFIG_TOML_FILE), seed)
|
||||
.await
|
||||
.expect("seed write");
|
||||
|
||||
persist_overrides(
|
||||
codex_home,
|
||||
Some("team"),
|
||||
&[(&[CONFIG_KEY_EFFORT], "minimal")],
|
||||
)
|
||||
.await
|
||||
.expect("persist");
|
||||
|
||||
let contents = read_config(codex_home).await;
|
||||
let expected = r#"[profiles.team]
|
||||
model = "gpt-5"
|
||||
model_reasoning_effort = "minimal"
|
||||
"#;
|
||||
assert_eq!(contents, expected);
|
||||
}
|
||||
|
||||
/// Verifies `persist_non_null_overrides` skips `None` entries and writes only present values at top-level.
|
||||
#[tokio::test]
|
||||
async fn persist_non_null_skips_none_top_level() {
|
||||
let tmpdir = tempdir().expect("tmp");
|
||||
let codex_home = tmpdir.path();
|
||||
|
||||
persist_non_null_overrides(
|
||||
codex_home,
|
||||
None,
|
||||
&[
|
||||
(&[CONFIG_KEY_MODEL], Some("gpt-5")),
|
||||
(&[CONFIG_KEY_EFFORT], None),
|
||||
],
|
||||
)
|
||||
.await
|
||||
.expect("persist");
|
||||
|
||||
let contents = read_config(codex_home).await;
|
||||
let expected = "model = \"gpt-5\"\n";
|
||||
assert_eq!(contents, expected);
|
||||
}
|
||||
|
||||
/// Verifies no-op behavior when all provided overrides are `None` (no file created/modified).
|
||||
#[tokio::test]
|
||||
async fn persist_non_null_noop_when_all_none() {
|
||||
let tmpdir = tempdir().expect("tmp");
|
||||
let codex_home = tmpdir.path();
|
||||
|
||||
persist_non_null_overrides(
|
||||
codex_home,
|
||||
None,
|
||||
&[(&["a"], None), (&["profiles", "p", "x"], None)],
|
||||
)
|
||||
.await
|
||||
.expect("persist");
|
||||
|
||||
// Should not create config.toml on a pure no-op
|
||||
assert!(!codex_home.join(CONFIG_TOML_FILE).exists());
|
||||
}
|
||||
|
||||
/// Verifies entries are written under the specified profile and `None` entries are skipped.
|
||||
#[tokio::test]
|
||||
async fn persist_non_null_respects_profile_override() {
|
||||
let tmpdir = tempdir().expect("tmp");
|
||||
let codex_home = tmpdir.path();
|
||||
|
||||
persist_non_null_overrides(
|
||||
codex_home,
|
||||
Some("team"),
|
||||
&[
|
||||
(&[CONFIG_KEY_MODEL], Some("o3")),
|
||||
(&[CONFIG_KEY_EFFORT], None),
|
||||
],
|
||||
)
|
||||
.await
|
||||
.expect("persist");
|
||||
|
||||
let contents = read_config(codex_home).await;
|
||||
let expected = r#"[profiles.team]
|
||||
model = "o3"
|
||||
"#;
|
||||
assert_eq!(contents, expected);
|
||||
}
|
||||
|
||||
// Test helper moved to bottom per review guidance.
|
||||
async fn read_config(codex_home: &Path) -> String {
|
||||
let p = codex_home.join(CONFIG_TOML_FILE);
|
||||
tokio::fs::read_to_string(p).await.unwrap_or_default()
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,4 @@
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
|
||||
/// Transcript of conversation history
|
||||
@@ -59,6 +60,26 @@ impl ConversationHistory {
|
||||
kept.reverse();
|
||||
self.items = kept;
|
||||
}
|
||||
|
||||
pub(crate) fn last_agent_message(&self) -> String {
|
||||
for item in self.items.iter().rev() {
|
||||
if let ResponseItem::Message { role, content, .. } = item
|
||||
&& role == "assistant"
|
||||
{
|
||||
return content
|
||||
.iter()
|
||||
.find_map(|ci| {
|
||||
if let ContentItem::OutputText { text } = ci {
|
||||
Some(text.clone())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.unwrap_or_default();
|
||||
}
|
||||
}
|
||||
String::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Anything that is not a system message or "reasoning" message is considered
|
||||
|
||||
@@ -10,65 +10,16 @@ use crate::error::Result as CodexResult;
|
||||
use crate::protocol::Event;
|
||||
use crate::protocol::EventMsg;
|
||||
use crate::protocol::SessionConfiguredEvent;
|
||||
use crate::rollout::RolloutItem;
|
||||
use crate::rollout::RolloutRecorder;
|
||||
use crate::rollout::recorder::RolloutItemSliceExt;
|
||||
use codex_protocol::mcp_protocol::ConversationId;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::protocol::InitialHistory;
|
||||
use codex_protocol::protocol::RolloutItem;
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ResumedHistory {
|
||||
pub conversation_id: ConversationId,
|
||||
pub history: Vec<RolloutItem>,
|
||||
pub rollout_path: PathBuf,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum InitialHistory {
|
||||
New,
|
||||
Resumed(ResumedHistory),
|
||||
Forked(Vec<ResponseItem>),
|
||||
}
|
||||
|
||||
impl PartialEq for InitialHistory {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
match (self, other) {
|
||||
(InitialHistory::New, InitialHistory::New) => true,
|
||||
(InitialHistory::Forked(a), InitialHistory::Forked(b)) => a == b,
|
||||
(InitialHistory::Resumed(_), InitialHistory::Resumed(_)) => true,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl InitialHistory {
|
||||
/// Return all response items contained in this initial history.
|
||||
pub fn get_response_items(&self) -> Vec<ResponseItem> {
|
||||
match self {
|
||||
InitialHistory::New => Vec::new(),
|
||||
InitialHistory::Forked(_) => Vec::new(),
|
||||
InitialHistory::Resumed(items) => {
|
||||
<[_] as RolloutItemSliceExt>::get_response_items(items.history.as_slice())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Return all events contained in this initial history.
|
||||
pub fn get_events(&self) -> Vec<crate::protocol::EventMsg> {
|
||||
match self {
|
||||
InitialHistory::New => Vec::new(),
|
||||
InitialHistory::Forked(_) => Vec::new(),
|
||||
InitialHistory::Resumed(items) => {
|
||||
<[_] as RolloutItemSliceExt>::get_events(items.history.as_slice())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents a newly created Codex conversation, including the first event
|
||||
/// (which is [`EventMsg::SessionConfigured`]).
|
||||
pub struct NewConversation {
|
||||
@@ -182,8 +133,15 @@ impl ConversationManager {
|
||||
self.finalize_spawn(codex, conversation_id).await
|
||||
}
|
||||
|
||||
pub async fn remove_conversation(&self, conversation_id: ConversationId) {
|
||||
self.conversations.write().await.remove(&conversation_id);
|
||||
/// Removes the conversation from the manager's internal map, though the
|
||||
/// conversation is stored as `Arc<CodexConversation>`, it is possible that
|
||||
/// other references to it exist elsewhere. Returns the conversation if the
|
||||
/// conversation was found and removed.
|
||||
pub async fn remove_conversation(
|
||||
&self,
|
||||
conversation_id: &ConversationId,
|
||||
) -> Option<Arc<CodexConversation>> {
|
||||
self.conversations.write().await.remove(conversation_id)
|
||||
}
|
||||
|
||||
/// Fork an existing conversation by dropping the last `drop_last_messages`
|
||||
@@ -192,74 +150,58 @@ impl ConversationManager {
|
||||
/// caller's `config`). The new conversation will have a fresh id.
|
||||
pub async fn fork_conversation(
|
||||
&self,
|
||||
base_rollout_path: PathBuf,
|
||||
_base_conversation_id: ConversationId,
|
||||
num_messages_to_drop: usize,
|
||||
config: Config,
|
||||
path: PathBuf,
|
||||
) -> CodexResult<NewConversation> {
|
||||
// Read prior responses from the rollout file (tolerate both tagged and legacy formats).
|
||||
let text = tokio::fs::read_to_string(&base_rollout_path)
|
||||
.await
|
||||
.map_err(|e| CodexErr::Io(std::io::Error::other(format!("read rollout: {e}"))))?;
|
||||
let mut responses: Vec<ResponseItem> = Vec::new();
|
||||
for line in text.lines() {
|
||||
if line.trim().is_empty() {
|
||||
continue;
|
||||
}
|
||||
let v: serde_json::Value = match serde_json::from_str(line) {
|
||||
Ok(v) => v,
|
||||
Err(_) => continue,
|
||||
};
|
||||
// Only consider response items (legacy lines have no record_type)
|
||||
match v.get("record_type").and_then(|s| s.as_str()) {
|
||||
Some("response") | None => {
|
||||
if let Ok(item) = serde_json::from_value::<ResponseItem>(v) {
|
||||
responses.push(item);
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
let kept = truncate_after_dropping_last_messages(responses, num_messages_to_drop);
|
||||
// Compute the prefix up to the cut point.
|
||||
let history = RolloutRecorder::get_rollout_history(&path).await?;
|
||||
let history = truncate_after_dropping_last_messages(history, num_messages_to_drop);
|
||||
|
||||
// Spawn a new conversation with the computed initial history.
|
||||
let auth_manager = self.auth_manager.clone();
|
||||
let CodexSpawnOk {
|
||||
codex,
|
||||
conversation_id,
|
||||
} = Codex::spawn(config, auth_manager, kept).await?;
|
||||
} = Codex::spawn(config, auth_manager, history).await?;
|
||||
|
||||
self.finalize_spawn(codex, conversation_id).await
|
||||
}
|
||||
}
|
||||
|
||||
/// Return a prefix of `items` obtained by dropping the last `n` user messages
|
||||
/// and all items that follow them.
|
||||
fn truncate_after_dropping_last_messages(items: Vec<ResponseItem>, n: usize) -> InitialHistory {
|
||||
fn truncate_after_dropping_last_messages(history: InitialHistory, n: usize) -> InitialHistory {
|
||||
if n == 0 {
|
||||
return InitialHistory::Forked(items);
|
||||
return InitialHistory::Forked(history.get_rollout_items());
|
||||
}
|
||||
|
||||
// Walk backwards counting only `user` Message items, find cut index.
|
||||
let mut count = 0usize;
|
||||
let mut cut_index = 0usize;
|
||||
for (idx, item) in items.iter().enumerate().rev() {
|
||||
if let ResponseItem::Message { role, .. } = item
|
||||
// Work directly on rollout items, and cut the vector at the nth-from-last user message input.
|
||||
let items: Vec<RolloutItem> = history.get_rollout_items();
|
||||
|
||||
// Find indices of user message inputs in rollout order.
|
||||
let mut user_positions: Vec<usize> = Vec::new();
|
||||
for (idx, item) in items.iter().enumerate() {
|
||||
if let RolloutItem::ResponseItem(ResponseItem::Message { role, .. }) = item
|
||||
&& role == "user"
|
||||
{
|
||||
count += 1;
|
||||
if count == n {
|
||||
// Cut everything from this user message to the end.
|
||||
cut_index = idx;
|
||||
break;
|
||||
}
|
||||
user_positions.push(idx);
|
||||
}
|
||||
}
|
||||
if cut_index == 0 {
|
||||
// No prefix remains after dropping; start a new conversation.
|
||||
|
||||
// If fewer than n user messages exist, treat as empty.
|
||||
if user_positions.len() < n {
|
||||
return InitialHistory::New;
|
||||
}
|
||||
|
||||
// Cut strictly before the nth-from-last user message (do not keep the nth itself).
|
||||
let cut_idx = user_positions[user_positions.len() - n];
|
||||
let rolled: Vec<RolloutItem> = items.into_iter().take(cut_idx).collect();
|
||||
|
||||
if rolled.is_empty() {
|
||||
InitialHistory::New
|
||||
} else {
|
||||
InitialHistory::Forked(items.into_iter().take(cut_index).collect())
|
||||
InitialHistory::Forked(rolled)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -314,13 +256,30 @@ mod tests {
|
||||
assistant_msg("a4"),
|
||||
];
|
||||
|
||||
let truncated = truncate_after_dropping_last_messages(items.clone(), 1);
|
||||
// Wrap as InitialHistory::Forked with response items only.
|
||||
let initial: Vec<RolloutItem> = items
|
||||
.iter()
|
||||
.cloned()
|
||||
.map(RolloutItem::ResponseItem)
|
||||
.collect();
|
||||
let truncated = truncate_after_dropping_last_messages(InitialHistory::Forked(initial), 1);
|
||||
let got_items = truncated.get_rollout_items();
|
||||
let expected_items = vec![
|
||||
RolloutItem::ResponseItem(items[0].clone()),
|
||||
RolloutItem::ResponseItem(items[1].clone()),
|
||||
RolloutItem::ResponseItem(items[2].clone()),
|
||||
];
|
||||
assert_eq!(
|
||||
truncated,
|
||||
InitialHistory::Forked(vec![items[0].clone(), items[1].clone(), items[2].clone(),])
|
||||
serde_json::to_value(&got_items).unwrap(),
|
||||
serde_json::to_value(&expected_items).unwrap()
|
||||
);
|
||||
|
||||
let truncated2 = truncate_after_dropping_last_messages(items, 2);
|
||||
let initial2: Vec<RolloutItem> = items
|
||||
.iter()
|
||||
.cloned()
|
||||
.map(RolloutItem::ResponseItem)
|
||||
.collect();
|
||||
let truncated2 = truncate_after_dropping_last_messages(InitialHistory::Forked(initial2), 2);
|
||||
assert!(matches!(truncated2, InitialHistory::New));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,38 +1,123 @@
|
||||
pub const DEFAULT_ORIGINATOR: &str = "codex_cli_rs";
|
||||
use reqwest::header::HeaderValue;
|
||||
use std::sync::LazyLock;
|
||||
use std::sync::Mutex;
|
||||
|
||||
pub fn get_codex_user_agent(originator: Option<&str>) -> String {
|
||||
/// Set this to add a suffix to the User-Agent string.
|
||||
///
|
||||
/// It is not ideal that we're using a global singleton for this.
|
||||
/// This is primarily designed to differentiate MCP clients from each other.
|
||||
/// Because there can only be one MCP server per process, it should be safe for this to be a global static.
|
||||
/// However, future users of this should use this with caution as a result.
|
||||
/// In addition, we want to be confident that this value is used for ALL clients and doing that requires a
|
||||
/// lot of wiring and it's easy to miss code paths by doing so.
|
||||
/// See https://github.com/openai/codex/pull/3388/files for an example of what that would look like.
|
||||
/// Finally, we want to make sure this is set for ALL mcp clients without needing to know a special env var
|
||||
/// or having to set data that they already specified in the mcp initialize request somewhere else.
|
||||
///
|
||||
/// A space is automatically added between the suffix and the rest of the User-Agent string.
|
||||
/// The full user agent string is returned from the mcp initialize response.
|
||||
/// Parenthesis will be added by Codex. This should only specify what goes inside of the parenthesis.
|
||||
pub static USER_AGENT_SUFFIX: LazyLock<Mutex<Option<String>>> = LazyLock::new(|| Mutex::new(None));
|
||||
|
||||
pub const CODEX_INTERNAL_ORIGINATOR_OVERRIDE_ENV_VAR: &str = "CODEX_INTERNAL_ORIGINATOR_OVERRIDE";
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Originator {
|
||||
pub value: String,
|
||||
pub header_value: HeaderValue,
|
||||
}
|
||||
|
||||
pub static ORIGINATOR: LazyLock<Originator> = LazyLock::new(|| {
|
||||
let default = "codex_cli_rs";
|
||||
let value = std::env::var(CODEX_INTERNAL_ORIGINATOR_OVERRIDE_ENV_VAR)
|
||||
.unwrap_or_else(|_| default.to_string());
|
||||
|
||||
match HeaderValue::from_str(&value) {
|
||||
Ok(header_value) => Originator {
|
||||
value,
|
||||
header_value,
|
||||
},
|
||||
Err(e) => {
|
||||
tracing::error!("Unable to turn originator override {value} into header value: {e}");
|
||||
Originator {
|
||||
value: default.to_string(),
|
||||
header_value: HeaderValue::from_static(default),
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
pub fn get_codex_user_agent() -> String {
|
||||
let build_version = env!("CARGO_PKG_VERSION");
|
||||
let os_info = os_info::get();
|
||||
format!(
|
||||
let prefix = format!(
|
||||
"{}/{build_version} ({} {}; {}) {}",
|
||||
originator.unwrap_or(DEFAULT_ORIGINATOR),
|
||||
ORIGINATOR.value.as_str(),
|
||||
os_info.os_type(),
|
||||
os_info.version(),
|
||||
os_info.architecture().unwrap_or("unknown"),
|
||||
crate::terminal::user_agent()
|
||||
)
|
||||
);
|
||||
let suffix = USER_AGENT_SUFFIX
|
||||
.lock()
|
||||
.ok()
|
||||
.and_then(|guard| guard.clone());
|
||||
let suffix = suffix
|
||||
.as_deref()
|
||||
.map(str::trim)
|
||||
.filter(|value| !value.is_empty())
|
||||
.map_or_else(String::new, |value| format!(" ({value})"));
|
||||
|
||||
let candidate = format!("{prefix}{suffix}");
|
||||
sanitize_user_agent(candidate, &prefix)
|
||||
}
|
||||
|
||||
/// Sanitize the user agent string.
|
||||
///
|
||||
/// Invalid characters are replaced with an underscore.
|
||||
///
|
||||
/// If the user agent fails to parse, it falls back to fallback and then to ORIGINATOR.
|
||||
fn sanitize_user_agent(candidate: String, fallback: &str) -> String {
|
||||
if HeaderValue::from_str(candidate.as_str()).is_ok() {
|
||||
return candidate;
|
||||
}
|
||||
|
||||
let sanitized: String = candidate
|
||||
.chars()
|
||||
.map(|ch| if matches!(ch, ' '..='~') { ch } else { '_' })
|
||||
.collect();
|
||||
if !sanitized.is_empty() && HeaderValue::from_str(sanitized.as_str()).is_ok() {
|
||||
tracing::warn!(
|
||||
"Sanitized Codex user agent because provided suffix contained invalid header characters"
|
||||
);
|
||||
sanitized
|
||||
} else if HeaderValue::from_str(fallback).is_ok() {
|
||||
tracing::warn!(
|
||||
"Falling back to base Codex user agent because provided suffix could not be sanitized"
|
||||
);
|
||||
fallback.to_string()
|
||||
} else {
|
||||
tracing::warn!(
|
||||
"Falling back to default Codex originator because base user agent string is invalid"
|
||||
);
|
||||
ORIGINATOR.value.clone()
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a reqwest client with default `originator` and `User-Agent` headers set.
|
||||
pub fn create_client(originator: &str) -> reqwest::Client {
|
||||
pub fn create_client() -> reqwest::Client {
|
||||
use reqwest::header::HeaderMap;
|
||||
use reqwest::header::HeaderValue;
|
||||
|
||||
let mut headers = HeaderMap::new();
|
||||
let originator_value = HeaderValue::from_str(originator)
|
||||
.unwrap_or_else(|_| HeaderValue::from_static(DEFAULT_ORIGINATOR));
|
||||
headers.insert("originator", originator_value);
|
||||
let ua = get_codex_user_agent(Some(originator));
|
||||
headers.insert("originator", ORIGINATOR.header_value.clone());
|
||||
let ua = get_codex_user_agent();
|
||||
|
||||
match reqwest::Client::builder()
|
||||
reqwest::Client::builder()
|
||||
// Set UA via dedicated helper to avoid header validation pitfalls
|
||||
.user_agent(ua)
|
||||
.default_headers(headers)
|
||||
.build()
|
||||
{
|
||||
Ok(client) => client,
|
||||
Err(_) => reqwest::Client::new(),
|
||||
}
|
||||
.unwrap_or_else(|_| reqwest::Client::new())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -41,7 +126,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_get_codex_user_agent() {
|
||||
let user_agent = get_codex_user_agent(None);
|
||||
let user_agent = get_codex_user_agent();
|
||||
assert!(user_agent.starts_with("codex_cli_rs/"));
|
||||
}
|
||||
|
||||
@@ -53,8 +138,7 @@ mod tests {
|
||||
use wiremock::matchers::method;
|
||||
use wiremock::matchers::path;
|
||||
|
||||
let originator = "test_originator";
|
||||
let client = create_client(originator);
|
||||
let client = create_client();
|
||||
|
||||
// Spin up a local mock server and capture a request.
|
||||
let server = MockServer::start().await;
|
||||
@@ -82,21 +166,43 @@ mod tests {
|
||||
let originator_header = headers
|
||||
.get("originator")
|
||||
.expect("originator header missing");
|
||||
assert_eq!(originator_header.to_str().unwrap(), originator);
|
||||
assert_eq!(originator_header.to_str().unwrap(), "codex_cli_rs");
|
||||
|
||||
// User-Agent matches the computed Codex UA for that originator
|
||||
let expected_ua = get_codex_user_agent(Some(originator));
|
||||
let expected_ua = get_codex_user_agent();
|
||||
let ua_header = headers
|
||||
.get("user-agent")
|
||||
.expect("user-agent header missing");
|
||||
assert_eq!(ua_header.to_str().unwrap(), expected_ua);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_suffix_is_sanitized() {
|
||||
let prefix = "codex_cli_rs/0.0.0";
|
||||
let suffix = "bad\rsuffix";
|
||||
|
||||
assert_eq!(
|
||||
sanitize_user_agent(format!("{prefix} ({suffix})"), prefix),
|
||||
"codex_cli_rs/0.0.0 (bad_suffix)"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_suffix_is_sanitized2() {
|
||||
let prefix = "codex_cli_rs/0.0.0";
|
||||
let suffix = "bad\0suffix";
|
||||
|
||||
assert_eq!(
|
||||
sanitize_user_agent(format!("{prefix} ({suffix})"), prefix),
|
||||
"codex_cli_rs/0.0.0 (bad_suffix)"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(target_os = "macos")]
|
||||
fn test_macos() {
|
||||
use regex_lite::Regex;
|
||||
let user_agent = get_codex_user_agent(None);
|
||||
let user_agent = get_codex_user_agent();
|
||||
let re = Regex::new(
|
||||
r"^codex_cli_rs/\d+\.\d+\.\d+ \(Mac OS \d+\.\d+\.\d+; (x86_64|arm64)\) (\S+)$",
|
||||
)
|
||||
|
||||
@@ -26,6 +26,7 @@ pub(crate) struct EnvironmentContext {
|
||||
pub approval_policy: Option<AskForApproval>,
|
||||
pub sandbox_mode: Option<SandboxMode>,
|
||||
pub network_access: Option<NetworkAccess>,
|
||||
pub writable_roots: Option<Vec<PathBuf>>,
|
||||
pub shell: Option<Shell>,
|
||||
}
|
||||
|
||||
@@ -57,6 +58,16 @@ impl EnvironmentContext {
|
||||
}
|
||||
None => None,
|
||||
},
|
||||
writable_roots: match sandbox_policy {
|
||||
Some(SandboxPolicy::WorkspaceWrite { writable_roots, .. }) => {
|
||||
if writable_roots.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(writable_roots)
|
||||
}
|
||||
}
|
||||
_ => None,
|
||||
},
|
||||
shell,
|
||||
}
|
||||
}
|
||||
@@ -72,6 +83,7 @@ impl EnvironmentContext {
|
||||
/// <cwd>...</cwd>
|
||||
/// <approval_policy>...</approval_policy>
|
||||
/// <sandbox_mode>...</sandbox_mode>
|
||||
/// <writable_roots>...</writable_roots>
|
||||
/// <network_access>...</network_access>
|
||||
/// <shell>...</shell>
|
||||
/// </environment_context>
|
||||
@@ -94,6 +106,16 @@ impl EnvironmentContext {
|
||||
" <network_access>{network_access}</network_access>"
|
||||
));
|
||||
}
|
||||
if let Some(writable_roots) = self.writable_roots {
|
||||
lines.push(" <writable_roots>".to_string());
|
||||
for writable_root in writable_roots {
|
||||
lines.push(format!(
|
||||
" <root>{}</root>",
|
||||
writable_root.to_string_lossy()
|
||||
));
|
||||
}
|
||||
lines.push(" </writable_roots>".to_string());
|
||||
}
|
||||
if let Some(shell) = self.shell
|
||||
&& let Some(shell_name) = shell.name()
|
||||
{
|
||||
@@ -115,3 +137,77 @@ impl From<EnvironmentContext> for ResponseItem {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
fn workspace_write_policy(writable_roots: Vec<&str>, network_access: bool) -> SandboxPolicy {
|
||||
SandboxPolicy::WorkspaceWrite {
|
||||
writable_roots: writable_roots.into_iter().map(PathBuf::from).collect(),
|
||||
network_access,
|
||||
exclude_tmpdir_env_var: false,
|
||||
exclude_slash_tmp: false,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serialize_workspace_write_environment_context() {
|
||||
let context = EnvironmentContext::new(
|
||||
Some(PathBuf::from("/repo")),
|
||||
Some(AskForApproval::OnRequest),
|
||||
Some(workspace_write_policy(vec!["/repo", "/tmp"], false)),
|
||||
None,
|
||||
);
|
||||
|
||||
let expected = r#"<environment_context>
|
||||
<cwd>/repo</cwd>
|
||||
<approval_policy>on-request</approval_policy>
|
||||
<sandbox_mode>workspace-write</sandbox_mode>
|
||||
<network_access>restricted</network_access>
|
||||
<writable_roots>
|
||||
<root>/repo</root>
|
||||
<root>/tmp</root>
|
||||
</writable_roots>
|
||||
</environment_context>"#;
|
||||
|
||||
assert_eq!(context.serialize_to_xml(), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serialize_read_only_environment_context() {
|
||||
let context = EnvironmentContext::new(
|
||||
None,
|
||||
Some(AskForApproval::Never),
|
||||
Some(SandboxPolicy::ReadOnly),
|
||||
None,
|
||||
);
|
||||
|
||||
let expected = r#"<environment_context>
|
||||
<approval_policy>never</approval_policy>
|
||||
<sandbox_mode>read-only</sandbox_mode>
|
||||
<network_access>restricted</network_access>
|
||||
</environment_context>"#;
|
||||
|
||||
assert_eq!(context.serialize_to_xml(), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serialize_full_access_environment_context() {
|
||||
let context = EnvironmentContext::new(
|
||||
None,
|
||||
Some(AskForApproval::OnFailure),
|
||||
Some(SandboxPolicy::DangerFullAccess),
|
||||
None,
|
||||
);
|
||||
|
||||
let expected = r#"<environment_context>
|
||||
<approval_policy>on-failure</approval_policy>
|
||||
<sandbox_mode>danger-full-access</sandbox_mode>
|
||||
<network_access>enabled</network_access>
|
||||
</environment_context>"#;
|
||||
|
||||
assert_eq!(context.serialize_to_xml(), expected);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
use crate::token_data::KnownPlan;
|
||||
use crate::token_data::PlanType;
|
||||
use codex_protocol::mcp_protocol::ConversationId;
|
||||
use reqwest::StatusCode;
|
||||
use serde_json;
|
||||
@@ -127,38 +129,58 @@ pub enum CodexErr {
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct UsageLimitReachedError {
|
||||
pub plan_type: Option<String>,
|
||||
pub resets_in_seconds: Option<u64>,
|
||||
pub(crate) plan_type: Option<PlanType>,
|
||||
pub(crate) resets_in_seconds: Option<u64>,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for UsageLimitReachedError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
// Base message differs slightly for legacy ChatGPT Plus plan users.
|
||||
if let Some(plan_type) = &self.plan_type
|
||||
&& plan_type == "plus"
|
||||
{
|
||||
write!(
|
||||
f,
|
||||
"You've hit your usage limit. Upgrade to Pro (https://openai.com/chatgpt/pricing) or try again"
|
||||
)?;
|
||||
if let Some(secs) = self.resets_in_seconds {
|
||||
let reset_duration = format_reset_duration(secs);
|
||||
write!(f, " in {reset_duration}.")?;
|
||||
} else {
|
||||
write!(f, " later.")?;
|
||||
let message = match self.plan_type.as_ref() {
|
||||
Some(PlanType::Known(KnownPlan::Plus)) => format!(
|
||||
"You've hit your usage limit. Upgrade to Pro (https://openai.com/chatgpt/pricing){}",
|
||||
retry_suffix_after_or(self.resets_in_seconds)
|
||||
),
|
||||
Some(PlanType::Known(KnownPlan::Team)) | Some(PlanType::Known(KnownPlan::Business)) => {
|
||||
format!(
|
||||
"You've hit your usage limit. To get more access now, send a request to your admin{}",
|
||||
retry_suffix_after_or(self.resets_in_seconds)
|
||||
)
|
||||
}
|
||||
} else {
|
||||
write!(f, "You've hit your usage limit.")?;
|
||||
|
||||
if let Some(secs) = self.resets_in_seconds {
|
||||
let reset_duration = format_reset_duration(secs);
|
||||
write!(f, " Try again in {reset_duration}.")?;
|
||||
} else {
|
||||
write!(f, " Try again later.")?;
|
||||
Some(PlanType::Known(KnownPlan::Free)) => {
|
||||
"To use Codex with your ChatGPT plan, upgrade to Plus: https://openai.com/chatgpt/pricing."
|
||||
.to_string()
|
||||
}
|
||||
}
|
||||
Some(PlanType::Known(KnownPlan::Pro))
|
||||
| Some(PlanType::Known(KnownPlan::Enterprise))
|
||||
| Some(PlanType::Known(KnownPlan::Edu)) => format!(
|
||||
"You've hit your usage limit.{}",
|
||||
retry_suffix(self.resets_in_seconds)
|
||||
),
|
||||
Some(PlanType::Unknown(_)) | None => format!(
|
||||
"You've hit your usage limit.{}",
|
||||
retry_suffix(self.resets_in_seconds)
|
||||
),
|
||||
};
|
||||
|
||||
Ok(())
|
||||
write!(f, "{message}")
|
||||
}
|
||||
}
|
||||
|
||||
fn retry_suffix(resets_in_seconds: Option<u64>) -> String {
|
||||
if let Some(secs) = resets_in_seconds {
|
||||
let reset_duration = format_reset_duration(secs);
|
||||
format!(" Try again in {reset_duration}.")
|
||||
} else {
|
||||
" Try again later.".to_string()
|
||||
}
|
||||
}
|
||||
|
||||
fn retry_suffix_after_or(resets_in_seconds: Option<u64>) -> String {
|
||||
if let Some(secs) = resets_in_seconds {
|
||||
let reset_duration = format_reset_duration(secs);
|
||||
format!(" or try again in {reset_duration}.")
|
||||
} else {
|
||||
" or try again later.".to_string()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -237,7 +259,7 @@ mod tests {
|
||||
#[test]
|
||||
fn usage_limit_reached_error_formats_plus_plan() {
|
||||
let err = UsageLimitReachedError {
|
||||
plan_type: Some("plus".to_string()),
|
||||
plan_type: Some(PlanType::Known(KnownPlan::Plus)),
|
||||
resets_in_seconds: None,
|
||||
};
|
||||
assert_eq!(
|
||||
@@ -246,6 +268,18 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn usage_limit_reached_error_formats_free_plan() {
|
||||
let err = UsageLimitReachedError {
|
||||
plan_type: Some(PlanType::Known(KnownPlan::Free)),
|
||||
resets_in_seconds: Some(3600),
|
||||
};
|
||||
assert_eq!(
|
||||
err.to_string(),
|
||||
"To use Codex with your ChatGPT plan, upgrade to Plus: https://openai.com/chatgpt/pricing."
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn usage_limit_reached_error_formats_default_when_none() {
|
||||
let err = UsageLimitReachedError {
|
||||
@@ -258,10 +292,34 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn usage_limit_reached_error_formats_team_plan() {
|
||||
let err = UsageLimitReachedError {
|
||||
plan_type: Some(PlanType::Known(KnownPlan::Team)),
|
||||
resets_in_seconds: Some(3600),
|
||||
};
|
||||
assert_eq!(
|
||||
err.to_string(),
|
||||
"You've hit your usage limit. To get more access now, send a request to your admin or try again in 1 hour."
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn usage_limit_reached_error_formats_business_plan_without_reset() {
|
||||
let err = UsageLimitReachedError {
|
||||
plan_type: Some(PlanType::Known(KnownPlan::Business)),
|
||||
resets_in_seconds: None,
|
||||
};
|
||||
assert_eq!(
|
||||
err.to_string(),
|
||||
"You've hit your usage limit. To get more access now, send a request to your admin or try again later."
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn usage_limit_reached_error_formats_default_for_other_plans() {
|
||||
let err = UsageLimitReachedError {
|
||||
plan_type: Some("pro".to_string()),
|
||||
plan_type: Some(PlanType::Known(KnownPlan::Pro)),
|
||||
resets_in_seconds: None,
|
||||
};
|
||||
assert_eq!(
|
||||
@@ -285,7 +343,7 @@ mod tests {
|
||||
#[test]
|
||||
fn usage_limit_reached_includes_hours_and_minutes() {
|
||||
let err = UsageLimitReachedError {
|
||||
plan_type: Some("plus".to_string()),
|
||||
plan_type: Some(PlanType::Known(KnownPlan::Plus)),
|
||||
resets_in_seconds: Some(3 * 3600 + 32 * 60),
|
||||
};
|
||||
assert_eq!(
|
||||
|
||||
@@ -25,31 +25,56 @@ pub(crate) fn map_response_item_to_event_messages(
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
let events: Vec<EventMsg> = content
|
||||
.iter()
|
||||
.filter_map(|content_item| match content_item {
|
||||
ContentItem::OutputText { text } => {
|
||||
Some(EventMsg::AgentMessage(AgentMessageEvent {
|
||||
message: text.clone(),
|
||||
}))
|
||||
}
|
||||
let mut events: Vec<EventMsg> = Vec::new();
|
||||
let mut message_parts: Vec<String> = Vec::new();
|
||||
let mut images: Vec<String> = Vec::new();
|
||||
let mut kind: Option<InputMessageKind> = None;
|
||||
|
||||
for content_item in content.iter() {
|
||||
match content_item {
|
||||
ContentItem::InputText { text } => {
|
||||
let trimmed = text.trim_start();
|
||||
let kind = if trimmed.starts_with("<environment_context>") {
|
||||
Some(InputMessageKind::EnvironmentContext)
|
||||
} else if trimmed.starts_with("<user_instructions>") {
|
||||
Some(InputMessageKind::UserInstructions)
|
||||
} else {
|
||||
Some(InputMessageKind::Plain)
|
||||
};
|
||||
Some(EventMsg::UserMessage(UserMessageEvent {
|
||||
message: text.clone(),
|
||||
kind,
|
||||
}))
|
||||
if kind.is_none() {
|
||||
let trimmed = text.trim_start();
|
||||
kind = if trimmed.starts_with("<environment_context>") {
|
||||
Some(InputMessageKind::EnvironmentContext)
|
||||
} else if trimmed.starts_with("<user_instructions>") {
|
||||
Some(InputMessageKind::UserInstructions)
|
||||
} else {
|
||||
Some(InputMessageKind::Plain)
|
||||
};
|
||||
}
|
||||
message_parts.push(text.clone());
|
||||
}
|
||||
_ => None,
|
||||
})
|
||||
.collect();
|
||||
ContentItem::InputImage { image_url } => {
|
||||
images.push(image_url.clone());
|
||||
}
|
||||
ContentItem::OutputText { text } => {
|
||||
events.push(EventMsg::AgentMessage(AgentMessageEvent {
|
||||
message: text.clone(),
|
||||
}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !message_parts.is_empty() || !images.is_empty() {
|
||||
let message = if message_parts.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
message_parts.join("")
|
||||
};
|
||||
let images = if images.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(images)
|
||||
};
|
||||
|
||||
events.push(EventMsg::UserMessage(UserMessageEvent {
|
||||
message,
|
||||
kind,
|
||||
images,
|
||||
}));
|
||||
}
|
||||
|
||||
events
|
||||
}
|
||||
|
||||
@@ -96,3 +121,47 @@ pub(crate) fn map_response_item_to_event_messages(
|
||||
| ResponseItem::Other => Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::map_response_item_to_event_messages;
|
||||
use crate::protocol::EventMsg;
|
||||
use crate::protocol::InputMessageKind;
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
#[test]
|
||||
fn maps_user_message_with_text_and_two_images() {
|
||||
let img1 = "https://example.com/one.png".to_string();
|
||||
let img2 = "https://example.com/two.jpg".to_string();
|
||||
|
||||
let item = ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![
|
||||
ContentItem::InputText {
|
||||
text: "Hello world".to_string(),
|
||||
},
|
||||
ContentItem::InputImage {
|
||||
image_url: img1.clone(),
|
||||
},
|
||||
ContentItem::InputImage {
|
||||
image_url: img2.clone(),
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
let events = map_response_item_to_event_messages(&item, false);
|
||||
assert_eq!(events.len(), 1, "expected a single user message event");
|
||||
|
||||
match &events[0] {
|
||||
EventMsg::UserMessage(user) => {
|
||||
assert_eq!(user.message, "Hello world");
|
||||
assert!(matches!(user.kind, Some(InputMessageKind::Plain)));
|
||||
assert_eq!(user.images, Some(vec![img1, img2]));
|
||||
}
|
||||
other => panic!("expected UserMessage, got {other:?}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -24,6 +24,9 @@ pub(crate) struct ExecCommandSession {
|
||||
|
||||
/// JoinHandle for the child wait task.
|
||||
wait_handle: StdMutex<Option<JoinHandle<()>>>,
|
||||
|
||||
/// Tracks whether the underlying process has exited.
|
||||
exit_status: std::sync::Arc<std::sync::atomic::AtomicBool>,
|
||||
}
|
||||
|
||||
impl ExecCommandSession {
|
||||
@@ -34,6 +37,7 @@ impl ExecCommandSession {
|
||||
reader_handle: JoinHandle<()>,
|
||||
writer_handle: JoinHandle<()>,
|
||||
wait_handle: JoinHandle<()>,
|
||||
exit_status: std::sync::Arc<std::sync::atomic::AtomicBool>,
|
||||
) -> Self {
|
||||
Self {
|
||||
writer_tx,
|
||||
@@ -42,6 +46,7 @@ impl ExecCommandSession {
|
||||
reader_handle: StdMutex::new(Some(reader_handle)),
|
||||
writer_handle: StdMutex::new(Some(writer_handle)),
|
||||
wait_handle: StdMutex::new(Some(wait_handle)),
|
||||
exit_status,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -52,6 +57,10 @@ impl ExecCommandSession {
|
||||
pub(crate) fn output_receiver(&self) -> broadcast::Receiver<Vec<u8>> {
|
||||
self.output_tx.subscribe()
|
||||
}
|
||||
|
||||
pub(crate) fn has_exited(&self) -> bool {
|
||||
self.exit_status.load(std::sync::atomic::Ordering::SeqCst)
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for ExecCommandSession {
|
||||
|
||||
@@ -6,6 +6,7 @@ mod session_manager;
|
||||
|
||||
pub use exec_command_params::ExecCommandParams;
|
||||
pub use exec_command_params::WriteStdinParams;
|
||||
pub(crate) use exec_command_session::ExecCommandSession;
|
||||
pub use responses_api::EXEC_COMMAND_TOOL_NAME;
|
||||
pub use responses_api::WRITE_STDIN_TOOL_NAME;
|
||||
pub use responses_api::create_exec_command_tool_for_responses_api;
|
||||
|
||||
@@ -3,6 +3,7 @@ use std::io::ErrorKind;
|
||||
use std::io::Read;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex as StdMutex;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::atomic::AtomicU32;
|
||||
|
||||
use portable_pty::CommandBuilder;
|
||||
@@ -19,6 +20,7 @@ use crate::exec_command::exec_command_params::ExecCommandParams;
|
||||
use crate::exec_command::exec_command_params::WriteStdinParams;
|
||||
use crate::exec_command::exec_command_session::ExecCommandSession;
|
||||
use crate::exec_command::session_id::SessionId;
|
||||
use crate::truncate::truncate_middle;
|
||||
use codex_protocol::models::FunctionCallOutputPayload;
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
@@ -327,11 +329,14 @@ async fn create_exec_command_session(
|
||||
|
||||
// Keep the child alive until it exits, then signal exit code.
|
||||
let (exit_tx, exit_rx) = oneshot::channel::<i32>();
|
||||
let exit_status = Arc::new(AtomicBool::new(false));
|
||||
let wait_exit_status = exit_status.clone();
|
||||
let wait_handle = tokio::task::spawn_blocking(move || {
|
||||
let code = match child.wait() {
|
||||
Ok(status) => status.exit_code() as i32,
|
||||
Err(_) => -1,
|
||||
};
|
||||
wait_exit_status.store(true, std::sync::atomic::Ordering::SeqCst);
|
||||
let _ = exit_tx.send(code);
|
||||
});
|
||||
|
||||
@@ -343,116 +348,11 @@ async fn create_exec_command_session(
|
||||
reader_handle,
|
||||
writer_handle,
|
||||
wait_handle,
|
||||
exit_status,
|
||||
);
|
||||
Ok((session, exit_rx))
|
||||
}
|
||||
|
||||
/// Truncate the middle of a UTF-8 string to at most `max_bytes` bytes,
|
||||
/// preserving the beginning and the end. Returns the possibly truncated
|
||||
/// string and `Some(original_token_count)` (estimated at 4 bytes/token)
|
||||
/// if truncation occurred; otherwise returns the original string and `None`.
|
||||
fn truncate_middle(s: &str, max_bytes: usize) -> (String, Option<u64>) {
|
||||
// No truncation needed
|
||||
if s.len() <= max_bytes {
|
||||
return (s.to_string(), None);
|
||||
}
|
||||
let est_tokens = (s.len() as u64).div_ceil(4);
|
||||
if max_bytes == 0 {
|
||||
// Cannot keep any content; still return a full marker (never truncated).
|
||||
return (format!("…{est_tokens} tokens truncated…"), Some(est_tokens));
|
||||
}
|
||||
|
||||
// Helper to truncate a string to a given byte length on a char boundary.
|
||||
fn truncate_on_boundary(input: &str, max_len: usize) -> &str {
|
||||
if input.len() <= max_len {
|
||||
return input;
|
||||
}
|
||||
let mut end = max_len;
|
||||
while end > 0 && !input.is_char_boundary(end) {
|
||||
end -= 1;
|
||||
}
|
||||
&input[..end]
|
||||
}
|
||||
|
||||
// Given a left/right budget, prefer newline boundaries; otherwise fall back
|
||||
// to UTF-8 char boundaries.
|
||||
fn pick_prefix_end(s: &str, left_budget: usize) -> usize {
|
||||
if let Some(head) = s.get(..left_budget)
|
||||
&& let Some(i) = head.rfind('\n')
|
||||
{
|
||||
return i + 1; // keep the newline so suffix starts on a fresh line
|
||||
}
|
||||
truncate_on_boundary(s, left_budget).len()
|
||||
}
|
||||
|
||||
fn pick_suffix_start(s: &str, right_budget: usize) -> usize {
|
||||
let start_tail = s.len().saturating_sub(right_budget);
|
||||
if let Some(tail) = s.get(start_tail..)
|
||||
&& let Some(i) = tail.find('\n')
|
||||
{
|
||||
return start_tail + i + 1; // start after newline
|
||||
}
|
||||
// Fall back to a char boundary at or after start_tail.
|
||||
let mut idx = start_tail.min(s.len());
|
||||
while idx < s.len() && !s.is_char_boundary(idx) {
|
||||
idx += 1;
|
||||
}
|
||||
idx
|
||||
}
|
||||
|
||||
// Refine marker length and budgets until stable. Marker is never truncated.
|
||||
let mut guess_tokens = est_tokens; // worst-case: everything truncated
|
||||
for _ in 0..4 {
|
||||
let marker = format!("…{guess_tokens} tokens truncated…");
|
||||
let marker_len = marker.len();
|
||||
let keep_budget = max_bytes.saturating_sub(marker_len);
|
||||
if keep_budget == 0 {
|
||||
// No room for any content within the cap; return a full, untruncated marker
|
||||
// that reflects the entire truncated content.
|
||||
return (format!("…{est_tokens} tokens truncated…"), Some(est_tokens));
|
||||
}
|
||||
|
||||
let left_budget = keep_budget / 2;
|
||||
let right_budget = keep_budget - left_budget;
|
||||
let prefix_end = pick_prefix_end(s, left_budget);
|
||||
let mut suffix_start = pick_suffix_start(s, right_budget);
|
||||
if suffix_start < prefix_end {
|
||||
suffix_start = prefix_end;
|
||||
}
|
||||
let kept_content_bytes = prefix_end + (s.len() - suffix_start);
|
||||
let truncated_content_bytes = s.len().saturating_sub(kept_content_bytes);
|
||||
let new_tokens = (truncated_content_bytes as u64).div_ceil(4);
|
||||
if new_tokens == guess_tokens {
|
||||
let mut out = String::with_capacity(marker_len + kept_content_bytes + 1);
|
||||
out.push_str(&s[..prefix_end]);
|
||||
out.push_str(&marker);
|
||||
// Place marker on its own line for symmetry when we keep line boundaries.
|
||||
out.push('\n');
|
||||
out.push_str(&s[suffix_start..]);
|
||||
return (out, Some(est_tokens));
|
||||
}
|
||||
guess_tokens = new_tokens;
|
||||
}
|
||||
|
||||
// Fallback: use last guess to build output.
|
||||
let marker = format!("…{guess_tokens} tokens truncated…");
|
||||
let marker_len = marker.len();
|
||||
let keep_budget = max_bytes.saturating_sub(marker_len);
|
||||
if keep_budget == 0 {
|
||||
return (format!("…{est_tokens} tokens truncated…"), Some(est_tokens));
|
||||
}
|
||||
let left_budget = keep_budget / 2;
|
||||
let right_budget = keep_budget - left_budget;
|
||||
let prefix_end = pick_prefix_end(s, left_budget);
|
||||
let suffix_start = pick_suffix_start(s, right_budget);
|
||||
let mut out = String::with_capacity(marker_len + prefix_end + (s.len() - suffix_start) + 1);
|
||||
out.push_str(&s[..prefix_end]);
|
||||
out.push_str(&marker);
|
||||
out.push('\n');
|
||||
out.push_str(&s[suffix_start..]);
|
||||
(out, Some(est_tokens))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@@ -616,50 +516,4 @@ Output:
|
||||
abc"#;
|
||||
assert_eq!(expected, text);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncate_middle_no_newlines_fallback() {
|
||||
// A long string with no newlines that exceeds the cap.
|
||||
let s = "abcdefghijklmnopqrstuvwxyz0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ";
|
||||
let max_bytes = 16; // force truncation
|
||||
let (out, original) = truncate_middle(s, max_bytes);
|
||||
// For very small caps, we return the full, untruncated marker,
|
||||
// even if it exceeds the cap.
|
||||
assert_eq!(out, "…16 tokens truncated…");
|
||||
// Original string length is 62 bytes => ceil(62/4) = 16 tokens.
|
||||
assert_eq!(original, Some(16));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncate_middle_prefers_newline_boundaries() {
|
||||
// Build a multi-line string of 20 numbered lines (each "NNN\n").
|
||||
let mut s = String::new();
|
||||
for i in 1..=20 {
|
||||
s.push_str(&format!("{i:03}\n"));
|
||||
}
|
||||
// Total length: 20 lines * 4 bytes per line = 80 bytes.
|
||||
assert_eq!(s.len(), 80);
|
||||
|
||||
// Choose a cap that forces truncation while leaving room for
|
||||
// a few lines on each side after accounting for the marker.
|
||||
let max_bytes = 64;
|
||||
// Expect exact output: first 4 lines, marker, last 4 lines, and correct token estimate (80/4 = 20).
|
||||
assert_eq!(
|
||||
truncate_middle(&s, max_bytes),
|
||||
(
|
||||
r#"001
|
||||
002
|
||||
003
|
||||
004
|
||||
…12 tokens truncated…
|
||||
017
|
||||
018
|
||||
019
|
||||
020
|
||||
"#
|
||||
.to_string(),
|
||||
Some(20)
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use codex_protocol::mcp_protocol::GitSha;
|
||||
use codex_protocol::protocol::GitInfo;
|
||||
use futures::future::join_all;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
@@ -43,19 +44,6 @@ pub fn get_git_repo_root(base_dir: &Path) -> Option<PathBuf> {
|
||||
/// Timeout for git commands to prevent freezing on large repositories
|
||||
const GIT_COMMAND_TIMEOUT: TokioDuration = TokioDuration::from_secs(5);
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
pub struct GitInfo {
|
||||
/// Current commit hash (SHA)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub commit_hash: Option<String>,
|
||||
/// Current branch name
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub branch: Option<String>,
|
||||
/// Repository URL (if available from remote)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub repository_url: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
pub struct GitDiffToRemote {
|
||||
pub sha: GitSha,
|
||||
@@ -814,7 +802,7 @@ mod tests {
|
||||
async fn resolve_root_git_project_for_trust_regular_repo_returns_repo_root() {
|
||||
let temp_dir = TempDir::new().expect("Failed to create temp dir");
|
||||
let repo_path = create_test_git_repo(&temp_dir).await;
|
||||
let expected = std::fs::canonicalize(&repo_path).unwrap().to_path_buf();
|
||||
let expected = std::fs::canonicalize(&repo_path).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
resolve_root_git_project_for_trust(&repo_path),
|
||||
@@ -822,10 +810,7 @@ mod tests {
|
||||
);
|
||||
let nested = repo_path.join("sub/dir");
|
||||
std::fs::create_dir_all(&nested).unwrap();
|
||||
assert_eq!(
|
||||
resolve_root_git_project_for_trust(&nested),
|
||||
Some(expected.clone())
|
||||
);
|
||||
assert_eq!(resolve_root_git_project_for_trust(&nested), Some(expected));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
||||
68
codex-rs/core/src/internal_storage.rs
Normal file
68
codex-rs/core/src/internal_storage.rs
Normal file
@@ -0,0 +1,68 @@
|
||||
use anyhow::Context;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
|
||||
pub(crate) const INTERNAL_STORAGE_FILE: &str = "internal_storage.json";
|
||||
|
||||
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
|
||||
pub struct InternalStorage {
|
||||
#[serde(skip)]
|
||||
storage_path: PathBuf,
|
||||
#[serde(default)]
|
||||
pub gpt_5_high_model_prompt_seen: bool,
|
||||
}
|
||||
|
||||
// TODO(jif) generalise all the file writers and build proper async channel inserters.
|
||||
impl InternalStorage {
|
||||
pub fn load(codex_home: &Path) -> Self {
|
||||
let storage_path = codex_home.join(INTERNAL_STORAGE_FILE);
|
||||
|
||||
match std::fs::read_to_string(&storage_path) {
|
||||
Ok(serialized) => match serde_json::from_str::<Self>(&serialized) {
|
||||
Ok(mut storage) => {
|
||||
storage.storage_path = storage_path;
|
||||
storage
|
||||
}
|
||||
Err(error) => {
|
||||
tracing::warn!("failed to parse internal storage: {error:?}");
|
||||
Self::empty(storage_path)
|
||||
}
|
||||
},
|
||||
Err(error) => {
|
||||
tracing::warn!("failed to read internal storage: {error:?}");
|
||||
Self::empty(storage_path)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn empty(storage_path: PathBuf) -> Self {
|
||||
Self {
|
||||
storage_path,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn persist(&self) -> anyhow::Result<()> {
|
||||
let serialized = serde_json::to_string_pretty(self)?;
|
||||
|
||||
if let Some(parent) = self.storage_path.parent() {
|
||||
tokio::fs::create_dir_all(parent).await.with_context(|| {
|
||||
format!(
|
||||
"failed to create internal storage directory at {}",
|
||||
parent.display()
|
||||
)
|
||||
})?;
|
||||
}
|
||||
|
||||
tokio::fs::write(&self.storage_path, serialized)
|
||||
.await
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"failed to persist internal storage at {}",
|
||||
self.storage_path.display()
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -16,6 +16,7 @@ mod codex_conversation;
|
||||
pub mod token_data;
|
||||
pub use codex_conversation::CodexConversation;
|
||||
pub mod config;
|
||||
pub mod config_edit;
|
||||
pub mod config_profile;
|
||||
pub mod config_types;
|
||||
mod conversation_history;
|
||||
@@ -27,6 +28,7 @@ mod exec_command;
|
||||
pub mod exec_env;
|
||||
mod flags;
|
||||
pub mod git_info;
|
||||
pub mod internal_storage;
|
||||
mod is_safe_command;
|
||||
pub mod landlock;
|
||||
mod mcp_connection_manager;
|
||||
@@ -34,6 +36,8 @@ mod mcp_tool_call;
|
||||
mod message_history;
|
||||
mod model_provider_info;
|
||||
pub mod parse_command;
|
||||
mod truncate;
|
||||
mod unified_exec;
|
||||
mod user_instructions;
|
||||
pub use model_provider_info::BUILT_IN_OSS_MODEL_PROVIDER_ID;
|
||||
pub use model_provider_info::ModelProviderInfo;
|
||||
@@ -42,6 +46,7 @@ pub use model_provider_info::built_in_model_providers;
|
||||
pub use model_provider_info::create_oss_provider_with_base_url;
|
||||
mod conversation_manager;
|
||||
mod event_mapping;
|
||||
pub use codex_protocol::protocol::InitialHistory;
|
||||
pub use conversation_manager::ConversationManager;
|
||||
pub use conversation_manager::NewConversation;
|
||||
// Re-export common auth types for workspace consumers
|
||||
@@ -61,13 +66,16 @@ pub mod spawn;
|
||||
pub mod terminal;
|
||||
mod tool_apply_patch;
|
||||
pub mod turn_diff_tracker;
|
||||
pub use rollout::ARCHIVED_SESSIONS_SUBDIR;
|
||||
pub use rollout::RolloutRecorder;
|
||||
pub use rollout::SESSIONS_SUBDIR;
|
||||
pub use rollout::SessionMeta;
|
||||
pub use rollout::list::ConversationItem;
|
||||
pub use rollout::list::ConversationsPage;
|
||||
pub use rollout::list::Cursor;
|
||||
mod user_notification;
|
||||
pub mod util;
|
||||
|
||||
pub use apply_patch::CODEX_APPLY_PATCH_ARG1;
|
||||
pub use safety::get_platform_sandbox;
|
||||
// Re-export the protocol types from the standalone `codex-protocol` crate so existing
|
||||
|
||||
@@ -163,6 +163,10 @@ impl McpConnectionManager {
|
||||
name: "codex-mcp-client".to_owned(),
|
||||
version: env!("CARGO_PKG_VERSION").to_owned(),
|
||||
title: Some("Codex".into()),
|
||||
// This field is used by Codex when it is an MCP
|
||||
// server: it should not be used when Codex is
|
||||
// an MCP client.
|
||||
user_agent: None,
|
||||
},
|
||||
protocol_version: mcp_types::MCP_SCHEMA_VERSION.to_owned(),
|
||||
};
|
||||
|
||||
@@ -103,7 +103,7 @@ pub fn find_family_for_model(slug: &str) -> Option<ModelFamily> {
|
||||
slug, "gpt-4.1",
|
||||
needs_special_apply_patch_instructions: true,
|
||||
)
|
||||
} else if slug.starts_with("gpt-oss") {
|
||||
} else if slug.starts_with("gpt-oss") || slug.starts_with("openai/gpt-oss") {
|
||||
model_family!(slug, "gpt-oss", apply_patch_tool_type: Some(ApplyPatchToolType::Function))
|
||||
} else if slug.starts_with("gpt-4o") {
|
||||
simple_model_family!(slug, "gpt-4o")
|
||||
|
||||
@@ -80,7 +80,10 @@ pub struct ModelProviderInfo {
|
||||
/// the connection as lost.
|
||||
pub stream_idle_timeout_ms: Option<u64>,
|
||||
|
||||
/// Whether this provider requires some form of standard authentication (API key, ChatGPT token).
|
||||
/// Does this provider require an OpenAI API Key or ChatGPT login token? If true,
|
||||
/// user is presented with login screen on first run, and login preference and token/key
|
||||
/// are stored in auth.json. If false (which is the default), login screen is skipped,
|
||||
/// and API key (if needed) comes from the "env_key" environment variable.
|
||||
#[serde(default)]
|
||||
pub requires_openai_auth: bool,
|
||||
}
|
||||
|
||||
@@ -78,7 +78,7 @@ pub(crate) fn get_model_info(model_family: &ModelFamily) -> Option<ModelInfo> {
|
||||
max_output_tokens: 4_096,
|
||||
}),
|
||||
|
||||
"gpt-5" => Some(ModelInfo {
|
||||
_ if slug.starts_with("gpt-5") => Some(ModelInfo {
|
||||
context_window: 272_000,
|
||||
max_output_tokens: 128_000,
|
||||
}),
|
||||
|
||||
@@ -8,7 +8,6 @@ use std::collections::HashMap;
|
||||
use crate::model_family::ModelFamily;
|
||||
use crate::plan_tool::PLAN_TOOL;
|
||||
use crate::protocol::AskForApproval;
|
||||
use crate::protocol::SandboxPolicy;
|
||||
use crate::tool_apply_patch::ApplyPatchToolType;
|
||||
use crate::tool_apply_patch::create_apply_patch_freeform_tool;
|
||||
use crate::tool_apply_patch::create_apply_patch_json_tool;
|
||||
@@ -58,7 +57,7 @@ pub(crate) enum OpenAiTool {
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ConfigShellToolType {
|
||||
DefaultShell,
|
||||
ShellWithRequest { sandbox_policy: SandboxPolicy },
|
||||
ShellWithRequest,
|
||||
LocalShell,
|
||||
StreamableShell,
|
||||
}
|
||||
@@ -70,17 +69,18 @@ pub(crate) struct ToolsConfig {
|
||||
pub apply_patch_tool_type: Option<ApplyPatchToolType>,
|
||||
pub web_search_request: bool,
|
||||
pub include_view_image_tool: bool,
|
||||
pub experimental_unified_exec_tool: bool,
|
||||
}
|
||||
|
||||
pub(crate) struct ToolsConfigParams<'a> {
|
||||
pub(crate) model_family: &'a ModelFamily,
|
||||
pub(crate) approval_policy: AskForApproval,
|
||||
pub(crate) sandbox_policy: SandboxPolicy,
|
||||
pub(crate) include_plan_tool: bool,
|
||||
pub(crate) include_apply_patch_tool: bool,
|
||||
pub(crate) include_web_search_request: bool,
|
||||
pub(crate) use_streamable_shell_tool: bool,
|
||||
pub(crate) include_view_image_tool: bool,
|
||||
pub(crate) experimental_unified_exec_tool: bool,
|
||||
}
|
||||
|
||||
impl ToolsConfig {
|
||||
@@ -88,12 +88,12 @@ impl ToolsConfig {
|
||||
let ToolsConfigParams {
|
||||
model_family,
|
||||
approval_policy,
|
||||
sandbox_policy,
|
||||
include_plan_tool,
|
||||
include_apply_patch_tool,
|
||||
include_web_search_request,
|
||||
use_streamable_shell_tool,
|
||||
include_view_image_tool,
|
||||
experimental_unified_exec_tool,
|
||||
} = params;
|
||||
let mut shell_type = if *use_streamable_shell_tool {
|
||||
ConfigShellToolType::StreamableShell
|
||||
@@ -103,9 +103,7 @@ impl ToolsConfig {
|
||||
ConfigShellToolType::DefaultShell
|
||||
};
|
||||
if matches!(approval_policy, AskForApproval::OnRequest) && !use_streamable_shell_tool {
|
||||
shell_type = ConfigShellToolType::ShellWithRequest {
|
||||
sandbox_policy: sandbox_policy.clone(),
|
||||
}
|
||||
shell_type = ConfigShellToolType::ShellWithRequest;
|
||||
}
|
||||
|
||||
let apply_patch_tool_type = match model_family.apply_patch_tool_type {
|
||||
@@ -126,6 +124,7 @@ impl ToolsConfig {
|
||||
apply_patch_tool_type,
|
||||
web_search_request: *include_web_search_request,
|
||||
include_view_image_tool: *include_view_image_tool,
|
||||
experimental_unified_exec_tool: *experimental_unified_exec_tool,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -200,7 +199,56 @@ fn create_shell_tool() -> OpenAiTool {
|
||||
})
|
||||
}
|
||||
|
||||
fn create_shell_tool_for_sandbox(sandbox_policy: &SandboxPolicy) -> OpenAiTool {
|
||||
fn create_unified_exec_tool() -> OpenAiTool {
|
||||
let mut properties = BTreeMap::new();
|
||||
properties.insert(
|
||||
"input".to_string(),
|
||||
JsonSchema::Array {
|
||||
items: Box::new(JsonSchema::String { description: None }),
|
||||
description: Some(
|
||||
"When no session_id is provided, treat the array as the command and arguments \
|
||||
to launch. When session_id is set, concatenate the strings (in order) and write \
|
||||
them to the session's stdin."
|
||||
.to_string(),
|
||||
),
|
||||
},
|
||||
);
|
||||
properties.insert(
|
||||
"session_id".to_string(),
|
||||
JsonSchema::String {
|
||||
description: Some(
|
||||
"Identifier for an existing interactive session. If omitted, a new command \
|
||||
is spawned."
|
||||
.to_string(),
|
||||
),
|
||||
},
|
||||
);
|
||||
properties.insert(
|
||||
"timeout_ms".to_string(),
|
||||
JsonSchema::Number {
|
||||
description: Some(
|
||||
"Maximum time in milliseconds to wait for output after writing the input."
|
||||
.to_string(),
|
||||
),
|
||||
},
|
||||
);
|
||||
|
||||
OpenAiTool::Function(ResponsesApiTool {
|
||||
name: "unified_exec".to_string(),
|
||||
description:
|
||||
"Runs a command in a PTY. Provide a session_id to reuse an existing interactive session.".to_string(),
|
||||
strict: false,
|
||||
parameters: JsonSchema::Object {
|
||||
properties,
|
||||
required: Some(vec!["input".to_string()]),
|
||||
additional_properties: Some(false),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
const SHELL_TOOL_DESCRIPTION: &str = r#"Runs a shell command and returns its output"#;
|
||||
|
||||
fn create_shell_tool_for_request() -> OpenAiTool {
|
||||
let mut properties = BTreeMap::new();
|
||||
properties.insert(
|
||||
"command".to_string(),
|
||||
@@ -212,82 +260,29 @@ fn create_shell_tool_for_sandbox(sandbox_policy: &SandboxPolicy) -> OpenAiTool {
|
||||
properties.insert(
|
||||
"workdir".to_string(),
|
||||
JsonSchema::String {
|
||||
description: Some("The working directory to execute the command in".to_string()),
|
||||
description: Some("Working directory to execute the command in.".to_string()),
|
||||
},
|
||||
);
|
||||
properties.insert(
|
||||
"timeout_ms".to_string(),
|
||||
JsonSchema::Number {
|
||||
description: Some("The timeout for the command in milliseconds".to_string()),
|
||||
description: Some("Timeout for the command in milliseconds.".to_string()),
|
||||
},
|
||||
);
|
||||
|
||||
if matches!(sandbox_policy, SandboxPolicy::WorkspaceWrite { .. }) {
|
||||
properties.insert(
|
||||
properties.insert(
|
||||
"with_escalated_permissions".to_string(),
|
||||
JsonSchema::Boolean {
|
||||
description: Some("Whether to request escalated permissions. Set to true if command needs to be run without sandbox restrictions".to_string()),
|
||||
description: Some("Request escalated permissions, only for when a command would otherwise be blocked by the sandbox.".to_string()),
|
||||
},
|
||||
);
|
||||
properties.insert(
|
||||
properties.insert(
|
||||
"justification".to_string(),
|
||||
JsonSchema::String {
|
||||
description: Some("Only set if with_escalated_permissions is true. 1-sentence explanation of why we want to run this command.".to_string()),
|
||||
description: Some("Required if and only if with_escalated_permissions == true. One sentence explaining why escalation is needed (e.g., write outside CWD, network fetch, git commit).".to_string()),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
let description = match sandbox_policy {
|
||||
SandboxPolicy::WorkspaceWrite {
|
||||
network_access,
|
||||
writable_roots,
|
||||
..
|
||||
} => {
|
||||
format!(
|
||||
r#"
|
||||
The shell tool is used to execute shell commands.
|
||||
- When invoking the shell tool, your call will be running in a sandbox, and some shell commands will require escalated privileges:
|
||||
- Types of actions that require escalated privileges:
|
||||
- Writing files other than those in the writable roots
|
||||
- writable roots:
|
||||
{}{}
|
||||
- Examples of commands that require escalated privileges:
|
||||
- git commit
|
||||
- npm install or pnpm install
|
||||
- cargo build
|
||||
- cargo test
|
||||
- When invoking a command that will require escalated privileges:
|
||||
- Provide the with_escalated_permissions parameter with the boolean value true
|
||||
- Include a short, 1 sentence explanation for why we need to run with_escalated_permissions in the justification parameter."#,
|
||||
writable_roots.iter().map(|wr| format!(" - {}", wr.to_string_lossy())).collect::<Vec<String>>().join("\n"),
|
||||
if !network_access {
|
||||
"\n - Commands that require network access\n"
|
||||
} else {
|
||||
""
|
||||
}
|
||||
)
|
||||
}
|
||||
SandboxPolicy::DangerFullAccess => {
|
||||
"Runs a shell command and returns its output.".to_string()
|
||||
}
|
||||
SandboxPolicy::ReadOnly => {
|
||||
r#"
|
||||
The shell tool is used to execute shell commands.
|
||||
- When invoking the shell tool, your call will be running in a sandbox, and some shell commands (including apply_patch) will require escalated permissions:
|
||||
- Types of actions that require escalated privileges:
|
||||
- Writing files
|
||||
- Applying patches
|
||||
- Examples of commands that require escalated privileges:
|
||||
- apply_patch
|
||||
- git commit
|
||||
- npm install or pnpm install
|
||||
- cargo build
|
||||
- cargo test
|
||||
- When invoking a command that will require escalated privileges:
|
||||
- Provide the with_escalated_permissions parameter with the boolean value true
|
||||
- Include a short, 1 sentence explanation for why we need to run with_escalated_permissions in the justification parameter"#.to_string()
|
||||
}
|
||||
};
|
||||
let description = SHELL_TOOL_DESCRIPTION.to_string();
|
||||
|
||||
OpenAiTool::Function(ResponsesApiTool {
|
||||
name: "shell".to_string(),
|
||||
@@ -300,7 +295,6 @@ The shell tool is used to execute shell commands.
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
fn create_view_image_tool() -> OpenAiTool {
|
||||
// Support only local filesystem path.
|
||||
let mut properties = BTreeMap::new();
|
||||
@@ -534,23 +528,27 @@ pub(crate) fn get_openai_tools(
|
||||
) -> Vec<OpenAiTool> {
|
||||
let mut tools: Vec<OpenAiTool> = Vec::new();
|
||||
|
||||
match &config.shell_type {
|
||||
ConfigShellToolType::DefaultShell => {
|
||||
tools.push(create_shell_tool());
|
||||
}
|
||||
ConfigShellToolType::ShellWithRequest { sandbox_policy } => {
|
||||
tools.push(create_shell_tool_for_sandbox(sandbox_policy));
|
||||
}
|
||||
ConfigShellToolType::LocalShell => {
|
||||
tools.push(OpenAiTool::LocalShell {});
|
||||
}
|
||||
ConfigShellToolType::StreamableShell => {
|
||||
tools.push(OpenAiTool::Function(
|
||||
crate::exec_command::create_exec_command_tool_for_responses_api(),
|
||||
));
|
||||
tools.push(OpenAiTool::Function(
|
||||
crate::exec_command::create_write_stdin_tool_for_responses_api(),
|
||||
));
|
||||
if config.experimental_unified_exec_tool {
|
||||
tools.push(create_unified_exec_tool());
|
||||
} else {
|
||||
match &config.shell_type {
|
||||
ConfigShellToolType::DefaultShell => {
|
||||
tools.push(create_shell_tool());
|
||||
}
|
||||
ConfigShellToolType::ShellWithRequest => {
|
||||
tools.push(create_shell_tool_for_request());
|
||||
}
|
||||
ConfigShellToolType::LocalShell => {
|
||||
tools.push(OpenAiTool::LocalShell {});
|
||||
}
|
||||
ConfigShellToolType::StreamableShell => {
|
||||
tools.push(OpenAiTool::Function(
|
||||
crate::exec_command::create_exec_command_tool_for_responses_api(),
|
||||
));
|
||||
tools.push(OpenAiTool::Function(
|
||||
crate::exec_command::create_write_stdin_tool_for_responses_api(),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -577,10 +575,8 @@ pub(crate) fn get_openai_tools(
|
||||
if config.include_view_image_tool {
|
||||
tools.push(create_view_image_tool());
|
||||
}
|
||||
|
||||
if let Some(mcp_tools) = mcp_tools {
|
||||
// Ensure deterministic ordering to maximize prompt cache hits.
|
||||
// HashMap iteration order is non-deterministic, so sort by fully-qualified tool name.
|
||||
let mut entries: Vec<(String, mcp_types::Tool)> = mcp_tools.into_iter().collect();
|
||||
entries.sort_by(|a, b| a.0.cmp(&b.0));
|
||||
|
||||
@@ -636,18 +632,18 @@ mod tests {
|
||||
let config = ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &model_family,
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::ReadOnly,
|
||||
include_plan_tool: true,
|
||||
include_apply_patch_tool: false,
|
||||
include_web_search_request: true,
|
||||
use_streamable_shell_tool: false,
|
||||
include_view_image_tool: true,
|
||||
experimental_unified_exec_tool: true,
|
||||
});
|
||||
let tools = get_openai_tools(&config, Some(HashMap::new()));
|
||||
|
||||
assert_eq_tool_names(
|
||||
&tools,
|
||||
&["local_shell", "update_plan", "web_search", "view_image"],
|
||||
&["unified_exec", "update_plan", "web_search", "view_image"],
|
||||
);
|
||||
}
|
||||
|
||||
@@ -657,18 +653,18 @@ mod tests {
|
||||
let config = ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &model_family,
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::ReadOnly,
|
||||
include_plan_tool: true,
|
||||
include_apply_patch_tool: false,
|
||||
include_web_search_request: true,
|
||||
use_streamable_shell_tool: false,
|
||||
include_view_image_tool: true,
|
||||
experimental_unified_exec_tool: true,
|
||||
});
|
||||
let tools = get_openai_tools(&config, Some(HashMap::new()));
|
||||
|
||||
assert_eq_tool_names(
|
||||
&tools,
|
||||
&["shell", "update_plan", "web_search", "view_image"],
|
||||
&["unified_exec", "update_plan", "web_search", "view_image"],
|
||||
);
|
||||
}
|
||||
|
||||
@@ -678,12 +674,12 @@ mod tests {
|
||||
let config = ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &model_family,
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::ReadOnly,
|
||||
include_plan_tool: false,
|
||||
include_apply_patch_tool: false,
|
||||
include_web_search_request: true,
|
||||
use_streamable_shell_tool: false,
|
||||
include_view_image_tool: true,
|
||||
experimental_unified_exec_tool: true,
|
||||
});
|
||||
let tools = get_openai_tools(
|
||||
&config,
|
||||
@@ -726,7 +722,7 @@ mod tests {
|
||||
assert_eq_tool_names(
|
||||
&tools,
|
||||
&[
|
||||
"shell",
|
||||
"unified_exec",
|
||||
"web_search",
|
||||
"view_image",
|
||||
"test_server/do_something_cool",
|
||||
@@ -783,12 +779,12 @@ mod tests {
|
||||
let config = ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &model_family,
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::ReadOnly,
|
||||
include_plan_tool: false,
|
||||
include_apply_patch_tool: false,
|
||||
include_web_search_request: false,
|
||||
use_streamable_shell_tool: false,
|
||||
include_view_image_tool: true,
|
||||
experimental_unified_exec_tool: true,
|
||||
});
|
||||
|
||||
// Intentionally construct a map with keys that would sort alphabetically.
|
||||
@@ -841,11 +837,11 @@ mod tests {
|
||||
]);
|
||||
|
||||
let tools = get_openai_tools(&config, Some(tools_map));
|
||||
// Expect shell first, followed by MCP tools sorted by fully-qualified name.
|
||||
// Expect unified_exec first, followed by MCP tools sorted by fully-qualified name.
|
||||
assert_eq_tool_names(
|
||||
&tools,
|
||||
&[
|
||||
"shell",
|
||||
"unified_exec",
|
||||
"view_image",
|
||||
"test_server/cool",
|
||||
"test_server/do",
|
||||
@@ -860,12 +856,12 @@ mod tests {
|
||||
let config = ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &model_family,
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::ReadOnly,
|
||||
include_plan_tool: false,
|
||||
include_apply_patch_tool: false,
|
||||
include_web_search_request: true,
|
||||
use_streamable_shell_tool: false,
|
||||
include_view_image_tool: true,
|
||||
experimental_unified_exec_tool: true,
|
||||
});
|
||||
|
||||
let tools = get_openai_tools(
|
||||
@@ -893,7 +889,7 @@ mod tests {
|
||||
|
||||
assert_eq_tool_names(
|
||||
&tools,
|
||||
&["shell", "web_search", "view_image", "dash/search"],
|
||||
&["unified_exec", "web_search", "view_image", "dash/search"],
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
@@ -922,12 +918,12 @@ mod tests {
|
||||
let config = ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &model_family,
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::ReadOnly,
|
||||
include_plan_tool: false,
|
||||
include_apply_patch_tool: false,
|
||||
include_web_search_request: true,
|
||||
use_streamable_shell_tool: false,
|
||||
include_view_image_tool: true,
|
||||
experimental_unified_exec_tool: true,
|
||||
});
|
||||
|
||||
let tools = get_openai_tools(
|
||||
@@ -953,7 +949,7 @@ mod tests {
|
||||
|
||||
assert_eq_tool_names(
|
||||
&tools,
|
||||
&["shell", "web_search", "view_image", "dash/paginate"],
|
||||
&["unified_exec", "web_search", "view_image", "dash/paginate"],
|
||||
);
|
||||
assert_eq!(
|
||||
tools[3],
|
||||
@@ -979,12 +975,12 @@ mod tests {
|
||||
let config = ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &model_family,
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::ReadOnly,
|
||||
include_plan_tool: false,
|
||||
include_apply_patch_tool: false,
|
||||
include_web_search_request: true,
|
||||
use_streamable_shell_tool: false,
|
||||
include_view_image_tool: true,
|
||||
experimental_unified_exec_tool: true,
|
||||
});
|
||||
|
||||
let tools = get_openai_tools(
|
||||
@@ -1008,7 +1004,10 @@ mod tests {
|
||||
)])),
|
||||
);
|
||||
|
||||
assert_eq_tool_names(&tools, &["shell", "web_search", "view_image", "dash/tags"]);
|
||||
assert_eq_tool_names(
|
||||
&tools,
|
||||
&["unified_exec", "web_search", "view_image", "dash/tags"],
|
||||
);
|
||||
assert_eq!(
|
||||
tools[3],
|
||||
OpenAiTool::Function(ResponsesApiTool {
|
||||
@@ -1036,12 +1035,12 @@ mod tests {
|
||||
let config = ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &model_family,
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::ReadOnly,
|
||||
include_plan_tool: false,
|
||||
include_apply_patch_tool: false,
|
||||
include_web_search_request: true,
|
||||
use_streamable_shell_tool: false,
|
||||
include_view_image_tool: true,
|
||||
experimental_unified_exec_tool: true,
|
||||
});
|
||||
|
||||
let tools = get_openai_tools(
|
||||
@@ -1065,7 +1064,10 @@ mod tests {
|
||||
)])),
|
||||
);
|
||||
|
||||
assert_eq_tool_names(&tools, &["shell", "web_search", "view_image", "dash/value"]);
|
||||
assert_eq_tool_names(
|
||||
&tools,
|
||||
&["unified_exec", "web_search", "view_image", "dash/value"],
|
||||
);
|
||||
assert_eq!(
|
||||
tools[3],
|
||||
OpenAiTool::Function(ResponsesApiTool {
|
||||
@@ -1086,13 +1088,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_shell_tool_for_sandbox_workspace_write() {
|
||||
let sandbox_policy = SandboxPolicy::WorkspaceWrite {
|
||||
writable_roots: vec!["workspace".into()],
|
||||
network_access: false,
|
||||
exclude_tmpdir_env_var: false,
|
||||
exclude_slash_tmp: false,
|
||||
};
|
||||
let tool = super::create_shell_tool_for_sandbox(&sandbox_policy);
|
||||
let tool = super::create_shell_tool_for_request();
|
||||
let OpenAiTool::Function(ResponsesApiTool {
|
||||
description, name, ..
|
||||
}) = &tool
|
||||
@@ -1101,29 +1097,13 @@ mod tests {
|
||||
};
|
||||
assert_eq!(name, "shell");
|
||||
|
||||
let expected = r#"
|
||||
The shell tool is used to execute shell commands.
|
||||
- When invoking the shell tool, your call will be running in a sandbox, and some shell commands will require escalated privileges:
|
||||
- Types of actions that require escalated privileges:
|
||||
- Writing files other than those in the writable roots
|
||||
- writable roots:
|
||||
- workspace
|
||||
- Commands that require network access
|
||||
|
||||
- Examples of commands that require escalated privileges:
|
||||
- git commit
|
||||
- npm install or pnpm install
|
||||
- cargo build
|
||||
- cargo test
|
||||
- When invoking a command that will require escalated privileges:
|
||||
- Provide the with_escalated_permissions parameter with the boolean value true
|
||||
- Include a short, 1 sentence explanation for why we need to run with_escalated_permissions in the justification parameter."#;
|
||||
let expected = super::SHELL_TOOL_DESCRIPTION;
|
||||
assert_eq!(description, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_shell_tool_for_sandbox_readonly() {
|
||||
let tool = super::create_shell_tool_for_sandbox(&SandboxPolicy::ReadOnly);
|
||||
let tool = super::create_shell_tool_for_request();
|
||||
let OpenAiTool::Function(ResponsesApiTool {
|
||||
description, name, ..
|
||||
}) = &tool
|
||||
@@ -1132,27 +1112,13 @@ The shell tool is used to execute shell commands.
|
||||
};
|
||||
assert_eq!(name, "shell");
|
||||
|
||||
let expected = r#"
|
||||
The shell tool is used to execute shell commands.
|
||||
- When invoking the shell tool, your call will be running in a sandbox, and some shell commands (including apply_patch) will require escalated permissions:
|
||||
- Types of actions that require escalated privileges:
|
||||
- Writing files
|
||||
- Applying patches
|
||||
- Examples of commands that require escalated privileges:
|
||||
- apply_patch
|
||||
- git commit
|
||||
- npm install or pnpm install
|
||||
- cargo build
|
||||
- cargo test
|
||||
- When invoking a command that will require escalated privileges:
|
||||
- Provide the with_escalated_permissions parameter with the boolean value true
|
||||
- Include a short, 1 sentence explanation for why we need to run with_escalated_permissions in the justification parameter"#;
|
||||
let expected = super::SHELL_TOOL_DESCRIPTION;
|
||||
assert_eq!(description, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_shell_tool_for_sandbox_danger_full_access() {
|
||||
let tool = super::create_shell_tool_for_sandbox(&SandboxPolicy::DangerFullAccess);
|
||||
let tool = super::create_shell_tool_for_request();
|
||||
let OpenAiTool::Function(ResponsesApiTool {
|
||||
description, name, ..
|
||||
}) = &tool
|
||||
@@ -1161,6 +1127,7 @@ The shell tool is used to execute shell commands.
|
||||
};
|
||||
assert_eq!(name, "shell");
|
||||
|
||||
assert_eq!(description, "Runs a shell command and returns its output.");
|
||||
let expected = super::SHELL_TOOL_DESCRIPTION;
|
||||
assert_eq!(description, expected);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -868,7 +868,7 @@ pub fn parse_command_impl(command: &[String]) -> Vec<ParsedCommand> {
|
||||
let parts = if contains_connectors(&normalized) {
|
||||
split_on_connectors(&normalized)
|
||||
} else {
|
||||
vec![normalized.clone()]
|
||||
vec![normalized]
|
||||
};
|
||||
|
||||
// Preserve left-to-right execution order for all commands, including bash -c/-lc
|
||||
@@ -1201,10 +1201,7 @@ fn parse_bash_lc_commands(original: &[String]) -> Option<Vec<ParsedCommand>> {
|
||||
name,
|
||||
}
|
||||
} else {
|
||||
ParsedCommand::Read {
|
||||
cmd: cmd.clone(),
|
||||
name,
|
||||
}
|
||||
ParsedCommand::Read { cmd, name }
|
||||
}
|
||||
} else {
|
||||
ParsedCommand::Read {
|
||||
@@ -1215,10 +1212,7 @@ fn parse_bash_lc_commands(original: &[String]) -> Option<Vec<ParsedCommand>> {
|
||||
}
|
||||
ParsedCommand::ListFiles { path, cmd, .. } => {
|
||||
if had_connectors {
|
||||
ParsedCommand::ListFiles {
|
||||
cmd: cmd.clone(),
|
||||
path,
|
||||
}
|
||||
ParsedCommand::ListFiles { cmd, path }
|
||||
} else {
|
||||
ParsedCommand::ListFiles {
|
||||
cmd: shlex_join(&script_tokens),
|
||||
@@ -1230,11 +1224,7 @@ fn parse_bash_lc_commands(original: &[String]) -> Option<Vec<ParsedCommand>> {
|
||||
query, path, cmd, ..
|
||||
} => {
|
||||
if had_connectors {
|
||||
ParsedCommand::Search {
|
||||
cmd: cmd.clone(),
|
||||
query,
|
||||
path,
|
||||
}
|
||||
ParsedCommand::Search { cmd, query, path }
|
||||
} else {
|
||||
ParsedCommand::Search {
|
||||
cmd: shlex_join(&script_tokens),
|
||||
|
||||
@@ -26,7 +26,7 @@ const PROJECT_DOC_SEPARATOR: &str = "\n\n--- project-doc ---\n\n";
|
||||
|
||||
/// Combines `Config::instructions` and `AGENTS.md` (if present) into a single
|
||||
/// string of instructions.
|
||||
pub(crate) async fn get_user_instructions(config: &Config) -> Option<String> {
|
||||
pub async fn get_user_instructions(config: &Config) -> Option<String> {
|
||||
match read_project_docs(config).await {
|
||||
Ok(Some(project_doc)) => match &config.user_instructions {
|
||||
Some(original_instructions) => Some(format!(
|
||||
@@ -115,7 +115,7 @@ pub fn discover_project_doc_paths(config: &Config) -> std::io::Result<Vec<PathBu
|
||||
// Build chain from cwd upwards and detect git root.
|
||||
let mut chain: Vec<PathBuf> = vec![dir.clone()];
|
||||
let mut git_root: Option<PathBuf> = None;
|
||||
let mut cursor = dir.clone();
|
||||
let mut cursor = dir;
|
||||
while let Some(parent) = cursor.parent() {
|
||||
let git_marker = cursor.join(".git");
|
||||
let git_exists = match std::fs::metadata(&git_marker) {
|
||||
|
||||
@@ -10,7 +10,9 @@ use time::macros::format_description;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::SESSIONS_SUBDIR;
|
||||
use super::recorder::SessionMetaWithGit;
|
||||
use crate::protocol::EventMsg;
|
||||
use codex_protocol::protocol::RolloutItem;
|
||||
use codex_protocol::protocol::RolloutLine;
|
||||
|
||||
/// Returned page of conversation summaries.
|
||||
#[derive(Debug, Default, PartialEq)]
|
||||
@@ -35,7 +37,7 @@ pub struct ConversationItem {
|
||||
}
|
||||
|
||||
/// Hard cap to bound worst‑case work per request.
|
||||
const MAX_SCAN_FILES: usize = 10_000;
|
||||
const MAX_SCAN_FILES: usize = 100;
|
||||
const HEAD_RECORD_LIMIT: usize = 10;
|
||||
|
||||
/// Pagination cursor identifying a file by timestamp and UUID.
|
||||
@@ -168,10 +170,14 @@ async fn traverse_directories_for_paths(
|
||||
if items.len() == page_size {
|
||||
break 'outer;
|
||||
}
|
||||
let head = read_first_jsonl_records(&path, HEAD_RECORD_LIMIT)
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
if should_include_session(&head) {
|
||||
// Read head and simultaneously detect message events within the same
|
||||
// first N JSONL records to avoid a second file read.
|
||||
let (head, saw_session_meta, saw_user_event) =
|
||||
read_head_and_flags(&path, HEAD_RECORD_LIMIT)
|
||||
.await
|
||||
.unwrap_or((Vec::new(), false, false));
|
||||
// Apply filters: must have session meta and at least one user message event
|
||||
if saw_session_meta && saw_user_event {
|
||||
items.push(ConversationItem { path, head });
|
||||
}
|
||||
}
|
||||
@@ -276,16 +282,19 @@ fn parse_timestamp_uuid_from_filename(name: &str) -> Option<(OffsetDateTime, Uui
|
||||
Some((ts, uuid))
|
||||
}
|
||||
|
||||
async fn read_first_jsonl_records(
|
||||
async fn read_head_and_flags(
|
||||
path: &Path,
|
||||
max_records: usize,
|
||||
) -> io::Result<Vec<serde_json::Value>> {
|
||||
) -> io::Result<(Vec<serde_json::Value>, bool, bool)> {
|
||||
use tokio::io::AsyncBufReadExt;
|
||||
|
||||
let file = tokio::fs::File::open(path).await?;
|
||||
let reader = tokio::io::BufReader::new(file);
|
||||
let mut lines = reader.lines();
|
||||
let mut head: Vec<serde_json::Value> = Vec::new();
|
||||
let mut saw_session_meta = false;
|
||||
let mut saw_user_event = false;
|
||||
|
||||
while head.len() < max_records {
|
||||
let line_opt = lines.next_line().await?;
|
||||
let Some(line) = line_opt else { break };
|
||||
@@ -293,43 +302,35 @@ async fn read_first_jsonl_records(
|
||||
if trimmed.is_empty() {
|
||||
continue;
|
||||
}
|
||||
if let Ok(v) = serde_json::from_str::<serde_json::Value>(trimmed) {
|
||||
head.push(v);
|
||||
|
||||
let parsed: Result<RolloutLine, _> = serde_json::from_str(trimmed);
|
||||
let Ok(rollout_line) = parsed else { continue };
|
||||
|
||||
match rollout_line.item {
|
||||
RolloutItem::SessionMeta(session_meta_line) => {
|
||||
if let Ok(val) = serde_json::to_value(session_meta_line) {
|
||||
head.push(val);
|
||||
saw_session_meta = true;
|
||||
}
|
||||
}
|
||||
RolloutItem::ResponseItem(item) => {
|
||||
if let Ok(val) = serde_json::to_value(item) {
|
||||
head.push(val);
|
||||
}
|
||||
}
|
||||
RolloutItem::TurnContext(_) => {
|
||||
// Not included in `head`; skip.
|
||||
}
|
||||
RolloutItem::Compacted(_) => {
|
||||
// Not included in `head`; skip.
|
||||
}
|
||||
RolloutItem::EventMsg(ev) => {
|
||||
if matches!(ev, EventMsg::UserMessage(_)) {
|
||||
saw_user_event = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(head)
|
||||
}
|
||||
|
||||
/// Return true if this conversation should be included in the listing.
|
||||
///
|
||||
/// Current rule: include only when the first JSON object is a session meta record
|
||||
/// (i.e., has `{"record_type": "session_meta", ...}`), which is how rollout
|
||||
/// files are written. Empty or malformed heads are excluded.
|
||||
fn should_include_session(head: &[serde_json::Value]) -> bool {
|
||||
let Some(first) = head.first() else {
|
||||
return false;
|
||||
};
|
||||
passes_session_meta_filter(first)
|
||||
}
|
||||
|
||||
/// Validate that the first record is a fully‑formed session meta line.
|
||||
///
|
||||
/// Requirements:
|
||||
/// - `record_type == "session_meta"`
|
||||
/// - Remaining fields (after removing `record_type`) deserialize into
|
||||
/// `SessionMetaWithGit`.
|
||||
fn passes_session_meta_filter(first: &serde_json::Value) -> bool {
|
||||
let Some(obj) = first.as_object() else {
|
||||
return false;
|
||||
};
|
||||
let record_type = obj.get("record_type").and_then(|v| v.as_str());
|
||||
if record_type != Some("session_meta") {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Remove the marker field and validate the remainder matches SessionMetaWithGit
|
||||
let mut cleaned = obj.clone();
|
||||
cleaned.remove("record_type");
|
||||
let val = serde_json::Value::Object(cleaned);
|
||||
serde_json::from_value::<SessionMetaWithGit>(val).is_ok()
|
||||
Ok((head, saw_session_meta, saw_user_event))
|
||||
}
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
//! Rollout module: persistence and discovery of session rollout files.
|
||||
|
||||
pub(crate) const SESSIONS_SUBDIR: &str = "sessions";
|
||||
pub const SESSIONS_SUBDIR: &str = "sessions";
|
||||
pub const ARCHIVED_SESSIONS_SUBDIR: &str = "archived_sessions";
|
||||
|
||||
pub mod list;
|
||||
pub(crate) mod policy;
|
||||
pub mod recorder;
|
||||
|
||||
pub use recorder::RolloutItem;
|
||||
pub use codex_protocol::protocol::SessionMeta;
|
||||
pub use recorder::RolloutRecorder;
|
||||
pub use recorder::SessionMeta;
|
||||
pub use recorder::SessionStateSnapshot;
|
||||
pub use recorder::RolloutRecorderParams;
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod tests;
|
||||
|
||||
@@ -1,10 +1,23 @@
|
||||
use crate::protocol::EventMsg;
|
||||
use crate::protocol::RolloutItem;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::protocol::Event;
|
||||
use codex_protocol::protocol::EventMsg;
|
||||
|
||||
/// Whether a rollout `item` should be persisted in rollout files.
|
||||
#[inline]
|
||||
pub(crate) fn is_persisted_response_item(item: &RolloutItem) -> bool {
|
||||
match item {
|
||||
RolloutItem::ResponseItem(item) => should_persist_response_item(item),
|
||||
RolloutItem::EventMsg(ev) => should_persist_event_msg(ev),
|
||||
// Persist Codex executive markers so we can analyze flows (e.g., compaction, API turns).
|
||||
RolloutItem::Compacted(_) | RolloutItem::TurnContext(_) | RolloutItem::SessionMeta(_) => {
|
||||
true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Whether a `ResponseItem` should be persisted in rollout files.
|
||||
#[inline]
|
||||
pub(crate) fn is_persisted_response_item(item: &ResponseItem) -> bool {
|
||||
pub(crate) fn should_persist_response_item(item: &ResponseItem) -> bool {
|
||||
match item {
|
||||
ResponseItem::Message { .. }
|
||||
| ResponseItem::Reasoning { .. }
|
||||
@@ -17,41 +30,43 @@ pub(crate) fn is_persisted_response_item(item: &ResponseItem) -> bool {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn is_persisted_event(event: &Event) -> bool {
|
||||
match event.msg {
|
||||
EventMsg::ExecApprovalRequest(_)
|
||||
| EventMsg::ApplyPatchApprovalRequest(_)
|
||||
| EventMsg::AgentReasoningDelta(_)
|
||||
| EventMsg::AgentReasoningRawContentDelta(_)
|
||||
| EventMsg::ExecCommandOutputDelta(_)
|
||||
| EventMsg::GetHistoryEntryResponse(_)
|
||||
| EventMsg::AgentMessageDelta(_)
|
||||
| EventMsg::TaskStarted(_)
|
||||
| EventMsg::TaskComplete(_)
|
||||
| EventMsg::McpToolCallBegin(_)
|
||||
| EventMsg::McpToolCallEnd(_)
|
||||
| EventMsg::WebSearchBegin(_)
|
||||
| EventMsg::WebSearchEnd(_)
|
||||
| EventMsg::ExecCommandBegin(_)
|
||||
| EventMsg::ExecCommandEnd(_)
|
||||
| EventMsg::PatchApplyBegin(_)
|
||||
| EventMsg::PatchApplyEnd(_)
|
||||
| EventMsg::TurnDiff(_)
|
||||
| EventMsg::BackgroundEvent(_)
|
||||
| EventMsg::McpListToolsResponse(_)
|
||||
| EventMsg::ListCustomPromptsResponse(_)
|
||||
| EventMsg::ShutdownComplete
|
||||
| EventMsg::ConversationHistory(_)
|
||||
| EventMsg::PlanUpdate(_)
|
||||
| EventMsg::TurnAborted(_)
|
||||
| EventMsg::StreamError(_)
|
||||
| EventMsg::Error(_)
|
||||
| EventMsg::AgentReasoningSectionBreak(_)
|
||||
| EventMsg::SessionConfigured(_) => false,
|
||||
/// Whether an `EventMsg` should be persisted in rollout files.
|
||||
#[inline]
|
||||
pub(crate) fn should_persist_event_msg(ev: &EventMsg) -> bool {
|
||||
match ev {
|
||||
EventMsg::UserMessage(_)
|
||||
| EventMsg::AgentMessage(_)
|
||||
| EventMsg::AgentReasoning(_)
|
||||
| EventMsg::AgentReasoningRawContent(_)
|
||||
| EventMsg::TokenCount(_) => true,
|
||||
EventMsg::Error(_)
|
||||
| EventMsg::TaskStarted(_)
|
||||
| EventMsg::TaskComplete(_)
|
||||
| EventMsg::AgentMessageDelta(_)
|
||||
| EventMsg::AgentReasoningDelta(_)
|
||||
| EventMsg::AgentReasoningRawContentDelta(_)
|
||||
| EventMsg::AgentReasoningSectionBreak(_)
|
||||
| EventMsg::SessionConfigured(_)
|
||||
| EventMsg::McpToolCallBegin(_)
|
||||
| EventMsg::McpToolCallEnd(_)
|
||||
| EventMsg::WebSearchBegin(_)
|
||||
| EventMsg::WebSearchEnd(_)
|
||||
| EventMsg::ExecCommandBegin(_)
|
||||
| EventMsg::ExecCommandOutputDelta(_)
|
||||
| EventMsg::ExecCommandEnd(_)
|
||||
| EventMsg::ExecApprovalRequest(_)
|
||||
| EventMsg::ApplyPatchApprovalRequest(_)
|
||||
| EventMsg::BackgroundEvent(_)
|
||||
| EventMsg::StreamError(_)
|
||||
| EventMsg::PatchApplyBegin(_)
|
||||
| EventMsg::PatchApplyEnd(_)
|
||||
| EventMsg::TurnDiff(_)
|
||||
| EventMsg::GetHistoryEntryResponse(_)
|
||||
| EventMsg::McpListToolsResponse(_)
|
||||
| EventMsg::ListCustomPromptsResponse(_)
|
||||
| EventMsg::PlanUpdate(_)
|
||||
| EventMsg::TurnAborted(_)
|
||||
| EventMsg::ShutdownComplete
|
||||
| EventMsg::ConversationPath(_) => false,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,10 +7,9 @@ use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use codex_protocol::mcp_protocol::ConversationId;
|
||||
use codex_protocol::protocol::Event;
|
||||
use codex_protocol::protocol::EventMsg;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use serde_json::Value;
|
||||
use time::OffsetDateTime;
|
||||
use time::format_description::FormatItem;
|
||||
use time::macros::format_description;
|
||||
@@ -27,33 +26,17 @@ use super::list::Cursor;
|
||||
use super::list::get_conversations;
|
||||
use super::policy::is_persisted_response_item;
|
||||
use crate::config::Config;
|
||||
use crate::conversation_manager::InitialHistory;
|
||||
use crate::conversation_manager::ResumedHistory;
|
||||
use crate::git_info::GitInfo;
|
||||
use crate::default_client::ORIGINATOR;
|
||||
use crate::git_info::collect_git_info;
|
||||
use crate::rollout::policy::is_persisted_event;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::protocol::InitialHistory;
|
||||
use codex_protocol::protocol::ResumedHistory;
|
||||
use codex_protocol::protocol::RolloutItem;
|
||||
use codex_protocol::protocol::RolloutLine;
|
||||
use codex_protocol::protocol::SessionMeta;
|
||||
use codex_protocol::protocol::SessionMetaLine;
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Default, Debug)]
|
||||
pub struct SessionMeta {
|
||||
pub id: ConversationId,
|
||||
pub timestamp: String,
|
||||
pub cwd: String,
|
||||
pub originator: String,
|
||||
pub cli_version: String,
|
||||
pub instructions: Option<String>,
|
||||
}
|
||||
|
||||
// SessionMetaWithGit is used in writes and reads; ensure it implements Debug.
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
pub struct SessionMetaWithGit {
|
||||
#[serde(flatten)]
|
||||
meta: SessionMeta,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
git: Option<GitInfo>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Default, Clone, Debug)]
|
||||
#[derive(Serialize, Deserialize, Default, Clone)]
|
||||
pub struct SessionStateSnapshot {}
|
||||
|
||||
#[derive(Serialize, Deserialize, Default, Clone)]
|
||||
@@ -78,99 +61,45 @@ pub struct SavedSession {
|
||||
#[derive(Clone)]
|
||||
pub struct RolloutRecorder {
|
||||
tx: Sender<RolloutCmd>,
|
||||
path: PathBuf,
|
||||
pub(crate) rollout_path: PathBuf,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[serde(tag = "record_type", rename_all = "snake_case")]
|
||||
enum TaggedLine {
|
||||
Response {
|
||||
#[serde(flatten)]
|
||||
item: ResponseItem,
|
||||
#[derive(Clone)]
|
||||
pub enum RolloutRecorderParams {
|
||||
Create {
|
||||
conversation_id: ConversationId,
|
||||
instructions: Option<String>,
|
||||
},
|
||||
Event {
|
||||
#[serde(flatten)]
|
||||
event: Event,
|
||||
Resume {
|
||||
path: PathBuf,
|
||||
},
|
||||
SessionMeta {
|
||||
#[serde(flatten)]
|
||||
meta: SessionMetaWithGit,
|
||||
},
|
||||
PrevSessionMeta {
|
||||
#[serde(flatten)]
|
||||
meta: SessionMetaWithGit,
|
||||
},
|
||||
State {
|
||||
#[serde(flatten)]
|
||||
state: SessionStateSnapshot,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
struct TimestampedLine {
|
||||
timestamp: String,
|
||||
#[serde(flatten)]
|
||||
record: TaggedLine,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum RolloutItem {
|
||||
ResponseItem(ResponseItem),
|
||||
Event(Event),
|
||||
SessionMeta(SessionMetaWithGit),
|
||||
}
|
||||
|
||||
impl From<ResponseItem> for RolloutItem {
|
||||
fn from(item: ResponseItem) -> Self {
|
||||
RolloutItem::ResponseItem(item)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Event> for RolloutItem {
|
||||
fn from(event: Event) -> Self {
|
||||
RolloutItem::Event(event)
|
||||
}
|
||||
}
|
||||
|
||||
/// Convenience helpers to extract typed items from a list of rollout items.
|
||||
pub trait RolloutItemSliceExt {
|
||||
fn get_response_items(&self) -> Vec<ResponseItem>;
|
||||
fn get_events(&self) -> Vec<EventMsg>;
|
||||
}
|
||||
|
||||
impl RolloutItemSliceExt for [RolloutItem] {
|
||||
fn get_response_items(&self) -> Vec<ResponseItem> {
|
||||
self.iter()
|
||||
.filter_map(|it| match it {
|
||||
RolloutItem::ResponseItem(ri) => Some(ri.clone()),
|
||||
_ => None,
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn get_events(&self) -> Vec<EventMsg> {
|
||||
self.iter()
|
||||
.filter_map(|it| match it {
|
||||
RolloutItem::Event(ev) => Some(ev.msg.clone()),
|
||||
_ => None,
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
enum RolloutCmd {
|
||||
AddResponseItems(Vec<ResponseItem>),
|
||||
AddEvents(Vec<Event>),
|
||||
AddSessionMeta(SessionMetaWithGit),
|
||||
Flush { ack: oneshot::Sender<()> },
|
||||
Shutdown { ack: oneshot::Sender<()> },
|
||||
AddItems(Vec<RolloutItem>),
|
||||
/// Ensure all prior writes are processed; respond when flushed.
|
||||
Flush {
|
||||
ack: oneshot::Sender<()>,
|
||||
},
|
||||
Shutdown {
|
||||
ack: oneshot::Sender<()>,
|
||||
},
|
||||
}
|
||||
|
||||
impl RolloutRecorderParams {
|
||||
pub fn new(conversation_id: ConversationId, instructions: Option<String>) -> Self {
|
||||
Self::Create {
|
||||
conversation_id,
|
||||
instructions,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn resume(path: PathBuf) -> Self {
|
||||
Self::Resume { path }
|
||||
}
|
||||
}
|
||||
|
||||
impl RolloutRecorder {
|
||||
pub fn path(&self) -> &Path {
|
||||
&self.path
|
||||
}
|
||||
#[allow(dead_code)]
|
||||
/// List conversations (rollout files) under the provided Codex home directory.
|
||||
pub async fn list_conversations(
|
||||
codex_home: &Path,
|
||||
@@ -183,157 +112,145 @@ impl RolloutRecorder {
|
||||
/// Attempt to create a new [`RolloutRecorder`]. If the sessions directory
|
||||
/// cannot be created or the rollout file cannot be opened we return the
|
||||
/// error so the caller can decide whether to disable persistence.
|
||||
pub async fn new(
|
||||
config: &Config,
|
||||
conversation_id: ConversationId,
|
||||
instructions: Option<String>,
|
||||
) -> std::io::Result<Self> {
|
||||
let LogFileInfo {
|
||||
file,
|
||||
conversation_id: session_id,
|
||||
timestamp,
|
||||
path,
|
||||
} = create_log_file(config, conversation_id)?;
|
||||
|
||||
let timestamp_format: &[FormatItem] =
|
||||
format_description!("[year]-[month]-[day]T[hour]:[minute]:[second]Z");
|
||||
let timestamp = timestamp
|
||||
.to_offset(time::UtcOffset::UTC)
|
||||
.format(timestamp_format)
|
||||
.map_err(|e| IoError::other(format!("failed to format timestamp: {e}")))?;
|
||||
|
||||
let cwd = config.cwd.to_path_buf();
|
||||
|
||||
let (tx, rx) = mpsc::channel(100);
|
||||
|
||||
tokio::task::spawn(rollout_writer(
|
||||
tokio::fs::File::from_std(file),
|
||||
rx,
|
||||
Some(SessionMeta {
|
||||
timestamp,
|
||||
id: session_id,
|
||||
cwd: config.cwd.to_string_lossy().to_string(),
|
||||
originator: config.responses_originator_header.clone(),
|
||||
cli_version: env!("CARGO_PKG_VERSION").to_string(),
|
||||
pub async fn new(config: &Config, params: RolloutRecorderParams) -> std::io::Result<Self> {
|
||||
let (file, rollout_path, meta) = match params {
|
||||
RolloutRecorderParams::Create {
|
||||
conversation_id,
|
||||
instructions,
|
||||
}),
|
||||
cwd,
|
||||
));
|
||||
} => {
|
||||
let LogFileInfo {
|
||||
file,
|
||||
path,
|
||||
conversation_id: session_id,
|
||||
timestamp,
|
||||
} = create_log_file(config, conversation_id)?;
|
||||
|
||||
Ok(Self { tx, path })
|
||||
let timestamp_format: &[FormatItem] = format_description!(
|
||||
"[year]-[month]-[day]T[hour]:[minute]:[second].[subsecond digits:3]Z"
|
||||
);
|
||||
let timestamp = timestamp
|
||||
.to_offset(time::UtcOffset::UTC)
|
||||
.format(timestamp_format)
|
||||
.map_err(|e| IoError::other(format!("failed to format timestamp: {e}")))?;
|
||||
|
||||
(
|
||||
tokio::fs::File::from_std(file),
|
||||
path,
|
||||
Some(SessionMeta {
|
||||
id: session_id,
|
||||
timestamp,
|
||||
cwd: config.cwd.clone(),
|
||||
originator: ORIGINATOR.value.clone(),
|
||||
cli_version: env!("CARGO_PKG_VERSION").to_string(),
|
||||
instructions,
|
||||
}),
|
||||
)
|
||||
}
|
||||
RolloutRecorderParams::Resume { path } => (
|
||||
tokio::fs::OpenOptions::new()
|
||||
.append(true)
|
||||
.open(&path)
|
||||
.await?,
|
||||
path,
|
||||
None,
|
||||
),
|
||||
};
|
||||
|
||||
// Clone the cwd for the spawned task to collect git info asynchronously
|
||||
let cwd = config.cwd.clone();
|
||||
|
||||
// A reasonably-sized bounded channel. If the buffer fills up the send
|
||||
// future will yield, which is fine – we only need to ensure we do not
|
||||
// perform *blocking* I/O on the caller's thread.
|
||||
let (tx, rx) = mpsc::channel::<RolloutCmd>(256);
|
||||
|
||||
// Spawn a Tokio task that owns the file handle and performs async
|
||||
// writes. Using `tokio::fs::File` keeps everything on the async I/O
|
||||
// driver instead of blocking the runtime.
|
||||
tokio::task::spawn(rollout_writer(file, rx, meta, cwd));
|
||||
|
||||
Ok(Self { tx, rollout_path })
|
||||
}
|
||||
|
||||
pub(crate) async fn record_items(&self, item: RolloutItem) -> std::io::Result<()> {
|
||||
match item {
|
||||
RolloutItem::ResponseItem(item) => self.record_response_item(&item).await,
|
||||
RolloutItem::Event(event) => self.record_event(&event).await,
|
||||
RolloutItem::SessionMeta(meta) => self.record_session_meta(&meta).await,
|
||||
pub(crate) async fn record_items(&self, items: &[RolloutItem]) -> std::io::Result<()> {
|
||||
let mut filtered = Vec::new();
|
||||
for item in items {
|
||||
// Note that function calls may look a bit strange if they are
|
||||
// "fully qualified MCP tool calls," so we could consider
|
||||
// reformatting them in that case.
|
||||
if is_persisted_response_item(item) {
|
||||
filtered.push(item.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Ensure all writes up to this point have been processed by the writer task.
|
||||
///
|
||||
/// This is a sequencing barrier for readers that plan to open and read the
|
||||
/// rollout file immediately after calling this method. The background writer
|
||||
/// processes the channel serially; when it dequeues `Flush`, all prior
|
||||
/// `AddResponseItems`/`AddEvents`/`AddSessionMeta` have already been written
|
||||
/// via `write_line`, which calls `file.flush()` (OS‐buffer flush).
|
||||
pub async fn flush(&self) -> std::io::Result<()> {
|
||||
let (tx_done, rx_done) = oneshot::channel();
|
||||
self.tx
|
||||
.send(RolloutCmd::Flush { ack: tx_done })
|
||||
.await
|
||||
.map_err(|e| IoError::other(format!("failed to queue rollout flush: {e}")))?;
|
||||
rx_done
|
||||
.await
|
||||
.map_err(|e| IoError::other(format!("failed waiting for rollout flush: {e}")))
|
||||
}
|
||||
|
||||
async fn record_response_item(&self, item: &ResponseItem) -> std::io::Result<()> {
|
||||
// Note that function calls may look a bit strange if they are
|
||||
// "fully qualified MCP tool calls," so we could consider
|
||||
// reformatting them in that case.
|
||||
if !is_persisted_response_item(item) {
|
||||
if filtered.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
self.tx
|
||||
.send(RolloutCmd::AddResponseItems(vec![item.clone()]))
|
||||
.send(RolloutCmd::AddItems(filtered))
|
||||
.await
|
||||
.map_err(|e| IoError::other(format!("failed to queue rollout items: {e}")))
|
||||
}
|
||||
|
||||
async fn record_event(&self, event: &Event) -> std::io::Result<()> {
|
||||
if !is_persisted_event(event) {
|
||||
return Ok(());
|
||||
}
|
||||
/// Flush all queued writes and wait until they are committed by the writer task.
|
||||
pub async fn flush(&self) -> std::io::Result<()> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
self.tx
|
||||
.send(RolloutCmd::AddEvents(vec![event.clone()]))
|
||||
.send(RolloutCmd::Flush { ack: tx })
|
||||
.await
|
||||
.map_err(|e| IoError::other(format!("failed to queue rollout event: {e}")))
|
||||
.map_err(|e| IoError::other(format!("failed to queue rollout flush: {e}")))?;
|
||||
rx.await
|
||||
.map_err(|e| IoError::other(format!("failed waiting for rollout flush: {e}")))
|
||||
}
|
||||
|
||||
async fn record_session_meta(&self, meta: &SessionMetaWithGit) -> std::io::Result<()> {
|
||||
self.tx
|
||||
.send(RolloutCmd::AddSessionMeta(meta.clone()))
|
||||
.await
|
||||
.map_err(|e| IoError::other(format!("failed to queue rollout session meta: {e}")))
|
||||
}
|
||||
|
||||
pub async fn get_rollout_history(path: &Path) -> std::io::Result<InitialHistory> {
|
||||
pub(crate) async fn get_rollout_history(path: &Path) -> std::io::Result<InitialHistory> {
|
||||
info!("Resuming rollout from {path:?}");
|
||||
tracing::error!("Resuming rollout from {path:?}");
|
||||
let text = tokio::fs::read_to_string(path).await?;
|
||||
let mut lines = text.lines();
|
||||
let first_line = lines
|
||||
.next()
|
||||
.ok_or_else(|| IoError::other("empty session file"))?;
|
||||
let conversation_id = if let Ok(TimestampedLine {
|
||||
record: TaggedLine::SessionMeta { meta },
|
||||
..
|
||||
}) = serde_json::from_str::<TimestampedLine>(first_line)
|
||||
{
|
||||
Some(meta.meta.id)
|
||||
} else if let Ok(meta) = serde_json::from_str::<SessionMetaWithGit>(first_line) {
|
||||
Some(meta.meta.id)
|
||||
} else if let Ok(meta) = serde_json::from_str::<SessionMeta>(first_line) {
|
||||
Some(meta.id)
|
||||
} else {
|
||||
return Err(IoError::other(
|
||||
"failed to parse first line of rollout file as SessionMeta",
|
||||
));
|
||||
};
|
||||
if text.trim().is_empty() {
|
||||
return Err(IoError::other("empty session file"));
|
||||
}
|
||||
|
||||
let mut items: Vec<RolloutItem> = Vec::new();
|
||||
for line in lines {
|
||||
let mut conversation_id: Option<ConversationId> = None;
|
||||
for line in text.lines() {
|
||||
if line.trim().is_empty() {
|
||||
continue;
|
||||
}
|
||||
match serde_json::from_str::<TimestampedLine>(line) {
|
||||
Ok(TimestampedLine {
|
||||
record: TaggedLine::State { .. },
|
||||
..
|
||||
}) => {}
|
||||
Ok(TimestampedLine {
|
||||
record: TaggedLine::Event { event },
|
||||
..
|
||||
}) => items.push(RolloutItem::Event(event)),
|
||||
Ok(TimestampedLine {
|
||||
record: TaggedLine::SessionMeta { meta },
|
||||
..
|
||||
})
|
||||
| Ok(TimestampedLine {
|
||||
record: TaggedLine::PrevSessionMeta { meta },
|
||||
..
|
||||
}) => items.push(RolloutItem::SessionMeta(meta)),
|
||||
Ok(TimestampedLine {
|
||||
record: TaggedLine::Response { item },
|
||||
..
|
||||
}) => {
|
||||
if is_persisted_response_item(&item) {
|
||||
let v: Value = match serde_json::from_str(line) {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
warn!("failed to parse line as JSON: {line:?}, error: {e}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// Parse the rollout line structure
|
||||
match serde_json::from_value::<RolloutLine>(v.clone()) {
|
||||
Ok(rollout_line) => match rollout_line.item {
|
||||
RolloutItem::SessionMeta(session_meta_line) => {
|
||||
// Use the FIRST SessionMeta encountered in the file as the canonical
|
||||
// conversation id and main session information. Keep all items intact.
|
||||
if conversation_id.is_none() {
|
||||
conversation_id = Some(session_meta_line.meta.id);
|
||||
}
|
||||
items.push(RolloutItem::SessionMeta(session_meta_line));
|
||||
}
|
||||
RolloutItem::ResponseItem(item) => {
|
||||
items.push(RolloutItem::ResponseItem(item));
|
||||
}
|
||||
RolloutItem::Compacted(item) => {
|
||||
items.push(RolloutItem::Compacted(item));
|
||||
}
|
||||
RolloutItem::TurnContext(item) => {
|
||||
items.push(RolloutItem::TurnContext(item));
|
||||
}
|
||||
RolloutItem::EventMsg(_ev) => {
|
||||
items.push(RolloutItem::EventMsg(_ev));
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
warn!("failed to parse rollout line: {v:?}, error: {e}");
|
||||
}
|
||||
Err(_) => warn!("failed to parse rollout line: {line}"),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -357,6 +274,10 @@ impl RolloutRecorder {
|
||||
}))
|
||||
}
|
||||
|
||||
pub(crate) fn get_rollout_path(&self) -> PathBuf {
|
||||
self.rollout_path.clone()
|
||||
}
|
||||
|
||||
pub async fn shutdown(&self) -> std::io::Result<()> {
|
||||
let (tx_done, rx_done) = oneshot::channel();
|
||||
match self.tx.send(RolloutCmd::Shutdown { ack: tx_done }).await {
|
||||
@@ -377,14 +298,14 @@ struct LogFileInfo {
|
||||
/// Opened file handle to the rollout file.
|
||||
file: File,
|
||||
|
||||
/// Full path to the rollout file.
|
||||
path: PathBuf,
|
||||
|
||||
/// Session ID (also embedded in filename).
|
||||
conversation_id: ConversationId,
|
||||
|
||||
/// Timestamp for the start of the session.
|
||||
timestamp: OffsetDateTime,
|
||||
|
||||
/// Full filesystem path to the rollout file.
|
||||
path: PathBuf,
|
||||
}
|
||||
|
||||
fn create_log_file(
|
||||
@@ -392,7 +313,8 @@ fn create_log_file(
|
||||
conversation_id: ConversationId,
|
||||
) -> std::io::Result<LogFileInfo> {
|
||||
// Resolve ~/.codex/sessions/YYYY/MM/DD and create it if missing.
|
||||
let timestamp = OffsetDateTime::now_utc();
|
||||
let timestamp = OffsetDateTime::now_local()
|
||||
.map_err(|e| IoError::other(format!("failed to get local time: {e}")))?;
|
||||
let mut dir = config.codex_home.clone();
|
||||
dir.push(SESSIONS_SUBDIR);
|
||||
dir.push(timestamp.year().to_string());
|
||||
@@ -418,9 +340,9 @@ fn create_log_file(
|
||||
|
||||
Ok(LogFileInfo {
|
||||
file,
|
||||
path,
|
||||
conversation_id,
|
||||
timestamp,
|
||||
path,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -435,43 +357,35 @@ async fn rollout_writer(
|
||||
// If we have a meta, collect git info asynchronously and write meta first
|
||||
if let Some(session_meta) = meta.take() {
|
||||
let git_info = collect_git_info(&cwd).await;
|
||||
let session_meta_with_git = SessionMetaWithGit {
|
||||
let session_meta_line = SessionMetaLine {
|
||||
meta: session_meta,
|
||||
git: git_info,
|
||||
};
|
||||
// Write the SessionMeta as the first item in the file
|
||||
|
||||
// Write the SessionMeta as the first item in the file, wrapped in a rollout line
|
||||
writer
|
||||
.write_tagged(TaggedLine::SessionMeta {
|
||||
meta: session_meta_with_git,
|
||||
})
|
||||
.write_rollout_item(RolloutItem::SessionMeta(session_meta_line))
|
||||
.await?;
|
||||
}
|
||||
|
||||
// Process rollout commands
|
||||
while let Some(cmd) = rx.recv().await {
|
||||
match cmd {
|
||||
RolloutCmd::AddResponseItems(items) => {
|
||||
RolloutCmd::AddItems(items) => {
|
||||
for item in items {
|
||||
if is_persisted_response_item(&item) {
|
||||
writer.write_tagged(TaggedLine::Response { item }).await?;
|
||||
writer.write_rollout_item(item).await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
RolloutCmd::AddEvents(events) => {
|
||||
for event in events {
|
||||
writer.write_tagged(TaggedLine::Event { event }).await?;
|
||||
}
|
||||
}
|
||||
// Sequencing barrier: by the time we handle `Flush`, all previously
|
||||
// queued writes have been applied and flushed to OS buffers.
|
||||
RolloutCmd::Flush { ack } => {
|
||||
// Ensure underlying file is flushed and then ack.
|
||||
if let Err(e) = writer.file.flush().await {
|
||||
let _ = ack.send(());
|
||||
return Err(e);
|
||||
}
|
||||
let _ = ack.send(());
|
||||
}
|
||||
RolloutCmd::AddSessionMeta(meta) => {
|
||||
writer
|
||||
.write_tagged(TaggedLine::PrevSessionMeta { meta })
|
||||
.await?;
|
||||
}
|
||||
RolloutCmd::Shutdown { ack } => {
|
||||
let _ = ack.send(());
|
||||
}
|
||||
@@ -486,6 +400,20 @@ struct JsonlWriter {
|
||||
}
|
||||
|
||||
impl JsonlWriter {
|
||||
async fn write_rollout_item(&mut self, rollout_item: RolloutItem) -> std::io::Result<()> {
|
||||
let timestamp_format: &[FormatItem] = format_description!(
|
||||
"[year]-[month]-[day]T[hour]:[minute]:[second].[subsecond digits:3]Z"
|
||||
);
|
||||
let timestamp = OffsetDateTime::now_utc()
|
||||
.format(timestamp_format)
|
||||
.map_err(|e| IoError::other(format!("failed to format timestamp: {e}")))?;
|
||||
|
||||
let line = RolloutLine {
|
||||
timestamp,
|
||||
item: rollout_item,
|
||||
};
|
||||
self.write_line(&line).await
|
||||
}
|
||||
async fn write_line(&mut self, item: &impl serde::Serialize) -> std::io::Result<()> {
|
||||
let mut json = serde_json::to_string(item)?;
|
||||
json.push('\n');
|
||||
@@ -493,12 +421,4 @@ impl JsonlWriter {
|
||||
self.file.flush().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn write_tagged(&mut self, record: TaggedLine) -> std::io::Result<()> {
|
||||
let timestamp = time::OffsetDateTime::now_utc()
|
||||
.format(&time::format_description::well_known::Rfc3339)
|
||||
.map_err(|e| IoError::other(format!("failed to format timestamp: {e}")))?;
|
||||
let line = TimestampedLine { timestamp, record };
|
||||
self.write_line(&line).await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -41,16 +41,31 @@ fn write_session_file(
|
||||
let mut file = File::create(file_path)?;
|
||||
|
||||
let meta = serde_json::json!({
|
||||
"record_type": "session_meta",
|
||||
"timestamp": ts_str,
|
||||
"id": uuid.to_string(),
|
||||
"cwd": "/",
|
||||
"originator": "test",
|
||||
"cli_version": "0.0.0",
|
||||
"instructions": null
|
||||
"type": "session_meta",
|
||||
"payload": {
|
||||
"id": uuid,
|
||||
"timestamp": ts_str,
|
||||
"instructions": null,
|
||||
"cwd": ".",
|
||||
"originator": "test_originator",
|
||||
"cli_version": "test_version"
|
||||
}
|
||||
});
|
||||
writeln!(file, "{meta}")?;
|
||||
|
||||
// Include at least one user message event to satisfy listing filters
|
||||
let user_event = serde_json::json!({
|
||||
"timestamp": ts_str,
|
||||
"type": "event_msg",
|
||||
"payload": {
|
||||
"type": "user_message",
|
||||
"message": "Hello from user",
|
||||
"kind": "plain"
|
||||
}
|
||||
});
|
||||
writeln!(file, "{user_event}")?;
|
||||
|
||||
for i in 0..num_records {
|
||||
let rec = serde_json::json!({
|
||||
"record_type": "response",
|
||||
@@ -61,18 +76,6 @@ fn write_session_file(
|
||||
Ok((dt, uuid))
|
||||
}
|
||||
|
||||
fn expected_session_meta(ts: &str, uuid: Uuid) -> serde_json::Value {
|
||||
serde_json::json!({
|
||||
"record_type": "session_meta",
|
||||
"timestamp": ts,
|
||||
"id": uuid.to_string(),
|
||||
"cwd": "/",
|
||||
"originator": "test",
|
||||
"cli_version": "0.0.0",
|
||||
"instructions": null
|
||||
})
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_list_conversations_latest_first() {
|
||||
let temp = TempDir::new().unwrap();
|
||||
@@ -110,24 +113,30 @@ async fn test_list_conversations_latest_first() {
|
||||
.join("01")
|
||||
.join(format!("rollout-2025-01-01T12-00-00-{u1}.jsonl"));
|
||||
|
||||
let head_3 = vec![
|
||||
expected_session_meta("2025-01-03T12-00-00", u3),
|
||||
serde_json::json!({"record_type": "response", "index": 0}),
|
||||
serde_json::json!({"record_type": "response", "index": 1}),
|
||||
serde_json::json!({"record_type": "response", "index": 2}),
|
||||
];
|
||||
let head_2 = vec![
|
||||
expected_session_meta("2025-01-02T12-00-00", u2),
|
||||
serde_json::json!({"record_type": "response", "index": 0}),
|
||||
serde_json::json!({"record_type": "response", "index": 1}),
|
||||
serde_json::json!({"record_type": "response", "index": 2}),
|
||||
];
|
||||
let head_1 = vec![
|
||||
expected_session_meta("2025-01-01T12-00-00", u1),
|
||||
serde_json::json!({"record_type": "response", "index": 0}),
|
||||
serde_json::json!({"record_type": "response", "index": 1}),
|
||||
serde_json::json!({"record_type": "response", "index": 2}),
|
||||
];
|
||||
let head_3 = vec![serde_json::json!({
|
||||
"id": u3,
|
||||
"timestamp": "2025-01-03T12-00-00",
|
||||
"instructions": null,
|
||||
"cwd": ".",
|
||||
"originator": "test_originator",
|
||||
"cli_version": "test_version"
|
||||
})];
|
||||
let head_2 = vec![serde_json::json!({
|
||||
"id": u2,
|
||||
"timestamp": "2025-01-02T12-00-00",
|
||||
"instructions": null,
|
||||
"cwd": ".",
|
||||
"originator": "test_originator",
|
||||
"cli_version": "test_version"
|
||||
})];
|
||||
let head_1 = vec![serde_json::json!({
|
||||
"id": u1,
|
||||
"timestamp": "2025-01-01T12-00-00",
|
||||
"instructions": null,
|
||||
"cwd": ".",
|
||||
"originator": "test_originator",
|
||||
"cli_version": "test_version"
|
||||
})];
|
||||
|
||||
let expected_cursor: Cursor =
|
||||
serde_json::from_str(&format!("\"2025-01-01T12-00-00|{u1}\"")).unwrap();
|
||||
@@ -187,14 +196,22 @@ async fn test_pagination_cursor() {
|
||||
.join("03")
|
||||
.join("04")
|
||||
.join(format!("rollout-2025-03-04T09-00-00-{u4}.jsonl"));
|
||||
let head_5 = vec![
|
||||
expected_session_meta("2025-03-05T09-00-00", u5),
|
||||
serde_json::json!({"record_type": "response", "index": 0}),
|
||||
];
|
||||
let head_4 = vec![
|
||||
expected_session_meta("2025-03-04T09-00-00", u4),
|
||||
serde_json::json!({"record_type": "response", "index": 0}),
|
||||
];
|
||||
let head_5 = vec![serde_json::json!({
|
||||
"id": u5,
|
||||
"timestamp": "2025-03-05T09-00-00",
|
||||
"instructions": null,
|
||||
"cwd": ".",
|
||||
"originator": "test_originator",
|
||||
"cli_version": "test_version"
|
||||
})];
|
||||
let head_4 = vec![serde_json::json!({
|
||||
"id": u4,
|
||||
"timestamp": "2025-03-04T09-00-00",
|
||||
"instructions": null,
|
||||
"cwd": ".",
|
||||
"originator": "test_originator",
|
||||
"cli_version": "test_version"
|
||||
})];
|
||||
let expected_cursor1: Cursor =
|
||||
serde_json::from_str(&format!("\"2025-03-04T09-00-00|{u4}\"")).unwrap();
|
||||
let expected_page1 = ConversationsPage {
|
||||
@@ -229,14 +246,22 @@ async fn test_pagination_cursor() {
|
||||
.join("03")
|
||||
.join("02")
|
||||
.join(format!("rollout-2025-03-02T09-00-00-{u2}.jsonl"));
|
||||
let head_3 = vec![
|
||||
expected_session_meta("2025-03-03T09-00-00", u3),
|
||||
serde_json::json!({"record_type": "response", "index": 0}),
|
||||
];
|
||||
let head_2 = vec![
|
||||
expected_session_meta("2025-03-02T09-00-00", u2),
|
||||
serde_json::json!({"record_type": "response", "index": 0}),
|
||||
];
|
||||
let head_3 = vec![serde_json::json!({
|
||||
"id": u3,
|
||||
"timestamp": "2025-03-03T09-00-00",
|
||||
"instructions": null,
|
||||
"cwd": ".",
|
||||
"originator": "test_originator",
|
||||
"cli_version": "test_version"
|
||||
})];
|
||||
let head_2 = vec![serde_json::json!({
|
||||
"id": u2,
|
||||
"timestamp": "2025-03-02T09-00-00",
|
||||
"instructions": null,
|
||||
"cwd": ".",
|
||||
"originator": "test_originator",
|
||||
"cli_version": "test_version"
|
||||
})];
|
||||
let expected_cursor2: Cursor =
|
||||
serde_json::from_str(&format!("\"2025-03-02T09-00-00|{u2}\"")).unwrap();
|
||||
let expected_page2 = ConversationsPage {
|
||||
@@ -265,10 +290,14 @@ async fn test_pagination_cursor() {
|
||||
.join("03")
|
||||
.join("01")
|
||||
.join(format!("rollout-2025-03-01T09-00-00-{u1}.jsonl"));
|
||||
let head_1 = vec![
|
||||
expected_session_meta("2025-03-01T09-00-00", u1),
|
||||
serde_json::json!({"record_type": "response", "index": 0}),
|
||||
];
|
||||
let head_1 = vec![serde_json::json!({
|
||||
"id": u1,
|
||||
"timestamp": "2025-03-01T09-00-00",
|
||||
"instructions": null,
|
||||
"cwd": ".",
|
||||
"originator": "test_originator",
|
||||
"cli_version": "test_version"
|
||||
})];
|
||||
let expected_cursor3: Cursor =
|
||||
serde_json::from_str(&format!("\"2025-03-01T09-00-00|{u1}\"")).unwrap();
|
||||
let expected_page3 = ConversationsPage {
|
||||
@@ -276,7 +305,7 @@ async fn test_pagination_cursor() {
|
||||
path: p1,
|
||||
head: head_1,
|
||||
}],
|
||||
next_cursor: Some(expected_cursor3.clone()),
|
||||
next_cursor: Some(expected_cursor3),
|
||||
num_scanned_files: 5, // scanned 05, 04 (anchor), 03, 02 (anchor), 01
|
||||
reached_scan_cap: false,
|
||||
};
|
||||
@@ -304,15 +333,18 @@ async fn test_get_conversation_contents() {
|
||||
.join("04")
|
||||
.join("01")
|
||||
.join(format!("rollout-2025-04-01T10-30-00-{uuid}.jsonl"));
|
||||
let expected_head = vec![
|
||||
expected_session_meta(ts, uuid),
|
||||
serde_json::json!({"record_type": "response", "index": 0}),
|
||||
serde_json::json!({"record_type": "response", "index": 1}),
|
||||
];
|
||||
let expected_head = vec![serde_json::json!({
|
||||
"id": uuid,
|
||||
"timestamp": ts,
|
||||
"instructions": null,
|
||||
"cwd": ".",
|
||||
"originator": "test_originator",
|
||||
"cli_version": "test_version"
|
||||
})];
|
||||
let expected_cursor: Cursor = serde_json::from_str(&format!("\"{ts}|{uuid}\"")).unwrap();
|
||||
let expected_page = ConversationsPage {
|
||||
items: vec![ConversationItem {
|
||||
path: expected_path.clone(),
|
||||
path: expected_path,
|
||||
head: expected_head,
|
||||
}],
|
||||
next_cursor: Some(expected_cursor),
|
||||
@@ -322,10 +354,15 @@ async fn test_get_conversation_contents() {
|
||||
assert_eq!(page, expected_page);
|
||||
|
||||
// Entire file contents equality
|
||||
let meta = expected_session_meta(ts, uuid);
|
||||
let meta = serde_json::json!({"timestamp": ts, "type": "session_meta", "payload": {"id": uuid, "timestamp": ts, "instructions": null, "cwd": ".", "originator": "test_originator", "cli_version": "test_version"}});
|
||||
let user_event = serde_json::json!({
|
||||
"timestamp": ts,
|
||||
"type": "event_msg",
|
||||
"payload": {"type": "user_message", "message": "Hello from user", "kind": "plain"}
|
||||
});
|
||||
let rec0 = serde_json::json!({"record_type": "response", "index": 0});
|
||||
let rec1 = serde_json::json!({"record_type": "response", "index": 1});
|
||||
let expected_content = format!("{meta}\n{rec0}\n{rec1}\n");
|
||||
let expected_content = format!("{meta}\n{user_event}\n{rec0}\n{rec1}\n");
|
||||
assert_eq!(content, expected_content);
|
||||
}
|
||||
|
||||
@@ -357,7 +394,16 @@ async fn test_stable_ordering_same_second_pagination() {
|
||||
.join("07")
|
||||
.join("01")
|
||||
.join(format!("rollout-2025-07-01T00-00-00-{u2}.jsonl"));
|
||||
let head = |u: Uuid| -> Vec<serde_json::Value> { vec![expected_session_meta(ts, u)] };
|
||||
let head = |u: Uuid| -> Vec<serde_json::Value> {
|
||||
vec![serde_json::json!({
|
||||
"id": u,
|
||||
"timestamp": ts,
|
||||
"instructions": null,
|
||||
"cwd": ".",
|
||||
"originator": "test_originator",
|
||||
"cli_version": "test_version"
|
||||
})]
|
||||
};
|
||||
let expected_cursor1: Cursor = serde_json::from_str(&format!("\"{ts}|{u2}\"")).unwrap();
|
||||
let expected_page1 = ConversationsPage {
|
||||
items: vec![
|
||||
@@ -391,7 +437,7 @@ async fn test_stable_ordering_same_second_pagination() {
|
||||
path: p1,
|
||||
head: head(u1),
|
||||
}],
|
||||
next_cursor: Some(expected_cursor2.clone()),
|
||||
next_cursor: Some(expected_cursor2),
|
||||
num_scanned_files: 3, // scanned u3, u2 (anchor), u1
|
||||
reached_scan_cap: false,
|
||||
};
|
||||
|
||||
@@ -293,7 +293,7 @@ mod tests {
|
||||
// With the parent dir explicitly added as a writable root, the
|
||||
// outside write should be permitted.
|
||||
let policy_with_parent = SandboxPolicy::WorkspaceWrite {
|
||||
writable_roots: vec![parent.clone()],
|
||||
writable_roots: vec![parent],
|
||||
network_access: false,
|
||||
exclude_tmpdir_env_var: true,
|
||||
exclude_slash_tmp: true,
|
||||
|
||||
@@ -153,7 +153,7 @@ mod tests {
|
||||
// Build a policy that only includes the two test roots as writable and
|
||||
// does not automatically include defaults TMPDIR or /tmp.
|
||||
let policy = SandboxPolicy::WorkspaceWrite {
|
||||
writable_roots: vec![root_with_git.clone(), root_without_git.clone()],
|
||||
writable_roots: vec![root_with_git, root_without_git],
|
||||
network_access: false,
|
||||
exclude_tmpdir_env_var: true,
|
||||
exclude_slash_tmp: true,
|
||||
|
||||
@@ -69,3 +69,8 @@
|
||||
; Added on top of Chrome profile
|
||||
; Needed for python multiprocessing on MacOS for the SemLock
|
||||
(allow ipc-posix-sem)
|
||||
|
||||
; needed to look up user info, see https://crbug.com/792228
|
||||
(allow mach-lookup
|
||||
(global-name "com.apple.system.opendirectoryd.libinfo")
|
||||
)
|
||||
|
||||
@@ -326,10 +326,7 @@ mod tests {
|
||||
.format_default_shell_invocation(input.iter().map(|s| s.to_string()).collect());
|
||||
let expected_cmd = expected_cmd
|
||||
.iter()
|
||||
.map(|s| {
|
||||
s.replace("BASHRC_PATH", bashrc_path.to_str().unwrap())
|
||||
.to_string()
|
||||
})
|
||||
.map(|s| s.replace("BASHRC_PATH", bashrc_path.to_str().unwrap()))
|
||||
.collect();
|
||||
|
||||
assert_eq!(actual_cmd, Some(expected_cmd));
|
||||
@@ -435,10 +432,7 @@ mod macos_tests {
|
||||
.format_default_shell_invocation(input.iter().map(|s| s.to_string()).collect());
|
||||
let expected_cmd = expected_cmd
|
||||
.iter()
|
||||
.map(|s| {
|
||||
s.replace("ZSHRC_PATH", zshrc_path.to_str().unwrap())
|
||||
.to_string()
|
||||
})
|
||||
.map(|s| s.replace("ZSHRC_PATH", zshrc_path.to_str().unwrap()))
|
||||
.collect();
|
||||
|
||||
assert_eq!(actual_cmd, Some(expected_cmd));
|
||||
|
||||
@@ -3,8 +3,6 @@ use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use thiserror::Error;
|
||||
|
||||
use codex_protocol::mcp_protocol::AuthMode;
|
||||
|
||||
#[derive(Deserialize, Serialize, Clone, Debug, PartialEq, Default)]
|
||||
pub struct TokenData {
|
||||
/// Flat info parsed from the JWT in auth.json.
|
||||
@@ -22,36 +20,6 @@ pub struct TokenData {
|
||||
pub account_id: Option<String>,
|
||||
}
|
||||
|
||||
impl TokenData {
|
||||
/// Returns true if this is a plan that should use the traditional
|
||||
/// "metered" billing via an API key.
|
||||
pub(crate) fn should_use_api_key(
|
||||
&self,
|
||||
preferred_auth_method: AuthMode,
|
||||
is_openai_email: bool,
|
||||
) -> bool {
|
||||
if preferred_auth_method == AuthMode::ApiKey {
|
||||
return true;
|
||||
}
|
||||
// If the email is an OpenAI email, use AuthMode::ChatGPT unless preferred_auth_method is AuthMode::ApiKey.
|
||||
if is_openai_email {
|
||||
return false;
|
||||
}
|
||||
|
||||
self.id_token
|
||||
.chatgpt_plan_type
|
||||
.as_ref()
|
||||
.is_none_or(|plan| plan.is_plan_that_should_use_api_key())
|
||||
}
|
||||
|
||||
pub fn is_openai_email(&self) -> bool {
|
||||
self.id_token
|
||||
.email
|
||||
.as_deref()
|
||||
.is_some_and(|email| email.trim().to_ascii_lowercase().ends_with("@openai.com"))
|
||||
}
|
||||
}
|
||||
|
||||
/// Flat subset of useful claims in id_token from auth.json.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
|
||||
pub struct IdTokenInfo {
|
||||
@@ -79,28 +47,6 @@ pub(crate) enum PlanType {
|
||||
Unknown(String),
|
||||
}
|
||||
|
||||
impl PlanType {
|
||||
fn is_plan_that_should_use_api_key(&self) -> bool {
|
||||
match self {
|
||||
Self::Known(known) => {
|
||||
use KnownPlan::*;
|
||||
!matches!(known, Free | Plus | Pro | Team)
|
||||
}
|
||||
Self::Unknown(_) => {
|
||||
// Unknown plans should use the API key.
|
||||
true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_string(&self) -> String {
|
||||
match self {
|
||||
Self::Known(known) => format!("{known:?}").to_lowercase(),
|
||||
Self::Unknown(s) => s.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub(crate) enum KnownPlan {
|
||||
|
||||
180
codex-rs/core/src/truncate.rs
Normal file
180
codex-rs/core/src/truncate.rs
Normal file
@@ -0,0 +1,180 @@
|
||||
//! Utilities for truncating large chunks of output while preserving a prefix
|
||||
//! and suffix on UTF-8 boundaries.
|
||||
|
||||
/// Truncate the middle of a UTF-8 string to at most `max_bytes` bytes,
|
||||
/// preserving the beginning and the end. Returns the possibly truncated
|
||||
/// string and `Some(original_token_count)` (estimated at 4 bytes/token)
|
||||
/// if truncation occurred; otherwise returns the original string and `None`.
|
||||
pub(crate) fn truncate_middle(s: &str, max_bytes: usize) -> (String, Option<u64>) {
|
||||
if s.len() <= max_bytes {
|
||||
return (s.to_string(), None);
|
||||
}
|
||||
|
||||
let est_tokens = (s.len() as u64).div_ceil(4);
|
||||
if max_bytes == 0 {
|
||||
return (format!("…{est_tokens} tokens truncated…"), Some(est_tokens));
|
||||
}
|
||||
|
||||
fn truncate_on_boundary(input: &str, max_len: usize) -> &str {
|
||||
if input.len() <= max_len {
|
||||
return input;
|
||||
}
|
||||
let mut end = max_len;
|
||||
while end > 0 && !input.is_char_boundary(end) {
|
||||
end -= 1;
|
||||
}
|
||||
&input[..end]
|
||||
}
|
||||
|
||||
fn pick_prefix_end(s: &str, left_budget: usize) -> usize {
|
||||
if let Some(head) = s.get(..left_budget)
|
||||
&& let Some(i) = head.rfind('\n')
|
||||
{
|
||||
return i + 1;
|
||||
}
|
||||
truncate_on_boundary(s, left_budget).len()
|
||||
}
|
||||
|
||||
fn pick_suffix_start(s: &str, right_budget: usize) -> usize {
|
||||
let start_tail = s.len().saturating_sub(right_budget);
|
||||
if let Some(tail) = s.get(start_tail..)
|
||||
&& let Some(i) = tail.find('\n')
|
||||
{
|
||||
return start_tail + i + 1;
|
||||
}
|
||||
|
||||
let mut idx = start_tail.min(s.len());
|
||||
while idx < s.len() && !s.is_char_boundary(idx) {
|
||||
idx += 1;
|
||||
}
|
||||
idx
|
||||
}
|
||||
|
||||
let mut guess_tokens = est_tokens;
|
||||
for _ in 0..4 {
|
||||
let marker = format!("…{guess_tokens} tokens truncated…");
|
||||
let marker_len = marker.len();
|
||||
let keep_budget = max_bytes.saturating_sub(marker_len);
|
||||
if keep_budget == 0 {
|
||||
return (format!("…{est_tokens} tokens truncated…"), Some(est_tokens));
|
||||
}
|
||||
|
||||
let left_budget = keep_budget / 2;
|
||||
let right_budget = keep_budget - left_budget;
|
||||
let prefix_end = pick_prefix_end(s, left_budget);
|
||||
let mut suffix_start = pick_suffix_start(s, right_budget);
|
||||
if suffix_start < prefix_end {
|
||||
suffix_start = prefix_end;
|
||||
}
|
||||
|
||||
let kept_content_bytes = prefix_end + (s.len() - suffix_start);
|
||||
let truncated_content_bytes = s.len().saturating_sub(kept_content_bytes);
|
||||
let new_tokens = (truncated_content_bytes as u64).div_ceil(4);
|
||||
|
||||
if new_tokens == guess_tokens {
|
||||
let mut out = String::with_capacity(marker_len + kept_content_bytes + 1);
|
||||
out.push_str(&s[..prefix_end]);
|
||||
out.push_str(&marker);
|
||||
out.push('\n');
|
||||
out.push_str(&s[suffix_start..]);
|
||||
return (out, Some(est_tokens));
|
||||
}
|
||||
|
||||
guess_tokens = new_tokens;
|
||||
}
|
||||
|
||||
let marker = format!("…{guess_tokens} tokens truncated…");
|
||||
let marker_len = marker.len();
|
||||
let keep_budget = max_bytes.saturating_sub(marker_len);
|
||||
if keep_budget == 0 {
|
||||
return (format!("…{est_tokens} tokens truncated…"), Some(est_tokens));
|
||||
}
|
||||
|
||||
let left_budget = keep_budget / 2;
|
||||
let right_budget = keep_budget - left_budget;
|
||||
let prefix_end = pick_prefix_end(s, left_budget);
|
||||
let suffix_start = pick_suffix_start(s, right_budget);
|
||||
|
||||
let mut out = String::with_capacity(marker_len + prefix_end + (s.len() - suffix_start) + 1);
|
||||
out.push_str(&s[..prefix_end]);
|
||||
out.push_str(&marker);
|
||||
out.push('\n');
|
||||
out.push_str(&s[suffix_start..]);
|
||||
(out, Some(est_tokens))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::truncate_middle;
|
||||
|
||||
#[test]
|
||||
fn truncate_middle_no_newlines_fallback() {
|
||||
let s = "abcdefghijklmnopqrstuvwxyz0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ*";
|
||||
let max_bytes = 32;
|
||||
let (out, original) = truncate_middle(s, max_bytes);
|
||||
assert!(out.starts_with("abc"));
|
||||
assert!(out.contains("tokens truncated"));
|
||||
assert!(out.ends_with("XYZ*"));
|
||||
assert_eq!(original, Some((s.len() as u64).div_ceil(4)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncate_middle_prefers_newline_boundaries() {
|
||||
let mut s = String::new();
|
||||
for i in 1..=20 {
|
||||
s.push_str(&format!("{i:03}\n"));
|
||||
}
|
||||
assert_eq!(s.len(), 80);
|
||||
|
||||
let max_bytes = 64;
|
||||
let (out, tokens) = truncate_middle(&s, max_bytes);
|
||||
assert!(out.starts_with("001\n002\n003\n004\n"));
|
||||
assert!(out.contains("tokens truncated"));
|
||||
assert!(out.ends_with("017\n018\n019\n020\n"));
|
||||
assert_eq!(tokens, Some(20));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncate_middle_handles_utf8_content() {
|
||||
let s = "😀😀😀😀😀😀😀😀😀😀\nsecond line with ascii text\n";
|
||||
let max_bytes = 32;
|
||||
let (out, tokens) = truncate_middle(s, max_bytes);
|
||||
|
||||
assert!(out.contains("tokens truncated"));
|
||||
assert!(!out.contains('\u{fffd}'));
|
||||
assert_eq!(tokens, Some((s.len() as u64).div_ceil(4)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncate_middle_prefers_newline_boundaries_2() {
|
||||
// Build a multi-line string of 20 numbered lines (each "NNN\n").
|
||||
let mut s = String::new();
|
||||
for i in 1..=20 {
|
||||
s.push_str(&format!("{i:03}\n"));
|
||||
}
|
||||
// Total length: 20 lines * 4 bytes per line = 80 bytes.
|
||||
assert_eq!(s.len(), 80);
|
||||
|
||||
// Choose a cap that forces truncation while leaving room for
|
||||
// a few lines on each side after accounting for the marker.
|
||||
let max_bytes = 64;
|
||||
// Expect exact output: first 4 lines, marker, last 4 lines, and correct token estimate (80/4 = 20).
|
||||
assert_eq!(
|
||||
truncate_middle(&s, max_bytes),
|
||||
(
|
||||
r#"001
|
||||
002
|
||||
003
|
||||
004
|
||||
…12 tokens truncated…
|
||||
017
|
||||
018
|
||||
019
|
||||
020
|
||||
"#
|
||||
.to_string(),
|
||||
Some(20)
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -678,7 +678,7 @@ index {left_oid}..{right_oid}
|
||||
let dest = dir.path().join("dest.txt");
|
||||
let mut acc = TurnDiffTracker::new();
|
||||
let mv = HashMap::from([(
|
||||
src.clone(),
|
||||
src,
|
||||
FileChange::Update {
|
||||
unified_diff: "".into(),
|
||||
move_path: Some(dest.clone()),
|
||||
|
||||
22
codex-rs/core/src/unified_exec/errors.rs
Normal file
22
codex-rs/core/src/unified_exec/errors.rs
Normal file
@@ -0,0 +1,22 @@
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub(crate) enum UnifiedExecError {
|
||||
#[error("Failed to create unified exec session: {pty_error}")]
|
||||
CreateSession {
|
||||
#[source]
|
||||
pty_error: anyhow::Error,
|
||||
},
|
||||
#[error("Unknown session id {session_id}")]
|
||||
UnknownSessionId { session_id: i32 },
|
||||
#[error("failed to write to stdin")]
|
||||
WriteToStdin,
|
||||
#[error("missing command line for unified exec request")]
|
||||
MissingCommandLine,
|
||||
}
|
||||
|
||||
impl UnifiedExecError {
|
||||
pub(crate) fn create_session(error: anyhow::Error) -> Self {
|
||||
Self::CreateSession { pty_error: error }
|
||||
}
|
||||
}
|
||||
633
codex-rs/core/src/unified_exec/mod.rs
Normal file
633
codex-rs/core/src/unified_exec/mod.rs
Normal file
@@ -0,0 +1,633 @@
|
||||
use portable_pty::CommandBuilder;
|
||||
use portable_pty::PtySize;
|
||||
use portable_pty::native_pty_system;
|
||||
use std::collections::HashMap;
|
||||
use std::collections::VecDeque;
|
||||
use std::io::ErrorKind;
|
||||
use std::io::Read;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex as StdMutex;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::atomic::AtomicI32;
|
||||
use std::sync::atomic::Ordering;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::Notify;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::task::JoinHandle;
|
||||
use tokio::time::Duration;
|
||||
use tokio::time::Instant;
|
||||
|
||||
use crate::exec_command::ExecCommandSession;
|
||||
use crate::truncate::truncate_middle;
|
||||
|
||||
mod errors;
|
||||
|
||||
pub(crate) use errors::UnifiedExecError;
|
||||
|
||||
const DEFAULT_TIMEOUT_MS: u64 = 1_000;
|
||||
const MAX_TIMEOUT_MS: u64 = 60_000;
|
||||
const UNIFIED_EXEC_OUTPUT_MAX_BYTES: usize = 128 * 1024; // 128 KiB
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct UnifiedExecRequest<'a> {
|
||||
pub session_id: Option<i32>,
|
||||
pub input_chunks: &'a [String],
|
||||
pub timeout_ms: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub(crate) struct UnifiedExecResult {
|
||||
pub session_id: Option<i32>,
|
||||
pub output: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub(crate) struct UnifiedExecSessionManager {
|
||||
next_session_id: AtomicI32,
|
||||
sessions: Mutex<HashMap<i32, ManagedUnifiedExecSession>>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct ManagedUnifiedExecSession {
|
||||
session: ExecCommandSession,
|
||||
output_buffer: OutputBuffer,
|
||||
/// Notifies waiters whenever new output has been appended to
|
||||
/// `output_buffer`, allowing clients to poll for fresh data.
|
||||
output_notify: Arc<Notify>,
|
||||
output_task: JoinHandle<()>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
struct OutputBufferState {
|
||||
chunks: VecDeque<Vec<u8>>,
|
||||
total_bytes: usize,
|
||||
}
|
||||
|
||||
impl OutputBufferState {
|
||||
fn push_chunk(&mut self, chunk: Vec<u8>) {
|
||||
self.total_bytes = self.total_bytes.saturating_add(chunk.len());
|
||||
self.chunks.push_back(chunk);
|
||||
|
||||
let mut excess = self
|
||||
.total_bytes
|
||||
.saturating_sub(UNIFIED_EXEC_OUTPUT_MAX_BYTES);
|
||||
|
||||
while excess > 0 {
|
||||
match self.chunks.front_mut() {
|
||||
Some(front) if excess >= front.len() => {
|
||||
excess -= front.len();
|
||||
self.total_bytes = self.total_bytes.saturating_sub(front.len());
|
||||
self.chunks.pop_front();
|
||||
}
|
||||
Some(front) => {
|
||||
front.drain(..excess);
|
||||
self.total_bytes = self.total_bytes.saturating_sub(excess);
|
||||
break;
|
||||
}
|
||||
None => break,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn drain(&mut self) -> Vec<Vec<u8>> {
|
||||
let drained: Vec<Vec<u8>> = self.chunks.drain(..).collect();
|
||||
self.total_bytes = 0;
|
||||
drained
|
||||
}
|
||||
}
|
||||
|
||||
type OutputBuffer = Arc<Mutex<OutputBufferState>>;
|
||||
type OutputHandles = (OutputBuffer, Arc<Notify>);
|
||||
|
||||
impl ManagedUnifiedExecSession {
|
||||
fn new(session: ExecCommandSession) -> Self {
|
||||
let output_buffer = Arc::new(Mutex::new(OutputBufferState::default()));
|
||||
let output_notify = Arc::new(Notify::new());
|
||||
let mut receiver = session.output_receiver();
|
||||
let buffer_clone = Arc::clone(&output_buffer);
|
||||
let notify_clone = Arc::clone(&output_notify);
|
||||
let output_task = tokio::spawn(async move {
|
||||
while let Ok(chunk) = receiver.recv().await {
|
||||
let mut guard = buffer_clone.lock().await;
|
||||
guard.push_chunk(chunk);
|
||||
drop(guard);
|
||||
notify_clone.notify_waiters();
|
||||
}
|
||||
});
|
||||
|
||||
Self {
|
||||
session,
|
||||
output_buffer,
|
||||
output_notify,
|
||||
output_task,
|
||||
}
|
||||
}
|
||||
|
||||
fn writer_sender(&self) -> mpsc::Sender<Vec<u8>> {
|
||||
self.session.writer_sender()
|
||||
}
|
||||
|
||||
fn output_handles(&self) -> OutputHandles {
|
||||
(
|
||||
Arc::clone(&self.output_buffer),
|
||||
Arc::clone(&self.output_notify),
|
||||
)
|
||||
}
|
||||
|
||||
fn has_exited(&self) -> bool {
|
||||
self.session.has_exited()
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for ManagedUnifiedExecSession {
|
||||
fn drop(&mut self) {
|
||||
self.output_task.abort();
|
||||
}
|
||||
}
|
||||
|
||||
impl UnifiedExecSessionManager {
|
||||
pub async fn handle_request(
|
||||
&self,
|
||||
request: UnifiedExecRequest<'_>,
|
||||
) -> Result<UnifiedExecResult, UnifiedExecError> {
|
||||
let (timeout_ms, timeout_warning) = match request.timeout_ms {
|
||||
Some(requested) if requested > MAX_TIMEOUT_MS => (
|
||||
MAX_TIMEOUT_MS,
|
||||
Some(format!(
|
||||
"Warning: requested timeout {requested}ms exceeds maximum of {MAX_TIMEOUT_MS}ms; clamping to {MAX_TIMEOUT_MS}ms.\n"
|
||||
)),
|
||||
),
|
||||
Some(requested) => (requested, None),
|
||||
None => (DEFAULT_TIMEOUT_MS, None),
|
||||
};
|
||||
|
||||
let mut new_session: Option<ManagedUnifiedExecSession> = None;
|
||||
let session_id;
|
||||
let writer_tx;
|
||||
let output_buffer;
|
||||
let output_notify;
|
||||
|
||||
if let Some(existing_id) = request.session_id {
|
||||
let mut sessions = self.sessions.lock().await;
|
||||
match sessions.get(&existing_id) {
|
||||
Some(session) => {
|
||||
if session.has_exited() {
|
||||
sessions.remove(&existing_id);
|
||||
return Err(UnifiedExecError::UnknownSessionId {
|
||||
session_id: existing_id,
|
||||
});
|
||||
}
|
||||
let (buffer, notify) = session.output_handles();
|
||||
session_id = existing_id;
|
||||
writer_tx = session.writer_sender();
|
||||
output_buffer = buffer;
|
||||
output_notify = notify;
|
||||
}
|
||||
None => {
|
||||
return Err(UnifiedExecError::UnknownSessionId {
|
||||
session_id: existing_id,
|
||||
});
|
||||
}
|
||||
}
|
||||
drop(sessions);
|
||||
} else {
|
||||
let command = request.input_chunks.to_vec();
|
||||
let new_id = self.next_session_id.fetch_add(1, Ordering::SeqCst);
|
||||
let session = create_unified_exec_session(&command).await?;
|
||||
let managed_session = ManagedUnifiedExecSession::new(session);
|
||||
let (buffer, notify) = managed_session.output_handles();
|
||||
writer_tx = managed_session.writer_sender();
|
||||
output_buffer = buffer;
|
||||
output_notify = notify;
|
||||
session_id = new_id;
|
||||
new_session = Some(managed_session);
|
||||
};
|
||||
|
||||
if request.session_id.is_some() {
|
||||
let joined_input = request.input_chunks.join(" ");
|
||||
if !joined_input.is_empty() && writer_tx.send(joined_input.into_bytes()).await.is_err()
|
||||
{
|
||||
return Err(UnifiedExecError::WriteToStdin);
|
||||
}
|
||||
}
|
||||
|
||||
let mut collected: Vec<u8> = Vec::with_capacity(4096);
|
||||
let start = Instant::now();
|
||||
let deadline = start + Duration::from_millis(timeout_ms);
|
||||
|
||||
loop {
|
||||
let drained_chunks;
|
||||
let mut wait_for_output = None;
|
||||
{
|
||||
let mut guard = output_buffer.lock().await;
|
||||
drained_chunks = guard.drain();
|
||||
if drained_chunks.is_empty() {
|
||||
wait_for_output = Some(output_notify.notified());
|
||||
}
|
||||
}
|
||||
|
||||
if drained_chunks.is_empty() {
|
||||
let remaining = deadline.saturating_duration_since(Instant::now());
|
||||
if remaining == Duration::ZERO {
|
||||
break;
|
||||
}
|
||||
|
||||
let notified = wait_for_output.unwrap_or_else(|| output_notify.notified());
|
||||
tokio::pin!(notified);
|
||||
tokio::select! {
|
||||
_ = &mut notified => {}
|
||||
_ = tokio::time::sleep(remaining) => break,
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
for chunk in drained_chunks {
|
||||
collected.extend_from_slice(&chunk);
|
||||
}
|
||||
|
||||
if Instant::now() >= deadline {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let (output, _maybe_tokens) = truncate_middle(
|
||||
&String::from_utf8_lossy(&collected),
|
||||
UNIFIED_EXEC_OUTPUT_MAX_BYTES,
|
||||
);
|
||||
let output = if let Some(warning) = timeout_warning {
|
||||
format!("{warning}{output}")
|
||||
} else {
|
||||
output
|
||||
};
|
||||
|
||||
let should_store_session = if let Some(session) = new_session.as_ref() {
|
||||
!session.has_exited()
|
||||
} else if request.session_id.is_some() {
|
||||
let mut sessions = self.sessions.lock().await;
|
||||
if let Some(existing) = sessions.get(&session_id) {
|
||||
if existing.has_exited() {
|
||||
sessions.remove(&session_id);
|
||||
false
|
||||
} else {
|
||||
true
|
||||
}
|
||||
} else {
|
||||
false
|
||||
}
|
||||
} else {
|
||||
true
|
||||
};
|
||||
|
||||
if should_store_session {
|
||||
if let Some(session) = new_session {
|
||||
self.sessions.lock().await.insert(session_id, session);
|
||||
}
|
||||
Ok(UnifiedExecResult {
|
||||
session_id: Some(session_id),
|
||||
output,
|
||||
})
|
||||
} else {
|
||||
Ok(UnifiedExecResult {
|
||||
session_id: None,
|
||||
output,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn create_unified_exec_session(
|
||||
command: &[String],
|
||||
) -> Result<ExecCommandSession, UnifiedExecError> {
|
||||
if command.is_empty() {
|
||||
return Err(UnifiedExecError::MissingCommandLine);
|
||||
}
|
||||
|
||||
let pty_system = native_pty_system();
|
||||
|
||||
let pair = pty_system
|
||||
.openpty(PtySize {
|
||||
rows: 24,
|
||||
cols: 80,
|
||||
pixel_width: 0,
|
||||
pixel_height: 0,
|
||||
})
|
||||
.map_err(UnifiedExecError::create_session)?;
|
||||
|
||||
// Safe thanks to the check at the top of the function.
|
||||
let mut command_builder = CommandBuilder::new(command[0].clone());
|
||||
for arg in &command[1..] {
|
||||
command_builder.arg(arg);
|
||||
}
|
||||
|
||||
let mut child = pair
|
||||
.slave
|
||||
.spawn_command(command_builder)
|
||||
.map_err(UnifiedExecError::create_session)?;
|
||||
let killer = child.clone_killer();
|
||||
|
||||
let (writer_tx, mut writer_rx) = mpsc::channel::<Vec<u8>>(128);
|
||||
let (output_tx, _) = tokio::sync::broadcast::channel::<Vec<u8>>(256);
|
||||
|
||||
let mut reader = pair
|
||||
.master
|
||||
.try_clone_reader()
|
||||
.map_err(UnifiedExecError::create_session)?;
|
||||
let output_tx_clone = output_tx.clone();
|
||||
let reader_handle = tokio::task::spawn_blocking(move || {
|
||||
let mut buf = [0u8; 8192];
|
||||
loop {
|
||||
match reader.read(&mut buf) {
|
||||
Ok(0) => break,
|
||||
Ok(n) => {
|
||||
let _ = output_tx_clone.send(buf[..n].to_vec());
|
||||
}
|
||||
Err(ref e) if e.kind() == ErrorKind::Interrupted => continue,
|
||||
Err(ref e) if e.kind() == ErrorKind::WouldBlock => {
|
||||
std::thread::sleep(Duration::from_millis(5));
|
||||
continue;
|
||||
}
|
||||
Err(_) => break,
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let writer = pair
|
||||
.master
|
||||
.take_writer()
|
||||
.map_err(UnifiedExecError::create_session)?;
|
||||
let writer = Arc::new(StdMutex::new(writer));
|
||||
let writer_handle = tokio::spawn({
|
||||
let writer = writer.clone();
|
||||
async move {
|
||||
while let Some(bytes) = writer_rx.recv().await {
|
||||
let writer = writer.clone();
|
||||
let _ = tokio::task::spawn_blocking(move || {
|
||||
if let Ok(mut guard) = writer.lock() {
|
||||
use std::io::Write;
|
||||
let _ = guard.write_all(&bytes);
|
||||
let _ = guard.flush();
|
||||
}
|
||||
})
|
||||
.await;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let exit_status = Arc::new(AtomicBool::new(false));
|
||||
let wait_exit_status = Arc::clone(&exit_status);
|
||||
let wait_handle = tokio::task::spawn_blocking(move || {
|
||||
let _ = child.wait();
|
||||
wait_exit_status.store(true, Ordering::SeqCst);
|
||||
});
|
||||
|
||||
Ok(ExecCommandSession::new(
|
||||
writer_tx,
|
||||
output_tx,
|
||||
killer,
|
||||
reader_handle,
|
||||
writer_handle,
|
||||
wait_handle,
|
||||
exit_status,
|
||||
))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn push_chunk_trims_only_excess_bytes() {
|
||||
let mut buffer = OutputBufferState::default();
|
||||
buffer.push_chunk(vec![b'a'; UNIFIED_EXEC_OUTPUT_MAX_BYTES]);
|
||||
buffer.push_chunk(vec![b'b']);
|
||||
buffer.push_chunk(vec![b'c']);
|
||||
|
||||
assert_eq!(buffer.total_bytes, UNIFIED_EXEC_OUTPUT_MAX_BYTES);
|
||||
assert_eq!(buffer.chunks.len(), 3);
|
||||
assert_eq!(
|
||||
buffer.chunks.front().unwrap().len(),
|
||||
UNIFIED_EXEC_OUTPUT_MAX_BYTES - 2
|
||||
);
|
||||
assert_eq!(buffer.chunks.pop_back().unwrap(), vec![b'c']);
|
||||
assert_eq!(buffer.chunks.pop_back().unwrap(), vec![b'b']);
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn unified_exec_persists_across_requests_jif() -> Result<(), UnifiedExecError> {
|
||||
let manager = UnifiedExecSessionManager::default();
|
||||
|
||||
let open_shell = manager
|
||||
.handle_request(UnifiedExecRequest {
|
||||
session_id: None,
|
||||
input_chunks: &["bash".to_string(), "-i".to_string()],
|
||||
timeout_ms: Some(1_500),
|
||||
})
|
||||
.await?;
|
||||
let session_id = open_shell.session_id.expect("expected session_id");
|
||||
|
||||
manager
|
||||
.handle_request(UnifiedExecRequest {
|
||||
session_id: Some(session_id),
|
||||
input_chunks: &[
|
||||
"export".to_string(),
|
||||
"CODEX_INTERACTIVE_SHELL_VAR=codex\n".to_string(),
|
||||
],
|
||||
timeout_ms: Some(2_500),
|
||||
})
|
||||
.await?;
|
||||
|
||||
let out_2 = manager
|
||||
.handle_request(UnifiedExecRequest {
|
||||
session_id: Some(session_id),
|
||||
input_chunks: &["echo $CODEX_INTERACTIVE_SHELL_VAR\n".to_string()],
|
||||
timeout_ms: Some(1_500),
|
||||
})
|
||||
.await?;
|
||||
assert!(out_2.output.contains("codex"));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn multi_unified_exec_sessions() -> Result<(), UnifiedExecError> {
|
||||
let manager = UnifiedExecSessionManager::default();
|
||||
|
||||
let shell_a = manager
|
||||
.handle_request(UnifiedExecRequest {
|
||||
session_id: None,
|
||||
input_chunks: &["/bin/bash".to_string(), "-i".to_string()],
|
||||
timeout_ms: Some(1_500),
|
||||
})
|
||||
.await?;
|
||||
let session_a = shell_a.session_id.expect("expected session id");
|
||||
|
||||
manager
|
||||
.handle_request(UnifiedExecRequest {
|
||||
session_id: Some(session_a),
|
||||
input_chunks: &["export CODEX_INTERACTIVE_SHELL_VAR=codex\n".to_string()],
|
||||
timeout_ms: Some(1_500),
|
||||
})
|
||||
.await?;
|
||||
|
||||
let out_2 = manager
|
||||
.handle_request(UnifiedExecRequest {
|
||||
session_id: None,
|
||||
input_chunks: &[
|
||||
"echo".to_string(),
|
||||
"$CODEX_INTERACTIVE_SHELL_VAR\n".to_string(),
|
||||
],
|
||||
timeout_ms: Some(1_500),
|
||||
})
|
||||
.await?;
|
||||
assert!(!out_2.output.contains("codex"));
|
||||
|
||||
let out_3 = manager
|
||||
.handle_request(UnifiedExecRequest {
|
||||
session_id: Some(session_a),
|
||||
input_chunks: &["echo $CODEX_INTERACTIVE_SHELL_VAR\n".to_string()],
|
||||
timeout_ms: Some(1_500),
|
||||
})
|
||||
.await?;
|
||||
assert!(out_3.output.contains("codex"));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
#[tokio::test]
|
||||
async fn unified_exec_timeouts() -> Result<(), UnifiedExecError> {
|
||||
let manager = UnifiedExecSessionManager::default();
|
||||
|
||||
let open_shell = manager
|
||||
.handle_request(UnifiedExecRequest {
|
||||
session_id: None,
|
||||
input_chunks: &["bash".to_string(), "-i".to_string()],
|
||||
timeout_ms: Some(1_500),
|
||||
})
|
||||
.await?;
|
||||
let session_id = open_shell.session_id.expect("expected session id");
|
||||
|
||||
manager
|
||||
.handle_request(UnifiedExecRequest {
|
||||
session_id: Some(session_id),
|
||||
input_chunks: &[
|
||||
"export".to_string(),
|
||||
"CODEX_INTERACTIVE_SHELL_VAR=codex\n".to_string(),
|
||||
],
|
||||
timeout_ms: Some(1_500),
|
||||
})
|
||||
.await?;
|
||||
|
||||
let out_2 = manager
|
||||
.handle_request(UnifiedExecRequest {
|
||||
session_id: Some(session_id),
|
||||
input_chunks: &["sleep 5 && echo $CODEX_INTERACTIVE_SHELL_VAR\n".to_string()],
|
||||
timeout_ms: Some(10),
|
||||
})
|
||||
.await?;
|
||||
assert!(!out_2.output.contains("codex"));
|
||||
|
||||
tokio::time::sleep(Duration::from_secs(7)).await;
|
||||
|
||||
let empty = Vec::new();
|
||||
let out_3 = manager
|
||||
.handle_request(UnifiedExecRequest {
|
||||
session_id: Some(session_id),
|
||||
input_chunks: &empty,
|
||||
timeout_ms: Some(100),
|
||||
})
|
||||
.await?;
|
||||
|
||||
assert!(out_3.output.contains("codex"));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
#[tokio::test]
|
||||
async fn requests_with_large_timeout_are_capped() -> Result<(), UnifiedExecError> {
|
||||
let manager = UnifiedExecSessionManager::default();
|
||||
|
||||
let result = manager
|
||||
.handle_request(UnifiedExecRequest {
|
||||
session_id: None,
|
||||
input_chunks: &["echo".to_string(), "codex".to_string()],
|
||||
timeout_ms: Some(120_000),
|
||||
})
|
||||
.await?;
|
||||
|
||||
assert!(result.output.starts_with(
|
||||
"Warning: requested timeout 120000ms exceeds maximum of 60000ms; clamping to 60000ms.\n"
|
||||
));
|
||||
assert!(result.output.contains("codex"));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
#[tokio::test]
|
||||
async fn completed_commands_do_not_persist_sessions() -> Result<(), UnifiedExecError> {
|
||||
let manager = UnifiedExecSessionManager::default();
|
||||
let result = manager
|
||||
.handle_request(UnifiedExecRequest {
|
||||
session_id: None,
|
||||
input_chunks: &["/bin/echo".to_string(), "codex".to_string()],
|
||||
timeout_ms: Some(1_500),
|
||||
})
|
||||
.await?;
|
||||
|
||||
assert!(result.session_id.is_none());
|
||||
assert!(result.output.contains("codex"));
|
||||
|
||||
assert!(manager.sessions.lock().await.is_empty());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn reusing_completed_session_returns_unknown_session() -> Result<(), UnifiedExecError> {
|
||||
let manager = UnifiedExecSessionManager::default();
|
||||
|
||||
let open_shell = manager
|
||||
.handle_request(UnifiedExecRequest {
|
||||
session_id: None,
|
||||
input_chunks: &["/bin/bash".to_string(), "-i".to_string()],
|
||||
timeout_ms: Some(1_500),
|
||||
})
|
||||
.await?;
|
||||
let session_id = open_shell.session_id.expect("expected session id");
|
||||
|
||||
manager
|
||||
.handle_request(UnifiedExecRequest {
|
||||
session_id: Some(session_id),
|
||||
input_chunks: &["exit\n".to_string()],
|
||||
timeout_ms: Some(1_500),
|
||||
})
|
||||
.await?;
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(200)).await;
|
||||
|
||||
let err = manager
|
||||
.handle_request(UnifiedExecRequest {
|
||||
session_id: Some(session_id),
|
||||
input_chunks: &[],
|
||||
timeout_ms: Some(100),
|
||||
})
|
||||
.await
|
||||
.expect_err("expected unknown session error");
|
||||
|
||||
match err {
|
||||
UnifiedExecError::UnknownSessionId { session_id: err_id } => {
|
||||
assert_eq!(err_id, session_id);
|
||||
}
|
||||
other => panic!("expected UnknownSessionId, got {other:?}"),
|
||||
}
|
||||
|
||||
assert!(!manager.sessions.lock().await.contains_key(&session_id));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,6 @@
|
||||
use assert_cmd::Command as AssertCommand;
|
||||
use codex_core::RolloutRecorder;
|
||||
use codex_core::protocol::GitInfo;
|
||||
use codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR;
|
||||
use std::time::Duration;
|
||||
use std::time::Instant;
|
||||
@@ -77,6 +79,22 @@ async fn chat_mode_stream_cli() {
|
||||
assert_eq!(hi_lines, 1, "Expected exactly one line with 'hi'");
|
||||
|
||||
server.verify().await;
|
||||
|
||||
// Verify a new session rollout was created and is discoverable via list_conversations
|
||||
let page = RolloutRecorder::list_conversations(home.path(), 10, None)
|
||||
.await
|
||||
.expect("list conversations");
|
||||
assert!(
|
||||
!page.items.is_empty(),
|
||||
"expected at least one session to be listed"
|
||||
);
|
||||
// First line of head must be the SessionMeta payload (id/timestamp)
|
||||
let head0 = page.items[0].head.first().expect("missing head record");
|
||||
assert!(head0.get("id").is_some(), "head[0] missing id");
|
||||
assert!(
|
||||
head0.get("timestamp").is_some(),
|
||||
"head[0] missing timestamp"
|
||||
);
|
||||
}
|
||||
|
||||
/// Verify that passing `-c experimental_instructions_file=...` to the CLI
|
||||
@@ -297,8 +315,10 @@ async fn integration_creates_and_checks_session_file() {
|
||||
Ok(v) => v,
|
||||
Err(_) => continue,
|
||||
};
|
||||
if item.get("type").and_then(|t| t.as_str()) == Some("message")
|
||||
&& let Some(c) = item.get("content")
|
||||
if item.get("type").and_then(|t| t.as_str()) == Some("response_item")
|
||||
&& let Some(payload) = item.get("payload")
|
||||
&& payload.get("type").and_then(|t| t.as_str()) == Some("message")
|
||||
&& let Some(c) = payload.get("content")
|
||||
&& c.to_string().contains(&marker)
|
||||
{
|
||||
matching_path = Some(path.to_path_buf());
|
||||
@@ -361,9 +381,16 @@ async fn integration_creates_and_checks_session_file() {
|
||||
.unwrap_or_else(|_| panic!("missing session meta line"));
|
||||
let meta: serde_json::Value = serde_json::from_str(meta_line)
|
||||
.unwrap_or_else(|_| panic!("Failed to parse session meta line as JSON"));
|
||||
assert!(meta.get("id").is_some(), "SessionMeta missing id");
|
||||
assert_eq!(
|
||||
meta.get("type").and_then(|v| v.as_str()),
|
||||
Some("session_meta")
|
||||
);
|
||||
let payload = meta
|
||||
.get("payload")
|
||||
.unwrap_or_else(|| panic!("Missing payload in meta line"));
|
||||
assert!(payload.get("id").is_some(), "SessionMeta missing id");
|
||||
assert!(
|
||||
meta.get("timestamp").is_some(),
|
||||
payload.get("timestamp").is_some(),
|
||||
"SessionMeta missing timestamp"
|
||||
);
|
||||
|
||||
@@ -375,8 +402,10 @@ async fn integration_creates_and_checks_session_file() {
|
||||
let Ok(item) = serde_json::from_str::<serde_json::Value>(line) else {
|
||||
continue;
|
||||
};
|
||||
if item.get("type").and_then(|t| t.as_str()) == Some("message")
|
||||
&& let Some(c) = item.get("content")
|
||||
if item.get("type").and_then(|t| t.as_str()) == Some("response_item")
|
||||
&& let Some(payload) = item.get("payload")
|
||||
&& payload.get("type").and_then(|t| t.as_str()) == Some("message")
|
||||
&& let Some(c) = payload.get("content")
|
||||
&& c.to_string().contains(&marker)
|
||||
{
|
||||
found_message = true;
|
||||
@@ -589,7 +618,7 @@ async fn integration_git_info_unit_test() {
|
||||
|
||||
// 5. Test serialization to ensure it works in SessionMeta
|
||||
let serialized = serde_json::to_string(&git_info).unwrap();
|
||||
let deserialized: codex_core::git_info::GitInfo = serde_json::from_str(&serialized).unwrap();
|
||||
let deserialized: GitInfo = serde_json::from_str(&serialized).unwrap();
|
||||
|
||||
assert_eq!(git_info.commit_hash, deserialized.commit_hash);
|
||||
assert_eq!(git_info.branch, deserialized.branch);
|
||||
|
||||
@@ -4,14 +4,12 @@ use codex_core::ModelProviderInfo;
|
||||
use codex_core::NewConversation;
|
||||
use codex_core::WireApi;
|
||||
use codex_core::built_in_model_providers;
|
||||
use codex_core::protocol::AgentMessageEvent;
|
||||
use codex_core::project_doc::get_user_instructions;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::InputMessageKind;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_core::protocol::UserMessageEvent;
|
||||
use codex_core::shell::default_user_shell;
|
||||
use codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR;
|
||||
use codex_protocol::mcp_protocol::AuthMode;
|
||||
use core_test_support::load_default_config_for_test;
|
||||
use core_test_support::load_sse_fixture_with_id;
|
||||
use core_test_support::wait_for_event;
|
||||
@@ -126,16 +124,21 @@ async fn resume_includes_initial_messages_and_sends_prior_items() {
|
||||
let tmpdir = TempDir::new().unwrap();
|
||||
let session_path = tmpdir.path().join("resume-session.jsonl");
|
||||
let mut f = std::fs::File::create(&session_path).unwrap();
|
||||
let convo_id = Uuid::new_v4();
|
||||
writeln!(
|
||||
f,
|
||||
"{}",
|
||||
json!({
|
||||
"record_type": "session_meta",
|
||||
"id": Uuid::new_v4(),
|
||||
"timestamp": "2024-01-01T00:00:00Z",
|
||||
"cwd": tmpdir.path().to_string_lossy(),
|
||||
"originator": "test",
|
||||
"cli_version": "0.0.0-test"
|
||||
"timestamp": "2024-01-01T00:00:00.000Z",
|
||||
"type": "session_meta",
|
||||
"payload": {
|
||||
"id": convo_id,
|
||||
"timestamp": "2024-01-01T00:00:00Z",
|
||||
"instructions": "be nice",
|
||||
"cwd": ".",
|
||||
"originator": "test_originator",
|
||||
"cli_version": "test_version"
|
||||
}
|
||||
})
|
||||
)
|
||||
.unwrap();
|
||||
@@ -148,30 +151,17 @@ async fn resume_includes_initial_messages_and_sends_prior_items() {
|
||||
text: "resumed user message".to_string(),
|
||||
}],
|
||||
};
|
||||
let mut prior_user_obj = serde_json::to_value(&prior_user)
|
||||
.unwrap()
|
||||
.as_object()
|
||||
.unwrap()
|
||||
.clone();
|
||||
prior_user_obj.insert("record_type".to_string(), serde_json::json!("response"));
|
||||
prior_user_obj.insert(
|
||||
"timestamp".to_string(),
|
||||
serde_json::json!("2025-01-01T00:00:00Z"),
|
||||
);
|
||||
writeln!(f, "{}", serde_json::Value::Object(prior_user_obj)).unwrap();
|
||||
|
||||
// Also include a matching user message event to preserve ordering at resume
|
||||
let prior_user_event = EventMsg::UserMessage(UserMessageEvent {
|
||||
message: "resumed user message".to_string(),
|
||||
kind: Some(InputMessageKind::Plain),
|
||||
});
|
||||
let prior_user_event_line = serde_json::json!({
|
||||
"timestamp": "2025-01-01T00:00:00Z",
|
||||
"record_type": "event",
|
||||
"id": "resume-0",
|
||||
"msg": prior_user_event,
|
||||
});
|
||||
writeln!(f, "{prior_user_event_line}").unwrap();
|
||||
let prior_user_json = serde_json::to_value(&prior_user).unwrap();
|
||||
writeln!(
|
||||
f,
|
||||
"{}",
|
||||
json!({
|
||||
"timestamp": "2024-01-01T00:00:01.000Z",
|
||||
"type": "response_item",
|
||||
"payload": prior_user_json
|
||||
})
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// Prior item: system message (excluded from API history)
|
||||
let prior_system = codex_protocol::models::ResponseItem::Message {
|
||||
@@ -181,17 +171,17 @@ async fn resume_includes_initial_messages_and_sends_prior_items() {
|
||||
text: "resumed system instruction".to_string(),
|
||||
}],
|
||||
};
|
||||
let mut prior_system_obj = serde_json::to_value(&prior_system)
|
||||
.unwrap()
|
||||
.as_object()
|
||||
.unwrap()
|
||||
.clone();
|
||||
prior_system_obj.insert("record_type".to_string(), serde_json::json!("response"));
|
||||
prior_system_obj.insert(
|
||||
"timestamp".to_string(),
|
||||
serde_json::json!("2025-01-01T00:00:00Z"),
|
||||
);
|
||||
writeln!(f, "{}", serde_json::Value::Object(prior_system_obj)).unwrap();
|
||||
let prior_system_json = serde_json::to_value(&prior_system).unwrap();
|
||||
writeln!(
|
||||
f,
|
||||
"{}",
|
||||
json!({
|
||||
"timestamp": "2024-01-01T00:00:02.000Z",
|
||||
"type": "response_item",
|
||||
"payload": prior_system_json
|
||||
})
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// Prior item: assistant message
|
||||
let prior_item = codex_protocol::models::ResponseItem::Message {
|
||||
@@ -201,27 +191,17 @@ async fn resume_includes_initial_messages_and_sends_prior_items() {
|
||||
text: "resumed assistant message".to_string(),
|
||||
}],
|
||||
};
|
||||
let mut prior_item_obj = serde_json::to_value(&prior_item)
|
||||
.unwrap()
|
||||
.as_object()
|
||||
.unwrap()
|
||||
.clone();
|
||||
prior_item_obj.insert("record_type".to_string(), serde_json::json!("response"));
|
||||
prior_item_obj.insert(
|
||||
"timestamp".to_string(),
|
||||
serde_json::json!("2025-01-01T00:00:00Z"),
|
||||
);
|
||||
writeln!(f, "{}", serde_json::Value::Object(prior_item_obj)).unwrap();
|
||||
let prior_item_event = EventMsg::AgentMessage(AgentMessageEvent {
|
||||
message: "resumed assistant message".to_string(),
|
||||
});
|
||||
let prior_event_line = serde_json::json!({
|
||||
"timestamp": "2025-01-01T00:00:00Z",
|
||||
"record_type": "event",
|
||||
"id": "resume-1",
|
||||
"msg": prior_item_event,
|
||||
});
|
||||
writeln!(f, "{prior_event_line}").unwrap();
|
||||
let prior_item_json = serde_json::to_value(&prior_item).unwrap();
|
||||
writeln!(
|
||||
f,
|
||||
"{}",
|
||||
json!({
|
||||
"timestamp": "2024-01-01T00:00:03.000Z",
|
||||
"type": "response_item",
|
||||
"payload": prior_item_json
|
||||
})
|
||||
)
|
||||
.unwrap();
|
||||
drop(f);
|
||||
|
||||
// Mock server that will receive the resumed request
|
||||
@@ -243,6 +223,8 @@ async fn resume_includes_initial_messages_and_sends_prior_items() {
|
||||
};
|
||||
let codex_home = TempDir::new().unwrap();
|
||||
let mut config = load_default_config_for_test(&codex_home);
|
||||
let cwd = TempDir::new().unwrap();
|
||||
config.cwd = cwd.path().to_path_buf();
|
||||
config.model_provider = model_provider;
|
||||
config.experimental_resume = Some(session_path.clone());
|
||||
// Also configure user instructions to ensure they are NOT delivered on resume.
|
||||
@@ -259,16 +241,13 @@ async fn resume_includes_initial_messages_and_sends_prior_items() {
|
||||
.await
|
||||
.expect("create new conversation");
|
||||
|
||||
// 1) Assert initial_messages contains the prior user + assistant messages as EventMsg entries
|
||||
// 1) Assert initial_messages only includes existing EventMsg entries; response items are not converted
|
||||
let initial_msgs = session_configured
|
||||
.initial_messages
|
||||
.clone()
|
||||
.expect("expected initial messages for resumed session");
|
||||
.expect("expected initial messages option for resumed session");
|
||||
let initial_json = serde_json::to_value(&initial_msgs).unwrap();
|
||||
let expected_initial_json = json!([
|
||||
{ "type": "user_message", "message": "resumed user message", "kind": "plain" },
|
||||
{ "type": "agent_message", "message": "resumed assistant message" }
|
||||
]);
|
||||
let expected_initial_json = json!([]);
|
||||
assert_eq!(initial_json, expected_initial_json);
|
||||
|
||||
// 2) Submit new input; the request body must include the prior item followed by the new user input.
|
||||
@@ -284,6 +263,29 @@ async fn resume_includes_initial_messages_and_sends_prior_items() {
|
||||
|
||||
let request = &server.received_requests().await.unwrap()[0];
|
||||
let request_body = request.body_json::<serde_json::Value>().unwrap();
|
||||
|
||||
// Build expected environment context for this turn.
|
||||
let shell = default_user_shell().await;
|
||||
let shell_line = match shell.name() {
|
||||
Some(name) => format!(" <shell>{name}</shell>\n"),
|
||||
None => String::new(),
|
||||
};
|
||||
let expected_env_text_turn = format!(
|
||||
r#"<environment_context>
|
||||
<cwd>{}</cwd>
|
||||
<approval_policy>on-request</approval_policy>
|
||||
<sandbox_mode>read-only</sandbox_mode>
|
||||
<network_access>restricted</network_access>
|
||||
{}</environment_context>"#,
|
||||
cwd.path().to_string_lossy(),
|
||||
shell_line.as_str(),
|
||||
);
|
||||
let expected_env_msg_turn = json!({
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": [ { "type": "input_text", "text": expected_env_text_turn } ]
|
||||
});
|
||||
|
||||
let expected_input = json!([
|
||||
{
|
||||
"type": "message",
|
||||
@@ -295,12 +297,14 @@ async fn resume_includes_initial_messages_and_sends_prior_items() {
|
||||
"role": "assistant",
|
||||
"content": [{ "type": "output_text", "text": "resumed assistant message" }]
|
||||
},
|
||||
expected_env_msg_turn,
|
||||
{
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": [{ "type": "input_text", "text": "hello" }]
|
||||
}
|
||||
]);
|
||||
|
||||
assert_eq!(request_body["input"], expected_input);
|
||||
}
|
||||
|
||||
@@ -434,56 +438,6 @@ async fn includes_base_instructions_override_in_request() {
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn originator_config_override_is_used() {
|
||||
// Mock server
|
||||
let server = MockServer::start().await;
|
||||
|
||||
let first = ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "text/event-stream")
|
||||
.set_body_raw(sse_completed("resp1"), "text/event-stream");
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.respond_with(first)
|
||||
.expect(1)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let model_provider = ModelProviderInfo {
|
||||
base_url: Some(format!("{}/v1", server.uri())),
|
||||
..built_in_model_providers()["openai"].clone()
|
||||
};
|
||||
|
||||
let codex_home = TempDir::new().unwrap();
|
||||
let mut config = load_default_config_for_test(&codex_home);
|
||||
config.model_provider = model_provider;
|
||||
config.responses_originator_header = "my_override".to_owned();
|
||||
|
||||
let conversation_manager =
|
||||
ConversationManager::with_auth(CodexAuth::from_api_key("Test API Key"));
|
||||
let codex = conversation_manager
|
||||
.new_conversation(config)
|
||||
.await
|
||||
.expect("create new conversation")
|
||||
.conversation;
|
||||
|
||||
codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![InputItem::Text {
|
||||
text: "hello".into(),
|
||||
}],
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
let request = &server.received_requests().await.unwrap()[0];
|
||||
let request_originator = request.headers.get("originator").unwrap();
|
||||
assert_eq!(request_originator.to_str().unwrap(), "my_override");
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn chatgpt_auth_sends_correct_request() {
|
||||
if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
||||
@@ -563,82 +517,6 @@ async fn chatgpt_auth_sends_correct_request() {
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn prefers_chatgpt_token_when_config_prefers_chatgpt() {
|
||||
if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
||||
println!(
|
||||
"Skipping test because it cannot execute when network is disabled in a Codex sandbox."
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
// Mock server
|
||||
let server = MockServer::start().await;
|
||||
|
||||
let first = ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "text/event-stream")
|
||||
.set_body_raw(sse_completed("resp1"), "text/event-stream");
|
||||
|
||||
// Expect ChatGPT base path and correct headers
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.and(header_regex("Authorization", r"Bearer Access-123"))
|
||||
.and(header_regex("chatgpt-account-id", r"acc-123"))
|
||||
.respond_with(first)
|
||||
.expect(1)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let model_provider = ModelProviderInfo {
|
||||
base_url: Some(format!("{}/v1", server.uri())),
|
||||
..built_in_model_providers()["openai"].clone()
|
||||
};
|
||||
|
||||
// Init session
|
||||
let codex_home = TempDir::new().unwrap();
|
||||
// Write auth.json that contains both API key and ChatGPT tokens for a plan that should prefer ChatGPT.
|
||||
let _jwt = write_auth_json(
|
||||
&codex_home,
|
||||
Some("sk-test-key"),
|
||||
"pro",
|
||||
"Access-123",
|
||||
Some("acc-123"),
|
||||
);
|
||||
|
||||
let mut config = load_default_config_for_test(&codex_home);
|
||||
config.model_provider = model_provider;
|
||||
config.preferred_auth_method = AuthMode::ChatGPT;
|
||||
|
||||
let auth_manager = match CodexAuth::from_codex_home(
|
||||
codex_home.path(),
|
||||
config.preferred_auth_method,
|
||||
&config.responses_originator_header,
|
||||
) {
|
||||
Ok(Some(auth)) => codex_core::AuthManager::from_auth_for_testing(auth),
|
||||
Ok(None) => panic!("No CodexAuth found in codex_home"),
|
||||
Err(e) => panic!("Failed to load CodexAuth: {e}"),
|
||||
};
|
||||
let conversation_manager = ConversationManager::new(auth_manager);
|
||||
let NewConversation {
|
||||
conversation: codex,
|
||||
..
|
||||
} = conversation_manager
|
||||
.new_conversation(config)
|
||||
.await
|
||||
.expect("create new conversation");
|
||||
|
||||
codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![InputItem::Text {
|
||||
text: "hello".into(),
|
||||
}],
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn prefers_apikey_when_config_prefers_apikey_even_with_chatgpt_tokens() {
|
||||
if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
||||
@@ -683,13 +561,8 @@ async fn prefers_apikey_when_config_prefers_apikey_even_with_chatgpt_tokens() {
|
||||
|
||||
let mut config = load_default_config_for_test(&codex_home);
|
||||
config.model_provider = model_provider;
|
||||
config.preferred_auth_method = AuthMode::ApiKey;
|
||||
|
||||
let auth_manager = match CodexAuth::from_codex_home(
|
||||
codex_home.path(),
|
||||
config.preferred_auth_method,
|
||||
&config.responses_originator_header,
|
||||
) {
|
||||
let auth_manager = match CodexAuth::from_codex_home(codex_home.path()) {
|
||||
Ok(Some(auth)) => codex_core::AuthManager::from_auth_for_testing(auth),
|
||||
Ok(None) => panic!("No CodexAuth found in codex_home"),
|
||||
Err(e) => panic!("Failed to load CodexAuth: {e}"),
|
||||
@@ -994,7 +867,7 @@ async fn history_dedupes_streamed_and_final_messages_across_turns() {
|
||||
conversation: codex,
|
||||
..
|
||||
} = conversation_manager
|
||||
.new_conversation(config)
|
||||
.new_conversation(config.clone())
|
||||
.await
|
||||
.expect("create new conversation");
|
||||
|
||||
@@ -1029,34 +902,49 @@ async fn history_dedupes_streamed_and_final_messages_across_turns() {
|
||||
let requests = server.received_requests().await.unwrap();
|
||||
assert_eq!(requests.len(), 3, "expected 3 requests (one per turn)");
|
||||
|
||||
// Replace full-array compare with tail-only raw JSON compare using a single hard-coded value.
|
||||
let r3_tail_expected = json!([
|
||||
{
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": [{"type":"input_text","text":"U1"}]
|
||||
},
|
||||
{
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [{"type":"output_text","text":"Hey there!\n"}]
|
||||
},
|
||||
{
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": [{"type":"input_text","text":"U2"}]
|
||||
},
|
||||
{
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [{"type":"output_text","text":"Hey there!\n"}]
|
||||
},
|
||||
{
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": [{"type":"input_text","text":"U3"}]
|
||||
}
|
||||
]);
|
||||
// Build expected environment context dynamically to avoid OS-dependent flakiness.
|
||||
let user_instructions = get_user_instructions(&config).await;
|
||||
let shell = default_user_shell().await;
|
||||
let shell_line = match shell.name() {
|
||||
Some(name) => format!(" <shell>{name}</shell>\n"),
|
||||
None => String::new(),
|
||||
};
|
||||
let expected_env_text = format!(
|
||||
r#"<environment_context>
|
||||
<cwd>{}</cwd>
|
||||
<approval_policy>on-request</approval_policy>
|
||||
<sandbox_mode>read-only</sandbox_mode>
|
||||
<network_access>restricted</network_access>
|
||||
{}</environment_context>"#,
|
||||
std::env::current_dir().unwrap().to_string_lossy(),
|
||||
shell_line.as_str(),
|
||||
);
|
||||
let expected_env_msg = json!({
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": [ { "type": "input_text", "text": expected_env_text } ]
|
||||
});
|
||||
// Wrap user instructions in the XML container to match the raw/ingest view
|
||||
let expected_ui_text = format!(
|
||||
"<user_instructions>\n\n{}\n\n</user_instructions>",
|
||||
user_instructions.clone().unwrap()
|
||||
);
|
||||
let expected_ui_msg = json!({
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": [ { "type": "input_text", "text": expected_ui_text } ]
|
||||
});
|
||||
|
||||
let expected_full = json!([
|
||||
expected_ui_msg,
|
||||
expected_env_msg.clone(),
|
||||
{"type":"message","role":"user","content":[{"type":"input_text","text":"U1"}]},
|
||||
{"type":"message","role":"assistant","content":[{"type":"output_text","text":"Hey there!\n"}]},
|
||||
expected_env_msg.clone(),
|
||||
{"type":"message","role":"user","content":[{"type":"input_text","text":"U2"}]},
|
||||
{"type":"message","role":"assistant","content":[{"type":"output_text","text":"Hey there!\n"}]},
|
||||
expected_env_msg,
|
||||
{"type":"message","role":"user","content":[{"type":"input_text","text":"U3"}]}]);
|
||||
|
||||
let r3_input_array = requests[2]
|
||||
.body_json::<serde_json::Value>()
|
||||
@@ -1065,12 +953,6 @@ async fn history_dedupes_streamed_and_final_messages_across_turns() {
|
||||
.and_then(|v| v.as_array())
|
||||
.cloned()
|
||||
.expect("r3 missing input array");
|
||||
// skipping earlier context and developer messages
|
||||
let tail_len = r3_tail_expected.as_array().unwrap().len();
|
||||
let actual_tail = &r3_input_array[r3_input_array.len() - tail_len..];
|
||||
assert_eq!(
|
||||
serde_json::Value::Array(actual_tail.to_vec()),
|
||||
r3_tail_expected,
|
||||
"request 3 tail mismatch",
|
||||
);
|
||||
|
||||
assert_eq!(json!(r3_input_array), expected_full);
|
||||
}
|
||||
|
||||
@@ -3,10 +3,13 @@
|
||||
use codex_core::CodexAuth;
|
||||
use codex_core::ConversationManager;
|
||||
use codex_core::ModelProviderInfo;
|
||||
use codex_core::NewConversation;
|
||||
use codex_core::built_in_model_providers;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_core::protocol::RolloutItem;
|
||||
use codex_core::protocol::RolloutLine;
|
||||
use codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR;
|
||||
use core_test_support::load_default_config_for_test;
|
||||
use core_test_support::wait_for_event;
|
||||
@@ -142,11 +145,12 @@ async fn summarize_context_three_requests_and_instructions() {
|
||||
let mut config = load_default_config_for_test(&home);
|
||||
config.model_provider = model_provider;
|
||||
let conversation_manager = ConversationManager::with_auth(CodexAuth::from_api_key("dummy"));
|
||||
let codex = conversation_manager
|
||||
.new_conversation(config)
|
||||
.await
|
||||
.unwrap()
|
||||
.conversation;
|
||||
let NewConversation {
|
||||
conversation: codex,
|
||||
session_configured,
|
||||
..
|
||||
} = conversation_manager.new_conversation(config).await.unwrap();
|
||||
let rollout_path = session_configured.rollout_path;
|
||||
|
||||
// 1) Normal user input – should hit server once.
|
||||
codex
|
||||
@@ -248,4 +252,47 @@ async fn summarize_context_three_requests_and_instructions() {
|
||||
!messages.iter().any(|(_, t)| t.contains(SUMMARIZE_TRIGGER)),
|
||||
"third request should not include the summarize trigger"
|
||||
);
|
||||
|
||||
// Shut down Codex to flush rollout entries before inspecting the file.
|
||||
codex.submit(Op::Shutdown).await.unwrap();
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::ShutdownComplete)).await;
|
||||
|
||||
// Verify rollout contains APITurn entries for each API call and a Compacted entry.
|
||||
let text = std::fs::read_to_string(&rollout_path).unwrap_or_else(|e| {
|
||||
panic!(
|
||||
"failed to read rollout file {}: {e}",
|
||||
rollout_path.display()
|
||||
)
|
||||
});
|
||||
let mut api_turn_count = 0usize;
|
||||
let mut saw_compacted_summary = false;
|
||||
for line in text.lines() {
|
||||
let trimmed = line.trim();
|
||||
if trimmed.is_empty() {
|
||||
continue;
|
||||
}
|
||||
let Ok(entry): Result<RolloutLine, _> = serde_json::from_str(trimmed) else {
|
||||
continue;
|
||||
};
|
||||
match entry.item {
|
||||
RolloutItem::TurnContext(_) => {
|
||||
api_turn_count += 1;
|
||||
}
|
||||
RolloutItem::Compacted(ci) => {
|
||||
if ci.message == SUMMARY_TEXT {
|
||||
saw_compacted_summary = true;
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
assert!(
|
||||
api_turn_count == 3,
|
||||
"expected three APITurn entries in rollout"
|
||||
);
|
||||
assert!(
|
||||
saw_compacted_summary,
|
||||
"expected a Compacted entry containing the summarizer output"
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,13 +1,16 @@
|
||||
use codex_core::CodexAuth;
|
||||
use codex_core::ContentItem;
|
||||
use codex_core::ConversationManager;
|
||||
use codex_core::ModelProviderInfo;
|
||||
use codex_core::NewConversation;
|
||||
use codex_core::ResponseItem;
|
||||
use codex_core::built_in_model_providers;
|
||||
use codex_core::protocol::ConversationPathResponseEvent;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_core::protocol::RolloutItem;
|
||||
use codex_core::protocol::RolloutLine;
|
||||
use core_test_support::load_default_config_for_test;
|
||||
use core_test_support::wait_for_event;
|
||||
use tempfile::TempDir;
|
||||
@@ -72,105 +75,121 @@ async fn fork_conversation_twice_drops_to_first_message() {
|
||||
let _ = wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
}
|
||||
|
||||
// Request history from the base conversation.
|
||||
codex.submit(Op::GetConversationPath).await.unwrap();
|
||||
// Request history from the base conversation to obtain rollout path.
|
||||
codex.submit(Op::GetPath).await.unwrap();
|
||||
let base_history =
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::ConversationHistory(_))).await;
|
||||
|
||||
// Capture path/id from the base history and compute expected prefixes after each fork.
|
||||
let (base_conv_id, base_path) = match &base_history {
|
||||
EventMsg::ConversationHistory(ConversationPathResponseEvent {
|
||||
conversation_id,
|
||||
path,
|
||||
}) => (*conversation_id, path.clone()),
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::ConversationPath(_))).await;
|
||||
let base_path = match &base_history {
|
||||
EventMsg::ConversationPath(ConversationPathResponseEvent { path, .. }) => path.clone(),
|
||||
_ => panic!("expected ConversationHistory event"),
|
||||
};
|
||||
|
||||
// Read entries from rollout file.
|
||||
async fn read_response_entries(path: &std::path::Path) -> Vec<ResponseItem> {
|
||||
let text = tokio::fs::read_to_string(path).await.unwrap_or_default();
|
||||
let mut out = Vec::new();
|
||||
// GetHistory flushes before returning the path; no wait needed.
|
||||
|
||||
// Helper: read rollout items (excluding SessionMeta) from a JSONL path.
|
||||
let read_items = |p: &std::path::Path| -> Vec<RolloutItem> {
|
||||
let text = std::fs::read_to_string(p).expect("read rollout file");
|
||||
let mut items: Vec<RolloutItem> = Vec::new();
|
||||
for line in text.lines() {
|
||||
if line.trim().is_empty() {
|
||||
continue;
|
||||
}
|
||||
if let Ok(item) = serde_json::from_str::<ResponseItem>(line) {
|
||||
out.push(item);
|
||||
let v: serde_json::Value = serde_json::from_str(line).expect("jsonl line");
|
||||
let rl: RolloutLine = serde_json::from_value(v).expect("rollout line");
|
||||
match rl.item {
|
||||
RolloutItem::SessionMeta(_) => {}
|
||||
other => items.push(other),
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
let entries_after_three: Vec<ResponseItem> = read_response_entries(&base_path).await;
|
||||
// History layout for this test:
|
||||
// [0] user instructions,
|
||||
// [1] environment context,
|
||||
// [2] "first" user message,
|
||||
// [3] "second" user message,
|
||||
// [4] "third" user message.
|
||||
items
|
||||
};
|
||||
|
||||
// Fork 1: drops the last user message and everything after.
|
||||
let expected_after_first = vec![
|
||||
entries_after_three[0].clone(),
|
||||
entries_after_three[1].clone(),
|
||||
entries_after_three[2].clone(),
|
||||
entries_after_three[3].clone(),
|
||||
];
|
||||
// Compute expected prefixes after each fork by truncating base rollout at nth-from-last user input.
|
||||
let base_items = read_items(&base_path);
|
||||
let find_user_input_positions = |items: &[RolloutItem]| -> Vec<usize> {
|
||||
let mut pos = Vec::new();
|
||||
for (i, it) in items.iter().enumerate() {
|
||||
if let RolloutItem::ResponseItem(ResponseItem::Message { role, content, .. }) = it
|
||||
&& role == "user"
|
||||
{
|
||||
// Consider any user message as an input boundary; recorder stores both EventMsg and ResponseItem.
|
||||
// We specifically look for input items, which are represented as ContentItem::InputText.
|
||||
if content
|
||||
.iter()
|
||||
.any(|c| matches!(c, ContentItem::InputText { .. }))
|
||||
{
|
||||
pos.push(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
pos
|
||||
};
|
||||
let user_inputs = find_user_input_positions(&base_items);
|
||||
|
||||
// Fork 2: drops the last user message and everything after.
|
||||
// [0] user instructions,
|
||||
// [1] environment context,
|
||||
// [2] "first" user message,
|
||||
let expected_after_second = vec![
|
||||
entries_after_three[0].clone(),
|
||||
entries_after_three[1].clone(),
|
||||
entries_after_three[2].clone(),
|
||||
];
|
||||
// After dropping last user input (n=1), cut strictly before that input if present, else empty.
|
||||
let cut1 = user_inputs
|
||||
.get(user_inputs.len().saturating_sub(1))
|
||||
.copied()
|
||||
.unwrap_or(0);
|
||||
let expected_after_first: Vec<RolloutItem> = base_items[..cut1].to_vec();
|
||||
|
||||
// Fork once with n=1 → drops the last user message and everything after.
|
||||
// After dropping again (n=1 on fork1), compute expected relative to fork1's rollout.
|
||||
|
||||
// Fork once with n=1 → drops the last user input and everything after.
|
||||
let NewConversation {
|
||||
conversation: codex_fork1,
|
||||
..
|
||||
} = conversation_manager
|
||||
.fork_conversation(base_path.clone(), base_conv_id, 1, config_for_fork.clone())
|
||||
.fork_conversation(1, config_for_fork.clone(), base_path.clone())
|
||||
.await
|
||||
.expect("fork 1");
|
||||
|
||||
codex_fork1.submit(Op::GetConversationPath).await.unwrap();
|
||||
codex_fork1.submit(Op::GetPath).await.unwrap();
|
||||
let fork1_history = wait_for_event(&codex_fork1, |ev| {
|
||||
matches!(ev, EventMsg::ConversationHistory(_))
|
||||
matches!(ev, EventMsg::ConversationPath(_))
|
||||
})
|
||||
.await;
|
||||
let (fork1_id, fork1_path) = match &fork1_history {
|
||||
EventMsg::ConversationHistory(ConversationPathResponseEvent {
|
||||
conversation_id,
|
||||
path,
|
||||
}) => (*conversation_id, path.clone()),
|
||||
let fork1_path = match &fork1_history {
|
||||
EventMsg::ConversationPath(ConversationPathResponseEvent { path, .. }) => path.clone(),
|
||||
_ => panic!("expected ConversationHistory event after first fork"),
|
||||
};
|
||||
let entries_after_first_fork: Vec<ResponseItem> = read_response_entries(&fork1_path).await;
|
||||
assert_eq!(entries_after_first_fork, expected_after_first);
|
||||
|
||||
// GetHistory on fork1 flushed; the file is ready.
|
||||
let fork1_items = read_items(&fork1_path);
|
||||
pretty_assertions::assert_eq!(
|
||||
serde_json::to_value(&fork1_items).unwrap(),
|
||||
serde_json::to_value(&expected_after_first).unwrap()
|
||||
);
|
||||
|
||||
// Fork again with n=1 → drops the (new) last user message, leaving only the first.
|
||||
let NewConversation {
|
||||
conversation: codex_fork2,
|
||||
..
|
||||
} = conversation_manager
|
||||
.fork_conversation(fork1_path.clone(), fork1_id, 1, config_for_fork.clone())
|
||||
.fork_conversation(1, config_for_fork.clone(), fork1_path.clone())
|
||||
.await
|
||||
.expect("fork 2");
|
||||
|
||||
codex_fork2.submit(Op::GetConversationPath).await.unwrap();
|
||||
codex_fork2.submit(Op::GetPath).await.unwrap();
|
||||
let fork2_history = wait_for_event(&codex_fork2, |ev| {
|
||||
matches!(ev, EventMsg::ConversationHistory(_))
|
||||
matches!(ev, EventMsg::ConversationPath(_))
|
||||
})
|
||||
.await;
|
||||
let (_fork2_id, fork2_path) = match &fork2_history {
|
||||
EventMsg::ConversationHistory(ConversationPathResponseEvent {
|
||||
conversation_id,
|
||||
path,
|
||||
}) => (*conversation_id, path.clone()),
|
||||
let fork2_path = match &fork2_history {
|
||||
EventMsg::ConversationPath(ConversationPathResponseEvent { path, .. }) => path.clone(),
|
||||
_ => panic!("expected ConversationHistory event after second fork"),
|
||||
};
|
||||
let entries_after_second_fork: Vec<ResponseItem> = read_response_entries(&fork2_path).await;
|
||||
assert_eq!(entries_after_second_fork, expected_after_second);
|
||||
// GetHistory on fork2 flushed; the file is ready.
|
||||
let fork1_items = read_items(&fork1_path);
|
||||
let fork1_user_inputs = find_user_input_positions(&fork1_items);
|
||||
let cut_last_on_fork1 = fork1_user_inputs
|
||||
.get(fork1_user_inputs.len().saturating_sub(1))
|
||||
.copied()
|
||||
.unwrap_or(0);
|
||||
let expected_after_second: Vec<RolloutItem> = fork1_items[..cut_last_on_fork1].to_vec();
|
||||
let fork2_items = read_items(&fork2_path);
|
||||
pretty_assertions::assert_eq!(
|
||||
serde_json::to_value(&fork2_items).unwrap(),
|
||||
serde_json::to_value(&expected_after_second).unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ mod exec;
|
||||
mod exec_stream_events;
|
||||
mod fork_conversation;
|
||||
mod live_cli;
|
||||
mod model_overrides;
|
||||
mod prompt_caching;
|
||||
mod seatbelt;
|
||||
mod stream_error_allows_next_turn;
|
||||
|
||||
92
codex-rs/core/tests/suite/model_overrides.rs
Normal file
92
codex-rs/core/tests/suite/model_overrides.rs
Normal file
@@ -0,0 +1,92 @@
|
||||
use codex_core::CodexAuth;
|
||||
use codex_core::ConversationManager;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_core::protocol_config_types::ReasoningEffort;
|
||||
use core_test_support::load_default_config_for_test;
|
||||
use core_test_support::wait_for_event;
|
||||
use pretty_assertions::assert_eq;
|
||||
use tempfile::TempDir;
|
||||
|
||||
const CONFIG_TOML: &str = "config.toml";
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn override_turn_context_does_not_persist_when_config_exists() {
|
||||
let codex_home = TempDir::new().unwrap();
|
||||
let config_path = codex_home.path().join(CONFIG_TOML);
|
||||
let initial_contents = "model = \"gpt-4o\"\n";
|
||||
tokio::fs::write(&config_path, initial_contents)
|
||||
.await
|
||||
.expect("seed config.toml");
|
||||
|
||||
let mut config = load_default_config_for_test(&codex_home);
|
||||
config.model = "gpt-4o".to_string();
|
||||
|
||||
let conversation_manager =
|
||||
ConversationManager::with_auth(CodexAuth::from_api_key("Test API Key"));
|
||||
let codex = conversation_manager
|
||||
.new_conversation(config)
|
||||
.await
|
||||
.expect("create conversation")
|
||||
.conversation;
|
||||
|
||||
codex
|
||||
.submit(Op::OverrideTurnContext {
|
||||
cwd: None,
|
||||
approval_policy: None,
|
||||
sandbox_policy: None,
|
||||
model: Some("o3".to_string()),
|
||||
effort: Some(ReasoningEffort::High),
|
||||
summary: None,
|
||||
})
|
||||
.await
|
||||
.expect("submit override");
|
||||
|
||||
codex.submit(Op::Shutdown).await.expect("request shutdown");
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::ShutdownComplete)).await;
|
||||
|
||||
let contents = tokio::fs::read_to_string(&config_path)
|
||||
.await
|
||||
.expect("read config.toml after override");
|
||||
assert_eq!(contents, initial_contents);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn override_turn_context_does_not_create_config_file() {
|
||||
let codex_home = TempDir::new().unwrap();
|
||||
let config_path = codex_home.path().join(CONFIG_TOML);
|
||||
assert!(
|
||||
!config_path.exists(),
|
||||
"test setup should start without config"
|
||||
);
|
||||
|
||||
let config = load_default_config_for_test(&codex_home);
|
||||
|
||||
let conversation_manager =
|
||||
ConversationManager::with_auth(CodexAuth::from_api_key("Test API Key"));
|
||||
let codex = conversation_manager
|
||||
.new_conversation(config)
|
||||
.await
|
||||
.expect("create conversation")
|
||||
.conversation;
|
||||
|
||||
codex
|
||||
.submit(Op::OverrideTurnContext {
|
||||
cwd: None,
|
||||
approval_policy: None,
|
||||
sandbox_policy: None,
|
||||
model: Some("o3".to_string()),
|
||||
effort: Some(ReasoningEffort::Medium),
|
||||
summary: None,
|
||||
})
|
||||
.await
|
||||
.expect("submit override");
|
||||
|
||||
codex.submit(Op::Shutdown).await.expect("request shutdown");
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::ShutdownComplete)).await;
|
||||
|
||||
assert!(
|
||||
!config_path.exists(),
|
||||
"override should not create config.toml"
|
||||
);
|
||||
}
|
||||
@@ -270,8 +270,13 @@ async fn prefixes_context_and_instructions_once_and_consistently_across_requests
|
||||
assert_eq!(requests.len(), 2, "expected two POST requests");
|
||||
|
||||
let shell = default_user_shell().await;
|
||||
let shell_line = match shell.name() {
|
||||
Some(name) => format!(" <shell>{name}</shell>\n"),
|
||||
None => String::new(),
|
||||
};
|
||||
|
||||
let expected_env_text = format!(
|
||||
// Per-turn environment context includes the shell tag.
|
||||
let expected_env_text_turn = format!(
|
||||
r#"<environment_context>
|
||||
<cwd>{}</cwd>
|
||||
<approval_policy>on-request</approval_policy>
|
||||
@@ -279,18 +284,15 @@ async fn prefixes_context_and_instructions_once_and_consistently_across_requests
|
||||
<network_access>restricted</network_access>
|
||||
{}</environment_context>"#,
|
||||
cwd.path().to_string_lossy(),
|
||||
match shell.name() {
|
||||
Some(name) => format!(" <shell>{name}</shell>\n"),
|
||||
None => String::new(),
|
||||
}
|
||||
shell_line.as_str(),
|
||||
);
|
||||
let expected_ui_text =
|
||||
"<user_instructions>\n\nbe consistent and helpful\n\n</user_instructions>";
|
||||
|
||||
let expected_env_msg = serde_json::json!({
|
||||
let expected_env_msg_turn = serde_json::json!({
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": [ { "type": "input_text", "text": expected_env_text } ]
|
||||
"content": [ { "type": "input_text", "text": expected_env_text_turn } ]
|
||||
});
|
||||
let expected_ui_msg = serde_json::json!({
|
||||
"type": "message",
|
||||
@@ -304,11 +306,29 @@ async fn prefixes_context_and_instructions_once_and_consistently_across_requests
|
||||
"content": [ { "type": "input_text", "text": "hello 1" } ]
|
||||
});
|
||||
let body1 = requests[0].body_json::<serde_json::Value>().unwrap();
|
||||
let body1_input = body1["input"].as_array().unwrap();
|
||||
assert_eq!(
|
||||
body1["input"],
|
||||
serde_json::json!([expected_ui_msg, expected_env_msg, expected_user_message_1])
|
||||
serde_json::json!([
|
||||
expected_ui_msg,
|
||||
expected_env_msg_turn,
|
||||
expected_user_message_1
|
||||
])
|
||||
);
|
||||
|
||||
let env_texts: Vec<&str> = body1_input
|
||||
.iter()
|
||||
.filter_map(|msg| {
|
||||
msg.get("content")
|
||||
.and_then(|content| content.as_array())
|
||||
.and_then(|content| content.first())
|
||||
.and_then(|item| item.get("text"))
|
||||
.and_then(|text| text.as_str())
|
||||
})
|
||||
.filter(|text| text.starts_with("<environment_context>"))
|
||||
.collect();
|
||||
assert_eq!(env_texts, vec![expected_env_text_turn.as_str()]);
|
||||
|
||||
let expected_user_message_2 = serde_json::json!({
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
@@ -318,7 +338,7 @@ async fn prefixes_context_and_instructions_once_and_consistently_across_requests
|
||||
let expected_body2 = serde_json::json!(
|
||||
[
|
||||
body1["input"].as_array().unwrap().as_slice(),
|
||||
[expected_user_message_2].as_slice(),
|
||||
[expected_env_msg_turn, expected_user_message_2].as_slice(),
|
||||
]
|
||||
.concat()
|
||||
);
|
||||
@@ -423,14 +443,29 @@ async fn overrides_turn_context_but_keeps_cached_prefix_and_key_constant() {
|
||||
"role": "user",
|
||||
"content": [ { "type": "input_text", "text": "hello 2" } ]
|
||||
});
|
||||
let shell = default_user_shell().await;
|
||||
let shell_line = match shell.name() {
|
||||
Some(name) => format!(" <shell>{name}</shell>\n"),
|
||||
None => String::new(),
|
||||
};
|
||||
|
||||
// After overriding the turn context, the environment context should be emitted again
|
||||
// reflecting the new approval policy and sandbox settings. Omit cwd because it did
|
||||
// not change.
|
||||
let expected_env_text_2 = r#"<environment_context>
|
||||
let expected_env_text_2 = format!(
|
||||
r#"<environment_context>
|
||||
<cwd>{}</cwd>
|
||||
<approval_policy>never</approval_policy>
|
||||
<sandbox_mode>workspace-write</sandbox_mode>
|
||||
<network_access>enabled</network_access>
|
||||
</environment_context>"#;
|
||||
<writable_roots>
|
||||
<root>{}</root>
|
||||
</writable_roots>
|
||||
{}</environment_context>"#,
|
||||
cwd.path().to_string_lossy(),
|
||||
writable.path().to_string_lossy(),
|
||||
shell_line.as_str()
|
||||
);
|
||||
let expected_env_msg_2 = serde_json::json!({
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
@@ -540,12 +575,165 @@ async fn per_turn_overrides_keep_cached_prefix_and_key_constant() {
|
||||
"role": "user",
|
||||
"content": [ { "type": "input_text", "text": "hello 2" } ]
|
||||
});
|
||||
let shell = default_user_shell().await;
|
||||
let shell_line = match shell.name() {
|
||||
Some(name) => format!(" <shell>{name}</shell>\n"),
|
||||
None => String::new(),
|
||||
};
|
||||
let expected_env_text_2 = format!(
|
||||
r#"<environment_context>
|
||||
<cwd>{}</cwd>
|
||||
<approval_policy>never</approval_policy>
|
||||
<sandbox_mode>workspace-write</sandbox_mode>
|
||||
<network_access>enabled</network_access>
|
||||
<writable_roots>
|
||||
<root>{}</root>
|
||||
</writable_roots>
|
||||
{}</environment_context>"#,
|
||||
new_cwd.path().to_string_lossy(),
|
||||
writable.path().to_string_lossy(),
|
||||
shell_line.as_str()
|
||||
);
|
||||
let expected_env_msg_2 = serde_json::json!({
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": [ { "type": "input_text", "text": expected_env_text_2 } ]
|
||||
});
|
||||
let expected_body2 = serde_json::json!(
|
||||
[
|
||||
body1["input"].as_array().unwrap().as_slice(),
|
||||
[expected_user_message_2].as_slice(),
|
||||
[expected_env_msg_2, expected_user_message_2].as_slice(),
|
||||
]
|
||||
.concat()
|
||||
);
|
||||
assert_eq!(body2["input"], expected_body2);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn tools_stable_across_all_approval_policy_transitions() {
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
let server = MockServer::start().await;
|
||||
|
||||
let sse = sse_completed("resp");
|
||||
let template = ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "text/event-stream")
|
||||
.set_body_raw(sse, "text/event-stream");
|
||||
|
||||
// Build all transitions FROM each to each other (exclude self transitions)
|
||||
let policies = vec![
|
||||
AskForApproval::UnlessTrusted,
|
||||
AskForApproval::OnFailure,
|
||||
AskForApproval::OnRequest,
|
||||
AskForApproval::Never,
|
||||
];
|
||||
let mut transitions: Vec<(AskForApproval, AskForApproval)> = Vec::new();
|
||||
for &from in &policies {
|
||||
for &to in &policies {
|
||||
if from != to {
|
||||
transitions.push((from, to));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Expect 2 POSTs per transition
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.respond_with(template)
|
||||
.expect((transitions.len() * 2) as u64)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let model_provider = ModelProviderInfo {
|
||||
base_url: Some(format!("{}/v1", server.uri())),
|
||||
..built_in_model_providers()["openai"].clone()
|
||||
};
|
||||
|
||||
let cwd = TempDir::new().unwrap();
|
||||
let codex_home = TempDir::new().unwrap();
|
||||
let mut config = load_default_config_for_test(&codex_home);
|
||||
config.cwd = cwd.path().to_path_buf();
|
||||
config.model_provider = model_provider;
|
||||
config.user_instructions = Some("be consistent and helpful".to_string());
|
||||
// Keep tools stable and minimal
|
||||
config.include_plan_tool = false;
|
||||
config.include_apply_patch_tool = false;
|
||||
config.tools_web_search_request = false;
|
||||
config.use_experimental_unified_exec_tool = true; // policy-independent tool
|
||||
|
||||
let conversation_manager =
|
||||
ConversationManager::with_auth(CodexAuth::from_api_key("Test API Key"));
|
||||
let codex = conversation_manager
|
||||
.new_conversation(config)
|
||||
.await
|
||||
.expect("create new conversation")
|
||||
.conversation;
|
||||
|
||||
for (i, (from, to)) in transitions.iter().enumerate() {
|
||||
// Ensure a known starting policy for this pair
|
||||
codex
|
||||
.submit(Op::OverrideTurnContext {
|
||||
cwd: None,
|
||||
approval_policy: Some(*from),
|
||||
sandbox_policy: None,
|
||||
model: None,
|
||||
effort: None,
|
||||
summary: None,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![InputItem::Text {
|
||||
text: format!("turn {i}-a"),
|
||||
}],
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
// Override to the target policy and send next turn
|
||||
codex
|
||||
.submit(Op::OverrideTurnContext {
|
||||
cwd: None,
|
||||
approval_policy: Some(*to),
|
||||
sandbox_policy: None,
|
||||
model: None,
|
||||
effort: None,
|
||||
summary: None,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![InputItem::Text {
|
||||
text: format!("turn {i}-b"),
|
||||
}],
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
}
|
||||
|
||||
// Verify tool arrays are identical across each pair of requests
|
||||
let requests = server.received_requests().await.unwrap();
|
||||
assert_eq!(
|
||||
requests.len(),
|
||||
transitions.len() * 2,
|
||||
"expected 2 requests per transition"
|
||||
);
|
||||
|
||||
for i in 0..transitions.len() {
|
||||
let body_a = requests[2 * i].body_json::<serde_json::Value>().unwrap();
|
||||
let body_b = requests[2 * i + 1]
|
||||
.body_json::<serde_json::Value>()
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
body_a["tools"], body_b["tools"],
|
||||
"tools changed between requests for transition #{i}: {:?}",
|
||||
transitions[i]
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -159,6 +159,41 @@ async fn read_only_forbids_all_writes() {
|
||||
.await;
|
||||
}
|
||||
|
||||
/// Verify that user lookups via `pwd.getpwuid(os.getuid())` work under the
|
||||
/// seatbelt sandbox. Prior to allowing the necessary mach‑lookup for
|
||||
/// OpenDirectory libinfo, this would fail with `KeyError: getpwuid(): uid not found`.
|
||||
#[tokio::test]
|
||||
async fn python_getpwuid_works_under_seatbelt() {
|
||||
if std::env::var(CODEX_SANDBOX_ENV_VAR) == Ok("seatbelt".to_string()) {
|
||||
eprintln!("{CODEX_SANDBOX_ENV_VAR} is set to 'seatbelt', skipping test.");
|
||||
return;
|
||||
}
|
||||
|
||||
// ReadOnly is sufficient here since we are only exercising user lookup.
|
||||
let policy = SandboxPolicy::ReadOnly;
|
||||
|
||||
let mut child = spawn_command_under_seatbelt(
|
||||
vec![
|
||||
"python3".to_string(),
|
||||
"-c".to_string(),
|
||||
// Print the passwd struct; success implies lookup worked.
|
||||
"import pwd, os; print(pwd.getpwuid(os.getuid()))".to_string(),
|
||||
],
|
||||
&policy,
|
||||
std::env::current_dir().expect("should be able to get current dir"),
|
||||
StdioPolicy::RedirectForShellTool,
|
||||
HashMap::new(),
|
||||
)
|
||||
.await
|
||||
.expect("should be able to spawn python under seatbelt");
|
||||
|
||||
let status = child
|
||||
.wait()
|
||||
.await
|
||||
.expect("should be able to wait for child process");
|
||||
assert!(status.success(), "python exited with {status:?}");
|
||||
}
|
||||
|
||||
#[expect(clippy::expect_used)]
|
||||
fn create_test_scenario(tmp: &TempDir) -> TestScenario {
|
||||
let repo_parent = tmp.path().to_path_buf();
|
||||
|
||||
@@ -25,7 +25,6 @@ codex-common = { path = "../common", features = [
|
||||
"sandbox_summary",
|
||||
] }
|
||||
codex-core = { path = "../core" }
|
||||
codex-login = { path = "../login" }
|
||||
codex-ollama = { path = "../ollama" }
|
||||
codex-protocol = { path = "../protocol" }
|
||||
owo-colors = "4.2.0"
|
||||
|
||||
@@ -280,7 +280,7 @@ impl EventProcessor for EventProcessorWithHumanOutput {
|
||||
parsed_cmd: _,
|
||||
}) => {
|
||||
self.call_id_to_command.insert(
|
||||
call_id.clone(),
|
||||
call_id,
|
||||
ExecCommandBegin {
|
||||
command: command.clone(),
|
||||
},
|
||||
@@ -382,7 +382,7 @@ impl EventProcessor for EventProcessorWithHumanOutput {
|
||||
// Store metadata so we can calculate duration later when we
|
||||
// receive the corresponding PatchApplyEnd event.
|
||||
self.call_id_to_patch.insert(
|
||||
call_id.clone(),
|
||||
call_id,
|
||||
PatchApplyBegin {
|
||||
start_time: Instant::now(),
|
||||
auto_approved,
|
||||
@@ -520,9 +520,11 @@ impl EventProcessor for EventProcessorWithHumanOutput {
|
||||
let SessionConfiguredEvent {
|
||||
session_id: conversation_id,
|
||||
model,
|
||||
reasoning_effort: _,
|
||||
history_log_id: _,
|
||||
history_entry_count: _,
|
||||
initial_messages: _,
|
||||
rollout_path: _,
|
||||
} = session_configured_event;
|
||||
|
||||
ts_println!(
|
||||
@@ -558,7 +560,7 @@ impl EventProcessor for EventProcessorWithHumanOutput {
|
||||
}
|
||||
},
|
||||
EventMsg::ShutdownComplete => return CodexStatus::Shutdown,
|
||||
EventMsg::ConversationHistory(_) => {}
|
||||
EventMsg::ConversationPath(_) => {}
|
||||
EventMsg::UserMessage(_) => {}
|
||||
}
|
||||
CodexStatus::Running
|
||||
|
||||
@@ -187,11 +187,8 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
let conversation_manager = ConversationManager::new(AuthManager::shared(
|
||||
config.codex_home.clone(),
|
||||
config.preferred_auth_method,
|
||||
config.responses_originator_header.clone(),
|
||||
));
|
||||
let conversation_manager =
|
||||
ConversationManager::new(AuthManager::shared(config.codex_home.clone()));
|
||||
let NewConversation {
|
||||
conversation_id: _,
|
||||
conversation,
|
||||
|
||||
@@ -61,7 +61,7 @@ pub(crate) async fn run_e2e_exec_test(cwd: &Path, response_streams: Vec<String>)
|
||||
.context("should find binary for codex-exec")
|
||||
.expect("should find binary for codex-exec")
|
||||
.current_dir(cwd.clone())
|
||||
.env("CODEX_HOME", cwd.clone())
|
||||
.env("CODEX_HOME", cwd)
|
||||
.env("OPENAI_API_KEY", "dummy")
|
||||
.env("OPENAI_BASE_URL", format!("{uri}/v1"))
|
||||
.arg("--skip-git-repo-check")
|
||||
|
||||
@@ -88,7 +88,7 @@ impl ExecvChecker {
|
||||
let mut program = valid_exec.program.to_string();
|
||||
for system_path in valid_exec.system_path {
|
||||
if is_executable_file(&system_path) {
|
||||
program = system_path.to_string();
|
||||
program = system_path;
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -196,7 +196,7 @@ system_path=[{fake_cp:?}]
|
||||
let checker = setup(&fake_cp);
|
||||
let exec_call = ExecCall {
|
||||
program: "cp".into(),
|
||||
args: vec![source.clone(), dest.clone()],
|
||||
args: vec![source, dest.clone()],
|
||||
};
|
||||
let valid_exec = match checker.r#match(&exec_call)? {
|
||||
MatchedExec::Match { exec } => exec,
|
||||
@@ -207,7 +207,7 @@ system_path=[{fake_cp:?}]
|
||||
assert_eq!(
|
||||
checker.check(valid_exec.clone(), &cwd, &[], &[]),
|
||||
Err(ReadablePathNotInReadableFolders {
|
||||
file: source_path.clone(),
|
||||
file: source_path,
|
||||
folders: vec![]
|
||||
}),
|
||||
);
|
||||
@@ -229,7 +229,7 @@ system_path=[{fake_cp:?}]
|
||||
// Both readable and writeable folders specified.
|
||||
assert_eq!(
|
||||
checker.check(
|
||||
valid_exec.clone(),
|
||||
valid_exec,
|
||||
&cwd,
|
||||
std::slice::from_ref(&root_path),
|
||||
std::slice::from_ref(&root_path)
|
||||
@@ -241,7 +241,7 @@ system_path=[{fake_cp:?}]
|
||||
// folders.
|
||||
let exec_call_folders_as_args = ExecCall {
|
||||
program: "cp".into(),
|
||||
args: vec![root.clone(), root.clone()],
|
||||
args: vec![root.clone(), root],
|
||||
};
|
||||
let valid_exec_call_folders_as_args = match checker.r#match(&exec_call_folders_as_args)? {
|
||||
MatchedExec::Match { exec } => exec,
|
||||
@@ -254,7 +254,7 @@ system_path=[{fake_cp:?}]
|
||||
std::slice::from_ref(&root_path),
|
||||
std::slice::from_ref(&root_path)
|
||||
),
|
||||
Ok(cp.clone()),
|
||||
Ok(cp),
|
||||
);
|
||||
|
||||
// Specify a parent of a readable folder as input.
|
||||
|
||||
@@ -104,7 +104,7 @@ impl PolicyBuilder {
|
||||
info!("adding program spec: {program_spec:?}");
|
||||
let name = program_spec.program.clone();
|
||||
let mut programs = self.programs.borrow_mut();
|
||||
programs.insert(name.clone(), program_spec);
|
||||
programs.insert(name, program_spec);
|
||||
}
|
||||
|
||||
fn add_forbidden_substrings(&self, substrings: &[String]) {
|
||||
|
||||
@@ -31,6 +31,13 @@ install:
|
||||
rustup show active-toolchain
|
||||
cargo fetch
|
||||
|
||||
# Run `cargo nextest` since it's faster than `cargo test`, though including
|
||||
# --no-fail-fast is important to ensure all tests are run.
|
||||
#
|
||||
# Run `cargo install cargo-nextest` if you don't have it installed.
|
||||
test:
|
||||
cargo nextest run --no-fail-fast
|
||||
|
||||
# Run the MCP server
|
||||
mcp-server-run *args:
|
||||
cargo run -p codex-mcp-server -- "$@"
|
||||
|
||||
@@ -15,9 +15,7 @@ path = "src/lib.rs"
|
||||
workspace = true
|
||||
|
||||
[target.'cfg(target_os = "linux")'.dependencies]
|
||||
anyhow = "1"
|
||||
clap = { version = "4", features = ["derive"] }
|
||||
codex-common = { path = "../common", features = ["cli"] }
|
||||
codex-core = { path = "../core" }
|
||||
landlock = "0.4.1"
|
||||
libc = "0.2.175"
|
||||
|
||||
@@ -17,7 +17,6 @@ serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
sha2 = "0.10"
|
||||
tempfile = "3"
|
||||
thiserror = "2.0.16"
|
||||
tiny_http = "0.12"
|
||||
tokio = { version = "1", features = [
|
||||
"io-std",
|
||||
@@ -31,5 +30,4 @@ urlencoding = "2.1"
|
||||
webbrowser = "1.0"
|
||||
|
||||
[dev-dependencies]
|
||||
pretty_assertions = "1.4.1"
|
||||
tempfile = "3"
|
||||
|
||||
@@ -16,6 +16,7 @@ use base64::Engine;
|
||||
use chrono::Utc;
|
||||
use codex_core::auth::AuthDotJson;
|
||||
use codex_core::auth::get_auth_file;
|
||||
use codex_core::default_client::ORIGINATOR;
|
||||
use codex_core::token_data::TokenData;
|
||||
use codex_core::token_data::parse_id_token;
|
||||
use rand::RngCore;
|
||||
@@ -35,19 +36,17 @@ pub struct ServerOptions {
|
||||
pub port: u16,
|
||||
pub open_browser: bool,
|
||||
pub force_state: Option<String>,
|
||||
pub originator: String,
|
||||
}
|
||||
|
||||
impl ServerOptions {
|
||||
pub fn new(codex_home: PathBuf, client_id: String, originator: String) -> Self {
|
||||
pub fn new(codex_home: PathBuf, client_id: String) -> Self {
|
||||
Self {
|
||||
codex_home,
|
||||
client_id: client_id.to_string(),
|
||||
client_id,
|
||||
issuer: DEFAULT_ISSUER.to_string(),
|
||||
port: DEFAULT_PORT,
|
||||
open_browser: true,
|
||||
force_state: None,
|
||||
originator,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -103,14 +102,7 @@ pub fn run_login_server(opts: ServerOptions) -> io::Result<LoginServer> {
|
||||
let server = Arc::new(server);
|
||||
|
||||
let redirect_uri = format!("http://localhost:{actual_port}/auth/callback");
|
||||
let auth_url = build_authorize_url(
|
||||
&opts.issuer,
|
||||
&opts.client_id,
|
||||
&redirect_uri,
|
||||
&pkce,
|
||||
&state,
|
||||
&opts.originator,
|
||||
);
|
||||
let auth_url = build_authorize_url(&opts.issuer, &opts.client_id, &redirect_uri, &pkce, &state);
|
||||
|
||||
if opts.open_browser {
|
||||
let _ = webbrowser::open(&auth_url);
|
||||
@@ -134,7 +126,7 @@ pub fn run_login_server(opts: ServerOptions) -> io::Result<LoginServer> {
|
||||
let shutdown_notify = Arc::new(tokio::sync::Notify::new());
|
||||
let server_handle = {
|
||||
let shutdown_notify = shutdown_notify.clone();
|
||||
let server = server.clone();
|
||||
let server = server;
|
||||
tokio::spawn(async move {
|
||||
let result = loop {
|
||||
tokio::select! {
|
||||
@@ -311,7 +303,6 @@ fn build_authorize_url(
|
||||
redirect_uri: &str,
|
||||
pkce: &PkceCodes,
|
||||
state: &str,
|
||||
originator: &str,
|
||||
) -> String {
|
||||
let query = vec![
|
||||
("response_type", "code"),
|
||||
@@ -323,7 +314,7 @@ fn build_authorize_url(
|
||||
("id_token_add_organizations", "true"),
|
||||
("codex_cli_simplified_flow", "true"),
|
||||
("state", state),
|
||||
("originator", originator),
|
||||
("originator", ORIGINATOR.value.as_str()),
|
||||
];
|
||||
let qs = query
|
||||
.into_iter()
|
||||
|
||||
@@ -102,7 +102,6 @@ async fn end_to_end_login_flow_persists_auth_json() {
|
||||
port: 0,
|
||||
open_browser: false,
|
||||
force_state: Some(state),
|
||||
originator: "test_originator".to_string(),
|
||||
};
|
||||
let server = run_login_server(opts).unwrap();
|
||||
let login_port = server.actual_port;
|
||||
@@ -161,7 +160,6 @@ async fn creates_missing_codex_home_dir() {
|
||||
port: 0,
|
||||
open_browser: false,
|
||||
force_state: Some(state),
|
||||
originator: "test_originator".to_string(),
|
||||
};
|
||||
let server = run_login_server(opts).unwrap();
|
||||
let login_port = server.actual_port;
|
||||
@@ -202,7 +200,6 @@ async fn cancels_previous_login_server_when_port_is_in_use() {
|
||||
port: 0,
|
||||
open_browser: false,
|
||||
force_state: Some("cancel_state".to_string()),
|
||||
originator: "test_originator".to_string(),
|
||||
};
|
||||
|
||||
let first_server = run_login_server(first_opts).unwrap();
|
||||
@@ -221,7 +218,6 @@ async fn cancels_previous_login_server_when_port_is_in_use() {
|
||||
port: login_port,
|
||||
open_browser: false,
|
||||
force_state: Some("cancel_state_2".to_string()),
|
||||
originator: "test_originator".to_string(),
|
||||
};
|
||||
|
||||
let second_server = run_login_server(second_opts).unwrap();
|
||||
|
||||
@@ -64,6 +64,9 @@ async fn main() -> Result<()> {
|
||||
name: "codex-mcp-client".to_owned(),
|
||||
version: env!("CARGO_PKG_VERSION").to_owned(),
|
||||
title: Some("Codex".to_string()),
|
||||
// This field is used by Codex when it is an MCP server: it should
|
||||
// not be used when Codex is an MCP client.
|
||||
user_agent: None,
|
||||
},
|
||||
protocol_version: MCP_SCHEMA_VERSION.to_owned(),
|
||||
};
|
||||
|
||||
@@ -26,7 +26,6 @@ schemars = "0.8.22"
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
shlex = "1.3.0"
|
||||
strum_macros = "0.27.2"
|
||||
tokio = { version = "1", features = [
|
||||
"io-std",
|
||||
"macros",
|
||||
@@ -41,8 +40,9 @@ uuid = { version = "1", features = ["serde", "v4"] }
|
||||
|
||||
[dev-dependencies]
|
||||
assert_cmd = "2"
|
||||
base64 = "0.22"
|
||||
mcp_test_support = { path = "tests/common" }
|
||||
os_info = "3.12.0"
|
||||
pretty_assertions = "1.4.1"
|
||||
tempfile = "3"
|
||||
tokio-test = "0.4"
|
||||
wiremock = "0.6"
|
||||
|
||||
@@ -11,10 +11,16 @@ use codex_core::NewConversation;
|
||||
use codex_core::RolloutRecorder;
|
||||
use codex_core::SessionMeta;
|
||||
use codex_core::auth::CLIENT_ID;
|
||||
use codex_core::auth::get_auth_file;
|
||||
use codex_core::auth::login_with_api_key;
|
||||
use codex_core::auth::try_read_auth_json;
|
||||
use codex_core::config::Config;
|
||||
use codex_core::config::ConfigOverrides;
|
||||
use codex_core::config::ConfigToml;
|
||||
use codex_core::config::load_config_as_toml;
|
||||
use codex_core::config_edit::CONFIG_KEY_EFFORT;
|
||||
use codex_core::config_edit::CONFIG_KEY_MODEL;
|
||||
use codex_core::config_edit::persist_non_null_overrides;
|
||||
use codex_core::default_client::get_codex_user_agent;
|
||||
use codex_core::exec::ExecParams;
|
||||
use codex_core::exec_env::create_env;
|
||||
@@ -35,7 +41,8 @@ use codex_protocol::mcp_protocol::AddConversationListenerParams;
|
||||
use codex_protocol::mcp_protocol::AddConversationSubscriptionResponse;
|
||||
use codex_protocol::mcp_protocol::ApplyPatchApprovalParams;
|
||||
use codex_protocol::mcp_protocol::ApplyPatchApprovalResponse;
|
||||
use codex_protocol::mcp_protocol::AuthMode;
|
||||
use codex_protocol::mcp_protocol::ArchiveConversationParams;
|
||||
use codex_protocol::mcp_protocol::ArchiveConversationResponse;
|
||||
use codex_protocol::mcp_protocol::AuthStatusChangeNotification;
|
||||
use codex_protocol::mcp_protocol::ClientRequest;
|
||||
use codex_protocol::mcp_protocol::ConversationId;
|
||||
@@ -53,6 +60,8 @@ use codex_protocol::mcp_protocol::InterruptConversationParams;
|
||||
use codex_protocol::mcp_protocol::InterruptConversationResponse;
|
||||
use codex_protocol::mcp_protocol::ListConversationsParams;
|
||||
use codex_protocol::mcp_protocol::ListConversationsResponse;
|
||||
use codex_protocol::mcp_protocol::LoginApiKeyParams;
|
||||
use codex_protocol::mcp_protocol::LoginApiKeyResponse;
|
||||
use codex_protocol::mcp_protocol::LoginChatGptCompleteNotification;
|
||||
use codex_protocol::mcp_protocol::LoginChatGptResponse;
|
||||
use codex_protocol::mcp_protocol::NewConversationParams;
|
||||
@@ -65,6 +74,9 @@ use codex_protocol::mcp_protocol::SendUserMessageResponse;
|
||||
use codex_protocol::mcp_protocol::SendUserTurnParams;
|
||||
use codex_protocol::mcp_protocol::SendUserTurnResponse;
|
||||
use codex_protocol::mcp_protocol::ServerNotification;
|
||||
use codex_protocol::mcp_protocol::SetDefaultModelParams;
|
||||
use codex_protocol::mcp_protocol::SetDefaultModelResponse;
|
||||
use codex_protocol::mcp_protocol::UserInfoResponse;
|
||||
use codex_protocol::mcp_protocol::UserSavedConfig;
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
@@ -73,12 +85,16 @@ use codex_protocol::protocol::USER_MESSAGE_BEGIN;
|
||||
use mcp_types::JSONRPCErrorError;
|
||||
use mcp_types::RequestId;
|
||||
use std::collections::HashMap;
|
||||
use std::ffi::OsStr;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::select;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::oneshot;
|
||||
use tracing::error;
|
||||
use tracing::info;
|
||||
use tracing::warn;
|
||||
use uuid::Uuid;
|
||||
|
||||
// Duration before a ChatGPT login attempt is abandoned.
|
||||
@@ -142,6 +158,9 @@ impl CodexMessageProcessor {
|
||||
ClientRequest::ResumeConversation { request_id, params } => {
|
||||
self.handle_resume_conversation(request_id, params).await;
|
||||
}
|
||||
ClientRequest::ArchiveConversation { request_id, params } => {
|
||||
self.archive_conversation(request_id, params).await;
|
||||
}
|
||||
ClientRequest::SendUserMessage { request_id, params } => {
|
||||
self.send_user_message(request_id, params).await;
|
||||
}
|
||||
@@ -160,6 +179,9 @@ impl CodexMessageProcessor {
|
||||
ClientRequest::GitDiffToRemote { request_id, params } => {
|
||||
self.git_diff_to_origin(request_id, params.cwd).await;
|
||||
}
|
||||
ClientRequest::LoginApiKey { request_id, params } => {
|
||||
self.login_api_key(request_id, params).await;
|
||||
}
|
||||
ClientRequest::LoginChatGpt { request_id } => {
|
||||
self.login_chatgpt(request_id).await;
|
||||
}
|
||||
@@ -175,25 +197,60 @@ impl CodexMessageProcessor {
|
||||
ClientRequest::GetUserSavedConfig { request_id } => {
|
||||
self.get_user_saved_config(request_id).await;
|
||||
}
|
||||
ClientRequest::SetDefaultModel { request_id, params } => {
|
||||
self.set_default_model(request_id, params).await;
|
||||
}
|
||||
ClientRequest::GetUserAgent { request_id } => {
|
||||
self.get_user_agent(request_id).await;
|
||||
}
|
||||
ClientRequest::UserInfo { request_id } => {
|
||||
self.get_user_info(request_id).await;
|
||||
}
|
||||
ClientRequest::ExecOneOffCommand { request_id, params } => {
|
||||
self.exec_one_off_command(request_id, params).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn login_api_key(&mut self, request_id: RequestId, params: LoginApiKeyParams) {
|
||||
{
|
||||
let mut guard = self.active_login.lock().await;
|
||||
if let Some(active) = guard.take() {
|
||||
active.drop();
|
||||
}
|
||||
}
|
||||
|
||||
match login_with_api_key(&self.config.codex_home, ¶ms.api_key) {
|
||||
Ok(()) => {
|
||||
self.auth_manager.reload();
|
||||
self.outgoing
|
||||
.send_response(request_id, LoginApiKeyResponse {})
|
||||
.await;
|
||||
|
||||
let payload = AuthStatusChangeNotification {
|
||||
auth_method: self.auth_manager.auth().map(|auth| auth.mode),
|
||||
};
|
||||
self.outgoing
|
||||
.send_server_notification(ServerNotification::AuthStatusChange(payload))
|
||||
.await;
|
||||
}
|
||||
Err(err) => {
|
||||
let error = JSONRPCErrorError {
|
||||
code: INTERNAL_ERROR_CODE,
|
||||
message: format!("failed to save api key: {err}"),
|
||||
data: None,
|
||||
};
|
||||
self.outgoing.send_error(request_id, error).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn login_chatgpt(&mut self, request_id: RequestId) {
|
||||
let config = self.config.as_ref();
|
||||
|
||||
let opts = LoginServerOptions {
|
||||
open_browser: false,
|
||||
..LoginServerOptions::new(
|
||||
config.codex_home.clone(),
|
||||
CLIENT_ID.to_string(),
|
||||
config.responses_originator_header.clone(),
|
||||
)
|
||||
..LoginServerOptions::new(config.codex_home.clone(), CLIENT_ID.to_string())
|
||||
};
|
||||
|
||||
enum LoginChatGptReply {
|
||||
@@ -341,7 +398,7 @@ impl CodexMessageProcessor {
|
||||
.await;
|
||||
|
||||
// Send auth status change notification reflecting the current auth mode
|
||||
// after logout (which may fall back to API key via env var).
|
||||
// after logout.
|
||||
let current_auth_method = self.auth_manager.auth().map(|auth| auth.mode);
|
||||
let payload = AuthStatusChangeNotification {
|
||||
auth_method: current_auth_method,
|
||||
@@ -356,7 +413,6 @@ impl CodexMessageProcessor {
|
||||
request_id: RequestId,
|
||||
params: codex_protocol::mcp_protocol::GetAuthStatusParams,
|
||||
) {
|
||||
let preferred_auth_method: AuthMode = self.auth_manager.preferred_auth_method();
|
||||
let include_token = params.include_token.unwrap_or(false);
|
||||
let do_refresh = params.refresh_token.unwrap_or(false);
|
||||
|
||||
@@ -364,6 +420,11 @@ impl CodexMessageProcessor {
|
||||
tracing::warn!("failed to refresh token while getting auth status: {err}");
|
||||
}
|
||||
|
||||
// Determine whether auth is required based on the active model provider.
|
||||
// If a custom provider is configured with `requires_openai_auth == false`,
|
||||
// then no auth step is required; otherwise, default to requiring auth.
|
||||
let requires_openai_auth = Some(self.config.model_provider.requires_openai_auth);
|
||||
|
||||
let response = match self.auth_manager.auth() {
|
||||
Some(auth) => {
|
||||
let (reported_auth_method, token_opt) = match auth.get_token().await {
|
||||
@@ -379,14 +440,14 @@ impl CodexMessageProcessor {
|
||||
};
|
||||
codex_protocol::mcp_protocol::GetAuthStatusResponse {
|
||||
auth_method: reported_auth_method,
|
||||
preferred_auth_method,
|
||||
auth_token: token_opt,
|
||||
requires_openai_auth,
|
||||
}
|
||||
}
|
||||
None => codex_protocol::mcp_protocol::GetAuthStatusResponse {
|
||||
auth_method: None,
|
||||
preferred_auth_method,
|
||||
auth_token: None,
|
||||
requires_openai_auth,
|
||||
},
|
||||
};
|
||||
|
||||
@@ -394,7 +455,7 @@ impl CodexMessageProcessor {
|
||||
}
|
||||
|
||||
async fn get_user_agent(&self, request_id: RequestId) {
|
||||
let user_agent = get_codex_user_agent(Some(&self.config.responses_originator_header));
|
||||
let user_agent = get_codex_user_agent();
|
||||
let response = GetUserAgentResponse { user_agent };
|
||||
self.outgoing.send_response(request_id, response).await;
|
||||
}
|
||||
@@ -434,6 +495,52 @@ impl CodexMessageProcessor {
|
||||
self.outgoing.send_response(request_id, response).await;
|
||||
}
|
||||
|
||||
async fn get_user_info(&self, request_id: RequestId) {
|
||||
// Read alleged user email from auth.json (best-effort; not verified).
|
||||
let auth_path = get_auth_file(&self.config.codex_home);
|
||||
let alleged_user_email = match try_read_auth_json(&auth_path) {
|
||||
Ok(auth) => auth.tokens.and_then(|t| t.id_token.email),
|
||||
Err(_) => None,
|
||||
};
|
||||
|
||||
let response = UserInfoResponse { alleged_user_email };
|
||||
self.outgoing.send_response(request_id, response).await;
|
||||
}
|
||||
|
||||
async fn set_default_model(&self, request_id: RequestId, params: SetDefaultModelParams) {
|
||||
let SetDefaultModelParams {
|
||||
model,
|
||||
reasoning_effort,
|
||||
} = params;
|
||||
let effort_str = reasoning_effort.map(|effort| effort.to_string());
|
||||
|
||||
let overrides: [(&[&str], Option<&str>); 2] = [
|
||||
(&[CONFIG_KEY_MODEL], model.as_deref()),
|
||||
(&[CONFIG_KEY_EFFORT], effort_str.as_deref()),
|
||||
];
|
||||
|
||||
match persist_non_null_overrides(
|
||||
&self.config.codex_home,
|
||||
self.config.active_profile.as_deref(),
|
||||
&overrides,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(()) => {
|
||||
let response = SetDefaultModelResponse {};
|
||||
self.outgoing.send_response(request_id, response).await;
|
||||
}
|
||||
Err(err) => {
|
||||
let error = JSONRPCErrorError {
|
||||
code: INTERNAL_ERROR_CODE,
|
||||
message: format!("failed to persist overrides: {err}"),
|
||||
data: None,
|
||||
};
|
||||
self.outgoing.send_error(request_id, error).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn exec_one_off_command(&self, request_id: RequestId, params: ExecOneOffCommandParams) {
|
||||
tracing::debug!("ExecOneOffCommand params: {params:?}");
|
||||
|
||||
@@ -528,6 +635,8 @@ impl CodexMessageProcessor {
|
||||
let response = NewConversationResponse {
|
||||
conversation_id,
|
||||
model: session_configured.model,
|
||||
reasoning_effort: session_configured.reasoning_effort,
|
||||
rollout_path: session_configured.rollout_path,
|
||||
};
|
||||
self.outgoing.send_response(request_id, response).await;
|
||||
}
|
||||
@@ -669,6 +778,141 @@ impl CodexMessageProcessor {
|
||||
}
|
||||
}
|
||||
|
||||
async fn archive_conversation(&self, request_id: RequestId, params: ArchiveConversationParams) {
|
||||
let ArchiveConversationParams {
|
||||
conversation_id,
|
||||
rollout_path,
|
||||
} = params;
|
||||
|
||||
// Verify that the rollout path is in the sessions directory or else
|
||||
// a malicious client could specify an arbitrary path.
|
||||
let rollout_folder = self.config.codex_home.join(codex_core::SESSIONS_SUBDIR);
|
||||
let canonical_rollout_path = tokio::fs::canonicalize(&rollout_path).await;
|
||||
let canonical_rollout_path = if let Ok(path) = canonical_rollout_path
|
||||
&& path.starts_with(&rollout_folder)
|
||||
{
|
||||
path
|
||||
} else {
|
||||
let error = JSONRPCErrorError {
|
||||
code: INVALID_REQUEST_ERROR_CODE,
|
||||
message: format!(
|
||||
"rollout path `{}` must be in sessions directory",
|
||||
rollout_path.display()
|
||||
),
|
||||
data: None,
|
||||
};
|
||||
self.outgoing.send_error(request_id, error).await;
|
||||
return;
|
||||
};
|
||||
|
||||
let required_suffix = format!("{}.jsonl", conversation_id.0);
|
||||
let Some(file_name) = canonical_rollout_path.file_name().map(OsStr::to_owned) else {
|
||||
let error = JSONRPCErrorError {
|
||||
code: INVALID_REQUEST_ERROR_CODE,
|
||||
message: format!(
|
||||
"rollout path `{}` missing file name",
|
||||
rollout_path.display()
|
||||
),
|
||||
data: None,
|
||||
};
|
||||
self.outgoing.send_error(request_id, error).await;
|
||||
return;
|
||||
};
|
||||
|
||||
if !file_name
|
||||
.to_string_lossy()
|
||||
.ends_with(required_suffix.as_str())
|
||||
{
|
||||
let error = JSONRPCErrorError {
|
||||
code: INVALID_REQUEST_ERROR_CODE,
|
||||
message: format!(
|
||||
"rollout path `{}` does not match conversation id {conversation_id}",
|
||||
rollout_path.display()
|
||||
),
|
||||
data: None,
|
||||
};
|
||||
self.outgoing.send_error(request_id, error).await;
|
||||
return;
|
||||
}
|
||||
|
||||
let removed_conversation = self
|
||||
.conversation_manager
|
||||
.remove_conversation(&conversation_id)
|
||||
.await;
|
||||
if let Some(conversation) = removed_conversation {
|
||||
info!("conversation {conversation_id} was active; shutting down");
|
||||
let conversation_clone = conversation.clone();
|
||||
let notify = Arc::new(tokio::sync::Notify::new());
|
||||
let notify_clone = notify.clone();
|
||||
|
||||
// Establish the listener for ShutdownComplete before submitting
|
||||
// Shutdown so it is not missed.
|
||||
let is_shutdown = tokio::spawn(async move {
|
||||
loop {
|
||||
select! {
|
||||
_ = notify_clone.notified() => {
|
||||
break;
|
||||
}
|
||||
event = conversation_clone.next_event() => {
|
||||
if let Ok(event) = event && matches!(event.msg, EventMsg::ShutdownComplete) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Request shutdown.
|
||||
match conversation.submit(Op::Shutdown).await {
|
||||
Ok(_) => {
|
||||
// Successfully submitted Shutdown; wait before proceeding.
|
||||
select! {
|
||||
_ = is_shutdown => {
|
||||
// Normal shutdown: proceed with archive.
|
||||
}
|
||||
_ = tokio::time::sleep(Duration::from_secs(10)) => {
|
||||
warn!("conversation {conversation_id} shutdown timed out; proceeding with archive");
|
||||
notify.notify_one();
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
error!("failed to submit Shutdown to conversation {conversation_id}: {err}");
|
||||
notify.notify_one();
|
||||
// Perhaps we lost a shutdown race, so let's continue to
|
||||
// clean up the .jsonl file.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Move the .jsonl file to the archived sessions subdir.
|
||||
let result: std::io::Result<()> = async {
|
||||
let archive_folder = self
|
||||
.config
|
||||
.codex_home
|
||||
.join(codex_core::ARCHIVED_SESSIONS_SUBDIR);
|
||||
tokio::fs::create_dir_all(&archive_folder).await?;
|
||||
tokio::fs::rename(&canonical_rollout_path, &archive_folder.join(&file_name)).await?;
|
||||
Ok(())
|
||||
}
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(()) => {
|
||||
let response = ArchiveConversationResponse {};
|
||||
self.outgoing.send_response(request_id, response).await;
|
||||
}
|
||||
Err(err) => {
|
||||
let error = JSONRPCErrorError {
|
||||
code: INTERNAL_ERROR_CODE,
|
||||
message: format!("failed to archive conversation: {err}"),
|
||||
data: None,
|
||||
};
|
||||
self.outgoing.send_error(request_id, error).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn send_user_message(&self, request_id: RequestId, params: SendUserMessageParams) {
|
||||
let SendUserMessageParams {
|
||||
conversation_id,
|
||||
@@ -1110,10 +1354,7 @@ fn extract_conversation_summary(
|
||||
head: &[serde_json::Value],
|
||||
) -> Option<ConversationSummary> {
|
||||
let session_meta = match head.first() {
|
||||
Some(first_line) => match serde_json::from_value::<SessionMeta>(first_line.clone()) {
|
||||
Ok(session_meta) => session_meta,
|
||||
Err(..) => return None,
|
||||
},
|
||||
Some(first_line) => serde_json::from_value::<SessionMeta>(first_line.clone()).ok()?,
|
||||
None => return None,
|
||||
};
|
||||
|
||||
@@ -1171,6 +1412,10 @@ mod tests {
|
||||
json!({
|
||||
"id": conversation_id.0,
|
||||
"timestamp": timestamp,
|
||||
"cwd": "/",
|
||||
"originator": "codex",
|
||||
"cli_version": "0.0.0",
|
||||
"instructions": null
|
||||
}),
|
||||
json!({
|
||||
"type": "message",
|
||||
|
||||
@@ -222,7 +222,7 @@ async fn run_codex_tool_session_inner(
|
||||
}
|
||||
EventMsg::TaskComplete(TaskCompleteEvent { last_agent_message }) => {
|
||||
let text = match last_agent_message {
|
||||
Some(msg) => msg.clone(),
|
||||
Some(msg) => msg,
|
||||
None => "".to_string(),
|
||||
};
|
||||
let result = CallToolResult {
|
||||
@@ -277,7 +277,7 @@ async fn run_codex_tool_session_inner(
|
||||
| EventMsg::GetHistoryEntryResponse(_)
|
||||
| EventMsg::PlanUpdate(_)
|
||||
| EventMsg::TurnAborted(_)
|
||||
| EventMsg::ConversationHistory(_)
|
||||
| EventMsg::ConversationPath(_)
|
||||
| EventMsg::UserMessage(_)
|
||||
| EventMsg::ShutdownComplete => {
|
||||
// For now, we do not do anything extra for these
|
||||
|
||||
@@ -14,6 +14,8 @@ use codex_protocol::mcp_protocol::ConversationId;
|
||||
use codex_core::AuthManager;
|
||||
use codex_core::ConversationManager;
|
||||
use codex_core::config::Config;
|
||||
use codex_core::default_client::USER_AGENT_SUFFIX;
|
||||
use codex_core::default_client::get_codex_user_agent;
|
||||
use codex_core::protocol::Submission;
|
||||
use mcp_types::CallToolRequestParams;
|
||||
use mcp_types::CallToolResult;
|
||||
@@ -54,11 +56,7 @@ impl MessageProcessor {
|
||||
config: Arc<Config>,
|
||||
) -> Self {
|
||||
let outgoing = Arc::new(outgoing);
|
||||
let auth_manager = AuthManager::shared(
|
||||
config.codex_home.clone(),
|
||||
config.preferred_auth_method,
|
||||
config.responses_originator_header.clone(),
|
||||
);
|
||||
let auth_manager = AuthManager::shared(config.codex_home.clone());
|
||||
let conversation_manager = Arc::new(ConversationManager::new(auth_manager.clone()));
|
||||
let codex_message_processor = CodexMessageProcessor::new(
|
||||
auth_manager,
|
||||
@@ -211,6 +209,14 @@ impl MessageProcessor {
|
||||
return;
|
||||
}
|
||||
|
||||
let client_info = params.client_info;
|
||||
let name = client_info.name;
|
||||
let version = client_info.version;
|
||||
let user_agent_suffix = format!("{name}; {version}");
|
||||
if let Ok(mut suffix) = USER_AGENT_SUFFIX.lock() {
|
||||
*suffix = Some(user_agent_suffix);
|
||||
}
|
||||
|
||||
self.initialized = true;
|
||||
|
||||
// Build a minimal InitializeResult. Fill with placeholders.
|
||||
@@ -231,6 +237,7 @@ impl MessageProcessor {
|
||||
name: "codex-mcp-server".to_string(),
|
||||
version: env!("CARGO_PKG_VERSION").to_string(),
|
||||
title: Some("Codex".to_string()),
|
||||
user_agent: Some(get_codex_user_agent()),
|
||||
},
|
||||
};
|
||||
|
||||
@@ -524,7 +531,6 @@ impl MessageProcessor {
|
||||
|
||||
// Spawn the long-running reply handler.
|
||||
tokio::spawn({
|
||||
let codex = codex.clone();
|
||||
let outgoing = outgoing.clone();
|
||||
let prompt = prompt.clone();
|
||||
let running_requests_id_to_codex_uuid = running_requests_id_to_codex_uuid.clone();
|
||||
|
||||
@@ -258,10 +258,12 @@ pub(crate) struct OutgoingError {
|
||||
mod tests {
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::SessionConfiguredEvent;
|
||||
use codex_protocol::config_types::ReasoningEffort;
|
||||
use codex_protocol::mcp_protocol::ConversationId;
|
||||
use codex_protocol::mcp_protocol::LoginChatGptCompleteNotification;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
use tempfile::NamedTempFile;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::*;
|
||||
@@ -272,14 +274,17 @@ mod tests {
|
||||
let outgoing_message_sender = OutgoingMessageSender::new(outgoing_tx);
|
||||
|
||||
let conversation_id = ConversationId::new();
|
||||
let rollout_file = NamedTempFile::new().unwrap();
|
||||
let event = Event {
|
||||
id: "1".to_string(),
|
||||
msg: EventMsg::SessionConfigured(SessionConfiguredEvent {
|
||||
session_id: conversation_id,
|
||||
model: "gpt-4o".to_string(),
|
||||
reasoning_effort: ReasoningEffort::default(),
|
||||
history_log_id: 1,
|
||||
history_entry_count: 1000,
|
||||
initial_messages: None,
|
||||
rollout_path: rollout_file.path().to_path_buf(),
|
||||
}),
|
||||
};
|
||||
|
||||
@@ -296,7 +301,7 @@ mod tests {
|
||||
let Ok(expected_params) = serde_json::to_value(&event) else {
|
||||
panic!("Event must serialize");
|
||||
};
|
||||
assert_eq!(params, Some(expected_params.clone()));
|
||||
assert_eq!(params, Some(expected_params));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -305,12 +310,15 @@ mod tests {
|
||||
let outgoing_message_sender = OutgoingMessageSender::new(outgoing_tx);
|
||||
|
||||
let conversation_id = ConversationId::new();
|
||||
let rollout_file = NamedTempFile::new().unwrap();
|
||||
let session_configured_event = SessionConfiguredEvent {
|
||||
session_id: conversation_id,
|
||||
model: "gpt-4o".to_string(),
|
||||
reasoning_effort: ReasoningEffort::default(),
|
||||
history_log_id: 1,
|
||||
history_entry_count: 1000,
|
||||
initial_messages: None,
|
||||
rollout_path: rollout_file.path().to_path_buf(),
|
||||
};
|
||||
let event = Event {
|
||||
id: "1".to_string(),
|
||||
@@ -337,9 +345,11 @@ mod tests {
|
||||
"msg": {
|
||||
"session_id": session_configured_event.session_id,
|
||||
"model": session_configured_event.model,
|
||||
"reasoning_effort": session_configured_event.reasoning_effort,
|
||||
"history_log_id": session_configured_event.history_log_id,
|
||||
"history_entry_count": session_configured_event.history_entry_count,
|
||||
"type": "session_configured",
|
||||
"rollout_path": rollout_file.path().to_path_buf(),
|
||||
}
|
||||
});
|
||||
assert_eq!(params.unwrap(), expected_params);
|
||||
|
||||
@@ -13,16 +13,14 @@ codex-core = { path = "../../../core" }
|
||||
codex-mcp-server = { path = "../.." }
|
||||
codex-protocol = { path = "../../../protocol" }
|
||||
mcp-types = { path = "../../../mcp-types" }
|
||||
os_info = "3.12.0"
|
||||
pretty_assertions = "1.4.1"
|
||||
serde = { version = "1" }
|
||||
serde_json = "1"
|
||||
shlex = "1.3.0"
|
||||
tempfile = "3"
|
||||
tokio = { version = "1", features = [
|
||||
"io-std",
|
||||
"macros",
|
||||
"process",
|
||||
"rt-multi-thread",
|
||||
] }
|
||||
uuid = { version = "1", features = ["serde", "v4"] }
|
||||
wiremock = "0.6"
|
||||
|
||||
@@ -13,15 +13,18 @@ use anyhow::Context;
|
||||
use assert_cmd::prelude::*;
|
||||
use codex_mcp_server::CodexToolCallParam;
|
||||
use codex_protocol::mcp_protocol::AddConversationListenerParams;
|
||||
use codex_protocol::mcp_protocol::ArchiveConversationParams;
|
||||
use codex_protocol::mcp_protocol::CancelLoginChatGptParams;
|
||||
use codex_protocol::mcp_protocol::GetAuthStatusParams;
|
||||
use codex_protocol::mcp_protocol::InterruptConversationParams;
|
||||
use codex_protocol::mcp_protocol::ListConversationsParams;
|
||||
use codex_protocol::mcp_protocol::LoginApiKeyParams;
|
||||
use codex_protocol::mcp_protocol::NewConversationParams;
|
||||
use codex_protocol::mcp_protocol::RemoveConversationListenerParams;
|
||||
use codex_protocol::mcp_protocol::ResumeConversationParams;
|
||||
use codex_protocol::mcp_protocol::SendUserMessageParams;
|
||||
use codex_protocol::mcp_protocol::SendUserTurnParams;
|
||||
use codex_protocol::mcp_protocol::SetDefaultModelParams;
|
||||
|
||||
use mcp_types::CallToolRequestParams;
|
||||
use mcp_types::ClientCapabilities;
|
||||
@@ -53,6 +56,18 @@ pub struct McpProcess {
|
||||
|
||||
impl McpProcess {
|
||||
pub async fn new(codex_home: &Path) -> anyhow::Result<Self> {
|
||||
Self::new_with_env(codex_home, &[]).await
|
||||
}
|
||||
|
||||
/// Creates a new MCP process, allowing tests to override or remove
|
||||
/// specific environment variables for the child process only.
|
||||
///
|
||||
/// Pass a tuple of (key, Some(value)) to set/override, or (key, None) to
|
||||
/// remove a variable from the child's environment.
|
||||
pub async fn new_with_env(
|
||||
codex_home: &Path,
|
||||
env_overrides: &[(&str, Option<&str>)],
|
||||
) -> anyhow::Result<Self> {
|
||||
// Use assert_cmd to locate the binary path and then switch to tokio::process::Command
|
||||
let std_cmd = StdCommand::cargo_bin("codex-mcp-server")
|
||||
.context("should find binary for codex-mcp-server")?;
|
||||
@@ -67,6 +82,17 @@ impl McpProcess {
|
||||
cmd.env("CODEX_HOME", codex_home);
|
||||
cmd.env("RUST_LOG", "debug");
|
||||
|
||||
for (k, v) in env_overrides {
|
||||
match v {
|
||||
Some(val) => {
|
||||
cmd.env(k, val);
|
||||
}
|
||||
None => {
|
||||
cmd.env_remove(k);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut process = cmd
|
||||
.kill_on_drop(true)
|
||||
.spawn()
|
||||
@@ -114,6 +140,7 @@ impl McpProcess {
|
||||
name: "elicitation test".into(),
|
||||
title: Some("Elicitation Test".into()),
|
||||
version: "0.0.0".into(),
|
||||
user_agent: None,
|
||||
},
|
||||
protocol_version: mcp_types::MCP_SCHEMA_VERSION.into(),
|
||||
};
|
||||
@@ -128,6 +155,14 @@ impl McpProcess {
|
||||
.await?;
|
||||
|
||||
let initialized = self.read_jsonrpc_message().await?;
|
||||
let os_info = os_info::get();
|
||||
let user_agent = format!(
|
||||
"codex_cli_rs/0.0.0 ({} {}; {}) {} (elicitation test; 0.0.0)",
|
||||
os_info.os_type(),
|
||||
os_info.version(),
|
||||
os_info.architecture().unwrap_or("unknown"),
|
||||
codex_core::terminal::user_agent()
|
||||
);
|
||||
assert_eq!(
|
||||
JSONRPCMessage::Response(JSONRPCResponse {
|
||||
jsonrpc: JSONRPC_VERSION.into(),
|
||||
@@ -141,7 +176,8 @@ impl McpProcess {
|
||||
"serverInfo": {
|
||||
"name": "codex-mcp-server",
|
||||
"title": "Codex",
|
||||
"version": "0.0.0"
|
||||
"version": "0.0.0",
|
||||
"user_agent": user_agent
|
||||
},
|
||||
"protocolVersion": mcp_types::MCP_SCHEMA_VERSION
|
||||
})
|
||||
@@ -186,6 +222,15 @@ impl McpProcess {
|
||||
self.send_request("newConversation", params).await
|
||||
}
|
||||
|
||||
/// Send an `archiveConversation` JSON-RPC request.
|
||||
pub async fn send_archive_conversation_request(
|
||||
&mut self,
|
||||
params: ArchiveConversationParams,
|
||||
) -> anyhow::Result<i64> {
|
||||
let params = Some(serde_json::to_value(params)?);
|
||||
self.send_request("archiveConversation", params).await
|
||||
}
|
||||
|
||||
/// Send an `addConversationListener` JSON-RPC request.
|
||||
pub async fn send_add_conversation_listener_request(
|
||||
&mut self,
|
||||
@@ -252,6 +297,20 @@ impl McpProcess {
|
||||
self.send_request("getUserAgent", None).await
|
||||
}
|
||||
|
||||
/// Send a `userInfo` JSON-RPC request.
|
||||
pub async fn send_user_info_request(&mut self) -> anyhow::Result<i64> {
|
||||
self.send_request("userInfo", None).await
|
||||
}
|
||||
|
||||
/// Send a `setDefaultModel` JSON-RPC request.
|
||||
pub async fn send_set_default_model_request(
|
||||
&mut self,
|
||||
params: SetDefaultModelParams,
|
||||
) -> anyhow::Result<i64> {
|
||||
let params = Some(serde_json::to_value(params)?);
|
||||
self.send_request("setDefaultModel", params).await
|
||||
}
|
||||
|
||||
/// Send a `listConversations` JSON-RPC request.
|
||||
pub async fn send_list_conversations_request(
|
||||
&mut self,
|
||||
@@ -270,6 +329,15 @@ impl McpProcess {
|
||||
self.send_request("resumeConversation", params).await
|
||||
}
|
||||
|
||||
/// Send a `loginApiKey` JSON-RPC request.
|
||||
pub async fn send_login_api_key_request(
|
||||
&mut self,
|
||||
params: LoginApiKeyParams,
|
||||
) -> anyhow::Result<i64> {
|
||||
let params = Some(serde_json::to_value(params)?);
|
||||
self.send_request("loginApiKey", params).await
|
||||
}
|
||||
|
||||
/// Send a `loginChatGpt` JSON-RPC request.
|
||||
pub async fn send_login_chat_gpt_request(&mut self) -> anyhow::Result<i64> {
|
||||
self.send_request("loginChatGpt", None).await
|
||||
|
||||
105
codex-rs/mcp-server/tests/suite/archive_conversation.rs
Normal file
105
codex-rs/mcp-server/tests/suite/archive_conversation.rs
Normal file
@@ -0,0 +1,105 @@
|
||||
use std::path::Path;
|
||||
|
||||
use codex_core::ARCHIVED_SESSIONS_SUBDIR;
|
||||
use codex_protocol::mcp_protocol::ArchiveConversationParams;
|
||||
use codex_protocol::mcp_protocol::ArchiveConversationResponse;
|
||||
use codex_protocol::mcp_protocol::NewConversationParams;
|
||||
use codex_protocol::mcp_protocol::NewConversationResponse;
|
||||
use mcp_test_support::McpProcess;
|
||||
use mcp_test_support::to_response;
|
||||
use mcp_types::JSONRPCResponse;
|
||||
use mcp_types::RequestId;
|
||||
use tempfile::TempDir;
|
||||
use tokio::time::timeout;
|
||||
|
||||
const DEFAULT_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn archive_conversation_moves_rollout_into_archived_directory() {
|
||||
let codex_home = TempDir::new().expect("create temp dir");
|
||||
create_config_toml(codex_home.path()).expect("write config.toml");
|
||||
|
||||
let mut mcp = McpProcess::new(codex_home.path())
|
||||
.await
|
||||
.expect("spawn mcp process");
|
||||
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize())
|
||||
.await
|
||||
.expect("initialize timeout")
|
||||
.expect("initialize request");
|
||||
|
||||
let new_request_id = mcp
|
||||
.send_new_conversation_request(NewConversationParams {
|
||||
model: Some("mock-model".to_string()),
|
||||
..Default::default()
|
||||
})
|
||||
.await
|
||||
.expect("send newConversation");
|
||||
let new_response: JSONRPCResponse = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(new_request_id)),
|
||||
)
|
||||
.await
|
||||
.expect("newConversation timeout")
|
||||
.expect("newConversation response");
|
||||
|
||||
let NewConversationResponse {
|
||||
conversation_id,
|
||||
rollout_path,
|
||||
..
|
||||
} = to_response::<NewConversationResponse>(new_response)
|
||||
.expect("deserialize newConversation response");
|
||||
|
||||
assert!(
|
||||
rollout_path.exists(),
|
||||
"expected rollout path {} to exist",
|
||||
rollout_path.display()
|
||||
);
|
||||
|
||||
let archive_request_id = mcp
|
||||
.send_archive_conversation_request(ArchiveConversationParams {
|
||||
conversation_id,
|
||||
rollout_path: rollout_path.clone(),
|
||||
})
|
||||
.await
|
||||
.expect("send archiveConversation");
|
||||
let archive_response: JSONRPCResponse = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(archive_request_id)),
|
||||
)
|
||||
.await
|
||||
.expect("archiveConversation timeout")
|
||||
.expect("archiveConversation response");
|
||||
|
||||
let _: ArchiveConversationResponse =
|
||||
to_response::<ArchiveConversationResponse>(archive_response)
|
||||
.expect("deserialize archiveConversation response");
|
||||
|
||||
let archived_directory = codex_home.path().join(ARCHIVED_SESSIONS_SUBDIR);
|
||||
let archived_rollout_path =
|
||||
archived_directory.join(rollout_path.file_name().unwrap_or_else(|| {
|
||||
panic!("rollout path {} missing file name", rollout_path.display())
|
||||
}));
|
||||
|
||||
assert!(
|
||||
!rollout_path.exists(),
|
||||
"expected rollout path {} to be moved",
|
||||
rollout_path.display()
|
||||
);
|
||||
assert!(
|
||||
archived_rollout_path.exists(),
|
||||
"expected archived rollout path {} to exist",
|
||||
archived_rollout_path.display()
|
||||
);
|
||||
}
|
||||
|
||||
fn create_config_toml(codex_home: &Path) -> std::io::Result<()> {
|
||||
let config_toml = codex_home.join("config.toml");
|
||||
std::fs::write(config_toml, config_contents())
|
||||
}
|
||||
|
||||
fn config_contents() -> &'static str {
|
||||
r#"model = "mock-model"
|
||||
approval_policy = "never"
|
||||
sandbox_mode = "read-only"
|
||||
"#
|
||||
}
|
||||
@@ -1,9 +1,10 @@
|
||||
use std::path::Path;
|
||||
|
||||
use codex_core::auth::login_with_api_key;
|
||||
use codex_protocol::mcp_protocol::AuthMode;
|
||||
use codex_protocol::mcp_protocol::GetAuthStatusParams;
|
||||
use codex_protocol::mcp_protocol::GetAuthStatusResponse;
|
||||
use codex_protocol::mcp_protocol::LoginApiKeyParams;
|
||||
use codex_protocol::mcp_protocol::LoginApiKeyResponse;
|
||||
use mcp_test_support::McpProcess;
|
||||
use mcp_test_support::to_response;
|
||||
use mcp_types::JSONRPCResponse;
|
||||
@@ -36,12 +37,31 @@ stream_max_retries = 0
|
||||
)
|
||||
}
|
||||
|
||||
async fn login_with_api_key_via_request(mcp: &mut McpProcess, api_key: &str) {
|
||||
let request_id = mcp
|
||||
.send_login_api_key_request(LoginApiKeyParams {
|
||||
api_key: api_key.to_string(),
|
||||
})
|
||||
.await
|
||||
.unwrap_or_else(|e| panic!("send loginApiKey: {e}"));
|
||||
|
||||
let resp: JSONRPCResponse = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(request_id)),
|
||||
)
|
||||
.await
|
||||
.unwrap_or_else(|e| panic!("loginApiKey timeout: {e}"))
|
||||
.unwrap_or_else(|e| panic!("loginApiKey response: {e}"));
|
||||
let _: LoginApiKeyResponse =
|
||||
to_response(resp).unwrap_or_else(|e| panic!("deserialize login response: {e}"));
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn get_auth_status_no_auth() {
|
||||
let codex_home = TempDir::new().unwrap_or_else(|e| panic!("create tempdir: {e}"));
|
||||
create_config_toml(codex_home.path()).expect("write config.toml");
|
||||
create_config_toml(codex_home.path()).unwrap_or_else(|err| panic!("write config.toml: {err}"));
|
||||
|
||||
let mut mcp = McpProcess::new(codex_home.path())
|
||||
let mut mcp = McpProcess::new_with_env(codex_home.path(), &[("OPENAI_API_KEY", None)])
|
||||
.await
|
||||
.expect("spawn mcp process");
|
||||
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize())
|
||||
@@ -72,8 +92,7 @@ async fn get_auth_status_no_auth() {
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn get_auth_status_with_api_key() {
|
||||
let codex_home = TempDir::new().unwrap_or_else(|e| panic!("create tempdir: {e}"));
|
||||
create_config_toml(codex_home.path()).expect("write config.toml");
|
||||
login_with_api_key(codex_home.path(), "sk-test-key").expect("seed api key");
|
||||
create_config_toml(codex_home.path()).unwrap_or_else(|err| panic!("write config.toml: {err}"));
|
||||
|
||||
let mut mcp = McpProcess::new(codex_home.path())
|
||||
.await
|
||||
@@ -83,6 +102,8 @@ async fn get_auth_status_with_api_key() {
|
||||
.expect("init timeout")
|
||||
.expect("init failed");
|
||||
|
||||
login_with_api_key_via_request(&mut mcp, "sk-test-key").await;
|
||||
|
||||
let request_id = mcp
|
||||
.send_get_auth_status_request(GetAuthStatusParams {
|
||||
include_token: Some(true),
|
||||
@@ -101,14 +122,12 @@ async fn get_auth_status_with_api_key() {
|
||||
let status: GetAuthStatusResponse = to_response(resp).expect("deserialize status");
|
||||
assert_eq!(status.auth_method, Some(AuthMode::ApiKey));
|
||||
assert_eq!(status.auth_token, Some("sk-test-key".to_string()));
|
||||
assert_eq!(status.preferred_auth_method, AuthMode::ChatGPT);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn get_auth_status_with_api_key_no_include_token() {
|
||||
let codex_home = TempDir::new().unwrap_or_else(|e| panic!("create tempdir: {e}"));
|
||||
create_config_toml(codex_home.path()).expect("write config.toml");
|
||||
login_with_api_key(codex_home.path(), "sk-test-key").expect("seed api key");
|
||||
create_config_toml(codex_home.path()).unwrap_or_else(|err| panic!("write config.toml: {err}"));
|
||||
|
||||
let mut mcp = McpProcess::new(codex_home.path())
|
||||
.await
|
||||
@@ -118,6 +137,8 @@ async fn get_auth_status_with_api_key_no_include_token() {
|
||||
.expect("init timeout")
|
||||
.expect("init failed");
|
||||
|
||||
login_with_api_key_via_request(&mut mcp, "sk-test-key").await;
|
||||
|
||||
// Build params via struct so None field is omitted in wire JSON.
|
||||
let params = GetAuthStatusParams {
|
||||
include_token: None,
|
||||
@@ -138,5 +159,4 @@ async fn get_auth_status_with_api_key_no_include_token() {
|
||||
let status: GetAuthStatusResponse = to_response(resp).expect("deserialize status");
|
||||
assert_eq!(status.auth_method, Some(AuthMode::ApiKey));
|
||||
assert!(status.auth_token.is_none(), "token must be omitted");
|
||||
assert_eq!(status.preferred_auth_method, AuthMode::ChatGPT);
|
||||
}
|
||||
|
||||
@@ -90,6 +90,8 @@ async fn test_codex_jsonrpc_conversation_flow() {
|
||||
let NewConversationResponse {
|
||||
conversation_id,
|
||||
model,
|
||||
reasoning_effort: _,
|
||||
rollout_path: _,
|
||||
} = new_conv_resp;
|
||||
assert_eq!(model, "mock-model");
|
||||
|
||||
|
||||
@@ -59,6 +59,8 @@ async fn test_conversation_create_and_send_message_ok() {
|
||||
let NewConversationResponse {
|
||||
conversation_id,
|
||||
model,
|
||||
reasoning_effort: _,
|
||||
rollout_path: _,
|
||||
} = to_response::<NewConversationResponse>(new_conv_resp)
|
||||
.expect("deserialize newConversation response");
|
||||
assert_eq!(model, "o3");
|
||||
|
||||
@@ -156,23 +156,45 @@ fn create_fake_rollout(codex_home: &Path, filename_ts: &str, meta_rfc3339: &str,
|
||||
|
||||
let file_path = dir.join(format!("rollout-{filename_ts}-{uuid}.jsonl"));
|
||||
let mut lines = Vec::new();
|
||||
// Meta line with timestamp (flattened meta in payload for new schema)
|
||||
lines.push(
|
||||
json!({
|
||||
"record_type": "session_meta",
|
||||
"id": uuid,
|
||||
"timestamp": meta_rfc3339,
|
||||
"cwd": codex_home.to_string_lossy(),
|
||||
"originator": "test",
|
||||
"cli_version": "0.0.0-test"
|
||||
"type": "session_meta",
|
||||
"payload": {
|
||||
"id": uuid,
|
||||
"timestamp": meta_rfc3339,
|
||||
"cwd": "/",
|
||||
"originator": "codex",
|
||||
"cli_version": "0.0.0",
|
||||
"instructions": null
|
||||
}
|
||||
})
|
||||
.to_string(),
|
||||
);
|
||||
// Minimal user message entry as a persisted response item
|
||||
// Minimal user message entry as a persisted response item (with envelope timestamp)
|
||||
lines.push(
|
||||
json!({
|
||||
"type":"message",
|
||||
"role":"user",
|
||||
"content":[{"type":"input_text","text": preview}]
|
||||
"timestamp": meta_rfc3339,
|
||||
"type":"response_item",
|
||||
"payload": {
|
||||
"type":"message",
|
||||
"role":"user",
|
||||
"content":[{"type":"input_text","text": preview}]
|
||||
}
|
||||
})
|
||||
.to_string(),
|
||||
);
|
||||
// Add a matching user message event line to satisfy filters
|
||||
lines.push(
|
||||
json!({
|
||||
"timestamp": meta_rfc3339,
|
||||
"type":"event_msg",
|
||||
"payload": {
|
||||
"type":"user_message",
|
||||
"message": preview,
|
||||
"kind": "plain"
|
||||
}
|
||||
})
|
||||
.to_string(),
|
||||
);
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use std::path::Path;
|
||||
use std::time::Duration;
|
||||
|
||||
use codex_core::auth::login_with_api_key;
|
||||
use codex_login::login_with_api_key;
|
||||
use codex_protocol::mcp_protocol::CancelLoginChatGptParams;
|
||||
use codex_protocol::mcp_protocol::CancelLoginChatGptResponse;
|
||||
use codex_protocol::mcp_protocol::GetAuthStatusParams;
|
||||
@@ -46,7 +46,7 @@ async fn logout_chatgpt_removes_auth() {
|
||||
login_with_api_key(codex_home.path(), "sk-test-key").expect("seed api key");
|
||||
assert!(codex_home.path().join("auth.json").exists());
|
||||
|
||||
let mut mcp = McpProcess::new(codex_home.path())
|
||||
let mut mcp = McpProcess::new_with_env(codex_home.path(), &[("OPENAI_API_KEY", None)])
|
||||
.await
|
||||
.expect("spawn mcp process");
|
||||
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize())
|
||||
@@ -95,7 +95,7 @@ async fn logout_chatgpt_removes_auth() {
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn login_and_cancel_chatgpt() {
|
||||
let codex_home = TempDir::new().unwrap_or_else(|e| panic!("create tempdir: {e}"));
|
||||
create_config_toml(codex_home.path()).expect("write config.toml");
|
||||
create_config_toml(codex_home.path()).unwrap_or_else(|err| panic!("write config.toml: {err}"));
|
||||
|
||||
let mut mcp = McpProcess::new(codex_home.path())
|
||||
.await
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
// Aggregates all former standalone integration tests as modules.
|
||||
mod archive_conversation;
|
||||
mod auth;
|
||||
mod codex_message_processor_flow;
|
||||
mod codex_tool;
|
||||
@@ -8,4 +9,6 @@ mod interrupt;
|
||||
mod list_resume;
|
||||
mod login;
|
||||
mod send_message;
|
||||
mod set_default_model;
|
||||
mod user_agent;
|
||||
mod user_info;
|
||||
|
||||
62
codex-rs/mcp-server/tests/suite/set_default_model.rs
Normal file
62
codex-rs/mcp-server/tests/suite/set_default_model.rs
Normal file
@@ -0,0 +1,62 @@
|
||||
use codex_core::config::ConfigToml;
|
||||
use codex_protocol::config_types::ReasoningEffort;
|
||||
use codex_protocol::mcp_protocol::SetDefaultModelParams;
|
||||
use codex_protocol::mcp_protocol::SetDefaultModelResponse;
|
||||
use mcp_test_support::McpProcess;
|
||||
use mcp_test_support::to_response;
|
||||
use mcp_types::JSONRPCResponse;
|
||||
use mcp_types::RequestId;
|
||||
use pretty_assertions::assert_eq;
|
||||
use tempfile::TempDir;
|
||||
use tokio::time::timeout;
|
||||
|
||||
const DEFAULT_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn set_default_model_persists_overrides() {
|
||||
let codex_home = TempDir::new().unwrap_or_else(|e| panic!("create tempdir: {e}"));
|
||||
|
||||
let mut mcp = McpProcess::new(codex_home.path())
|
||||
.await
|
||||
.expect("spawn mcp process");
|
||||
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize())
|
||||
.await
|
||||
.expect("init timeout")
|
||||
.expect("init failed");
|
||||
|
||||
let params = SetDefaultModelParams {
|
||||
model: Some("o4-mini".to_string()),
|
||||
reasoning_effort: Some(ReasoningEffort::High),
|
||||
};
|
||||
|
||||
let request_id = mcp
|
||||
.send_set_default_model_request(params)
|
||||
.await
|
||||
.expect("send setDefaultModel");
|
||||
|
||||
let resp: JSONRPCResponse = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(request_id)),
|
||||
)
|
||||
.await
|
||||
.expect("setDefaultModel timeout")
|
||||
.expect("setDefaultModel response");
|
||||
|
||||
let _: SetDefaultModelResponse =
|
||||
to_response(resp).expect("deserialize setDefaultModel response");
|
||||
|
||||
let config_path = codex_home.path().join("config.toml");
|
||||
let config_contents = tokio::fs::read_to_string(&config_path)
|
||||
.await
|
||||
.expect("read config.toml");
|
||||
let config_toml: ConfigToml = toml::from_str(&config_contents).expect("parse config.toml");
|
||||
|
||||
assert_eq!(
|
||||
ConfigToml {
|
||||
model: Some("o4-mini".to_string()),
|
||||
model_reasoning_effort: Some(ReasoningEffort::High),
|
||||
..Default::default()
|
||||
},
|
||||
config_toml,
|
||||
);
|
||||
}
|
||||
@@ -1,5 +1,3 @@
|
||||
use codex_core::default_client::DEFAULT_ORIGINATOR;
|
||||
use codex_core::default_client::get_codex_user_agent;
|
||||
use codex_protocol::mcp_protocol::GetUserAgentResponse;
|
||||
use mcp_test_support::McpProcess;
|
||||
use mcp_test_support::to_response;
|
||||
@@ -35,11 +33,18 @@ async fn get_user_agent_returns_current_codex_user_agent() {
|
||||
.expect("getUserAgent timeout")
|
||||
.expect("getUserAgent response");
|
||||
|
||||
let os_info = os_info::get();
|
||||
let user_agent = format!(
|
||||
"codex_cli_rs/0.0.0 ({} {}; {}) {} (elicitation test; 0.0.0)",
|
||||
os_info.os_type(),
|
||||
os_info.version(),
|
||||
os_info.architecture().unwrap_or("unknown"),
|
||||
codex_core::terminal::user_agent()
|
||||
);
|
||||
|
||||
let received: GetUserAgentResponse =
|
||||
to_response(response).expect("deserialize getUserAgent response");
|
||||
let expected = GetUserAgentResponse {
|
||||
user_agent: get_codex_user_agent(Some(DEFAULT_ORIGINATOR)),
|
||||
};
|
||||
let expected = GetUserAgentResponse { user_agent };
|
||||
|
||||
assert_eq!(received, expected);
|
||||
}
|
||||
|
||||
78
codex-rs/mcp-server/tests/suite/user_info.rs
Normal file
78
codex-rs/mcp-server/tests/suite/user_info.rs
Normal file
@@ -0,0 +1,78 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Context;
|
||||
use base64::Engine;
|
||||
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
|
||||
use codex_core::auth::AuthDotJson;
|
||||
use codex_core::auth::get_auth_file;
|
||||
use codex_core::auth::write_auth_json;
|
||||
use codex_core::token_data::IdTokenInfo;
|
||||
use codex_core::token_data::TokenData;
|
||||
use codex_protocol::mcp_protocol::UserInfoResponse;
|
||||
use mcp_test_support::McpProcess;
|
||||
use mcp_test_support::to_response;
|
||||
use mcp_types::JSONRPCResponse;
|
||||
use mcp_types::RequestId;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
use tempfile::TempDir;
|
||||
use tokio::time::timeout;
|
||||
|
||||
const DEFAULT_READ_TIMEOUT: Duration = Duration::from_secs(10);
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn user_info_returns_email_from_auth_json() {
|
||||
let codex_home = TempDir::new().expect("create tempdir");
|
||||
|
||||
let auth_path = get_auth_file(codex_home.path());
|
||||
let mut id_token = IdTokenInfo::default();
|
||||
id_token.email = Some("user@example.com".to_string());
|
||||
id_token.raw_jwt = encode_id_token_with_email("user@example.com").expect("encode id token");
|
||||
|
||||
let auth = AuthDotJson {
|
||||
openai_api_key: None,
|
||||
tokens: Some(TokenData {
|
||||
id_token,
|
||||
access_token: "access".to_string(),
|
||||
refresh_token: "refresh".to_string(),
|
||||
account_id: None,
|
||||
}),
|
||||
last_refresh: None,
|
||||
};
|
||||
write_auth_json(&auth_path, &auth).expect("write auth.json");
|
||||
|
||||
let mut mcp = McpProcess::new(codex_home.path())
|
||||
.await
|
||||
.expect("spawn mcp process");
|
||||
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize())
|
||||
.await
|
||||
.expect("initialize timeout")
|
||||
.expect("initialize request");
|
||||
|
||||
let request_id = mcp.send_user_info_request().await.expect("send userInfo");
|
||||
let response: JSONRPCResponse = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(request_id)),
|
||||
)
|
||||
.await
|
||||
.expect("userInfo timeout")
|
||||
.expect("userInfo response");
|
||||
|
||||
let received: UserInfoResponse = to_response(response).expect("deserialize userInfo response");
|
||||
let expected = UserInfoResponse {
|
||||
alleged_user_email: Some("user@example.com".to_string()),
|
||||
};
|
||||
|
||||
assert_eq!(received, expected);
|
||||
}
|
||||
|
||||
fn encode_id_token_with_email(email: &str) -> anyhow::Result<String> {
|
||||
let header_b64 = URL_SAFE_NO_PAD.encode(
|
||||
serde_json::to_vec(&json!({ "alg": "none", "typ": "JWT" }))
|
||||
.context("serialize jwt header")?,
|
||||
);
|
||||
let payload =
|
||||
serde_json::to_vec(&json!({ "email": email })).context("serialize jwt payload")?;
|
||||
let payload_b64 = URL_SAFE_NO_PAD.encode(payload);
|
||||
Ok(format!("{header_b64}.{payload_b64}.signature"))
|
||||
}
|
||||
21
codex-rs/mcp-types/check_lib_rs.py
Executable file
21
codex-rs/mcp-types/check_lib_rs.py
Executable file
@@ -0,0 +1,21 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def main() -> int:
|
||||
crate_dir = Path(__file__).resolve().parent
|
||||
generator = crate_dir / "generate_mcp_types.py"
|
||||
|
||||
result = subprocess.run(
|
||||
[sys.executable, str(generator), "--check"],
|
||||
cwd=crate_dir,
|
||||
check=False,
|
||||
)
|
||||
return result.returncode
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
@@ -5,15 +5,19 @@ import argparse
|
||||
import json
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
from dataclasses import (
|
||||
dataclass,
|
||||
)
|
||||
from difflib import unified_diff
|
||||
from pathlib import Path
|
||||
from shutil import copy2
|
||||
|
||||
# Helper first so it is defined when other functions call it.
|
||||
from typing import Any, Literal
|
||||
|
||||
|
||||
SCHEMA_VERSION = "2025-06-18"
|
||||
JSONRPC_VERSION = "2.0"
|
||||
|
||||
@@ -43,16 +47,31 @@ def main() -> int:
|
||||
default_schema_file = (
|
||||
Path(__file__).resolve().parent / "schema" / SCHEMA_VERSION / "schema.json"
|
||||
)
|
||||
default_lib_rs = Path(__file__).resolve().parent / "src/lib.rs"
|
||||
parser.add_argument(
|
||||
"schema_file",
|
||||
nargs="?",
|
||||
default=default_schema_file,
|
||||
help="schema.json file to process",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--check",
|
||||
action="store_true",
|
||||
help="Regenerate lib.rs in a sandbox and ensure the checked-in file matches",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
schema_file = args.schema_file
|
||||
schema_file = Path(args.schema_file)
|
||||
crate_dir = Path(__file__).resolve().parent
|
||||
|
||||
lib_rs = Path(__file__).resolve().parent / "src/lib.rs"
|
||||
if args.check:
|
||||
return run_check(schema_file, crate_dir, default_lib_rs)
|
||||
|
||||
generate_lib_rs(schema_file, default_lib_rs, fmt=True)
|
||||
return 0
|
||||
|
||||
|
||||
def generate_lib_rs(schema_file: Path, lib_rs: Path, fmt: bool) -> None:
|
||||
lib_rs.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
global DEFINITIONS # Allow helper functions to access the schema.
|
||||
|
||||
@@ -117,9 +136,7 @@ fn default_jsonrpc() -> String {{ JSONRPC_VERSION.to_owned() }}
|
||||
|
||||
for req_name in CLIENT_REQUEST_TYPE_NAMES:
|
||||
defn = definitions[req_name]
|
||||
method_const = (
|
||||
defn.get("properties", {}).get("method", {}).get("const", req_name)
|
||||
)
|
||||
method_const = defn.get("properties", {}).get("method", {}).get("const", req_name)
|
||||
payload_type = f"<{req_name} as ModelContextProtocolRequest>::Params"
|
||||
try_from_impl_lines.append(f' "{method_const}" => {{\n')
|
||||
try_from_impl_lines.append(
|
||||
@@ -128,9 +145,7 @@ fn default_jsonrpc() -> String {{ JSONRPC_VERSION.to_owned() }}
|
||||
try_from_impl_lines.append(
|
||||
f" let params: {payload_type} = serde_json::from_value(params_json)?;\n"
|
||||
)
|
||||
try_from_impl_lines.append(
|
||||
f" Ok(ClientRequest::{req_name}(params))\n"
|
||||
)
|
||||
try_from_impl_lines.append(f" Ok(ClientRequest::{req_name}(params))\n")
|
||||
try_from_impl_lines.append(" },\n")
|
||||
|
||||
try_from_impl_lines.append(
|
||||
@@ -144,9 +159,7 @@ fn default_jsonrpc() -> String {{ JSONRPC_VERSION.to_owned() }}
|
||||
|
||||
# Generate TryFrom for ServerNotification
|
||||
notif_impl_lines: list[str] = []
|
||||
notif_impl_lines.append(
|
||||
"impl TryFrom<JSONRPCNotification> for ServerNotification {\n"
|
||||
)
|
||||
notif_impl_lines.append("impl TryFrom<JSONRPCNotification> for ServerNotification {\n")
|
||||
notif_impl_lines.append(" type Error = serde_json::Error;\n")
|
||||
notif_impl_lines.append(
|
||||
" fn try_from(n: JSONRPCNotification) -> std::result::Result<Self, Self::Error> {\n"
|
||||
@@ -155,9 +168,7 @@ fn default_jsonrpc() -> String {{ JSONRPC_VERSION.to_owned() }}
|
||||
|
||||
for notif_name in SERVER_NOTIFICATION_TYPE_NAMES:
|
||||
n_def = definitions[notif_name]
|
||||
method_const = (
|
||||
n_def.get("properties", {}).get("method", {}).get("const", notif_name)
|
||||
)
|
||||
method_const = n_def.get("properties", {}).get("method", {}).get("const", notif_name)
|
||||
payload_type = f"<{notif_name} as ModelContextProtocolNotification>::Params"
|
||||
notif_impl_lines.append(f' "{method_const}" => {{\n')
|
||||
# params may be optional
|
||||
@@ -167,9 +178,7 @@ fn default_jsonrpc() -> String {{ JSONRPC_VERSION.to_owned() }}
|
||||
notif_impl_lines.append(
|
||||
f" let params: {payload_type} = serde_json::from_value(params_json)?;\n"
|
||||
)
|
||||
notif_impl_lines.append(
|
||||
f" Ok(ServerNotification::{notif_name}(params))\n"
|
||||
)
|
||||
notif_impl_lines.append(f" Ok(ServerNotification::{notif_name}(params))\n")
|
||||
notif_impl_lines.append(" },\n")
|
||||
|
||||
notif_impl_lines.append(
|
||||
@@ -185,13 +194,70 @@ fn default_jsonrpc() -> String {{ JSONRPC_VERSION.to_owned() }}
|
||||
for chunk in out:
|
||||
f.write(chunk)
|
||||
|
||||
subprocess.check_call(
|
||||
["cargo", "fmt", "--", "--config", "imports_granularity=Item"],
|
||||
cwd=lib_rs.parent.parent,
|
||||
stderr=subprocess.DEVNULL,
|
||||
)
|
||||
if fmt:
|
||||
subprocess.check_call(
|
||||
["cargo", "fmt", "--", "--config", "imports_granularity=Item"],
|
||||
cwd=lib_rs.parent.parent,
|
||||
stderr=subprocess.DEVNULL,
|
||||
)
|
||||
|
||||
return 0
|
||||
|
||||
def run_check(schema_file: Path, crate_dir: Path, checked_in_lib: Path) -> int:
|
||||
config_path = crate_dir.parent / "rustfmt.toml"
|
||||
eprint(f"Running --check with schema {schema_file}")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_path = Path(tmp_dir)
|
||||
eprint(f"Created temporary workspace at {tmp_path}")
|
||||
manifest_path = tmp_path / "Cargo.toml"
|
||||
eprint(f"Copying Cargo.toml into {manifest_path}")
|
||||
copy2(crate_dir / "Cargo.toml", manifest_path)
|
||||
manifest_text = manifest_path.read_text(encoding="utf-8")
|
||||
manifest_text = manifest_text.replace(
|
||||
"version = { workspace = true }",
|
||||
'version = "0.0.0"',
|
||||
)
|
||||
manifest_text = manifest_text.replace("\n[lints]\nworkspace = true\n", "\n")
|
||||
manifest_path.write_text(manifest_text, encoding="utf-8")
|
||||
src_dir = tmp_path / "src"
|
||||
src_dir.mkdir(parents=True, exist_ok=True)
|
||||
eprint(f"Generating lib.rs into {src_dir}")
|
||||
generated_lib = src_dir / "lib.rs"
|
||||
|
||||
generate_lib_rs(schema_file, generated_lib, fmt=False)
|
||||
|
||||
eprint("Formatting generated lib.rs with rustfmt")
|
||||
subprocess.check_call(
|
||||
[
|
||||
"rustfmt",
|
||||
"--config-path",
|
||||
str(config_path),
|
||||
str(generated_lib),
|
||||
],
|
||||
cwd=tmp_path,
|
||||
stderr=subprocess.DEVNULL,
|
||||
)
|
||||
|
||||
eprint("Comparing generated lib.rs with checked-in version")
|
||||
checked_in_contents = checked_in_lib.read_text(encoding="utf-8")
|
||||
generated_contents = generated_lib.read_text(encoding="utf-8")
|
||||
|
||||
if checked_in_contents == generated_contents:
|
||||
eprint("lib.rs matches checked-in version")
|
||||
return 0
|
||||
|
||||
diff = unified_diff(
|
||||
checked_in_contents.splitlines(keepends=True),
|
||||
generated_contents.splitlines(keepends=True),
|
||||
fromfile=str(checked_in_lib),
|
||||
tofile=str(generated_lib),
|
||||
)
|
||||
diff_text = "".join(diff)
|
||||
eprint("Generated lib.rs does not match the checked-in version. Diff:")
|
||||
if diff_text:
|
||||
eprint(diff_text, end="")
|
||||
eprint("Re-run generate_mcp_types.py without --check to update src/lib.rs.")
|
||||
return 1
|
||||
|
||||
|
||||
def add_definition(name: str, definition: dict[str, Any], out: list[str]) -> None:
|
||||
@@ -265,8 +331,11 @@ class StructField:
|
||||
name: str
|
||||
type_name: str
|
||||
serde: str | None = None
|
||||
comment: str | None = None
|
||||
|
||||
def append(self, out: list[str], supports_const: bool) -> None:
|
||||
if self.comment:
|
||||
out.append(f" // {self.comment}\n")
|
||||
if self.serde:
|
||||
out.append(f" {self.serde}\n")
|
||||
if self.viz == "const":
|
||||
@@ -312,6 +381,18 @@ def define_struct(
|
||||
else:
|
||||
fields.append(StructField("pub", rs_prop.name, prop_type, rs_prop.serde))
|
||||
|
||||
# Special-case: add Codex-specific user_agent to Implementation
|
||||
if name == "Implementation":
|
||||
fields.append(
|
||||
StructField(
|
||||
"pub",
|
||||
"user_agent",
|
||||
"Option<String>",
|
||||
'#[serde(default, skip_serializing_if = "Option::is_none")]',
|
||||
"This is an extra field that the Codex MCP server sends as part of InitializeResult.",
|
||||
)
|
||||
)
|
||||
|
||||
if implements_request_trait(name):
|
||||
add_trait_impl(name, "ModelContextProtocolRequest", fields, out)
|
||||
elif implements_notification_trait(name):
|
||||
@@ -406,15 +487,11 @@ def define_untagged_enum(name: str, type_list: list[str], out: list[str]) -> Non
|
||||
case "integer":
|
||||
out.append(" Integer(i64),\n")
|
||||
case _:
|
||||
raise ValueError(
|
||||
f"Unknown type in untagged enum: {simple_type} in {name}"
|
||||
)
|
||||
raise ValueError(f"Unknown type in untagged enum: {simple_type} in {name}")
|
||||
out.append("}\n\n")
|
||||
|
||||
|
||||
def define_any_of(
|
||||
name: str, list_of_refs: list[Any], description: str | None = None
|
||||
) -> list[str]:
|
||||
def define_any_of(name: str, list_of_refs: list[Any], description: str | None = None) -> list[str]:
|
||||
"""Generate a Rust enum for a JSON-Schema `anyOf` union.
|
||||
|
||||
For most types we simply map each `$ref` inside the `anyOf` list to a
|
||||
@@ -479,9 +556,7 @@ def define_any_of(
|
||||
if name == "ClientRequest":
|
||||
payload_type = f"<{ref_name} as ModelContextProtocolRequest>::Params"
|
||||
else:
|
||||
payload_type = (
|
||||
f"<{ref_name} as ModelContextProtocolNotification>::Params"
|
||||
)
|
||||
payload_type = f"<{ref_name} as ModelContextProtocolNotification>::Params"
|
||||
|
||||
# Determine the wire value for `method` so we can annotate the
|
||||
# variant appropriately. If for some reason the schema does not
|
||||
@@ -489,9 +564,7 @@ def define_any_of(
|
||||
# least compile (although deserialization will likely fail).
|
||||
request_def = DEFINITIONS.get(ref_name, {})
|
||||
method_const = (
|
||||
request_def.get("properties", {})
|
||||
.get("method", {})
|
||||
.get("const", ref_name)
|
||||
request_def.get("properties", {}).get("method", {}).get("const", ref_name)
|
||||
)
|
||||
|
||||
out.append(f' #[serde(rename = "{method_const}")]\n')
|
||||
@@ -541,7 +614,7 @@ def map_type(
|
||||
if type_prop == "string":
|
||||
if const_prop := typedef.get("const", None):
|
||||
assert isinstance(const_prop, str)
|
||||
return f'&\'static str = "{const_prop }"'
|
||||
return f'&\'static str = "{const_prop}"'
|
||||
else:
|
||||
return "String"
|
||||
elif type_prop == "integer":
|
||||
@@ -617,7 +690,7 @@ def rust_prop_name(name: str, is_optional: bool) -> RustProp:
|
||||
serde_annotations.append('skip_serializing_if = "Option::is_none"')
|
||||
|
||||
if serde_annotations:
|
||||
serde_str = f'#[serde({", ".join(serde_annotations)})]'
|
||||
serde_str = f"#[serde({', '.join(serde_annotations)})]"
|
||||
else:
|
||||
serde_str = None
|
||||
return RustProp(prop_name, serde_str)
|
||||
@@ -625,9 +698,7 @@ def rust_prop_name(name: str, is_optional: bool) -> RustProp:
|
||||
|
||||
def to_snake_case(name: str) -> str:
|
||||
"""Convert a camelCase or PascalCase name to snake_case."""
|
||||
snake_case = name[0].lower() + "".join(
|
||||
"_" + c.lower() if c.isupper() else c for c in name[1:]
|
||||
)
|
||||
snake_case = name[0].lower() + "".join("_" + c.lower() if c.isupper() else c for c in name[1:])
|
||||
if snake_case != name:
|
||||
return snake_case
|
||||
else:
|
||||
@@ -663,5 +734,9 @@ def emit_doc_comment(text: str | None, out: list[str]) -> None:
|
||||
out.append(f"/// {line.rstrip()}\n")
|
||||
|
||||
|
||||
def eprint(*args: Any, **kwargs: Any) -> None:
|
||||
print(*args, file=sys.stderr, **kwargs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
|
||||
@@ -487,6 +487,9 @@ pub struct Implementation {
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub title: Option<String>,
|
||||
pub version: String,
|
||||
// This is an extra field that the Codex MCP server sends as part of InitializeResult.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub user_agent: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize, TS)]
|
||||
|
||||
@@ -62,6 +62,7 @@ fn deserialize_initialize_request() {
|
||||
name: "acme-client".into(),
|
||||
title: Some("Acme".to_string()),
|
||||
version: "1.2.3".into(),
|
||||
user_agent: None,
|
||||
},
|
||||
protocol_version: "2025-06-18".into(),
|
||||
}
|
||||
|
||||
@@ -24,9 +24,7 @@ tokio = { version = "1", features = [
|
||||
"rt-multi-thread",
|
||||
"signal",
|
||||
] }
|
||||
toml = "0.9.5"
|
||||
tracing = { version = "0.1.41", features = ["log"] }
|
||||
wiremock = "0.6"
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = "3"
|
||||
|
||||
@@ -16,6 +16,7 @@ path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1"
|
||||
mcp-types = { path = "../mcp-types" }
|
||||
codex-protocol = { path = "../protocol" }
|
||||
ts-rs = "11"
|
||||
clap = { version = "4", features = ["derive"] }
|
||||
|
||||
@@ -16,29 +16,37 @@ pub fn generate_ts(out_dir: &Path, prettier: Option<&Path>) -> Result<()> {
|
||||
ensure_dir(out_dir)?;
|
||||
|
||||
// Generate TS bindings
|
||||
mcp_types::InitializeResult::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::ConversationId::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::InputItem::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::ClientRequest::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::ServerRequest::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::NewConversationResponse::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::ListConversationsResponse::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::ResumeConversationResponse::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::ArchiveConversationResponse::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::AddConversationSubscriptionResponse::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::RemoveConversationSubscriptionResponse::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::SendUserMessageResponse::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::SendUserTurnResponse::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::InterruptConversationResponse::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::GitDiffToRemoteResponse::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::LoginApiKeyParams::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::LoginApiKeyResponse::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::LoginChatGptResponse::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::LoginChatGptCompleteNotification::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::CancelLoginChatGptResponse::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::LogoutChatGptResponse::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::GetAuthStatusResponse::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::ApplyPatchApprovalResponse::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::ExecCommandApprovalResponse::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::GetUserSavedConfigResponse::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::SetDefaultModelResponse::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::GetUserAgentResponse::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::UserInfoResponse::export_all_to(out_dir)?;
|
||||
|
||||
// All notification types reachable from this enum will be generated by
|
||||
// induction, so they do not need to be listed individually.
|
||||
codex_protocol::mcp_protocol::ServerNotification::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::ListConversationsResponse::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::ResumeConversationResponse::export_all_to(out_dir)?;
|
||||
|
||||
generate_index_ts(out_dir)?;
|
||||
|
||||
|
||||
@@ -17,7 +17,6 @@ icu_locale_core = "2.0.0"
|
||||
mcp-types = { path = "../mcp-types" }
|
||||
mime_guess = "2.0.5"
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_bytes = "0.11"
|
||||
serde_json = "1"
|
||||
serde_with = { version = "3.14.0", features = ["macros", "base64"] }
|
||||
strum = "0.27.2"
|
||||
@@ -29,3 +28,8 @@ uuid = { version = "1", features = ["serde", "v4"] }
|
||||
|
||||
[dev-dependencies]
|
||||
pretty_assertions = "1.4.1"
|
||||
tempfile = "3"
|
||||
|
||||
[package.metadata.cargo-shear]
|
||||
# Required because the not imported as strum_macros in non-nightly builds.
|
||||
ignored = ["strum"]
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user