mirror of
https://github.com/openai/codex.git
synced 2026-02-08 01:43:46 +00:00
Compare commits
53 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7a3e4b607f | ||
|
|
d3ff668f68 | ||
|
|
81caee3400 | ||
|
|
51dd5af807 | ||
|
|
6372ba9d5f | ||
|
|
bdfdebcfa1 | ||
|
|
62a73b6d58 | ||
|
|
be4364bb80 | ||
|
|
0d3e673019 | ||
|
|
41a317321d | ||
|
|
051bf81df9 | ||
|
|
a70f5b0b3c | ||
|
|
224c4867dd | ||
|
|
c9c6560685 | ||
|
|
634764ece9 | ||
|
|
5bc3e325a6 | ||
|
|
4156060416 | ||
|
|
98122cbad0 | ||
|
|
7b21b443bb | ||
|
|
93dec9045e | ||
|
|
69898e3dba | ||
|
|
c4af304c77 | ||
|
|
5b7707dfb1 | ||
|
|
59d6937550 | ||
|
|
932a5a446f | ||
|
|
484f6f4c26 | ||
|
|
5522663f92 | ||
|
|
98e171258c | ||
|
|
da667b1f56 | ||
|
|
1e29774fce | ||
|
|
9ce6bbc43e | ||
|
|
7520d8ba58 | ||
|
|
0318f30ed8 | ||
|
|
be212db0c8 | ||
|
|
5b022c2904 | ||
|
|
e21ce6c5de | ||
|
|
267c05fb30 | ||
|
|
634650dd25 | ||
|
|
8a0c2e5841 | ||
|
|
0f8bb4579b | ||
|
|
35fd69a9f0 | ||
|
|
ccba737d26 | ||
|
|
75076aabfe | ||
|
|
f6b563ec64 | ||
|
|
357e4c902b | ||
|
|
ef8b8ebc94 | ||
|
|
54b290ec1d | ||
|
|
efd0c21b9b | ||
|
|
61e81af887 | ||
|
|
f07b8aa591 | ||
|
|
5f3f70203c | ||
|
|
21c6d40a44 | ||
|
|
a9b5e8a136 |
@@ -14,6 +14,26 @@ RUN apt-get update && \
|
||||
pkg-config clang musl-tools libssl-dev just && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install Bazel via Bazelisk (mirrors bazelbuild/setup-bazelisk@v3).
|
||||
ARG BAZELISK_VERSION=latest
|
||||
RUN arch="$(uname -m)" && \
|
||||
case "$arch" in \
|
||||
x86_64) arch="amd64" ;; \
|
||||
aarch64) arch="arm64" ;; \
|
||||
*) echo "Unsupported architecture: $arch" >&2; exit 1 ;; \
|
||||
esac && \
|
||||
if [ "$BAZELISK_VERSION" = "latest" ]; then \
|
||||
url="https://github.com/bazelbuild/bazelisk/releases/latest/download/bazelisk-linux-${arch}"; \
|
||||
else \
|
||||
url="https://github.com/bazelbuild/bazelisk/releases/download/v${BAZELISK_VERSION}/bazelisk-linux-${arch}"; \
|
||||
fi && \
|
||||
curl -fsSL "$url" -o /usr/local/bin/bazelisk && \
|
||||
chmod +x /usr/local/bin/bazelisk && \
|
||||
ln -s /usr/local/bin/bazelisk /usr/local/bin/bazel
|
||||
|
||||
# Install dotslash.
|
||||
RUN curl -LSfs "https://github.com/facebook/dotslash/releases/download/v0.5.8/dotslash-ubuntu-22.04.$(uname -m).tar.gz" | tar fxz - -C /usr/local/bin
|
||||
|
||||
# Ubuntu 24.04 ships with user 'ubuntu' already created with UID 1000.
|
||||
USER ubuntu
|
||||
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -9,6 +9,7 @@ node_modules
|
||||
|
||||
# build
|
||||
dist/
|
||||
bazel-*
|
||||
build/
|
||||
out/
|
||||
storybook-static/
|
||||
|
||||
@@ -77,11 +77,11 @@ If you don’t have the tool:
|
||||
- Prefer deep equals comparisons whenever possible. Perform `assert_eq!()` on entire objects, rather than individual fields.
|
||||
- Avoid mutating process environment in tests; prefer passing environment-derived flags or dependencies from above.
|
||||
|
||||
### Spawning workspace binaries in tests (Cargo vs Buck2)
|
||||
### Spawning workspace binaries in tests (Cargo vs Bazel)
|
||||
|
||||
- Prefer `codex_utils_cargo_bin::cargo_bin("...")` over `assert_cmd::Command::cargo_bin(...)` or `escargot` when tests need to spawn first-party binaries.
|
||||
- Under Buck2, `CARGO_BIN_EXE_*` may be project-relative (e.g. `buck-out/...`), which breaks if a test changes its working directory. `codex_utils_cargo_bin::cargo_bin` resolves to an absolute path first.
|
||||
- When locating fixture files under Buck2, avoid `env!("CARGO_MANIFEST_DIR")` (Buck codegen sets it to `"."`). Prefer deriving paths from `codex_utils_cargo_bin::buck_project_root()` when needed.
|
||||
- Under Bazel, binaries and resources may live under runfiles; use `codex_utils_cargo_bin::cargo_bin` to resolve absolute paths that remain stable after `chdir`.
|
||||
- When locating fixture files or test resources under Bazel, avoid `env!("CARGO_MANIFEST_DIR")`. Prefer `codex_utils_cargo_bin::find_resource!` so paths resolve correctly under both Cargo and Bazel runfiles.
|
||||
|
||||
### Integration tests (core)
|
||||
|
||||
|
||||
72
codex-rs/Cargo.lock
generated
72
codex-rs/Cargo.lock
generated
@@ -819,6 +819,8 @@ version = "1.2.30"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "deec109607ca693028562ed836a5f1c4b8bd77755c4e132fc5ce11b0b6211ae7"
|
||||
dependencies = [
|
||||
"jobserver",
|
||||
"libc",
|
||||
"shlex",
|
||||
]
|
||||
|
||||
@@ -1127,6 +1129,7 @@ dependencies = [
|
||||
"codex-common",
|
||||
"codex-core",
|
||||
"codex-git",
|
||||
"codex-utils-cargo-bin",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"tempfile",
|
||||
@@ -1153,7 +1156,6 @@ dependencies = [
|
||||
"codex-execpolicy",
|
||||
"codex-login",
|
||||
"codex-mcp-server",
|
||||
"codex-process-hardening",
|
||||
"codex-protocol",
|
||||
"codex-responses-api-proxy",
|
||||
"codex-rmcp-client",
|
||||
@@ -1163,7 +1165,6 @@ dependencies = [
|
||||
"codex-utils-absolute-path",
|
||||
"codex-utils-cargo-bin",
|
||||
"codex-windows-sandbox",
|
||||
"ctor 0.5.0",
|
||||
"libc",
|
||||
"owo-colors",
|
||||
"predicates",
|
||||
@@ -1197,6 +1198,7 @@ dependencies = [
|
||||
"tracing",
|
||||
"tracing-opentelemetry",
|
||||
"tracing-subscriber",
|
||||
"zstd",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1298,7 +1300,6 @@ dependencies = [
|
||||
"dunce",
|
||||
"encoding_rs",
|
||||
"env-flags",
|
||||
"escargot",
|
||||
"eventsource-stream",
|
||||
"futures",
|
||||
"http 1.3.1",
|
||||
@@ -1348,6 +1349,19 @@ dependencies = [
|
||||
"which",
|
||||
"wildmatch",
|
||||
"wiremock",
|
||||
"zstd",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "codex-debug-client"
|
||||
version = "0.0.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"clap",
|
||||
"codex-app-server-protocol",
|
||||
"pretty_assertions",
|
||||
"serde",
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1605,10 +1619,12 @@ dependencies = [
|
||||
"opentelemetry-otlp",
|
||||
"opentelemetry-semantic-conventions",
|
||||
"opentelemetry_sdk",
|
||||
"pretty_assertions",
|
||||
"reqwest",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"strum_macros 0.27.2",
|
||||
"thiserror 2.0.17",
|
||||
"tokio",
|
||||
"tracing",
|
||||
"tracing-opentelemetry",
|
||||
@@ -1876,6 +1892,7 @@ name = "codex-utils-cargo-bin"
|
||||
version = "0.0.0"
|
||||
dependencies = [
|
||||
"assert_cmd",
|
||||
"path-absolutize",
|
||||
"thiserror 2.0.17",
|
||||
]
|
||||
|
||||
@@ -2790,17 +2807,6 @@ version = "3.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dea2df4cf52843e0452895c455a1a2cfbb842a1e7329671acf418fdc53ed4c59"
|
||||
|
||||
[[package]]
|
||||
name = "escargot"
|
||||
version = "0.5.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "11c3aea32bc97b500c9ca6a72b768a26e558264303d101d3409cf6d57a9ed0cf"
|
||||
dependencies = [
|
||||
"log",
|
||||
"serde",
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "event-listener"
|
||||
version = "5.4.0"
|
||||
@@ -3924,6 +3930,16 @@ version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130"
|
||||
|
||||
[[package]]
|
||||
name = "jobserver"
|
||||
version = "0.1.34"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9afb3de4395d6b3e67a780b6de64b51c978ecf11cb9a462c66be7d4ca9039d33"
|
||||
dependencies = [
|
||||
"getrandom 0.3.3",
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "js-sys"
|
||||
version = "0.3.77"
|
||||
@@ -8809,6 +8825,34 @@ dependencies = [
|
||||
"syn 2.0.104",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zstd"
|
||||
version = "0.13.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e91ee311a569c327171651566e07972200e76fcfe2242a4fa446149a3881c08a"
|
||||
dependencies = [
|
||||
"zstd-safe",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zstd-safe"
|
||||
version = "7.2.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8f49c4d5f0abb602a93fb8736af2a4f4dd9512e36f7f570d66e65ff867ed3b9d"
|
||||
dependencies = [
|
||||
"zstd-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zstd-sys"
|
||||
version = "2.0.16+zstd.1.5.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "91e19ebc2adc8f83e43039e79776e3fda8ca919132d68a1fed6a5faca2683748"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"pkg-config",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zune-core"
|
||||
version = "0.4.12"
|
||||
|
||||
@@ -6,6 +6,7 @@ members = [
|
||||
"app-server",
|
||||
"app-server-protocol",
|
||||
"app-server-test-client",
|
||||
"debug-client",
|
||||
"apply-patch",
|
||||
"arg0",
|
||||
"feedback",
|
||||
@@ -134,7 +135,6 @@ dunce = "1.0.4"
|
||||
encoding_rs = "0.8.35"
|
||||
env-flags = "0.1.1"
|
||||
env_logger = "0.11.5"
|
||||
escargot = "0.5"
|
||||
eventsource-stream = "0.2.3"
|
||||
futures = { version = "0.3", default-features = false }
|
||||
http = "1.3.1"
|
||||
@@ -218,6 +218,7 @@ tracing-subscriber = "0.3.22"
|
||||
tracing-test = "0.2.5"
|
||||
tree-sitter = "0.25.10"
|
||||
tree-sitter-bash = "0.25"
|
||||
zstd = "0.13"
|
||||
tree-sitter-highlight = "0.25.10"
|
||||
ts-rs = "11"
|
||||
tui-scrollbar = "0.2.1"
|
||||
|
||||
@@ -109,6 +109,10 @@ client_request_definitions! {
|
||||
params: v2::ThreadResumeParams,
|
||||
response: v2::ThreadResumeResponse,
|
||||
},
|
||||
ThreadFork => "thread/fork" {
|
||||
params: v2::ThreadForkParams,
|
||||
response: v2::ThreadForkResponse,
|
||||
},
|
||||
ThreadArchive => "thread/archive" {
|
||||
params: v2::ThreadArchiveParams,
|
||||
response: v2::ThreadArchiveResponse,
|
||||
@@ -121,6 +125,10 @@ client_request_definitions! {
|
||||
params: v2::ThreadListParams,
|
||||
response: v2::ThreadListResponse,
|
||||
},
|
||||
ThreadLoadedList => "thread/loaded/list" {
|
||||
params: v2::ThreadLoadedListParams,
|
||||
response: v2::ThreadLoadedListResponse,
|
||||
},
|
||||
SkillsList => "skills/list" {
|
||||
params: v2::SkillsListParams,
|
||||
response: v2::SkillsListResponse,
|
||||
@@ -197,6 +205,11 @@ client_request_definitions! {
|
||||
response: v2::ConfigWriteResponse,
|
||||
},
|
||||
|
||||
ConfigRequirementsRead => "configRequirements/read" {
|
||||
params: #[ts(type = "undefined")] #[serde(skip_serializing_if = "Option::is_none")] Option<()>,
|
||||
response: v2::ConfigRequirementsReadResponse,
|
||||
},
|
||||
|
||||
GetAccount => "account/read" {
|
||||
params: v2::GetAccountParams,
|
||||
response: v2::GetAccountResponse,
|
||||
@@ -221,6 +234,11 @@ client_request_definitions! {
|
||||
params: v1::ResumeConversationParams,
|
||||
response: v1::ResumeConversationResponse,
|
||||
},
|
||||
/// Fork a recorded Codex conversation into a new session.
|
||||
ForkConversation {
|
||||
params: v1::ForkConversationParams,
|
||||
response: v1::ForkConversationResponse,
|
||||
},
|
||||
ArchiveConversation {
|
||||
params: v1::ArchiveConversationParams,
|
||||
response: v1::ArchiveConversationResponse,
|
||||
@@ -711,6 +729,22 @@ mod tests {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serialize_config_requirements_read() -> Result<()> {
|
||||
let request = ClientRequest::ConfigRequirementsRead {
|
||||
request_id: RequestId::Integer(1),
|
||||
params: None,
|
||||
};
|
||||
assert_eq!(
|
||||
json!({
|
||||
"method": "configRequirements/read",
|
||||
"id": 1,
|
||||
}),
|
||||
serde_json::to_value(&request)?,
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serialize_account_login_api_key() -> Result<()> {
|
||||
let request = ClientRequest::LoginAccount {
|
||||
|
||||
@@ -83,6 +83,15 @@ pub struct ResumeConversationResponse {
|
||||
pub rollout_path: PathBuf,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, JsonSchema, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ForkConversationResponse {
|
||||
pub conversation_id: ThreadId,
|
||||
pub model: String,
|
||||
pub initial_messages: Option<Vec<EventMsg>>,
|
||||
pub rollout_path: PathBuf,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]
|
||||
#[serde(untagged)]
|
||||
pub enum GetConversationSummaryParams {
|
||||
@@ -148,6 +157,14 @@ pub struct ResumeConversationParams {
|
||||
pub overrides: Option<NewConversationParams>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ForkConversationParams {
|
||||
pub path: Option<PathBuf>,
|
||||
pub conversation_id: Option<ThreadId>,
|
||||
pub overrides: Option<NewConversationParams>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct AddConversationSubscriptionResponse {
|
||||
|
||||
@@ -453,6 +453,22 @@ pub struct ConfigReadResponse {
|
||||
pub layers: Option<Vec<ConfigLayer>>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(export_to = "v2/")]
|
||||
pub struct ConfigRequirements {
|
||||
pub allowed_approval_policies: Option<Vec<AskForApproval>>,
|
||||
pub allowed_sandbox_modes: Option<Vec<SandboxMode>>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(export_to = "v2/")]
|
||||
pub struct ConfigRequirementsReadResponse {
|
||||
/// Null if no requirements are configured (e.g. no requirements.toml/MDM entries).
|
||||
pub requirements: Option<ConfigRequirements>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(export_to = "v2/")]
|
||||
@@ -1064,6 +1080,47 @@ pub struct ThreadResumeResponse {
|
||||
pub reasoning_effort: Option<ReasoningEffort>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Default, Clone, PartialEq, JsonSchema, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(export_to = "v2/")]
|
||||
/// There are two ways to fork a thread:
|
||||
/// 1. By thread_id: load the thread from disk by thread_id and fork it into a new thread.
|
||||
/// 2. By path: load the thread from disk by path and fork it into a new thread.
|
||||
///
|
||||
/// If using path, the thread_id param will be ignored.
|
||||
///
|
||||
/// Prefer using thread_id whenever possible.
|
||||
pub struct ThreadForkParams {
|
||||
pub thread_id: String,
|
||||
|
||||
/// [UNSTABLE] Specify the rollout path to fork from.
|
||||
/// If specified, the thread_id param will be ignored.
|
||||
pub path: Option<PathBuf>,
|
||||
|
||||
/// Configuration overrides for the forked thread, if any.
|
||||
pub model: Option<String>,
|
||||
pub model_provider: Option<String>,
|
||||
pub cwd: Option<String>,
|
||||
pub approval_policy: Option<AskForApproval>,
|
||||
pub sandbox: Option<SandboxMode>,
|
||||
pub config: Option<HashMap<String, serde_json::Value>>,
|
||||
pub base_instructions: Option<String>,
|
||||
pub developer_instructions: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(export_to = "v2/")]
|
||||
pub struct ThreadForkResponse {
|
||||
pub thread: Thread,
|
||||
pub model: String,
|
||||
pub model_provider: String,
|
||||
pub cwd: PathBuf,
|
||||
pub approval_policy: AskForApproval,
|
||||
pub sandbox: SandboxPolicy,
|
||||
pub reasoning_effort: Option<ReasoningEffort>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(export_to = "v2/")]
|
||||
@@ -1123,6 +1180,27 @@ pub struct ThreadListResponse {
|
||||
pub next_cursor: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default, JsonSchema, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(export_to = "v2/")]
|
||||
pub struct ThreadLoadedListParams {
|
||||
/// Opaque pagination cursor returned by a previous call.
|
||||
pub cursor: Option<String>,
|
||||
/// Optional page size; defaults to no limit.
|
||||
pub limit: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(export_to = "v2/")]
|
||||
pub struct ThreadLoadedListResponse {
|
||||
/// Thread ids for sessions currently loaded in memory.
|
||||
pub data: Vec<String>,
|
||||
/// Opaque cursor to pass to the next call to continue after the last item.
|
||||
/// if None, there are no more items to return.
|
||||
pub next_cursor: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(export_to = "v2/")]
|
||||
@@ -1238,7 +1316,7 @@ pub struct Thread {
|
||||
pub source: SessionSource,
|
||||
/// Optional Git metadata captured when the thread was created.
|
||||
pub git_info: Option<GitInfo>,
|
||||
/// Only populated on `thread/resume` and `thread/rollback` responses.
|
||||
/// Only populated on `thread/resume`, `thread/rollback`, `thread/fork` responses.
|
||||
/// For all other responses and notifications returning a Thread,
|
||||
/// the turns field will be an empty list.
|
||||
pub turns: Vec<Turn>,
|
||||
@@ -1314,7 +1392,7 @@ impl From<CoreTokenUsage> for TokenUsageBreakdown {
|
||||
#[ts(export_to = "v2/")]
|
||||
pub struct Turn {
|
||||
pub id: String,
|
||||
/// Only populated on a `thread/resume` response.
|
||||
/// Only populated on a `thread/resume` or `thread/fork` response.
|
||||
/// For all other responses and notifications returning a Turn,
|
||||
/// the items field will be an empty list.
|
||||
pub items: Vec<ThreadItem>,
|
||||
@@ -1460,6 +1538,7 @@ pub enum UserInput {
|
||||
Text { text: String },
|
||||
Image { url: String },
|
||||
LocalImage { path: PathBuf },
|
||||
Skill { name: String, path: PathBuf },
|
||||
}
|
||||
|
||||
impl UserInput {
|
||||
@@ -1468,6 +1547,7 @@ impl UserInput {
|
||||
UserInput::Text { text } => CoreUserInput::Text { text },
|
||||
UserInput::Image { url } => CoreUserInput::Image { image_url: url },
|
||||
UserInput::LocalImage { path } => CoreUserInput::LocalImage { path },
|
||||
UserInput::Skill { name, path } => CoreUserInput::Skill { name, path },
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1478,6 +1558,7 @@ impl From<CoreUserInput> for UserInput {
|
||||
CoreUserInput::Text { text } => UserInput::Text { text },
|
||||
CoreUserInput::Image { image_url } => UserInput::Image { url: image_url },
|
||||
CoreUserInput::LocalImage { path } => UserInput::LocalImage { path },
|
||||
CoreUserInput::Skill { name, path } => UserInput::Skill { name, path },
|
||||
_ => unreachable!("unsupported user input variant"),
|
||||
}
|
||||
}
|
||||
@@ -2063,6 +2144,10 @@ mod tests {
|
||||
CoreUserInput::LocalImage {
|
||||
path: PathBuf::from("local/image.png"),
|
||||
},
|
||||
CoreUserInput::Skill {
|
||||
name: "skill-creator".to_string(),
|
||||
path: PathBuf::from("/repo/.codex/skills/skill-creator/SKILL.md"),
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
@@ -2080,6 +2165,10 @@ mod tests {
|
||||
UserInput::LocalImage {
|
||||
path: PathBuf::from("local/image.png"),
|
||||
},
|
||||
UserInput::Skill {
|
||||
name: "skill-creator".to_string(),
|
||||
path: PathBuf::from("/repo/.codex/skills/skill-creator/SKILL.md"),
|
||||
},
|
||||
],
|
||||
}
|
||||
);
|
||||
|
||||
@@ -41,7 +41,7 @@ Use the thread APIs to create, list, or archive conversations. Drive a conversat
|
||||
## Lifecycle Overview
|
||||
|
||||
- Initialize once: Immediately after launching the codex app-server process, send an `initialize` request with your client metadata, then emit an `initialized` notification. Any other request before this handshake gets rejected.
|
||||
- Start (or resume) a thread: Call `thread/start` to open a fresh conversation. The response returns the thread object and you’ll also get a `thread/started` notification. If you’re continuing an existing conversation, call `thread/resume` with its ID instead.
|
||||
- Start (or resume) a thread: Call `thread/start` to open a fresh conversation. The response returns the thread object and you’ll also get a `thread/started` notification. If you’re continuing an existing conversation, call `thread/resume` with its ID instead. If you want to branch from an existing conversation, call `thread/fork` to create a new thread id with copied history.
|
||||
- Begin a turn: To send user input, call `turn/start` with the target `threadId` and the user's input. Optional fields let you override model, cwd, sandbox policy, etc. This immediately returns the new turn object and triggers a `turn/started` notification.
|
||||
- Stream events: After `turn/start`, keep reading JSON-RPC notifications on stdout. You’ll see `item/started`, `item/completed`, deltas like `item/agentMessage/delta`, tool progress, etc. These represent streaming model output plus any side effects (commands, tool calls, reasoning notes).
|
||||
- Finish the turn: When the model is done (or the turn is interrupted via making the `turn/interrupt` call), the server sends `turn/completed` with the final turn state and token usage.
|
||||
@@ -72,7 +72,9 @@ Example (from OpenAI's official VSCode extension):
|
||||
|
||||
- `thread/start` — create a new thread; emits `thread/started` and auto-subscribes you to turn/item events for that thread.
|
||||
- `thread/resume` — reopen an existing thread by id so subsequent `turn/start` calls append to it.
|
||||
- `thread/fork` — fork an existing thread into a new thread id by copying the stored history; emits `thread/started` and auto-subscribes you to turn/item events for the new thread.
|
||||
- `thread/list` — page through stored rollouts; supports cursor-based pagination and optional `modelProviders` filtering.
|
||||
- `thread/loaded/list` — list the thread ids currently loaded in memory.
|
||||
- `thread/archive` — move a thread’s rollout file into the archived directory; returns `{}` on success.
|
||||
- `thread/rollback` — drop the last N turns from the agent’s in-memory context and persist a rollback marker in the rollout so future resumes see the pruned history; returns the updated `thread` (with `turns` populated) on success.
|
||||
- `turn/start` — add user input to a thread and begin Codex generation; responds with the initial `turn` object and streams `turn/started`, `item/*`, and `turn/completed` notifications.
|
||||
@@ -88,6 +90,7 @@ Example (from OpenAI's official VSCode extension):
|
||||
- `config/read` — fetch the effective config on disk after resolving config layering.
|
||||
- `config/value/write` — write a single config key/value to the user's config.toml on disk.
|
||||
- `config/batchWrite` — apply multiple config edits atomically to the user's config.toml on disk.
|
||||
- `configRequirements/read` — fetch the loaded requirements allow-lists from `requirements.toml` and/or MDM (or `null` if none are configured).
|
||||
|
||||
### Example: Start or resume a thread
|
||||
|
||||
@@ -120,6 +123,14 @@ To continue a stored session, call `thread/resume` with the `thread.id` you prev
|
||||
{ "id": 11, "result": { "thread": { "id": "thr_123", … } } }
|
||||
```
|
||||
|
||||
To branch from a stored session, call `thread/fork` with the `thread.id`. This creates a new thread id and emits a `thread/started` notification for it:
|
||||
|
||||
```json
|
||||
{ "method": "thread/fork", "id": 12, "params": { "threadId": "thr_123" } }
|
||||
{ "id": 12, "result": { "thread": { "id": "thr_456", … } } }
|
||||
{ "method": "thread/started", "params": { "thread": { … } } }
|
||||
```
|
||||
|
||||
### Example: List threads (with pagination & filters)
|
||||
|
||||
`thread/list` lets you render a history UI. Pass any combination of:
|
||||
@@ -146,6 +157,17 @@ Example:
|
||||
|
||||
When `nextCursor` is `null`, you’ve reached the final page.
|
||||
|
||||
### Example: List loaded threads
|
||||
|
||||
`thread/loaded/list` returns thread ids currently loaded in memory. This is useful when you want to check which sessions are active without scanning rollouts on disk.
|
||||
|
||||
```json
|
||||
{ "method": "thread/loaded/list", "id": 21 }
|
||||
{ "id": 21, "result": {
|
||||
"data": ["thr_123", "thr_456"]
|
||||
} }
|
||||
```
|
||||
|
||||
### Example: Archive a thread
|
||||
|
||||
Use `thread/archive` to move the persisted rollout (stored as a JSONL file on disk) into the archived sessions directory.
|
||||
@@ -200,13 +222,14 @@ You can optionally specify config overrides on the new turn. If specified, these
|
||||
|
||||
### Example: Start a turn (invoke a skill)
|
||||
|
||||
Invoke a skill by sending a text input that begins with `$<skill-name>`.
|
||||
Invoke a skill explicitly by including `$<skill-name>` in the text input and adding a `skill` input item alongside it.
|
||||
|
||||
```json
|
||||
{ "method": "turn/start", "id": 33, "params": {
|
||||
"threadId": "thr_123",
|
||||
"input": [
|
||||
{ "type": "text", "text": "$skill-creator Add a new skill for triaging flaky CI and include step-by-step usage." }
|
||||
{ "type": "text", "text": "$skill-creator Add a new skill for triaging flaky CI and include step-by-step usage." },
|
||||
{ "type": "skill", "name": "skill-creator", "path": "/Users/me/.codex/skills/skill-creator/SKILL.md" }
|
||||
]
|
||||
} }
|
||||
{ "id": 33, "result": { "turn": {
|
||||
@@ -428,7 +451,23 @@ UI guidance for IDEs: surface an approval dialog as soon as the request arrives.
|
||||
|
||||
## Skills
|
||||
|
||||
Skills are invoked by sending a text input that starts with `$<skill-name>`. The rest of the text is passed to the skill as its input.
|
||||
Invoke a skill by including `$<skill-name>` in the text input. Add a `skill` input item (recommended) so the backend injects full skill instructions instead of relying on the model to resolve the name.
|
||||
|
||||
```json
|
||||
{
|
||||
"method": "turn/start",
|
||||
"id": 101,
|
||||
"params": {
|
||||
"threadId": "thread-1",
|
||||
"input": [
|
||||
{ "type": "text", "text": "$skill-creator Add a new skill for triaging flaky CI." },
|
||||
{ "type": "skill", "name": "skill-creator", "path": "/Users/me/.codex/skills/skill-creator/SKILL.md" }
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
If you omit the `skill` item, the model will still parse the `$<skill-name>` marker and try to locate the skill, which can add latency.
|
||||
|
||||
Example:
|
||||
|
||||
|
||||
@@ -28,6 +28,8 @@ use codex_app_server_protocol::ConversationSummary;
|
||||
use codex_app_server_protocol::ExecOneOffCommandResponse;
|
||||
use codex_app_server_protocol::FeedbackUploadParams;
|
||||
use codex_app_server_protocol::FeedbackUploadResponse;
|
||||
use codex_app_server_protocol::ForkConversationParams;
|
||||
use codex_app_server_protocol::ForkConversationResponse;
|
||||
use codex_app_server_protocol::FuzzyFileSearchParams;
|
||||
use codex_app_server_protocol::FuzzyFileSearchResponse;
|
||||
use codex_app_server_protocol::GetAccountParams;
|
||||
@@ -86,9 +88,13 @@ use codex_app_server_protocol::SkillsListResponse;
|
||||
use codex_app_server_protocol::Thread;
|
||||
use codex_app_server_protocol::ThreadArchiveParams;
|
||||
use codex_app_server_protocol::ThreadArchiveResponse;
|
||||
use codex_app_server_protocol::ThreadForkParams;
|
||||
use codex_app_server_protocol::ThreadForkResponse;
|
||||
use codex_app_server_protocol::ThreadItem;
|
||||
use codex_app_server_protocol::ThreadListParams;
|
||||
use codex_app_server_protocol::ThreadListResponse;
|
||||
use codex_app_server_protocol::ThreadLoadedListParams;
|
||||
use codex_app_server_protocol::ThreadLoadedListResponse;
|
||||
use codex_app_server_protocol::ThreadResumeParams;
|
||||
use codex_app_server_protocol::ThreadResumeResponse;
|
||||
use codex_app_server_protocol::ThreadRollbackParams;
|
||||
@@ -124,6 +130,7 @@ use codex_core::config::ConfigService;
|
||||
use codex_core::config::edit::ConfigEditsBuilder;
|
||||
use codex_core::config::types::McpServerTransportConfig;
|
||||
use codex_core::default_client::get_codex_user_agent;
|
||||
use codex_core::error::CodexErr;
|
||||
use codex_core::exec::ExecParams;
|
||||
use codex_core::exec_env::create_env;
|
||||
use codex_core::features::Feature;
|
||||
@@ -367,6 +374,9 @@ impl CodexMessageProcessor {
|
||||
ClientRequest::ThreadResume { request_id, params } => {
|
||||
self.thread_resume(request_id, params).await;
|
||||
}
|
||||
ClientRequest::ThreadFork { request_id, params } => {
|
||||
self.thread_fork(request_id, params).await;
|
||||
}
|
||||
ClientRequest::ThreadArchive { request_id, params } => {
|
||||
self.thread_archive(request_id, params).await;
|
||||
}
|
||||
@@ -376,6 +386,9 @@ impl CodexMessageProcessor {
|
||||
ClientRequest::ThreadList { request_id, params } => {
|
||||
self.thread_list(request_id, params).await;
|
||||
}
|
||||
ClientRequest::ThreadLoadedList { request_id, params } => {
|
||||
self.thread_loaded_list(request_id, params).await;
|
||||
}
|
||||
ClientRequest::SkillsList { request_id, params } => {
|
||||
self.skills_list(request_id, params).await;
|
||||
}
|
||||
@@ -433,6 +446,9 @@ impl CodexMessageProcessor {
|
||||
ClientRequest::ResumeConversation { request_id, params } => {
|
||||
self.handle_resume_conversation(request_id, params).await;
|
||||
}
|
||||
ClientRequest::ForkConversation { request_id, params } => {
|
||||
self.handle_fork_conversation(request_id, params).await;
|
||||
}
|
||||
ClientRequest::ArchiveConversation { request_id, params } => {
|
||||
self.archive_conversation(request_id, params).await;
|
||||
}
|
||||
@@ -510,6 +526,9 @@ impl CodexMessageProcessor {
|
||||
| ClientRequest::ConfigBatchWrite { .. } => {
|
||||
warn!("Config request reached CodexMessageProcessor unexpectedly");
|
||||
}
|
||||
ClientRequest::ConfigRequirementsRead { .. } => {
|
||||
warn!("ConfigRequirementsRead request reached CodexMessageProcessor unexpectedly");
|
||||
}
|
||||
ClientRequest::GetAccountRateLimits {
|
||||
request_id,
|
||||
params: _,
|
||||
@@ -582,7 +601,7 @@ impl CodexMessageProcessor {
|
||||
.await;
|
||||
|
||||
let payload = AuthStatusChangeNotification {
|
||||
auth_method: self.auth_manager.auth().map(|auth| auth.mode),
|
||||
auth_method: self.auth_manager.auth_cached().map(|auth| auth.mode),
|
||||
};
|
||||
self.outgoing
|
||||
.send_server_notification(ServerNotification::AuthStatusChange(payload))
|
||||
@@ -612,7 +631,7 @@ impl CodexMessageProcessor {
|
||||
.await;
|
||||
|
||||
let payload_v2 = AccountUpdatedNotification {
|
||||
auth_mode: self.auth_manager.auth().map(|auth| auth.mode),
|
||||
auth_mode: self.auth_manager.auth_cached().map(|auth| auth.mode),
|
||||
};
|
||||
self.outgoing
|
||||
.send_server_notification(ServerNotification::AccountUpdated(payload_v2))
|
||||
@@ -704,7 +723,7 @@ impl CodexMessageProcessor {
|
||||
auth_manager.reload();
|
||||
|
||||
// Notify clients with the actual current auth mode.
|
||||
let current_auth_method = auth_manager.auth().map(|a| a.mode);
|
||||
let current_auth_method = auth_manager.auth_cached().map(|a| a.mode);
|
||||
let payload = AuthStatusChangeNotification {
|
||||
auth_method: current_auth_method,
|
||||
};
|
||||
@@ -794,7 +813,7 @@ impl CodexMessageProcessor {
|
||||
auth_manager.reload();
|
||||
|
||||
// Notify clients with the actual current auth mode.
|
||||
let current_auth_method = auth_manager.auth().map(|a| a.mode);
|
||||
let current_auth_method = auth_manager.auth_cached().map(|a| a.mode);
|
||||
let payload_v2 = AccountUpdatedNotification {
|
||||
auth_mode: current_auth_method,
|
||||
};
|
||||
@@ -906,7 +925,7 @@ impl CodexMessageProcessor {
|
||||
}
|
||||
|
||||
// Reflect the current auth method after logout (likely None).
|
||||
Ok(self.auth_manager.auth().map(|auth| auth.mode))
|
||||
Ok(self.auth_manager.auth_cached().map(|auth| auth.mode))
|
||||
}
|
||||
|
||||
async fn logout_v1(&mut self, request_id: RequestId) {
|
||||
@@ -973,10 +992,10 @@ impl CodexMessageProcessor {
|
||||
requires_openai_auth: Some(false),
|
||||
}
|
||||
} else {
|
||||
match self.auth_manager.auth() {
|
||||
match self.auth_manager.auth().await {
|
||||
Some(auth) => {
|
||||
let auth_mode = auth.mode;
|
||||
let (reported_auth_method, token_opt) = match auth.get_token().await {
|
||||
let (reported_auth_method, token_opt) = match auth.get_token() {
|
||||
Ok(token) if !token.is_empty() => {
|
||||
let tok = if include_token { Some(token) } else { None };
|
||||
(Some(auth_mode), tok)
|
||||
@@ -1021,7 +1040,7 @@ impl CodexMessageProcessor {
|
||||
return;
|
||||
}
|
||||
|
||||
let account = match self.auth_manager.auth() {
|
||||
let account = match self.auth_manager.auth_cached() {
|
||||
Some(auth) => Some(match auth.mode {
|
||||
AuthMode::ApiKey => Account::ApiKey {},
|
||||
AuthMode::ChatGPT => {
|
||||
@@ -1075,7 +1094,7 @@ impl CodexMessageProcessor {
|
||||
}
|
||||
|
||||
async fn fetch_account_rate_limits(&self) -> Result<CoreRateLimitSnapshot, JSONRPCErrorError> {
|
||||
let Some(auth) = self.auth_manager.auth() else {
|
||||
let Some(auth) = self.auth_manager.auth().await else {
|
||||
return Err(JSONRPCErrorError {
|
||||
code: INVALID_REQUEST_ERROR_CODE,
|
||||
message: "codex account authentication required to read rate limits".to_string(),
|
||||
@@ -1092,7 +1111,6 @@ impl CodexMessageProcessor {
|
||||
}
|
||||
|
||||
let client = BackendClient::from_auth(self.config.chatgpt_base_url.clone(), &auth)
|
||||
.await
|
||||
.map_err(|err| JSONRPCErrorError {
|
||||
code: INTERNAL_ERROR_CODE,
|
||||
message: format!("failed to construct backend client: {err}"),
|
||||
@@ -1132,7 +1150,10 @@ impl CodexMessageProcessor {
|
||||
|
||||
async fn get_user_info(&self, request_id: RequestId) {
|
||||
// Read alleged user email from cached auth (best-effort; not verified).
|
||||
let alleged_user_email = self.auth_manager.auth().and_then(|a| a.get_account_email());
|
||||
let alleged_user_email = self
|
||||
.auth_manager
|
||||
.auth_cached()
|
||||
.and_then(|a| a.get_account_email());
|
||||
|
||||
let response = UserInfoResponse { alleged_user_email };
|
||||
self.outgoing.send_response(request_id, response).await;
|
||||
@@ -1588,6 +1609,61 @@ impl CodexMessageProcessor {
|
||||
self.outgoing.send_response(request_id, response).await;
|
||||
}
|
||||
|
||||
async fn thread_loaded_list(&self, request_id: RequestId, params: ThreadLoadedListParams) {
|
||||
let ThreadLoadedListParams { cursor, limit } = params;
|
||||
let mut data = self
|
||||
.thread_manager
|
||||
.list_thread_ids()
|
||||
.await
|
||||
.into_iter()
|
||||
.map(|thread_id| thread_id.to_string())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if data.is_empty() {
|
||||
let response = ThreadLoadedListResponse {
|
||||
data,
|
||||
next_cursor: None,
|
||||
};
|
||||
self.outgoing.send_response(request_id, response).await;
|
||||
return;
|
||||
}
|
||||
|
||||
data.sort();
|
||||
let total = data.len();
|
||||
let start = match cursor {
|
||||
Some(cursor) => {
|
||||
let cursor = match ThreadId::from_string(&cursor) {
|
||||
Ok(id) => id.to_string(),
|
||||
Err(_) => {
|
||||
let error = JSONRPCErrorError {
|
||||
code: INVALID_REQUEST_ERROR_CODE,
|
||||
message: format!("invalid cursor: {cursor}"),
|
||||
data: None,
|
||||
};
|
||||
self.outgoing.send_error(request_id, error).await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
match data.binary_search(&cursor) {
|
||||
Ok(idx) => idx + 1,
|
||||
Err(idx) => idx,
|
||||
}
|
||||
}
|
||||
None => 0,
|
||||
};
|
||||
|
||||
let effective_limit = limit.unwrap_or(total as u32).max(1) as usize;
|
||||
let end = start.saturating_add(effective_limit).min(total);
|
||||
let page = data[start..end].to_vec();
|
||||
let next_cursor = page.last().filter(|_| end < total).cloned();
|
||||
|
||||
let response = ThreadLoadedListResponse {
|
||||
data: page,
|
||||
next_cursor,
|
||||
};
|
||||
self.outgoing.send_response(request_id, response).await;
|
||||
}
|
||||
|
||||
async fn thread_resume(&mut self, request_id: RequestId, params: ThreadResumeParams) {
|
||||
let ThreadResumeParams {
|
||||
thread_id,
|
||||
@@ -1793,6 +1869,198 @@ impl CodexMessageProcessor {
|
||||
}
|
||||
}
|
||||
|
||||
async fn thread_fork(&mut self, request_id: RequestId, params: ThreadForkParams) {
|
||||
let ThreadForkParams {
|
||||
thread_id,
|
||||
path,
|
||||
model,
|
||||
model_provider,
|
||||
cwd,
|
||||
approval_policy,
|
||||
sandbox,
|
||||
config: cli_overrides,
|
||||
base_instructions,
|
||||
developer_instructions,
|
||||
} = params;
|
||||
|
||||
let overrides_requested = model.is_some()
|
||||
|| model_provider.is_some()
|
||||
|| cwd.is_some()
|
||||
|| approval_policy.is_some()
|
||||
|| sandbox.is_some()
|
||||
|| cli_overrides.is_some()
|
||||
|| base_instructions.is_some()
|
||||
|| developer_instructions.is_some();
|
||||
|
||||
let config = if overrides_requested {
|
||||
let overrides = self.build_thread_config_overrides(
|
||||
model,
|
||||
model_provider,
|
||||
cwd,
|
||||
approval_policy,
|
||||
sandbox,
|
||||
base_instructions,
|
||||
developer_instructions,
|
||||
);
|
||||
|
||||
// Persist windows sandbox feature.
|
||||
let mut cli_overrides = cli_overrides.unwrap_or_default();
|
||||
if cfg!(windows) && self.config.features.enabled(Feature::WindowsSandbox) {
|
||||
cli_overrides.insert(
|
||||
"features.experimental_windows_sandbox".to_string(),
|
||||
serde_json::json!(true),
|
||||
);
|
||||
}
|
||||
|
||||
match derive_config_from_params(&self.cli_overrides, Some(cli_overrides), overrides)
|
||||
.await
|
||||
{
|
||||
Ok(config) => config,
|
||||
Err(err) => {
|
||||
let error = JSONRPCErrorError {
|
||||
code: INVALID_REQUEST_ERROR_CODE,
|
||||
message: format!("error deriving config: {err}"),
|
||||
data: None,
|
||||
};
|
||||
self.outgoing.send_error(request_id, error).await;
|
||||
return;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
self.config.as_ref().clone()
|
||||
};
|
||||
|
||||
let rollout_path = if let Some(path) = path {
|
||||
path
|
||||
} else {
|
||||
let existing_thread_id = match ThreadId::from_string(&thread_id) {
|
||||
Ok(id) => id,
|
||||
Err(err) => {
|
||||
let error = JSONRPCErrorError {
|
||||
code: INVALID_REQUEST_ERROR_CODE,
|
||||
message: format!("invalid thread id: {err}"),
|
||||
data: None,
|
||||
};
|
||||
self.outgoing.send_error(request_id, error).await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
match find_thread_path_by_id_str(
|
||||
&self.config.codex_home,
|
||||
&existing_thread_id.to_string(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(Some(p)) => p,
|
||||
Ok(None) => {
|
||||
self.send_invalid_request_error(
|
||||
request_id,
|
||||
format!("no rollout found for thread id {existing_thread_id}"),
|
||||
)
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
Err(err) => {
|
||||
self.send_invalid_request_error(
|
||||
request_id,
|
||||
format!("failed to locate thread id {existing_thread_id}: {err}"),
|
||||
)
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let fallback_model_provider = config.model_provider_id.clone();
|
||||
|
||||
let NewThread {
|
||||
thread_id,
|
||||
session_configured,
|
||||
..
|
||||
} = match self
|
||||
.thread_manager
|
||||
.fork_thread(usize::MAX, config, rollout_path.clone())
|
||||
.await
|
||||
{
|
||||
Ok(thread) => thread,
|
||||
Err(err) => {
|
||||
let (code, message) = match err {
|
||||
CodexErr::Io(_) | CodexErr::Json(_) => (
|
||||
INVALID_REQUEST_ERROR_CODE,
|
||||
format!("failed to load rollout `{}`: {err}", rollout_path.display()),
|
||||
),
|
||||
CodexErr::InvalidRequest(message) => (INVALID_REQUEST_ERROR_CODE, message),
|
||||
_ => (INTERNAL_ERROR_CODE, format!("error forking thread: {err}")),
|
||||
};
|
||||
let error = JSONRPCErrorError {
|
||||
code,
|
||||
message,
|
||||
data: None,
|
||||
};
|
||||
self.outgoing.send_error(request_id, error).await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let SessionConfiguredEvent {
|
||||
rollout_path,
|
||||
initial_messages,
|
||||
..
|
||||
} = session_configured;
|
||||
// Auto-attach a conversation listener when forking a thread.
|
||||
if let Err(err) = self
|
||||
.attach_conversation_listener(thread_id, false, ApiVersion::V2)
|
||||
.await
|
||||
{
|
||||
tracing::warn!(
|
||||
"failed to attach listener for thread {}: {}",
|
||||
thread_id,
|
||||
err.message
|
||||
);
|
||||
}
|
||||
|
||||
let mut thread = match read_summary_from_rollout(
|
||||
rollout_path.as_path(),
|
||||
fallback_model_provider.as_str(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(summary) => summary_to_thread(summary),
|
||||
Err(err) => {
|
||||
self.send_internal_error(
|
||||
request_id,
|
||||
format!(
|
||||
"failed to load rollout `{}` for thread {thread_id}: {err}",
|
||||
rollout_path.display()
|
||||
),
|
||||
)
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
thread.turns = initial_messages
|
||||
.as_deref()
|
||||
.map_or_else(Vec::new, build_turns_from_event_msgs);
|
||||
|
||||
let response = ThreadForkResponse {
|
||||
thread: thread.clone(),
|
||||
model: session_configured.model,
|
||||
model_provider: session_configured.model_provider_id,
|
||||
cwd: session_configured.cwd,
|
||||
approval_policy: session_configured.approval_policy.into(),
|
||||
sandbox: session_configured.sandbox_policy.into(),
|
||||
reasoning_effort: session_configured.reasoning_effort,
|
||||
};
|
||||
|
||||
self.outgoing.send_response(request_id, response).await;
|
||||
|
||||
let notif = ThreadStartedNotification { thread };
|
||||
self.outgoing
|
||||
.send_server_notification(ServerNotification::ThreadStarted(notif))
|
||||
.await;
|
||||
}
|
||||
|
||||
async fn get_thread_summary(
|
||||
&self,
|
||||
request_id: RequestId,
|
||||
@@ -2416,6 +2684,166 @@ impl CodexMessageProcessor {
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_fork_conversation(
|
||||
&self,
|
||||
request_id: RequestId,
|
||||
params: ForkConversationParams,
|
||||
) {
|
||||
let ForkConversationParams {
|
||||
path,
|
||||
conversation_id,
|
||||
overrides,
|
||||
} = params;
|
||||
|
||||
// Derive a Config using the same logic as new conversation, honoring overrides if provided.
|
||||
let config = match overrides {
|
||||
Some(overrides) => {
|
||||
let NewConversationParams {
|
||||
model,
|
||||
model_provider,
|
||||
profile,
|
||||
cwd,
|
||||
approval_policy,
|
||||
sandbox: sandbox_mode,
|
||||
config: cli_overrides,
|
||||
base_instructions,
|
||||
developer_instructions,
|
||||
compact_prompt,
|
||||
include_apply_patch_tool,
|
||||
} = overrides;
|
||||
|
||||
// Persist windows sandbox feature.
|
||||
let mut cli_overrides = cli_overrides.unwrap_or_default();
|
||||
if cfg!(windows) && self.config.features.enabled(Feature::WindowsSandbox) {
|
||||
cli_overrides.insert(
|
||||
"features.experimental_windows_sandbox".to_string(),
|
||||
serde_json::json!(true),
|
||||
);
|
||||
}
|
||||
|
||||
let overrides = ConfigOverrides {
|
||||
model,
|
||||
config_profile: profile,
|
||||
cwd: cwd.map(PathBuf::from),
|
||||
approval_policy,
|
||||
sandbox_mode,
|
||||
model_provider,
|
||||
codex_linux_sandbox_exe: self.codex_linux_sandbox_exe.clone(),
|
||||
base_instructions,
|
||||
developer_instructions,
|
||||
compact_prompt,
|
||||
include_apply_patch_tool,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
derive_config_from_params(&self.cli_overrides, Some(cli_overrides), overrides).await
|
||||
}
|
||||
None => Ok(self.config.as_ref().clone()),
|
||||
};
|
||||
let config = match config {
|
||||
Ok(cfg) => cfg,
|
||||
Err(err) => {
|
||||
self.send_invalid_request_error(
|
||||
request_id,
|
||||
format!("error deriving config: {err}"),
|
||||
)
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let rollout_path = if let Some(path) = path {
|
||||
path
|
||||
} else if let Some(conversation_id) = conversation_id {
|
||||
match find_thread_path_by_id_str(&self.config.codex_home, &conversation_id.to_string())
|
||||
.await
|
||||
{
|
||||
Ok(Some(found_path)) => found_path,
|
||||
Ok(None) => {
|
||||
self.send_invalid_request_error(
|
||||
request_id,
|
||||
format!("no rollout found for conversation id {conversation_id}"),
|
||||
)
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
Err(err) => {
|
||||
self.send_invalid_request_error(
|
||||
request_id,
|
||||
format!("failed to locate conversation id {conversation_id}: {err}"),
|
||||
)
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
self.send_invalid_request_error(
|
||||
request_id,
|
||||
"either path or conversation id must be provided".to_string(),
|
||||
)
|
||||
.await;
|
||||
return;
|
||||
};
|
||||
|
||||
let NewThread {
|
||||
thread_id,
|
||||
session_configured,
|
||||
..
|
||||
} = match self
|
||||
.thread_manager
|
||||
.fork_thread(usize::MAX, config, rollout_path.clone())
|
||||
.await
|
||||
{
|
||||
Ok(thread) => thread,
|
||||
Err(err) => {
|
||||
let (code, message) = match err {
|
||||
CodexErr::Io(_) | CodexErr::Json(_) => (
|
||||
INVALID_REQUEST_ERROR_CODE,
|
||||
format!("failed to load rollout `{}`: {err}", rollout_path.display()),
|
||||
),
|
||||
CodexErr::InvalidRequest(message) => (INVALID_REQUEST_ERROR_CODE, message),
|
||||
_ => (
|
||||
INTERNAL_ERROR_CODE,
|
||||
format!("error forking conversation: {err}"),
|
||||
),
|
||||
};
|
||||
let error = JSONRPCErrorError {
|
||||
code,
|
||||
message,
|
||||
data: None,
|
||||
};
|
||||
self.outgoing.send_error(request_id, error).await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
self.outgoing
|
||||
.send_server_notification(ServerNotification::SessionConfigured(
|
||||
SessionConfiguredNotification {
|
||||
session_id: session_configured.session_id,
|
||||
model: session_configured.model.clone(),
|
||||
reasoning_effort: session_configured.reasoning_effort,
|
||||
history_log_id: session_configured.history_log_id,
|
||||
history_entry_count: session_configured.history_entry_count,
|
||||
initial_messages: session_configured.initial_messages.clone(),
|
||||
rollout_path: session_configured.rollout_path.clone(),
|
||||
},
|
||||
))
|
||||
.await;
|
||||
let initial_messages = session_configured
|
||||
.initial_messages
|
||||
.map(|msgs| msgs.into_iter().collect());
|
||||
|
||||
// Reply with conversation id + model and initial messages (when present)
|
||||
let response = ForkConversationResponse {
|
||||
conversation_id: thread_id,
|
||||
model: session_configured.model.clone(),
|
||||
initial_messages,
|
||||
rollout_path: session_configured.rollout_path.clone(),
|
||||
};
|
||||
self.outgoing.send_response(request_id, response).await;
|
||||
}
|
||||
|
||||
async fn send_invalid_request_error(&self, request_id: RequestId, message: String) {
|
||||
let error = JSONRPCErrorError {
|
||||
code: INVALID_REQUEST_ERROR_CODE,
|
||||
|
||||
@@ -3,13 +3,18 @@ use crate::error_code::INVALID_REQUEST_ERROR_CODE;
|
||||
use codex_app_server_protocol::ConfigBatchWriteParams;
|
||||
use codex_app_server_protocol::ConfigReadParams;
|
||||
use codex_app_server_protocol::ConfigReadResponse;
|
||||
use codex_app_server_protocol::ConfigRequirements;
|
||||
use codex_app_server_protocol::ConfigRequirementsReadResponse;
|
||||
use codex_app_server_protocol::ConfigValueWriteParams;
|
||||
use codex_app_server_protocol::ConfigWriteErrorCode;
|
||||
use codex_app_server_protocol::ConfigWriteResponse;
|
||||
use codex_app_server_protocol::JSONRPCErrorError;
|
||||
use codex_app_server_protocol::SandboxMode;
|
||||
use codex_core::config::ConfigService;
|
||||
use codex_core::config::ConfigServiceError;
|
||||
use codex_core::config_loader::ConfigRequirementsToml;
|
||||
use codex_core::config_loader::LoaderOverrides;
|
||||
use codex_core::config_loader::SandboxModeRequirement as CoreSandboxModeRequirement;
|
||||
use serde_json::json;
|
||||
use std::path::PathBuf;
|
||||
use toml::Value as TomlValue;
|
||||
@@ -37,6 +42,19 @@ impl ConfigApi {
|
||||
self.service.read(params).await.map_err(map_error)
|
||||
}
|
||||
|
||||
pub(crate) async fn config_requirements_read(
|
||||
&self,
|
||||
) -> Result<ConfigRequirementsReadResponse, JSONRPCErrorError> {
|
||||
let requirements = self
|
||||
.service
|
||||
.read_requirements()
|
||||
.await
|
||||
.map_err(map_error)?
|
||||
.map(map_requirements_toml_to_api);
|
||||
|
||||
Ok(ConfigRequirementsReadResponse { requirements })
|
||||
}
|
||||
|
||||
pub(crate) async fn write_value(
|
||||
&self,
|
||||
params: ConfigValueWriteParams,
|
||||
@@ -52,6 +70,32 @@ impl ConfigApi {
|
||||
}
|
||||
}
|
||||
|
||||
fn map_requirements_toml_to_api(requirements: ConfigRequirementsToml) -> ConfigRequirements {
|
||||
ConfigRequirements {
|
||||
allowed_approval_policies: requirements.allowed_approval_policies.map(|policies| {
|
||||
policies
|
||||
.into_iter()
|
||||
.map(codex_app_server_protocol::AskForApproval::from)
|
||||
.collect()
|
||||
}),
|
||||
allowed_sandbox_modes: requirements.allowed_sandbox_modes.map(|modes| {
|
||||
modes
|
||||
.into_iter()
|
||||
.filter_map(map_sandbox_mode_requirement_to_api)
|
||||
.collect()
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
fn map_sandbox_mode_requirement_to_api(mode: CoreSandboxModeRequirement) -> Option<SandboxMode> {
|
||||
match mode {
|
||||
CoreSandboxModeRequirement::ReadOnly => Some(SandboxMode::ReadOnly),
|
||||
CoreSandboxModeRequirement::WorkspaceWrite => Some(SandboxMode::WorkspaceWrite),
|
||||
CoreSandboxModeRequirement::DangerFullAccess => Some(SandboxMode::DangerFullAccess),
|
||||
CoreSandboxModeRequirement::ExternalSandbox => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn map_error(err: ConfigServiceError) -> JSONRPCErrorError {
|
||||
if let Some(code) = err.write_error_code() {
|
||||
return config_write_error(code, err.to_string());
|
||||
@@ -73,3 +117,38 @@ fn config_write_error(code: ConfigWriteErrorCode, message: impl Into<String>) ->
|
||||
})),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use codex_protocol::protocol::AskForApproval as CoreAskForApproval;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
#[test]
|
||||
fn map_requirements_toml_to_api_converts_core_enums() {
|
||||
let requirements = ConfigRequirementsToml {
|
||||
allowed_approval_policies: Some(vec![
|
||||
CoreAskForApproval::Never,
|
||||
CoreAskForApproval::OnRequest,
|
||||
]),
|
||||
allowed_sandbox_modes: Some(vec![
|
||||
CoreSandboxModeRequirement::ReadOnly,
|
||||
CoreSandboxModeRequirement::ExternalSandbox,
|
||||
]),
|
||||
};
|
||||
|
||||
let mapped = map_requirements_toml_to_api(requirements);
|
||||
|
||||
assert_eq!(
|
||||
mapped.allowed_approval_policies,
|
||||
Some(vec![
|
||||
codex_app_server_protocol::AskForApproval::Never,
|
||||
codex_app_server_protocol::AskForApproval::OnRequest,
|
||||
])
|
||||
);
|
||||
assert_eq!(
|
||||
mapped.allowed_sandbox_modes,
|
||||
Some(vec![SandboxMode::ReadOnly]),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -158,6 +158,12 @@ impl MessageProcessor {
|
||||
ClientRequest::ConfigBatchWrite { request_id, params } => {
|
||||
self.handle_config_batch_write(request_id, params).await;
|
||||
}
|
||||
ClientRequest::ConfigRequirementsRead {
|
||||
request_id,
|
||||
params: _,
|
||||
} => {
|
||||
self.handle_config_requirements_read(request_id).await;
|
||||
}
|
||||
other => {
|
||||
self.codex_message_processor.process_request(other).await;
|
||||
}
|
||||
@@ -210,4 +216,11 @@ impl MessageProcessor {
|
||||
Err(error) => self.outgoing.send_error(request_id, error).await,
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_config_requirements_read(&self, request_id: RequestId) {
|
||||
match self.config_api.config_requirements_read().await {
|
||||
Ok(response) => self.outgoing.send_response(request_id, response).await,
|
||||
Err(error) => self.outgoing.send_error(request_id, error).await,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -18,8 +18,9 @@ pub use core_test_support::test_path_buf_with_windows;
|
||||
pub use core_test_support::test_tmp_path;
|
||||
pub use core_test_support::test_tmp_path_buf;
|
||||
pub use mcp_process::McpProcess;
|
||||
pub use mock_model_server::create_mock_chat_completions_server;
|
||||
pub use mock_model_server::create_mock_chat_completions_server_unchecked;
|
||||
pub use mock_model_server::create_mock_responses_server_repeating_assistant;
|
||||
pub use mock_model_server::create_mock_responses_server_sequence;
|
||||
pub use mock_model_server::create_mock_responses_server_sequence_unchecked;
|
||||
pub use models_cache::write_models_cache;
|
||||
pub use models_cache::write_models_cache_with_models;
|
||||
pub use responses::create_apply_patch_sse_response;
|
||||
|
||||
@@ -21,6 +21,7 @@ use codex_app_server_protocol::ConfigBatchWriteParams;
|
||||
use codex_app_server_protocol::ConfigReadParams;
|
||||
use codex_app_server_protocol::ConfigValueWriteParams;
|
||||
use codex_app_server_protocol::FeedbackUploadParams;
|
||||
use codex_app_server_protocol::ForkConversationParams;
|
||||
use codex_app_server_protocol::GetAccountParams;
|
||||
use codex_app_server_protocol::GetAuthStatusParams;
|
||||
use codex_app_server_protocol::InitializeParams;
|
||||
@@ -43,7 +44,9 @@ use codex_app_server_protocol::SendUserTurnParams;
|
||||
use codex_app_server_protocol::ServerRequest;
|
||||
use codex_app_server_protocol::SetDefaultModelParams;
|
||||
use codex_app_server_protocol::ThreadArchiveParams;
|
||||
use codex_app_server_protocol::ThreadForkParams;
|
||||
use codex_app_server_protocol::ThreadListParams;
|
||||
use codex_app_server_protocol::ThreadLoadedListParams;
|
||||
use codex_app_server_protocol::ThreadResumeParams;
|
||||
use codex_app_server_protocol::ThreadRollbackParams;
|
||||
use codex_app_server_protocol::ThreadStartParams;
|
||||
@@ -60,7 +63,7 @@ pub struct McpProcess {
|
||||
process: Child,
|
||||
stdin: ChildStdin,
|
||||
stdout: BufReader<ChildStdout>,
|
||||
pending_user_messages: VecDeque<JSONRPCNotification>,
|
||||
pending_messages: VecDeque<JSONRPCMessage>,
|
||||
}
|
||||
|
||||
impl McpProcess {
|
||||
@@ -127,7 +130,7 @@ impl McpProcess {
|
||||
process,
|
||||
stdin,
|
||||
stdout,
|
||||
pending_user_messages: VecDeque::new(),
|
||||
pending_messages: VecDeque::new(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -308,6 +311,15 @@ impl McpProcess {
|
||||
self.send_request("thread/resume", params).await
|
||||
}
|
||||
|
||||
/// Send a `thread/fork` JSON-RPC request.
|
||||
pub async fn send_thread_fork_request(
|
||||
&mut self,
|
||||
params: ThreadForkParams,
|
||||
) -> anyhow::Result<i64> {
|
||||
let params = Some(serde_json::to_value(params)?);
|
||||
self.send_request("thread/fork", params).await
|
||||
}
|
||||
|
||||
/// Send a `thread/archive` JSON-RPC request.
|
||||
pub async fn send_thread_archive_request(
|
||||
&mut self,
|
||||
@@ -335,6 +347,15 @@ impl McpProcess {
|
||||
self.send_request("thread/list", params).await
|
||||
}
|
||||
|
||||
/// Send a `thread/loaded/list` JSON-RPC request.
|
||||
pub async fn send_thread_loaded_list_request(
|
||||
&mut self,
|
||||
params: ThreadLoadedListParams,
|
||||
) -> anyhow::Result<i64> {
|
||||
let params = Some(serde_json::to_value(params)?);
|
||||
self.send_request("thread/loaded/list", params).await
|
||||
}
|
||||
|
||||
/// Send a `model/list` JSON-RPC request.
|
||||
pub async fn send_list_models_request(
|
||||
&mut self,
|
||||
@@ -353,6 +374,15 @@ impl McpProcess {
|
||||
self.send_request("resumeConversation", params).await
|
||||
}
|
||||
|
||||
/// Send a `forkConversation` JSON-RPC request.
|
||||
pub async fn send_fork_conversation_request(
|
||||
&mut self,
|
||||
params: ForkConversationParams,
|
||||
) -> anyhow::Result<i64> {
|
||||
let params = Some(serde_json::to_value(params)?);
|
||||
self.send_request("forkConversation", params).await
|
||||
}
|
||||
|
||||
/// Send a `loginApiKey` JSON-RPC request.
|
||||
pub async fn send_login_api_key_request(
|
||||
&mut self,
|
||||
@@ -544,27 +574,16 @@ impl McpProcess {
|
||||
pub async fn read_stream_until_request_message(&mut self) -> anyhow::Result<ServerRequest> {
|
||||
eprintln!("in read_stream_until_request_message()");
|
||||
|
||||
loop {
|
||||
let message = self.read_jsonrpc_message().await?;
|
||||
let message = self
|
||||
.read_stream_until_message(|message| matches!(message, JSONRPCMessage::Request(_)))
|
||||
.await?;
|
||||
|
||||
match message {
|
||||
JSONRPCMessage::Notification(notification) => {
|
||||
eprintln!("notification: {notification:?}");
|
||||
self.enqueue_user_message(notification);
|
||||
}
|
||||
JSONRPCMessage::Request(jsonrpc_request) => {
|
||||
return jsonrpc_request.try_into().with_context(
|
||||
|| "failed to deserialize ServerRequest from JSONRPCRequest",
|
||||
);
|
||||
}
|
||||
JSONRPCMessage::Error(_) => {
|
||||
anyhow::bail!("unexpected JSONRPCMessage::Error: {message:?}");
|
||||
}
|
||||
JSONRPCMessage::Response(_) => {
|
||||
anyhow::bail!("unexpected JSONRPCMessage::Response: {message:?}");
|
||||
}
|
||||
}
|
||||
}
|
||||
let JSONRPCMessage::Request(jsonrpc_request) = message else {
|
||||
unreachable!("expected JSONRPCMessage::Request, got {message:?}");
|
||||
};
|
||||
jsonrpc_request
|
||||
.try_into()
|
||||
.with_context(|| "failed to deserialize ServerRequest from JSONRPCRequest")
|
||||
}
|
||||
|
||||
pub async fn read_stream_until_response_message(
|
||||
@@ -573,52 +592,32 @@ impl McpProcess {
|
||||
) -> anyhow::Result<JSONRPCResponse> {
|
||||
eprintln!("in read_stream_until_response_message({request_id:?})");
|
||||
|
||||
loop {
|
||||
let message = self.read_jsonrpc_message().await?;
|
||||
match message {
|
||||
JSONRPCMessage::Notification(notification) => {
|
||||
eprintln!("notification: {notification:?}");
|
||||
self.enqueue_user_message(notification);
|
||||
}
|
||||
JSONRPCMessage::Request(_) => {
|
||||
anyhow::bail!("unexpected JSONRPCMessage::Request: {message:?}");
|
||||
}
|
||||
JSONRPCMessage::Error(_) => {
|
||||
anyhow::bail!("unexpected JSONRPCMessage::Error: {message:?}");
|
||||
}
|
||||
JSONRPCMessage::Response(jsonrpc_response) => {
|
||||
if jsonrpc_response.id == request_id {
|
||||
return Ok(jsonrpc_response);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
let message = self
|
||||
.read_stream_until_message(|message| {
|
||||
Self::message_request_id(message) == Some(&request_id)
|
||||
})
|
||||
.await?;
|
||||
|
||||
let JSONRPCMessage::Response(response) = message else {
|
||||
unreachable!("expected JSONRPCMessage::Response, got {message:?}");
|
||||
};
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
pub async fn read_stream_until_error_message(
|
||||
&mut self,
|
||||
request_id: RequestId,
|
||||
) -> anyhow::Result<JSONRPCError> {
|
||||
loop {
|
||||
let message = self.read_jsonrpc_message().await?;
|
||||
match message {
|
||||
JSONRPCMessage::Notification(notification) => {
|
||||
eprintln!("notification: {notification:?}");
|
||||
self.enqueue_user_message(notification);
|
||||
}
|
||||
JSONRPCMessage::Request(_) => {
|
||||
anyhow::bail!("unexpected JSONRPCMessage::Request: {message:?}");
|
||||
}
|
||||
JSONRPCMessage::Response(_) => {
|
||||
// Keep scanning; we're waiting for an error with matching id.
|
||||
}
|
||||
JSONRPCMessage::Error(err) => {
|
||||
if err.id == request_id {
|
||||
return Ok(err);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
let message = self
|
||||
.read_stream_until_message(|message| {
|
||||
Self::message_request_id(message) == Some(&request_id)
|
||||
})
|
||||
.await?;
|
||||
|
||||
let JSONRPCMessage::Error(err) = message else {
|
||||
unreachable!("expected JSONRPCMessage::Error, got {message:?}");
|
||||
};
|
||||
Ok(err)
|
||||
}
|
||||
|
||||
pub async fn read_stream_until_notification_message(
|
||||
@@ -627,46 +626,64 @@ impl McpProcess {
|
||||
) -> anyhow::Result<JSONRPCNotification> {
|
||||
eprintln!("in read_stream_until_notification_message({method})");
|
||||
|
||||
if let Some(notification) = self.take_pending_notification_by_method(method) {
|
||||
return Ok(notification);
|
||||
let message = self
|
||||
.read_stream_until_message(|message| {
|
||||
matches!(
|
||||
message,
|
||||
JSONRPCMessage::Notification(notification) if notification.method == method
|
||||
)
|
||||
})
|
||||
.await?;
|
||||
|
||||
let JSONRPCMessage::Notification(notification) = message else {
|
||||
unreachable!("expected JSONRPCMessage::Notification, got {message:?}");
|
||||
};
|
||||
Ok(notification)
|
||||
}
|
||||
|
||||
/// Clears any buffered messages so future reads only consider new stream items.
|
||||
///
|
||||
/// We call this when e.g. we want to validate against the next turn and no longer care about
|
||||
/// messages buffered from the prior turn.
|
||||
pub fn clear_message_buffer(&mut self) {
|
||||
self.pending_messages.clear();
|
||||
}
|
||||
|
||||
/// Reads the stream until a message matches `predicate`, buffering any non-matching messages
|
||||
/// for later reads.
|
||||
async fn read_stream_until_message<F>(&mut self, predicate: F) -> anyhow::Result<JSONRPCMessage>
|
||||
where
|
||||
F: Fn(&JSONRPCMessage) -> bool,
|
||||
{
|
||||
if let Some(message) = self.take_pending_message(&predicate) {
|
||||
return Ok(message);
|
||||
}
|
||||
|
||||
loop {
|
||||
let message = self.read_jsonrpc_message().await?;
|
||||
match message {
|
||||
JSONRPCMessage::Notification(notification) => {
|
||||
if notification.method == method {
|
||||
return Ok(notification);
|
||||
}
|
||||
self.enqueue_user_message(notification);
|
||||
}
|
||||
JSONRPCMessage::Request(_) => {
|
||||
anyhow::bail!("unexpected JSONRPCMessage::Request: {message:?}");
|
||||
}
|
||||
JSONRPCMessage::Error(_) => {
|
||||
anyhow::bail!("unexpected JSONRPCMessage::Error: {message:?}");
|
||||
}
|
||||
JSONRPCMessage::Response(_) => {
|
||||
anyhow::bail!("unexpected JSONRPCMessage::Response: {message:?}");
|
||||
}
|
||||
if predicate(&message) {
|
||||
return Ok(message);
|
||||
}
|
||||
self.pending_messages.push_back(message);
|
||||
}
|
||||
}
|
||||
|
||||
fn take_pending_notification_by_method(&mut self, method: &str) -> Option<JSONRPCNotification> {
|
||||
if let Some(pos) = self
|
||||
.pending_user_messages
|
||||
.iter()
|
||||
.position(|notification| notification.method == method)
|
||||
{
|
||||
return self.pending_user_messages.remove(pos);
|
||||
fn take_pending_message<F>(&mut self, predicate: &F) -> Option<JSONRPCMessage>
|
||||
where
|
||||
F: Fn(&JSONRPCMessage) -> bool,
|
||||
{
|
||||
if let Some(pos) = self.pending_messages.iter().position(predicate) {
|
||||
return self.pending_messages.remove(pos);
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn enqueue_user_message(&mut self, notification: JSONRPCNotification) {
|
||||
if notification.method == "codex/event/user_message" {
|
||||
self.pending_user_messages.push_back(notification);
|
||||
fn message_request_id(message: &JSONRPCMessage) -> Option<&RequestId> {
|
||||
match message {
|
||||
JSONRPCMessage::Request(request) => Some(&request.id),
|
||||
JSONRPCMessage::Response(response) => Some(&response.id),
|
||||
JSONRPCMessage::Error(err) => Some(&err.id),
|
||||
JSONRPCMessage::Notification(_) => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,17 +1,18 @@
|
||||
use std::sync::atomic::AtomicUsize;
|
||||
use std::sync::atomic::Ordering;
|
||||
|
||||
use core_test_support::responses;
|
||||
use wiremock::Mock;
|
||||
use wiremock::MockServer;
|
||||
use wiremock::Respond;
|
||||
use wiremock::ResponseTemplate;
|
||||
use wiremock::matchers::method;
|
||||
use wiremock::matchers::path;
|
||||
use wiremock::matchers::path_regex;
|
||||
|
||||
/// Create a mock server that will provide the responses, in order, for
|
||||
/// requests to the `/v1/chat/completions` endpoint.
|
||||
pub async fn create_mock_chat_completions_server(responses: Vec<String>) -> MockServer {
|
||||
let server = MockServer::start().await;
|
||||
/// requests to the `/v1/responses` endpoint.
|
||||
pub async fn create_mock_responses_server_sequence(responses: Vec<String>) -> MockServer {
|
||||
let server = responses::start_mock_server().await;
|
||||
|
||||
let num_calls = responses.len();
|
||||
let seq_responder = SeqResponder {
|
||||
@@ -20,7 +21,7 @@ pub async fn create_mock_chat_completions_server(responses: Vec<String>) -> Mock
|
||||
};
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/chat/completions"))
|
||||
.and(path_regex(".*/responses$"))
|
||||
.respond_with(seq_responder)
|
||||
.expect(num_calls as u64)
|
||||
.mount(&server)
|
||||
@@ -29,10 +30,10 @@ pub async fn create_mock_chat_completions_server(responses: Vec<String>) -> Mock
|
||||
server
|
||||
}
|
||||
|
||||
/// Same as `create_mock_chat_completions_server` but does not enforce an
|
||||
/// Same as `create_mock_responses_server_sequence` but does not enforce an
|
||||
/// expectation on the number of calls.
|
||||
pub async fn create_mock_chat_completions_server_unchecked(responses: Vec<String>) -> MockServer {
|
||||
let server = MockServer::start().await;
|
||||
pub async fn create_mock_responses_server_sequence_unchecked(responses: Vec<String>) -> MockServer {
|
||||
let server = responses::start_mock_server().await;
|
||||
|
||||
let seq_responder = SeqResponder {
|
||||
num_calls: AtomicUsize::new(0),
|
||||
@@ -40,7 +41,7 @@ pub async fn create_mock_chat_completions_server_unchecked(responses: Vec<String
|
||||
};
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/chat/completions"))
|
||||
.and(path_regex(".*/responses$"))
|
||||
.respond_with(seq_responder)
|
||||
.mount(&server)
|
||||
.await;
|
||||
@@ -57,10 +58,24 @@ impl Respond for SeqResponder {
|
||||
fn respond(&self, _: &wiremock::Request) -> ResponseTemplate {
|
||||
let call_num = self.num_calls.fetch_add(1, Ordering::SeqCst);
|
||||
match self.responses.get(call_num) {
|
||||
Some(response) => ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "text/event-stream")
|
||||
.set_body_raw(response.clone(), "text/event-stream"),
|
||||
Some(response) => responses::sse_response(response.clone()),
|
||||
None => panic!("no response for {call_num}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a mock responses API server that returns the same assistant message for every request.
|
||||
pub async fn create_mock_responses_server_repeating_assistant(message: &str) -> MockServer {
|
||||
let server = responses::start_mock_server().await;
|
||||
let body = responses::sse(vec![
|
||||
responses::ev_response_created("resp-1"),
|
||||
responses::ev_assistant_message("msg-1", message),
|
||||
responses::ev_completed("resp-1"),
|
||||
]);
|
||||
Mock::given(method("POST"))
|
||||
.and(path_regex(".*/responses$"))
|
||||
.respond_with(responses::sse_response(body))
|
||||
.mount(&server)
|
||||
.await;
|
||||
server
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use core_test_support::responses;
|
||||
use serde_json::json;
|
||||
use std::path::Path;
|
||||
|
||||
@@ -14,85 +15,30 @@ pub fn create_shell_command_sse_response(
|
||||
"workdir": workdir.map(|w| w.to_string_lossy()),
|
||||
"timeout_ms": timeout_ms
|
||||
}))?;
|
||||
let tool_call = json!({
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": call_id,
|
||||
"function": {
|
||||
"name": "shell_command",
|
||||
"arguments": tool_call_arguments
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
"finish_reason": "tool_calls"
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
let sse = format!(
|
||||
"data: {}\n\ndata: DONE\n\n",
|
||||
serde_json::to_string(&tool_call)?
|
||||
);
|
||||
Ok(sse)
|
||||
Ok(responses::sse(vec![
|
||||
responses::ev_response_created("resp-1"),
|
||||
responses::ev_function_call(call_id, "shell_command", &tool_call_arguments),
|
||||
responses::ev_completed("resp-1"),
|
||||
]))
|
||||
}
|
||||
|
||||
pub fn create_final_assistant_message_sse_response(message: &str) -> anyhow::Result<String> {
|
||||
let assistant_message = json!({
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"content": message
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
let sse = format!(
|
||||
"data: {}\n\ndata: DONE\n\n",
|
||||
serde_json::to_string(&assistant_message)?
|
||||
);
|
||||
Ok(sse)
|
||||
Ok(responses::sse(vec![
|
||||
responses::ev_response_created("resp-1"),
|
||||
responses::ev_assistant_message("msg-1", message),
|
||||
responses::ev_completed("resp-1"),
|
||||
]))
|
||||
}
|
||||
|
||||
pub fn create_apply_patch_sse_response(
|
||||
patch_content: &str,
|
||||
call_id: &str,
|
||||
) -> anyhow::Result<String> {
|
||||
// Use shell_command to call apply_patch with heredoc format
|
||||
let command = format!("apply_patch <<'EOF'\n{patch_content}\nEOF");
|
||||
let tool_call_arguments = serde_json::to_string(&json!({
|
||||
"command": command
|
||||
}))?;
|
||||
|
||||
let tool_call = json!({
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": call_id,
|
||||
"function": {
|
||||
"name": "shell_command",
|
||||
"arguments": tool_call_arguments
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
"finish_reason": "tool_calls"
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
let sse = format!(
|
||||
"data: {}\n\ndata: DONE\n\n",
|
||||
serde_json::to_string(&tool_call)?
|
||||
);
|
||||
Ok(sse)
|
||||
Ok(responses::sse(vec![
|
||||
responses::ev_response_created("resp-1"),
|
||||
responses::ev_apply_patch_shell_command_call_via_heredoc(call_id, patch_content),
|
||||
responses::ev_completed("resp-1"),
|
||||
]))
|
||||
}
|
||||
|
||||
pub fn create_exec_command_sse_response(call_id: &str) -> anyhow::Result<String> {
|
||||
@@ -108,28 +54,9 @@ pub fn create_exec_command_sse_response(call_id: &str) -> anyhow::Result<String>
|
||||
"cmd": command.join(" "),
|
||||
"yield_time_ms": 500
|
||||
}))?;
|
||||
let tool_call = json!({
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": call_id,
|
||||
"function": {
|
||||
"name": "exec_command",
|
||||
"arguments": tool_call_arguments
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
"finish_reason": "tool_calls"
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
let sse = format!(
|
||||
"data: {}\n\ndata: DONE\n\n",
|
||||
serde_json::to_string(&tool_call)?
|
||||
);
|
||||
Ok(sse)
|
||||
Ok(responses::sse(vec![
|
||||
responses::ev_response_created("resp-1"),
|
||||
responses::ev_function_call(call_id, "exec_command", &tool_call_arguments),
|
||||
responses::ev_completed("resp-1"),
|
||||
]))
|
||||
}
|
||||
|
||||
@@ -37,7 +37,7 @@ model_provider = "mock_provider"
|
||||
[model_providers.mock_provider]
|
||||
name = "Mock provider for test"
|
||||
base_url = "http://127.0.0.1:0/v1"
|
||||
wire_api = "chat"
|
||||
wire_api = "responses"
|
||||
request_max_retries = 0
|
||||
stream_max_retries = 0
|
||||
{requires_line}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use anyhow::Result;
|
||||
use app_test_support::McpProcess;
|
||||
use app_test_support::create_final_assistant_message_sse_response;
|
||||
use app_test_support::create_mock_chat_completions_server;
|
||||
use app_test_support::create_mock_responses_server_sequence;
|
||||
use app_test_support::create_shell_command_sse_response;
|
||||
use app_test_support::format_with_current_shell;
|
||||
use app_test_support::to_response;
|
||||
@@ -65,7 +65,7 @@ async fn test_codex_jsonrpc_conversation_flow() -> Result<()> {
|
||||
)?,
|
||||
create_final_assistant_message_sse_response("Enjoy your new git repo!")?,
|
||||
];
|
||||
let server = create_mock_chat_completions_server(responses).await;
|
||||
let server = create_mock_responses_server_sequence(responses).await;
|
||||
create_config_toml(&codex_home, &server.uri())?;
|
||||
|
||||
// Start MCP server and initialize.
|
||||
@@ -197,7 +197,7 @@ async fn test_send_user_turn_changes_approval_policy_behavior() -> Result<()> {
|
||||
)?,
|
||||
create_final_assistant_message_sse_response("done 2")?,
|
||||
];
|
||||
let server = create_mock_chat_completions_server(responses).await;
|
||||
let server = create_mock_responses_server_sequence(responses).await;
|
||||
create_config_toml(&codex_home, &server.uri())?;
|
||||
|
||||
// Start MCP server and initialize.
|
||||
@@ -363,7 +363,7 @@ async fn test_send_user_turn_updates_sandbox_and_cwd_between_turns() -> Result<(
|
||||
)?,
|
||||
create_final_assistant_message_sse_response("done second")?,
|
||||
];
|
||||
let server = create_mock_chat_completions_server(responses).await;
|
||||
let server = create_mock_responses_server_sequence(responses).await;
|
||||
create_config_toml(&codex_home, &server.uri())?;
|
||||
|
||||
let mut mcp = McpProcess::new(&codex_home).await?;
|
||||
@@ -430,6 +430,7 @@ async fn test_send_user_turn_updates_sandbox_and_cwd_between_turns() -> Result<(
|
||||
mcp.read_stream_until_notification_message("codex/event/task_complete"),
|
||||
)
|
||||
.await??;
|
||||
mcp.clear_message_buffer();
|
||||
|
||||
let second_turn_id = mcp
|
||||
.send_send_user_turn_request(SendUserTurnParams {
|
||||
@@ -499,7 +500,7 @@ model_provider = "mock_provider"
|
||||
[model_providers.mock_provider]
|
||||
name = "Mock provider for test"
|
||||
base_url = "{server_uri}/v1"
|
||||
wire_api = "chat"
|
||||
wire_api = "responses"
|
||||
request_max_retries = 0
|
||||
stream_max_retries = 0
|
||||
"#
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
use anyhow::Result;
|
||||
use app_test_support::McpProcess;
|
||||
use app_test_support::create_final_assistant_message_sse_response;
|
||||
use app_test_support::create_mock_chat_completions_server;
|
||||
use app_test_support::to_response;
|
||||
use codex_app_server_protocol::AddConversationListenerParams;
|
||||
use codex_app_server_protocol::AddConversationSubscriptionResponse;
|
||||
@@ -12,6 +11,7 @@ use codex_app_server_protocol::NewConversationResponse;
|
||||
use codex_app_server_protocol::RequestId;
|
||||
use codex_app_server_protocol::SendUserMessageParams;
|
||||
use codex_app_server_protocol::SendUserMessageResponse;
|
||||
use core_test_support::responses;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
use std::path::Path;
|
||||
@@ -23,8 +23,9 @@ const DEFAULT_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn test_conversation_create_and_send_message_ok() -> Result<()> {
|
||||
// Mock server – we won't strictly rely on it, but provide one to satisfy any model wiring.
|
||||
let responses = vec![create_final_assistant_message_sse_response("Done")?];
|
||||
let server = create_mock_chat_completions_server(responses).await;
|
||||
let response_body = create_final_assistant_message_sse_response("Done")?;
|
||||
let server = responses::start_mock_server().await;
|
||||
let response_mock = responses::mount_sse_sequence(&server, vec![response_body]).await;
|
||||
|
||||
// Temporary Codex home with config pointing at the mock server.
|
||||
let codex_home = TempDir::new()?;
|
||||
@@ -86,32 +87,30 @@ async fn test_conversation_create_and_send_message_ok() -> Result<()> {
|
||||
.await??;
|
||||
let _ok: SendUserMessageResponse = to_response::<SendUserMessageResponse>(send_resp)?;
|
||||
|
||||
// avoid race condition by waiting for the mock server to receive the chat.completions request
|
||||
// Avoid race condition by waiting for the mock server to receive the responses request.
|
||||
let deadline = std::time::Instant::now() + DEFAULT_READ_TIMEOUT;
|
||||
let requests = loop {
|
||||
let requests = server.received_requests().await.unwrap_or_default();
|
||||
let requests = response_mock.requests();
|
||||
if !requests.is_empty() {
|
||||
break requests;
|
||||
}
|
||||
if std::time::Instant::now() >= deadline {
|
||||
panic!("mock server did not receive the chat.completions request in time");
|
||||
panic!("mock server did not receive the responses request in time");
|
||||
}
|
||||
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
||||
};
|
||||
|
||||
// Verify the outbound request body matches expectations for Chat Completions.
|
||||
// Verify the outbound request body matches expectations for Responses.
|
||||
let request = requests
|
||||
.first()
|
||||
.expect("mock server should have received at least one request");
|
||||
let body = request.body_json::<serde_json::Value>()?;
|
||||
let body = request.body_json();
|
||||
assert_eq!(body["model"], json!("o3"));
|
||||
assert!(body["stream"].as_bool().unwrap_or(false));
|
||||
let messages = body["messages"]
|
||||
.as_array()
|
||||
.expect("messages should be array");
|
||||
let last = messages.last().expect("at least one message");
|
||||
assert_eq!(last["role"], json!("user"));
|
||||
assert_eq!(last["content"], json!("Hello"));
|
||||
let user_texts = request.message_input_texts("user");
|
||||
assert!(
|
||||
user_texts.iter().any(|text| text == "Hello"),
|
||||
"expected user input to include Hello, got {user_texts:?}"
|
||||
);
|
||||
|
||||
drop(server);
|
||||
Ok(())
|
||||
@@ -133,7 +132,7 @@ model_provider = "mock_provider"
|
||||
[model_providers.mock_provider]
|
||||
name = "Mock provider for test"
|
||||
base_url = "{server_uri}/v1"
|
||||
wire_api = "chat"
|
||||
wire_api = "responses"
|
||||
request_max_retries = 0
|
||||
stream_max_retries = 0
|
||||
"#
|
||||
|
||||
140
codex-rs/app-server/tests/suite/fork_thread.rs
Normal file
140
codex-rs/app-server/tests/suite/fork_thread.rs
Normal file
@@ -0,0 +1,140 @@
|
||||
use anyhow::Result;
|
||||
use app_test_support::McpProcess;
|
||||
use app_test_support::create_fake_rollout;
|
||||
use app_test_support::to_response;
|
||||
use codex_app_server_protocol::ForkConversationParams;
|
||||
use codex_app_server_protocol::ForkConversationResponse;
|
||||
use codex_app_server_protocol::JSONRPCNotification;
|
||||
use codex_app_server_protocol::JSONRPCResponse;
|
||||
use codex_app_server_protocol::NewConversationParams; // reused for overrides shape
|
||||
use codex_app_server_protocol::RequestId;
|
||||
use codex_app_server_protocol::ServerNotification;
|
||||
use codex_app_server_protocol::SessionConfiguredNotification;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use pretty_assertions::assert_eq;
|
||||
use tempfile::TempDir;
|
||||
use tokio::time::timeout;
|
||||
|
||||
const DEFAULT_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn fork_conversation_creates_new_rollout() -> Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
|
||||
let preview = "Hello A";
|
||||
let conversation_id = create_fake_rollout(
|
||||
codex_home.path(),
|
||||
"2025-01-02T12-00-00",
|
||||
"2025-01-02T12:00:00Z",
|
||||
preview,
|
||||
Some("openai"),
|
||||
None,
|
||||
)?;
|
||||
|
||||
let original_path = codex_home
|
||||
.path()
|
||||
.join("sessions")
|
||||
.join("2025")
|
||||
.join("01")
|
||||
.join("02")
|
||||
.join(format!(
|
||||
"rollout-2025-01-02T12-00-00-{conversation_id}.jsonl"
|
||||
));
|
||||
assert!(
|
||||
original_path.exists(),
|
||||
"expected original rollout to exist at {}",
|
||||
original_path.display()
|
||||
);
|
||||
let original_contents = std::fs::read_to_string(&original_path)?;
|
||||
|
||||
let mut mcp = McpProcess::new(codex_home.path()).await?;
|
||||
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()).await??;
|
||||
|
||||
let fork_req_id = mcp
|
||||
.send_fork_conversation_request(ForkConversationParams {
|
||||
path: Some(original_path.clone()),
|
||||
conversation_id: None,
|
||||
overrides: Some(NewConversationParams {
|
||||
model: Some("o3".to_string()),
|
||||
..Default::default()
|
||||
}),
|
||||
})
|
||||
.await?;
|
||||
|
||||
// Expect a sessionConfigured notification for the forked session.
|
||||
let notification: JSONRPCNotification = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_notification_message("sessionConfigured"),
|
||||
)
|
||||
.await??;
|
||||
let session_configured: ServerNotification = notification.try_into()?;
|
||||
let ServerNotification::SessionConfigured(SessionConfiguredNotification {
|
||||
model,
|
||||
session_id,
|
||||
rollout_path,
|
||||
initial_messages: session_initial_messages,
|
||||
..
|
||||
}) = session_configured
|
||||
else {
|
||||
unreachable!("expected sessionConfigured notification");
|
||||
};
|
||||
|
||||
assert_eq!(model, "o3");
|
||||
assert_ne!(
|
||||
session_id.to_string(),
|
||||
conversation_id,
|
||||
"expected a new conversation id when forking"
|
||||
);
|
||||
assert_ne!(
|
||||
rollout_path, original_path,
|
||||
"expected a new rollout path when forking"
|
||||
);
|
||||
assert!(
|
||||
rollout_path.exists(),
|
||||
"expected forked rollout to exist at {}",
|
||||
rollout_path.display()
|
||||
);
|
||||
|
||||
let session_initial_messages =
|
||||
session_initial_messages.expect("expected initial messages when forking from rollout");
|
||||
match session_initial_messages.as_slice() {
|
||||
[EventMsg::UserMessage(message)] => {
|
||||
assert_eq!(message.message, preview);
|
||||
}
|
||||
other => panic!("unexpected initial messages from rollout fork: {other:#?}"),
|
||||
}
|
||||
|
||||
// Then the response for forkConversation.
|
||||
let fork_resp: JSONRPCResponse = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(fork_req_id)),
|
||||
)
|
||||
.await??;
|
||||
let ForkConversationResponse {
|
||||
conversation_id: forked_id,
|
||||
model: forked_model,
|
||||
initial_messages: response_initial_messages,
|
||||
rollout_path: response_rollout_path,
|
||||
} = to_response::<ForkConversationResponse>(fork_resp)?;
|
||||
|
||||
assert_eq!(forked_model, "o3");
|
||||
assert_eq!(response_rollout_path, rollout_path);
|
||||
assert_ne!(forked_id.to_string(), conversation_id);
|
||||
|
||||
let response_initial_messages =
|
||||
response_initial_messages.expect("expected initial messages in fork response");
|
||||
match response_initial_messages.as_slice() {
|
||||
[EventMsg::UserMessage(message)] => {
|
||||
assert_eq!(message.message, preview);
|
||||
}
|
||||
other => panic!("unexpected initial messages in fork response: {other:#?}"),
|
||||
}
|
||||
|
||||
let after_contents = std::fs::read_to_string(&original_path)?;
|
||||
assert_eq!(
|
||||
after_contents, original_contents,
|
||||
"fork should not mutate the original rollout file"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -18,7 +18,7 @@ use tempfile::TempDir;
|
||||
use tokio::time::timeout;
|
||||
|
||||
use app_test_support::McpProcess;
|
||||
use app_test_support::create_mock_chat_completions_server;
|
||||
use app_test_support::create_mock_responses_server_sequence;
|
||||
use app_test_support::create_shell_command_sse_response;
|
||||
use app_test_support::to_response;
|
||||
|
||||
@@ -56,7 +56,7 @@ async fn shell_command_interruption() -> anyhow::Result<()> {
|
||||
std::fs::create_dir(&working_directory)?;
|
||||
|
||||
// Create mock server with a single SSE response: the long sleep command
|
||||
let server = create_mock_chat_completions_server(vec![create_shell_command_sse_response(
|
||||
let server = create_mock_responses_server_sequence(vec![create_shell_command_sse_response(
|
||||
shell_command.clone(),
|
||||
Some(&working_directory),
|
||||
Some(10_000), // 10 seconds timeout in ms
|
||||
@@ -153,7 +153,7 @@ model_provider = "mock_provider"
|
||||
[model_providers.mock_provider]
|
||||
name = "Mock provider for test"
|
||||
base_url = "{server_uri}/v1"
|
||||
wire_api = "chat"
|
||||
wire_api = "responses"
|
||||
request_max_retries = 0
|
||||
stream_max_retries = 0
|
||||
"#
|
||||
|
||||
@@ -32,7 +32,7 @@ model_provider = "mock_provider"
|
||||
[model_providers.mock_provider]
|
||||
name = "Mock provider for test"
|
||||
base_url = "http://127.0.0.1:0/v1"
|
||||
wire_api = "chat"
|
||||
wire_api = "responses"
|
||||
request_max_retries = 0
|
||||
stream_max_retries = 0
|
||||
"#,
|
||||
|
||||
@@ -3,6 +3,7 @@ mod auth;
|
||||
mod codex_message_processor_flow;
|
||||
mod config;
|
||||
mod create_thread;
|
||||
mod fork_thread;
|
||||
mod fuzzy_file_search;
|
||||
mod interrupt;
|
||||
mod list_resume;
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
use anyhow::Result;
|
||||
use app_test_support::McpProcess;
|
||||
use app_test_support::create_final_assistant_message_sse_response;
|
||||
use app_test_support::create_mock_chat_completions_server;
|
||||
use app_test_support::to_response;
|
||||
use codex_app_server_protocol::AddConversationListenerParams;
|
||||
use codex_app_server_protocol::AddConversationSubscriptionResponse;
|
||||
@@ -17,6 +15,7 @@ use codex_protocol::ThreadId;
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::protocol::RawResponseItemEvent;
|
||||
use core_test_support::responses;
|
||||
use pretty_assertions::assert_eq;
|
||||
use std::path::Path;
|
||||
use tempfile::TempDir;
|
||||
@@ -26,13 +25,21 @@ const DEFAULT_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_send_message_success() -> Result<()> {
|
||||
// Spin up a mock completions server that immediately ends the Codex turn.
|
||||
// Spin up a mock responses server that immediately ends the Codex turn.
|
||||
// Two Codex turns hit the mock model (session start + send-user-message). Provide two SSE responses.
|
||||
let responses = vec![
|
||||
create_final_assistant_message_sse_response("Done")?,
|
||||
create_final_assistant_message_sse_response("Done")?,
|
||||
];
|
||||
let server = create_mock_chat_completions_server(responses).await;
|
||||
let server = responses::start_mock_server().await;
|
||||
let body1 = responses::sse(vec![
|
||||
responses::ev_response_created("resp-1"),
|
||||
responses::ev_assistant_message("msg-1", "Done"),
|
||||
responses::ev_completed("resp-1"),
|
||||
]);
|
||||
let body2 = responses::sse(vec![
|
||||
responses::ev_response_created("resp-2"),
|
||||
responses::ev_assistant_message("msg-2", "Done"),
|
||||
responses::ev_completed("resp-2"),
|
||||
]);
|
||||
let _response_mock1 = responses::mount_sse_once(&server, body1).await;
|
||||
let _response_mock2 = responses::mount_sse_once(&server, body2).await;
|
||||
|
||||
// Create a temporary Codex home with config pointing at the mock server.
|
||||
let codex_home = TempDir::new()?;
|
||||
@@ -135,8 +142,13 @@ async fn send_message(
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_send_message_raw_notifications_opt_in() -> Result<()> {
|
||||
let responses = vec![create_final_assistant_message_sse_response("Done")?];
|
||||
let server = create_mock_chat_completions_server(responses).await;
|
||||
let server = responses::start_mock_server().await;
|
||||
let body = responses::sse(vec![
|
||||
responses::ev_response_created("resp-1"),
|
||||
responses::ev_assistant_message("msg-1", "Done"),
|
||||
responses::ev_completed("resp-1"),
|
||||
]);
|
||||
let _response_mock = responses::mount_sse_once(&server, body).await;
|
||||
|
||||
let codex_home = TempDir::new()?;
|
||||
create_config_toml(codex_home.path(), &server.uri())?;
|
||||
@@ -259,7 +271,7 @@ model_provider = "mock_provider"
|
||||
[model_providers.mock_provider]
|
||||
name = "Mock provider for test"
|
||||
base_url = "{server_uri}/v1"
|
||||
wire_api = "chat"
|
||||
wire_api = "responses"
|
||||
request_max_retries = 0
|
||||
stream_max_retries = 0
|
||||
"#
|
||||
@@ -269,6 +281,7 @@ stream_max_retries = 0
|
||||
|
||||
#[expect(clippy::expect_used)]
|
||||
async fn read_raw_response_item(mcp: &mut McpProcess, conversation_id: ThreadId) -> ResponseItem {
|
||||
// TODO: Switch to rawResponseItem/completed once we migrate to app server v2 in codex web.
|
||||
loop {
|
||||
let raw_notification: JSONRPCNotification = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
|
||||
@@ -67,7 +67,7 @@ model_provider = "mock_provider"
|
||||
[model_providers.mock_provider]
|
||||
name = "Mock provider for test"
|
||||
base_url = "http://127.0.0.1:0/v1"
|
||||
wire_api = "chat"
|
||||
wire_api = "responses"
|
||||
request_max_retries = 0
|
||||
stream_max_retries = 0
|
||||
{requires_line}
|
||||
|
||||
@@ -5,7 +5,9 @@ mod output_schema;
|
||||
mod rate_limits;
|
||||
mod review;
|
||||
mod thread_archive;
|
||||
mod thread_fork;
|
||||
mod thread_list;
|
||||
mod thread_loaded_list;
|
||||
mod thread_resume;
|
||||
mod thread_rollback;
|
||||
mod thread_start;
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
use anyhow::Result;
|
||||
use app_test_support::McpProcess;
|
||||
use app_test_support::create_final_assistant_message_sse_response;
|
||||
use app_test_support::create_mock_chat_completions_server_unchecked;
|
||||
use app_test_support::create_mock_responses_server_repeating_assistant;
|
||||
use app_test_support::to_response;
|
||||
use codex_app_server_protocol::ItemCompletedNotification;
|
||||
use codex_app_server_protocol::ItemStartedNotification;
|
||||
@@ -44,10 +43,7 @@ async fn review_start_runs_review_turn_and_emits_code_review_item() -> Result<()
|
||||
"overall_confidence_score": 0.75
|
||||
})
|
||||
.to_string();
|
||||
let responses = vec![create_final_assistant_message_sse_response(
|
||||
&review_payload,
|
||||
)?];
|
||||
let server = create_mock_chat_completions_server_unchecked(responses).await;
|
||||
let server = create_mock_responses_server_repeating_assistant(&review_payload).await;
|
||||
|
||||
let codex_home = TempDir::new()?;
|
||||
create_config_toml(codex_home.path(), &server.uri())?;
|
||||
@@ -135,7 +131,7 @@ async fn review_start_runs_review_turn_and_emits_code_review_item() -> Result<()
|
||||
|
||||
#[tokio::test]
|
||||
async fn review_start_rejects_empty_base_branch() -> Result<()> {
|
||||
let server = create_mock_chat_completions_server_unchecked(vec![]).await;
|
||||
let server = create_mock_responses_server_repeating_assistant("Done").await;
|
||||
let codex_home = TempDir::new()?;
|
||||
create_config_toml(codex_home.path(), &server.uri())?;
|
||||
|
||||
@@ -176,10 +172,7 @@ async fn review_start_with_detached_delivery_returns_new_thread_id() -> Result<(
|
||||
"overall_confidence_score": 0.5
|
||||
})
|
||||
.to_string();
|
||||
let responses = vec![create_final_assistant_message_sse_response(
|
||||
&review_payload,
|
||||
)?];
|
||||
let server = create_mock_chat_completions_server_unchecked(responses).await;
|
||||
let server = create_mock_responses_server_repeating_assistant(&review_payload).await;
|
||||
|
||||
let codex_home = TempDir::new()?;
|
||||
create_config_toml(codex_home.path(), &server.uri())?;
|
||||
@@ -219,7 +212,7 @@ async fn review_start_with_detached_delivery_returns_new_thread_id() -> Result<(
|
||||
|
||||
#[tokio::test]
|
||||
async fn review_start_rejects_empty_commit_sha() -> Result<()> {
|
||||
let server = create_mock_chat_completions_server_unchecked(vec![]).await;
|
||||
let server = create_mock_responses_server_repeating_assistant("Done").await;
|
||||
let codex_home = TempDir::new()?;
|
||||
create_config_toml(codex_home.path(), &server.uri())?;
|
||||
|
||||
@@ -254,7 +247,7 @@ async fn review_start_rejects_empty_commit_sha() -> Result<()> {
|
||||
|
||||
#[tokio::test]
|
||||
async fn review_start_rejects_empty_custom_instructions() -> Result<()> {
|
||||
let server = create_mock_chat_completions_server_unchecked(vec![]).await;
|
||||
let server = create_mock_responses_server_repeating_assistant("Done").await;
|
||||
let codex_home = TempDir::new()?;
|
||||
create_config_toml(codex_home.path(), &server.uri())?;
|
||||
|
||||
@@ -320,7 +313,7 @@ model_provider = "mock_provider"
|
||||
[model_providers.mock_provider]
|
||||
name = "Mock provider"
|
||||
base_url = "{server_uri}/v1"
|
||||
wire_api = "chat"
|
||||
wire_api = "responses"
|
||||
request_max_retries = 0
|
||||
stream_max_retries = 0
|
||||
"#
|
||||
|
||||
140
codex-rs/app-server/tests/suite/v2/thread_fork.rs
Normal file
140
codex-rs/app-server/tests/suite/v2/thread_fork.rs
Normal file
@@ -0,0 +1,140 @@
|
||||
use anyhow::Result;
|
||||
use app_test_support::McpProcess;
|
||||
use app_test_support::create_fake_rollout;
|
||||
use app_test_support::create_mock_responses_server_repeating_assistant;
|
||||
use app_test_support::to_response;
|
||||
use codex_app_server_protocol::JSONRPCNotification;
|
||||
use codex_app_server_protocol::JSONRPCResponse;
|
||||
use codex_app_server_protocol::RequestId;
|
||||
use codex_app_server_protocol::SessionSource;
|
||||
use codex_app_server_protocol::ThreadForkParams;
|
||||
use codex_app_server_protocol::ThreadForkResponse;
|
||||
use codex_app_server_protocol::ThreadItem;
|
||||
use codex_app_server_protocol::ThreadStartedNotification;
|
||||
use codex_app_server_protocol::TurnStatus;
|
||||
use codex_app_server_protocol::UserInput;
|
||||
use pretty_assertions::assert_eq;
|
||||
use std::path::Path;
|
||||
use tempfile::TempDir;
|
||||
use tokio::time::timeout;
|
||||
|
||||
const DEFAULT_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
|
||||
|
||||
#[tokio::test]
|
||||
async fn thread_fork_creates_new_thread_and_emits_started() -> Result<()> {
|
||||
let server = create_mock_responses_server_repeating_assistant("Done").await;
|
||||
let codex_home = TempDir::new()?;
|
||||
create_config_toml(codex_home.path(), &server.uri())?;
|
||||
|
||||
let preview = "Saved user message";
|
||||
let conversation_id = create_fake_rollout(
|
||||
codex_home.path(),
|
||||
"2025-01-05T12-00-00",
|
||||
"2025-01-05T12:00:00Z",
|
||||
preview,
|
||||
Some("mock_provider"),
|
||||
None,
|
||||
)?;
|
||||
|
||||
let original_path = codex_home
|
||||
.path()
|
||||
.join("sessions")
|
||||
.join("2025")
|
||||
.join("01")
|
||||
.join("05")
|
||||
.join(format!(
|
||||
"rollout-2025-01-05T12-00-00-{conversation_id}.jsonl"
|
||||
));
|
||||
assert!(
|
||||
original_path.exists(),
|
||||
"expected original rollout to exist at {}",
|
||||
original_path.display()
|
||||
);
|
||||
let original_contents = std::fs::read_to_string(&original_path)?;
|
||||
|
||||
let mut mcp = McpProcess::new(codex_home.path()).await?;
|
||||
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()).await??;
|
||||
|
||||
let fork_id = mcp
|
||||
.send_thread_fork_request(ThreadForkParams {
|
||||
thread_id: conversation_id.clone(),
|
||||
..Default::default()
|
||||
})
|
||||
.await?;
|
||||
let fork_resp: JSONRPCResponse = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(fork_id)),
|
||||
)
|
||||
.await??;
|
||||
let ThreadForkResponse { thread, .. } = to_response::<ThreadForkResponse>(fork_resp)?;
|
||||
|
||||
let after_contents = std::fs::read_to_string(&original_path)?;
|
||||
assert_eq!(
|
||||
after_contents, original_contents,
|
||||
"fork should not mutate the original rollout file"
|
||||
);
|
||||
|
||||
assert_ne!(thread.id, conversation_id);
|
||||
assert_eq!(thread.preview, preview);
|
||||
assert_eq!(thread.model_provider, "mock_provider");
|
||||
assert!(thread.path.is_absolute());
|
||||
assert_ne!(thread.path, original_path);
|
||||
assert!(thread.cwd.is_absolute());
|
||||
assert_eq!(thread.source, SessionSource::VsCode);
|
||||
|
||||
assert_eq!(
|
||||
thread.turns.len(),
|
||||
1,
|
||||
"expected forked thread to include one turn"
|
||||
);
|
||||
let turn = &thread.turns[0];
|
||||
assert_eq!(turn.status, TurnStatus::Completed);
|
||||
assert_eq!(turn.items.len(), 1, "expected user message item");
|
||||
match &turn.items[0] {
|
||||
ThreadItem::UserMessage { content, .. } => {
|
||||
assert_eq!(
|
||||
content,
|
||||
&vec![UserInput::Text {
|
||||
text: preview.to_string()
|
||||
}]
|
||||
);
|
||||
}
|
||||
other => panic!("expected user message item, got {other:?}"),
|
||||
}
|
||||
|
||||
// A corresponding thread/started notification should arrive.
|
||||
let notif: JSONRPCNotification = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_notification_message("thread/started"),
|
||||
)
|
||||
.await??;
|
||||
let started: ThreadStartedNotification =
|
||||
serde_json::from_value(notif.params.expect("params must be present"))?;
|
||||
assert_eq!(started.thread, thread);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Helper to create a config.toml pointing at the mock model server.
|
||||
fn create_config_toml(codex_home: &Path, server_uri: &str) -> std::io::Result<()> {
|
||||
let config_toml = codex_home.join("config.toml");
|
||||
std::fs::write(
|
||||
config_toml,
|
||||
format!(
|
||||
r#"
|
||||
model = "mock-model"
|
||||
approval_policy = "never"
|
||||
sandbox_mode = "read-only"
|
||||
|
||||
model_provider = "mock_provider"
|
||||
|
||||
[model_providers.mock_provider]
|
||||
name = "Mock provider for test"
|
||||
base_url = "{server_uri}/v1"
|
||||
wire_api = "responses"
|
||||
request_max_retries = 0
|
||||
stream_max_retries = 0
|
||||
"#
|
||||
),
|
||||
)
|
||||
}
|
||||
139
codex-rs/app-server/tests/suite/v2/thread_loaded_list.rs
Normal file
139
codex-rs/app-server/tests/suite/v2/thread_loaded_list.rs
Normal file
@@ -0,0 +1,139 @@
|
||||
use anyhow::Result;
|
||||
use app_test_support::McpProcess;
|
||||
use app_test_support::create_mock_responses_server_repeating_assistant;
|
||||
use app_test_support::to_response;
|
||||
use codex_app_server_protocol::JSONRPCResponse;
|
||||
use codex_app_server_protocol::RequestId;
|
||||
use codex_app_server_protocol::ThreadLoadedListParams;
|
||||
use codex_app_server_protocol::ThreadLoadedListResponse;
|
||||
use codex_app_server_protocol::ThreadStartParams;
|
||||
use codex_app_server_protocol::ThreadStartResponse;
|
||||
use pretty_assertions::assert_eq;
|
||||
use std::path::Path;
|
||||
use tempfile::TempDir;
|
||||
use tokio::time::timeout;
|
||||
|
||||
const DEFAULT_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
|
||||
|
||||
#[tokio::test]
|
||||
async fn thread_loaded_list_returns_loaded_thread_ids() -> Result<()> {
|
||||
let server = create_mock_responses_server_repeating_assistant("Done").await;
|
||||
let codex_home = TempDir::new()?;
|
||||
create_config_toml(codex_home.path(), &server.uri())?;
|
||||
|
||||
let mut mcp = McpProcess::new(codex_home.path()).await?;
|
||||
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()).await??;
|
||||
|
||||
let thread_id = start_thread(&mut mcp).await?;
|
||||
|
||||
let list_id = mcp
|
||||
.send_thread_loaded_list_request(ThreadLoadedListParams::default())
|
||||
.await?;
|
||||
let resp: JSONRPCResponse = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(list_id)),
|
||||
)
|
||||
.await??;
|
||||
let ThreadLoadedListResponse {
|
||||
mut data,
|
||||
next_cursor,
|
||||
} = to_response::<ThreadLoadedListResponse>(resp)?;
|
||||
data.sort();
|
||||
assert_eq!(data, vec![thread_id]);
|
||||
assert_eq!(next_cursor, None);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn thread_loaded_list_paginates() -> Result<()> {
|
||||
let server = create_mock_responses_server_repeating_assistant("Done").await;
|
||||
let codex_home = TempDir::new()?;
|
||||
create_config_toml(codex_home.path(), &server.uri())?;
|
||||
|
||||
let mut mcp = McpProcess::new(codex_home.path()).await?;
|
||||
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()).await??;
|
||||
|
||||
let first = start_thread(&mut mcp).await?;
|
||||
let second = start_thread(&mut mcp).await?;
|
||||
|
||||
let mut expected = [first, second];
|
||||
expected.sort();
|
||||
|
||||
let list_id = mcp
|
||||
.send_thread_loaded_list_request(ThreadLoadedListParams {
|
||||
cursor: None,
|
||||
limit: Some(1),
|
||||
})
|
||||
.await?;
|
||||
let resp: JSONRPCResponse = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(list_id)),
|
||||
)
|
||||
.await??;
|
||||
let ThreadLoadedListResponse {
|
||||
data: first_page,
|
||||
next_cursor,
|
||||
} = to_response::<ThreadLoadedListResponse>(resp)?;
|
||||
assert_eq!(first_page, vec![expected[0].clone()]);
|
||||
assert_eq!(next_cursor, Some(expected[0].clone()));
|
||||
|
||||
let list_id = mcp
|
||||
.send_thread_loaded_list_request(ThreadLoadedListParams {
|
||||
cursor: next_cursor,
|
||||
limit: Some(1),
|
||||
})
|
||||
.await?;
|
||||
let resp: JSONRPCResponse = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(list_id)),
|
||||
)
|
||||
.await??;
|
||||
let ThreadLoadedListResponse {
|
||||
data: second_page,
|
||||
next_cursor,
|
||||
} = to_response::<ThreadLoadedListResponse>(resp)?;
|
||||
assert_eq!(second_page, vec![expected[1].clone()]);
|
||||
assert_eq!(next_cursor, None);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn create_config_toml(codex_home: &Path, server_uri: &str) -> std::io::Result<()> {
|
||||
let config_toml = codex_home.join("config.toml");
|
||||
std::fs::write(
|
||||
config_toml,
|
||||
format!(
|
||||
r#"
|
||||
model = "mock-model"
|
||||
approval_policy = "never"
|
||||
sandbox_mode = "read-only"
|
||||
|
||||
model_provider = "mock_provider"
|
||||
|
||||
[model_providers.mock_provider]
|
||||
name = "Mock provider for test"
|
||||
base_url = "{server_uri}/v1"
|
||||
wire_api = "responses"
|
||||
request_max_retries = 0
|
||||
stream_max_retries = 0
|
||||
"#
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
async fn start_thread(mcp: &mut McpProcess) -> Result<String> {
|
||||
let req_id = mcp
|
||||
.send_thread_start_request(ThreadStartParams {
|
||||
model: Some("gpt-5.1".to_string()),
|
||||
..Default::default()
|
||||
})
|
||||
.await?;
|
||||
let resp: JSONRPCResponse = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(req_id)),
|
||||
)
|
||||
.await??;
|
||||
let ThreadStartResponse { thread, .. } = to_response::<ThreadStartResponse>(resp)?;
|
||||
Ok(thread.id)
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
use anyhow::Result;
|
||||
use app_test_support::McpProcess;
|
||||
use app_test_support::create_fake_rollout;
|
||||
use app_test_support::create_mock_chat_completions_server;
|
||||
use app_test_support::create_mock_responses_server_repeating_assistant;
|
||||
use app_test_support::to_response;
|
||||
use codex_app_server_protocol::JSONRPCResponse;
|
||||
use codex_app_server_protocol::RequestId;
|
||||
@@ -23,7 +23,7 @@ const DEFAULT_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs
|
||||
|
||||
#[tokio::test]
|
||||
async fn thread_resume_returns_original_thread() -> Result<()> {
|
||||
let server = create_mock_chat_completions_server(vec![]).await;
|
||||
let server = create_mock_responses_server_repeating_assistant("Done").await;
|
||||
let codex_home = TempDir::new()?;
|
||||
create_config_toml(codex_home.path(), &server.uri())?;
|
||||
|
||||
@@ -66,7 +66,7 @@ async fn thread_resume_returns_original_thread() -> Result<()> {
|
||||
|
||||
#[tokio::test]
|
||||
async fn thread_resume_returns_rollout_history() -> Result<()> {
|
||||
let server = create_mock_chat_completions_server(vec![]).await;
|
||||
let server = create_mock_responses_server_repeating_assistant("Done").await;
|
||||
let codex_home = TempDir::new()?;
|
||||
create_config_toml(codex_home.path(), &server.uri())?;
|
||||
|
||||
@@ -130,7 +130,7 @@ async fn thread_resume_returns_rollout_history() -> Result<()> {
|
||||
|
||||
#[tokio::test]
|
||||
async fn thread_resume_prefers_path_over_thread_id() -> Result<()> {
|
||||
let server = create_mock_chat_completions_server(vec![]).await;
|
||||
let server = create_mock_responses_server_repeating_assistant("Done").await;
|
||||
let codex_home = TempDir::new()?;
|
||||
create_config_toml(codex_home.path(), &server.uri())?;
|
||||
|
||||
@@ -174,7 +174,7 @@ async fn thread_resume_prefers_path_over_thread_id() -> Result<()> {
|
||||
|
||||
#[tokio::test]
|
||||
async fn thread_resume_supports_history_and_overrides() -> Result<()> {
|
||||
let server = create_mock_chat_completions_server(vec![]).await;
|
||||
let server = create_mock_responses_server_repeating_assistant("Done").await;
|
||||
let codex_home = TempDir::new()?;
|
||||
create_config_toml(codex_home.path(), &server.uri())?;
|
||||
|
||||
@@ -247,7 +247,7 @@ model_provider = "mock_provider"
|
||||
[model_providers.mock_provider]
|
||||
name = "Mock provider for test"
|
||||
base_url = "{server_uri}/v1"
|
||||
wire_api = "chat"
|
||||
wire_api = "responses"
|
||||
request_max_retries = 0
|
||||
stream_max_retries = 0
|
||||
"#
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use anyhow::Result;
|
||||
use app_test_support::McpProcess;
|
||||
use app_test_support::create_final_assistant_message_sse_response;
|
||||
use app_test_support::create_mock_chat_completions_server_unchecked;
|
||||
use app_test_support::create_mock_responses_server_sequence_unchecked;
|
||||
use app_test_support::to_response;
|
||||
use codex_app_server_protocol::JSONRPCResponse;
|
||||
use codex_app_server_protocol::RequestId;
|
||||
@@ -28,7 +28,7 @@ async fn thread_rollback_drops_last_turns_and_persists_to_rollout() -> Result<()
|
||||
create_final_assistant_message_sse_response("Done")?,
|
||||
create_final_assistant_message_sse_response("Done")?,
|
||||
];
|
||||
let server = create_mock_chat_completions_server_unchecked(responses).await;
|
||||
let server = create_mock_responses_server_sequence_unchecked(responses).await;
|
||||
|
||||
let codex_home = TempDir::new()?;
|
||||
create_config_toml(codex_home.path(), &server.uri())?;
|
||||
@@ -168,7 +168,7 @@ model_provider = "mock_provider"
|
||||
[model_providers.mock_provider]
|
||||
name = "Mock provider for test"
|
||||
base_url = "{server_uri}/v1"
|
||||
wire_api = "chat"
|
||||
wire_api = "responses"
|
||||
request_max_retries = 0
|
||||
stream_max_retries = 0
|
||||
"#
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use anyhow::Result;
|
||||
use app_test_support::McpProcess;
|
||||
use app_test_support::create_mock_chat_completions_server;
|
||||
use app_test_support::create_mock_responses_server_repeating_assistant;
|
||||
use app_test_support::to_response;
|
||||
use codex_app_server_protocol::JSONRPCNotification;
|
||||
use codex_app_server_protocol::JSONRPCResponse;
|
||||
@@ -17,7 +17,7 @@ const DEFAULT_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs
|
||||
#[tokio::test]
|
||||
async fn thread_start_creates_thread_and_emits_started() -> Result<()> {
|
||||
// Provide a mock server and config so model wiring is valid.
|
||||
let server = create_mock_chat_completions_server(vec![]).await;
|
||||
let server = create_mock_responses_server_repeating_assistant("Done").await;
|
||||
|
||||
let codex_home = TempDir::new()?;
|
||||
create_config_toml(codex_home.path(), &server.uri())?;
|
||||
@@ -85,7 +85,7 @@ model_provider = "mock_provider"
|
||||
[model_providers.mock_provider]
|
||||
name = "Mock provider for test"
|
||||
base_url = "{server_uri}/v1"
|
||||
wire_api = "chat"
|
||||
wire_api = "responses"
|
||||
request_max_retries = 0
|
||||
stream_max_retries = 0
|
||||
"#
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
use anyhow::Result;
|
||||
use app_test_support::McpProcess;
|
||||
use app_test_support::create_mock_chat_completions_server;
|
||||
use app_test_support::create_mock_responses_server_sequence;
|
||||
use app_test_support::create_shell_command_sse_response;
|
||||
use app_test_support::to_response;
|
||||
use codex_app_server_protocol::JSONRPCNotification;
|
||||
@@ -41,7 +41,7 @@ async fn turn_interrupt_aborts_running_turn() -> Result<()> {
|
||||
std::fs::create_dir(&working_directory)?;
|
||||
|
||||
// Mock server: long-running shell command then (after abort) nothing else needed.
|
||||
let server = create_mock_chat_completions_server(vec![create_shell_command_sse_response(
|
||||
let server = create_mock_responses_server_sequence(vec![create_shell_command_sse_response(
|
||||
shell_command.clone(),
|
||||
Some(&working_directory),
|
||||
Some(10_000),
|
||||
@@ -135,7 +135,7 @@ model_provider = "mock_provider"
|
||||
[model_providers.mock_provider]
|
||||
name = "Mock provider for test"
|
||||
base_url = "{server_uri}/v1"
|
||||
wire_api = "chat"
|
||||
wire_api = "responses"
|
||||
request_max_retries = 0
|
||||
stream_max_retries = 0
|
||||
"#
|
||||
|
||||
@@ -3,8 +3,8 @@ use app_test_support::McpProcess;
|
||||
use app_test_support::create_apply_patch_sse_response;
|
||||
use app_test_support::create_exec_command_sse_response;
|
||||
use app_test_support::create_final_assistant_message_sse_response;
|
||||
use app_test_support::create_mock_chat_completions_server;
|
||||
use app_test_support::create_mock_chat_completions_server_unchecked;
|
||||
use app_test_support::create_mock_responses_server_sequence;
|
||||
use app_test_support::create_mock_responses_server_sequence_unchecked;
|
||||
use app_test_support::create_shell_command_sse_response;
|
||||
use app_test_support::format_with_current_shell_display;
|
||||
use app_test_support::to_response;
|
||||
@@ -50,7 +50,7 @@ async fn turn_start_emits_notifications_and_accepts_model_override() -> Result<(
|
||||
create_final_assistant_message_sse_response("Done")?,
|
||||
create_final_assistant_message_sse_response("Done")?,
|
||||
];
|
||||
let server = create_mock_chat_completions_server_unchecked(responses).await;
|
||||
let server = create_mock_responses_server_sequence_unchecked(responses).await;
|
||||
|
||||
let codex_home = TempDir::new()?;
|
||||
create_config_toml(codex_home.path(), &server.uri(), "never")?;
|
||||
@@ -157,7 +157,7 @@ async fn turn_start_accepts_local_image_input() -> Result<()> {
|
||||
];
|
||||
// Use the unchecked variant because the request payload includes a LocalImage
|
||||
// which the strict matcher does not currently cover.
|
||||
let server = create_mock_chat_completions_server_unchecked(responses).await;
|
||||
let server = create_mock_responses_server_sequence_unchecked(responses).await;
|
||||
|
||||
let codex_home = TempDir::new()?;
|
||||
create_config_toml(codex_home.path(), &server.uri(), "never")?;
|
||||
@@ -233,7 +233,7 @@ async fn turn_start_exec_approval_toggle_v2() -> Result<()> {
|
||||
)?,
|
||||
create_final_assistant_message_sse_response("done 2")?,
|
||||
];
|
||||
let server = create_mock_chat_completions_server(responses).await;
|
||||
let server = create_mock_responses_server_sequence(responses).await;
|
||||
// Default approval is untrusted to force elicitation on first turn.
|
||||
create_config_toml(codex_home.as_path(), &server.uri(), "untrusted")?;
|
||||
|
||||
@@ -357,7 +357,7 @@ async fn turn_start_exec_approval_decline_v2() -> Result<()> {
|
||||
)?,
|
||||
create_final_assistant_message_sse_response("done")?,
|
||||
];
|
||||
let server = create_mock_chat_completions_server(responses).await;
|
||||
let server = create_mock_responses_server_sequence(responses).await;
|
||||
create_config_toml(codex_home.as_path(), &server.uri(), "untrusted")?;
|
||||
|
||||
let mut mcp = McpProcess::new(codex_home.as_path()).await?;
|
||||
@@ -503,7 +503,7 @@ async fn turn_start_updates_sandbox_and_cwd_between_turns_v2() -> Result<()> {
|
||||
)?,
|
||||
create_final_assistant_message_sse_response("done second")?,
|
||||
];
|
||||
let server = create_mock_chat_completions_server(responses).await;
|
||||
let server = create_mock_responses_server_sequence(responses).await;
|
||||
create_config_toml(&codex_home, &server.uri(), "untrusted")?;
|
||||
|
||||
let mut mcp = McpProcess::new(&codex_home).await?;
|
||||
@@ -554,6 +554,7 @@ async fn turn_start_updates_sandbox_and_cwd_between_turns_v2() -> Result<()> {
|
||||
mcp.read_stream_until_notification_message("codex/event/task_complete"),
|
||||
)
|
||||
.await??;
|
||||
mcp.clear_message_buffer();
|
||||
|
||||
// second turn with workspace-write and second_cwd, ensure exec begins in second_cwd
|
||||
let second_turn = mcp
|
||||
@@ -636,7 +637,7 @@ async fn turn_start_file_change_approval_v2() -> Result<()> {
|
||||
create_apply_patch_sse_response(patch, "patch-call")?,
|
||||
create_final_assistant_message_sse_response("patch applied")?,
|
||||
];
|
||||
let server = create_mock_chat_completions_server(responses).await;
|
||||
let server = create_mock_responses_server_sequence(responses).await;
|
||||
create_config_toml(&codex_home, &server.uri(), "untrusted")?;
|
||||
|
||||
let mut mcp = McpProcess::new(&codex_home).await?;
|
||||
@@ -812,7 +813,7 @@ async fn turn_start_file_change_approval_accept_for_session_persists_v2() -> Res
|
||||
create_apply_patch_sse_response(patch_2, "patch-call-2")?,
|
||||
create_final_assistant_message_sse_response("patch 2 applied")?,
|
||||
];
|
||||
let server = create_mock_chat_completions_server(responses).await;
|
||||
let server = create_mock_responses_server_sequence(responses).await;
|
||||
create_config_toml(&codex_home, &server.uri(), "untrusted")?;
|
||||
|
||||
let mut mcp = McpProcess::new(&codex_home).await?;
|
||||
@@ -986,7 +987,7 @@ async fn turn_start_file_change_approval_decline_v2() -> Result<()> {
|
||||
create_apply_patch_sse_response(patch, "patch-call")?,
|
||||
create_final_assistant_message_sse_response("patch declined")?,
|
||||
];
|
||||
let server = create_mock_chat_completions_server(responses).await;
|
||||
let server = create_mock_responses_server_sequence(responses).await;
|
||||
create_config_toml(&codex_home, &server.uri(), "untrusted")?;
|
||||
|
||||
let mut mcp = McpProcess::new(&codex_home).await?;
|
||||
@@ -1124,7 +1125,7 @@ async fn command_execution_notifications_include_process_id() -> Result<()> {
|
||||
create_exec_command_sse_response("uexec-1")?,
|
||||
create_final_assistant_message_sse_response("done")?,
|
||||
];
|
||||
let server = create_mock_chat_completions_server(responses).await;
|
||||
let server = create_mock_responses_server_sequence(responses).await;
|
||||
let codex_home = TempDir::new()?;
|
||||
create_config_toml(codex_home.path(), &server.uri(), "never")?;
|
||||
let config_toml = codex_home.path().join("config.toml");
|
||||
@@ -1263,7 +1264,7 @@ model_provider = "mock_provider"
|
||||
[model_providers.mock_provider]
|
||||
name = "Mock provider for test"
|
||||
base_url = "{server_uri}/v1"
|
||||
wire_api = "chat"
|
||||
wire_api = "responses"
|
||||
request_max_retries = 0
|
||||
stream_max_retries = 0
|
||||
"#
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use codex_utils_cargo_bin::find_resource;
|
||||
use pretty_assertions::assert_eq;
|
||||
use std::collections::BTreeMap;
|
||||
use std::fs;
|
||||
@@ -8,7 +9,7 @@ use tempfile::tempdir;
|
||||
|
||||
#[test]
|
||||
fn test_apply_patch_scenarios() -> anyhow::Result<()> {
|
||||
let scenarios_dir = Path::new(env!("CARGO_MANIFEST_DIR")).join("tests/fixtures/scenarios");
|
||||
let scenarios_dir = find_resource!("tests/fixtures/scenarios")?;
|
||||
for scenario in fs::read_dir(scenarios_dir)? {
|
||||
let scenario = scenario?;
|
||||
let path = scenario.path();
|
||||
|
||||
@@ -73,8 +73,8 @@ impl Client {
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn from_auth(base_url: impl Into<String>, auth: &CodexAuth) -> Result<Self> {
|
||||
let token = auth.get_token().await.map_err(anyhow::Error::from)?;
|
||||
pub fn from_auth(base_url: impl Into<String>, auth: &CodexAuth) -> Result<Self> {
|
||||
let token = auth.get_token().map_err(anyhow::Error::from)?;
|
||||
let mut client = Self::new(base_url)?
|
||||
.with_user_agent(get_codex_user_agent())
|
||||
.with_bearer_token(token);
|
||||
|
||||
@@ -12,6 +12,7 @@ anyhow = { workspace = true }
|
||||
clap = { workspace = true, features = ["derive"] }
|
||||
codex-common = { workspace = true, features = ["cli"] }
|
||||
codex-core = { workspace = true }
|
||||
codex-utils-cargo-bin = { workspace = true }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
serde_json = { workspace = true }
|
||||
tokio = { workspace = true, features = ["full"] }
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use codex_core::CodexAuth;
|
||||
use codex_core::AuthManager;
|
||||
use std::path::Path;
|
||||
use std::sync::LazyLock;
|
||||
use std::sync::RwLock;
|
||||
@@ -23,9 +23,10 @@ pub async fn init_chatgpt_token_from_auth(
|
||||
codex_home: &Path,
|
||||
auth_credentials_store_mode: AuthCredentialsStoreMode,
|
||||
) -> std::io::Result<()> {
|
||||
let auth = CodexAuth::from_auth_storage(codex_home, auth_credentials_store_mode)?;
|
||||
if let Some(auth) = auth {
|
||||
let token_data = auth.get_token_data().await?;
|
||||
let auth_manager =
|
||||
AuthManager::new(codex_home.to_path_buf(), false, auth_credentials_store_mode);
|
||||
if let Some(auth) = auth_manager.auth().await {
|
||||
let token_data = auth.get_token_data()?;
|
||||
set_chatgpt_token_data(token_data);
|
||||
}
|
||||
Ok(())
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use codex_chatgpt::apply_command::apply_diff_from_task;
|
||||
use codex_chatgpt::get_task::GetTaskResponse;
|
||||
use std::path::Path;
|
||||
use codex_utils_cargo_bin::find_resource;
|
||||
use tempfile::TempDir;
|
||||
use tokio::process::Command;
|
||||
|
||||
@@ -68,8 +68,8 @@ async fn create_temp_git_repo() -> anyhow::Result<TempDir> {
|
||||
}
|
||||
|
||||
async fn mock_get_task_with_fixture() -> anyhow::Result<GetTaskResponse> {
|
||||
let fixture_path = Path::new(env!("CARGO_MANIFEST_DIR")).join("tests/task_turn_fixture.json");
|
||||
let fixture_content = std::fs::read_to_string(fixture_path)?;
|
||||
let fixture_path = find_resource!("tests/task_turn_fixture.json")?;
|
||||
let fixture_content = tokio::fs::read_to_string(fixture_path).await?;
|
||||
let response: GetTaskResponse = serde_json::from_str(&fixture_content)?;
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
@@ -30,7 +30,6 @@ codex-exec = { workspace = true }
|
||||
codex-execpolicy = { workspace = true }
|
||||
codex-login = { workspace = true }
|
||||
codex-mcp-server = { workspace = true }
|
||||
codex-process-hardening = { workspace = true }
|
||||
codex-protocol = { workspace = true }
|
||||
codex-responses-api-proxy = { workspace = true }
|
||||
codex-rmcp-client = { workspace = true }
|
||||
@@ -38,7 +37,6 @@ codex-stdio-to-uds = { workspace = true }
|
||||
codex-tui = { workspace = true }
|
||||
codex-tui2 = { workspace = true }
|
||||
codex-utils-absolute-path = { workspace = true }
|
||||
ctor = { workspace = true }
|
||||
libc = { workspace = true }
|
||||
owo-colors = { workspace = true }
|
||||
regex-lite = { workspace = true }
|
||||
|
||||
@@ -155,7 +155,7 @@ pub async fn run_login_status(cli_config_overrides: CliConfigOverrides) -> ! {
|
||||
|
||||
match CodexAuth::from_auth_storage(&config.codex_home, config.cli_auth_credentials_store_mode) {
|
||||
Ok(Some(auth)) => match auth.mode {
|
||||
AuthMode::ApiKey => match auth.get_token().await {
|
||||
AuthMode::ApiKey => match auth.get_token() {
|
||||
Ok(api_key) => {
|
||||
eprintln!("Logged in using an API key - {}", safe_format_key(&api_key));
|
||||
std::process::exit(0);
|
||||
|
||||
@@ -418,14 +418,6 @@ fn stage_str(stage: codex_core::features::Stage) -> &'static str {
|
||||
}
|
||||
}
|
||||
|
||||
/// As early as possible in the process lifecycle, apply hardening measures. We
|
||||
/// skip this in debug builds to avoid interfering with debugging.
|
||||
#[ctor::ctor]
|
||||
#[cfg(not(debug_assertions))]
|
||||
fn pre_main_hardening() {
|
||||
codex_process_hardening::pre_main_hardening();
|
||||
}
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
arg0_dispatch_or_else(|codex_linux_sandbox_exe| async move {
|
||||
cli_main(codex_linux_sandbox_exe).await?;
|
||||
|
||||
@@ -10,7 +10,6 @@ pub use cli::Cli;
|
||||
use anyhow::anyhow;
|
||||
use chrono::Utc;
|
||||
use codex_cloud_tasks_client::TaskStatus;
|
||||
use codex_login::AuthManager;
|
||||
use owo_colors::OwoColorize;
|
||||
use owo_colors::Stream;
|
||||
use std::cmp::Ordering;
|
||||
@@ -65,7 +64,11 @@ async fn init_backend(user_agent_suffix: &str) -> anyhow::Result<BackendContext>
|
||||
append_error_log(format!("startup: base_url={base_url} path_style={style}"));
|
||||
|
||||
let auth_manager = util::load_auth_manager().await;
|
||||
let auth = match auth_manager.as_ref().and_then(AuthManager::auth) {
|
||||
let auth = match auth_manager.as_ref() {
|
||||
Some(manager) => manager.auth().await,
|
||||
None => None,
|
||||
};
|
||||
let auth = match auth {
|
||||
Some(auth) => auth,
|
||||
None => {
|
||||
eprintln!(
|
||||
@@ -79,7 +82,7 @@ async fn init_backend(user_agent_suffix: &str) -> anyhow::Result<BackendContext>
|
||||
append_error_log(format!("auth: mode=ChatGPT account_id={acc}"));
|
||||
}
|
||||
|
||||
let token = match auth.get_token().await {
|
||||
let token = match auth.get_token() {
|
||||
Ok(t) if !t.is_empty() => t,
|
||||
_ => {
|
||||
eprintln!(
|
||||
|
||||
@@ -85,8 +85,8 @@ pub async fn build_chatgpt_headers() -> HeaderMap {
|
||||
HeaderValue::from_str(&ua).unwrap_or(HeaderValue::from_static("codex-cli")),
|
||||
);
|
||||
if let Some(am) = load_auth_manager().await
|
||||
&& let Some(auth) = am.auth()
|
||||
&& let Ok(tok) = auth.get_token().await
|
||||
&& let Some(auth) = am.auth().await
|
||||
&& let Ok(tok) = auth.get_token()
|
||||
&& !tok.is_empty()
|
||||
{
|
||||
let v = format!("Bearer {tok}");
|
||||
|
||||
@@ -10,6 +10,7 @@ use crate::provider::WireApi;
|
||||
use crate::sse::chat::spawn_chat_stream;
|
||||
use crate::telemetry::SseTelemetry;
|
||||
use codex_client::HttpTransport;
|
||||
use codex_client::RequestCompression;
|
||||
use codex_client::RequestTelemetry;
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::ReasoningItemContent;
|
||||
@@ -80,7 +81,13 @@ impl<T: HttpTransport, A: AuthProvider> ChatClient<T, A> {
|
||||
extra_headers: HeaderMap,
|
||||
) -> Result<ResponseStream, ApiError> {
|
||||
self.streaming
|
||||
.stream(self.path(), body, extra_headers, spawn_chat_stream)
|
||||
.stream(
|
||||
self.path(),
|
||||
body,
|
||||
extra_headers,
|
||||
RequestCompression::None,
|
||||
spawn_chat_stream,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,9 +9,11 @@ use crate::provider::Provider;
|
||||
use crate::provider::WireApi;
|
||||
use crate::requests::ResponsesRequest;
|
||||
use crate::requests::ResponsesRequestBuilder;
|
||||
use crate::requests::responses::Compression;
|
||||
use crate::sse::spawn_response_stream;
|
||||
use crate::telemetry::SseTelemetry;
|
||||
use codex_client::HttpTransport;
|
||||
use codex_client::RequestCompression;
|
||||
use codex_client::RequestTelemetry;
|
||||
use codex_protocol::protocol::SessionSource;
|
||||
use http::HeaderMap;
|
||||
@@ -33,6 +35,7 @@ pub struct ResponsesOptions {
|
||||
pub conversation_id: Option<String>,
|
||||
pub session_source: Option<SessionSource>,
|
||||
pub extra_headers: HeaderMap,
|
||||
pub compression: Compression,
|
||||
}
|
||||
|
||||
impl<T: HttpTransport, A: AuthProvider> ResponsesClient<T, A> {
|
||||
@@ -56,7 +59,8 @@ impl<T: HttpTransport, A: AuthProvider> ResponsesClient<T, A> {
|
||||
&self,
|
||||
request: ResponsesRequest,
|
||||
) -> Result<ResponseStream, ApiError> {
|
||||
self.stream(request.body, request.headers).await
|
||||
self.stream(request.body, request.headers, request.compression)
|
||||
.await
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip_all, err)]
|
||||
@@ -75,6 +79,7 @@ impl<T: HttpTransport, A: AuthProvider> ResponsesClient<T, A> {
|
||||
conversation_id,
|
||||
session_source,
|
||||
extra_headers,
|
||||
compression,
|
||||
} = options;
|
||||
|
||||
let request = ResponsesRequestBuilder::new(model, &prompt.instructions, &prompt.input)
|
||||
@@ -88,6 +93,7 @@ impl<T: HttpTransport, A: AuthProvider> ResponsesClient<T, A> {
|
||||
.session_source(session_source)
|
||||
.store_override(store_override)
|
||||
.extra_headers(extra_headers)
|
||||
.compression(compression)
|
||||
.build(self.streaming.provider())?;
|
||||
|
||||
self.stream_request(request).await
|
||||
@@ -104,9 +110,21 @@ impl<T: HttpTransport, A: AuthProvider> ResponsesClient<T, A> {
|
||||
&self,
|
||||
body: Value,
|
||||
extra_headers: HeaderMap,
|
||||
compression: Compression,
|
||||
) -> Result<ResponseStream, ApiError> {
|
||||
let compression = match compression {
|
||||
Compression::None => RequestCompression::None,
|
||||
Compression::Zstd => RequestCompression::Zstd,
|
||||
};
|
||||
|
||||
self.streaming
|
||||
.stream(self.path(), body, extra_headers, spawn_response_stream)
|
||||
.stream(
|
||||
self.path(),
|
||||
body,
|
||||
extra_headers,
|
||||
compression,
|
||||
spawn_response_stream,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ use crate::provider::Provider;
|
||||
use crate::telemetry::SseTelemetry;
|
||||
use crate::telemetry::run_with_request_telemetry;
|
||||
use codex_client::HttpTransport;
|
||||
use codex_client::RequestCompression;
|
||||
use codex_client::RequestTelemetry;
|
||||
use codex_client::StreamResponse;
|
||||
use http::HeaderMap;
|
||||
@@ -52,6 +53,7 @@ impl<T: HttpTransport, A: AuthProvider> StreamingClient<T, A> {
|
||||
path: &str,
|
||||
body: Value,
|
||||
extra_headers: HeaderMap,
|
||||
compression: RequestCompression,
|
||||
spawner: fn(StreamResponse, Duration, Option<Arc<dyn SseTelemetry>>) -> ResponseStream,
|
||||
) -> Result<ResponseStream, ApiError> {
|
||||
let builder = || {
|
||||
@@ -62,6 +64,7 @@ impl<T: HttpTransport, A: AuthProvider> StreamingClient<T, A> {
|
||||
http::HeaderValue::from_static("text/event-stream"),
|
||||
);
|
||||
req.body = Some(body.clone());
|
||||
req.compression = compression;
|
||||
add_auth_headers(&self.auth, req)
|
||||
};
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use codex_client::Request;
|
||||
use codex_client::RequestCompression;
|
||||
use codex_client::RetryOn;
|
||||
use codex_client::RetryPolicy;
|
||||
use http::Method;
|
||||
@@ -87,6 +88,7 @@ impl Provider {
|
||||
url: self.url_for_path(path),
|
||||
headers: self.headers.clone(),
|
||||
body: None,
|
||||
compression: RequestCompression::None,
|
||||
timeout: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,10 +11,18 @@ use codex_protocol::protocol::SessionSource;
|
||||
use http::HeaderMap;
|
||||
use serde_json::Value;
|
||||
|
||||
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
|
||||
pub enum Compression {
|
||||
#[default]
|
||||
None,
|
||||
Zstd,
|
||||
}
|
||||
|
||||
/// Assembled request body plus headers for a Responses stream request.
|
||||
pub struct ResponsesRequest {
|
||||
pub body: Value,
|
||||
pub headers: HeaderMap,
|
||||
pub compression: Compression,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
@@ -32,6 +40,7 @@ pub struct ResponsesRequestBuilder<'a> {
|
||||
session_source: Option<SessionSource>,
|
||||
store_override: Option<bool>,
|
||||
headers: HeaderMap,
|
||||
compression: Compression,
|
||||
}
|
||||
|
||||
impl<'a> ResponsesRequestBuilder<'a> {
|
||||
@@ -94,6 +103,11 @@ impl<'a> ResponsesRequestBuilder<'a> {
|
||||
self
|
||||
}
|
||||
|
||||
pub fn compression(mut self, compression: Compression) -> Self {
|
||||
self.compression = compression;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self, provider: &Provider) -> Result<ResponsesRequest, ApiError> {
|
||||
let model = self
|
||||
.model
|
||||
@@ -138,7 +152,11 @@ impl<'a> ResponsesRequestBuilder<'a> {
|
||||
insert_header(&mut headers, "x-openai-subagent", &subagent);
|
||||
}
|
||||
|
||||
Ok(ResponsesRequest { body, headers })
|
||||
Ok(ResponsesRequest {
|
||||
body,
|
||||
headers,
|
||||
compression: self.compression,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ use codex_api::Provider;
|
||||
use codex_api::ResponsesClient;
|
||||
use codex_api::ResponsesOptions;
|
||||
use codex_api::WireApi;
|
||||
use codex_api::requests::responses::Compression;
|
||||
use codex_client::HttpTransport;
|
||||
use codex_client::Request;
|
||||
use codex_client::Response;
|
||||
@@ -229,7 +230,9 @@ async fn responses_client_uses_responses_path_for_responses_wire() -> Result<()>
|
||||
let client = ResponsesClient::new(transport, provider("openai", WireApi::Responses), NoAuth);
|
||||
|
||||
let body = serde_json::json!({ "echo": true });
|
||||
let _stream = client.stream(body, HeaderMap::new()).await?;
|
||||
let _stream = client
|
||||
.stream(body, HeaderMap::new(), Compression::None)
|
||||
.await?;
|
||||
|
||||
let requests = state.take_stream_requests();
|
||||
assert_path_ends_with(&requests, "/responses");
|
||||
@@ -243,7 +246,9 @@ async fn responses_client_uses_chat_path_for_chat_wire() -> Result<()> {
|
||||
let client = ResponsesClient::new(transport, provider("openai", WireApi::Chat), NoAuth);
|
||||
|
||||
let body = serde_json::json!({ "echo": true });
|
||||
let _stream = client.stream(body, HeaderMap::new()).await?;
|
||||
let _stream = client
|
||||
.stream(body, HeaderMap::new(), Compression::None)
|
||||
.await?;
|
||||
|
||||
let requests = state.take_stream_requests();
|
||||
assert_path_ends_with(&requests, "/chat/completions");
|
||||
@@ -258,7 +263,9 @@ async fn streaming_client_adds_auth_headers() -> Result<()> {
|
||||
let client = ResponsesClient::new(transport, provider("openai", WireApi::Responses), auth);
|
||||
|
||||
let body = serde_json::json!({ "model": "gpt-test" });
|
||||
let _stream = client.stream(body, HeaderMap::new()).await?;
|
||||
let _stream = client
|
||||
.stream(body, HeaderMap::new(), Compression::None)
|
||||
.await?;
|
||||
|
||||
let requests = state.take_stream_requests();
|
||||
assert_eq!(requests.len(), 1);
|
||||
|
||||
@@ -9,6 +9,7 @@ use codex_api::Provider;
|
||||
use codex_api::ResponseEvent;
|
||||
use codex_api::ResponsesClient;
|
||||
use codex_api::WireApi;
|
||||
use codex_api::requests::responses::Compression;
|
||||
use codex_client::HttpTransport;
|
||||
use codex_client::Request;
|
||||
use codex_client::Response;
|
||||
@@ -124,7 +125,11 @@ async fn responses_stream_parses_items_and_completed_end_to_end() -> Result<()>
|
||||
let client = ResponsesClient::new(transport, provider("openai", WireApi::Responses), NoAuth);
|
||||
|
||||
let mut stream = client
|
||||
.stream(serde_json::json!({"echo": true}), HeaderMap::new())
|
||||
.stream(
|
||||
serde_json::json!({"echo": true}),
|
||||
HeaderMap::new(),
|
||||
Compression::None,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let mut events = Vec::new();
|
||||
@@ -189,7 +194,11 @@ async fn responses_stream_aggregates_output_text_deltas() -> Result<()> {
|
||||
let client = ResponsesClient::new(transport, provider("openai", WireApi::Responses), NoAuth);
|
||||
|
||||
let stream = client
|
||||
.stream(serde_json::json!({"echo": true}), HeaderMap::new())
|
||||
.stream(
|
||||
serde_json::json!({"echo": true}),
|
||||
HeaderMap::new(),
|
||||
Compression::None,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let mut stream = stream.aggregate();
|
||||
|
||||
@@ -19,6 +19,7 @@ thiserror = { workspace = true }
|
||||
tokio = { workspace = true, features = ["macros", "rt", "time", "sync"] }
|
||||
tracing = { workspace = true }
|
||||
tracing-opentelemetry = { workspace = true }
|
||||
zstd = { workspace = true }
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
@@ -104,6 +104,13 @@ impl CodexRequestBuilder {
|
||||
self.map(|builder| builder.json(value))
|
||||
}
|
||||
|
||||
pub fn body<B>(self, body: B) -> Self
|
||||
where
|
||||
B: Into<reqwest::Body>,
|
||||
{
|
||||
self.map(|builder| builder.body(body))
|
||||
}
|
||||
|
||||
pub async fn send(self) -> Result<Response, reqwest::Error> {
|
||||
let headers = trace_headers();
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ pub use crate::default_client::CodexRequestBuilder;
|
||||
pub use crate::error::StreamError;
|
||||
pub use crate::error::TransportError;
|
||||
pub use crate::request::Request;
|
||||
pub use crate::request::RequestCompression;
|
||||
pub use crate::request::Response;
|
||||
pub use crate::retry::RetryOn;
|
||||
pub use crate::retry::RetryPolicy;
|
||||
|
||||
@@ -5,12 +5,20 @@ use serde::Serialize;
|
||||
use serde_json::Value;
|
||||
use std::time::Duration;
|
||||
|
||||
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
|
||||
pub enum RequestCompression {
|
||||
#[default]
|
||||
None,
|
||||
Zstd,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Request {
|
||||
pub method: Method,
|
||||
pub url: String,
|
||||
pub headers: HeaderMap,
|
||||
pub body: Option<Value>,
|
||||
pub compression: RequestCompression,
|
||||
pub timeout: Option<Duration>,
|
||||
}
|
||||
|
||||
@@ -21,6 +29,7 @@ impl Request {
|
||||
url,
|
||||
headers: HeaderMap::new(),
|
||||
body: None,
|
||||
compression: RequestCompression::None,
|
||||
timeout: None,
|
||||
}
|
||||
}
|
||||
@@ -29,6 +38,11 @@ impl Request {
|
||||
self.body = serde_json::to_value(body).ok();
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_compression(mut self, compression: RequestCompression) -> Self {
|
||||
self.compression = compression;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
|
||||
@@ -2,6 +2,7 @@ use crate::default_client::CodexHttpClient;
|
||||
use crate::default_client::CodexRequestBuilder;
|
||||
use crate::error::TransportError;
|
||||
use crate::request::Request;
|
||||
use crate::request::RequestCompression;
|
||||
use crate::request::Response;
|
||||
use async_trait::async_trait;
|
||||
use bytes::Bytes;
|
||||
@@ -41,18 +42,70 @@ impl ReqwestTransport {
|
||||
}
|
||||
|
||||
fn build(&self, req: Request) -> Result<CodexRequestBuilder, TransportError> {
|
||||
let mut builder = self
|
||||
.client
|
||||
.request(
|
||||
Method::from_bytes(req.method.as_str().as_bytes()).unwrap_or(Method::GET),
|
||||
&req.url,
|
||||
)
|
||||
.headers(req.headers);
|
||||
if let Some(timeout) = req.timeout {
|
||||
let Request {
|
||||
method,
|
||||
url,
|
||||
mut headers,
|
||||
body,
|
||||
compression,
|
||||
timeout,
|
||||
} = req;
|
||||
|
||||
let mut builder = self.client.request(
|
||||
Method::from_bytes(method.as_str().as_bytes()).unwrap_or(Method::GET),
|
||||
&url,
|
||||
);
|
||||
|
||||
if let Some(timeout) = timeout {
|
||||
builder = builder.timeout(timeout);
|
||||
}
|
||||
if let Some(body) = req.body {
|
||||
builder = builder.json(&body);
|
||||
|
||||
if let Some(body) = body {
|
||||
if compression != RequestCompression::None {
|
||||
if headers.contains_key(http::header::CONTENT_ENCODING) {
|
||||
return Err(TransportError::Build(
|
||||
"request compression was requested but content-encoding is already set"
|
||||
.to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let json = serde_json::to_vec(&body)
|
||||
.map_err(|err| TransportError::Build(err.to_string()))?;
|
||||
let pre_compression_bytes = json.len();
|
||||
let compression_start = std::time::Instant::now();
|
||||
let (compressed, content_encoding) = match compression {
|
||||
RequestCompression::None => unreachable!("guarded by compression != None"),
|
||||
RequestCompression::Zstd => (
|
||||
zstd::stream::encode_all(std::io::Cursor::new(json), 3)
|
||||
.map_err(|err| TransportError::Build(err.to_string()))?,
|
||||
http::HeaderValue::from_static("zstd"),
|
||||
),
|
||||
};
|
||||
let post_compression_bytes = compressed.len();
|
||||
let compression_duration = compression_start.elapsed();
|
||||
|
||||
// Ensure the server knows to unpack the request body.
|
||||
headers.insert(http::header::CONTENT_ENCODING, content_encoding);
|
||||
if !headers.contains_key(http::header::CONTENT_TYPE) {
|
||||
headers.insert(
|
||||
http::header::CONTENT_TYPE,
|
||||
http::HeaderValue::from_static("application/json"),
|
||||
);
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
pre_compression_bytes,
|
||||
post_compression_bytes,
|
||||
compression_duration_ms = compression_duration.as_millis(),
|
||||
"Compressed request body with zstd"
|
||||
);
|
||||
|
||||
builder = builder.headers(headers).body(compressed);
|
||||
} else {
|
||||
builder = builder.headers(headers).json(&body);
|
||||
}
|
||||
} else {
|
||||
builder = builder.headers(headers);
|
||||
}
|
||||
Ok(builder)
|
||||
}
|
||||
|
||||
@@ -122,11 +122,11 @@ keyring = { workspace = true, features = ["sync-secret-service"] }
|
||||
assert_cmd = { workspace = true }
|
||||
assert_matches = { workspace = true }
|
||||
codex-arg0 = { workspace = true }
|
||||
codex-core = { path = ".", features = ["deterministic_process_ids"] }
|
||||
codex-core = { path = ".", default-features = false, features = ["deterministic_process_ids"] }
|
||||
codex-otel = { workspace = true, features = ["disable-default-metrics-exporter"] }
|
||||
codex-utils-cargo-bin = { workspace = true }
|
||||
core_test_support = { workspace = true }
|
||||
ctor = { workspace = true }
|
||||
escargot = { workspace = true }
|
||||
image = { workspace = true, features = ["jpeg", "png"] }
|
||||
maplit = { workspace = true }
|
||||
predicates = { workspace = true }
|
||||
@@ -137,6 +137,7 @@ tracing-subscriber = { workspace = true }
|
||||
tracing-test = { workspace = true, features = ["no-env-filter"] }
|
||||
walkdir = { workspace = true }
|
||||
wiremock = { workspace = true }
|
||||
zstd = { workspace = true }
|
||||
|
||||
[package.metadata.cargo-shear]
|
||||
ignored = ["openssl-sys"]
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -100,7 +100,7 @@ fn extract_request_id(headers: Option<&HeaderMap>) -> Option<String> {
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) async fn auth_provider_from_auth(
|
||||
pub(crate) fn auth_provider_from_auth(
|
||||
auth: Option<CodexAuth>,
|
||||
provider: &ModelProviderInfo,
|
||||
) -> crate::error::Result<CoreAuthProvider> {
|
||||
@@ -119,7 +119,7 @@ pub(crate) async fn auth_provider_from_auth(
|
||||
}
|
||||
|
||||
if let Some(auth) = auth {
|
||||
let token = auth.get_token().await?;
|
||||
let token = auth.get_token()?;
|
||||
Ok(CoreAuthProvider {
|
||||
token: Some(token),
|
||||
account_id: auth.get_account_id(),
|
||||
|
||||
@@ -8,12 +8,10 @@ use serde::Serialize;
|
||||
use serial_test::serial;
|
||||
use std::env;
|
||||
use std::fmt::Debug;
|
||||
use std::io::ErrorKind;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
use std::time::Duration;
|
||||
|
||||
use codex_app_server_protocol::AuthMode;
|
||||
use codex_protocol::config_types::ForcedLoginMethod;
|
||||
@@ -77,10 +75,6 @@ impl RefreshTokenError {
|
||||
Self::Transient(_) => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn other_with_message(message: impl Into<String>) -> Self {
|
||||
Self::Transient(std::io::Error::other(message.into()))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<RefreshTokenError> for std::io::Error {
|
||||
@@ -93,40 +87,6 @@ impl From<RefreshTokenError> for std::io::Error {
|
||||
}
|
||||
|
||||
impl CodexAuth {
|
||||
pub async fn refresh_token(&self) -> Result<String, RefreshTokenError> {
|
||||
tracing::info!("Refreshing token");
|
||||
|
||||
let token_data = self.get_current_token_data().ok_or_else(|| {
|
||||
RefreshTokenError::Transient(std::io::Error::other("Token data is not available."))
|
||||
})?;
|
||||
let token = token_data.refresh_token;
|
||||
|
||||
let refresh_response = try_refresh_token(token, &self.client).await?;
|
||||
|
||||
let updated = update_tokens(
|
||||
&self.storage,
|
||||
refresh_response.id_token,
|
||||
refresh_response.access_token,
|
||||
refresh_response.refresh_token,
|
||||
)
|
||||
.await
|
||||
.map_err(RefreshTokenError::from)?;
|
||||
|
||||
if let Ok(mut auth_lock) = self.auth_dot_json.lock() {
|
||||
*auth_lock = Some(updated.clone());
|
||||
}
|
||||
|
||||
let access = match updated.tokens {
|
||||
Some(t) => t.access_token,
|
||||
None => {
|
||||
return Err(RefreshTokenError::other_with_message(
|
||||
"Token data is not available after refresh.",
|
||||
));
|
||||
}
|
||||
};
|
||||
Ok(access)
|
||||
}
|
||||
|
||||
/// Loads the available auth information from auth storage.
|
||||
pub fn from_auth_storage(
|
||||
codex_home: &Path,
|
||||
@@ -135,62 +95,23 @@ impl CodexAuth {
|
||||
load_auth(codex_home, false, auth_credentials_store_mode)
|
||||
}
|
||||
|
||||
pub async fn get_token_data(&self) -> Result<TokenData, std::io::Error> {
|
||||
pub fn get_token_data(&self) -> Result<TokenData, std::io::Error> {
|
||||
let auth_dot_json: Option<AuthDotJson> = self.get_current_auth_json();
|
||||
match auth_dot_json {
|
||||
Some(AuthDotJson {
|
||||
tokens: Some(mut tokens),
|
||||
last_refresh: Some(last_refresh),
|
||||
tokens: Some(tokens),
|
||||
last_refresh: Some(_),
|
||||
..
|
||||
}) => {
|
||||
if last_refresh < Utc::now() - chrono::Duration::days(TOKEN_REFRESH_INTERVAL) {
|
||||
let refresh_result = tokio::time::timeout(
|
||||
Duration::from_secs(60),
|
||||
try_refresh_token(tokens.refresh_token.clone(), &self.client),
|
||||
)
|
||||
.await;
|
||||
let refresh_response = match refresh_result {
|
||||
Ok(Ok(response)) => response,
|
||||
Ok(Err(err)) => return Err(err.into()),
|
||||
Err(_) => {
|
||||
return Err(std::io::Error::new(
|
||||
ErrorKind::TimedOut,
|
||||
"timed out while refreshing OpenAI API key",
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let updated_auth_dot_json = update_tokens(
|
||||
&self.storage,
|
||||
refresh_response.id_token,
|
||||
refresh_response.access_token,
|
||||
refresh_response.refresh_token,
|
||||
)
|
||||
.await?;
|
||||
|
||||
tokens = updated_auth_dot_json
|
||||
.tokens
|
||||
.clone()
|
||||
.ok_or(std::io::Error::other(
|
||||
"Token data is not available after refresh.",
|
||||
))?;
|
||||
|
||||
#[expect(clippy::unwrap_used)]
|
||||
let mut auth_lock = self.auth_dot_json.lock().unwrap();
|
||||
*auth_lock = Some(updated_auth_dot_json);
|
||||
}
|
||||
|
||||
Ok(tokens)
|
||||
}
|
||||
}) => Ok(tokens),
|
||||
_ => Err(std::io::Error::other("Token data is not available.")),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_token(&self) -> Result<String, std::io::Error> {
|
||||
pub fn get_token(&self) -> Result<String, std::io::Error> {
|
||||
match self.mode {
|
||||
AuthMode::ApiKey => Ok(self.api_key.clone().unwrap_or_default()),
|
||||
AuthMode::ChatGPT => {
|
||||
let id_token = self.get_token_data().await?.access_token;
|
||||
let id_token = self.get_token_data()?.access_token;
|
||||
Ok(id_token)
|
||||
}
|
||||
}
|
||||
@@ -338,7 +259,7 @@ pub fn load_auth_dot_json(
|
||||
storage.load()
|
||||
}
|
||||
|
||||
pub async fn enforce_login_restrictions(config: &Config) -> std::io::Result<()> {
|
||||
pub fn enforce_login_restrictions(config: &Config) -> std::io::Result<()> {
|
||||
let Some(auth) = load_auth(
|
||||
&config.codex_home,
|
||||
true,
|
||||
@@ -376,7 +297,7 @@ pub async fn enforce_login_restrictions(config: &Config) -> std::io::Result<()>
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let token_data = match auth.get_token_data().await {
|
||||
let token_data = match auth.get_token_data() {
|
||||
Ok(data) => data,
|
||||
Err(err) => {
|
||||
return logout_with_message(
|
||||
@@ -525,6 +446,7 @@ async fn try_refresh_token(
|
||||
Ok(refresh_response)
|
||||
} else {
|
||||
let body = response.text().await.unwrap_or_default();
|
||||
tracing::error!("Failed to refresh token: {status}: {body}");
|
||||
if status == StatusCode::UNAUTHORIZED {
|
||||
let failed = classify_refresh_token_failure(&body);
|
||||
Err(RefreshTokenError::Permanent(failed))
|
||||
@@ -623,6 +545,89 @@ struct CachedAuth {
|
||||
auth: Option<CodexAuth>,
|
||||
}
|
||||
|
||||
enum UnauthorizedRecoveryStep {
|
||||
Reload,
|
||||
RefreshToken,
|
||||
Done,
|
||||
}
|
||||
|
||||
enum ReloadOutcome {
|
||||
Reloaded,
|
||||
Skipped,
|
||||
}
|
||||
|
||||
// UnauthorizedRecovery is a state machine that handles an attempt to refresh the authentication when requests
|
||||
// to API fail with 401 status code.
|
||||
// The client calls next() every time it encounters a 401 error, one time per retry.
|
||||
// For API key based authentication, we don't do anything and let the error bubble to the user.
|
||||
// For ChatGPT based authentication, we:
|
||||
// 1. Attempt to reload the auth data from disk. We only reload if the account id matches the one the current process is running as.
|
||||
// 2. Attempt to refresh the token using OAuth token refresh flow.
|
||||
// If after both steps the server still responds with 401 we let the error bubble to the user.
|
||||
pub struct UnauthorizedRecovery {
|
||||
manager: Arc<AuthManager>,
|
||||
step: UnauthorizedRecoveryStep,
|
||||
expected_account_id: Option<String>,
|
||||
}
|
||||
|
||||
impl UnauthorizedRecovery {
|
||||
fn new(manager: Arc<AuthManager>) -> Self {
|
||||
let expected_account_id = manager
|
||||
.auth_cached()
|
||||
.as_ref()
|
||||
.and_then(CodexAuth::get_account_id);
|
||||
Self {
|
||||
manager,
|
||||
step: UnauthorizedRecoveryStep::Reload,
|
||||
expected_account_id,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn has_next(&self) -> bool {
|
||||
if !self
|
||||
.manager
|
||||
.auth_cached()
|
||||
.is_some_and(|auth| auth.mode == AuthMode::ChatGPT)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
!matches!(self.step, UnauthorizedRecoveryStep::Done)
|
||||
}
|
||||
|
||||
pub async fn next(&mut self) -> Result<(), RefreshTokenError> {
|
||||
if !self.has_next() {
|
||||
return Err(RefreshTokenError::Permanent(RefreshTokenFailedError::new(
|
||||
RefreshTokenFailedReason::Other,
|
||||
"No more recovery steps available.",
|
||||
)));
|
||||
}
|
||||
|
||||
match self.step {
|
||||
UnauthorizedRecoveryStep::Reload => {
|
||||
match self
|
||||
.manager
|
||||
.reload_if_account_id_matches(self.expected_account_id.as_deref())
|
||||
{
|
||||
ReloadOutcome::Reloaded => {
|
||||
self.step = UnauthorizedRecoveryStep::RefreshToken;
|
||||
}
|
||||
ReloadOutcome::Skipped => {
|
||||
self.manager.refresh_token().await?;
|
||||
self.step = UnauthorizedRecoveryStep::Done;
|
||||
}
|
||||
}
|
||||
}
|
||||
UnauthorizedRecoveryStep::RefreshToken => {
|
||||
self.manager.refresh_token().await?;
|
||||
self.step = UnauthorizedRecoveryStep::Done;
|
||||
}
|
||||
UnauthorizedRecoveryStep::Done => {}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Central manager providing a single source of truth for auth.json derived
|
||||
/// authentication data. It loads once (or on preference change) and then
|
||||
/// hands out cloned `CodexAuth` values so the rest of the program has a
|
||||
@@ -689,28 +694,53 @@ impl AuthManager {
|
||||
})
|
||||
}
|
||||
|
||||
/// Current cached auth (clone). May be `None` if not logged in or load failed.
|
||||
pub fn auth(&self) -> Option<CodexAuth> {
|
||||
/// Current cached auth (clone) without attempting a refresh.
|
||||
pub fn auth_cached(&self) -> Option<CodexAuth> {
|
||||
self.inner.read().ok().and_then(|c| c.auth.clone())
|
||||
}
|
||||
|
||||
/// Current cached auth (clone). May be `None` if not logged in or load failed.
|
||||
/// Refreshes cached ChatGPT tokens if they are stale before returning.
|
||||
pub async fn auth(&self) -> Option<CodexAuth> {
|
||||
let auth = self.auth_cached()?;
|
||||
if let Err(err) = self.refresh_if_stale(&auth).await {
|
||||
tracing::error!("Failed to refresh token: {}", err);
|
||||
return Some(auth);
|
||||
}
|
||||
self.auth_cached()
|
||||
}
|
||||
|
||||
/// Force a reload of the auth information from auth.json. Returns
|
||||
/// whether the auth value changed.
|
||||
pub fn reload(&self) -> bool {
|
||||
let new_auth = load_auth(
|
||||
&self.codex_home,
|
||||
self.enable_codex_api_key_env,
|
||||
self.auth_credentials_store_mode,
|
||||
)
|
||||
.ok()
|
||||
.flatten();
|
||||
if let Ok(mut guard) = self.inner.write() {
|
||||
let changed = !AuthManager::auths_equal(&guard.auth, &new_auth);
|
||||
guard.auth = new_auth;
|
||||
changed
|
||||
} else {
|
||||
false
|
||||
tracing::info!("Reloading auth");
|
||||
let new_auth = self.load_auth_from_storage();
|
||||
self.set_auth(new_auth)
|
||||
}
|
||||
|
||||
fn reload_if_account_id_matches(&self, expected_account_id: Option<&str>) -> ReloadOutcome {
|
||||
let expected_account_id = match expected_account_id {
|
||||
Some(account_id) => account_id,
|
||||
None => {
|
||||
tracing::info!("Skipping auth reload because no account id is available.");
|
||||
return ReloadOutcome::Skipped;
|
||||
}
|
||||
};
|
||||
|
||||
let new_auth = self.load_auth_from_storage();
|
||||
let new_account_id = new_auth.as_ref().and_then(CodexAuth::get_account_id);
|
||||
|
||||
if new_account_id.as_deref() != Some(expected_account_id) {
|
||||
let found_account_id = new_account_id.as_deref().unwrap_or("unknown");
|
||||
tracing::info!(
|
||||
"Skipping auth reload due to account id mismatch (expected: {expected_account_id}, found: {found_account_id})"
|
||||
);
|
||||
return ReloadOutcome::Skipped;
|
||||
}
|
||||
|
||||
tracing::info!("Reloading auth for account {expected_account_id}");
|
||||
self.set_auth(new_auth);
|
||||
ReloadOutcome::Reloaded
|
||||
}
|
||||
|
||||
fn auths_equal(a: &Option<CodexAuth>, b: &Option<CodexAuth>) -> bool {
|
||||
@@ -721,6 +751,27 @@ impl AuthManager {
|
||||
}
|
||||
}
|
||||
|
||||
fn load_auth_from_storage(&self) -> Option<CodexAuth> {
|
||||
load_auth(
|
||||
&self.codex_home,
|
||||
self.enable_codex_api_key_env,
|
||||
self.auth_credentials_store_mode,
|
||||
)
|
||||
.ok()
|
||||
.flatten()
|
||||
}
|
||||
|
||||
fn set_auth(&self, new_auth: Option<CodexAuth>) -> bool {
|
||||
if let Ok(mut guard) = self.inner.write() {
|
||||
let changed = !AuthManager::auths_equal(&guard.auth, &new_auth);
|
||||
tracing::info!("Reloaded auth, changed: {changed}");
|
||||
guard.auth = new_auth;
|
||||
changed
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Convenience constructor returning an `Arc` wrapper.
|
||||
pub fn shared(
|
||||
codex_home: PathBuf,
|
||||
@@ -734,26 +785,27 @@ impl AuthManager {
|
||||
))
|
||||
}
|
||||
|
||||
pub fn unauthorized_recovery(self: &Arc<Self>) -> UnauthorizedRecovery {
|
||||
UnauthorizedRecovery::new(Arc::clone(self))
|
||||
}
|
||||
|
||||
/// Attempt to refresh the current auth token (if any). On success, reload
|
||||
/// the auth state from disk so other components observe refreshed token.
|
||||
/// If the token refresh fails in a permanent (non‑transient) way, logs out
|
||||
/// to clear invalid auth state.
|
||||
pub async fn refresh_token(&self) -> Result<Option<String>, RefreshTokenError> {
|
||||
let auth = match self.auth() {
|
||||
Some(a) => a,
|
||||
None => return Ok(None),
|
||||
/// If the token refresh fails, returns the error to the caller.
|
||||
pub async fn refresh_token(&self) -> Result<(), RefreshTokenError> {
|
||||
tracing::info!("Refreshing token");
|
||||
|
||||
let auth = match self.auth_cached() {
|
||||
Some(auth) => auth,
|
||||
None => return Ok(()),
|
||||
};
|
||||
match auth.refresh_token().await {
|
||||
Ok(token) => {
|
||||
// Reload to pick up persisted changes.
|
||||
self.reload();
|
||||
Ok(Some(token))
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to refresh token: {}", e);
|
||||
Err(e)
|
||||
}
|
||||
}
|
||||
let token_data = auth.get_current_token_data().ok_or_else(|| {
|
||||
RefreshTokenError::Transient(std::io::Error::other("Token data is not available."))
|
||||
})?;
|
||||
self.refresh_tokens(&auth, token_data.refresh_token).await?;
|
||||
// Reload to pick up persisted changes.
|
||||
self.reload();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Log out by deleting the on‑disk auth.json (if present). Returns Ok(true)
|
||||
@@ -768,7 +820,51 @@ impl AuthManager {
|
||||
}
|
||||
|
||||
pub fn get_auth_mode(&self) -> Option<AuthMode> {
|
||||
self.auth().map(|a| a.mode)
|
||||
self.auth_cached().map(|a| a.mode)
|
||||
}
|
||||
|
||||
async fn refresh_if_stale(&self, auth: &CodexAuth) -> Result<bool, RefreshTokenError> {
|
||||
if auth.mode != AuthMode::ChatGPT {
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
let auth_dot_json = match auth.get_current_auth_json() {
|
||||
Some(auth_dot_json) => auth_dot_json,
|
||||
None => return Ok(false),
|
||||
};
|
||||
let tokens = match auth_dot_json.tokens {
|
||||
Some(tokens) => tokens,
|
||||
None => return Ok(false),
|
||||
};
|
||||
let last_refresh = match auth_dot_json.last_refresh {
|
||||
Some(last_refresh) => last_refresh,
|
||||
None => return Ok(false),
|
||||
};
|
||||
if last_refresh >= Utc::now() - chrono::Duration::days(TOKEN_REFRESH_INTERVAL) {
|
||||
return Ok(false);
|
||||
}
|
||||
self.refresh_tokens(auth, tokens.refresh_token).await?;
|
||||
self.reload();
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
async fn refresh_tokens(
|
||||
&self,
|
||||
auth: &CodexAuth,
|
||||
refresh_token: String,
|
||||
) -> Result<(), RefreshTokenError> {
|
||||
let refresh_response = try_refresh_token(refresh_token, &auth.client).await?;
|
||||
|
||||
update_tokens(
|
||||
&auth.storage,
|
||||
refresh_response.id_token,
|
||||
refresh_response.access_token,
|
||||
refresh_response.refresh_token,
|
||||
)
|
||||
.await
|
||||
.map_err(RefreshTokenError::from)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -930,7 +1026,7 @@ mod tests {
|
||||
assert_eq!(auth.mode, AuthMode::ApiKey);
|
||||
assert_eq!(auth.api_key, Some("sk-test-key".to_string()));
|
||||
|
||||
assert!(auth.get_token_data().await.is_err());
|
||||
assert!(auth.get_token_data().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -1058,7 +1154,6 @@ mod tests {
|
||||
let config = build_config(codex_home.path(), Some(ForcedLoginMethod::Chatgpt), None).await;
|
||||
|
||||
let err = super::enforce_login_restrictions(&config)
|
||||
.await
|
||||
.expect_err("expected method mismatch to error");
|
||||
assert!(err.to_string().contains("ChatGPT login is required"));
|
||||
assert!(
|
||||
@@ -1084,7 +1179,6 @@ mod tests {
|
||||
let config = build_config(codex_home.path(), None, Some("org_mine".to_string())).await;
|
||||
|
||||
let err = super::enforce_login_restrictions(&config)
|
||||
.await
|
||||
.expect_err("expected workspace mismatch to error");
|
||||
assert!(err.to_string().contains("workspace org_mine"));
|
||||
assert!(
|
||||
@@ -1109,9 +1203,7 @@ mod tests {
|
||||
|
||||
let config = build_config(codex_home.path(), None, Some("org_mine".to_string())).await;
|
||||
|
||||
super::enforce_login_restrictions(&config)
|
||||
.await
|
||||
.expect("matching workspace should succeed");
|
||||
super::enforce_login_restrictions(&config).expect("matching workspace should succeed");
|
||||
assert!(
|
||||
codex_home.path().join("auth.json").exists(),
|
||||
"auth.json should remain when restrictions pass"
|
||||
@@ -1127,9 +1219,7 @@ mod tests {
|
||||
|
||||
let config = build_config(codex_home.path(), None, Some("org_mine".to_string())).await;
|
||||
|
||||
super::enforce_login_restrictions(&config)
|
||||
.await
|
||||
.expect("matching workspace should succeed");
|
||||
super::enforce_login_restrictions(&config).expect("matching workspace should succeed");
|
||||
assert!(
|
||||
codex_home.path().join("auth.json").exists(),
|
||||
"auth.json should remain when restrictions pass"
|
||||
@@ -1145,7 +1235,6 @@ mod tests {
|
||||
let config = build_config(codex_home.path(), Some(ForcedLoginMethod::Chatgpt), None).await;
|
||||
|
||||
let err = super::enforce_login_restrictions(&config)
|
||||
.await
|
||||
.expect_err("environment API key should not satisfy forced ChatGPT login");
|
||||
assert!(
|
||||
err.to_string()
|
||||
|
||||
@@ -2,6 +2,7 @@ use std::sync::Arc;
|
||||
|
||||
use crate::api_bridge::auth_provider_from_auth;
|
||||
use crate::api_bridge::map_api_error;
|
||||
use crate::auth::UnauthorizedRecovery;
|
||||
use codex_api::AggregateStreamExt;
|
||||
use codex_api::ChatClient as ApiChatClient;
|
||||
use codex_api::CompactClient as ApiCompactClient;
|
||||
@@ -17,8 +18,10 @@ use codex_api::TransportError;
|
||||
use codex_api::common::Reasoning;
|
||||
use codex_api::create_text_param_for_request;
|
||||
use codex_api::error::ApiError;
|
||||
use codex_api::requests::responses::Compression;
|
||||
use codex_app_server_protocol::AuthMode;
|
||||
use codex_otel::otel_manager::OtelManager;
|
||||
use codex_otel::OtelManager;
|
||||
|
||||
use codex_protocol::ThreadId;
|
||||
use codex_protocol::config_types::ReasoningSummary as ReasoningSummaryConfig;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
@@ -47,6 +50,7 @@ use crate::default_client::build_reqwest_client;
|
||||
use crate::error::CodexErr;
|
||||
use crate::error::Result;
|
||||
use crate::features::FEATURES;
|
||||
use crate::features::Feature;
|
||||
use crate::flags::CODEX_RS_SSE_FIXTURE;
|
||||
use crate::model_provider_info::ModelProviderInfo;
|
||||
use crate::model_provider_info::WireApi;
|
||||
@@ -153,13 +157,18 @@ impl ModelClient {
|
||||
let conversation_id = self.conversation_id.to_string();
|
||||
let session_source = self.session_source.clone();
|
||||
|
||||
let mut refreshed = false;
|
||||
let mut auth_recovery = auth_manager
|
||||
.as_ref()
|
||||
.map(super::auth::AuthManager::unauthorized_recovery);
|
||||
loop {
|
||||
let auth = auth_manager.as_ref().and_then(|m| m.auth());
|
||||
let auth = match auth_manager.as_ref() {
|
||||
Some(manager) => manager.auth().await,
|
||||
None => None,
|
||||
};
|
||||
let api_provider = self
|
||||
.provider
|
||||
.to_api_provider(auth.as_ref().map(|a| a.mode))?;
|
||||
let api_auth = auth_provider_from_auth(auth.clone(), &self.provider).await?;
|
||||
let api_auth = auth_provider_from_auth(auth.clone(), &self.provider)?;
|
||||
let transport = ReqwestTransport::new(build_reqwest_client());
|
||||
let (request_telemetry, sse_telemetry) = self.build_streaming_telemetry();
|
||||
let client = ApiChatClient::new(transport, api_provider, api_auth)
|
||||
@@ -179,7 +188,7 @@ impl ModelClient {
|
||||
Err(ApiError::Transport(TransportError::Http { status, .. }))
|
||||
if status == StatusCode::UNAUTHORIZED =>
|
||||
{
|
||||
handle_unauthorized(status, &mut refreshed, &auth_manager, &auth).await?;
|
||||
handle_unauthorized(status, &mut auth_recovery).await?;
|
||||
continue;
|
||||
}
|
||||
Err(err) => return Err(map_api_error(err)),
|
||||
@@ -241,15 +250,34 @@ impl ModelClient {
|
||||
let conversation_id = self.conversation_id.to_string();
|
||||
let session_source = self.session_source.clone();
|
||||
|
||||
let mut refreshed = false;
|
||||
let mut auth_recovery = auth_manager
|
||||
.as_ref()
|
||||
.map(super::auth::AuthManager::unauthorized_recovery);
|
||||
loop {
|
||||
let auth = auth_manager.as_ref().and_then(|m| m.auth());
|
||||
let auth = match auth_manager.as_ref() {
|
||||
Some(manager) => manager.auth().await,
|
||||
None => None,
|
||||
};
|
||||
let api_provider = self
|
||||
.provider
|
||||
.to_api_provider(auth.as_ref().map(|a| a.mode))?;
|
||||
let api_auth = auth_provider_from_auth(auth.clone(), &self.provider).await?;
|
||||
let api_auth = auth_provider_from_auth(auth.clone(), &self.provider)?;
|
||||
let transport = ReqwestTransport::new(build_reqwest_client());
|
||||
let (request_telemetry, sse_telemetry) = self.build_streaming_telemetry();
|
||||
let compression = if self
|
||||
.config
|
||||
.features
|
||||
.enabled(Feature::EnableRequestCompression)
|
||||
&& auth
|
||||
.as_ref()
|
||||
.is_some_and(|auth| auth.mode == AuthMode::ChatGPT)
|
||||
&& self.provider.is_openai()
|
||||
{
|
||||
Compression::Zstd
|
||||
} else {
|
||||
Compression::None
|
||||
};
|
||||
|
||||
let client = ApiResponsesClient::new(transport, api_provider, api_auth)
|
||||
.with_telemetry(Some(request_telemetry), Some(sse_telemetry));
|
||||
|
||||
@@ -262,6 +290,7 @@ impl ModelClient {
|
||||
conversation_id: Some(conversation_id.clone()),
|
||||
session_source: Some(session_source.clone()),
|
||||
extra_headers: beta_feature_headers(&self.config),
|
||||
compression,
|
||||
};
|
||||
|
||||
let stream_result = client
|
||||
@@ -275,7 +304,7 @@ impl ModelClient {
|
||||
Err(ApiError::Transport(TransportError::Http { status, .. }))
|
||||
if status == StatusCode::UNAUTHORIZED =>
|
||||
{
|
||||
handle_unauthorized(status, &mut refreshed, &auth_manager, &auth).await?;
|
||||
handle_unauthorized(status, &mut auth_recovery).await?;
|
||||
continue;
|
||||
}
|
||||
Err(err) => return Err(map_api_error(err)),
|
||||
@@ -327,11 +356,14 @@ impl ModelClient {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
let auth_manager = self.auth_manager.clone();
|
||||
let auth = auth_manager.as_ref().and_then(|m| m.auth());
|
||||
let auth = match auth_manager.as_ref() {
|
||||
Some(manager) => manager.auth().await,
|
||||
None => None,
|
||||
};
|
||||
let api_provider = self
|
||||
.provider
|
||||
.to_api_provider(auth.as_ref().map(|a| a.mode))?;
|
||||
let api_auth = auth_provider_from_auth(auth.clone(), &self.provider).await?;
|
||||
let api_auth = auth_provider_from_auth(auth.clone(), &self.provider)?;
|
||||
let transport = ReqwestTransport::new(build_reqwest_client());
|
||||
let request_telemetry = self.build_request_telemetry();
|
||||
let client = ApiCompactClient::new(transport, api_provider, api_auth)
|
||||
@@ -483,29 +515,19 @@ where
|
||||
/// the mapped `CodexErr` is returned to the caller.
|
||||
async fn handle_unauthorized(
|
||||
status: StatusCode,
|
||||
refreshed: &mut bool,
|
||||
auth_manager: &Option<Arc<AuthManager>>,
|
||||
auth: &Option<crate::auth::CodexAuth>,
|
||||
auth_recovery: &mut Option<UnauthorizedRecovery>,
|
||||
) -> Result<()> {
|
||||
if *refreshed {
|
||||
return Err(map_unauthorized_status(status));
|
||||
}
|
||||
|
||||
if let Some(manager) = auth_manager.as_ref()
|
||||
&& let Some(auth) = auth.as_ref()
|
||||
&& auth.mode == AuthMode::ChatGPT
|
||||
if let Some(recovery) = auth_recovery
|
||||
&& recovery.has_next()
|
||||
{
|
||||
match manager.refresh_token().await {
|
||||
Ok(_) => {
|
||||
*refreshed = true;
|
||||
Ok(())
|
||||
}
|
||||
return match recovery.next().await {
|
||||
Ok(_) => Ok(()),
|
||||
Err(RefreshTokenError::Permanent(failed)) => Err(CodexErr::RefreshTokenFailed(failed)),
|
||||
Err(RefreshTokenError::Transient(other)) => Err(CodexErr::Io(other)),
|
||||
}
|
||||
} else {
|
||||
Err(map_unauthorized_status(status))
|
||||
};
|
||||
}
|
||||
|
||||
Err(map_unauthorized_status(status))
|
||||
}
|
||||
|
||||
fn map_unauthorized_status(status: StatusCode) -> CodexErr {
|
||||
|
||||
@@ -7,6 +7,7 @@ use std::sync::atomic::AtomicU64;
|
||||
use std::sync::atomic::Ordering;
|
||||
|
||||
use crate::AuthManager;
|
||||
use crate::CodexAuth;
|
||||
use crate::SandboxState;
|
||||
use crate::agent::AgentControl;
|
||||
use crate::agent::AgentStatus;
|
||||
@@ -152,7 +153,7 @@ use crate::user_instructions::UserInstructions;
|
||||
use crate::user_notification::UserNotification;
|
||||
use crate::util::backoff;
|
||||
use codex_async_utils::OrCancelExt;
|
||||
use codex_otel::otel_manager::OtelManager;
|
||||
use codex_otel::OtelManager;
|
||||
use codex_protocol::config_types::ReasoningSummary as ReasoningSummaryConfig;
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::ResponseInputItem;
|
||||
@@ -633,17 +634,32 @@ impl Session {
|
||||
}
|
||||
maybe_push_chat_wire_api_deprecation(&config, &mut post_session_configured_events);
|
||||
|
||||
let auth = auth_manager.auth().await;
|
||||
let auth = auth.as_ref();
|
||||
let otel_manager = OtelManager::new(
|
||||
conversation_id,
|
||||
session_configuration.model.as_str(),
|
||||
session_configuration.model.as_str(),
|
||||
auth_manager.auth().and_then(|a| a.get_account_id()),
|
||||
auth_manager.auth().and_then(|a| a.get_account_email()),
|
||||
auth_manager.auth().map(|a| a.mode),
|
||||
auth.and_then(CodexAuth::get_account_id),
|
||||
auth.and_then(CodexAuth::get_account_email),
|
||||
auth.map(|a| a.mode),
|
||||
config.otel.log_user_prompt,
|
||||
terminal::user_agent(),
|
||||
session_configuration.session_source.clone(),
|
||||
);
|
||||
config.features.emit_metrics(&otel_manager);
|
||||
otel_manager.counter(
|
||||
"codex.session.started",
|
||||
1,
|
||||
&[(
|
||||
"is_git",
|
||||
if get_git_repo_root(&session_configuration.cwd).is_some() {
|
||||
"true"
|
||||
} else {
|
||||
"false"
|
||||
},
|
||||
)],
|
||||
);
|
||||
|
||||
otel_manager.conversation_starts(
|
||||
config.model_provider.name.as_str(),
|
||||
@@ -1243,12 +1259,10 @@ impl Session {
|
||||
);
|
||||
}
|
||||
RolloutItem::Compacted(compacted) => {
|
||||
let snapshot = history.get_history();
|
||||
// TODO(jif) clean
|
||||
if let Some(replacement) = &compacted.replacement_history {
|
||||
history.replace(replacement.clone());
|
||||
} else {
|
||||
let user_messages = collect_user_messages(&snapshot);
|
||||
let user_messages = collect_user_messages(history.raw_items());
|
||||
let rebuilt = compact::build_compacted_history(
|
||||
self.build_initial_context(turn_context),
|
||||
&user_messages,
|
||||
@@ -1263,7 +1277,7 @@ impl Session {
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
history.get_history()
|
||||
history.raw_items().to_vec()
|
||||
}
|
||||
|
||||
/// Append ResponseItems to the in-memory conversation history only.
|
||||
@@ -1756,6 +1770,7 @@ mod handlers {
|
||||
use codex_protocol::protocol::TurnAbortReason;
|
||||
use codex_protocol::protocol::WarningEvent;
|
||||
|
||||
use crate::context_manager::is_user_turn_boundary;
|
||||
use codex_protocol::user_input::UserInput;
|
||||
use codex_rmcp_client::ElicitationAction;
|
||||
use codex_rmcp_client::ElicitationResponse;
|
||||
@@ -2093,7 +2108,10 @@ mod handlers {
|
||||
|
||||
let mut history = sess.clone_history().await;
|
||||
history.drop_last_n_user_turns(num_turns);
|
||||
sess.replace_history(history.get_history()).await;
|
||||
|
||||
// Replace with the raw items. We don't want to replace with a normalized
|
||||
// version of the history.
|
||||
sess.replace_history(history.raw_items().to_vec()).await;
|
||||
sess.recompute_token_usage(turn_context.as_ref()).await;
|
||||
|
||||
sess.send_event_raw_flushed(Event {
|
||||
@@ -2110,6 +2128,17 @@ mod handlers {
|
||||
.terminate_all_processes()
|
||||
.await;
|
||||
info!("Shutting down Codex instance");
|
||||
let history = sess.clone_history().await;
|
||||
let turn_count = history
|
||||
.raw_items()
|
||||
.iter()
|
||||
.filter(|item| is_user_turn_boundary(item))
|
||||
.count();
|
||||
sess.services.otel_manager.counter(
|
||||
"conversation.turn.count",
|
||||
i64::try_from(turn_count).unwrap_or(0),
|
||||
&[],
|
||||
);
|
||||
|
||||
// Gracefully flush and shutdown rollout recorder on session end so tests
|
||||
// that inspect the rollout file do not race with the background writer.
|
||||
@@ -2367,7 +2396,7 @@ pub(crate) async fn run_task(
|
||||
let turn_input: Vec<ResponseItem> = {
|
||||
sess.record_conversation_items(&turn_context, &pending_input)
|
||||
.await;
|
||||
sess.clone_history().await.get_history_for_prompt()
|
||||
sess.clone_history().await.for_prompt()
|
||||
};
|
||||
|
||||
let turn_input_messages = turn_input
|
||||
@@ -2495,7 +2524,7 @@ async fn run_turn(
|
||||
|
||||
let mut retries = 0;
|
||||
loop {
|
||||
match try_run_turn(
|
||||
let err = match try_run_turn(
|
||||
Arc::clone(&router),
|
||||
Arc::clone(&sess),
|
||||
Arc::clone(&turn_context),
|
||||
@@ -2505,17 +2534,10 @@ async fn run_turn(
|
||||
)
|
||||
.await
|
||||
{
|
||||
// todo(aibrahim): map special cases and ? on other errors
|
||||
Ok(output) => return Ok(output),
|
||||
Err(CodexErr::TurnAborted) => {
|
||||
return Err(CodexErr::TurnAborted);
|
||||
}
|
||||
Err(CodexErr::Interrupted) => return Err(CodexErr::Interrupted),
|
||||
Err(CodexErr::EnvVar(var)) => return Err(CodexErr::EnvVar(var)),
|
||||
Err(e @ CodexErr::Fatal(_)) => return Err(e),
|
||||
Err(e @ CodexErr::ContextWindowExceeded) => {
|
||||
Err(CodexErr::ContextWindowExceeded) => {
|
||||
sess.set_total_tokens_full(&turn_context).await;
|
||||
return Err(e);
|
||||
return Err(CodexErr::ContextWindowExceeded);
|
||||
}
|
||||
Err(CodexErr::UsageLimitReached(e)) => {
|
||||
let rate_limits = e.rate_limits.clone();
|
||||
@@ -2524,39 +2546,38 @@ async fn run_turn(
|
||||
}
|
||||
return Err(CodexErr::UsageLimitReached(e));
|
||||
}
|
||||
Err(CodexErr::UsageNotIncluded) => return Err(CodexErr::UsageNotIncluded),
|
||||
Err(e @ CodexErr::QuotaExceeded) => return Err(e),
|
||||
Err(e @ CodexErr::InvalidImageRequest()) => return Err(e),
|
||||
Err(e @ CodexErr::InvalidRequest(_)) => return Err(e),
|
||||
Err(e @ CodexErr::RefreshTokenFailed(_)) => return Err(e),
|
||||
Err(e) => {
|
||||
// Use the configured provider-specific stream retry budget.
|
||||
let max_retries = turn_context.client.get_provider().stream_max_retries();
|
||||
if retries < max_retries {
|
||||
retries += 1;
|
||||
let delay = match e {
|
||||
CodexErr::Stream(_, Some(delay)) => delay,
|
||||
_ => backoff(retries),
|
||||
};
|
||||
warn!(
|
||||
"stream disconnected - retrying turn ({retries}/{max_retries} in {delay:?})...",
|
||||
);
|
||||
Err(err) => err,
|
||||
};
|
||||
|
||||
// Surface retry information to any UI/front‑end so the
|
||||
// user understands what is happening instead of staring
|
||||
// at a seemingly frozen screen.
|
||||
sess.notify_stream_error(
|
||||
&turn_context,
|
||||
format!("Reconnecting... {retries}/{max_retries}"),
|
||||
e,
|
||||
)
|
||||
.await;
|
||||
if !err.is_retryable() {
|
||||
return Err(err);
|
||||
}
|
||||
|
||||
tokio::time::sleep(delay).await;
|
||||
} else {
|
||||
return Err(e);
|
||||
// Use the configured provider-specific stream retry budget.
|
||||
let max_retries = turn_context.client.get_provider().stream_max_retries();
|
||||
if retries < max_retries {
|
||||
retries += 1;
|
||||
let delay = match &err {
|
||||
CodexErr::Stream(_, requested_delay) => {
|
||||
requested_delay.unwrap_or_else(|| backoff(retries))
|
||||
}
|
||||
}
|
||||
_ => backoff(retries),
|
||||
};
|
||||
warn!("stream disconnected - retrying turn ({retries}/{max_retries} in {delay:?})...",);
|
||||
|
||||
// Surface retry information to any UI/front‑end so the
|
||||
// user understands what is happening instead of staring
|
||||
// at a seemingly frozen screen.
|
||||
sess.notify_stream_error(
|
||||
&turn_context,
|
||||
format!("Reconnecting... {retries}/{max_retries}"),
|
||||
err,
|
||||
)
|
||||
.await;
|
||||
|
||||
tokio::time::sleep(delay).await;
|
||||
} else {
|
||||
return Err(err);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2841,6 +2862,7 @@ pub(super) fn get_last_assistant_message_from_turn(responses: &[ResponseItem]) -
|
||||
#[cfg(test)]
|
||||
pub(crate) use tests::make_session_and_context;
|
||||
|
||||
use crate::git_info::get_git_repo_root;
|
||||
#[cfg(test)]
|
||||
pub(crate) use tests::make_session_and_context_with_rx;
|
||||
|
||||
@@ -2915,8 +2937,8 @@ mod tests {
|
||||
}))
|
||||
.await;
|
||||
|
||||
let actual = session.state.lock().await.clone_history().get_history();
|
||||
assert_eq!(expected, actual);
|
||||
let history = session.state.lock().await.clone_history();
|
||||
assert_eq!(expected, history.raw_items());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -3005,8 +3027,8 @@ mod tests {
|
||||
.record_initial_history(InitialHistory::Forked(rollout_items))
|
||||
.await;
|
||||
|
||||
let actual = session.state.lock().await.clone_history().get_history();
|
||||
assert_eq!(expected, actual);
|
||||
let history = session.state.lock().await.clone_history();
|
||||
assert_eq!(expected, history.raw_items());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -3062,8 +3084,8 @@ mod tests {
|
||||
expected.extend(initial_context);
|
||||
expected.extend(turn_1);
|
||||
|
||||
let actual = sess.clone_history().await.get_history();
|
||||
assert_eq!(expected, actual);
|
||||
let history = sess.clone_history().await;
|
||||
assert_eq!(expected, history.raw_items());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -3088,8 +3110,8 @@ mod tests {
|
||||
let rollback_event = wait_for_thread_rolled_back(&rx).await;
|
||||
assert_eq!(rollback_event.num_turns, 99);
|
||||
|
||||
let actual = sess.clone_history().await.get_history();
|
||||
assert_eq!(initial_context, actual);
|
||||
let history = sess.clone_history().await;
|
||||
assert_eq!(initial_context, history.raw_items());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -3109,8 +3131,8 @@ mod tests {
|
||||
Some(CodexErrorInfo::ThreadRollbackFailed)
|
||||
);
|
||||
|
||||
let actual = sess.clone_history().await.get_history();
|
||||
assert_eq!(initial_context, actual);
|
||||
let history = sess.clone_history().await;
|
||||
assert_eq!(initial_context, history.raw_items());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -3130,8 +3152,8 @@ mod tests {
|
||||
Some(CodexErrorInfo::ThreadRollbackFailed)
|
||||
);
|
||||
|
||||
let actual = sess.clone_history().await.get_history();
|
||||
assert_eq!(initial_context, actual);
|
||||
let history = sess.clone_history().await;
|
||||
assert_eq!(initial_context, history.raw_items());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -3632,8 +3654,8 @@ mod tests {
|
||||
.record_model_warning("too many unified exec processes", &turn_context)
|
||||
.await;
|
||||
|
||||
let mut history = session.clone_history().await;
|
||||
let history_items = history.get_history();
|
||||
let history = session.clone_history().await;
|
||||
let history_items = history.raw_items();
|
||||
let last = history_items.last().expect("warning recorded");
|
||||
|
||||
match last {
|
||||
@@ -3779,8 +3801,9 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
let history = sess.clone_history().await.get_history();
|
||||
let _ = history;
|
||||
// TODO(jif) investigate what is this?
|
||||
let history = sess.clone_history().await;
|
||||
let _ = history.raw_items();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -3869,7 +3892,7 @@ mod tests {
|
||||
rollout_items.push(RolloutItem::ResponseItem(assistant1.clone()));
|
||||
|
||||
let summary1 = "summary one";
|
||||
let snapshot1 = live_history.get_history();
|
||||
let snapshot1 = live_history.clone().for_prompt();
|
||||
let user_messages1 = collect_user_messages(&snapshot1);
|
||||
let rebuilt1 = compact::build_compacted_history(
|
||||
session.build_initial_context(turn_context),
|
||||
@@ -3903,7 +3926,7 @@ mod tests {
|
||||
rollout_items.push(RolloutItem::ResponseItem(assistant2.clone()));
|
||||
|
||||
let summary2 = "summary two";
|
||||
let snapshot2 = live_history.get_history();
|
||||
let snapshot2 = live_history.clone().for_prompt();
|
||||
let user_messages2 = collect_user_messages(&snapshot2);
|
||||
let rebuilt2 = compact::build_compacted_history(
|
||||
session.build_initial_context(turn_context),
|
||||
@@ -3936,7 +3959,7 @@ mod tests {
|
||||
live_history.record_items(std::iter::once(&assistant3), turn_context.truncation_policy);
|
||||
rollout_items.push(RolloutItem::ResponseItem(assistant3.clone()));
|
||||
|
||||
(rollout_items, live_history.get_history())
|
||||
(rollout_items, live_history.for_prompt())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
||||
@@ -95,9 +95,11 @@ async fn run_compact_task_inner(
|
||||
sess.persist_rollout_items(&[rollout_item]).await;
|
||||
|
||||
loop {
|
||||
let turn_input = history.get_history_for_prompt();
|
||||
// Clone is required because of the loop
|
||||
let turn_input = history.clone().for_prompt();
|
||||
let turn_input_len = turn_input.len();
|
||||
let prompt = Prompt {
|
||||
input: turn_input.clone(),
|
||||
input: turn_input,
|
||||
..Default::default()
|
||||
};
|
||||
let attempt_result = drain_to_completed(&sess, turn_context.as_ref(), &prompt).await;
|
||||
@@ -119,7 +121,7 @@ async fn run_compact_task_inner(
|
||||
return;
|
||||
}
|
||||
Err(e @ CodexErr::ContextWindowExceeded) => {
|
||||
if turn_input.len() > 1 {
|
||||
if turn_input_len > 1 {
|
||||
// Trim from the beginning to preserve cache (prefix-based) and keep recent messages intact.
|
||||
error!(
|
||||
"Context window exceeded while compacting; removing oldest history item. Error: {e}"
|
||||
@@ -155,15 +157,15 @@ async fn run_compact_task_inner(
|
||||
}
|
||||
}
|
||||
|
||||
let history_snapshot = sess.clone_history().await.get_history();
|
||||
let summary_suffix =
|
||||
get_last_assistant_message_from_turn(&history_snapshot).unwrap_or_default();
|
||||
let history_snapshot = sess.clone_history().await;
|
||||
let history_items = history_snapshot.raw_items();
|
||||
let summary_suffix = get_last_assistant_message_from_turn(history_items).unwrap_or_default();
|
||||
let summary_text = format!("{SUMMARY_PREFIX}\n{summary_suffix}");
|
||||
let user_messages = collect_user_messages(&history_snapshot);
|
||||
let user_messages = collect_user_messages(history_items);
|
||||
|
||||
let initial_context = sess.build_initial_context(turn_context.as_ref());
|
||||
let mut new_history = build_compacted_history(initial_context, &user_messages, &summary_text);
|
||||
let ghost_snapshots: Vec<ResponseItem> = history_snapshot
|
||||
let ghost_snapshots: Vec<ResponseItem> = history_items
|
||||
.iter()
|
||||
.filter(|item| matches!(item, ResponseItem::GhostSnapshot { .. }))
|
||||
.cloned()
|
||||
|
||||
@@ -40,9 +40,18 @@ async fn run_remote_compact_task_inner_impl(
|
||||
sess: &Arc<Session>,
|
||||
turn_context: &Arc<TurnContext>,
|
||||
) -> CodexResult<()> {
|
||||
let mut history = sess.clone_history().await;
|
||||
let history = sess.clone_history().await;
|
||||
|
||||
// Required to keep `/undo` available after compaction
|
||||
let ghost_snapshots: Vec<ResponseItem> = history
|
||||
.raw_items()
|
||||
.iter()
|
||||
.filter(|item| matches!(item, ResponseItem::GhostSnapshot { .. }))
|
||||
.cloned()
|
||||
.collect();
|
||||
|
||||
let prompt = Prompt {
|
||||
input: history.get_history_for_prompt(),
|
||||
input: history.for_prompt(),
|
||||
tools: vec![],
|
||||
parallel_tool_calls: false,
|
||||
base_instructions_override: turn_context.base_instructions.clone(),
|
||||
@@ -53,13 +62,6 @@ async fn run_remote_compact_task_inner_impl(
|
||||
.client
|
||||
.compact_conversation_history(&prompt)
|
||||
.await?;
|
||||
// Required to keep `/undo` available after compaction
|
||||
let ghost_snapshots: Vec<ResponseItem> = history
|
||||
.get_history()
|
||||
.iter()
|
||||
.filter(|item| matches!(item, ResponseItem::GhostSnapshot { .. }))
|
||||
.cloned()
|
||||
.collect();
|
||||
|
||||
if !ghost_snapshots.is_empty() {
|
||||
new_history.extend(ghost_snapshots);
|
||||
|
||||
@@ -1,25 +1,26 @@
|
||||
use std::fmt;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::config_loader::RequirementSource;
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, Error, PartialEq, Eq)]
|
||||
pub enum ConstraintError {
|
||||
#[error("value `{candidate}` is not in the allowed set {allowed}")]
|
||||
InvalidValue { candidate: String, allowed: String },
|
||||
#[error(
|
||||
"invalid value for `{field_name}`: `{candidate}` is not in the allowed set {allowed} (set by {requirement_source})"
|
||||
)]
|
||||
InvalidValue {
|
||||
field_name: &'static str,
|
||||
candidate: String,
|
||||
allowed: String,
|
||||
requirement_source: RequirementSource,
|
||||
},
|
||||
|
||||
#[error("field `{field_name}` cannot be empty")]
|
||||
EmptyField { field_name: String },
|
||||
}
|
||||
|
||||
impl ConstraintError {
|
||||
pub fn invalid_value(candidate: impl Into<String>, allowed: impl Into<String>) -> Self {
|
||||
Self::InvalidValue {
|
||||
candidate: candidate.into(),
|
||||
allowed: allowed.into(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn empty_field(field_name: impl Into<String>) -> Self {
|
||||
Self::EmptyField {
|
||||
field_name: field_name.into(),
|
||||
@@ -63,24 +64,6 @@ impl<T: Send + Sync> Constrained<T> {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn allow_only(value: T) -> Self
|
||||
where
|
||||
T: PartialEq + Send + Sync + fmt::Debug + Clone + 'static,
|
||||
{
|
||||
#[expect(clippy::expect_used)]
|
||||
Self::new(value.clone(), move |candidate| {
|
||||
if *candidate == value {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(ConstraintError::invalid_value(
|
||||
format!("{candidate:?}"),
|
||||
format!("{value:?}"),
|
||||
))
|
||||
}
|
||||
})
|
||||
.expect("initial value should always be valid")
|
||||
}
|
||||
|
||||
/// Allow any value of T, using T's Default as the initial value.
|
||||
pub fn allow_any_from_default() -> Self
|
||||
where
|
||||
@@ -89,22 +72,6 @@ impl<T: Send + Sync> Constrained<T> {
|
||||
Self::allow_any(T::default())
|
||||
}
|
||||
|
||||
pub fn allow_values(initial_value: T, allowed: Vec<T>) -> ConstraintResult<Self>
|
||||
where
|
||||
T: PartialEq + Send + Sync + fmt::Debug + 'static,
|
||||
{
|
||||
Self::new(initial_value, move |candidate| {
|
||||
if allowed.contains(candidate) {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(ConstraintError::invalid_value(
|
||||
format!("{candidate:?}"),
|
||||
format!("{allowed:?}"),
|
||||
))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub fn get(&self) -> &T {
|
||||
&self.value
|
||||
}
|
||||
@@ -154,6 +121,15 @@ mod tests {
|
||||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
fn invalid_value(candidate: impl Into<String>, allowed: impl Into<String>) -> ConstraintError {
|
||||
ConstraintError::InvalidValue {
|
||||
field_name: "<unknown>",
|
||||
candidate: candidate.into(),
|
||||
allowed: allowed.into(),
|
||||
requirement_source: RequirementSource::Unknown,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn constrained_allow_any_accepts_any_value() {
|
||||
let mut constrained = Constrained::allow_any(5);
|
||||
@@ -173,17 +149,11 @@ mod tests {
|
||||
if *value > 0 {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(ConstraintError::invalid_value(
|
||||
value.to_string(),
|
||||
"positive values",
|
||||
))
|
||||
Err(invalid_value(value.to_string(), "positive values"))
|
||||
}
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
result,
|
||||
Err(ConstraintError::invalid_value("0", "positive values"))
|
||||
);
|
||||
assert_eq!(result, Err(invalid_value("0", "positive values")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -192,10 +162,7 @@ mod tests {
|
||||
if *value > 0 {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(ConstraintError::invalid_value(
|
||||
value.to_string(),
|
||||
"positive values",
|
||||
))
|
||||
Err(invalid_value(value.to_string(), "positive values"))
|
||||
}
|
||||
})
|
||||
.expect("initial value should be accepted");
|
||||
@@ -203,7 +170,7 @@ mod tests {
|
||||
let err = constrained
|
||||
.set(-5)
|
||||
.expect_err("negative values should be rejected");
|
||||
assert_eq!(err, ConstraintError::invalid_value("-5", "positive values"));
|
||||
assert_eq!(err, invalid_value("-5", "positive values"));
|
||||
assert_eq!(constrained.value(), 1);
|
||||
}
|
||||
|
||||
@@ -213,10 +180,7 @@ mod tests {
|
||||
if *value > 0 {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(ConstraintError::invalid_value(
|
||||
value.to_string(),
|
||||
"positive values",
|
||||
))
|
||||
Err(invalid_value(value.to_string(), "positive values"))
|
||||
}
|
||||
})
|
||||
.expect("initial value should be accepted");
|
||||
@@ -227,7 +191,7 @@ mod tests {
|
||||
let err = constrained
|
||||
.can_set(&-1)
|
||||
.expect_err("can_set should reject negative value");
|
||||
assert_eq!(err, ConstraintError::invalid_value("-1", "positive values"));
|
||||
assert_eq!(err, invalid_value("-1", "positive values"));
|
||||
assert_eq!(constrained.value(), 1);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1444,6 +1444,7 @@ impl Config {
|
||||
environment,
|
||||
exporter,
|
||||
trace_exporter,
|
||||
metrics_exporter: OtelExporterKind::Statsig,
|
||||
}
|
||||
},
|
||||
};
|
||||
@@ -1503,6 +1504,15 @@ impl Config {
|
||||
}
|
||||
self.forced_auto_mode_downgraded_on_windows = !value;
|
||||
}
|
||||
|
||||
pub fn set_windows_elevated_sandbox_globally(&mut self, value: bool) {
|
||||
crate::safety::set_windows_elevated_sandbox_enabled(value);
|
||||
if value {
|
||||
self.features.enable(Feature::WindowsSandboxElevated);
|
||||
} else {
|
||||
self.features.disable(Feature::WindowsSandboxElevated);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn default_review_model() -> String {
|
||||
|
||||
@@ -4,6 +4,7 @@ use crate::config::edit::ConfigEdit;
|
||||
use crate::config::edit::ConfigEditsBuilder;
|
||||
use crate::config_loader::ConfigLayerEntry;
|
||||
use crate::config_loader::ConfigLayerStack;
|
||||
use crate::config_loader::ConfigRequirementsToml;
|
||||
use crate::config_loader::LoaderOverrides;
|
||||
use crate::config_loader::load_config_layers_state;
|
||||
use crate::config_loader::merge_toml_values;
|
||||
@@ -157,6 +158,22 @@ impl ConfigService {
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn read_requirements(
|
||||
&self,
|
||||
) -> Result<Option<ConfigRequirementsToml>, ConfigServiceError> {
|
||||
let layers = self
|
||||
.load_thread_agnostic_config()
|
||||
.await
|
||||
.map_err(|err| ConfigServiceError::io("failed to read configuration layers", err))?;
|
||||
|
||||
let requirements = layers.requirements_toml().clone();
|
||||
if requirements.is_empty() {
|
||||
Ok(None)
|
||||
} else {
|
||||
Ok(Some(requirements))
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn write_value(
|
||||
&self,
|
||||
params: ConfigValueWriteParams,
|
||||
|
||||
@@ -306,6 +306,7 @@ pub struct OtelTlsConfig {
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub enum OtelExporterKind {
|
||||
None,
|
||||
Statsig,
|
||||
OtlpHttp {
|
||||
endpoint: String,
|
||||
#[serde(default)]
|
||||
@@ -346,6 +347,7 @@ pub struct OtelConfig {
|
||||
pub environment: String,
|
||||
pub exporter: OtelExporterKind,
|
||||
pub trace_exporter: OtelExporterKind,
|
||||
pub metrics_exporter: OtelExporterKind,
|
||||
}
|
||||
|
||||
impl Default for OtelConfig {
|
||||
@@ -355,6 +357,7 @@ impl Default for OtelConfig {
|
||||
environment: DEFAULT_OTEL_ENVIRONMENT.to_owned(),
|
||||
exporter: OtelExporterKind::None,
|
||||
trace_exporter: OtelExporterKind::None,
|
||||
metrics_exporter: OtelExporterKind::Statsig,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,11 +1,42 @@
|
||||
use codex_protocol::config_types::SandboxMode;
|
||||
use codex_protocol::protocol::AskForApproval;
|
||||
use codex_protocol::protocol::SandboxPolicy;
|
||||
use codex_utils_absolute_path::AbsolutePathBuf;
|
||||
use serde::Deserialize;
|
||||
use std::fmt;
|
||||
|
||||
use crate::config::Constrained;
|
||||
use crate::config::ConstraintError;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum RequirementSource {
|
||||
Unknown,
|
||||
MdmManagedPreferences { domain: String, key: String },
|
||||
SystemRequirementsToml { file: AbsolutePathBuf },
|
||||
LegacyManagedConfigTomlFromFile { file: AbsolutePathBuf },
|
||||
LegacyManagedConfigTomlFromMdm,
|
||||
}
|
||||
|
||||
impl fmt::Display for RequirementSource {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
RequirementSource::Unknown => write!(f, "<unspecified>"),
|
||||
RequirementSource::MdmManagedPreferences { domain, key } => {
|
||||
write!(f, "MDM {domain}:{key}")
|
||||
}
|
||||
RequirementSource::SystemRequirementsToml { file } => {
|
||||
write!(f, "{}", file.as_path().display())
|
||||
}
|
||||
RequirementSource::LegacyManagedConfigTomlFromFile { file } => {
|
||||
write!(f, "{}", file.as_path().display())
|
||||
}
|
||||
RequirementSource::LegacyManagedConfigTomlFromMdm => {
|
||||
write!(f, "MDM managed_config.toml (legacy)")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Normalized version of [`ConfigRequirementsToml`] after deserialization and
|
||||
/// normalization.
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
@@ -30,6 +61,75 @@ pub struct ConfigRequirementsToml {
|
||||
pub allowed_sandbox_modes: Option<Vec<SandboxModeRequirement>>,
|
||||
}
|
||||
|
||||
/// Value paired with the requirement source it came from, for better error
|
||||
/// messages.
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct Sourced<T> {
|
||||
pub value: T,
|
||||
pub source: RequirementSource,
|
||||
}
|
||||
|
||||
impl<T> Sourced<T> {
|
||||
pub fn new(value: T, source: RequirementSource) -> Self {
|
||||
Self { value, source }
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> std::ops::Deref for Sourced<T> {
|
||||
type Target = T;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.value
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, PartialEq)]
|
||||
pub struct ConfigRequirementsWithSources {
|
||||
pub allowed_approval_policies: Option<Sourced<Vec<AskForApproval>>>,
|
||||
pub allowed_sandbox_modes: Option<Sourced<Vec<SandboxModeRequirement>>>,
|
||||
}
|
||||
|
||||
impl ConfigRequirementsWithSources {
|
||||
pub fn merge_unset_fields(&mut self, source: RequirementSource, other: ConfigRequirementsToml) {
|
||||
// For every field in `other` that is `Some`, if the corresponding field
|
||||
// in `self` is `None`, copy the value from `other` into `self`.
|
||||
macro_rules! fill_missing_take {
|
||||
($base:expr, $other:expr, $source:expr, { $($field:ident),+ $(,)? }) => {
|
||||
// Destructure without `..` so adding fields to `ConfigRequirementsToml`
|
||||
// forces this merge logic to be updated.
|
||||
let ConfigRequirementsToml { $($field: _,)+ } = &$other;
|
||||
|
||||
$(
|
||||
if $base.$field.is_none()
|
||||
&& let Some(value) = $other.$field.take()
|
||||
{
|
||||
$base.$field = Some(Sourced::new(value, $source.clone()));
|
||||
}
|
||||
)+
|
||||
};
|
||||
}
|
||||
|
||||
let mut other = other;
|
||||
fill_missing_take!(
|
||||
self,
|
||||
other,
|
||||
source,
|
||||
{ allowed_approval_policies, allowed_sandbox_modes }
|
||||
);
|
||||
}
|
||||
|
||||
pub fn into_toml(self) -> ConfigRequirementsToml {
|
||||
let ConfigRequirementsWithSources {
|
||||
allowed_approval_policies,
|
||||
allowed_sandbox_modes,
|
||||
} = self;
|
||||
ConfigRequirementsToml {
|
||||
allowed_approval_policies: allowed_approval_policies.map(|sourced| sourced.value),
|
||||
allowed_sandbox_modes: allowed_sandbox_modes.map(|sourced| sourced.value),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Currently, `external-sandbox` is not supported in config.toml, but it is
|
||||
/// supported through programmatic use.
|
||||
#[derive(Deserialize, Debug, Clone, Copy, PartialEq)]
|
||||
@@ -58,40 +158,41 @@ impl From<SandboxMode> for SandboxModeRequirement {
|
||||
}
|
||||
|
||||
impl ConfigRequirementsToml {
|
||||
/// For every field in `other` that is `Some`, if the corresponding field in
|
||||
/// `self` is `None`, copy the value from `other` into `self`.
|
||||
pub fn merge_unset_fields(&mut self, mut other: ConfigRequirementsToml) {
|
||||
macro_rules! fill_missing_take {
|
||||
($base:expr, $other:expr, { $($field:ident),+ $(,)? }) => {
|
||||
$(
|
||||
if $base.$field.is_none() {
|
||||
if let Some(value) = $other.$field.take() {
|
||||
$base.$field = Some(value);
|
||||
}
|
||||
}
|
||||
)+
|
||||
};
|
||||
}
|
||||
|
||||
fill_missing_take!(self, other, { allowed_approval_policies, allowed_sandbox_modes });
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.allowed_approval_policies.is_none() && self.allowed_sandbox_modes.is_none()
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<ConfigRequirementsToml> for ConfigRequirements {
|
||||
impl TryFrom<ConfigRequirementsWithSources> for ConfigRequirements {
|
||||
type Error = ConstraintError;
|
||||
|
||||
fn try_from(toml: ConfigRequirementsToml) -> Result<Self, Self::Error> {
|
||||
let ConfigRequirementsToml {
|
||||
fn try_from(toml: ConfigRequirementsWithSources) -> Result<Self, Self::Error> {
|
||||
let ConfigRequirementsWithSources {
|
||||
allowed_approval_policies,
|
||||
allowed_sandbox_modes,
|
||||
} = toml;
|
||||
|
||||
let approval_policy: Constrained<AskForApproval> = match allowed_approval_policies {
|
||||
Some(policies) => {
|
||||
if let Some(first) = policies.first() {
|
||||
Constrained::allow_values(*first, policies)?
|
||||
} else {
|
||||
Some(Sourced {
|
||||
value: policies,
|
||||
source: requirement_source,
|
||||
}) => {
|
||||
let Some(initial_value) = policies.first().copied() else {
|
||||
return Err(ConstraintError::empty_field("allowed_approval_policies"));
|
||||
}
|
||||
};
|
||||
|
||||
Constrained::new(initial_value, move |candidate| {
|
||||
if policies.contains(candidate) {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(ConstraintError::InvalidValue {
|
||||
field_name: "approval_policy",
|
||||
candidate: format!("{candidate:?}"),
|
||||
allowed: format!("{policies:?}"),
|
||||
requirement_source: requirement_source.clone(),
|
||||
})
|
||||
}
|
||||
})?
|
||||
}
|
||||
None => Constrained::allow_any_from_default(),
|
||||
};
|
||||
@@ -105,12 +206,17 @@ impl TryFrom<ConfigRequirementsToml> for ConfigRequirements {
|
||||
// format to allow specifying those parameters.
|
||||
let default_sandbox_policy = SandboxPolicy::ReadOnly;
|
||||
let sandbox_policy: Constrained<SandboxPolicy> = match allowed_sandbox_modes {
|
||||
Some(modes) => {
|
||||
Some(Sourced {
|
||||
value: modes,
|
||||
source: requirement_source,
|
||||
}) => {
|
||||
if !modes.contains(&SandboxModeRequirement::ReadOnly) {
|
||||
return Err(ConstraintError::invalid_value(
|
||||
"allowed_sandbox_modes",
|
||||
"must include 'read-only' to allow any SandboxPolicy",
|
||||
));
|
||||
return Err(ConstraintError::InvalidValue {
|
||||
field_name: "allowed_sandbox_modes",
|
||||
candidate: format!("{modes:?}"),
|
||||
allowed: "must include 'read-only' to allow any SandboxPolicy".to_string(),
|
||||
requirement_source,
|
||||
});
|
||||
};
|
||||
|
||||
Constrained::new(default_sandbox_policy, move |candidate| {
|
||||
@@ -127,10 +233,12 @@ impl TryFrom<ConfigRequirementsToml> for ConfigRequirements {
|
||||
if modes.contains(&mode) {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(ConstraintError::invalid_value(
|
||||
format!("{candidate:?}"),
|
||||
format!("{modes:?}"),
|
||||
))
|
||||
Err(ConstraintError::InvalidValue {
|
||||
field_name: "sandbox_mode",
|
||||
candidate: format!("{mode:?}"),
|
||||
allowed: format!("{modes:?}"),
|
||||
requirement_source: requirement_source.clone(),
|
||||
})
|
||||
}
|
||||
})?
|
||||
}
|
||||
@@ -152,45 +260,168 @@ mod tests {
|
||||
use pretty_assertions::assert_eq;
|
||||
use toml::from_str;
|
||||
|
||||
fn with_unknown_source(toml: ConfigRequirementsToml) -> ConfigRequirementsWithSources {
|
||||
let ConfigRequirementsToml {
|
||||
allowed_approval_policies,
|
||||
allowed_sandbox_modes,
|
||||
} = toml;
|
||||
ConfigRequirementsWithSources {
|
||||
allowed_approval_policies: allowed_approval_policies
|
||||
.map(|value| Sourced::new(value, RequirementSource::Unknown)),
|
||||
allowed_sandbox_modes: allowed_sandbox_modes
|
||||
.map(|value| Sourced::new(value, RequirementSource::Unknown)),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn merge_unset_fields_only_fills_missing_values() -> Result<()> {
|
||||
fn merge_unset_fields_copies_every_field_and_sets_sources() {
|
||||
let mut target = ConfigRequirementsWithSources::default();
|
||||
let source = RequirementSource::LegacyManagedConfigTomlFromMdm;
|
||||
|
||||
let allowed_approval_policies = vec![AskForApproval::UnlessTrusted, AskForApproval::Never];
|
||||
let allowed_sandbox_modes = vec![
|
||||
SandboxModeRequirement::WorkspaceWrite,
|
||||
SandboxModeRequirement::DangerFullAccess,
|
||||
];
|
||||
|
||||
// Intentionally constructed without `..Default::default()` so adding a new field to
|
||||
// `ConfigRequirementsToml` forces this test to be updated.
|
||||
let other = ConfigRequirementsToml {
|
||||
allowed_approval_policies: Some(allowed_approval_policies.clone()),
|
||||
allowed_sandbox_modes: Some(allowed_sandbox_modes.clone()),
|
||||
};
|
||||
|
||||
target.merge_unset_fields(source.clone(), other);
|
||||
|
||||
assert_eq!(
|
||||
target,
|
||||
ConfigRequirementsWithSources {
|
||||
allowed_approval_policies: Some(Sourced::new(
|
||||
allowed_approval_policies,
|
||||
source.clone()
|
||||
)),
|
||||
allowed_sandbox_modes: Some(Sourced::new(allowed_sandbox_modes, source)),
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn merge_unset_fields_fills_missing_values() -> Result<()> {
|
||||
let source: ConfigRequirementsToml = from_str(
|
||||
r#"
|
||||
allowed_approval_policies = ["on-request"]
|
||||
"#,
|
||||
)?;
|
||||
|
||||
let mut empty_target: ConfigRequirementsToml = from_str(
|
||||
r#"
|
||||
# intentionally left unset
|
||||
"#,
|
||||
)?;
|
||||
empty_target.merge_unset_fields(source.clone());
|
||||
assert_eq!(
|
||||
empty_target.allowed_approval_policies,
|
||||
Some(vec![AskForApproval::OnRequest])
|
||||
);
|
||||
let source_location = RequirementSource::MdmManagedPreferences {
|
||||
domain: "com.codex".to_string(),
|
||||
key: "allowed_approval_policies".to_string(),
|
||||
};
|
||||
|
||||
let mut populated_target: ConfigRequirementsToml = from_str(
|
||||
let mut empty_target = ConfigRequirementsWithSources::default();
|
||||
empty_target.merge_unset_fields(source_location.clone(), source);
|
||||
assert_eq!(
|
||||
empty_target,
|
||||
ConfigRequirementsWithSources {
|
||||
allowed_approval_policies: Some(Sourced::new(
|
||||
vec![AskForApproval::OnRequest],
|
||||
source_location,
|
||||
)),
|
||||
allowed_sandbox_modes: None,
|
||||
}
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn merge_unset_fields_does_not_overwrite_existing_values() -> Result<()> {
|
||||
let existing_source = RequirementSource::LegacyManagedConfigTomlFromMdm;
|
||||
let mut populated_target = ConfigRequirementsWithSources::default();
|
||||
let populated_requirements: ConfigRequirementsToml = from_str(
|
||||
r#"
|
||||
allowed_approval_policies = ["never"]
|
||||
"#,
|
||||
)?;
|
||||
populated_target.merge_unset_fields(source);
|
||||
populated_target.merge_unset_fields(existing_source.clone(), populated_requirements);
|
||||
|
||||
let source: ConfigRequirementsToml = from_str(
|
||||
r#"
|
||||
allowed_approval_policies = ["on-request"]
|
||||
"#,
|
||||
)?;
|
||||
let source_location = RequirementSource::MdmManagedPreferences {
|
||||
domain: "com.codex".to_string(),
|
||||
key: "allowed_approval_policies".to_string(),
|
||||
};
|
||||
populated_target.merge_unset_fields(source_location, source);
|
||||
|
||||
assert_eq!(
|
||||
populated_target.allowed_approval_policies,
|
||||
Some(vec![AskForApproval::Never])
|
||||
populated_target,
|
||||
ConfigRequirementsWithSources {
|
||||
allowed_approval_policies: Some(Sourced::new(
|
||||
vec![AskForApproval::Never],
|
||||
existing_source,
|
||||
)),
|
||||
allowed_sandbox_modes: None,
|
||||
}
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn constraint_error_includes_requirement_source() -> Result<()> {
|
||||
let source: ConfigRequirementsToml = from_str(
|
||||
r#"
|
||||
allowed_approval_policies = ["on-request"]
|
||||
allowed_sandbox_modes = ["read-only"]
|
||||
"#,
|
||||
)?;
|
||||
|
||||
let requirements_toml_file = if cfg!(windows) {
|
||||
"C:\\etc\\codex\\requirements.toml"
|
||||
} else {
|
||||
"/etc/codex/requirements.toml"
|
||||
};
|
||||
let requirements_toml_file = AbsolutePathBuf::from_absolute_path(requirements_toml_file)?;
|
||||
let source_location = RequirementSource::SystemRequirementsToml {
|
||||
file: requirements_toml_file,
|
||||
};
|
||||
|
||||
let mut target = ConfigRequirementsWithSources::default();
|
||||
target.merge_unset_fields(source_location.clone(), source);
|
||||
let requirements = ConfigRequirements::try_from(target)?;
|
||||
|
||||
assert_eq!(
|
||||
requirements.approval_policy.can_set(&AskForApproval::Never),
|
||||
Err(ConstraintError::InvalidValue {
|
||||
field_name: "approval_policy",
|
||||
candidate: "Never".into(),
|
||||
allowed: "[OnRequest]".into(),
|
||||
requirement_source: source_location.clone(),
|
||||
})
|
||||
);
|
||||
assert_eq!(
|
||||
requirements
|
||||
.sandbox_policy
|
||||
.can_set(&SandboxPolicy::DangerFullAccess),
|
||||
Err(ConstraintError::InvalidValue {
|
||||
field_name: "sandbox_mode",
|
||||
candidate: "DangerFullAccess".into(),
|
||||
allowed: "[ReadOnly]".into(),
|
||||
requirement_source: source_location,
|
||||
})
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_allowed_approval_policies() -> Result<()> {
|
||||
let toml_str = r#"
|
||||
allowed_approval_policies = ["untrusted", "on-request"]
|
||||
"#;
|
||||
let config: ConfigRequirementsToml = from_str(toml_str)?;
|
||||
let requirements = ConfigRequirements::try_from(config)?;
|
||||
let requirements: ConfigRequirements = with_unknown_source(config).try_into()?;
|
||||
|
||||
assert_eq!(
|
||||
requirements.approval_policy.value(),
|
||||
@@ -208,8 +439,10 @@ mod tests {
|
||||
.approval_policy
|
||||
.can_set(&AskForApproval::OnFailure),
|
||||
Err(ConstraintError::InvalidValue {
|
||||
field_name: "approval_policy",
|
||||
candidate: "OnFailure".into(),
|
||||
allowed: "[UnlessTrusted, OnRequest]".into(),
|
||||
requirement_source: RequirementSource::Unknown,
|
||||
})
|
||||
);
|
||||
assert!(
|
||||
@@ -221,8 +454,10 @@ mod tests {
|
||||
assert_eq!(
|
||||
requirements.approval_policy.can_set(&AskForApproval::Never),
|
||||
Err(ConstraintError::InvalidValue {
|
||||
field_name: "approval_policy",
|
||||
candidate: "Never".into(),
|
||||
allowed: "[UnlessTrusted, OnRequest]".into(),
|
||||
requirement_source: RequirementSource::Unknown,
|
||||
})
|
||||
);
|
||||
assert!(
|
||||
@@ -241,7 +476,7 @@ mod tests {
|
||||
allowed_sandbox_modes = ["read-only", "workspace-write"]
|
||||
"#;
|
||||
let config: ConfigRequirementsToml = from_str(toml_str)?;
|
||||
let requirements = ConfigRequirements::try_from(config)?;
|
||||
let requirements: ConfigRequirements = with_unknown_source(config).try_into()?;
|
||||
|
||||
let root = if cfg!(windows) { "C:\\repo" } else { "/repo" };
|
||||
assert!(
|
||||
@@ -266,8 +501,10 @@ mod tests {
|
||||
.sandbox_policy
|
||||
.can_set(&SandboxPolicy::DangerFullAccess),
|
||||
Err(ConstraintError::InvalidValue {
|
||||
field_name: "sandbox_mode",
|
||||
candidate: "DangerFullAccess".into(),
|
||||
allowed: "[ReadOnly, WorkspaceWrite]".into(),
|
||||
requirement_source: RequirementSource::Unknown,
|
||||
})
|
||||
);
|
||||
assert_eq!(
|
||||
@@ -277,8 +514,10 @@ mod tests {
|
||||
network_access: NetworkAccess::Restricted,
|
||||
}),
|
||||
Err(ConstraintError::InvalidValue {
|
||||
candidate: "ExternalSandbox { network_access: Restricted }".into(),
|
||||
field_name: "sandbox_mode",
|
||||
candidate: "ExternalSandbox".into(),
|
||||
allowed: "[ReadOnly, WorkspaceWrite]".into(),
|
||||
requirement_source: RequirementSource::Unknown,
|
||||
})
|
||||
);
|
||||
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
use super::config_requirements::ConfigRequirementsToml;
|
||||
use super::config_requirements::ConfigRequirementsWithSources;
|
||||
use super::config_requirements::RequirementSource;
|
||||
use base64::Engine;
|
||||
use base64::prelude::BASE64_STANDARD;
|
||||
use core_foundation::base::TCFType;
|
||||
@@ -13,6 +15,13 @@ const MANAGED_PREFERENCES_APPLICATION_ID: &str = "com.openai.codex";
|
||||
const MANAGED_PREFERENCES_CONFIG_KEY: &str = "config_toml_base64";
|
||||
const MANAGED_PREFERENCES_REQUIREMENTS_KEY: &str = "requirements_toml_base64";
|
||||
|
||||
pub(super) fn managed_preferences_requirements_source() -> RequirementSource {
|
||||
RequirementSource::MdmManagedPreferences {
|
||||
domain: MANAGED_PREFERENCES_APPLICATION_ID.to_string(),
|
||||
key: MANAGED_PREFERENCES_REQUIREMENTS_KEY.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn load_managed_admin_config_layer(
|
||||
override_base64: Option<&str>,
|
||||
) -> io::Result<Option<TomlValue>> {
|
||||
@@ -47,21 +56,26 @@ fn load_managed_admin_config() -> io::Result<Option<TomlValue>> {
|
||||
}
|
||||
|
||||
pub(crate) async fn load_managed_admin_requirements_toml(
|
||||
target: &mut ConfigRequirementsToml,
|
||||
target: &mut ConfigRequirementsWithSources,
|
||||
override_base64: Option<&str>,
|
||||
) -> io::Result<()> {
|
||||
if let Some(encoded) = override_base64 {
|
||||
let trimmed = encoded.trim();
|
||||
if !trimmed.is_empty() {
|
||||
target.merge_unset_fields(parse_managed_requirements_base64(trimmed)?);
|
||||
if trimmed.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
target.merge_unset_fields(
|
||||
managed_preferences_requirements_source(),
|
||||
parse_managed_requirements_base64(trimmed)?,
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
match task::spawn_blocking(load_managed_admin_requirements).await {
|
||||
Ok(result) => {
|
||||
if let Some(requirements) = result? {
|
||||
target.merge_unset_fields(requirements);
|
||||
target.merge_unset_fields(managed_preferences_requirements_source(), requirements);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ mod tests;
|
||||
|
||||
use crate::config::CONFIG_TOML_FILE;
|
||||
use crate::config::ConfigToml;
|
||||
use crate::config_loader::config_requirements::ConfigRequirementsToml;
|
||||
use crate::config_loader::config_requirements::ConfigRequirementsWithSources;
|
||||
use crate::config_loader::layer_io::LoadedConfigLayers;
|
||||
use codex_app_server_protocol::ConfigLayerSource;
|
||||
use codex_protocol::config_types::SandboxMode;
|
||||
@@ -25,6 +25,9 @@ use std::path::Path;
|
||||
use toml::Value as TomlValue;
|
||||
|
||||
pub use config_requirements::ConfigRequirements;
|
||||
pub use config_requirements::ConfigRequirementsToml;
|
||||
pub use config_requirements::RequirementSource;
|
||||
pub use config_requirements::SandboxModeRequirement;
|
||||
pub use merge::merge_toml_values;
|
||||
pub use state::ConfigLayerEntry;
|
||||
pub use state::ConfigLayerStack;
|
||||
@@ -76,7 +79,7 @@ pub async fn load_config_layers_state(
|
||||
cli_overrides: &[(String, TomlValue)],
|
||||
overrides: LoaderOverrides,
|
||||
) -> io::Result<ConfigLayerStack> {
|
||||
let mut config_requirements_toml = ConfigRequirementsToml::default();
|
||||
let mut config_requirements_toml = ConfigRequirementsWithSources::default();
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
macos::load_managed_admin_requirements_toml(
|
||||
@@ -201,7 +204,11 @@ pub async fn load_config_layers_state(
|
||||
));
|
||||
}
|
||||
|
||||
ConfigLayerStack::new(layers, config_requirements_toml.try_into()?)
|
||||
ConfigLayerStack::new(
|
||||
layers,
|
||||
config_requirements_toml.clone().try_into()?,
|
||||
config_requirements_toml.into_toml(),
|
||||
)
|
||||
}
|
||||
|
||||
/// Attempts to load a config.toml file from `config_toml`.
|
||||
@@ -253,9 +260,11 @@ async fn load_config_toml_for_required_layer(
|
||||
/// If available, apply requirements from `/etc/codex/requirements.toml` to
|
||||
/// `config_requirements_toml` by filling in any unset fields.
|
||||
async fn load_requirements_toml(
|
||||
config_requirements_toml: &mut ConfigRequirementsToml,
|
||||
config_requirements_toml: &mut ConfigRequirementsWithSources,
|
||||
requirements_toml_file: impl AsRef<Path>,
|
||||
) -> io::Result<()> {
|
||||
let requirements_toml_file =
|
||||
AbsolutePathBuf::from_absolute_path(requirements_toml_file.as_ref())?;
|
||||
match tokio::fs::read_to_string(&requirements_toml_file).await {
|
||||
Ok(contents) => {
|
||||
let requirements_config: ConfigRequirementsToml =
|
||||
@@ -268,7 +277,12 @@ async fn load_requirements_toml(
|
||||
),
|
||||
)
|
||||
})?;
|
||||
config_requirements_toml.merge_unset_fields(requirements_config);
|
||||
config_requirements_toml.merge_unset_fields(
|
||||
RequirementSource::SystemRequirementsToml {
|
||||
file: requirements_toml_file.clone(),
|
||||
},
|
||||
requirements_config,
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
if e.kind() != io::ErrorKind::NotFound {
|
||||
@@ -287,7 +301,7 @@ async fn load_requirements_toml(
|
||||
}
|
||||
|
||||
async fn load_requirements_from_legacy_scheme(
|
||||
config_requirements_toml: &mut ConfigRequirementsToml,
|
||||
config_requirements_toml: &mut ConfigRequirementsWithSources,
|
||||
loaded_config_layers: LoadedConfigLayers,
|
||||
) -> io::Result<()> {
|
||||
// In this implementation, earlier layers cannot be overwritten by later
|
||||
@@ -297,12 +311,16 @@ async fn load_requirements_from_legacy_scheme(
|
||||
managed_config,
|
||||
managed_config_from_mdm,
|
||||
} = loaded_config_layers;
|
||||
for config in [
|
||||
managed_config_from_mdm,
|
||||
managed_config.map(|c| c.managed_config),
|
||||
]
|
||||
.into_iter()
|
||||
.flatten()
|
||||
|
||||
for (source, config) in managed_config_from_mdm
|
||||
.map(|config| (RequirementSource::LegacyManagedConfigTomlFromMdm, config))
|
||||
.into_iter()
|
||||
.chain(managed_config.map(|c| {
|
||||
(
|
||||
RequirementSource::LegacyManagedConfigTomlFromFile { file: c.file },
|
||||
c.managed_config,
|
||||
)
|
||||
}))
|
||||
{
|
||||
let legacy_config: LegacyManagedConfigToml =
|
||||
config.try_into().map_err(|err: toml::de::Error| {
|
||||
@@ -313,7 +331,7 @@ async fn load_requirements_from_legacy_scheme(
|
||||
})?;
|
||||
|
||||
let new_requirements_toml = ConfigRequirementsToml::from(legacy_config);
|
||||
config_requirements_toml.merge_unset_fields(new_requirements_toml);
|
||||
config_requirements_toml.merge_unset_fields(source, new_requirements_toml);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@@ -556,7 +574,14 @@ impl From<LegacyManagedConfigToml> for ConfigRequirementsToml {
|
||||
config_requirements_toml.allowed_approval_policies = Some(vec![approval_policy]);
|
||||
}
|
||||
if let Some(sandbox_mode) = sandbox_mode {
|
||||
config_requirements_toml.allowed_sandbox_modes = Some(vec![sandbox_mode.into()]);
|
||||
let required_mode: SandboxModeRequirement = sandbox_mode.into();
|
||||
// Allowing read-only is a requirement for Codex to function correctly.
|
||||
// So in this backfill path, we append read-only if it's not already specified.
|
||||
let mut allowed_modes = vec![SandboxModeRequirement::ReadOnly];
|
||||
if required_mode != SandboxModeRequirement::ReadOnly {
|
||||
allowed_modes.push(required_mode);
|
||||
}
|
||||
config_requirements_toml.allowed_sandbox_modes = Some(allowed_modes);
|
||||
}
|
||||
config_requirements_toml
|
||||
}
|
||||
@@ -604,4 +629,22 @@ foo = "xyzzy"
|
||||
assert_eq!(normalized_toml_value, TomlValue::Table(expected_toml_value));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn legacy_managed_config_backfill_includes_read_only_sandbox_mode() {
|
||||
let legacy = LegacyManagedConfigToml {
|
||||
approval_policy: None,
|
||||
sandbox_mode: Some(SandboxMode::WorkspaceWrite),
|
||||
};
|
||||
|
||||
let requirements = ConfigRequirementsToml::from(legacy);
|
||||
|
||||
assert_eq!(
|
||||
requirements.allowed_sandbox_modes,
|
||||
Some(vec![
|
||||
SandboxModeRequirement::ReadOnly,
|
||||
SandboxModeRequirement::WorkspaceWrite
|
||||
])
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use crate::config_loader::ConfigRequirements;
|
||||
use crate::config_loader::ConfigRequirementsToml;
|
||||
|
||||
use super::fingerprint::record_origins;
|
||||
use super::fingerprint::version_for_toml;
|
||||
@@ -86,18 +87,25 @@ pub struct ConfigLayerStack {
|
||||
/// Constraints that must be enforced when deriving a [Config] from the
|
||||
/// layers.
|
||||
requirements: ConfigRequirements,
|
||||
|
||||
/// Raw requirements data as loaded from requirements.toml/MDM/legacy
|
||||
/// sources. This preserves the original allow-lists so they can be
|
||||
/// surfaced via APIs.
|
||||
requirements_toml: ConfigRequirementsToml,
|
||||
}
|
||||
|
||||
impl ConfigLayerStack {
|
||||
pub fn new(
|
||||
layers: Vec<ConfigLayerEntry>,
|
||||
requirements: ConfigRequirements,
|
||||
requirements_toml: ConfigRequirementsToml,
|
||||
) -> std::io::Result<Self> {
|
||||
let user_layer_index = verify_layer_ordering(&layers)?;
|
||||
Ok(Self {
|
||||
layers,
|
||||
user_layer_index,
|
||||
requirements,
|
||||
requirements_toml,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -111,6 +119,10 @@ impl ConfigLayerStack {
|
||||
&self.requirements
|
||||
}
|
||||
|
||||
pub fn requirements_toml(&self) -> &ConfigRequirementsToml {
|
||||
&self.requirements_toml
|
||||
}
|
||||
|
||||
/// Creates a new [ConfigLayerStack] using the specified values to inject a
|
||||
/// "user layer" into the stack. If such a layer already exists, it is
|
||||
/// replaced; otherwise, it is inserted into the stack at the appropriate
|
||||
@@ -131,6 +143,7 @@ impl ConfigLayerStack {
|
||||
layers,
|
||||
user_layer_index: self.user_layer_index,
|
||||
requirements: self.requirements.clone(),
|
||||
requirements_toml: self.requirements_toml.clone(),
|
||||
}
|
||||
}
|
||||
None => {
|
||||
@@ -151,6 +164,7 @@ impl ConfigLayerStack {
|
||||
layers,
|
||||
user_layer_index: Some(user_layer_index),
|
||||
requirements: self.requirements.clone(),
|
||||
requirements_toml: self.requirements_toml.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@ use crate::config::ConfigBuilder;
|
||||
use crate::config::ConfigOverrides;
|
||||
use crate::config_loader::ConfigLayerEntry;
|
||||
use crate::config_loader::ConfigRequirements;
|
||||
use crate::config_loader::config_requirements::ConfigRequirementsToml;
|
||||
use crate::config_loader::config_requirements::ConfigRequirementsWithSources;
|
||||
use crate::config_loader::fingerprint::version_for_toml;
|
||||
use crate::config_loader::load_requirements_toml;
|
||||
use codex_protocol::protocol::AskForApproval;
|
||||
@@ -315,11 +315,14 @@ allowed_approval_policies = ["never", "on-request"]
|
||||
)
|
||||
.await?;
|
||||
|
||||
let mut config_requirements_toml = ConfigRequirementsToml::default();
|
||||
let mut config_requirements_toml = ConfigRequirementsWithSources::default();
|
||||
load_requirements_toml(&mut config_requirements_toml, &requirements_file).await?;
|
||||
|
||||
assert_eq!(
|
||||
config_requirements_toml.allowed_approval_policies,
|
||||
config_requirements_toml
|
||||
.allowed_approval_policies
|
||||
.as_deref()
|
||||
.cloned(),
|
||||
Some(vec![AskForApproval::Never, AskForApproval::OnRequest])
|
||||
);
|
||||
|
||||
|
||||
@@ -67,17 +67,18 @@ impl ContextManager {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn get_history(&mut self) -> Vec<ResponseItem> {
|
||||
/// Returns the history prepared for sending to the model. This applies a proper
|
||||
/// normalization and drop un-suited items.
|
||||
pub(crate) fn for_prompt(mut self) -> Vec<ResponseItem> {
|
||||
self.normalize_history();
|
||||
self.contents()
|
||||
self.items
|
||||
.retain(|item| !matches!(item, ResponseItem::GhostSnapshot { .. }));
|
||||
self.items
|
||||
}
|
||||
|
||||
// Returns the history prepared for sending to the model.
|
||||
// With extra response items filtered out and GhostCommits removed.
|
||||
pub(crate) fn get_history_for_prompt(&mut self) -> Vec<ResponseItem> {
|
||||
let mut history = self.get_history();
|
||||
Self::remove_ghost_snapshots(&mut history);
|
||||
history
|
||||
/// Returns raw items in the history.
|
||||
pub(crate) fn raw_items(&self) -> &[ResponseItem] {
|
||||
&self.items
|
||||
}
|
||||
|
||||
// Estimate token usage using byte-based heuristics from the truncation helpers.
|
||||
@@ -168,9 +169,7 @@ impl ContextManager {
|
||||
return;
|
||||
}
|
||||
|
||||
// Keep behavior consistent with call sites that previously operated on `get_history()`:
|
||||
// normalize first (call/output invariants), then truncate based on the normalized view.
|
||||
let snapshot = self.get_history();
|
||||
let snapshot = self.items.clone();
|
||||
let user_positions = user_message_positions(&snapshot);
|
||||
let Some(&first_user_idx) = user_positions.first() else {
|
||||
self.replace(snapshot);
|
||||
@@ -250,15 +249,6 @@ impl ContextManager {
|
||||
normalize::remove_orphan_outputs(&mut self.items);
|
||||
}
|
||||
|
||||
/// Returns a clone of the contents in the transcript.
|
||||
fn contents(&self) -> Vec<ResponseItem> {
|
||||
self.items.clone()
|
||||
}
|
||||
|
||||
fn remove_ghost_snapshots(items: &mut Vec<ResponseItem>) {
|
||||
items.retain(|item| !matches!(item, ResponseItem::GhostSnapshot { .. }));
|
||||
}
|
||||
|
||||
fn process_item(&self, item: &ResponseItem, policy: TruncationPolicy) -> ResponseItem {
|
||||
let policy_with_serialization_budget = policy.mul(1.2);
|
||||
match item {
|
||||
@@ -332,7 +322,7 @@ fn is_session_prefix(text: &str) -> bool {
|
||||
lowered.starts_with("<environment_context>")
|
||||
}
|
||||
|
||||
fn is_user_turn_boundary(item: &ResponseItem) -> bool {
|
||||
pub(crate) fn is_user_turn_boundary(item: &ResponseItem) -> bool {
|
||||
let ResponseItem::Message { role, content, .. } = item else {
|
||||
return false;
|
||||
};
|
||||
|
||||
@@ -101,7 +101,7 @@ fn filters_non_api_messages() {
|
||||
let a = assistant_msg("hello");
|
||||
h.record_items([&u, &a], policy);
|
||||
|
||||
let items = h.contents();
|
||||
let items = h.raw_items();
|
||||
assert_eq!(
|
||||
items,
|
||||
vec![
|
||||
@@ -160,8 +160,8 @@ fn get_history_for_prompt_drops_ghost_commits() {
|
||||
let items = vec![ResponseItem::GhostSnapshot {
|
||||
ghost_commit: GhostCommit::new("ghost-1".to_string(), None, Vec::new(), Vec::new()),
|
||||
}];
|
||||
let mut history = create_history_with_items(items);
|
||||
let filtered = history.get_history_for_prompt();
|
||||
let history = create_history_with_items(items);
|
||||
let filtered = history.for_prompt();
|
||||
assert_eq!(filtered, vec![]);
|
||||
}
|
||||
|
||||
@@ -184,7 +184,7 @@ fn remove_first_item_removes_matching_output_for_function_call() {
|
||||
];
|
||||
let mut h = create_history_with_items(items);
|
||||
h.remove_first_item();
|
||||
assert_eq!(h.contents(), vec![]);
|
||||
assert_eq!(h.raw_items(), vec![]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -206,7 +206,7 @@ fn remove_first_item_removes_matching_call_for_output() {
|
||||
];
|
||||
let mut h = create_history_with_items(items);
|
||||
h.remove_first_item();
|
||||
assert_eq!(h.contents(), vec![]);
|
||||
assert_eq!(h.raw_items(), vec![]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -234,7 +234,7 @@ fn remove_first_item_handles_local_shell_pair() {
|
||||
];
|
||||
let mut h = create_history_with_items(items);
|
||||
h.remove_first_item();
|
||||
assert_eq!(h.contents(), vec![]);
|
||||
assert_eq!(h.raw_items(), vec![]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -250,7 +250,7 @@ fn drop_last_n_user_turns_preserves_prefix() {
|
||||
let mut history = create_history_with_items(items);
|
||||
history.drop_last_n_user_turns(1);
|
||||
assert_eq!(
|
||||
history.get_history(),
|
||||
history.for_prompt(),
|
||||
vec![
|
||||
assistant_msg("session prefix item"),
|
||||
user_msg("u1"),
|
||||
@@ -267,7 +267,7 @@ fn drop_last_n_user_turns_preserves_prefix() {
|
||||
]);
|
||||
history.drop_last_n_user_turns(99);
|
||||
assert_eq!(
|
||||
history.get_history(),
|
||||
history.for_prompt(),
|
||||
vec![assistant_msg("session prefix item")]
|
||||
);
|
||||
}
|
||||
@@ -307,7 +307,7 @@ fn drop_last_n_user_turns_ignores_session_prefix_user_messages() {
|
||||
assistant_msg("turn 1 assistant"),
|
||||
];
|
||||
|
||||
assert_eq!(history.get_history(), expected_prefix_and_first_turn);
|
||||
assert_eq!(history.for_prompt(), expected_prefix_and_first_turn);
|
||||
|
||||
let expected_prefix_only = vec![
|
||||
user_input_text_msg("<environment_context>ctx</environment_context>"),
|
||||
@@ -337,7 +337,7 @@ fn drop_last_n_user_turns_ignores_session_prefix_user_messages() {
|
||||
assistant_msg("turn 2 assistant"),
|
||||
]);
|
||||
history.drop_last_n_user_turns(2);
|
||||
assert_eq!(history.get_history(), expected_prefix_only);
|
||||
assert_eq!(history.for_prompt(), expected_prefix_only);
|
||||
|
||||
let mut history = create_history_with_items(vec![
|
||||
user_input_text_msg("<environment_context>ctx</environment_context>"),
|
||||
@@ -355,7 +355,7 @@ fn drop_last_n_user_turns_ignores_session_prefix_user_messages() {
|
||||
assistant_msg("turn 2 assistant"),
|
||||
]);
|
||||
history.drop_last_n_user_turns(3);
|
||||
assert_eq!(history.get_history(), expected_prefix_only);
|
||||
assert_eq!(history.for_prompt(), expected_prefix_only);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -375,7 +375,7 @@ fn remove_first_item_handles_custom_tool_pair() {
|
||||
];
|
||||
let mut h = create_history_with_items(items);
|
||||
h.remove_first_item();
|
||||
assert_eq!(h.contents(), vec![]);
|
||||
assert_eq!(h.raw_items(), vec![]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -402,8 +402,8 @@ fn normalization_retains_local_shell_outputs() {
|
||||
},
|
||||
];
|
||||
|
||||
let mut history = create_history_with_items(items.clone());
|
||||
let normalized = history.get_history();
|
||||
let history = create_history_with_items(items.clone());
|
||||
let normalized = history.for_prompt();
|
||||
assert_eq!(normalized, items);
|
||||
}
|
||||
|
||||
@@ -607,7 +607,7 @@ fn normalize_adds_missing_output_for_function_call() {
|
||||
h.normalize_history();
|
||||
|
||||
assert_eq!(
|
||||
h.contents(),
|
||||
h.raw_items(),
|
||||
vec![
|
||||
ResponseItem::FunctionCall {
|
||||
id: None,
|
||||
@@ -641,7 +641,7 @@ fn normalize_adds_missing_output_for_custom_tool_call() {
|
||||
h.normalize_history();
|
||||
|
||||
assert_eq!(
|
||||
h.contents(),
|
||||
h.raw_items(),
|
||||
vec![
|
||||
ResponseItem::CustomToolCall {
|
||||
id: None,
|
||||
@@ -678,7 +678,7 @@ fn normalize_adds_missing_output_for_local_shell_call_with_id() {
|
||||
h.normalize_history();
|
||||
|
||||
assert_eq!(
|
||||
h.contents(),
|
||||
h.raw_items(),
|
||||
vec![
|
||||
ResponseItem::LocalShellCall {
|
||||
id: None,
|
||||
@@ -717,7 +717,7 @@ fn normalize_removes_orphan_function_call_output() {
|
||||
|
||||
h.normalize_history();
|
||||
|
||||
assert_eq!(h.contents(), vec![]);
|
||||
assert_eq!(h.raw_items(), vec![]);
|
||||
}
|
||||
|
||||
#[cfg(not(debug_assertions))]
|
||||
@@ -731,7 +731,7 @@ fn normalize_removes_orphan_custom_tool_call_output() {
|
||||
|
||||
h.normalize_history();
|
||||
|
||||
assert_eq!(h.contents(), vec![]);
|
||||
assert_eq!(h.raw_items(), vec![]);
|
||||
}
|
||||
|
||||
#[cfg(not(debug_assertions))]
|
||||
@@ -780,7 +780,7 @@ fn normalize_mixed_inserts_and_removals() {
|
||||
h.normalize_history();
|
||||
|
||||
assert_eq!(
|
||||
h.contents(),
|
||||
h.raw_items(),
|
||||
vec![
|
||||
ResponseItem::FunctionCall {
|
||||
id: None,
|
||||
@@ -840,7 +840,7 @@ fn normalize_adds_missing_output_for_function_call_inserts_output() {
|
||||
let mut h = create_history_with_items(items);
|
||||
h.normalize_history();
|
||||
assert_eq!(
|
||||
h.contents(),
|
||||
h.raw_items(),
|
||||
vec![
|
||||
ResponseItem::FunctionCall {
|
||||
id: None,
|
||||
|
||||
@@ -2,3 +2,4 @@ mod history;
|
||||
mod normalize;
|
||||
|
||||
pub(crate) use history::ContextManager;
|
||||
pub(crate) use history::is_user_turn_boundary;
|
||||
|
||||
@@ -22,7 +22,7 @@ use tokio::task::JoinError;
|
||||
pub type Result<T> = std::result::Result<T, CodexErr>;
|
||||
|
||||
/// Limit UI error messages to a reasonable size while keeping useful context.
|
||||
const ERROR_MESSAGE_UI_MAX_BYTES: usize = 2 * 1024; // 4 KiB
|
||||
const ERROR_MESSAGE_UI_MAX_BYTES: usize = 2 * 1024; // 2 KiB
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum SandboxErr {
|
||||
@@ -181,6 +181,43 @@ impl From<CancelErr> for CodexErr {
|
||||
}
|
||||
}
|
||||
|
||||
impl CodexErr {
|
||||
pub fn is_retryable(&self) -> bool {
|
||||
match self {
|
||||
CodexErr::TurnAborted
|
||||
| CodexErr::Interrupted
|
||||
| CodexErr::EnvVar(_)
|
||||
| CodexErr::Fatal(_)
|
||||
| CodexErr::UsageNotIncluded
|
||||
| CodexErr::QuotaExceeded
|
||||
| CodexErr::InvalidImageRequest()
|
||||
| CodexErr::InvalidRequest(_)
|
||||
| CodexErr::RefreshTokenFailed(_)
|
||||
| CodexErr::UnsupportedOperation(_)
|
||||
| CodexErr::Sandbox(_)
|
||||
| CodexErr::LandlockSandboxExecutableNotProvided
|
||||
| CodexErr::RetryLimit(_)
|
||||
| CodexErr::ContextWindowExceeded
|
||||
| CodexErr::ThreadNotFound(_)
|
||||
| CodexErr::Spawn
|
||||
| CodexErr::SessionConfiguredNotFirstEvent
|
||||
| CodexErr::UsageLimitReached(_) => false,
|
||||
CodexErr::Stream(..)
|
||||
| CodexErr::Timeout
|
||||
| CodexErr::UnexpectedStatus(_)
|
||||
| CodexErr::ResponseStreamFailed(_)
|
||||
| CodexErr::ConnectionFailed(_)
|
||||
| CodexErr::InternalServerError
|
||||
| CodexErr::InternalAgentDied
|
||||
| CodexErr::Io(_)
|
||||
| CodexErr::Json(_)
|
||||
| CodexErr::TokioJoin(_) => true,
|
||||
#[cfg(target_os = "linux")]
|
||||
CodexErr::LandlockRuleset(_) | CodexErr::LandlockPathFd(_) => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ConnectionFailedError {
|
||||
pub source: reqwest::Error,
|
||||
|
||||
@@ -29,9 +29,16 @@ where
|
||||
"HOME", "LOGNAME", "PATH", "SHELL", "USER", "USERNAME", "TMPDIR", "TEMP", "TMP",
|
||||
];
|
||||
let allow: HashSet<&str> = CORE_VARS.iter().copied().collect();
|
||||
vars.into_iter()
|
||||
.filter(|(k, _)| allow.contains(k.as_str()))
|
||||
.collect()
|
||||
let is_core_var = |name: &str| {
|
||||
if cfg!(target_os = "windows") {
|
||||
CORE_VARS
|
||||
.iter()
|
||||
.any(|allowed| allowed.eq_ignore_ascii_case(name))
|
||||
} else {
|
||||
allow.contains(name)
|
||||
}
|
||||
};
|
||||
vars.into_iter().filter(|(k, _)| is_core_var(k)).collect()
|
||||
}
|
||||
};
|
||||
|
||||
@@ -198,6 +205,30 @@ mod tests {
|
||||
assert_eq!(result, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(target_os = "windows")]
|
||||
fn test_core_inherit_respects_case_insensitive_names_on_windows() {
|
||||
let vars = make_vars(&[
|
||||
("Path", "C:\\Windows\\System32"),
|
||||
("TEMP", "C:\\Temp"),
|
||||
("FOO", "bar"),
|
||||
]);
|
||||
|
||||
let policy = ShellEnvironmentPolicy {
|
||||
inherit: ShellEnvironmentPolicyInherit::Core,
|
||||
ignore_default_excludes: true,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let result = populate_env(vars, &policy);
|
||||
let expected: HashMap<String, String> = hashmap! {
|
||||
"Path".to_string() => "C:\\Windows\\System32".to_string(),
|
||||
"TEMP".to_string() => "C:\\Temp".to_string(),
|
||||
};
|
||||
|
||||
assert_eq!(result, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_inherit_none() {
|
||||
let vars = make_vars(&[("PATH", "/usr/bin"), ("HOME", "/home")]);
|
||||
|
||||
@@ -421,6 +421,7 @@ mod tests {
|
||||
use crate::config_loader::ConfigLayerEntry;
|
||||
use crate::config_loader::ConfigLayerStack;
|
||||
use crate::config_loader::ConfigRequirements;
|
||||
use crate::config_loader::ConfigRequirementsToml;
|
||||
use crate::features::Feature;
|
||||
use crate::features::Features;
|
||||
use codex_app_server_protocol::ConfigLayerSource;
|
||||
@@ -441,7 +442,12 @@ mod tests {
|
||||
ConfigLayerSource::Project { dot_codex_folder },
|
||||
TomlValue::Table(Default::default()),
|
||||
);
|
||||
ConfigLayerStack::new(vec![layer], ConfigRequirements::default()).expect("ConfigLayerStack")
|
||||
ConfigLayerStack::new(
|
||||
vec![layer],
|
||||
ConfigRequirements::default(),
|
||||
ConfigRequirementsToml::default(),
|
||||
)
|
||||
.expect("ConfigLayerStack")
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -573,7 +579,11 @@ mod tests {
|
||||
TomlValue::Table(Default::default()),
|
||||
),
|
||||
];
|
||||
let config_stack = ConfigLayerStack::new(layers, ConfigRequirements::default())?;
|
||||
let config_stack = ConfigLayerStack::new(
|
||||
layers,
|
||||
ConfigRequirements::default(),
|
||||
ConfigRequirementsToml::default(),
|
||||
)?;
|
||||
|
||||
let policy = load_exec_policy(&config_stack).await?;
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
|
||||
use crate::config::ConfigToml;
|
||||
use crate::config::profile::ConfigProfile;
|
||||
use codex_otel::OtelManager;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use std::collections::BTreeMap;
|
||||
@@ -89,6 +90,8 @@ pub enum Feature {
|
||||
Tui2,
|
||||
/// Enforce UTF8 output in Powershell.
|
||||
PowershellUtf8,
|
||||
/// Compress request bodies (zstd) when sending streaming requests to codex-backend.
|
||||
EnableRequestCompression,
|
||||
}
|
||||
|
||||
impl Feature {
|
||||
@@ -191,6 +194,21 @@ impl Features {
|
||||
.map(|usage| (usage.alias.as_str(), usage.feature))
|
||||
}
|
||||
|
||||
pub fn emit_metrics(&self, otel: &OtelManager) {
|
||||
for feature in FEATURES {
|
||||
if self.enabled(feature.id) != feature.default_enabled {
|
||||
otel.counter(
|
||||
"codex.feature.state",
|
||||
1,
|
||||
&[
|
||||
("feature", feature.key),
|
||||
("value", &self.enabled(feature.id).to_string()),
|
||||
],
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply a table of key -> bool toggles (e.g. from TOML).
|
||||
pub fn apply_map(&mut self, m: &BTreeMap<String, bool>) {
|
||||
for (k, v) in m {
|
||||
@@ -374,6 +392,12 @@ pub const FEATURES: &[FeatureSpec] = &[
|
||||
stage: Stage::Experimental,
|
||||
default_enabled: false,
|
||||
},
|
||||
FeatureSpec {
|
||||
id: Feature::EnableRequestCompression,
|
||||
key: "enable_request_compression",
|
||||
stage: Stage::Experimental,
|
||||
default_enabled: false,
|
||||
},
|
||||
FeatureSpec {
|
||||
id: Feature::Tui2,
|
||||
key: "tui2",
|
||||
|
||||
@@ -4,9 +4,6 @@ use thiserror::Error;
|
||||
pub enum FunctionCallError {
|
||||
#[error("{0}")]
|
||||
RespondToModel(String),
|
||||
#[error("{0}")]
|
||||
#[allow(dead_code)] // TODO(jif) fix in a follow-up PR
|
||||
Denied(String),
|
||||
#[error("LocalShellCall without call_id or id")]
|
||||
MissingLocalShellCallId,
|
||||
#[error("Fatal error: {0}")]
|
||||
|
||||
@@ -51,6 +51,7 @@ pub mod token_data;
|
||||
mod truncate;
|
||||
mod unified_exec;
|
||||
mod user_instructions;
|
||||
pub mod windows_sandbox;
|
||||
pub use model_provider_info::CHAT_WIRE_API_DEPRECATION_SUMMARY;
|
||||
pub use model_provider_info::DEFAULT_LMSTUDIO_PORT;
|
||||
pub use model_provider_info::DEFAULT_OLLAMA_PORT;
|
||||
@@ -114,6 +115,8 @@ pub use command_safety::is_safe_command;
|
||||
pub use exec_policy::ExecPolicyError;
|
||||
pub use exec_policy::load_exec_policy;
|
||||
pub use safety::get_platform_sandbox;
|
||||
pub use safety::is_windows_elevated_sandbox_enabled;
|
||||
pub use safety::set_windows_elevated_sandbox_enabled;
|
||||
pub use safety::set_windows_sandbox_enabled;
|
||||
// Re-export the protocol types from the standalone `codex-protocol` crate so existing
|
||||
// `codex_core::protocol::...` references continue to work across the workspace.
|
||||
|
||||
@@ -12,6 +12,7 @@ use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::RwLock;
|
||||
use tokio::sync::TryLockError;
|
||||
use tokio::time::timeout;
|
||||
use tracing::error;
|
||||
|
||||
use super::cache;
|
||||
@@ -21,6 +22,7 @@ use crate::api_bridge::map_api_error;
|
||||
use crate::auth::AuthManager;
|
||||
use crate::config::Config;
|
||||
use crate::default_client::build_reqwest_client;
|
||||
use crate::error::CodexErr;
|
||||
use crate::error::Result as CoreResult;
|
||||
use crate::features::Feature;
|
||||
use crate::model_provider_info::ModelProviderInfo;
|
||||
@@ -29,6 +31,7 @@ use crate::models_manager::model_presets::builtin_model_presets;
|
||||
|
||||
const MODEL_CACHE_FILE: &str = "models_cache.json";
|
||||
const DEFAULT_MODEL_CACHE_TTL: Duration = Duration::from_secs(300);
|
||||
const MODELS_REFRESH_TIMEOUT: Duration = Duration::from_secs(5);
|
||||
const OPENAI_DEFAULT_API_MODEL: &str = "gpt-5.1-codex-max";
|
||||
const OPENAI_DEFAULT_CHATGPT_MODEL: &str = "gpt-5.2-codex";
|
||||
const CODEX_AUTO_BALANCED_MODEL: &str = "codex-auto-balanced";
|
||||
@@ -98,17 +101,20 @@ impl ModelsManager {
|
||||
if !remote_models_feature || self.auth_manager.get_auth_mode() == Some(AuthMode::ApiKey) {
|
||||
return Ok(());
|
||||
}
|
||||
let auth = self.auth_manager.auth();
|
||||
let auth = self.auth_manager.auth().await;
|
||||
let api_provider = self.provider.to_api_provider(Some(AuthMode::ChatGPT))?;
|
||||
let api_auth = auth_provider_from_auth(auth.clone(), &self.provider).await?;
|
||||
let api_auth = auth_provider_from_auth(auth.clone(), &self.provider)?;
|
||||
let transport = ReqwestTransport::new(build_reqwest_client());
|
||||
let client = ModelsClient::new(transport, api_provider, api_auth);
|
||||
|
||||
let client_version = format_client_version_to_whole();
|
||||
let (models, etag) = client
|
||||
.list_models(&client_version, HeaderMap::new())
|
||||
.await
|
||||
.map_err(map_api_error)?;
|
||||
let (models, etag) = timeout(
|
||||
MODELS_REFRESH_TIMEOUT,
|
||||
client.list_models(&client_version, HeaderMap::new()),
|
||||
)
|
||||
.await
|
||||
.map_err(|_| CodexErr::Timeout)?
|
||||
.map_err(map_api_error)?;
|
||||
|
||||
self.apply_remote_models(models.clone()).await;
|
||||
*self.etag.write().await = etag.clone();
|
||||
|
||||
@@ -12,7 +12,7 @@ use crate::config::Config;
|
||||
use crate::truncate::approx_bytes_for_tokens;
|
||||
use tracing::warn;
|
||||
|
||||
const BASE_INSTRUCTIONS: &str = include_str!("../../prompt.md");
|
||||
pub const BASE_INSTRUCTIONS: &str = include_str!("../../prompt.md");
|
||||
const BASE_INSTRUCTIONS_WITH_APPLY_PATCH: &str =
|
||||
include_str!("../../prompt_with_apply_patch_instructions.md");
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ use codex_otel::config::OtelExporter;
|
||||
use codex_otel::config::OtelHttpProtocol;
|
||||
use codex_otel::config::OtelSettings;
|
||||
use codex_otel::config::OtelTlsConfig as OtelTlsSettings;
|
||||
use codex_otel::otel_provider::OtelProvider;
|
||||
use codex_otel::traces::otel_provider::OtelProvider;
|
||||
use std::error::Error;
|
||||
|
||||
/// Build an OpenTelemetry provider from the app Config.
|
||||
@@ -18,6 +18,7 @@ pub fn build_provider(
|
||||
) -> Result<Option<OtelProvider>, Box<dyn Error>> {
|
||||
let to_otel_exporter = |kind: &Kind| match kind {
|
||||
Kind::None => OtelExporter::None,
|
||||
Kind::Statsig => OtelExporter::Statsig,
|
||||
Kind::OtlpHttp {
|
||||
endpoint,
|
||||
headers,
|
||||
@@ -63,6 +64,11 @@ pub fn build_provider(
|
||||
|
||||
let exporter = to_otel_exporter(&config.otel.exporter);
|
||||
let trace_exporter = to_otel_exporter(&config.otel.trace_exporter);
|
||||
let metrics_exporter = if config.analytics {
|
||||
to_otel_exporter(&config.otel.metrics_exporter)
|
||||
} else {
|
||||
OtelExporter::None
|
||||
};
|
||||
|
||||
OtelProvider::from(&OtelSettings {
|
||||
service_name: originator().value.to_owned(),
|
||||
@@ -71,6 +77,7 @@ pub fn build_provider(
|
||||
environment: config.otel.environment.to_string(),
|
||||
exporter,
|
||||
trace_exporter,
|
||||
metrics_exporter,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -45,12 +45,17 @@ pub(crate) fn user_message_positions_in_rollout(items: &[RolloutItem]) -> Vec<us
|
||||
/// The boundary index is 0-based from the start of `items` (so `n_from_start = 0` returns
|
||||
/// a prefix that excludes the first user message and everything after it).
|
||||
///
|
||||
/// If `n_from_start` is `usize::MAX`, this returns the full rollout (no truncation).
|
||||
/// If fewer than or equal to `n_from_start` user messages exist, this returns an empty
|
||||
/// vector (out of range).
|
||||
pub(crate) fn truncate_rollout_before_nth_user_message_from_start(
|
||||
items: &[RolloutItem],
|
||||
n_from_start: usize,
|
||||
) -> Vec<RolloutItem> {
|
||||
if n_from_start == usize::MAX {
|
||||
return items.to_vec();
|
||||
}
|
||||
|
||||
let user_positions = user_message_positions_in_rollout(items);
|
||||
|
||||
// If fewer than or equal to n user messages exist, treat as empty (out of range).
|
||||
@@ -139,6 +144,22 @@ mod tests {
|
||||
assert_matches!(truncated2.as_slice(), []);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncation_max_keeps_full_rollout() {
|
||||
let rollout = vec![
|
||||
RolloutItem::ResponseItem(user_msg("u1")),
|
||||
RolloutItem::ResponseItem(assistant_msg("a1")),
|
||||
RolloutItem::ResponseItem(user_msg("u2")),
|
||||
];
|
||||
|
||||
let truncated = truncate_rollout_before_nth_user_message_from_start(&rollout, usize::MAX);
|
||||
|
||||
assert_eq!(
|
||||
serde_json::to_value(&truncated).unwrap(),
|
||||
serde_json::to_value(&rollout).unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncates_rollout_from_start_applies_thread_rollback_markers() {
|
||||
let rollout_items = vec![
|
||||
|
||||
@@ -328,6 +328,8 @@ Write the YAML frontmatter with `name` and `description`:
|
||||
- Include all "when to use" information here - Not in the body. The body is only loaded after triggering, so "When to Use This Skill" sections in the body are not helpful to Codex.
|
||||
- Example description for a `docx` skill: "Comprehensive document creation, editing, and analysis with support for tracked changes, comments, formatting preservation, and text extraction. Use when Codex needs to work with professional documents (.docx files) for: (1) Creating new documents, (2) Modifying or editing content, (3) Working with tracked changes, (4) Adding comments, or any other document tasks"
|
||||
|
||||
Ensure the frontmatter is valid YAML. Keep `name` and `description` as single-line scalars. If either could be interpreted as YAML syntax, wrap it in quotes.
|
||||
|
||||
Do not include any other fields in YAML frontmatter.
|
||||
|
||||
##### Body
|
||||
|
||||
@@ -6,7 +6,7 @@ use crate::skills::model::SkillMetadata;
|
||||
use crate::skills::system::system_cache_root_dir;
|
||||
use codex_app_server_protocol::ConfigLayerSource;
|
||||
use codex_protocol::protocol::SkillScope;
|
||||
use dunce::canonicalize as normalize_path;
|
||||
use dunce::canonicalize as canonicalize_path;
|
||||
use serde::Deserialize;
|
||||
use std::collections::HashSet;
|
||||
use std::collections::VecDeque;
|
||||
@@ -36,6 +36,9 @@ const SKILLS_DIR_NAME: &str = "skills";
|
||||
const MAX_NAME_LEN: usize = 64;
|
||||
const MAX_DESCRIPTION_LEN: usize = 1024;
|
||||
const MAX_SHORT_DESCRIPTION_LEN: usize = MAX_DESCRIPTION_LEN;
|
||||
// Traversal depth from the skills root.
|
||||
const MAX_SCAN_DEPTH: usize = 6;
|
||||
const MAX_SKILLS_DIRS_PER_ROOT: usize = 2000;
|
||||
|
||||
#[derive(Debug)]
|
||||
enum SkillParseError {
|
||||
@@ -165,7 +168,7 @@ pub(crate) fn skill_roots_from_layer_stack(
|
||||
}
|
||||
|
||||
fn discover_skills_under_root(root: &Path, scope: SkillScope, outcome: &mut SkillLoadOutcome) {
|
||||
let Ok(root) = normalize_path(root) else {
|
||||
let Ok(root) = canonicalize_path(root) else {
|
||||
return;
|
||||
};
|
||||
|
||||
@@ -173,8 +176,38 @@ fn discover_skills_under_root(root: &Path, scope: SkillScope, outcome: &mut Skil
|
||||
return;
|
||||
}
|
||||
|
||||
let mut queue: VecDeque<PathBuf> = VecDeque::from([root]);
|
||||
while let Some(dir) = queue.pop_front() {
|
||||
fn enqueue_dir(
|
||||
queue: &mut VecDeque<(PathBuf, usize)>,
|
||||
visited_dirs: &mut HashSet<PathBuf>,
|
||||
truncated_by_dir_limit: &mut bool,
|
||||
path: PathBuf,
|
||||
depth: usize,
|
||||
) {
|
||||
if depth > MAX_SCAN_DEPTH {
|
||||
return;
|
||||
}
|
||||
if visited_dirs.len() >= MAX_SKILLS_DIRS_PER_ROOT {
|
||||
*truncated_by_dir_limit = true;
|
||||
return;
|
||||
}
|
||||
if visited_dirs.insert(path.clone()) {
|
||||
queue.push_back((path, depth));
|
||||
}
|
||||
}
|
||||
|
||||
// Follow symlinks for user, admin, and repo skills. System skills are written by Codex itself.
|
||||
let follow_symlinks = matches!(
|
||||
scope,
|
||||
SkillScope::Repo | SkillScope::User | SkillScope::Admin
|
||||
);
|
||||
|
||||
let mut visited_dirs: HashSet<PathBuf> = HashSet::new();
|
||||
visited_dirs.insert(root.clone());
|
||||
|
||||
let mut queue: VecDeque<(PathBuf, usize)> = VecDeque::from([(root.clone(), 0)]);
|
||||
let mut truncated_by_dir_limit = false;
|
||||
|
||||
while let Some((dir, depth)) = queue.pop_front() {
|
||||
let entries = match fs::read_dir(&dir) {
|
||||
Ok(entries) => entries,
|
||||
Err(e) => {
|
||||
@@ -199,11 +232,64 @@ fn discover_skills_under_root(root: &Path, scope: SkillScope, outcome: &mut Skil
|
||||
};
|
||||
|
||||
if file_type.is_symlink() {
|
||||
if !follow_symlinks {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Follow the symlink to determine what it points to.
|
||||
let metadata = match fs::metadata(&path) {
|
||||
Ok(metadata) => metadata,
|
||||
Err(e) => {
|
||||
error!(
|
||||
"failed to stat skills entry {} (symlink): {e:#}",
|
||||
path.display()
|
||||
);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
if metadata.is_dir() {
|
||||
let Ok(resolved_dir) = canonicalize_path(&path) else {
|
||||
continue;
|
||||
};
|
||||
enqueue_dir(
|
||||
&mut queue,
|
||||
&mut visited_dirs,
|
||||
&mut truncated_by_dir_limit,
|
||||
resolved_dir,
|
||||
depth + 1,
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
if metadata.is_file() && file_name == SKILLS_FILENAME {
|
||||
match parse_skill_file(&path, scope) {
|
||||
Ok(skill) => outcome.skills.push(skill),
|
||||
Err(err) => {
|
||||
if scope != SkillScope::System {
|
||||
outcome.errors.push(SkillError {
|
||||
path,
|
||||
message: err.to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
if file_type.is_dir() {
|
||||
queue.push_back(path);
|
||||
let Ok(resolved_dir) = canonicalize_path(&path) else {
|
||||
continue;
|
||||
};
|
||||
enqueue_dir(
|
||||
&mut queue,
|
||||
&mut visited_dirs,
|
||||
&mut truncated_by_dir_limit,
|
||||
resolved_dir,
|
||||
depth + 1,
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -224,6 +310,14 @@ fn discover_skills_under_root(root: &Path, scope: SkillScope, outcome: &mut Skil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if truncated_by_dir_limit {
|
||||
tracing::warn!(
|
||||
"skills scan truncated after {} directories (root: {})",
|
||||
MAX_SKILLS_DIRS_PER_ROOT,
|
||||
root.display()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_skill_file(path: &Path, scope: SkillScope) -> Result<SkillMetadata, SkillParseError> {
|
||||
@@ -253,7 +347,7 @@ fn parse_skill_file(path: &Path, scope: SkillScope) -> Result<SkillMetadata, Ski
|
||||
)?;
|
||||
}
|
||||
|
||||
let resolved_path = normalize_path(path).unwrap_or_else(|_| path.to_path_buf());
|
||||
let resolved_path = canonicalize_path(path).unwrap_or_else(|_| path.to_path_buf());
|
||||
|
||||
Ok(SkillMetadata {
|
||||
name,
|
||||
@@ -316,6 +410,7 @@ mod tests {
|
||||
use crate::config_loader::ConfigLayerEntry;
|
||||
use crate::config_loader::ConfigLayerStack;
|
||||
use crate::config_loader::ConfigRequirements;
|
||||
use crate::config_loader::ConfigRequirementsToml;
|
||||
use codex_protocol::protocol::SkillScope;
|
||||
use codex_utils_absolute_path::AbsolutePathBuf;
|
||||
use pretty_assertions::assert_eq;
|
||||
@@ -350,7 +445,7 @@ mod tests {
|
||||
}
|
||||
|
||||
fn normalized(path: &Path) -> PathBuf {
|
||||
normalize_path(path).unwrap_or_else(|_| path.to_path_buf())
|
||||
canonicalize_path(path).unwrap_or_else(|_| path.to_path_buf())
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -377,7 +472,11 @@ mod tests {
|
||||
TomlValue::Table(toml::map::Map::new()),
|
||||
),
|
||||
];
|
||||
let stack = ConfigLayerStack::new(layers, ConfigRequirements::default())?;
|
||||
let stack = ConfigLayerStack::new(
|
||||
layers,
|
||||
ConfigRequirements::default(),
|
||||
ConfigRequirementsToml::default(),
|
||||
)?;
|
||||
|
||||
let got = skill_roots_from_layer_stack(&stack)
|
||||
.into_iter()
|
||||
@@ -429,6 +528,243 @@ mod tests {
|
||||
path
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
fn symlink_dir(target: &Path, link: &Path) {
|
||||
std::os::unix::fs::symlink(target, link).unwrap();
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
fn symlink_file(target: &Path, link: &Path) {
|
||||
std::os::unix::fs::symlink(target, link).unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[cfg(unix)]
|
||||
async fn loads_skills_via_symlinked_subdir_for_user_scope() {
|
||||
let codex_home = tempfile::tempdir().expect("tempdir");
|
||||
let shared = tempfile::tempdir().expect("tempdir");
|
||||
|
||||
let shared_skill_path = write_skill_at(shared.path(), "demo", "linked-skill", "from link");
|
||||
|
||||
fs::create_dir_all(codex_home.path().join("skills")).unwrap();
|
||||
symlink_dir(shared.path(), &codex_home.path().join("skills/shared"));
|
||||
|
||||
let cfg = make_config(&codex_home).await;
|
||||
let outcome = load_skills(&cfg);
|
||||
|
||||
assert!(
|
||||
outcome.errors.is_empty(),
|
||||
"unexpected errors: {:?}",
|
||||
outcome.errors
|
||||
);
|
||||
assert_eq!(
|
||||
outcome.skills,
|
||||
vec![SkillMetadata {
|
||||
name: "linked-skill".to_string(),
|
||||
description: "from link".to_string(),
|
||||
short_description: None,
|
||||
path: normalized(&shared_skill_path),
|
||||
scope: SkillScope::User,
|
||||
}]
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[cfg(unix)]
|
||||
async fn loads_skills_via_symlinked_skill_file_for_user_scope() {
|
||||
let codex_home = tempfile::tempdir().expect("tempdir");
|
||||
let shared = tempfile::tempdir().expect("tempdir");
|
||||
|
||||
let shared_skill_path =
|
||||
write_skill_at(shared.path(), "demo", "linked-file-skill", "from link");
|
||||
|
||||
let skill_dir = codex_home.path().join("skills/demo");
|
||||
fs::create_dir_all(&skill_dir).unwrap();
|
||||
symlink_file(&shared_skill_path, &skill_dir.join(SKILLS_FILENAME));
|
||||
|
||||
let cfg = make_config(&codex_home).await;
|
||||
let outcome = load_skills(&cfg);
|
||||
|
||||
assert!(
|
||||
outcome.errors.is_empty(),
|
||||
"unexpected errors: {:?}",
|
||||
outcome.errors
|
||||
);
|
||||
assert_eq!(
|
||||
outcome.skills,
|
||||
vec![SkillMetadata {
|
||||
name: "linked-file-skill".to_string(),
|
||||
description: "from link".to_string(),
|
||||
short_description: None,
|
||||
path: normalized(&shared_skill_path),
|
||||
scope: SkillScope::User,
|
||||
}]
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[cfg(unix)]
|
||||
async fn does_not_loop_on_symlink_cycle_for_user_scope() {
|
||||
let codex_home = tempfile::tempdir().expect("tempdir");
|
||||
|
||||
// Create a cycle:
|
||||
// $CODEX_HOME/skills/cycle/loop -> $CODEX_HOME/skills/cycle
|
||||
let cycle_dir = codex_home.path().join("skills/cycle");
|
||||
fs::create_dir_all(&cycle_dir).unwrap();
|
||||
symlink_dir(&cycle_dir, &cycle_dir.join("loop"));
|
||||
|
||||
let skill_path = write_skill_at(&cycle_dir, "demo", "cycle-skill", "still loads");
|
||||
|
||||
let cfg = make_config(&codex_home).await;
|
||||
let outcome = load_skills(&cfg);
|
||||
|
||||
assert!(
|
||||
outcome.errors.is_empty(),
|
||||
"unexpected errors: {:?}",
|
||||
outcome.errors
|
||||
);
|
||||
assert_eq!(
|
||||
outcome.skills,
|
||||
vec![SkillMetadata {
|
||||
name: "cycle-skill".to_string(),
|
||||
description: "still loads".to_string(),
|
||||
short_description: None,
|
||||
path: normalized(&skill_path),
|
||||
scope: SkillScope::User,
|
||||
}]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(unix)]
|
||||
fn loads_skills_via_symlinked_subdir_for_admin_scope() {
|
||||
let admin_root = tempfile::tempdir().expect("tempdir");
|
||||
let shared = tempfile::tempdir().expect("tempdir");
|
||||
|
||||
let shared_skill_path =
|
||||
write_skill_at(shared.path(), "demo", "admin-linked-skill", "from link");
|
||||
fs::create_dir_all(admin_root.path()).unwrap();
|
||||
symlink_dir(shared.path(), &admin_root.path().join("shared"));
|
||||
|
||||
let outcome = load_skills_from_roots([SkillRoot {
|
||||
path: admin_root.path().to_path_buf(),
|
||||
scope: SkillScope::Admin,
|
||||
}]);
|
||||
|
||||
assert!(
|
||||
outcome.errors.is_empty(),
|
||||
"unexpected errors: {:?}",
|
||||
outcome.errors
|
||||
);
|
||||
assert_eq!(
|
||||
outcome.skills,
|
||||
vec![SkillMetadata {
|
||||
name: "admin-linked-skill".to_string(),
|
||||
description: "from link".to_string(),
|
||||
short_description: None,
|
||||
path: normalized(&shared_skill_path),
|
||||
scope: SkillScope::Admin,
|
||||
}]
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[cfg(unix)]
|
||||
async fn loads_skills_via_symlinked_subdir_for_repo_scope() {
|
||||
let codex_home = tempfile::tempdir().expect("tempdir");
|
||||
let repo_dir = tempfile::tempdir().expect("tempdir");
|
||||
mark_as_git_repo(repo_dir.path());
|
||||
let shared = tempfile::tempdir().expect("tempdir");
|
||||
|
||||
let linked_skill_path =
|
||||
write_skill_at(shared.path(), "demo", "repo-linked-skill", "from link");
|
||||
let repo_skills_root = repo_dir
|
||||
.path()
|
||||
.join(REPO_ROOT_CONFIG_DIR_NAME)
|
||||
.join(SKILLS_DIR_NAME);
|
||||
fs::create_dir_all(&repo_skills_root).unwrap();
|
||||
symlink_dir(shared.path(), &repo_skills_root.join("shared"));
|
||||
|
||||
let cfg = make_config_for_cwd(&codex_home, repo_dir.path().to_path_buf()).await;
|
||||
let outcome = load_skills(&cfg);
|
||||
|
||||
assert!(
|
||||
outcome.errors.is_empty(),
|
||||
"unexpected errors: {:?}",
|
||||
outcome.errors
|
||||
);
|
||||
assert_eq!(
|
||||
outcome.skills,
|
||||
vec![SkillMetadata {
|
||||
name: "repo-linked-skill".to_string(),
|
||||
description: "from link".to_string(),
|
||||
short_description: None,
|
||||
path: normalized(&linked_skill_path),
|
||||
scope: SkillScope::Repo,
|
||||
}]
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[cfg(unix)]
|
||||
async fn system_scope_ignores_symlinked_subdir() {
|
||||
let codex_home = tempfile::tempdir().expect("tempdir");
|
||||
let shared = tempfile::tempdir().expect("tempdir");
|
||||
|
||||
write_skill_at(shared.path(), "demo", "system-linked-skill", "from link");
|
||||
|
||||
let system_root = codex_home.path().join("skills/.system");
|
||||
fs::create_dir_all(&system_root).unwrap();
|
||||
symlink_dir(shared.path(), &system_root.join("shared"));
|
||||
|
||||
let cfg = make_config(&codex_home).await;
|
||||
let outcome = load_skills(&cfg);
|
||||
|
||||
assert!(
|
||||
outcome.errors.is_empty(),
|
||||
"unexpected errors: {:?}",
|
||||
outcome.errors
|
||||
);
|
||||
assert_eq!(outcome.skills.len(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn respects_max_scan_depth_for_user_scope() {
|
||||
let codex_home = tempfile::tempdir().expect("tempdir");
|
||||
|
||||
let within_depth_path = write_skill(
|
||||
&codex_home,
|
||||
"d0/d1/d2/d3/d4/d5",
|
||||
"within-depth-skill",
|
||||
"loads",
|
||||
);
|
||||
let _too_deep_path = write_skill(
|
||||
&codex_home,
|
||||
"d0/d1/d2/d3/d4/d5/d6",
|
||||
"too-deep-skill",
|
||||
"should not load",
|
||||
);
|
||||
|
||||
let cfg = make_config(&codex_home).await;
|
||||
let outcome = load_skills(&cfg);
|
||||
|
||||
assert!(
|
||||
outcome.errors.is_empty(),
|
||||
"unexpected errors: {:?}",
|
||||
outcome.errors
|
||||
);
|
||||
assert_eq!(
|
||||
outcome.skills,
|
||||
vec![SkillMetadata {
|
||||
name: "within-depth-skill".to_string(),
|
||||
description: "loads".to_string(),
|
||||
short_description: None,
|
||||
path: normalized(&within_depth_path),
|
||||
scope: SkillScope::User,
|
||||
}]
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn loads_valid_skill() {
|
||||
let codex_home = tempfile::tempdir().expect("tempdir");
|
||||
@@ -1029,7 +1365,7 @@ mod tests {
|
||||
outcome.errors
|
||||
);
|
||||
let expected_path =
|
||||
normalize_path(&nested_skill_path).unwrap_or_else(|_| nested_skill_path.clone());
|
||||
canonicalize_path(&nested_skill_path).unwrap_or_else(|_| nested_skill_path.clone());
|
||||
assert_eq!(
|
||||
vec![SkillMetadata {
|
||||
name: "dupe-skill".to_string(),
|
||||
|
||||
@@ -66,10 +66,11 @@ pub(crate) async fn spawn_child_async(
|
||||
|
||||
#[cfg(unix)]
|
||||
unsafe {
|
||||
let set_process_group = matches!(stdio_policy, StdioPolicy::RedirectForShellTool);
|
||||
#[cfg(target_os = "linux")]
|
||||
let parent_pid = libc::getpid();
|
||||
cmd.pre_exec(move || {
|
||||
if libc::setpgid(0, 0) == -1 {
|
||||
if set_process_group && libc::setpgid(0, 0) == -1 {
|
||||
return Err(std::io::Error::last_os_error());
|
||||
}
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ use crate::skills::SkillsManager;
|
||||
use crate::tools::sandboxing::ApprovalStore;
|
||||
use crate::unified_exec::UnifiedExecProcessManager;
|
||||
use crate::user_notification::UserNotifier;
|
||||
use codex_otel::otel_manager::OtelManager;
|
||||
use codex_otel::OtelManager;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::RwLock;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
@@ -119,8 +119,7 @@ pub(crate) async fn handle_output_item_done(
|
||||
output.needs_follow_up = true;
|
||||
}
|
||||
// The tool request should be answered directly (or was denied); push that response into the transcript.
|
||||
Err(FunctionCallError::RespondToModel(message))
|
||||
| Err(FunctionCallError::Denied(message)) => {
|
||||
Err(FunctionCallError::RespondToModel(message)) => {
|
||||
let response = ResponseInputItem::FunctionCallOutput {
|
||||
call_id: String::new(),
|
||||
output: FunctionCallOutputPayload {
|
||||
|
||||
@@ -29,8 +29,18 @@ impl SessionTask for CompactTask {
|
||||
session.as_ref(),
|
||||
&ctx.client.get_provider(),
|
||||
) {
|
||||
let _ = session.services.otel_manager.counter(
|
||||
"codex.task.compact",
|
||||
1,
|
||||
&[("type", "remote")],
|
||||
);
|
||||
crate::compact_remote::run_remote_compact_task(session, ctx).await
|
||||
} else {
|
||||
let _ = session.services.otel_manager.counter(
|
||||
"codex.task.compact",
|
||||
1,
|
||||
&[("type", "local")],
|
||||
);
|
||||
crate::compact::run_compact_task(session, ctx, input).await
|
||||
}
|
||||
|
||||
|
||||
@@ -46,6 +46,12 @@ impl SessionTask for ReviewTask {
|
||||
input: Vec<UserInput>,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> Option<String> {
|
||||
let _ = session
|
||||
.session
|
||||
.services
|
||||
.otel_manager
|
||||
.counter("codex.task.review", 1, &[]);
|
||||
|
||||
// Start sub-codex conversation and get the receiver for events.
|
||||
let output = match start_review_conversation(
|
||||
session.clone(),
|
||||
@@ -77,10 +83,6 @@ async fn start_review_conversation(
|
||||
) -> Option<async_channel::Receiver<Event>> {
|
||||
let config = ctx.client.config();
|
||||
let mut sub_agent_config = config.as_ref().clone();
|
||||
// Run with only reviewer rubric — drop outer user_instructions
|
||||
sub_agent_config.user_instructions = None;
|
||||
// Avoid loading project docs; reviewer only needs findings
|
||||
sub_agent_config.project_doc_max_bytes = 0;
|
||||
// Carry over review-only feature restrictions so the delegate cannot
|
||||
// re-enable blocked tools (web search, view image).
|
||||
sub_agent_config
|
||||
|
||||
@@ -38,6 +38,11 @@ impl SessionTask for UndoTask {
|
||||
_input: Vec<UserInput>,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> Option<String> {
|
||||
let _ = session
|
||||
.session
|
||||
.services
|
||||
.otel_manager
|
||||
.counter("codex.task.undo", 1, &[]);
|
||||
let sess = session.clone_session();
|
||||
sess.send_event(
|
||||
ctx.as_ref(),
|
||||
@@ -59,8 +64,8 @@ impl SessionTask for UndoTask {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut history = sess.clone_history().await;
|
||||
let mut items = history.get_history();
|
||||
let history = sess.clone_history().await;
|
||||
let mut items = history.raw_items().to_vec();
|
||||
let mut completed = UndoCompletedEvent {
|
||||
success: false,
|
||||
message: None,
|
||||
|
||||
@@ -58,6 +58,12 @@ impl SessionTask for UserShellCommandTask {
|
||||
_input: Vec<UserInput>,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> Option<String> {
|
||||
let _ = session
|
||||
.session
|
||||
.services
|
||||
.otel_manager
|
||||
.counter("codex.task.user_shell", 1, &[]);
|
||||
|
||||
let event = EventMsg::TaskStarted(TaskStartedEvent {
|
||||
model_context_window: turn_context.client.get_model_context_window(),
|
||||
});
|
||||
|
||||
@@ -132,6 +132,10 @@ impl ThreadManager {
|
||||
self.state.models_manager.list_models(config).await
|
||||
}
|
||||
|
||||
pub async fn list_thread_ids(&self) -> Vec<ThreadId> {
|
||||
self.state.threads.read().await.keys().copied().collect()
|
||||
}
|
||||
|
||||
pub async fn get_thread(&self, thread_id: ThreadId) -> CodexResult<Arc<CodexThread>> {
|
||||
self.state.get_thread(thread_id).await
|
||||
}
|
||||
@@ -179,7 +183,7 @@ impl ThreadManager {
|
||||
/// Fork an existing thread by taking messages up to the given position (not including
|
||||
/// the message at the given position) and starting a new thread with identical
|
||||
/// configuration (unless overridden by the caller's `config`). The new thread will have
|
||||
/// a fresh id.
|
||||
/// a fresh id. Pass `usize::MAX` to keep the full rollout history.
|
||||
pub async fn fork_thread(
|
||||
&self,
|
||||
nth_user_message: usize,
|
||||
|
||||
@@ -100,7 +100,6 @@ pub(crate) enum ToolEmitter {
|
||||
command: Vec<String>,
|
||||
cwd: PathBuf,
|
||||
source: ExecCommandSource,
|
||||
interaction_input: Option<String>,
|
||||
parsed_cmd: Vec<ParsedCommand>,
|
||||
process_id: Option<String>,
|
||||
},
|
||||
@@ -141,7 +140,6 @@ impl ToolEmitter {
|
||||
command: command.to_vec(),
|
||||
cwd,
|
||||
source,
|
||||
interaction_input: None, // TODO(jif) drop this field in the protocol.
|
||||
parsed_cmd,
|
||||
process_id,
|
||||
}
|
||||
@@ -231,7 +229,6 @@ impl ToolEmitter {
|
||||
command,
|
||||
cwd,
|
||||
source,
|
||||
interaction_input,
|
||||
parsed_cmd,
|
||||
process_id,
|
||||
},
|
||||
@@ -244,7 +241,7 @@ impl ToolEmitter {
|
||||
cwd.as_path(),
|
||||
parsed_cmd,
|
||||
*source,
|
||||
interaction_input.as_deref(),
|
||||
None,
|
||||
process_id.as_deref(),
|
||||
),
|
||||
stage,
|
||||
|
||||
@@ -17,6 +17,7 @@ use crate::tools::context::ToolOutput;
|
||||
use crate::tools::context::ToolPayload;
|
||||
use crate::tools::events::ToolEmitter;
|
||||
use crate::tools::events::ToolEventCtx;
|
||||
use crate::tools::handlers::parse_arguments;
|
||||
use crate::tools::orchestrator::ToolOrchestrator;
|
||||
use crate::tools::registry::ToolHandler;
|
||||
use crate::tools::registry::ToolKind;
|
||||
@@ -87,11 +88,7 @@ impl ToolHandler for ApplyPatchHandler {
|
||||
|
||||
let patch_input = match payload {
|
||||
ToolPayload::Function { arguments } => {
|
||||
let args: ApplyPatchToolArgs = serde_json::from_str(&arguments).map_err(|e| {
|
||||
FunctionCallError::RespondToModel(format!(
|
||||
"failed to parse function arguments: {e:?}"
|
||||
))
|
||||
})?;
|
||||
let args: ApplyPatchToolArgs = parse_arguments(&arguments)?;
|
||||
args.input
|
||||
}
|
||||
ToolPayload::Custom { input } => input,
|
||||
|
||||
@@ -10,6 +10,7 @@ use crate::function_tool::FunctionCallError;
|
||||
use crate::tools::context::ToolInvocation;
|
||||
use crate::tools::context::ToolOutput;
|
||||
use crate::tools::context::ToolPayload;
|
||||
use crate::tools::handlers::parse_arguments;
|
||||
use crate::tools::registry::ToolHandler;
|
||||
use crate::tools::registry::ToolKind;
|
||||
|
||||
@@ -52,11 +53,7 @@ impl ToolHandler for GrepFilesHandler {
|
||||
}
|
||||
};
|
||||
|
||||
let args: GrepFilesArgs = serde_json::from_str(&arguments).map_err(|err| {
|
||||
FunctionCallError::RespondToModel(format!(
|
||||
"failed to parse function arguments: {err:?}"
|
||||
))
|
||||
})?;
|
||||
let args: GrepFilesArgs = parse_arguments(&arguments)?;
|
||||
|
||||
let pattern = args.pattern.trim();
|
||||
if pattern.is_empty() {
|
||||
|
||||
@@ -13,6 +13,7 @@ use crate::function_tool::FunctionCallError;
|
||||
use crate::tools::context::ToolInvocation;
|
||||
use crate::tools::context::ToolOutput;
|
||||
use crate::tools::context::ToolPayload;
|
||||
use crate::tools::handlers::parse_arguments;
|
||||
use crate::tools::registry::ToolHandler;
|
||||
use crate::tools::registry::ToolKind;
|
||||
|
||||
@@ -62,11 +63,7 @@ impl ToolHandler for ListDirHandler {
|
||||
}
|
||||
};
|
||||
|
||||
let args: ListDirArgs = serde_json::from_str(&arguments).map_err(|err| {
|
||||
FunctionCallError::RespondToModel(format!(
|
||||
"failed to parse function arguments: {err:?}"
|
||||
))
|
||||
})?;
|
||||
let args: ListDirArgs = parse_arguments(&arguments)?;
|
||||
|
||||
let ListDirArgs {
|
||||
dir_path,
|
||||
@@ -125,6 +122,8 @@ async fn list_dir_slice(
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
entries.sort_unstable_by(|a, b| a.name.cmp(&b.name));
|
||||
|
||||
let start_index = offset - 1;
|
||||
if start_index >= entries.len() {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
@@ -135,11 +134,10 @@ async fn list_dir_slice(
|
||||
let remaining_entries = entries.len() - start_index;
|
||||
let capped_limit = limit.min(remaining_entries);
|
||||
let end_index = start_index + capped_limit;
|
||||
let mut selected_entries = entries[start_index..end_index].to_vec();
|
||||
selected_entries.sort_unstable_by(|a, b| a.name.cmp(&b.name));
|
||||
let selected_entries = &entries[start_index..end_index];
|
||||
let mut formatted = Vec::with_capacity(selected_entries.len());
|
||||
|
||||
for entry in &selected_entries {
|
||||
for entry in selected_entries {
|
||||
formatted.push(format_entry_line(entry));
|
||||
}
|
||||
|
||||
@@ -273,6 +271,7 @@ impl From<&FileType> for DirEntryKind {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
use tempfile::tempdir;
|
||||
|
||||
#[tokio::test]
|
||||
@@ -404,6 +403,44 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn paginates_in_sorted_order() {
|
||||
let temp = tempdir().expect("create tempdir");
|
||||
let dir_path = temp.path();
|
||||
|
||||
let dir_a = dir_path.join("a");
|
||||
let dir_b = dir_path.join("b");
|
||||
tokio::fs::create_dir(&dir_a).await.expect("create a");
|
||||
tokio::fs::create_dir(&dir_b).await.expect("create b");
|
||||
|
||||
tokio::fs::write(dir_a.join("a_child.txt"), b"a")
|
||||
.await
|
||||
.expect("write a child");
|
||||
tokio::fs::write(dir_b.join("b_child.txt"), b"b")
|
||||
.await
|
||||
.expect("write b child");
|
||||
|
||||
let first_page = list_dir_slice(dir_path, 1, 2, 2)
|
||||
.await
|
||||
.expect("list page one");
|
||||
assert_eq!(
|
||||
first_page,
|
||||
vec![
|
||||
"a/".to_string(),
|
||||
" a_child.txt".to_string(),
|
||||
"More than 2 entries found".to_string()
|
||||
]
|
||||
);
|
||||
|
||||
let second_page = list_dir_slice(dir_path, 3, 2, 2)
|
||||
.await
|
||||
.expect("list page two");
|
||||
assert_eq!(
|
||||
second_page,
|
||||
vec!["b/".to_string(), " b_child.txt".to_string()]
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn handles_large_limit_without_overflow() {
|
||||
let temp = tempdir().expect("create tempdir");
|
||||
@@ -450,7 +487,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn bfs_truncation() -> anyhow::Result<()> {
|
||||
async fn truncation_respects_sorted_order() -> anyhow::Result<()> {
|
||||
let temp = tempdir()?;
|
||||
let dir_path = temp.path();
|
||||
let nested = dir_path.join("nested");
|
||||
@@ -467,7 +504,7 @@ mod tests {
|
||||
vec![
|
||||
"nested/".to_string(),
|
||||
" child.txt".to_string(),
|
||||
"root.txt".to_string(),
|
||||
" deeper/".to_string(),
|
||||
"More than 3 entries found".to_string()
|
||||
]
|
||||
);
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user