mirror of
https://github.com/openai/codex.git
synced 2026-05-14 16:22:51 +00:00
Compare commits
12 Commits
kevinliu/f
...
codex/add-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c58b5c0910 | ||
|
|
0bd53ac078 | ||
|
|
9aa1331c92 | ||
|
|
555a172722 | ||
|
|
bd0f423d36 | ||
|
|
b0757a1c23 | ||
|
|
e98b35ac78 | ||
|
|
e6939025f5 | ||
|
|
f123cc6541 | ||
|
|
5f16fe8dda | ||
|
|
98b31b0390 | ||
|
|
9487ae4ce7 |
18
.vscode/launch.json
vendored
18
.vscode/launch.json
vendored
@@ -1,18 +0,0 @@
|
||||
{
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"name": "Cargo launch",
|
||||
"cargo": {
|
||||
"cwd": "${workspaceFolder}/codex-rs",
|
||||
"args": [
|
||||
"build",
|
||||
"--bin=codex-tui"
|
||||
]
|
||||
},
|
||||
"args": []
|
||||
}
|
||||
]
|
||||
}
|
||||
10
.vscode/settings.json
vendored
10
.vscode/settings.json
vendored
@@ -1,10 +0,0 @@
|
||||
{
|
||||
"rust-analyzer.checkOnSave": true,
|
||||
"rust-analyzer.check.command": "clippy",
|
||||
"rust-analyzer.check.extraArgs": ["--all-features", "--tests"],
|
||||
"rust-analyzer.rustfmt.extraArgs": ["--config", "imports_granularity=Item"],
|
||||
"[rust]": {
|
||||
"editor.defaultFormatter": "rust-lang.rust-analyzer",
|
||||
"editor.formatOnSave": true,
|
||||
}
|
||||
}
|
||||
@@ -3,7 +3,3 @@
|
||||
In the codex-rs folder where the rust code lives:
|
||||
|
||||
- Never add or modify any code related to `CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR`. You operate in a sandbox where `CODEX_SANDBOX_NETWORK_DISABLED=1` will be set whenever you use the `shell` tool. Any existing code that uses `CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR` was authored with this fact in mind. It is often used to early exit out of tests that the author knew you would not be able to run given your sandbox limitations.
|
||||
|
||||
Before creating a pull request with changes to `codex-rs`, run `just fmt` (in `codex-rs` directory) to format the code and `just fix` (in `codex-rs` directory) to fix any linter issues in the code, ensure the test suite passes by running `cargo test --all-features` in the `codex-rs` directory.
|
||||
|
||||
When making individual changes prefer running tests on individual files or projects first.
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
* current platform / architecture, an error is thrown.
|
||||
*/
|
||||
|
||||
import { spawnSync } from "child_process";
|
||||
import fs from "fs";
|
||||
import path from "path";
|
||||
import { fileURLToPath, pathToFileURL } from "url";
|
||||
@@ -34,7 +35,7 @@ const wantsNative = fs.existsSync(path.join(__dirname, "use-native")) ||
|
||||
: false);
|
||||
|
||||
// Try native binary if requested.
|
||||
if (wantsNative && process.platform !== 'win32') {
|
||||
if (wantsNative) {
|
||||
const { platform, arch } = process;
|
||||
|
||||
let targetTriple = null;
|
||||
@@ -73,76 +74,22 @@ if (wantsNative && process.platform !== 'win32') {
|
||||
}
|
||||
|
||||
const binaryPath = path.join(__dirname, "..", "bin", `codex-${targetTriple}`);
|
||||
|
||||
// Use an asynchronous spawn instead of spawnSync so that Node is able to
|
||||
// respond to signals (e.g. Ctrl-C / SIGINT) while the native binary is
|
||||
// executing. This allows us to forward those signals to the child process
|
||||
// and guarantees that when either the child terminates or the parent
|
||||
// receives a fatal signal, both processes exit in a predictable manner.
|
||||
const { spawn } = await import("child_process");
|
||||
|
||||
const child = spawn(binaryPath, process.argv.slice(2), {
|
||||
const result = spawnSync(binaryPath, process.argv.slice(2), {
|
||||
stdio: "inherit",
|
||||
});
|
||||
|
||||
child.on("error", (err) => {
|
||||
// Typically triggered when the binary is missing or not executable.
|
||||
// Re-throwing here will terminate the parent with a non-zero exit code
|
||||
// while still printing a helpful stack trace.
|
||||
// eslint-disable-next-line no-console
|
||||
console.error(err);
|
||||
process.exit(1);
|
||||
});
|
||||
const exitCode = typeof result.status === "number" ? result.status : 1;
|
||||
process.exit(exitCode);
|
||||
}
|
||||
|
||||
// Forward common termination signals to the child so that it shuts down
|
||||
// gracefully. In the handler we temporarily disable the default behavior of
|
||||
// exiting immediately; once the child has been signaled we simply wait for
|
||||
// its exit event which will in turn terminate the parent (see below).
|
||||
const forwardSignal = (signal) => {
|
||||
if (child.killed) {
|
||||
return;
|
||||
}
|
||||
try {
|
||||
child.kill(signal);
|
||||
} catch {
|
||||
/* ignore */
|
||||
}
|
||||
};
|
||||
// Fallback: execute the original JavaScript CLI.
|
||||
|
||||
["SIGINT", "SIGTERM", "SIGHUP"].forEach((sig) => {
|
||||
process.on(sig, () => forwardSignal(sig));
|
||||
});
|
||||
// Resolve the path to the compiled CLI bundle
|
||||
const cliPath = path.resolve(__dirname, "../dist/cli.js");
|
||||
const cliUrl = pathToFileURL(cliPath).href;
|
||||
|
||||
// When the child exits, mirror its termination reason in the parent so that
|
||||
// shell scripts and other tooling observe the correct exit status.
|
||||
// Wrap the lifetime of the child process in a Promise so that we can await
|
||||
// its termination in a structured way. The Promise resolves with an object
|
||||
// describing how the child exited: either via exit code or due to a signal.
|
||||
const childResult = await new Promise((resolve) => {
|
||||
child.on("exit", (code, signal) => {
|
||||
if (signal) {
|
||||
resolve({ type: "signal", signal });
|
||||
} else {
|
||||
resolve({ type: "code", exitCode: code ?? 1 });
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
if (childResult.type === "signal") {
|
||||
// Re-emit the same signal so that the parent terminates with the expected
|
||||
// semantics (this also sets the correct exit code of 128 + n).
|
||||
process.kill(process.pid, childResult.signal);
|
||||
} else {
|
||||
process.exit(childResult.exitCode);
|
||||
}
|
||||
} else {
|
||||
// Fallback: execute the original JavaScript CLI.
|
||||
|
||||
// Resolve the path to the compiled CLI bundle
|
||||
const cliPath = path.resolve(__dirname, "../dist/cli.js");
|
||||
const cliUrl = pathToFileURL(cliPath).href;
|
||||
|
||||
// Load and execute the CLI
|
||||
// Load and execute the CLI
|
||||
(async () => {
|
||||
try {
|
||||
await import(cliUrl);
|
||||
} catch (err) {
|
||||
@@ -150,4 +97,4 @@ if (wantsNative && process.platform !== 'win32') {
|
||||
console.error(err);
|
||||
process.exit(1);
|
||||
}
|
||||
}
|
||||
})();
|
||||
|
||||
68
codex-rs/Cargo.lock
generated
68
codex-rs/Cargo.lock
generated
@@ -399,15 +399,6 @@ version = "2.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6099cdc01846bc367c4e7dd630dc5966dccf36b652fae7a74e17b640411a91b2"
|
||||
|
||||
[[package]]
|
||||
name = "block-buffer"
|
||||
version = "0.10.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71"
|
||||
dependencies = [
|
||||
"generic-array",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bstr"
|
||||
version = "1.12.0"
|
||||
@@ -680,7 +671,6 @@ dependencies = [
|
||||
"seccompiler",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"sha1",
|
||||
"strum_macros 0.27.1",
|
||||
"tempfile",
|
||||
"thiserror 2.0.12",
|
||||
@@ -693,7 +683,6 @@ dependencies = [
|
||||
"tree-sitter",
|
||||
"tree-sitter-bash",
|
||||
"uuid",
|
||||
"walkdir",
|
||||
"wildmatch",
|
||||
"wiremock",
|
||||
]
|
||||
@@ -799,7 +788,6 @@ dependencies = [
|
||||
"schemars 0.8.22",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"shlex",
|
||||
"tokio",
|
||||
"toml 0.9.1",
|
||||
"tracing",
|
||||
@@ -944,15 +932,6 @@ version = "0.8.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b"
|
||||
|
||||
[[package]]
|
||||
name = "cpufeatures"
|
||||
version = "0.2.17"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280"
|
||||
dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crc32fast"
|
||||
version = "1.4.2"
|
||||
@@ -1027,16 +1006,6 @@ version = "0.2.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "460fbee9c2c2f33933d720630a6a0bac33ba7053db5344fac858d4b8952d77d5"
|
||||
|
||||
[[package]]
|
||||
name = "crypto-common"
|
||||
version = "0.1.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3"
|
||||
dependencies = [
|
||||
"generic-array",
|
||||
"typenum",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ctor"
|
||||
version = "0.1.26"
|
||||
@@ -1187,16 +1156,6 @@ version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6184e33543162437515c2e2b48714794e37845ec9851711914eec9d308f6ebe8"
|
||||
|
||||
[[package]]
|
||||
name = "digest"
|
||||
version = "0.10.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292"
|
||||
dependencies = [
|
||||
"block-buffer",
|
||||
"crypto-common",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dirs"
|
||||
version = "6.0.0"
|
||||
@@ -1686,16 +1645,6 @@ dependencies = [
|
||||
"byteorder",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "generic-array"
|
||||
version = "0.14.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a"
|
||||
dependencies = [
|
||||
"typenum",
|
||||
"version_check",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "getopts"
|
||||
version = "0.2.23"
|
||||
@@ -3995,17 +3944,6 @@ dependencies = [
|
||||
"syn 2.0.104",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sha1"
|
||||
version = "0.10.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"cpufeatures",
|
||||
"digest",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sharded-slab"
|
||||
version = "0.1.7"
|
||||
@@ -4913,12 +4851,6 @@ dependencies = [
|
||||
"unicode-width 0.2.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "typenum"
|
||||
version = "1.18.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1dccffe3ce07af9386bfd29e80c0ab1a8205a2fc34e4bcd40364df902cfa8f3f"
|
||||
|
||||
[[package]]
|
||||
name = "unicase"
|
||||
version = "2.8.1"
|
||||
|
||||
@@ -64,11 +64,7 @@ impl CliConfigOverrides {
|
||||
// `-c model=o3` without the quotes.
|
||||
let value: Value = match parse_toml_value(value_str) {
|
||||
Ok(v) => v,
|
||||
Err(_) => {
|
||||
// Strip leading/trailing quotes if present
|
||||
let trimmed = value_str.trim().trim_matches(|c| c == '"' || c == '\'');
|
||||
Value::String(trimmed.to_string())
|
||||
}
|
||||
Err(_) => Value::String(value_str.to_string()),
|
||||
};
|
||||
|
||||
Ok((key.to_string(), value))
|
||||
|
||||
@@ -92,32 +92,6 @@ http_headers = { "X-Example-Header" = "example-value" }
|
||||
env_http_headers = { "X-Example-Features": "EXAMPLE_FEATURES" }
|
||||
```
|
||||
|
||||
### Per-provider network tuning
|
||||
|
||||
The following optional settings control retry behaviour and streaming idle timeouts **per model provider**. They must be specified inside the corresponding `[model_providers.<id>]` block in `config.toml`. (Older releases accepted top‑level keys; those are now ignored.)
|
||||
|
||||
Example:
|
||||
|
||||
```toml
|
||||
[model_providers.openai]
|
||||
name = "OpenAI"
|
||||
base_url = "https://api.openai.com/v1"
|
||||
env_key = "OPENAI_API_KEY"
|
||||
# network tuning overrides (all optional; falls back to built‑in defaults)
|
||||
request_max_retries = 4 # retry failed HTTP requests
|
||||
stream_max_retries = 10 # retry dropped SSE streams
|
||||
stream_idle_timeout_ms = 300000 # 5m idle timeout
|
||||
```
|
||||
|
||||
#### request_max_retries
|
||||
How many times Codex will retry a failed HTTP request to the model provider. Defaults to `4`.
|
||||
|
||||
#### stream_max_retries
|
||||
Number of times Codex will attempt to reconnect when a streaming response is interrupted. Defaults to `10`.
|
||||
|
||||
#### stream_idle_timeout_ms
|
||||
How long Codex will wait for activity on a streaming response before treating the connection as lost. Defaults to `300_000` (5 minutes).
|
||||
|
||||
## model_provider
|
||||
|
||||
Identifies which provider to use from the `model_providers` map. Defaults to `"openai"`. You can override the `base_url` for the built-in `openai` provider via the `OPENAI_BASE_URL` environment variable.
|
||||
@@ -470,7 +444,7 @@ Currently, `"vscode"` is the default, though Codex does not verify VS Code is in
|
||||
|
||||
## hide_agent_reasoning
|
||||
|
||||
Codex intermittently emits "reasoning" events that show the model's internal "thinking" before it produces a final answer. Some users may find these events distracting, especially in CI logs or minimal terminal output.
|
||||
Codex intermittently emits "reasoning" events that show the model’s internal "thinking" before it produces a final answer. Some users may find these events distracting, especially in CI logs or minimal terminal output.
|
||||
|
||||
Setting `hide_agent_reasoning` to `true` suppresses these events in **both** the TUI as well as the headless `exec` sub-command:
|
||||
|
||||
|
||||
@@ -28,7 +28,6 @@ rand = "0.9"
|
||||
reqwest = { version = "0.12", features = ["json", "stream"] }
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
sha1 = "0.10.6"
|
||||
strum_macros = "0.27.1"
|
||||
thiserror = "2.0.12"
|
||||
time = { version = "0.3", features = ["formatting", "local-offset", "macros"] }
|
||||
@@ -66,5 +65,4 @@ predicates = "3"
|
||||
pretty_assertions = "1.4.1"
|
||||
tempfile = "3"
|
||||
tokio-test = "0.4"
|
||||
walkdir = "2.5.0"
|
||||
wiremock = "0.6"
|
||||
|
||||
@@ -21,6 +21,8 @@ use crate::client_common::ResponseEvent;
|
||||
use crate::client_common::ResponseStream;
|
||||
use crate::error::CodexErr;
|
||||
use crate::error::Result;
|
||||
use crate::flags::OPENAI_REQUEST_MAX_RETRIES;
|
||||
use crate::flags::OPENAI_STREAM_IDLE_TIMEOUT_MS;
|
||||
use crate::models::ContentItem;
|
||||
use crate::models::ResponseItem;
|
||||
use crate::openai_tools::create_tools_json_for_chat_completions_api;
|
||||
@@ -119,7 +121,6 @@ pub(crate) async fn stream_chat_completions(
|
||||
);
|
||||
|
||||
let mut attempt = 0;
|
||||
let max_retries = provider.request_max_retries();
|
||||
loop {
|
||||
attempt += 1;
|
||||
|
||||
@@ -133,13 +134,9 @@ pub(crate) async fn stream_chat_completions(
|
||||
|
||||
match res {
|
||||
Ok(resp) if resp.status().is_success() => {
|
||||
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent>>(1600);
|
||||
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent>>(16);
|
||||
let stream = resp.bytes_stream().map_err(CodexErr::Reqwest);
|
||||
tokio::spawn(process_chat_sse(
|
||||
stream,
|
||||
tx_event,
|
||||
provider.stream_idle_timeout(),
|
||||
));
|
||||
tokio::spawn(process_chat_sse(stream, tx_event));
|
||||
return Ok(ResponseStream { rx_event });
|
||||
}
|
||||
Ok(res) => {
|
||||
@@ -149,7 +146,7 @@ pub(crate) async fn stream_chat_completions(
|
||||
return Err(CodexErr::UnexpectedStatus(status, body));
|
||||
}
|
||||
|
||||
if attempt > max_retries {
|
||||
if attempt > *OPENAI_REQUEST_MAX_RETRIES {
|
||||
return Err(CodexErr::RetryLimit(status));
|
||||
}
|
||||
|
||||
@@ -165,7 +162,7 @@ pub(crate) async fn stream_chat_completions(
|
||||
tokio::time::sleep(delay).await;
|
||||
}
|
||||
Err(e) => {
|
||||
if attempt > max_retries {
|
||||
if attempt > *OPENAI_REQUEST_MAX_RETRIES {
|
||||
return Err(e.into());
|
||||
}
|
||||
let delay = backoff(attempt);
|
||||
@@ -178,15 +175,14 @@ pub(crate) async fn stream_chat_completions(
|
||||
/// Lightweight SSE processor for the Chat Completions streaming format. The
|
||||
/// output is mapped onto Codex's internal [`ResponseEvent`] so that the rest
|
||||
/// of the pipeline can stay agnostic of the underlying wire format.
|
||||
async fn process_chat_sse<S>(
|
||||
stream: S,
|
||||
tx_event: mpsc::Sender<Result<ResponseEvent>>,
|
||||
idle_timeout: Duration,
|
||||
) where
|
||||
async fn process_chat_sse<S>(stream: S, tx_event: mpsc::Sender<Result<ResponseEvent>>)
|
||||
where
|
||||
S: Stream<Item = Result<Bytes>> + Unpin,
|
||||
{
|
||||
let mut stream = stream.eventsource();
|
||||
|
||||
let idle_timeout = *OPENAI_STREAM_IDLE_TIMEOUT_MS;
|
||||
|
||||
// State to accumulate a function call across streaming chunks.
|
||||
// OpenAI may split the `arguments` string over multiple `delta` events
|
||||
// until the chunk whose `finish_reason` is `tool_calls` is emitted. We
|
||||
|
||||
@@ -15,7 +15,6 @@ use tokio_util::io::ReaderStream;
|
||||
use tracing::debug;
|
||||
use tracing::trace;
|
||||
use tracing::warn;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::chat_completions::AggregateStreamExt;
|
||||
use crate::chat_completions::stream_chat_completions;
|
||||
@@ -30,6 +29,8 @@ use crate::config_types::ReasoningSummary as ReasoningSummaryConfig;
|
||||
use crate::error::CodexErr;
|
||||
use crate::error::Result;
|
||||
use crate::flags::CODEX_RS_SSE_FIXTURE;
|
||||
use crate::flags::OPENAI_REQUEST_MAX_RETRIES;
|
||||
use crate::flags::OPENAI_STREAM_IDLE_TIMEOUT_MS;
|
||||
use crate::model_provider_info::ModelProviderInfo;
|
||||
use crate::model_provider_info::WireApi;
|
||||
use crate::models::ResponseItem;
|
||||
@@ -43,7 +44,6 @@ pub struct ModelClient {
|
||||
config: Arc<Config>,
|
||||
client: reqwest::Client,
|
||||
provider: ModelProviderInfo,
|
||||
session_id: Uuid,
|
||||
effort: ReasoningEffortConfig,
|
||||
summary: ReasoningSummaryConfig,
|
||||
}
|
||||
@@ -54,13 +54,11 @@ impl ModelClient {
|
||||
provider: ModelProviderInfo,
|
||||
effort: ReasoningEffortConfig,
|
||||
summary: ReasoningSummaryConfig,
|
||||
session_id: Uuid,
|
||||
) -> Self {
|
||||
Self {
|
||||
config,
|
||||
client: reqwest::Client::new(),
|
||||
provider,
|
||||
session_id,
|
||||
effort,
|
||||
summary,
|
||||
}
|
||||
@@ -111,7 +109,7 @@ impl ModelClient {
|
||||
if let Some(path) = &*CODEX_RS_SSE_FIXTURE {
|
||||
// short circuit for tests
|
||||
warn!(path, "Streaming from fixture");
|
||||
return stream_from_fixture(path, self.provider.clone()).await;
|
||||
return stream_from_fixture(path).await;
|
||||
}
|
||||
|
||||
let full_instructions = prompt.get_full_instructions(&self.config.model);
|
||||
@@ -127,7 +125,6 @@ impl ModelClient {
|
||||
reasoning,
|
||||
previous_response_id: prompt.prev_id.clone(),
|
||||
store: prompt.store,
|
||||
// TODO: make this configurable
|
||||
stream: true,
|
||||
};
|
||||
|
||||
@@ -138,7 +135,6 @@ impl ModelClient {
|
||||
);
|
||||
|
||||
let mut attempt = 0;
|
||||
let max_retries = self.provider.request_max_retries();
|
||||
loop {
|
||||
attempt += 1;
|
||||
|
||||
@@ -146,22 +142,17 @@ impl ModelClient {
|
||||
.provider
|
||||
.create_request_builder(&self.client)?
|
||||
.header("OpenAI-Beta", "responses=experimental")
|
||||
.header("session_id", self.session_id.to_string())
|
||||
.header(reqwest::header::ACCEPT, "text/event-stream")
|
||||
.json(&payload);
|
||||
|
||||
let res = req_builder.send().await;
|
||||
match res {
|
||||
Ok(resp) if resp.status().is_success() => {
|
||||
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent>>(1600);
|
||||
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent>>(16);
|
||||
|
||||
// spawn task to process SSE
|
||||
let stream = resp.bytes_stream().map_err(CodexErr::Reqwest);
|
||||
tokio::spawn(process_sse(
|
||||
stream,
|
||||
tx_event,
|
||||
self.provider.stream_idle_timeout(),
|
||||
));
|
||||
tokio::spawn(process_sse(stream, tx_event));
|
||||
|
||||
return Ok(ResponseStream { rx_event });
|
||||
}
|
||||
@@ -180,7 +171,7 @@ impl ModelClient {
|
||||
return Err(CodexErr::UnexpectedStatus(status, body));
|
||||
}
|
||||
|
||||
if attempt > max_retries {
|
||||
if attempt > *OPENAI_REQUEST_MAX_RETRIES {
|
||||
return Err(CodexErr::RetryLimit(status));
|
||||
}
|
||||
|
||||
@@ -197,7 +188,7 @@ impl ModelClient {
|
||||
tokio::time::sleep(delay).await;
|
||||
}
|
||||
Err(e) => {
|
||||
if attempt > max_retries {
|
||||
if attempt > *OPENAI_REQUEST_MAX_RETRIES {
|
||||
return Err(e.into());
|
||||
}
|
||||
let delay = backoff(attempt);
|
||||
@@ -207,8 +198,8 @@ impl ModelClient {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_provider(&self) -> ModelProviderInfo {
|
||||
self.provider.clone()
|
||||
pub fn streaming_enabled(&self) -> bool {
|
||||
self.config.streaming_enabled
|
||||
}
|
||||
}
|
||||
|
||||
@@ -261,16 +252,14 @@ struct ResponseCompletedOutputTokensDetails {
|
||||
reasoning_tokens: u64,
|
||||
}
|
||||
|
||||
async fn process_sse<S>(
|
||||
stream: S,
|
||||
tx_event: mpsc::Sender<Result<ResponseEvent>>,
|
||||
idle_timeout: Duration,
|
||||
) where
|
||||
async fn process_sse<S>(stream: S, tx_event: mpsc::Sender<Result<ResponseEvent>>)
|
||||
where
|
||||
S: Stream<Item = Result<Bytes>> + Unpin,
|
||||
{
|
||||
let mut stream = stream.eventsource();
|
||||
|
||||
// If the stream stays completely silent for an extended period treat it as disconnected.
|
||||
let idle_timeout = *OPENAI_STREAM_IDLE_TIMEOUT_MS;
|
||||
// The response id returned from the "complete" message.
|
||||
let mut response_completed: Option<ResponseCompleted> = None;
|
||||
|
||||
@@ -404,11 +393,8 @@ async fn process_sse<S>(
|
||||
}
|
||||
|
||||
/// used in tests to stream from a text SSE file
|
||||
async fn stream_from_fixture(
|
||||
path: impl AsRef<Path>,
|
||||
provider: ModelProviderInfo,
|
||||
) -> Result<ResponseStream> {
|
||||
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent>>(1600);
|
||||
async fn stream_from_fixture(path: impl AsRef<Path>) -> Result<ResponseStream> {
|
||||
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent>>(16);
|
||||
let f = std::fs::File::open(path.as_ref())?;
|
||||
let lines = std::io::BufReader::new(f).lines();
|
||||
|
||||
@@ -421,11 +407,7 @@ async fn stream_from_fixture(
|
||||
|
||||
let rdr = std::io::Cursor::new(content);
|
||||
let stream = ReaderStream::new(rdr).map_err(CodexErr::Io);
|
||||
tokio::spawn(process_sse(
|
||||
stream,
|
||||
tx_event,
|
||||
provider.stream_idle_timeout(),
|
||||
));
|
||||
tokio::spawn(process_sse(stream, tx_event));
|
||||
Ok(ResponseStream { rx_event })
|
||||
}
|
||||
|
||||
@@ -445,10 +427,7 @@ mod tests {
|
||||
|
||||
/// Runs the SSE parser on pre-chunked byte slices and returns every event
|
||||
/// (including any final `Err` from a stream-closure check).
|
||||
async fn collect_events(
|
||||
chunks: &[&[u8]],
|
||||
provider: ModelProviderInfo,
|
||||
) -> Vec<Result<ResponseEvent>> {
|
||||
async fn collect_events(chunks: &[&[u8]]) -> Vec<Result<ResponseEvent>> {
|
||||
let mut builder = IoBuilder::new();
|
||||
for chunk in chunks {
|
||||
builder.read(chunk);
|
||||
@@ -457,7 +436,7 @@ mod tests {
|
||||
let reader = builder.build();
|
||||
let stream = ReaderStream::new(reader).map_err(CodexErr::Io);
|
||||
let (tx, mut rx) = mpsc::channel::<Result<ResponseEvent>>(16);
|
||||
tokio::spawn(process_sse(stream, tx, provider.stream_idle_timeout()));
|
||||
tokio::spawn(process_sse(stream, tx));
|
||||
|
||||
let mut events = Vec::new();
|
||||
while let Some(ev) = rx.recv().await {
|
||||
@@ -468,10 +447,7 @@ mod tests {
|
||||
|
||||
/// Builds an in-memory SSE stream from JSON fixtures and returns only the
|
||||
/// successfully parsed events (panics on internal channel errors).
|
||||
async fn run_sse(
|
||||
events: Vec<serde_json::Value>,
|
||||
provider: ModelProviderInfo,
|
||||
) -> Vec<ResponseEvent> {
|
||||
async fn run_sse(events: Vec<serde_json::Value>) -> Vec<ResponseEvent> {
|
||||
let mut body = String::new();
|
||||
for e in events {
|
||||
let kind = e
|
||||
@@ -487,7 +463,7 @@ mod tests {
|
||||
|
||||
let (tx, mut rx) = mpsc::channel::<Result<ResponseEvent>>(8);
|
||||
let stream = ReaderStream::new(std::io::Cursor::new(body)).map_err(CodexErr::Io);
|
||||
tokio::spawn(process_sse(stream, tx, provider.stream_idle_timeout()));
|
||||
tokio::spawn(process_sse(stream, tx));
|
||||
|
||||
let mut out = Vec::new();
|
||||
while let Some(ev) = rx.recv().await {
|
||||
@@ -532,25 +508,7 @@ mod tests {
|
||||
let sse2 = format!("event: response.output_item.done\ndata: {item2}\n\n");
|
||||
let sse3 = format!("event: response.completed\ndata: {completed}\n\n");
|
||||
|
||||
let provider = ModelProviderInfo {
|
||||
name: "test".to_string(),
|
||||
base_url: "https://test.com".to_string(),
|
||||
env_key: Some("TEST_API_KEY".to_string()),
|
||||
env_key_instructions: None,
|
||||
wire_api: WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
request_max_retries: Some(0),
|
||||
stream_max_retries: Some(0),
|
||||
stream_idle_timeout_ms: Some(1000),
|
||||
};
|
||||
|
||||
let events = collect_events(
|
||||
&[sse1.as_bytes(), sse2.as_bytes(), sse3.as_bytes()],
|
||||
provider,
|
||||
)
|
||||
.await;
|
||||
let events = collect_events(&[sse1.as_bytes(), sse2.as_bytes(), sse3.as_bytes()]).await;
|
||||
|
||||
assert_eq!(events.len(), 3);
|
||||
|
||||
@@ -591,21 +549,8 @@ mod tests {
|
||||
.to_string();
|
||||
|
||||
let sse1 = format!("event: response.output_item.done\ndata: {item1}\n\n");
|
||||
let provider = ModelProviderInfo {
|
||||
name: "test".to_string(),
|
||||
base_url: "https://test.com".to_string(),
|
||||
env_key: Some("TEST_API_KEY".to_string()),
|
||||
env_key_instructions: None,
|
||||
wire_api: WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
request_max_retries: Some(0),
|
||||
stream_max_retries: Some(0),
|
||||
stream_idle_timeout_ms: Some(1000),
|
||||
};
|
||||
|
||||
let events = collect_events(&[sse1.as_bytes()], provider).await;
|
||||
let events = collect_events(&[sse1.as_bytes()]).await;
|
||||
|
||||
assert_eq!(events.len(), 2);
|
||||
|
||||
@@ -693,21 +638,7 @@ mod tests {
|
||||
let mut evs = vec![case.event];
|
||||
evs.push(completed.clone());
|
||||
|
||||
let provider = ModelProviderInfo {
|
||||
name: "test".to_string(),
|
||||
base_url: "https://test.com".to_string(),
|
||||
env_key: Some("TEST_API_KEY".to_string()),
|
||||
env_key_instructions: None,
|
||||
wire_api: WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
request_max_retries: Some(0),
|
||||
stream_max_retries: Some(0),
|
||||
stream_idle_timeout_ms: Some(1000),
|
||||
};
|
||||
|
||||
let out = run_sse(evs, provider).await;
|
||||
let out = run_sse(evs).await;
|
||||
assert_eq!(out.len(), case.expected_len, "case {}", case.name);
|
||||
assert!(
|
||||
(case.expect_first)(&out[0]),
|
||||
|
||||
@@ -53,12 +53,14 @@ impl Prompt {
|
||||
pub enum ResponseEvent {
|
||||
Created,
|
||||
OutputItemDone(ResponseItem),
|
||||
/// Streaming text from an assistant message.
|
||||
OutputTextDelta(String),
|
||||
/// Streaming text from a reasoning summary.
|
||||
ReasoningSummaryDelta(String),
|
||||
Completed {
|
||||
response_id: String,
|
||||
token_usage: Option<TokenUsage>,
|
||||
},
|
||||
OutputTextDelta(String),
|
||||
ReasoningSummaryDelta(String),
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
|
||||
@@ -49,7 +49,9 @@ use crate::exec::ExecToolCallOutput;
|
||||
use crate::exec::SandboxType;
|
||||
use crate::exec::process_exec_tool_call;
|
||||
use crate::exec_env::create_env;
|
||||
use crate::flags::OPENAI_STREAM_MAX_RETRIES;
|
||||
use crate::mcp_connection_manager::McpConnectionManager;
|
||||
use crate::mcp_connection_manager::try_parse_fully_qualified_tool_name;
|
||||
use crate::mcp_tool_call::handle_mcp_tool_call;
|
||||
use crate::models::ContentItem;
|
||||
use crate::models::FunctionCallOutputPayload;
|
||||
@@ -59,9 +61,7 @@ use crate::models::ResponseInputItem;
|
||||
use crate::models::ResponseItem;
|
||||
use crate::models::ShellToolCallParams;
|
||||
use crate::project_doc::get_user_instructions;
|
||||
use crate::protocol::AgentMessageDeltaEvent;
|
||||
use crate::protocol::AgentMessageEvent;
|
||||
use crate::protocol::AgentReasoningDeltaEvent;
|
||||
use crate::protocol::AgentReasoningEvent;
|
||||
use crate::protocol::ApplyPatchApprovalRequestEvent;
|
||||
use crate::protocol::AskForApproval;
|
||||
@@ -102,11 +102,8 @@ impl Codex {
|
||||
/// of `Codex` and the ID of the `SessionInitialized` event that was
|
||||
/// submitted to start the session.
|
||||
pub async fn spawn(config: Config, ctrl_c: Arc<Notify>) -> CodexResult<(Codex, String)> {
|
||||
// experimental resume path (undocumented)
|
||||
let resume_path = config.experimental_resume.clone();
|
||||
info!("resume_path: {resume_path:?}");
|
||||
let (tx_sub, rx_sub) = async_channel::bounded(64);
|
||||
let (tx_event, rx_event) = async_channel::bounded(1600);
|
||||
let (tx_event, rx_event) = async_channel::bounded(64);
|
||||
|
||||
let instructions = get_user_instructions(&config).await;
|
||||
let configure_session = Op::ConfigureSession {
|
||||
@@ -120,7 +117,6 @@ impl Codex {
|
||||
disable_response_storage: config.disable_response_storage,
|
||||
notify: config.notify.clone(),
|
||||
cwd: config.cwd.clone(),
|
||||
resume_path: resume_path.clone(),
|
||||
};
|
||||
|
||||
let config = Arc::new(config);
|
||||
@@ -310,30 +306,24 @@ impl Session {
|
||||
/// transcript, if enabled.
|
||||
async fn record_conversation_items(&self, items: &[ResponseItem]) {
|
||||
debug!("Recording items for conversation: {items:?}");
|
||||
self.record_state_snapshot(items).await;
|
||||
self.record_rollout_items(items).await;
|
||||
|
||||
if let Some(transcript) = self.state.lock().unwrap().zdr_transcript.as_mut() {
|
||||
transcript.record_items(items);
|
||||
}
|
||||
}
|
||||
|
||||
async fn record_state_snapshot(&self, items: &[ResponseItem]) {
|
||||
let snapshot = {
|
||||
let state = self.state.lock().unwrap();
|
||||
crate::rollout::SessionStateSnapshot {
|
||||
previous_response_id: state.previous_response_id.clone(),
|
||||
}
|
||||
};
|
||||
|
||||
/// Append the given items to the session's rollout transcript (if enabled)
|
||||
/// and persist them to disk.
|
||||
async fn record_rollout_items(&self, items: &[ResponseItem]) {
|
||||
// Clone the recorder outside of the mutex so we don't hold the lock
|
||||
// across an await point (MutexGuard is not Send).
|
||||
let recorder = {
|
||||
let guard = self.rollout.lock().unwrap();
|
||||
guard.as_ref().cloned()
|
||||
};
|
||||
|
||||
if let Some(rec) = recorder {
|
||||
if let Err(e) = rec.record_state(snapshot).await {
|
||||
error!("failed to record rollout state: {e:#}");
|
||||
}
|
||||
if let Err(e) = rec.record_items(items).await {
|
||||
error!("failed to record rollout items: {e:#}");
|
||||
}
|
||||
@@ -527,7 +517,7 @@ async fn submission_loop(
|
||||
ctrl_c: Arc<Notify>,
|
||||
) {
|
||||
// Generate a unique ID for the lifetime of this Codex session.
|
||||
let mut session_id = Uuid::new_v4();
|
||||
let session_id = Uuid::new_v4();
|
||||
|
||||
let mut sess: Option<Arc<Session>> = None;
|
||||
// shorthand - send an event when there is no active session
|
||||
@@ -580,11 +570,8 @@ async fn submission_loop(
|
||||
disable_response_storage,
|
||||
notify,
|
||||
cwd,
|
||||
resume_path,
|
||||
} => {
|
||||
info!(
|
||||
"Configuring session: model={model}; provider={provider:?}; resume={resume_path:?}"
|
||||
);
|
||||
info!("Configuring session: model={model}; provider={provider:?}");
|
||||
if !cwd.is_absolute() {
|
||||
let message = format!("cwd is not absolute: {cwd:?}");
|
||||
error!(message);
|
||||
@@ -597,48 +584,12 @@ async fn submission_loop(
|
||||
}
|
||||
return;
|
||||
}
|
||||
// Optionally resume an existing rollout.
|
||||
let mut restored_items: Option<Vec<ResponseItem>> = None;
|
||||
let mut restored_prev_id: Option<String> = None;
|
||||
let rollout_recorder: Option<RolloutRecorder> =
|
||||
if let Some(path) = resume_path.as_ref() {
|
||||
match RolloutRecorder::resume(path).await {
|
||||
Ok((rec, saved)) => {
|
||||
session_id = saved.session_id;
|
||||
restored_prev_id = saved.state.previous_response_id;
|
||||
if !saved.items.is_empty() {
|
||||
restored_items = Some(saved.items);
|
||||
}
|
||||
Some(rec)
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("failed to resume rollout from {path:?}: {e}");
|
||||
None
|
||||
}
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let rollout_recorder = match rollout_recorder {
|
||||
Some(rec) => Some(rec),
|
||||
None => match RolloutRecorder::new(&config, session_id, instructions.clone())
|
||||
.await
|
||||
{
|
||||
Ok(r) => Some(r),
|
||||
Err(e) => {
|
||||
warn!("failed to initialise rollout recorder: {e}");
|
||||
None
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
let client = ModelClient::new(
|
||||
config.clone(),
|
||||
provider.clone(),
|
||||
model_reasoning_effort,
|
||||
model_reasoning_summary,
|
||||
session_id,
|
||||
);
|
||||
|
||||
// abort any current running session and clone its state
|
||||
@@ -692,6 +643,21 @@ async fn submission_loop(
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Attempt to create a RolloutRecorder *before* moving the
|
||||
// `instructions` value into the Session struct.
|
||||
// TODO: if ConfigureSession is sent twice, we will create an
|
||||
// overlapping rollout file. Consider passing RolloutRecorder
|
||||
// from above.
|
||||
let rollout_recorder =
|
||||
match RolloutRecorder::new(&config, session_id, instructions.clone()).await {
|
||||
Ok(r) => Some(r),
|
||||
Err(e) => {
|
||||
warn!("failed to initialise rollout recorder: {e}");
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
sess = Some(Arc::new(Session {
|
||||
client,
|
||||
tx_event: tx_event.clone(),
|
||||
@@ -709,19 +675,6 @@ async fn submission_loop(
|
||||
codex_linux_sandbox_exe: config.codex_linux_sandbox_exe.clone(),
|
||||
}));
|
||||
|
||||
// Patch restored state into the newly created session.
|
||||
if let Some(sess_arc) = &sess {
|
||||
if restored_prev_id.is_some() || restored_items.is_some() {
|
||||
let mut st = sess_arc.state.lock().unwrap();
|
||||
st.previous_response_id = restored_prev_id;
|
||||
if let (Some(hist), Some(items)) =
|
||||
(st.zdr_transcript.as_mut(), restored_items.as_ref())
|
||||
{
|
||||
hist.record_items(items.iter());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Gather history metadata for SessionConfiguredEvent.
|
||||
let (history_log_id, history_entry_count) =
|
||||
crate::message_history::history_metadata(&config).await;
|
||||
@@ -790,8 +743,6 @@ async fn submission_loop(
|
||||
}
|
||||
}
|
||||
Op::AddToHistory { text } => {
|
||||
// TODO: What should we do if we got AddToHistory before ConfigureSession?
|
||||
// currently, if ConfigureSession has resume path, this history will be ignored
|
||||
let id = session_id;
|
||||
let config = config.clone();
|
||||
tokio::spawn(async move {
|
||||
@@ -967,17 +918,15 @@ async fn run_task(sess: Arc<Session>, sub_id: String, input: Vec<InputItem>) {
|
||||
) => {
|
||||
items_to_record_in_conversation_history.push(item);
|
||||
let (content, success): (String, Option<bool>) = match result {
|
||||
Ok(CallToolResult {
|
||||
content,
|
||||
is_error,
|
||||
structured_content: _,
|
||||
}) => match serde_json::to_string(content) {
|
||||
Ok(content) => (content, *is_error),
|
||||
Err(e) => {
|
||||
warn!("Failed to serialize MCP tool call output: {e}");
|
||||
(e.to_string(), Some(true))
|
||||
Ok(CallToolResult { content, is_error }) => {
|
||||
match serde_json::to_string(content) {
|
||||
Ok(content) => (content, *is_error),
|
||||
Err(e) => {
|
||||
warn!("Failed to serialize MCP tool call output: {e}");
|
||||
(e.to_string(), Some(true))
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
Err(e) => (e.clone(), Some(true)),
|
||||
};
|
||||
items_to_record_in_conversation_history.push(
|
||||
@@ -1076,13 +1025,12 @@ async fn run_turn(
|
||||
Err(CodexErr::Interrupted) => return Err(CodexErr::Interrupted),
|
||||
Err(CodexErr::EnvVar(var)) => return Err(CodexErr::EnvVar(var)),
|
||||
Err(e) => {
|
||||
// Use the configured provider-specific stream retry budget.
|
||||
let max_retries = sess.client.get_provider().stream_max_retries();
|
||||
if retries < max_retries {
|
||||
if retries < *OPENAI_STREAM_MAX_RETRIES {
|
||||
retries += 1;
|
||||
let delay = backoff(retries);
|
||||
warn!(
|
||||
"stream disconnected - retrying turn ({retries}/{max_retries} in {delay:?})...",
|
||||
"stream disconnected - retrying turn ({retries}/{} in {delay:?})...",
|
||||
*OPENAI_STREAM_MAX_RETRIES
|
||||
);
|
||||
|
||||
// Surface retry information to any UI/front‑end so the
|
||||
@@ -1091,7 +1039,8 @@ async fn run_turn(
|
||||
sess.notify_background_event(
|
||||
&sub_id,
|
||||
format!(
|
||||
"stream error: {e}; retrying {retries}/{max_retries} in {delay:?}…"
|
||||
"stream error: {e}; retrying {retries}/{} in {:?}…",
|
||||
*OPENAI_STREAM_MAX_RETRIES, delay
|
||||
),
|
||||
)
|
||||
.await;
|
||||
@@ -1173,32 +1122,14 @@ async fn try_run_turn(
|
||||
let mut stream = sess.client.clone().stream(&prompt).await?;
|
||||
|
||||
let mut output = Vec::new();
|
||||
loop {
|
||||
// Poll the next item from the model stream. We must inspect *both* Ok and Err
|
||||
// cases so that transient stream failures (e.g., dropped SSE connection before
|
||||
// `response.completed`) bubble up and trigger the caller's retry logic.
|
||||
let event = stream.next().await;
|
||||
let Some(event) = event else {
|
||||
// Channel closed without yielding a final Completed event or explicit error.
|
||||
// Treat as a disconnected stream so the caller can retry.
|
||||
return Err(CodexErr::Stream(
|
||||
"stream closed before response.completed".into(),
|
||||
));
|
||||
};
|
||||
|
||||
let event = match event {
|
||||
Ok(ev) => ev,
|
||||
Err(e) => {
|
||||
// Propagate the underlying stream error to the caller (run_turn), which
|
||||
// will apply the configured `stream_max_retries` policy.
|
||||
return Err(e);
|
||||
}
|
||||
};
|
||||
|
||||
// Patch: buffer for non-streaming mode
|
||||
let mut assistant_message_buf = String::new();
|
||||
let streaming_enabled = sess.client.streaming_enabled();
|
||||
while let Some(event) = stream.next().await {
|
||||
let event = event?;
|
||||
match event {
|
||||
ResponseEvent::Created => {
|
||||
let mut state = sess.state.lock().unwrap();
|
||||
// We successfully created a new response and ensured that all pending calls were included so we can clear the pending call ids.
|
||||
state.pending_call_ids.clear();
|
||||
}
|
||||
ResponseEvent::OutputItemDone(item) => {
|
||||
@@ -1211,18 +1142,59 @@ async fn try_run_turn(
|
||||
_ => None,
|
||||
};
|
||||
if let Some(call_id) = call_id {
|
||||
// We just got a new call id so we need to make sure to respond to it in the next turn.
|
||||
let mut state = sess.state.lock().unwrap();
|
||||
state.pending_call_ids.insert(call_id.clone());
|
||||
}
|
||||
let response = handle_response_item(sess, sub_id, item.clone()).await?;
|
||||
|
||||
// Patch: buffer assistant message text if streaming is disabled
|
||||
if !streaming_enabled {
|
||||
if let ResponseItem::Message { role, content } = &item {
|
||||
if role == "assistant" {
|
||||
for c in content {
|
||||
if let ContentItem::OutputText { text } = c {
|
||||
assistant_message_buf.push_str(text);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
let response = match &item {
|
||||
ResponseItem::Message { .. } | ResponseItem::Reasoning { .. } => None,
|
||||
_ => handle_response_item(sess, sub_id, item.clone()).await?,
|
||||
};
|
||||
output.push(ProcessedResponseItem { item, response });
|
||||
}
|
||||
ResponseEvent::OutputTextDelta(text) => {
|
||||
if streaming_enabled {
|
||||
let event = Event {
|
||||
id: sub_id.to_string(),
|
||||
msg: EventMsg::AgentMessageDelta(AgentMessageEvent { message: text }),
|
||||
};
|
||||
sess.tx_event.send(event).await.ok();
|
||||
} else {
|
||||
assistant_message_buf.push_str(&text);
|
||||
}
|
||||
}
|
||||
ResponseEvent::ReasoningSummaryDelta(text) => {
|
||||
let event = Event {
|
||||
id: sub_id.to_string(),
|
||||
msg: EventMsg::AgentReasoningDelta(AgentReasoningEvent { text }),
|
||||
};
|
||||
sess.tx_event.send(event).await.ok();
|
||||
}
|
||||
ResponseEvent::Completed {
|
||||
response_id,
|
||||
token_usage,
|
||||
} => {
|
||||
// Patch: emit full message if we buffered deltas
|
||||
if !streaming_enabled && !assistant_message_buf.is_empty() {
|
||||
let event = Event {
|
||||
id: sub_id.to_string(),
|
||||
msg: EventMsg::AgentMessage(AgentMessageEvent {
|
||||
message: assistant_message_buf.clone(),
|
||||
}),
|
||||
};
|
||||
sess.tx_event.send(event).await.ok();
|
||||
}
|
||||
if let Some(token_usage) = token_usage {
|
||||
sess.tx_event
|
||||
.send(Event {
|
||||
@@ -1235,24 +1207,11 @@ async fn try_run_turn(
|
||||
|
||||
let mut state = sess.state.lock().unwrap();
|
||||
state.previous_response_id = Some(response_id);
|
||||
return Ok(output);
|
||||
}
|
||||
ResponseEvent::OutputTextDelta(delta) => {
|
||||
let event = Event {
|
||||
id: sub_id.to_string(),
|
||||
msg: EventMsg::AgentMessageDelta(AgentMessageDeltaEvent { delta }),
|
||||
};
|
||||
sess.tx_event.send(event).await.ok();
|
||||
}
|
||||
ResponseEvent::ReasoningSummaryDelta(delta) => {
|
||||
let event = Event {
|
||||
id: sub_id.to_string(),
|
||||
msg: EventMsg::AgentReasoningDelta(AgentReasoningDeltaEvent { delta }),
|
||||
};
|
||||
sess.tx_event.send(event).await.ok();
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
async fn handle_response_item(
|
||||
@@ -1355,13 +1314,13 @@ async fn handle_function_call(
|
||||
let params = match parse_container_exec_arguments(arguments, sess, &call_id) {
|
||||
Ok(params) => params,
|
||||
Err(output) => {
|
||||
return *output;
|
||||
return output;
|
||||
}
|
||||
};
|
||||
handle_container_exec_with_params(params, sess, sub_id, call_id).await
|
||||
}
|
||||
_ => {
|
||||
match sess.mcp_connection_manager.parse_tool_name(&name) {
|
||||
match try_parse_fully_qualified_tool_name(&name) {
|
||||
Some((server, tool_name)) => {
|
||||
// TODO(mbolin): Determine appropriate timeout for tool call.
|
||||
let timeout = None;
|
||||
@@ -1398,7 +1357,7 @@ fn parse_container_exec_arguments(
|
||||
arguments: String,
|
||||
sess: &Session,
|
||||
call_id: &str,
|
||||
) -> Result<ExecParams, Box<ResponseInputItem>> {
|
||||
) -> Result<ExecParams, ResponseInputItem> {
|
||||
// parse command
|
||||
match serde_json::from_str::<ShellToolCallParams>(&arguments) {
|
||||
Ok(shell_tool_call_params) => Ok(to_exec_params(shell_tool_call_params, sess)),
|
||||
@@ -1411,7 +1370,7 @@ fn parse_container_exec_arguments(
|
||||
success: None,
|
||||
},
|
||||
};
|
||||
Err(Box::new(output))
|
||||
Err(output)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -131,15 +131,19 @@ pub struct Config {
|
||||
/// request using the Responses API.
|
||||
pub model_reasoning_summary: ReasoningSummary,
|
||||
|
||||
/// Whether to surface live streaming delta events in front-ends. When `true`
|
||||
/// (default) Codex will forward `AgentMessageDelta` / `AgentReasoningDelta`
|
||||
/// events and UIs may show a typing indicator. When `false` Codex UIs should
|
||||
/// ignore delta events and rely solely on the final aggregated
|
||||
/// `AgentMessage`/`AgentReasoning` items (legacy behaviour).
|
||||
pub streaming_enabled: bool,
|
||||
|
||||
/// When set to `true`, overrides the default heuristic and forces
|
||||
/// `model_supports_reasoning_summaries()` to return `true`.
|
||||
pub model_supports_reasoning_summaries: bool,
|
||||
|
||||
/// Base URL for requests to ChatGPT (as opposed to the OpenAI API).
|
||||
pub chatgpt_base_url: String,
|
||||
|
||||
/// Experimental rollout resume path (absolute path to .jsonl; undocumented).
|
||||
pub experimental_resume: Option<PathBuf>,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
@@ -325,8 +329,12 @@ pub struct ConfigToml {
|
||||
/// Base URL for requests to ChatGPT (as opposed to the OpenAI API).
|
||||
pub chatgpt_base_url: Option<String>,
|
||||
|
||||
/// Experimental rollout resume path (absolute path to .jsonl; undocumented).
|
||||
pub experimental_resume: Option<PathBuf>,
|
||||
/// Whether to surface live streaming delta events in front-ends. When `true`
|
||||
/// (default) Codex will forward `AgentMessageDelta` / `AgentReasoningDelta`
|
||||
/// events and UIs may show a typing indicator. When `false` Codex UIs should
|
||||
/// ignore delta events and rely solely on the final aggregated
|
||||
/// `AgentMessage`/`AgentReasoning` items (legacy behaviour).
|
||||
pub streaming: Option<bool>,
|
||||
}
|
||||
|
||||
impl ConfigToml {
|
||||
@@ -454,9 +462,6 @@ impl Config {
|
||||
.as_ref()
|
||||
.map(|info| info.max_output_tokens)
|
||||
});
|
||||
|
||||
let experimental_resume = cfg.experimental_resume;
|
||||
|
||||
let config = Self {
|
||||
model,
|
||||
model_context_window,
|
||||
@@ -495,6 +500,7 @@ impl Config {
|
||||
.or(cfg.model_reasoning_summary)
|
||||
.unwrap_or_default(),
|
||||
|
||||
streaming_enabled: cfg.streaming.unwrap_or(true),
|
||||
model_supports_reasoning_summaries: cfg
|
||||
.model_supports_reasoning_summaries
|
||||
.unwrap_or(false),
|
||||
@@ -503,8 +509,6 @@ impl Config {
|
||||
.chatgpt_base_url
|
||||
.or(cfg.chatgpt_base_url)
|
||||
.unwrap_or("https://chatgpt.com/backend-api/".to_string()),
|
||||
|
||||
experimental_resume,
|
||||
};
|
||||
Ok(config)
|
||||
}
|
||||
@@ -693,9 +697,6 @@ name = "OpenAI using Chat Completions"
|
||||
base_url = "https://api.openai.com/v1"
|
||||
env_key = "OPENAI_API_KEY"
|
||||
wire_api = "chat"
|
||||
request_max_retries = 4 # retry failed HTTP requests
|
||||
stream_max_retries = 10 # retry dropped SSE streams
|
||||
stream_idle_timeout_ms = 300000 # 5m idle timeout
|
||||
|
||||
[profiles.o3]
|
||||
model = "o3"
|
||||
@@ -736,9 +737,6 @@ disable_response_storage = true
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
request_max_retries: Some(4),
|
||||
stream_max_retries: Some(10),
|
||||
stream_idle_timeout_ms: Some(300_000),
|
||||
};
|
||||
let model_provider_map = {
|
||||
let mut model_provider_map = built_in_model_providers();
|
||||
@@ -815,9 +813,9 @@ disable_response_storage = true
|
||||
hide_agent_reasoning: false,
|
||||
model_reasoning_effort: ReasoningEffort::High,
|
||||
model_reasoning_summary: ReasoningSummary::Detailed,
|
||||
streaming_enabled: true,
|
||||
model_supports_reasoning_summaries: false,
|
||||
chatgpt_base_url: "https://chatgpt.com/backend-api/".to_string(),
|
||||
experimental_resume: None,
|
||||
},
|
||||
o3_profile_config
|
||||
);
|
||||
@@ -862,9 +860,9 @@ disable_response_storage = true
|
||||
hide_agent_reasoning: false,
|
||||
model_reasoning_effort: ReasoningEffort::default(),
|
||||
model_reasoning_summary: ReasoningSummary::default(),
|
||||
streaming_enabled: true,
|
||||
model_supports_reasoning_summaries: false,
|
||||
chatgpt_base_url: "https://chatgpt.com/backend-api/".to_string(),
|
||||
experimental_resume: None,
|
||||
};
|
||||
|
||||
assert_eq!(expected_gpt3_profile_config, gpt3_profile_config);
|
||||
@@ -924,9 +922,9 @@ disable_response_storage = true
|
||||
hide_agent_reasoning: false,
|
||||
model_reasoning_effort: ReasoningEffort::default(),
|
||||
model_reasoning_summary: ReasoningSummary::default(),
|
||||
streaming_enabled: true,
|
||||
model_supports_reasoning_summaries: false,
|
||||
chatgpt_base_url: "https://chatgpt.com/backend-api/".to_string(),
|
||||
experimental_resume: None,
|
||||
};
|
||||
|
||||
assert_eq!(expected_zdr_profile_config, zdr_profile_config);
|
||||
|
||||
@@ -150,7 +150,7 @@ pub type EnvironmentVariablePattern = WildMatchPattern<'*', '?'>;
|
||||
/// Deriving the `env` based on this policy works as follows:
|
||||
/// 1. Create an initial map based on the `inherit` policy.
|
||||
/// 2. If `ignore_default_excludes` is false, filter the map using the default
|
||||
/// exclude pattern(s), which are: `"*KEY*"` and `"*TOKEN*"`.
|
||||
/// exclude pattern(s), which are: "*KEY*" and "*TOKEN*".
|
||||
/// 3. If `exclude` is not empty, filter the map using the provided patterns.
|
||||
/// 4. Insert any entries from `r#set` into the map.
|
||||
/// 5. If non-empty, filter the map using the `include_only` patterns.
|
||||
@@ -228,3 +228,10 @@ pub enum ReasoningSummary {
|
||||
/// Option to disable reasoning summaries.
|
||||
None,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// NOTE: The canonical ConfigToml definition lives in `crate::config`.
|
||||
// Historically this file accidentally re-declared that struct, which caused
|
||||
// drift and confusion. The duplicate has been removed; please use
|
||||
// `codex_core::config::ConfigToml` instead.
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
@@ -11,6 +11,14 @@ env_flags! {
|
||||
pub OPENAI_TIMEOUT_MS: Duration = Duration::from_millis(300_000), |value| {
|
||||
value.parse().map(Duration::from_millis)
|
||||
};
|
||||
pub OPENAI_REQUEST_MAX_RETRIES: u64 = 4;
|
||||
pub OPENAI_STREAM_MAX_RETRIES: u64 = 10;
|
||||
|
||||
// We generally don't want to disconnect; this updates the timeout to be five minutes
|
||||
// which matches the upstream typescript codex impl.
|
||||
pub OPENAI_STREAM_IDLE_TIMEOUT_MS: Duration = Duration::from_millis(300_000), |value| {
|
||||
value.parse().map(Duration::from_millis)
|
||||
};
|
||||
|
||||
/// Fixture path for offline tests (see client.rs).
|
||||
pub CODEX_RS_SSE_FIXTURE: Option<&str> = None;
|
||||
|
||||
@@ -7,7 +7,6 @@
|
||||
//! `"<server><MCP_TOOL_NAME_DELIMITER><tool>"` as the key.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Context;
|
||||
@@ -17,13 +16,8 @@ use codex_mcp_client::McpClient;
|
||||
use mcp_types::ClientCapabilities;
|
||||
use mcp_types::Implementation;
|
||||
use mcp_types::Tool;
|
||||
|
||||
use serde_json::json;
|
||||
use sha1::Digest;
|
||||
use sha1::Sha1;
|
||||
use tokio::task::JoinSet;
|
||||
use tracing::info;
|
||||
use tracing::warn;
|
||||
|
||||
use crate::config_types::McpServerConfig;
|
||||
|
||||
@@ -32,8 +26,7 @@ use crate::config_types::McpServerConfig;
|
||||
///
|
||||
/// OpenAI requires tool names to conform to `^[a-zA-Z0-9_-]+$`, so we must
|
||||
/// choose a delimiter from this character set.
|
||||
const MCP_TOOL_NAME_DELIMITER: &str = "__";
|
||||
const MAX_TOOL_NAME_LENGTH: usize = 64;
|
||||
const MCP_TOOL_NAME_DELIMITER: &str = "__OAI_CODEX_MCP__";
|
||||
|
||||
/// Timeout for the `tools/list` request.
|
||||
const LIST_TOOLS_TIMEOUT: Duration = Duration::from_secs(10);
|
||||
@@ -42,42 +35,16 @@ const LIST_TOOLS_TIMEOUT: Duration = Duration::from_secs(10);
|
||||
/// spawned successfully.
|
||||
pub type ClientStartErrors = HashMap<String, anyhow::Error>;
|
||||
|
||||
fn qualify_tools(tools: Vec<ToolInfo>) -> HashMap<String, ToolInfo> {
|
||||
let mut used_names = HashSet::new();
|
||||
let mut qualified_tools = HashMap::new();
|
||||
for tool in tools {
|
||||
let mut qualified_name = format!(
|
||||
"{}{}{}",
|
||||
tool.server_name, MCP_TOOL_NAME_DELIMITER, tool.tool_name
|
||||
);
|
||||
if qualified_name.len() > MAX_TOOL_NAME_LENGTH {
|
||||
let mut hasher = Sha1::new();
|
||||
hasher.update(qualified_name.as_bytes());
|
||||
let sha1 = hasher.finalize();
|
||||
let sha1_str = format!("{sha1:x}");
|
||||
|
||||
// Truncate to make room for the hash suffix
|
||||
let prefix_len = MAX_TOOL_NAME_LENGTH - sha1_str.len();
|
||||
|
||||
qualified_name = format!("{}{}", &qualified_name[..prefix_len], sha1_str);
|
||||
}
|
||||
|
||||
if used_names.contains(&qualified_name) {
|
||||
warn!("skipping duplicated tool {}", qualified_name);
|
||||
continue;
|
||||
}
|
||||
|
||||
used_names.insert(qualified_name.clone());
|
||||
qualified_tools.insert(qualified_name, tool);
|
||||
}
|
||||
|
||||
qualified_tools
|
||||
fn fully_qualified_tool_name(server: &str, tool: &str) -> String {
|
||||
format!("{server}{MCP_TOOL_NAME_DELIMITER}{tool}")
|
||||
}
|
||||
|
||||
struct ToolInfo {
|
||||
server_name: String,
|
||||
tool_name: String,
|
||||
tool: Tool,
|
||||
pub(crate) fn try_parse_fully_qualified_tool_name(fq_name: &str) -> Option<(String, String)> {
|
||||
let (server, tool) = fq_name.split_once(MCP_TOOL_NAME_DELIMITER)?;
|
||||
if server.is_empty() || tool.is_empty() {
|
||||
return None;
|
||||
}
|
||||
Some((server.to_string(), tool.to_string()))
|
||||
}
|
||||
|
||||
/// A thin wrapper around a set of running [`McpClient`] instances.
|
||||
@@ -90,7 +57,7 @@ pub(crate) struct McpConnectionManager {
|
||||
clients: HashMap<String, std::sync::Arc<McpClient>>,
|
||||
|
||||
/// Fully qualified tool name -> tool instance.
|
||||
tools: HashMap<String, ToolInfo>,
|
||||
tools: HashMap<String, Tool>,
|
||||
}
|
||||
|
||||
impl McpConnectionManager {
|
||||
@@ -112,19 +79,9 @@ impl McpConnectionManager {
|
||||
|
||||
// Launch all configured servers concurrently.
|
||||
let mut join_set = JoinSet::new();
|
||||
let mut errors = ClientStartErrors::new();
|
||||
|
||||
for (server_name, cfg) in mcp_servers {
|
||||
// Validate server name before spawning
|
||||
if !is_valid_mcp_server_name(&server_name) {
|
||||
let error = anyhow::anyhow!(
|
||||
"invalid server name '{}': must match pattern ^[a-zA-Z0-9_-]+$",
|
||||
server_name
|
||||
);
|
||||
errors.insert(server_name, error);
|
||||
continue;
|
||||
}
|
||||
|
||||
// TODO: Verify server name: require `^[a-zA-Z0-9_-]+$`?
|
||||
join_set.spawn(async move {
|
||||
let McpServerConfig { command, args, env } = cfg;
|
||||
let client_res = McpClient::new_stdio_client(command, args, env).await;
|
||||
@@ -136,14 +93,10 @@ impl McpConnectionManager {
|
||||
experimental: None,
|
||||
roots: None,
|
||||
sampling: None,
|
||||
// https://modelcontextprotocol.io/specification/2025-06-18/client/elicitation#capabilities
|
||||
// indicates this should be an empty object.
|
||||
elicitation: Some(json!({})),
|
||||
},
|
||||
client_info: Implementation {
|
||||
name: "codex-mcp-client".to_owned(),
|
||||
version: env!("CARGO_PKG_VERSION").to_owned(),
|
||||
title: Some("Codex".into()),
|
||||
},
|
||||
protocol_version: mcp_types::MCP_SCHEMA_VERSION.to_owned(),
|
||||
};
|
||||
@@ -164,6 +117,7 @@ impl McpConnectionManager {
|
||||
|
||||
let mut clients: HashMap<String, std::sync::Arc<McpClient>> =
|
||||
HashMap::with_capacity(join_set.len());
|
||||
let mut errors = ClientStartErrors::new();
|
||||
|
||||
while let Some(res) = join_set.join_next().await {
|
||||
let (server_name, client_res) = res?; // JoinError propagation
|
||||
@@ -178,9 +132,7 @@ impl McpConnectionManager {
|
||||
}
|
||||
}
|
||||
|
||||
let all_tools = list_all_tools(&clients).await?;
|
||||
|
||||
let tools = qualify_tools(all_tools);
|
||||
let tools = list_all_tools(&clients).await?;
|
||||
|
||||
Ok((Self { clients, tools }, errors))
|
||||
}
|
||||
@@ -188,10 +140,7 @@ impl McpConnectionManager {
|
||||
/// Returns a single map that contains **all** tools. Each key is the
|
||||
/// fully-qualified name for the tool.
|
||||
pub fn list_all_tools(&self) -> HashMap<String, Tool> {
|
||||
self.tools
|
||||
.iter()
|
||||
.map(|(name, tool)| (name.clone(), tool.tool.clone()))
|
||||
.collect()
|
||||
self.tools.clone()
|
||||
}
|
||||
|
||||
/// Invoke the tool indicated by the (server, tool) pair.
|
||||
@@ -213,19 +162,13 @@ impl McpConnectionManager {
|
||||
.await
|
||||
.with_context(|| format!("tool call failed for `{server}/{tool}`"))
|
||||
}
|
||||
|
||||
pub fn parse_tool_name(&self, tool_name: &str) -> Option<(String, String)> {
|
||||
self.tools
|
||||
.get(tool_name)
|
||||
.map(|tool| (tool.server_name.clone(), tool.tool_name.clone()))
|
||||
}
|
||||
}
|
||||
|
||||
/// Query every server for its available tools and return a single map that
|
||||
/// contains **all** tools. Each key is the fully-qualified name for the tool.
|
||||
async fn list_all_tools(
|
||||
pub async fn list_all_tools(
|
||||
clients: &HashMap<String, std::sync::Arc<McpClient>>,
|
||||
) -> Result<Vec<ToolInfo>> {
|
||||
) -> Result<HashMap<String, Tool>> {
|
||||
let mut join_set = JoinSet::new();
|
||||
|
||||
// Spawn one task per server so we can query them concurrently. This
|
||||
@@ -242,19 +185,18 @@ async fn list_all_tools(
|
||||
});
|
||||
}
|
||||
|
||||
let mut aggregated: Vec<ToolInfo> = Vec::with_capacity(join_set.len());
|
||||
let mut aggregated: HashMap<String, Tool> = HashMap::with_capacity(join_set.len());
|
||||
|
||||
while let Some(join_res) = join_set.join_next().await {
|
||||
let (server_name, list_result) = join_res?;
|
||||
let list_result = list_result?;
|
||||
|
||||
for tool in list_result.tools {
|
||||
let tool_info = ToolInfo {
|
||||
server_name: server_name.clone(),
|
||||
tool_name: tool.name.clone(),
|
||||
tool,
|
||||
};
|
||||
aggregated.push(tool_info);
|
||||
// TODO(mbolin): escape tool names that contain invalid characters.
|
||||
let fq_name = fully_qualified_tool_name(&server_name, &tool.name);
|
||||
if aggregated.insert(fq_name.clone(), tool).is_some() {
|
||||
panic!("tool name collision for '{fq_name}': suspicious");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -266,99 +208,3 @@ async fn list_all_tools(
|
||||
|
||||
Ok(aggregated)
|
||||
}
|
||||
|
||||
fn is_valid_mcp_server_name(server_name: &str) -> bool {
|
||||
!server_name.is_empty()
|
||||
&& server_name
|
||||
.chars()
|
||||
.all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-')
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[allow(clippy::unwrap_used)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use mcp_types::ToolInputSchema;
|
||||
|
||||
fn create_test_tool(server_name: &str, tool_name: &str) -> ToolInfo {
|
||||
ToolInfo {
|
||||
server_name: server_name.to_string(),
|
||||
tool_name: tool_name.to_string(),
|
||||
tool: Tool {
|
||||
annotations: None,
|
||||
description: Some(format!("Test tool: {tool_name}")),
|
||||
input_schema: ToolInputSchema {
|
||||
properties: None,
|
||||
required: None,
|
||||
r#type: "object".to_string(),
|
||||
},
|
||||
name: tool_name.to_string(),
|
||||
output_schema: None,
|
||||
title: None,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_qualify_tools_short_non_duplicated_names() {
|
||||
let tools = vec![
|
||||
create_test_tool("server1", "tool1"),
|
||||
create_test_tool("server1", "tool2"),
|
||||
];
|
||||
|
||||
let qualified_tools = qualify_tools(tools);
|
||||
|
||||
assert_eq!(qualified_tools.len(), 2);
|
||||
assert!(qualified_tools.contains_key("server1__tool1"));
|
||||
assert!(qualified_tools.contains_key("server1__tool2"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_qualify_tools_duplicated_names_skipped() {
|
||||
let tools = vec![
|
||||
create_test_tool("server1", "duplicate_tool"),
|
||||
create_test_tool("server1", "duplicate_tool"),
|
||||
];
|
||||
|
||||
let qualified_tools = qualify_tools(tools);
|
||||
|
||||
// Only the first tool should remain, the second is skipped
|
||||
assert_eq!(qualified_tools.len(), 1);
|
||||
assert!(qualified_tools.contains_key("server1__duplicate_tool"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_qualify_tools_long_names_same_server() {
|
||||
let server_name = "my_server";
|
||||
|
||||
let tools = vec![
|
||||
create_test_tool(
|
||||
server_name,
|
||||
"extremely_lengthy_function_name_that_absolutely_surpasses_all_reasonable_limits",
|
||||
),
|
||||
create_test_tool(
|
||||
server_name,
|
||||
"yet_another_extremely_lengthy_function_name_that_absolutely_surpasses_all_reasonable_limits",
|
||||
),
|
||||
];
|
||||
|
||||
let qualified_tools = qualify_tools(tools);
|
||||
|
||||
assert_eq!(qualified_tools.len(), 2);
|
||||
|
||||
let mut keys: Vec<_> = qualified_tools.keys().cloned().collect();
|
||||
keys.sort();
|
||||
|
||||
assert_eq!(keys[0].len(), 64);
|
||||
assert_eq!(
|
||||
keys[0],
|
||||
"my_server__extremely_lena02e507efc5a9de88637e436690364fd4219e4ef"
|
||||
);
|
||||
|
||||
assert_eq!(keys[1].len(), 64);
|
||||
assert_eq!(
|
||||
keys[1],
|
||||
"my_server__yet_another_e1c3987bd9c50b826cbe1687966f79f0c602d19ca"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,7 +9,6 @@ use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use std::collections::HashMap;
|
||||
use std::env::VarError;
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::error::EnvVarError;
|
||||
use crate::openai_api_key::get_openai_api_key;
|
||||
@@ -17,9 +16,6 @@ use crate::openai_api_key::get_openai_api_key;
|
||||
/// Value for the `OpenAI-Originator` header that is sent with requests to
|
||||
/// OpenAI.
|
||||
const OPENAI_ORIGINATOR_HEADER: &str = "codex_cli_rs";
|
||||
const DEFAULT_STREAM_IDLE_TIMEOUT_MS: u64 = 300_000;
|
||||
const DEFAULT_STREAM_MAX_RETRIES: u64 = 10;
|
||||
const DEFAULT_REQUEST_MAX_RETRIES: u64 = 4;
|
||||
|
||||
/// Wire protocol that the provider speaks. Most third-party services only
|
||||
/// implement the classic OpenAI Chat Completions JSON schema, whereas OpenAI
|
||||
@@ -30,7 +26,7 @@ const DEFAULT_REQUEST_MAX_RETRIES: u64 = 4;
|
||||
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum WireApi {
|
||||
/// The experimental "Responses" API exposed by OpenAI at `/v1/responses`.
|
||||
/// The experimental “Responses” API exposed by OpenAI at `/v1/responses`.
|
||||
Responses,
|
||||
|
||||
/// Regular Chat Completions compatible with `/v1/chat/completions`.
|
||||
@@ -68,16 +64,6 @@ pub struct ModelProviderInfo {
|
||||
/// value should be used. If the environment variable is not set, or the
|
||||
/// value is empty, the header will not be included in the request.
|
||||
pub env_http_headers: Option<HashMap<String, String>>,
|
||||
|
||||
/// Maximum number of times to retry a failed HTTP request to this provider.
|
||||
pub request_max_retries: Option<u64>,
|
||||
|
||||
/// Number of times to retry reconnecting a dropped streaming response before failing.
|
||||
pub stream_max_retries: Option<u64>,
|
||||
|
||||
/// Idle timeout (in milliseconds) to wait for activity on a streaming response before treating
|
||||
/// the connection as lost.
|
||||
pub stream_idle_timeout_ms: Option<u64>,
|
||||
}
|
||||
|
||||
impl ModelProviderInfo {
|
||||
@@ -175,25 +161,6 @@ impl ModelProviderInfo {
|
||||
None => Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
/// Effective maximum number of request retries for this provider.
|
||||
pub fn request_max_retries(&self) -> u64 {
|
||||
self.request_max_retries
|
||||
.unwrap_or(DEFAULT_REQUEST_MAX_RETRIES)
|
||||
}
|
||||
|
||||
/// Effective maximum number of stream reconnection attempts for this provider.
|
||||
pub fn stream_max_retries(&self) -> u64 {
|
||||
self.stream_max_retries
|
||||
.unwrap_or(DEFAULT_STREAM_MAX_RETRIES)
|
||||
}
|
||||
|
||||
/// Effective idle timeout for streaming responses.
|
||||
pub fn stream_idle_timeout(&self) -> Duration {
|
||||
self.stream_idle_timeout_ms
|
||||
.map(Duration::from_millis)
|
||||
.unwrap_or(Duration::from_millis(DEFAULT_STREAM_IDLE_TIMEOUT_MS))
|
||||
}
|
||||
}
|
||||
|
||||
/// Built-in default provider list.
|
||||
@@ -238,10 +205,6 @@ pub fn built_in_model_providers() -> HashMap<String, ModelProviderInfo> {
|
||||
.into_iter()
|
||||
.collect(),
|
||||
),
|
||||
// Use global defaults for retry/timeout unless overridden in config.toml.
|
||||
request_max_retries: None,
|
||||
stream_max_retries: None,
|
||||
stream_idle_timeout_ms: None,
|
||||
},
|
||||
),
|
||||
]
|
||||
@@ -271,9 +234,6 @@ base_url = "http://localhost:11434/v1"
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
request_max_retries: None,
|
||||
stream_max_retries: None,
|
||||
stream_idle_timeout_ms: None,
|
||||
};
|
||||
|
||||
let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap();
|
||||
@@ -299,9 +259,6 @@ query_params = { api-version = "2025-04-01-preview" }
|
||||
}),
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
request_max_retries: None,
|
||||
stream_max_retries: None,
|
||||
stream_idle_timeout_ms: None,
|
||||
};
|
||||
|
||||
let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap();
|
||||
@@ -330,9 +287,6 @@ env_http_headers = { "X-Example-Env-Header" = "EXAMPLE_ENV_VAR" }
|
||||
env_http_headers: Some(maplit::hashmap! {
|
||||
"X-Example-Env-Header".to_string() => "EXAMPLE_ENV_VAR".to_string(),
|
||||
}),
|
||||
request_max_retries: None,
|
||||
stream_max_retries: None,
|
||||
stream_idle_timeout_ms: None,
|
||||
};
|
||||
|
||||
let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap();
|
||||
|
||||
@@ -69,10 +69,6 @@ pub enum Op {
|
||||
/// `ConfigureSession` operation so that the business-logic layer can
|
||||
/// operate deterministically.
|
||||
cwd: std::path::PathBuf,
|
||||
|
||||
/// Path to a rollout file to resume from.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
resume_path: Option<std::path::PathBuf>,
|
||||
},
|
||||
|
||||
/// Abort current task.
|
||||
@@ -286,14 +282,14 @@ pub enum EventMsg {
|
||||
/// Agent text output message
|
||||
AgentMessage(AgentMessageEvent),
|
||||
|
||||
/// Agent text output delta message
|
||||
AgentMessageDelta(AgentMessageDeltaEvent),
|
||||
/// Incremental assistant text delta
|
||||
AgentMessageDelta(AgentMessageEvent),
|
||||
|
||||
/// Reasoning event from agent.
|
||||
AgentReasoning(AgentReasoningEvent),
|
||||
|
||||
/// Agent reasoning delta event from agent.
|
||||
AgentReasoningDelta(AgentReasoningDeltaEvent),
|
||||
/// Incremental reasoning text delta.
|
||||
AgentReasoningDelta(AgentReasoningEvent),
|
||||
|
||||
/// Ack the client's configure message.
|
||||
SessionConfigured(SessionConfiguredEvent),
|
||||
@@ -350,21 +346,11 @@ pub struct AgentMessageEvent {
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct AgentMessageDeltaEvent {
|
||||
pub delta: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct AgentReasoningEvent {
|
||||
pub text: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct AgentReasoningDeltaEvent {
|
||||
pub delta: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct McpToolCallBeginEvent {
|
||||
/// Identifier so this can be paired with the McpToolCallEnd event.
|
||||
|
||||
@@ -1,47 +1,33 @@
|
||||
//! Persist Codex session rollouts (.jsonl) so sessions can be replayed or inspected later.
|
||||
//! Functionality to persist a Codex conversation *rollout* – a linear list of
|
||||
//! [`ResponseItem`] objects exchanged during a session – to disk so that
|
||||
//! sessions can be replayed or inspected later (mirrors the behaviour of the
|
||||
//! upstream TypeScript implementation).
|
||||
|
||||
use std::fs::File;
|
||||
use std::fs::{self};
|
||||
use std::io::Error as IoError;
|
||||
use std::path::Path;
|
||||
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use serde_json::Value;
|
||||
use time::OffsetDateTime;
|
||||
use time::format_description::FormatItem;
|
||||
use time::macros::format_description;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::sync::mpsc::Sender;
|
||||
use tokio::sync::mpsc::{self};
|
||||
use tracing::info;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::models::ResponseItem;
|
||||
|
||||
/// Folder inside `~/.codex` that holds saved rollouts.
|
||||
const SESSIONS_SUBDIR: &str = "sessions";
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Default)]
|
||||
pub struct SessionMeta {
|
||||
pub id: Uuid,
|
||||
pub timestamp: String,
|
||||
pub instructions: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Default, Clone)]
|
||||
pub struct SessionStateSnapshot {
|
||||
pub previous_response_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Default, Clone)]
|
||||
pub struct SavedSession {
|
||||
pub session: SessionMeta,
|
||||
#[serde(default)]
|
||||
pub items: Vec<ResponseItem>,
|
||||
#[serde(default)]
|
||||
pub state: SessionStateSnapshot,
|
||||
pub session_id: Uuid,
|
||||
#[derive(Serialize)]
|
||||
struct SessionMeta {
|
||||
id: String,
|
||||
timestamp: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
instructions: Option<String>,
|
||||
}
|
||||
|
||||
/// Records all [`ResponseItem`]s for a session and flushes them to disk after
|
||||
@@ -55,13 +41,7 @@ pub struct SavedSession {
|
||||
/// ```
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct RolloutRecorder {
|
||||
tx: Sender<RolloutCmd>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
enum RolloutCmd {
|
||||
AddItems(Vec<ResponseItem>),
|
||||
UpdateState(SessionStateSnapshot),
|
||||
tx: Sender<String>,
|
||||
}
|
||||
|
||||
impl RolloutRecorder {
|
||||
@@ -79,6 +59,7 @@ impl RolloutRecorder {
|
||||
timestamp,
|
||||
} = create_log_file(config, uuid)?;
|
||||
|
||||
// Build the static session metadata JSON first.
|
||||
let timestamp_format: &[FormatItem] = format_description!(
|
||||
"[year]-[month]-[day]T[hour]:[minute]:[second].[subsecond digits:3]Z"
|
||||
);
|
||||
@@ -88,29 +69,46 @@ impl RolloutRecorder {
|
||||
|
||||
let meta = SessionMeta {
|
||||
timestamp,
|
||||
id: session_id,
|
||||
id: session_id.to_string(),
|
||||
instructions,
|
||||
};
|
||||
|
||||
// A reasonably-sized bounded channel. If the buffer fills up the send
|
||||
// future will yield, which is fine – we only need to ensure we do not
|
||||
// perform *blocking* I/O on the caller’s thread.
|
||||
let (tx, rx) = mpsc::channel::<RolloutCmd>(256);
|
||||
let (tx, mut rx) = mpsc::channel::<String>(256);
|
||||
|
||||
// Spawn a Tokio task that owns the file handle and performs async
|
||||
// writes. Using `tokio::fs::File` keeps everything on the async I/O
|
||||
// driver instead of blocking the runtime.
|
||||
tokio::task::spawn(rollout_writer(
|
||||
tokio::fs::File::from_std(file),
|
||||
rx,
|
||||
Some(meta),
|
||||
));
|
||||
tokio::task::spawn(async move {
|
||||
let mut file = tokio::fs::File::from_std(file);
|
||||
|
||||
Ok(Self { tx })
|
||||
while let Some(line) = rx.recv().await {
|
||||
// Write line + newline, then flush to disk.
|
||||
if let Err(e) = file.write_all(line.as_bytes()).await {
|
||||
tracing::warn!("rollout writer: failed to write line: {e}");
|
||||
break;
|
||||
}
|
||||
if let Err(e) = file.write_all(b"\n").await {
|
||||
tracing::warn!("rollout writer: failed to write newline: {e}");
|
||||
break;
|
||||
}
|
||||
if let Err(e) = file.flush().await {
|
||||
tracing::warn!("rollout writer: failed to flush: {e}");
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let recorder = Self { tx };
|
||||
// Ensure SessionMeta is the first item in the file.
|
||||
recorder.record_item(&meta).await?;
|
||||
Ok(recorder)
|
||||
}
|
||||
|
||||
/// Append `items` to the rollout file.
|
||||
pub(crate) async fn record_items(&self, items: &[ResponseItem]) -> std::io::Result<()> {
|
||||
let mut filtered = Vec::new();
|
||||
for item in items {
|
||||
match item {
|
||||
// Note that function calls may look a bit strange if they are
|
||||
@@ -119,86 +117,27 @@ impl RolloutRecorder {
|
||||
ResponseItem::Message { .. }
|
||||
| ResponseItem::LocalShellCall { .. }
|
||||
| ResponseItem::FunctionCall { .. }
|
||||
| ResponseItem::FunctionCallOutput { .. } => filtered.push(item.clone()),
|
||||
| ResponseItem::FunctionCallOutput { .. } => {}
|
||||
ResponseItem::Reasoning { .. } | ResponseItem::Other => {
|
||||
// These should never be serialized.
|
||||
continue;
|
||||
}
|
||||
}
|
||||
self.record_item(item).await?;
|
||||
}
|
||||
if filtered.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
self.tx
|
||||
.send(RolloutCmd::AddItems(filtered))
|
||||
.await
|
||||
.map_err(|e| IoError::other(format!("failed to queue rollout items: {e}")))
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) async fn record_state(&self, state: SessionStateSnapshot) -> std::io::Result<()> {
|
||||
async fn record_item(&self, item: &impl Serialize) -> std::io::Result<()> {
|
||||
// Serialize the item to JSON first so that the writer thread only has
|
||||
// to perform the actual write.
|
||||
let json = serde_json::to_string(item)
|
||||
.map_err(|e| IoError::other(format!("failed to serialize response items: {e}")))?;
|
||||
|
||||
self.tx
|
||||
.send(RolloutCmd::UpdateState(state))
|
||||
.send(json)
|
||||
.await
|
||||
.map_err(|e| IoError::other(format!("failed to queue rollout state: {e}")))
|
||||
}
|
||||
|
||||
pub async fn resume(path: &Path) -> std::io::Result<(Self, SavedSession)> {
|
||||
info!("Resuming rollout from {path:?}");
|
||||
let text = tokio::fs::read_to_string(path).await?;
|
||||
let mut lines = text.lines();
|
||||
let meta_line = lines
|
||||
.next()
|
||||
.ok_or_else(|| IoError::other("empty session file"))?;
|
||||
let session: SessionMeta = serde_json::from_str(meta_line)
|
||||
.map_err(|e| IoError::other(format!("failed to parse session meta: {e}")))?;
|
||||
let mut items = Vec::new();
|
||||
let mut state = SessionStateSnapshot::default();
|
||||
|
||||
for line in lines {
|
||||
if line.trim().is_empty() {
|
||||
continue;
|
||||
}
|
||||
let v: Value = match serde_json::from_str(line) {
|
||||
Ok(v) => v,
|
||||
Err(_) => continue,
|
||||
};
|
||||
if v.get("record_type")
|
||||
.and_then(|rt| rt.as_str())
|
||||
.map(|s| s == "state")
|
||||
.unwrap_or(false)
|
||||
{
|
||||
if let Ok(s) = serde_json::from_value::<SessionStateSnapshot>(v.clone()) {
|
||||
state = s
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if let Ok(item) = serde_json::from_value::<ResponseItem>(v.clone()) {
|
||||
match item {
|
||||
ResponseItem::Message { .. }
|
||||
| ResponseItem::LocalShellCall { .. }
|
||||
| ResponseItem::FunctionCall { .. }
|
||||
| ResponseItem::FunctionCallOutput { .. } => items.push(item),
|
||||
ResponseItem::Reasoning { .. } | ResponseItem::Other => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let saved = SavedSession {
|
||||
session: session.clone(),
|
||||
items: items.clone(),
|
||||
state: state.clone(),
|
||||
session_id: session.id,
|
||||
};
|
||||
|
||||
let file = std::fs::OpenOptions::new()
|
||||
.append(true)
|
||||
.read(true)
|
||||
.open(path)?;
|
||||
|
||||
let (tx, rx) = mpsc::channel::<RolloutCmd>(256);
|
||||
tokio::task::spawn(rollout_writer(tokio::fs::File::from_std(file), rx, None));
|
||||
info!("Resumed rollout successfully from {path:?}");
|
||||
Ok((Self { tx }, saved))
|
||||
.map_err(|e| IoError::other(format!("failed to queue rollout item: {e}")))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -214,16 +153,14 @@ struct LogFileInfo {
|
||||
}
|
||||
|
||||
fn create_log_file(config: &Config, session_id: Uuid) -> std::io::Result<LogFileInfo> {
|
||||
// Resolve ~/.codex/sessions/YYYY/MM/DD and create it if missing.
|
||||
let timestamp = OffsetDateTime::now_local()
|
||||
.map_err(|e| IoError::other(format!("failed to get local time: {e}")))?;
|
||||
// Resolve ~/.codex/sessions and create it if missing.
|
||||
let mut dir = config.codex_home.clone();
|
||||
dir.push(SESSIONS_SUBDIR);
|
||||
dir.push(timestamp.year().to_string());
|
||||
dir.push(format!("{:02}", u8::from(timestamp.month())));
|
||||
dir.push(format!("{:02}", timestamp.day()));
|
||||
fs::create_dir_all(&dir)?;
|
||||
|
||||
let timestamp = OffsetDateTime::now_local()
|
||||
.map_err(|e| IoError::other(format!("failed to get local time: {e}")))?;
|
||||
|
||||
// Custom format for YYYY-MM-DDThh-mm-ss. Use `-` instead of `:` for
|
||||
// compatibility with filesystems that do not allow colons in filenames.
|
||||
let format: &[FormatItem] =
|
||||
@@ -246,54 +183,3 @@ fn create_log_file(config: &Config, session_id: Uuid) -> std::io::Result<LogFile
|
||||
timestamp,
|
||||
})
|
||||
}
|
||||
|
||||
async fn rollout_writer(
|
||||
mut file: tokio::fs::File,
|
||||
mut rx: mpsc::Receiver<RolloutCmd>,
|
||||
meta: Option<SessionMeta>,
|
||||
) {
|
||||
if let Some(meta) = meta {
|
||||
if let Ok(json) = serde_json::to_string(&meta) {
|
||||
let _ = file.write_all(json.as_bytes()).await;
|
||||
let _ = file.write_all(b"\n").await;
|
||||
let _ = file.flush().await;
|
||||
}
|
||||
}
|
||||
while let Some(cmd) = rx.recv().await {
|
||||
match cmd {
|
||||
RolloutCmd::AddItems(items) => {
|
||||
for item in items {
|
||||
match item {
|
||||
ResponseItem::Message { .. }
|
||||
| ResponseItem::LocalShellCall { .. }
|
||||
| ResponseItem::FunctionCall { .. }
|
||||
| ResponseItem::FunctionCallOutput { .. } => {
|
||||
if let Ok(json) = serde_json::to_string(&item) {
|
||||
let _ = file.write_all(json.as_bytes()).await;
|
||||
let _ = file.write_all(b"\n").await;
|
||||
}
|
||||
}
|
||||
ResponseItem::Reasoning { .. } | ResponseItem::Other => {}
|
||||
}
|
||||
}
|
||||
let _ = file.flush().await;
|
||||
}
|
||||
RolloutCmd::UpdateState(state) => {
|
||||
#[derive(Serialize)]
|
||||
struct StateLine<'a> {
|
||||
record_type: &'static str,
|
||||
#[serde(flatten)]
|
||||
state: &'a SessionStateSnapshot,
|
||||
}
|
||||
if let Ok(json) = serde_json::to_string(&StateLine {
|
||||
record_type: "state",
|
||||
state: &state,
|
||||
}) {
|
||||
let _ = file.write_all(json.as_bytes()).await;
|
||||
let _ = file.write_all(b"\n").await;
|
||||
let _ = file.flush().await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,11 +2,7 @@
|
||||
|
||||
use assert_cmd::Command as AssertCommand;
|
||||
use codex_core::exec::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR;
|
||||
use std::time::Duration;
|
||||
use std::time::Instant;
|
||||
use tempfile::TempDir;
|
||||
use uuid::Uuid;
|
||||
use walkdir::WalkDir;
|
||||
use wiremock::Mock;
|
||||
use wiremock::MockServer;
|
||||
use wiremock::ResponseTemplate;
|
||||
@@ -62,6 +58,8 @@ async fn chat_mode_stream_cli() {
|
||||
.arg(&provider_override)
|
||||
.arg("-c")
|
||||
.arg("model_provider=\"mock\"")
|
||||
.arg("-c")
|
||||
.arg("streaming=false")
|
||||
.arg("-C")
|
||||
.arg(env!("CARGO_MANIFEST_DIR"))
|
||||
.arg("hello?");
|
||||
@@ -108,6 +106,8 @@ async fn responses_api_stream_cli() {
|
||||
.arg("--")
|
||||
.arg("exec")
|
||||
.arg("--skip-git-repo-check")
|
||||
.arg("-c")
|
||||
.arg("streaming=false")
|
||||
.arg("-C")
|
||||
.arg(env!("CARGO_MANIFEST_DIR"))
|
||||
.arg("hello?");
|
||||
@@ -122,10 +122,9 @@ async fn responses_api_stream_cli() {
|
||||
assert!(stdout.contains("fixture hello"));
|
||||
}
|
||||
|
||||
/// End-to-end: create a session (writes rollout), verify the file, then resume and confirm append.
|
||||
/// Tests chat completions with streaming enabled (streaming=true) through the CLI using a mock server.
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn integration_creates_and_checks_session_file() {
|
||||
// Honor sandbox network restrictions for CI parity with the other tests.
|
||||
async fn chat_mode_streaming_enabled_cli() {
|
||||
if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
||||
println!(
|
||||
"Skipping test because it cannot execute when network is disabled in a Codex sandbox."
|
||||
@@ -133,19 +132,30 @@ async fn integration_creates_and_checks_session_file() {
|
||||
return;
|
||||
}
|
||||
|
||||
// 1. Temp home so we read/write isolated session files.
|
||||
let server = MockServer::start().await;
|
||||
// Simulate streaming deltas: 'h' and 'i' as separate chunks
|
||||
let sse = concat!(
|
||||
"data: {\"choices\":[{\"delta\":{\"content\":\"h\"}}]}\n\n",
|
||||
"data: {\"choices\":[{\"delta\":{\"content\":\"i\"}}]}\n\n",
|
||||
"data: {\"choices\":[{\"delta\":{}}]}\n\n",
|
||||
"data: [DONE]\n\n"
|
||||
);
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/chat/completions"))
|
||||
.respond_with(
|
||||
ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "text/event-stream")
|
||||
.set_body_raw(sse, "text/event-stream"),
|
||||
)
|
||||
.expect(1)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let home = TempDir::new().unwrap();
|
||||
|
||||
// 2. Unique marker we'll look for in the session log.
|
||||
let marker = format!("integration-test-{}", Uuid::new_v4());
|
||||
let prompt = format!("echo {marker}");
|
||||
|
||||
// 3. Use the same offline SSE fixture as responses_api_stream_cli so the test is hermetic.
|
||||
let fixture =
|
||||
std::path::Path::new(env!("CARGO_MANIFEST_DIR")).join("tests/cli_responses_fixture.sse");
|
||||
|
||||
// 4. Run the codex CLI through cargo (ensures the right bin is built) and invoke `exec`,
|
||||
// which is what records a session.
|
||||
let provider_override = format!(
|
||||
"model_providers.mock={{ name = \"mock\", base_url = \"{}/v1\", env_key = \"PATH\", wire_api = \"chat\" }}",
|
||||
server.uri()
|
||||
);
|
||||
let mut cmd = AssertCommand::new("cargo");
|
||||
cmd.arg("run")
|
||||
.arg("-p")
|
||||
@@ -154,166 +164,95 @@ async fn integration_creates_and_checks_session_file() {
|
||||
.arg("--")
|
||||
.arg("exec")
|
||||
.arg("--skip-git-repo-check")
|
||||
.arg("-c")
|
||||
.arg(&provider_override)
|
||||
.arg("-c")
|
||||
.arg("model_provider=\"mock\"")
|
||||
.arg("-c")
|
||||
.arg("streaming=true")
|
||||
.arg("-C")
|
||||
.arg(env!("CARGO_MANIFEST_DIR"))
|
||||
.arg(&prompt);
|
||||
.arg("hello?");
|
||||
cmd.env("CODEX_HOME", home.path())
|
||||
.env("OPENAI_API_KEY", "dummy")
|
||||
.env("CODEX_RS_SSE_FIXTURE", &fixture)
|
||||
// Required for CLI arg parsing even though fixture short-circuits network usage.
|
||||
.env("OPENAI_BASE_URL", "http://unused.local");
|
||||
.env("OPENAI_BASE_URL", format!("{}/v1", server.uri()));
|
||||
|
||||
let output = cmd.output().unwrap();
|
||||
assert!(
|
||||
output.status.success(),
|
||||
"codex-cli exec failed: {}",
|
||||
String::from_utf8_lossy(&output.stderr)
|
||||
);
|
||||
|
||||
// Wait for sessions dir to appear.
|
||||
let sessions_dir = home.path().join("sessions");
|
||||
let dir_deadline = Instant::now() + Duration::from_secs(5);
|
||||
while !sessions_dir.exists() && Instant::now() < dir_deadline {
|
||||
std::thread::sleep(Duration::from_millis(50));
|
||||
}
|
||||
assert!(sessions_dir.exists(), "sessions directory never appeared");
|
||||
|
||||
// Find the session file that contains `marker`.
|
||||
let deadline = Instant::now() + Duration::from_secs(10);
|
||||
let mut matching_path: Option<std::path::PathBuf> = None;
|
||||
while Instant::now() < deadline && matching_path.is_none() {
|
||||
for entry in WalkDir::new(&sessions_dir) {
|
||||
let entry = match entry {
|
||||
Ok(e) => e,
|
||||
Err(_) => continue,
|
||||
};
|
||||
if !entry.file_type().is_file() {
|
||||
continue;
|
||||
}
|
||||
if !entry.file_name().to_string_lossy().ends_with(".jsonl") {
|
||||
continue;
|
||||
}
|
||||
let path = entry.path();
|
||||
let Ok(content) = std::fs::read_to_string(path) else {
|
||||
continue;
|
||||
};
|
||||
let mut lines = content.lines();
|
||||
if lines.next().is_none() {
|
||||
continue;
|
||||
}
|
||||
for line in lines {
|
||||
if line.trim().is_empty() {
|
||||
continue;
|
||||
}
|
||||
let item: serde_json::Value = match serde_json::from_str(line) {
|
||||
Ok(v) => v,
|
||||
Err(_) => continue,
|
||||
};
|
||||
if item.get("type").and_then(|t| t.as_str()) == Some("message") {
|
||||
if let Some(c) = item.get("content") {
|
||||
if c.to_string().contains(&marker) {
|
||||
matching_path = Some(path.to_path_buf());
|
||||
break;
|
||||
}
|
||||
}
|
||||
assert!(output.status.success());
|
||||
let stdout = String::from_utf8_lossy(&output.stdout);
|
||||
// Assert that 'h' and 'i' are output as two separate chunks from stdout, not as a single chunk
|
||||
// We split the output on 'h' and 'i' and check their order and separation
|
||||
let mut chunks = Vec::new();
|
||||
let mut last = 0;
|
||||
for (idx, c) in stdout.char_indices() {
|
||||
if c == 'h' || c == 'i' {
|
||||
if last != idx {
|
||||
let chunk = &stdout[last..idx];
|
||||
if !chunk.trim().is_empty() {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
}
|
||||
}
|
||||
if matching_path.is_none() {
|
||||
std::thread::sleep(Duration::from_millis(50));
|
||||
chunks.push(&stdout[idx..idx + c.len_utf8()]);
|
||||
last = idx + c.len_utf8();
|
||||
}
|
||||
}
|
||||
|
||||
let path = match matching_path {
|
||||
Some(p) => p,
|
||||
None => panic!("No session file containing the marker was found"),
|
||||
};
|
||||
|
||||
// Basic sanity checks on location and metadata.
|
||||
let rel = match path.strip_prefix(&sessions_dir) {
|
||||
Ok(r) => r,
|
||||
Err(_) => panic!("session file should live under sessions/"),
|
||||
};
|
||||
let comps: Vec<String> = rel
|
||||
.components()
|
||||
.map(|c| c.as_os_str().to_string_lossy().into_owned())
|
||||
if last < stdout.len() {
|
||||
let chunk = &stdout[last..];
|
||||
if !chunk.trim().is_empty() {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
}
|
||||
// Only keep the 'h' and 'i' chunks
|
||||
let delta_chunks: Vec<&str> = chunks
|
||||
.iter()
|
||||
.cloned()
|
||||
.filter(|s| *s == "h" || *s == "i")
|
||||
.collect();
|
||||
assert_eq!(
|
||||
comps.len(),
|
||||
4,
|
||||
"Expected sessions/YYYY/MM/DD/<file>, got {rel:?}"
|
||||
);
|
||||
let year = &comps[0];
|
||||
let month = &comps[1];
|
||||
let day = &comps[2];
|
||||
assert!(
|
||||
year.len() == 4 && year.chars().all(|c| c.is_ascii_digit()),
|
||||
"Year dir not 4-digit numeric: {year}"
|
||||
);
|
||||
assert!(
|
||||
month.len() == 2 && month.chars().all(|c| c.is_ascii_digit()),
|
||||
"Month dir not zero-padded 2-digit numeric: {month}"
|
||||
);
|
||||
assert!(
|
||||
day.len() == 2 && day.chars().all(|c| c.is_ascii_digit()),
|
||||
"Day dir not zero-padded 2-digit numeric: {day}"
|
||||
);
|
||||
if let Ok(m) = month.parse::<u8>() {
|
||||
assert!((1..=12).contains(&m), "Month out of range: {m}");
|
||||
}
|
||||
if let Ok(d) = day.parse::<u8>() {
|
||||
assert!((1..=31).contains(&d), "Day out of range: {d}");
|
||||
}
|
||||
|
||||
let content =
|
||||
std::fs::read_to_string(&path).unwrap_or_else(|_| panic!("Failed to read session file"));
|
||||
let mut lines = content.lines();
|
||||
let meta_line = lines
|
||||
.next()
|
||||
.ok_or("missing session meta line")
|
||||
.unwrap_or_else(|_| panic!("missing session meta line"));
|
||||
let meta: serde_json::Value = serde_json::from_str(meta_line)
|
||||
.unwrap_or_else(|_| panic!("Failed to parse session meta line as JSON"));
|
||||
assert!(meta.get("id").is_some(), "SessionMeta missing id");
|
||||
assert!(
|
||||
meta.get("timestamp").is_some(),
|
||||
"SessionMeta missing timestamp"
|
||||
delta_chunks,
|
||||
vec!["h", "i"],
|
||||
"Expected two separate delta chunks 'h' and 'i' from stdout"
|
||||
);
|
||||
|
||||
let mut found_message = false;
|
||||
for line in lines {
|
||||
if line.trim().is_empty() {
|
||||
continue;
|
||||
}
|
||||
let Ok(item) = serde_json::from_str::<serde_json::Value>(line) else {
|
||||
continue;
|
||||
};
|
||||
if item.get("type").and_then(|t| t.as_str()) == Some("message") {
|
||||
if let Some(c) = item.get("content") {
|
||||
if c.to_string().contains(&marker) {
|
||||
found_message = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
assert!(
|
||||
found_message,
|
||||
"No message found in session file containing the marker"
|
||||
);
|
||||
server.verify().await;
|
||||
}
|
||||
|
||||
// Second run: resume and append.
|
||||
let orig_len = content.lines().count();
|
||||
let marker2 = format!("integration-resume-{}", Uuid::new_v4());
|
||||
let prompt2 = format!("echo {marker2}");
|
||||
// Cross‑platform safe resume override. On Windows, backslashes in a TOML string must be escaped
|
||||
// or the parse will fail and the raw literal (including quotes) may be preserved all the way down
|
||||
// to Config, which in turn breaks resume because the path is invalid. Normalize to forward slashes
|
||||
// to sidestep the issue.
|
||||
let resume_path_str = path.to_string_lossy().replace('\\', "/");
|
||||
let resume_override = format!("experimental_resume=\"{resume_path_str}\"");
|
||||
let mut cmd2 = AssertCommand::new("cargo");
|
||||
cmd2.arg("run")
|
||||
/// Tests responses API with streaming enabled (streaming=true) through the CLI using a local SSE fixture file.
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn responses_api_streaming_enabled_cli() {
|
||||
if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
||||
println!(
|
||||
"Skipping test because it cannot execute when network is disabled in a Codex sandbox."
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
// Create a fixture with two deltas: 'fixture ' and 'hello'
|
||||
use std::fs;
|
||||
use std::io::Write;
|
||||
let fixture_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
|
||||
.join("tests/cli_responses_fixture_streaming.sse");
|
||||
let mut fixture_file = fs::File::create(&fixture_path).unwrap();
|
||||
writeln!(fixture_file, "event: response.created").unwrap();
|
||||
writeln!(
|
||||
fixture_file,
|
||||
"data: {{\"type\":\"response.created\",\"response\":{{\"id\":\"resp1\"}}}}\n"
|
||||
)
|
||||
.unwrap();
|
||||
writeln!(fixture_file, "event: response.output_text.delta").unwrap();
|
||||
writeln!(fixture_file, "data: {{\"type\":\"response.output_text.delta\",\"delta\":\"fixture \",\"item_id\":\"msg1\"}}\n").unwrap();
|
||||
writeln!(fixture_file, "event: response.output_text.delta").unwrap();
|
||||
writeln!(fixture_file, "data: {{\"type\":\"response.output_text.delta\",\"delta\":\"hello\",\"item_id\":\"msg1\"}}\n").unwrap();
|
||||
writeln!(fixture_file, "event: response.output_text.done").unwrap();
|
||||
writeln!(fixture_file, "data: {{\"type\":\"response.output_text.done\",\"text\":\"fixture hello\",\"item_id\":\"msg1\"}}\n").unwrap();
|
||||
writeln!(fixture_file, "event: response.output_item.done").unwrap();
|
||||
writeln!(fixture_file, "data: {{\"type\":\"response.output_item.done\",\"item\":{{\"type\":\"message\",\"role\":\"assistant\",\"content\":[{{\"type\":\"output_text\",\"text\":\"fixture hello\"}}]}}}}\n").unwrap();
|
||||
writeln!(fixture_file, "event: response.completed").unwrap();
|
||||
writeln!(fixture_file, "data: {{\"type\":\"response.completed\",\"response\":{{\"id\":\"resp1\",\"output\":[]}}}}\n").unwrap();
|
||||
|
||||
let home = TempDir::new().unwrap();
|
||||
let mut cmd = AssertCommand::new("cargo");
|
||||
cmd.arg("run")
|
||||
.arg("-p")
|
||||
.arg("codex-cli")
|
||||
.arg("--quiet")
|
||||
@@ -321,41 +260,49 @@ async fn integration_creates_and_checks_session_file() {
|
||||
.arg("exec")
|
||||
.arg("--skip-git-repo-check")
|
||||
.arg("-c")
|
||||
.arg(&resume_override)
|
||||
.arg("streaming=true")
|
||||
.arg("-C")
|
||||
.arg(env!("CARGO_MANIFEST_DIR"))
|
||||
.arg(&prompt2);
|
||||
cmd2.env("CODEX_HOME", home.path())
|
||||
.arg("hello?");
|
||||
cmd.env("CODEX_HOME", home.path())
|
||||
.env("OPENAI_API_KEY", "dummy")
|
||||
.env("CODEX_RS_SSE_FIXTURE", &fixture)
|
||||
.env("CODEX_RS_SSE_FIXTURE", &fixture_path)
|
||||
.env("OPENAI_BASE_URL", "http://unused.local");
|
||||
let output2 = cmd2.output().unwrap();
|
||||
assert!(output2.status.success(), "resume codex-cli run failed");
|
||||
|
||||
// The rollout writer runs on a background async task; give it a moment to flush.
|
||||
let mut new_len = orig_len;
|
||||
let deadline = Instant::now() + Duration::from_secs(5);
|
||||
let mut content2 = String::new();
|
||||
while Instant::now() < deadline {
|
||||
if let Ok(c) = std::fs::read_to_string(&path) {
|
||||
let count = c.lines().count();
|
||||
if count > orig_len {
|
||||
content2 = c;
|
||||
new_len = count;
|
||||
break;
|
||||
let output = cmd.output().unwrap();
|
||||
assert!(output.status.success());
|
||||
let stdout = String::from_utf8_lossy(&output.stdout);
|
||||
// Assert that 'fixture ' and 'hello' are output as two separate chunks from stdout, not as a single chunk
|
||||
// We split the output on the known delta substrings and check their order and separation
|
||||
let mut chunks = Vec::new();
|
||||
let mut last = 0;
|
||||
for pat in ["fixture ", "hello"] {
|
||||
if let Some(idx) = stdout[last..].find(pat) {
|
||||
if last != last + idx {
|
||||
let chunk = &stdout[last..last + idx];
|
||||
if !chunk.trim().is_empty() {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
}
|
||||
chunks.push(&stdout[last + idx..last + idx + pat.len()]);
|
||||
last = last + idx + pat.len();
|
||||
}
|
||||
std::thread::sleep(Duration::from_millis(50));
|
||||
}
|
||||
if content2.is_empty() {
|
||||
// last attempt
|
||||
content2 = std::fs::read_to_string(&path).unwrap();
|
||||
new_len = content2.lines().count();
|
||||
if last < stdout.len() {
|
||||
let chunk = &stdout[last..];
|
||||
if !chunk.trim().is_empty() {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
}
|
||||
assert!(new_len > orig_len, "rollout file did not grow after resume");
|
||||
assert!(content2.contains(&marker), "rollout lost original marker");
|
||||
assert!(
|
||||
content2.contains(&marker2),
|
||||
"rollout missing resumed marker"
|
||||
// Only keep the delta chunks
|
||||
let delta_chunks: Vec<&str> = chunks
|
||||
.iter()
|
||||
.cloned()
|
||||
.filter(|s| *s == "fixture " || *s == "hello")
|
||||
.collect();
|
||||
assert_eq!(
|
||||
delta_chunks,
|
||||
vec!["fixture ", "hello"],
|
||||
"Expected two separate delta chunks 'fixture ' and 'hello' from stdout"
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,113 +0,0 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use codex_core::Codex;
|
||||
use codex_core::ModelProviderInfo;
|
||||
use codex_core::exec::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_core::protocol::SessionConfiguredEvent;
|
||||
mod test_support;
|
||||
use tempfile::TempDir;
|
||||
use test_support::load_default_config_for_test;
|
||||
use test_support::load_sse_fixture_with_id;
|
||||
use tokio::time::timeout;
|
||||
use wiremock::Mock;
|
||||
use wiremock::MockServer;
|
||||
use wiremock::ResponseTemplate;
|
||||
use wiremock::matchers::method;
|
||||
use wiremock::matchers::path;
|
||||
|
||||
/// Build minimal SSE stream with completed marker using the JSON fixture.
|
||||
fn sse_completed(id: &str) -> String {
|
||||
load_sse_fixture_with_id("tests/fixtures/completed_template.json", id)
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn includes_session_id_and_model_headers_in_request() {
|
||||
#![allow(clippy::unwrap_used)]
|
||||
|
||||
if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
||||
println!(
|
||||
"Skipping test because it cannot execute when network is disabled in a Codex sandbox."
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
// Mock server
|
||||
let server = MockServer::start().await;
|
||||
|
||||
// First request – must NOT include `previous_response_id`.
|
||||
let first = ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "text/event-stream")
|
||||
.set_body_raw(sse_completed("resp1"), "text/event-stream");
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.respond_with(first)
|
||||
.expect(1)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let model_provider = ModelProviderInfo {
|
||||
name: "openai".into(),
|
||||
base_url: format!("{}/v1", server.uri()),
|
||||
// Environment variable that should exist in the test environment.
|
||||
// ModelClient will return an error if the environment variable for the
|
||||
// provider is not set.
|
||||
env_key: Some("PATH".into()),
|
||||
env_key_instructions: None,
|
||||
wire_api: codex_core::WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: Some(
|
||||
[("originator".to_string(), "codex_cli_rs".to_string())]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
),
|
||||
env_http_headers: None,
|
||||
request_max_retries: Some(0),
|
||||
stream_max_retries: Some(0),
|
||||
stream_idle_timeout_ms: None,
|
||||
};
|
||||
|
||||
// Init session
|
||||
let codex_home = TempDir::new().unwrap();
|
||||
let mut config = load_default_config_for_test(&codex_home);
|
||||
config.model_provider = model_provider;
|
||||
let ctrl_c = std::sync::Arc::new(tokio::sync::Notify::new());
|
||||
let (codex, _init_id) = Codex::spawn(config, ctrl_c.clone()).await.unwrap();
|
||||
|
||||
codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![InputItem::Text {
|
||||
text: "hello".into(),
|
||||
}],
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let mut current_session_id = None;
|
||||
// Wait for TaskComplete
|
||||
loop {
|
||||
let ev = timeout(Duration::from_secs(1), codex.next_event())
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
if let EventMsg::SessionConfigured(SessionConfiguredEvent { session_id, .. }) = ev.msg {
|
||||
current_session_id = Some(session_id.to_string());
|
||||
}
|
||||
if matches!(ev.msg, EventMsg::TaskComplete(_)) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// get request from the server
|
||||
let request = &server.received_requests().await.unwrap()[0];
|
||||
let request_body = request.headers.get("session_id").unwrap();
|
||||
let originator = request.headers.get("originator").unwrap();
|
||||
|
||||
assert!(current_session_id.is_some());
|
||||
assert_eq!(request_body.to_str().unwrap(), ¤t_session_id.unwrap());
|
||||
assert_eq!(originator.to_str().unwrap(), "codex_cli_rs");
|
||||
}
|
||||
@@ -45,10 +45,22 @@ async fn spawn_codex() -> Result<Codex, CodexErr> {
|
||||
"OPENAI_API_KEY must be set for live tests"
|
||||
);
|
||||
|
||||
// Environment tweaks to keep the tests snappy and inexpensive while still
|
||||
// exercising retry/robustness logic.
|
||||
//
|
||||
// NOTE: Starting with the 2024 edition `std::env::set_var` is `unsafe`
|
||||
// because changing the process environment races with any other threads
|
||||
// that might be performing environment look-ups at the same time.
|
||||
// Restrict the unsafety to this tiny block that happens at the very
|
||||
// beginning of the test, before we spawn any background tasks that could
|
||||
// observe the environment.
|
||||
unsafe {
|
||||
std::env::set_var("OPENAI_REQUEST_MAX_RETRIES", "2");
|
||||
std::env::set_var("OPENAI_STREAM_MAX_RETRIES", "2");
|
||||
}
|
||||
|
||||
let codex_home = TempDir::new().unwrap();
|
||||
let mut config = load_default_config_for_test(&codex_home);
|
||||
config.model_provider.request_max_retries = Some(2);
|
||||
config.model_provider.stream_max_retries = Some(2);
|
||||
let config = load_default_config_for_test(&codex_home);
|
||||
let (agent, _init_id) = Codex::spawn(config, std::sync::Arc::new(Notify::new())).await?;
|
||||
|
||||
Ok(agent)
|
||||
@@ -67,7 +79,7 @@ async fn live_streaming_and_prev_id_reset() {
|
||||
|
||||
let codex = spawn_codex().await.unwrap();
|
||||
|
||||
// ---------- Task 1 ----------
|
||||
// ---------- Task 1 ----------
|
||||
codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![InputItem::Text {
|
||||
@@ -101,7 +113,7 @@ async fn live_streaming_and_prev_id_reset() {
|
||||
"Agent did not stream any AgentMessage before TaskComplete"
|
||||
);
|
||||
|
||||
// ---------- Task 2 (same session) ----------
|
||||
// ---------- Task 2 (same session) ----------
|
||||
codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![InputItem::Text {
|
||||
|
||||
@@ -88,8 +88,13 @@ async fn keeps_previous_response_id_between_tasks() {
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
// Configure retry behavior explicitly to avoid mutating process-wide
|
||||
// environment variables.
|
||||
// Environment
|
||||
// Update environment – `set_var` is `unsafe` starting with the 2024
|
||||
// edition so we group the calls into a single `unsafe { … }` block.
|
||||
unsafe {
|
||||
std::env::set_var("OPENAI_REQUEST_MAX_RETRIES", "0");
|
||||
std::env::set_var("OPENAI_STREAM_MAX_RETRIES", "0");
|
||||
}
|
||||
let model_provider = ModelProviderInfo {
|
||||
name: "openai".into(),
|
||||
base_url: format!("{}/v1", server.uri()),
|
||||
@@ -102,10 +107,6 @@ async fn keeps_previous_response_id_between_tasks() {
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
// disable retries so we don't get duplicate calls in this test
|
||||
request_max_retries: Some(0),
|
||||
stream_max_retries: Some(0),
|
||||
stream_idle_timeout_ms: None,
|
||||
};
|
||||
|
||||
// Init session
|
||||
|
||||
@@ -70,8 +70,19 @@ async fn retries_on_early_close() {
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
// Configure retry behavior explicitly to avoid mutating process-wide
|
||||
// environment variables.
|
||||
// Environment
|
||||
//
|
||||
// As of Rust 2024 `std::env::set_var` has been made `unsafe` because
|
||||
// mutating the process environment is inherently racy when other threads
|
||||
// are running. We therefore have to wrap every call in an explicit
|
||||
// `unsafe` block. These are limited to the test-setup section so the
|
||||
// scope is very small and clearly delineated.
|
||||
|
||||
unsafe {
|
||||
std::env::set_var("OPENAI_REQUEST_MAX_RETRIES", "0");
|
||||
std::env::set_var("OPENAI_STREAM_MAX_RETRIES", "1");
|
||||
std::env::set_var("OPENAI_STREAM_IDLE_TIMEOUT_MS", "2000");
|
||||
}
|
||||
|
||||
let model_provider = ModelProviderInfo {
|
||||
name: "openai".into(),
|
||||
@@ -85,10 +96,6 @@ async fn retries_on_early_close() {
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
// exercise retry path: first attempt yields incomplete stream, so allow 1 retry
|
||||
request_max_retries: Some(0),
|
||||
stream_max_retries: Some(1),
|
||||
stream_idle_timeout_ms: Some(2000),
|
||||
};
|
||||
|
||||
let ctrl_c = std::sync::Arc::new(tokio::sync::Notify::new());
|
||||
|
||||
@@ -51,10 +51,6 @@ pub struct Cli {
|
||||
#[arg(long = "color", value_enum, default_value_t = Color::Auto)]
|
||||
pub color: Color,
|
||||
|
||||
/// Print events to stdout as JSONL.
|
||||
#[arg(long = "json", default_value_t = false)]
|
||||
pub json: bool,
|
||||
|
||||
/// Specifies file where the last message from the agent should be written.
|
||||
#[arg(long = "output-last-message")]
|
||||
pub last_message_file: Option<PathBuf>,
|
||||
|
||||
@@ -1,37 +1,564 @@
|
||||
use codex_common::elapsed::format_elapsed;
|
||||
use codex_common::summarize_sandbox_policy;
|
||||
use codex_core::WireApi;
|
||||
use codex_core::config::Config;
|
||||
use codex_core::model_supports_reasoning_summaries;
|
||||
use codex_core::protocol::AgentMessageEvent;
|
||||
use codex_core::protocol::BackgroundEventEvent;
|
||||
use codex_core::protocol::ErrorEvent;
|
||||
use codex_core::protocol::Event;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::ExecCommandBeginEvent;
|
||||
use codex_core::protocol::ExecCommandEndEvent;
|
||||
use codex_core::protocol::FileChange;
|
||||
use codex_core::protocol::McpToolCallBeginEvent;
|
||||
use codex_core::protocol::McpToolCallEndEvent;
|
||||
use codex_core::protocol::PatchApplyBeginEvent;
|
||||
use codex_core::protocol::PatchApplyEndEvent;
|
||||
use codex_core::protocol::SessionConfiguredEvent;
|
||||
use codex_core::protocol::TokenUsage;
|
||||
use owo_colors::OwoColorize;
|
||||
use owo_colors::Style;
|
||||
use shlex::try_join;
|
||||
use std::collections::HashMap;
|
||||
use std::io::Write;
|
||||
use std::io::{self};
|
||||
use std::time::Instant;
|
||||
|
||||
pub(crate) trait EventProcessor {
|
||||
/// Print summary of effective configuration and user prompt.
|
||||
fn print_config_summary(&mut self, config: &Config, prompt: &str);
|
||||
/// This should be configurable. When used in CI, users may not want to impose
|
||||
/// a limit so they can see the full transcript.
|
||||
const MAX_OUTPUT_LINES_FOR_EXEC_TOOL_CALL: usize = 20;
|
||||
|
||||
/// Handle a single event emitted by the agent.
|
||||
fn process_event(&mut self, event: Event);
|
||||
pub(crate) struct EventProcessor {
|
||||
call_id_to_command: HashMap<String, ExecCommandBegin>,
|
||||
call_id_to_patch: HashMap<String, PatchApplyBegin>,
|
||||
|
||||
/// Tracks in-flight MCP tool calls so we can calculate duration and print
|
||||
/// a concise summary when the corresponding `McpToolCallEnd` event is
|
||||
/// received.
|
||||
call_id_to_tool_call: HashMap<String, McpToolCallBegin>,
|
||||
|
||||
// To ensure that --color=never is respected, ANSI escapes _must_ be added
|
||||
// using .style() with one of these fields. If you need a new style, add a
|
||||
// new field here.
|
||||
bold: Style,
|
||||
italic: Style,
|
||||
dimmed: Style,
|
||||
|
||||
magenta: Style,
|
||||
red: Style,
|
||||
green: Style,
|
||||
cyan: Style,
|
||||
|
||||
/// Whether to include `AgentReasoning` events in the output.
|
||||
show_agent_reasoning: bool,
|
||||
/// Whether to surface streaming deltas (true = print deltas + suppress final message).
|
||||
streaming_enabled: bool,
|
||||
/// Internal: have we already printed the `codex` header for the current streaming turn?
|
||||
printed_agent_header: bool,
|
||||
/// Internal: have we already printed the `thinking` header for current streaming turn?
|
||||
printed_reasoning_header: bool,
|
||||
}
|
||||
|
||||
pub(crate) fn create_config_summary_entries(config: &Config) -> Vec<(&'static str, String)> {
|
||||
let mut entries = vec![
|
||||
("workdir", config.cwd.display().to_string()),
|
||||
("model", config.model.clone()),
|
||||
("provider", config.model_provider_id.clone()),
|
||||
("approval", format!("{:?}", config.approval_policy)),
|
||||
("sandbox", summarize_sandbox_policy(&config.sandbox_policy)),
|
||||
];
|
||||
if config.model_provider.wire_api == WireApi::Responses
|
||||
&& model_supports_reasoning_summaries(config)
|
||||
{
|
||||
entries.push((
|
||||
"reasoning effort",
|
||||
config.model_reasoning_effort.to_string(),
|
||||
));
|
||||
entries.push((
|
||||
"reasoning summaries",
|
||||
config.model_reasoning_summary.to_string(),
|
||||
));
|
||||
impl EventProcessor {
|
||||
pub(crate) fn create_with_ansi(
|
||||
with_ansi: bool,
|
||||
show_agent_reasoning: bool,
|
||||
streaming_enabled: bool,
|
||||
) -> Self {
|
||||
let call_id_to_command = HashMap::new();
|
||||
let call_id_to_patch = HashMap::new();
|
||||
let call_id_to_tool_call = HashMap::new();
|
||||
|
||||
if with_ansi {
|
||||
Self {
|
||||
call_id_to_command,
|
||||
call_id_to_patch,
|
||||
bold: Style::new().bold(),
|
||||
italic: Style::new().italic(),
|
||||
dimmed: Style::new().dimmed(),
|
||||
magenta: Style::new().magenta(),
|
||||
red: Style::new().red(),
|
||||
green: Style::new().green(),
|
||||
cyan: Style::new().cyan(),
|
||||
call_id_to_tool_call,
|
||||
show_agent_reasoning,
|
||||
streaming_enabled,
|
||||
printed_agent_header: false,
|
||||
printed_reasoning_header: false,
|
||||
}
|
||||
} else {
|
||||
Self {
|
||||
call_id_to_command,
|
||||
call_id_to_patch,
|
||||
bold: Style::new(),
|
||||
italic: Style::new(),
|
||||
dimmed: Style::new(),
|
||||
magenta: Style::new(),
|
||||
red: Style::new(),
|
||||
green: Style::new(),
|
||||
cyan: Style::new(),
|
||||
call_id_to_tool_call,
|
||||
show_agent_reasoning,
|
||||
streaming_enabled,
|
||||
printed_agent_header: false,
|
||||
printed_reasoning_header: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct ExecCommandBegin {
|
||||
command: Vec<String>,
|
||||
start_time: Instant,
|
||||
}
|
||||
|
||||
/// Metadata captured when an `McpToolCallBegin` event is received.
|
||||
struct McpToolCallBegin {
|
||||
/// Formatted invocation string, e.g. `server.tool({"city":"sf"})`.
|
||||
invocation: String,
|
||||
/// Timestamp when the call started so we can compute duration later.
|
||||
start_time: Instant,
|
||||
}
|
||||
|
||||
struct PatchApplyBegin {
|
||||
start_time: Instant,
|
||||
auto_approved: bool,
|
||||
}
|
||||
|
||||
// Timestamped println helper. The timestamp is styled with self.dimmed.
|
||||
#[macro_export]
|
||||
macro_rules! ts_println {
|
||||
($self:ident, $($arg:tt)*) => {{
|
||||
let now = chrono::Utc::now();
|
||||
let formatted = now.format("[%Y-%m-%dT%H:%M:%S]");
|
||||
print!("{} ", formatted.style($self.dimmed));
|
||||
println!($($arg)*);
|
||||
}};
|
||||
}
|
||||
|
||||
impl EventProcessor {
|
||||
/// Print a concise summary of the effective configuration that will be used
|
||||
/// for the session. This mirrors the information shown in the TUI welcome
|
||||
/// screen.
|
||||
pub(crate) fn print_config_summary(&mut self, config: &Config, prompt: &str) {
|
||||
const VERSION: &str = env!("CARGO_PKG_VERSION");
|
||||
ts_println!(
|
||||
self,
|
||||
"OpenAI Codex v{} (research preview)\n--------",
|
||||
VERSION
|
||||
);
|
||||
|
||||
let mut entries = vec![
|
||||
("workdir", config.cwd.display().to_string()),
|
||||
("model", config.model.clone()),
|
||||
("provider", config.model_provider_id.clone()),
|
||||
("approval", format!("{:?}", config.approval_policy)),
|
||||
("sandbox", summarize_sandbox_policy(&config.sandbox_policy)),
|
||||
];
|
||||
if config.model_provider.wire_api == WireApi::Responses
|
||||
&& model_supports_reasoning_summaries(config)
|
||||
{
|
||||
entries.push((
|
||||
"reasoning effort",
|
||||
config.model_reasoning_effort.to_string(),
|
||||
));
|
||||
entries.push((
|
||||
"reasoning summaries",
|
||||
config.model_reasoning_summary.to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
for (key, value) in entries {
|
||||
println!("{} {}", format!("{key}:").style(self.bold), value);
|
||||
}
|
||||
|
||||
println!("--------");
|
||||
|
||||
// Echo the prompt that will be sent to the agent so it is visible in the
|
||||
// transcript/logs before any events come in. Note the prompt may have been
|
||||
// read from stdin, so it may not be visible in the terminal otherwise.
|
||||
ts_println!(
|
||||
self,
|
||||
"{}\n{}",
|
||||
"User instructions:".style(self.bold).style(self.cyan),
|
||||
prompt
|
||||
);
|
||||
}
|
||||
|
||||
entries
|
||||
pub(crate) fn process_event(&mut self, event: Event) {
|
||||
let Event { id: _, msg } = event;
|
||||
match msg {
|
||||
EventMsg::Error(ErrorEvent { message }) => {
|
||||
let prefix = "ERROR:".style(self.red);
|
||||
ts_println!(self, "{prefix} {message}");
|
||||
}
|
||||
EventMsg::BackgroundEvent(BackgroundEventEvent { message }) => {
|
||||
ts_println!(self, "{}", message.style(self.dimmed));
|
||||
}
|
||||
EventMsg::TaskStarted | EventMsg::TaskComplete(_) => {
|
||||
// Reset streaming headers at start/end boundaries.
|
||||
if matches!(msg, EventMsg::TaskStarted) {
|
||||
self.printed_agent_header = false;
|
||||
self.printed_reasoning_header = false;
|
||||
}
|
||||
// Ignore.
|
||||
}
|
||||
EventMsg::TokenCount(TokenUsage { total_tokens, .. }) => {
|
||||
ts_println!(self, "tokens used: {total_tokens}");
|
||||
}
|
||||
EventMsg::AgentMessage(AgentMessageEvent { message }) => {
|
||||
if self.streaming_enabled {
|
||||
// Suppress full message when streaming; final markdown not printed in CLI.
|
||||
// If no deltas were seen, fall back to printing now.
|
||||
if !self.printed_agent_header {
|
||||
ts_println!(
|
||||
self,
|
||||
"{}\n{message}",
|
||||
"codex".style(self.bold).style(self.magenta)
|
||||
);
|
||||
}
|
||||
} else {
|
||||
ts_println!(
|
||||
self,
|
||||
"{}\n{message}",
|
||||
"codex".style(self.bold).style(self.magenta)
|
||||
);
|
||||
}
|
||||
}
|
||||
EventMsg::AgentMessageDelta(AgentMessageEvent { message }) => {
|
||||
if !self.streaming_enabled {
|
||||
// streaming disabled, ignore
|
||||
} else {
|
||||
if !self.printed_agent_header {
|
||||
ts_println!(self, "{}", "codex".style(self.bold).style(self.magenta));
|
||||
self.printed_agent_header = true;
|
||||
}
|
||||
print!("{message}");
|
||||
let _ = io::stdout().flush();
|
||||
}
|
||||
}
|
||||
EventMsg::ExecCommandBegin(ExecCommandBeginEvent {
|
||||
call_id,
|
||||
command,
|
||||
cwd,
|
||||
}) => {
|
||||
self.call_id_to_command.insert(
|
||||
call_id.clone(),
|
||||
ExecCommandBegin {
|
||||
command: command.clone(),
|
||||
start_time: Instant::now(),
|
||||
},
|
||||
);
|
||||
ts_println!(
|
||||
self,
|
||||
"{} {} in {}",
|
||||
"exec".style(self.magenta),
|
||||
escape_command(&command).style(self.bold),
|
||||
cwd.to_string_lossy(),
|
||||
);
|
||||
}
|
||||
EventMsg::ExecCommandEnd(ExecCommandEndEvent {
|
||||
call_id,
|
||||
stdout,
|
||||
stderr,
|
||||
exit_code,
|
||||
}) => {
|
||||
let exec_command = self.call_id_to_command.remove(&call_id);
|
||||
let (duration, call) = if let Some(ExecCommandBegin {
|
||||
command,
|
||||
start_time,
|
||||
}) = exec_command
|
||||
{
|
||||
(
|
||||
format!(" in {}", format_elapsed(start_time)),
|
||||
format!("{}", escape_command(&command).style(self.bold)),
|
||||
)
|
||||
} else {
|
||||
("".to_string(), format!("exec('{call_id}')"))
|
||||
};
|
||||
|
||||
let output = if exit_code == 0 { stdout } else { stderr };
|
||||
let truncated_output = output
|
||||
.lines()
|
||||
.take(MAX_OUTPUT_LINES_FOR_EXEC_TOOL_CALL)
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
match exit_code {
|
||||
0 => {
|
||||
let title = format!("{call} succeeded{duration}:");
|
||||
ts_println!(self, "{}", title.style(self.green));
|
||||
}
|
||||
_ => {
|
||||
let title = format!("{call} exited {exit_code}{duration}:");
|
||||
ts_println!(self, "{}", title.style(self.red));
|
||||
}
|
||||
}
|
||||
println!("{}", truncated_output.style(self.dimmed));
|
||||
}
|
||||
EventMsg::McpToolCallBegin(McpToolCallBeginEvent {
|
||||
call_id,
|
||||
server,
|
||||
tool,
|
||||
arguments,
|
||||
}) => {
|
||||
// Build fully-qualified tool name: server.tool
|
||||
let fq_tool_name = format!("{server}.{tool}");
|
||||
|
||||
// Format arguments as compact JSON so they fit on one line.
|
||||
let args_str = arguments
|
||||
.as_ref()
|
||||
.map(|v: &serde_json::Value| {
|
||||
serde_json::to_string(v).unwrap_or_else(|_| v.to_string())
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
let invocation = if args_str.is_empty() {
|
||||
format!("{fq_tool_name}()")
|
||||
} else {
|
||||
format!("{fq_tool_name}({args_str})")
|
||||
};
|
||||
|
||||
self.call_id_to_tool_call.insert(
|
||||
call_id.clone(),
|
||||
McpToolCallBegin {
|
||||
invocation: invocation.clone(),
|
||||
start_time: Instant::now(),
|
||||
},
|
||||
);
|
||||
|
||||
ts_println!(
|
||||
self,
|
||||
"{} {}",
|
||||
"tool".style(self.magenta),
|
||||
invocation.style(self.bold),
|
||||
);
|
||||
}
|
||||
EventMsg::McpToolCallEnd(tool_call_end_event) => {
|
||||
let is_success = tool_call_end_event.is_success();
|
||||
let McpToolCallEndEvent { call_id, result } = tool_call_end_event;
|
||||
// Retrieve start time and invocation for duration calculation and labeling.
|
||||
let info = self.call_id_to_tool_call.remove(&call_id);
|
||||
|
||||
let (duration, invocation) = if let Some(McpToolCallBegin {
|
||||
invocation,
|
||||
start_time,
|
||||
..
|
||||
}) = info
|
||||
{
|
||||
(format!(" in {}", format_elapsed(start_time)), invocation)
|
||||
} else {
|
||||
(String::new(), format!("tool('{call_id}')"))
|
||||
};
|
||||
|
||||
let status_str = if is_success { "success" } else { "failed" };
|
||||
let title_style = if is_success { self.green } else { self.red };
|
||||
let title = format!("{invocation} {status_str}{duration}:");
|
||||
|
||||
ts_println!(self, "{}", title.style(title_style));
|
||||
|
||||
if let Ok(res) = result {
|
||||
let val: serde_json::Value = res.into();
|
||||
let pretty =
|
||||
serde_json::to_string_pretty(&val).unwrap_or_else(|_| val.to_string());
|
||||
|
||||
for line in pretty.lines().take(MAX_OUTPUT_LINES_FOR_EXEC_TOOL_CALL) {
|
||||
println!("{}", line.style(self.dimmed));
|
||||
}
|
||||
}
|
||||
}
|
||||
EventMsg::PatchApplyBegin(PatchApplyBeginEvent {
|
||||
call_id,
|
||||
auto_approved,
|
||||
changes,
|
||||
}) => {
|
||||
// Store metadata so we can calculate duration later when we
|
||||
// receive the corresponding PatchApplyEnd event.
|
||||
self.call_id_to_patch.insert(
|
||||
call_id.clone(),
|
||||
PatchApplyBegin {
|
||||
start_time: Instant::now(),
|
||||
auto_approved,
|
||||
},
|
||||
);
|
||||
|
||||
ts_println!(
|
||||
self,
|
||||
"{} auto_approved={}:",
|
||||
"apply_patch".style(self.magenta),
|
||||
auto_approved,
|
||||
);
|
||||
|
||||
// Pretty-print the patch summary with colored diff markers so
|
||||
// it's easy to scan in the terminal output.
|
||||
for (path, change) in changes.iter() {
|
||||
match change {
|
||||
FileChange::Add { content } => {
|
||||
let header = format!(
|
||||
"{} {}",
|
||||
format_file_change(change),
|
||||
path.to_string_lossy()
|
||||
);
|
||||
println!("{}", header.style(self.magenta));
|
||||
for line in content.lines() {
|
||||
println!("{}", line.style(self.green));
|
||||
}
|
||||
}
|
||||
FileChange::Delete => {
|
||||
let header = format!(
|
||||
"{} {}",
|
||||
format_file_change(change),
|
||||
path.to_string_lossy()
|
||||
);
|
||||
println!("{}", header.style(self.magenta));
|
||||
}
|
||||
FileChange::Update {
|
||||
unified_diff,
|
||||
move_path,
|
||||
} => {
|
||||
let header = if let Some(dest) = move_path {
|
||||
format!(
|
||||
"{} {} -> {}",
|
||||
format_file_change(change),
|
||||
path.to_string_lossy(),
|
||||
dest.to_string_lossy()
|
||||
)
|
||||
} else {
|
||||
format!("{} {}", format_file_change(change), path.to_string_lossy())
|
||||
};
|
||||
println!("{}", header.style(self.magenta));
|
||||
|
||||
// Colorize diff lines. We keep file header lines
|
||||
// (--- / +++) without extra coloring so they are
|
||||
// still readable.
|
||||
for diff_line in unified_diff.lines() {
|
||||
if diff_line.starts_with('+') && !diff_line.starts_with("+++") {
|
||||
println!("{}", diff_line.style(self.green));
|
||||
} else if diff_line.starts_with('-')
|
||||
&& !diff_line.starts_with("---")
|
||||
{
|
||||
println!("{}", diff_line.style(self.red));
|
||||
} else {
|
||||
println!("{diff_line}");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
EventMsg::PatchApplyEnd(PatchApplyEndEvent {
|
||||
call_id,
|
||||
stdout,
|
||||
stderr,
|
||||
success,
|
||||
}) => {
|
||||
let patch_begin = self.call_id_to_patch.remove(&call_id);
|
||||
|
||||
// Compute duration and summary label similar to exec commands.
|
||||
let (duration, label) = if let Some(PatchApplyBegin {
|
||||
start_time,
|
||||
auto_approved,
|
||||
}) = patch_begin
|
||||
{
|
||||
(
|
||||
format!(" in {}", format_elapsed(start_time)),
|
||||
format!("apply_patch(auto_approved={auto_approved})"),
|
||||
)
|
||||
} else {
|
||||
(String::new(), format!("apply_patch('{call_id}')"))
|
||||
};
|
||||
|
||||
let (exit_code, output, title_style) = if success {
|
||||
(0, stdout, self.green)
|
||||
} else {
|
||||
(1, stderr, self.red)
|
||||
};
|
||||
|
||||
let title = format!("{label} exited {exit_code}{duration}:");
|
||||
ts_println!(self, "{}", title.style(title_style));
|
||||
for line in output.lines() {
|
||||
println!("{}", line.style(self.dimmed));
|
||||
}
|
||||
}
|
||||
EventMsg::ExecApprovalRequest(_) => {
|
||||
// Should we exit?
|
||||
}
|
||||
EventMsg::ApplyPatchApprovalRequest(_) => {
|
||||
// Should we exit?
|
||||
}
|
||||
EventMsg::AgentReasoning(agent_reasoning_event) => {
|
||||
if self.show_agent_reasoning {
|
||||
if self.streaming_enabled {
|
||||
if !self.printed_reasoning_header {
|
||||
ts_println!(
|
||||
self,
|
||||
"{}\n{}",
|
||||
"thinking".style(self.italic).style(self.magenta),
|
||||
agent_reasoning_event.text
|
||||
);
|
||||
}
|
||||
} else {
|
||||
ts_println!(
|
||||
self,
|
||||
"{}\n{}",
|
||||
"thinking".style(self.italic).style(self.magenta),
|
||||
agent_reasoning_event.text
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
EventMsg::AgentReasoningDelta(agent_reasoning_event) => {
|
||||
if self.show_agent_reasoning && self.streaming_enabled {
|
||||
if !self.printed_reasoning_header {
|
||||
ts_println!(
|
||||
self,
|
||||
"{}",
|
||||
"thinking".style(self.italic).style(self.magenta)
|
||||
);
|
||||
self.printed_reasoning_header = true;
|
||||
}
|
||||
print!("{}", agent_reasoning_event.text);
|
||||
let _ = io::stdout().flush();
|
||||
}
|
||||
}
|
||||
EventMsg::SessionConfigured(session_configured_event) => {
|
||||
let SessionConfiguredEvent {
|
||||
session_id,
|
||||
model,
|
||||
history_log_id: _,
|
||||
history_entry_count: _,
|
||||
} = session_configured_event;
|
||||
|
||||
ts_println!(
|
||||
self,
|
||||
"{} {}",
|
||||
"codex session".style(self.magenta).style(self.bold),
|
||||
session_id.to_string().style(self.dimmed)
|
||||
);
|
||||
|
||||
ts_println!(self, "model: {}", model);
|
||||
println!();
|
||||
}
|
||||
EventMsg::GetHistoryEntryResponse(_) => {
|
||||
// Currently ignored in exec output.
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn escape_command(command: &[String]) -> String {
|
||||
try_join(command.iter().map(|s| s.as_str())).unwrap_or_else(|_| command.join(" "))
|
||||
}
|
||||
|
||||
fn format_file_change(change: &FileChange) -> &'static str {
|
||||
match change {
|
||||
FileChange::Add { .. } => "A",
|
||||
FileChange::Delete => "D",
|
||||
FileChange::Update {
|
||||
move_path: Some(_), ..
|
||||
} => "R",
|
||||
FileChange::Update {
|
||||
move_path: None, ..
|
||||
} => "M",
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,520 +0,0 @@
|
||||
use codex_common::elapsed::format_elapsed;
|
||||
use codex_core::config::Config;
|
||||
use codex_core::protocol::AgentMessageDeltaEvent;
|
||||
use codex_core::protocol::AgentMessageEvent;
|
||||
use codex_core::protocol::AgentReasoningDeltaEvent;
|
||||
use codex_core::protocol::BackgroundEventEvent;
|
||||
use codex_core::protocol::ErrorEvent;
|
||||
use codex_core::protocol::Event;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::ExecCommandBeginEvent;
|
||||
use codex_core::protocol::ExecCommandEndEvent;
|
||||
use codex_core::protocol::FileChange;
|
||||
use codex_core::protocol::McpToolCallBeginEvent;
|
||||
use codex_core::protocol::McpToolCallEndEvent;
|
||||
use codex_core::protocol::PatchApplyBeginEvent;
|
||||
use codex_core::protocol::PatchApplyEndEvent;
|
||||
use codex_core::protocol::SessionConfiguredEvent;
|
||||
use codex_core::protocol::TokenUsage;
|
||||
use owo_colors::OwoColorize;
|
||||
use owo_colors::Style;
|
||||
use shlex::try_join;
|
||||
use std::collections::HashMap;
|
||||
use std::io::Write;
|
||||
use std::time::Instant;
|
||||
|
||||
use crate::event_processor::EventProcessor;
|
||||
use crate::event_processor::create_config_summary_entries;
|
||||
|
||||
/// This should be configurable. When used in CI, users may not want to impose
|
||||
/// a limit so they can see the full transcript.
|
||||
const MAX_OUTPUT_LINES_FOR_EXEC_TOOL_CALL: usize = 20;
|
||||
pub(crate) struct EventProcessorWithHumanOutput {
|
||||
call_id_to_command: HashMap<String, ExecCommandBegin>,
|
||||
call_id_to_patch: HashMap<String, PatchApplyBegin>,
|
||||
|
||||
/// Tracks in-flight MCP tool calls so we can calculate duration and print
|
||||
/// a concise summary when the corresponding `McpToolCallEnd` event is
|
||||
/// received.
|
||||
call_id_to_tool_call: HashMap<String, McpToolCallBegin>,
|
||||
|
||||
// To ensure that --color=never is respected, ANSI escapes _must_ be added
|
||||
// using .style() with one of these fields. If you need a new style, add a
|
||||
// new field here.
|
||||
bold: Style,
|
||||
italic: Style,
|
||||
dimmed: Style,
|
||||
|
||||
magenta: Style,
|
||||
red: Style,
|
||||
green: Style,
|
||||
cyan: Style,
|
||||
|
||||
/// Whether to include `AgentReasoning` events in the output.
|
||||
show_agent_reasoning: bool,
|
||||
answer_started: bool,
|
||||
reasoning_started: bool,
|
||||
}
|
||||
|
||||
impl EventProcessorWithHumanOutput {
|
||||
pub(crate) fn create_with_ansi(with_ansi: bool, config: &Config) -> Self {
|
||||
let call_id_to_command = HashMap::new();
|
||||
let call_id_to_patch = HashMap::new();
|
||||
let call_id_to_tool_call = HashMap::new();
|
||||
|
||||
if with_ansi {
|
||||
Self {
|
||||
call_id_to_command,
|
||||
call_id_to_patch,
|
||||
bold: Style::new().bold(),
|
||||
italic: Style::new().italic(),
|
||||
dimmed: Style::new().dimmed(),
|
||||
magenta: Style::new().magenta(),
|
||||
red: Style::new().red(),
|
||||
green: Style::new().green(),
|
||||
cyan: Style::new().cyan(),
|
||||
call_id_to_tool_call,
|
||||
show_agent_reasoning: !config.hide_agent_reasoning,
|
||||
answer_started: false,
|
||||
reasoning_started: false,
|
||||
}
|
||||
} else {
|
||||
Self {
|
||||
call_id_to_command,
|
||||
call_id_to_patch,
|
||||
bold: Style::new(),
|
||||
italic: Style::new(),
|
||||
dimmed: Style::new(),
|
||||
magenta: Style::new(),
|
||||
red: Style::new(),
|
||||
green: Style::new(),
|
||||
cyan: Style::new(),
|
||||
call_id_to_tool_call,
|
||||
show_agent_reasoning: !config.hide_agent_reasoning,
|
||||
answer_started: false,
|
||||
reasoning_started: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct ExecCommandBegin {
|
||||
command: Vec<String>,
|
||||
start_time: Instant,
|
||||
}
|
||||
|
||||
/// Metadata captured when an `McpToolCallBegin` event is received.
|
||||
struct McpToolCallBegin {
|
||||
/// Formatted invocation string, e.g. `server.tool({"city":"sf"})`.
|
||||
invocation: String,
|
||||
/// Timestamp when the call started so we can compute duration later.
|
||||
start_time: Instant,
|
||||
}
|
||||
|
||||
struct PatchApplyBegin {
|
||||
start_time: Instant,
|
||||
auto_approved: bool,
|
||||
}
|
||||
|
||||
// Timestamped println helper. The timestamp is styled with self.dimmed.
|
||||
#[macro_export]
|
||||
macro_rules! ts_println {
|
||||
($self:ident, $($arg:tt)*) => {{
|
||||
let now = chrono::Utc::now();
|
||||
let formatted = now.format("[%Y-%m-%dT%H:%M:%S]");
|
||||
print!("{} ", formatted.style($self.dimmed));
|
||||
println!($($arg)*);
|
||||
}};
|
||||
}
|
||||
|
||||
impl EventProcessor for EventProcessorWithHumanOutput {
|
||||
/// Print a concise summary of the effective configuration that will be used
|
||||
/// for the session. This mirrors the information shown in the TUI welcome
|
||||
/// screen.
|
||||
fn print_config_summary(&mut self, config: &Config, prompt: &str) {
|
||||
const VERSION: &str = env!("CARGO_PKG_VERSION");
|
||||
ts_println!(
|
||||
self,
|
||||
"OpenAI Codex v{} (research preview)\n--------",
|
||||
VERSION
|
||||
);
|
||||
|
||||
let entries = create_config_summary_entries(config);
|
||||
|
||||
for (key, value) in entries {
|
||||
println!("{} {}", format!("{key}:").style(self.bold), value);
|
||||
}
|
||||
|
||||
println!("--------");
|
||||
|
||||
// Echo the prompt that will be sent to the agent so it is visible in the
|
||||
// transcript/logs before any events come in. Note the prompt may have been
|
||||
// read from stdin, so it may not be visible in the terminal otherwise.
|
||||
ts_println!(
|
||||
self,
|
||||
"{}\n{}",
|
||||
"User instructions:".style(self.bold).style(self.cyan),
|
||||
prompt
|
||||
);
|
||||
}
|
||||
|
||||
fn process_event(&mut self, event: Event) {
|
||||
let Event { id: _, msg } = event;
|
||||
match msg {
|
||||
EventMsg::Error(ErrorEvent { message }) => {
|
||||
let prefix = "ERROR:".style(self.red);
|
||||
ts_println!(self, "{prefix} {message}");
|
||||
}
|
||||
EventMsg::BackgroundEvent(BackgroundEventEvent { message }) => {
|
||||
ts_println!(self, "{}", message.style(self.dimmed));
|
||||
}
|
||||
EventMsg::TaskStarted | EventMsg::TaskComplete(_) => {
|
||||
// Ignore.
|
||||
}
|
||||
EventMsg::TokenCount(TokenUsage { total_tokens, .. }) => {
|
||||
ts_println!(self, "tokens used: {total_tokens}");
|
||||
}
|
||||
EventMsg::AgentMessageDelta(AgentMessageDeltaEvent { delta }) => {
|
||||
if !self.answer_started {
|
||||
ts_println!(self, "{}\n", "codex".style(self.italic).style(self.magenta));
|
||||
self.answer_started = true;
|
||||
}
|
||||
print!("{delta}");
|
||||
#[allow(clippy::expect_used)]
|
||||
std::io::stdout().flush().expect("could not flush stdout");
|
||||
}
|
||||
EventMsg::AgentReasoningDelta(AgentReasoningDeltaEvent { delta }) => {
|
||||
if !self.show_agent_reasoning {
|
||||
return;
|
||||
}
|
||||
if !self.reasoning_started {
|
||||
ts_println!(
|
||||
self,
|
||||
"{}\n",
|
||||
"thinking".style(self.italic).style(self.magenta),
|
||||
);
|
||||
self.reasoning_started = true;
|
||||
}
|
||||
print!("{delta}");
|
||||
#[allow(clippy::expect_used)]
|
||||
std::io::stdout().flush().expect("could not flush stdout");
|
||||
}
|
||||
EventMsg::AgentMessage(AgentMessageEvent { message }) => {
|
||||
// if answer_started is false, this means we haven't received any
|
||||
// delta. Thus, we need to print the message as a new answer.
|
||||
if !self.answer_started {
|
||||
ts_println!(
|
||||
self,
|
||||
"{}\n{}",
|
||||
"codex".style(self.italic).style(self.magenta),
|
||||
message,
|
||||
);
|
||||
} else {
|
||||
println!();
|
||||
self.answer_started = false;
|
||||
}
|
||||
}
|
||||
EventMsg::ExecCommandBegin(ExecCommandBeginEvent {
|
||||
call_id,
|
||||
command,
|
||||
cwd,
|
||||
}) => {
|
||||
self.call_id_to_command.insert(
|
||||
call_id.clone(),
|
||||
ExecCommandBegin {
|
||||
command: command.clone(),
|
||||
start_time: Instant::now(),
|
||||
},
|
||||
);
|
||||
ts_println!(
|
||||
self,
|
||||
"{} {} in {}",
|
||||
"exec".style(self.magenta),
|
||||
escape_command(&command).style(self.bold),
|
||||
cwd.to_string_lossy(),
|
||||
);
|
||||
}
|
||||
EventMsg::ExecCommandEnd(ExecCommandEndEvent {
|
||||
call_id,
|
||||
stdout,
|
||||
stderr,
|
||||
exit_code,
|
||||
}) => {
|
||||
let exec_command = self.call_id_to_command.remove(&call_id);
|
||||
let (duration, call) = if let Some(ExecCommandBegin {
|
||||
command,
|
||||
start_time,
|
||||
}) = exec_command
|
||||
{
|
||||
(
|
||||
format!(" in {}", format_elapsed(start_time)),
|
||||
format!("{}", escape_command(&command).style(self.bold)),
|
||||
)
|
||||
} else {
|
||||
("".to_string(), format!("exec('{call_id}')"))
|
||||
};
|
||||
|
||||
let output = if exit_code == 0 { stdout } else { stderr };
|
||||
let truncated_output = output
|
||||
.lines()
|
||||
.take(MAX_OUTPUT_LINES_FOR_EXEC_TOOL_CALL)
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
match exit_code {
|
||||
0 => {
|
||||
let title = format!("{call} succeeded{duration}:");
|
||||
ts_println!(self, "{}", title.style(self.green));
|
||||
}
|
||||
_ => {
|
||||
let title = format!("{call} exited {exit_code}{duration}:");
|
||||
ts_println!(self, "{}", title.style(self.red));
|
||||
}
|
||||
}
|
||||
println!("{}", truncated_output.style(self.dimmed));
|
||||
}
|
||||
EventMsg::McpToolCallBegin(McpToolCallBeginEvent {
|
||||
call_id,
|
||||
server,
|
||||
tool,
|
||||
arguments,
|
||||
}) => {
|
||||
// Build fully-qualified tool name: server.tool
|
||||
let fq_tool_name = format!("{server}.{tool}");
|
||||
|
||||
// Format arguments as compact JSON so they fit on one line.
|
||||
let args_str = arguments
|
||||
.as_ref()
|
||||
.map(|v: &serde_json::Value| {
|
||||
serde_json::to_string(v).unwrap_or_else(|_| v.to_string())
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
let invocation = if args_str.is_empty() {
|
||||
format!("{fq_tool_name}()")
|
||||
} else {
|
||||
format!("{fq_tool_name}({args_str})")
|
||||
};
|
||||
|
||||
self.call_id_to_tool_call.insert(
|
||||
call_id.clone(),
|
||||
McpToolCallBegin {
|
||||
invocation: invocation.clone(),
|
||||
start_time: Instant::now(),
|
||||
},
|
||||
);
|
||||
|
||||
ts_println!(
|
||||
self,
|
||||
"{} {}",
|
||||
"tool".style(self.magenta),
|
||||
invocation.style(self.bold),
|
||||
);
|
||||
}
|
||||
EventMsg::McpToolCallEnd(tool_call_end_event) => {
|
||||
let is_success = tool_call_end_event.is_success();
|
||||
let McpToolCallEndEvent { call_id, result } = tool_call_end_event;
|
||||
// Retrieve start time and invocation for duration calculation and labeling.
|
||||
let info = self.call_id_to_tool_call.remove(&call_id);
|
||||
|
||||
let (duration, invocation) = if let Some(McpToolCallBegin {
|
||||
invocation,
|
||||
start_time,
|
||||
..
|
||||
}) = info
|
||||
{
|
||||
(format!(" in {}", format_elapsed(start_time)), invocation)
|
||||
} else {
|
||||
(String::new(), format!("tool('{call_id}')"))
|
||||
};
|
||||
|
||||
let status_str = if is_success { "success" } else { "failed" };
|
||||
let title_style = if is_success { self.green } else { self.red };
|
||||
let title = format!("{invocation} {status_str}{duration}:");
|
||||
|
||||
ts_println!(self, "{}", title.style(title_style));
|
||||
|
||||
if let Ok(res) = result {
|
||||
let val: serde_json::Value = res.into();
|
||||
let pretty =
|
||||
serde_json::to_string_pretty(&val).unwrap_or_else(|_| val.to_string());
|
||||
|
||||
for line in pretty.lines().take(MAX_OUTPUT_LINES_FOR_EXEC_TOOL_CALL) {
|
||||
println!("{}", line.style(self.dimmed));
|
||||
}
|
||||
}
|
||||
}
|
||||
EventMsg::PatchApplyBegin(PatchApplyBeginEvent {
|
||||
call_id,
|
||||
auto_approved,
|
||||
changes,
|
||||
}) => {
|
||||
// Store metadata so we can calculate duration later when we
|
||||
// receive the corresponding PatchApplyEnd event.
|
||||
self.call_id_to_patch.insert(
|
||||
call_id.clone(),
|
||||
PatchApplyBegin {
|
||||
start_time: Instant::now(),
|
||||
auto_approved,
|
||||
},
|
||||
);
|
||||
|
||||
ts_println!(
|
||||
self,
|
||||
"{} auto_approved={}:",
|
||||
"apply_patch".style(self.magenta),
|
||||
auto_approved,
|
||||
);
|
||||
|
||||
// Pretty-print the patch summary with colored diff markers so
|
||||
// it's easy to scan in the terminal output.
|
||||
for (path, change) in changes.iter() {
|
||||
match change {
|
||||
FileChange::Add { content } => {
|
||||
let header = format!(
|
||||
"{} {}",
|
||||
format_file_change(change),
|
||||
path.to_string_lossy()
|
||||
);
|
||||
println!("{}", header.style(self.magenta));
|
||||
for line in content.lines() {
|
||||
println!("{}", line.style(self.green));
|
||||
}
|
||||
}
|
||||
FileChange::Delete => {
|
||||
let header = format!(
|
||||
"{} {}",
|
||||
format_file_change(change),
|
||||
path.to_string_lossy()
|
||||
);
|
||||
println!("{}", header.style(self.magenta));
|
||||
}
|
||||
FileChange::Update {
|
||||
unified_diff,
|
||||
move_path,
|
||||
} => {
|
||||
let header = if let Some(dest) = move_path {
|
||||
format!(
|
||||
"{} {} -> {}",
|
||||
format_file_change(change),
|
||||
path.to_string_lossy(),
|
||||
dest.to_string_lossy()
|
||||
)
|
||||
} else {
|
||||
format!("{} {}", format_file_change(change), path.to_string_lossy())
|
||||
};
|
||||
println!("{}", header.style(self.magenta));
|
||||
|
||||
// Colorize diff lines. We keep file header lines
|
||||
// (--- / +++) without extra coloring so they are
|
||||
// still readable.
|
||||
for diff_line in unified_diff.lines() {
|
||||
if diff_line.starts_with('+') && !diff_line.starts_with("+++") {
|
||||
println!("{}", diff_line.style(self.green));
|
||||
} else if diff_line.starts_with('-')
|
||||
&& !diff_line.starts_with("---")
|
||||
{
|
||||
println!("{}", diff_line.style(self.red));
|
||||
} else {
|
||||
println!("{diff_line}");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
EventMsg::PatchApplyEnd(PatchApplyEndEvent {
|
||||
call_id,
|
||||
stdout,
|
||||
stderr,
|
||||
success,
|
||||
}) => {
|
||||
let patch_begin = self.call_id_to_patch.remove(&call_id);
|
||||
|
||||
// Compute duration and summary label similar to exec commands.
|
||||
let (duration, label) = if let Some(PatchApplyBegin {
|
||||
start_time,
|
||||
auto_approved,
|
||||
}) = patch_begin
|
||||
{
|
||||
(
|
||||
format!(" in {}", format_elapsed(start_time)),
|
||||
format!("apply_patch(auto_approved={auto_approved})"),
|
||||
)
|
||||
} else {
|
||||
(String::new(), format!("apply_patch('{call_id}')"))
|
||||
};
|
||||
|
||||
let (exit_code, output, title_style) = if success {
|
||||
(0, stdout, self.green)
|
||||
} else {
|
||||
(1, stderr, self.red)
|
||||
};
|
||||
|
||||
let title = format!("{label} exited {exit_code}{duration}:");
|
||||
ts_println!(self, "{}", title.style(title_style));
|
||||
for line in output.lines() {
|
||||
println!("{}", line.style(self.dimmed));
|
||||
}
|
||||
}
|
||||
EventMsg::ExecApprovalRequest(_) => {
|
||||
// Should we exit?
|
||||
}
|
||||
EventMsg::ApplyPatchApprovalRequest(_) => {
|
||||
// Should we exit?
|
||||
}
|
||||
EventMsg::AgentReasoning(agent_reasoning_event) => {
|
||||
if self.show_agent_reasoning {
|
||||
if !self.reasoning_started {
|
||||
ts_println!(
|
||||
self,
|
||||
"{}\n{}",
|
||||
"codex".style(self.italic).style(self.magenta),
|
||||
agent_reasoning_event.text,
|
||||
);
|
||||
} else {
|
||||
println!();
|
||||
self.reasoning_started = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
EventMsg::SessionConfigured(session_configured_event) => {
|
||||
let SessionConfiguredEvent {
|
||||
session_id,
|
||||
model,
|
||||
history_log_id: _,
|
||||
history_entry_count: _,
|
||||
} = session_configured_event;
|
||||
|
||||
ts_println!(
|
||||
self,
|
||||
"{} {}",
|
||||
"codex session".style(self.magenta).style(self.bold),
|
||||
session_id.to_string().style(self.dimmed)
|
||||
);
|
||||
|
||||
ts_println!(self, "model: {}", model);
|
||||
println!();
|
||||
}
|
||||
EventMsg::GetHistoryEntryResponse(_) => {
|
||||
// Currently ignored in exec output.
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn escape_command(command: &[String]) -> String {
|
||||
try_join(command.iter().map(|s| s.as_str())).unwrap_or_else(|_| command.join(" "))
|
||||
}
|
||||
|
||||
fn format_file_change(change: &FileChange) -> &'static str {
|
||||
match change {
|
||||
FileChange::Add { .. } => "A",
|
||||
FileChange::Delete => "D",
|
||||
FileChange::Update {
|
||||
move_path: Some(_), ..
|
||||
} => "R",
|
||||
FileChange::Update {
|
||||
move_path: None, ..
|
||||
} => "M",
|
||||
}
|
||||
}
|
||||
@@ -1,48 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use codex_core::config::Config;
|
||||
use codex_core::protocol::Event;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use serde_json::json;
|
||||
|
||||
use crate::event_processor::EventProcessor;
|
||||
use crate::event_processor::create_config_summary_entries;
|
||||
|
||||
pub(crate) struct EventProcessorWithJsonOutput;
|
||||
|
||||
impl EventProcessorWithJsonOutput {
|
||||
pub fn new() -> Self {
|
||||
Self {}
|
||||
}
|
||||
}
|
||||
|
||||
impl EventProcessor for EventProcessorWithJsonOutput {
|
||||
fn print_config_summary(&mut self, config: &Config, prompt: &str) {
|
||||
let entries = create_config_summary_entries(config)
|
||||
.into_iter()
|
||||
.map(|(key, value)| (key.to_string(), value))
|
||||
.collect::<HashMap<String, String>>();
|
||||
#[allow(clippy::expect_used)]
|
||||
let config_json =
|
||||
serde_json::to_string(&entries).expect("Failed to serialize config summary to JSON");
|
||||
println!("{config_json}");
|
||||
|
||||
let prompt_json = json!({
|
||||
"prompt": prompt,
|
||||
});
|
||||
println!("{prompt_json}");
|
||||
}
|
||||
|
||||
fn process_event(&mut self, event: Event) {
|
||||
match event.msg {
|
||||
EventMsg::AgentMessageDelta(_) | EventMsg::AgentReasoningDelta(_) => {
|
||||
// Suppress streaming events in JSON mode.
|
||||
}
|
||||
_ => {
|
||||
if let Ok(line) = serde_json::to_string(&event) {
|
||||
println!("{line}");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,5 @@
|
||||
mod cli;
|
||||
mod event_processor;
|
||||
mod event_processor_with_human_output;
|
||||
mod event_processor_with_json_output;
|
||||
|
||||
use std::io::IsTerminal;
|
||||
use std::io::Read;
|
||||
@@ -21,15 +19,12 @@ use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_core::protocol::TaskCompleteEvent;
|
||||
use codex_core::util::is_inside_git_repo;
|
||||
use event_processor_with_human_output::EventProcessorWithHumanOutput;
|
||||
use event_processor_with_json_output::EventProcessorWithJsonOutput;
|
||||
use event_processor::EventProcessor;
|
||||
use tracing::debug;
|
||||
use tracing::error;
|
||||
use tracing::info;
|
||||
use tracing_subscriber::EnvFilter;
|
||||
|
||||
use crate::event_processor::EventProcessor;
|
||||
|
||||
pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> anyhow::Result<()> {
|
||||
let Cli {
|
||||
images,
|
||||
@@ -41,7 +36,6 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any
|
||||
skip_git_repo_check,
|
||||
color,
|
||||
last_message_file,
|
||||
json: json_mode,
|
||||
sandbox_mode: sandbox_mode_cli_arg,
|
||||
prompt,
|
||||
config_overrides,
|
||||
@@ -121,15 +115,12 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any
|
||||
};
|
||||
|
||||
let config = Config::load_with_cli_overrides(cli_kv_overrides, overrides)?;
|
||||
let mut event_processor: Box<dyn EventProcessor> = if json_mode {
|
||||
Box::new(EventProcessorWithJsonOutput::new())
|
||||
} else {
|
||||
Box::new(EventProcessorWithHumanOutput::create_with_ansi(
|
||||
stdout_with_ansi,
|
||||
&config,
|
||||
))
|
||||
};
|
||||
|
||||
println!("[DEBUG] streaming_enabled: {}", config.streaming_enabled);
|
||||
let mut event_processor = EventProcessor::create_with_ansi(
|
||||
stdout_with_ansi,
|
||||
!config.hide_agent_reasoning,
|
||||
config.streaming_enabled,
|
||||
);
|
||||
// Print the effective configuration and prompt so users can see what Codex
|
||||
// is using.
|
||||
event_processor.print_config_summary(&config, &prompt);
|
||||
|
||||
@@ -23,10 +23,3 @@ file-search *args:
|
||||
# format code
|
||||
fmt:
|
||||
cargo fmt -- --config imports_granularity=Item
|
||||
|
||||
fix:
|
||||
cargo clippy --fix --all-features --tests --allow-dirty
|
||||
|
||||
install:
|
||||
rustup show active-toolchain
|
||||
cargo fetch
|
||||
|
||||
@@ -57,12 +57,10 @@ async fn main() -> Result<()> {
|
||||
experimental: None,
|
||||
roots: None,
|
||||
sampling: None,
|
||||
elicitation: None,
|
||||
},
|
||||
client_info: Implementation {
|
||||
name: "codex-mcp-client".to_owned(),
|
||||
version: env!("CARGO_PKG_VERSION").to_owned(),
|
||||
title: Some("Codex".to_string()),
|
||||
},
|
||||
protocol_version: MCP_SCHEMA_VERSION.to_owned(),
|
||||
};
|
||||
|
||||
@@ -22,7 +22,6 @@ mcp-types = { path = "../mcp-types" }
|
||||
schemars = "0.8.22"
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
shlex = "1.3.0"
|
||||
toml = "0.9"
|
||||
tracing = { version = "0.1.41", features = ["log"] }
|
||||
tracing-subscriber = { version = "0.3", features = ["fmt", "env-filter"] }
|
||||
|
||||
@@ -108,10 +108,7 @@ pub(crate) fn create_tool_for_codex_tool_call_param() -> Tool {
|
||||
|
||||
Tool {
|
||||
name: "codex".to_string(),
|
||||
title: Some("Codex".to_string()),
|
||||
input_schema: tool_input_schema,
|
||||
// TODO(mbolin): This should be defined.
|
||||
output_schema: None,
|
||||
description: Some(
|
||||
"Run a Codex session. Accepts configuration parameters matching the Codex Config struct.".to_string(),
|
||||
),
|
||||
@@ -182,7 +179,6 @@ mod tests {
|
||||
let tool_json = serde_json::to_value(&tool).expect("tool serializes");
|
||||
let expected_tool_json = serde_json::json!({
|
||||
"name": "codex",
|
||||
"title": "Codex",
|
||||
"description": "Run a Codex session. Accepts configuration parameters matching the Codex Config struct.",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
|
||||
@@ -2,31 +2,33 @@
|
||||
//! Tokio task. Separated from `message_processor.rs` to keep that file small
|
||||
//! and to make future feature-growth easier to manage.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use codex_core::Codex;
|
||||
use codex_core::codex_wrapper::init_codex;
|
||||
use codex_core::config::Config as CodexConfig;
|
||||
use codex_core::protocol::AgentMessageEvent;
|
||||
use codex_core::protocol::Event;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::ExecApprovalRequestEvent;
|
||||
use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_core::protocol::ReviewDecision;
|
||||
use codex_core::protocol::Submission;
|
||||
use codex_core::protocol::TaskCompleteEvent;
|
||||
use mcp_types::CallToolResult;
|
||||
use mcp_types::ContentBlock;
|
||||
use mcp_types::ElicitRequest;
|
||||
use mcp_types::ElicitRequestParamsRequestedSchema;
|
||||
use mcp_types::ModelContextProtocolRequest;
|
||||
use mcp_types::CallToolResultContent;
|
||||
use mcp_types::JSONRPC_VERSION;
|
||||
use mcp_types::JSONRPCMessage;
|
||||
use mcp_types::JSONRPCResponse;
|
||||
use mcp_types::RequestId;
|
||||
use mcp_types::TextContent;
|
||||
use serde::Deserialize;
|
||||
use serde_json::json;
|
||||
use tracing::error;
|
||||
use tokio::sync::mpsc::Sender;
|
||||
|
||||
use crate::outgoing_message::OutgoingMessageSender;
|
||||
/// Convert a Codex [`Event`] to an MCP notification.
|
||||
fn codex_event_to_notification(event: &Event) -> JSONRPCMessage {
|
||||
#[expect(clippy::expect_used)]
|
||||
JSONRPCMessage::Notification(mcp_types::JSONRPCNotification {
|
||||
jsonrpc: JSONRPC_VERSION.into(),
|
||||
method: "codex/event".into(),
|
||||
params: Some(serde_json::to_value(event).expect("Event must serialize")),
|
||||
})
|
||||
}
|
||||
|
||||
/// Run a complete Codex session and stream events back to the client.
|
||||
///
|
||||
@@ -36,28 +38,34 @@ pub async fn run_codex_tool_session(
|
||||
id: RequestId,
|
||||
initial_prompt: String,
|
||||
config: CodexConfig,
|
||||
outgoing: Arc<OutgoingMessageSender>,
|
||||
outgoing: Sender<JSONRPCMessage>,
|
||||
) {
|
||||
let (codex, first_event, _ctrl_c) = match init_codex(config).await {
|
||||
Ok(res) => res,
|
||||
Err(e) => {
|
||||
let result = CallToolResult {
|
||||
content: vec![ContentBlock::TextContent(TextContent {
|
||||
content: vec![CallToolResultContent::TextContent(TextContent {
|
||||
r#type: "text".to_string(),
|
||||
text: format!("Failed to start Codex session: {e}"),
|
||||
annotations: None,
|
||||
})],
|
||||
is_error: Some(true),
|
||||
structured_content: None,
|
||||
};
|
||||
outgoing.send_response(id.clone(), result.into()).await;
|
||||
let _ = outgoing
|
||||
.send(JSONRPCMessage::Response(JSONRPCResponse {
|
||||
jsonrpc: JSONRPC_VERSION.into(),
|
||||
id,
|
||||
result: result.into(),
|
||||
}))
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
let codex = Arc::new(codex);
|
||||
|
||||
// Send initial SessionConfigured event.
|
||||
outgoing.send_event_as_notification(&first_event).await;
|
||||
let _ = outgoing
|
||||
.send(codex_event_to_notification(&first_event))
|
||||
.await;
|
||||
|
||||
// Use the original MCP request ID as the `sub_id` for the Codex submission so that
|
||||
// any events emitted for this tool-call can be correlated with the
|
||||
@@ -68,7 +76,7 @@ pub async fn run_codex_tool_session(
|
||||
};
|
||||
|
||||
let submission = Submission {
|
||||
id: sub_id.clone(),
|
||||
id: sub_id,
|
||||
op: Op::UserInput {
|
||||
items: vec![InputItem::Text {
|
||||
text: initial_prompt.clone(),
|
||||
@@ -80,105 +88,95 @@ pub async fn run_codex_tool_session(
|
||||
tracing::error!("Failed to submit initial prompt: {e}");
|
||||
}
|
||||
|
||||
let mut last_agent_message: Option<String> = None;
|
||||
|
||||
// Stream events until the task needs to pause for user interaction or
|
||||
// completes.
|
||||
loop {
|
||||
match codex.next_event().await {
|
||||
Ok(event) => {
|
||||
outgoing.send_event_as_notification(&event).await;
|
||||
let _ = outgoing.send(codex_event_to_notification(&event)).await;
|
||||
|
||||
match event.msg {
|
||||
EventMsg::ExecApprovalRequest(ExecApprovalRequestEvent {
|
||||
command,
|
||||
cwd,
|
||||
reason: _,
|
||||
}) => {
|
||||
let escaped_command = shlex::try_join(command.iter().map(|s| s.as_str()))
|
||||
.unwrap_or_else(|_| command.join(" "));
|
||||
let message = format!("Allow Codex to run `{escaped_command}` in {cwd:?}?");
|
||||
|
||||
let params = json!({
|
||||
// These fields are required so that `params`
|
||||
// conforms to ElicitRequestParams.
|
||||
"message": message,
|
||||
"requestedSchema": ElicitRequestParamsRequestedSchema {
|
||||
r#type: "object".to_string(),
|
||||
properties: json!({}),
|
||||
required: None,
|
||||
},
|
||||
|
||||
// These are additional fields the client can use to
|
||||
// correlate the request with the codex tool call.
|
||||
"codex_elicitation": "exec-approval",
|
||||
"codex_mcp_tool_call_id": sub_id,
|
||||
"codex_event_id": event.id,
|
||||
"codex_command": command,
|
||||
// Could convert it to base64 encoded bytes if we
|
||||
// don't want to use to_string_lossy() here?
|
||||
"codex_cwd": cwd.to_string_lossy().to_string()
|
||||
});
|
||||
let on_response = outgoing
|
||||
.send_request(ElicitRequest::METHOD, Some(params))
|
||||
match &event.msg {
|
||||
EventMsg::AgentMessage(AgentMessageEvent { message }) => {
|
||||
last_agent_message = Some(message.clone());
|
||||
}
|
||||
EventMsg::ExecApprovalRequest(_) => {
|
||||
let result = CallToolResult {
|
||||
content: vec![CallToolResultContent::TextContent(TextContent {
|
||||
r#type: "text".to_string(),
|
||||
text: "EXEC_APPROVAL_REQUIRED".to_string(),
|
||||
annotations: None,
|
||||
})],
|
||||
is_error: None,
|
||||
};
|
||||
let _ = outgoing
|
||||
.send(JSONRPCMessage::Response(JSONRPCResponse {
|
||||
jsonrpc: JSONRPC_VERSION.into(),
|
||||
id: id.clone(),
|
||||
result: result.into(),
|
||||
}))
|
||||
.await;
|
||||
|
||||
// Listen for the response on a separate task so we do
|
||||
// not block the main loop of this function.
|
||||
{
|
||||
let codex = codex.clone();
|
||||
let event_id = event.id.clone();
|
||||
tokio::spawn(async move {
|
||||
on_exec_approval_response(event_id, on_response, codex).await;
|
||||
});
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
EventMsg::ApplyPatchApprovalRequest(_) => {
|
||||
let result = CallToolResult {
|
||||
content: vec![ContentBlock::TextContent(TextContent {
|
||||
content: vec![CallToolResultContent::TextContent(TextContent {
|
||||
r#type: "text".to_string(),
|
||||
text: "PATCH_APPROVAL_REQUIRED".to_string(),
|
||||
annotations: None,
|
||||
})],
|
||||
is_error: None,
|
||||
structured_content: None,
|
||||
};
|
||||
outgoing.send_response(id.clone(), result.into()).await;
|
||||
let _ = outgoing
|
||||
.send(JSONRPCMessage::Response(JSONRPCResponse {
|
||||
jsonrpc: JSONRPC_VERSION.into(),
|
||||
id: id.clone(),
|
||||
result: result.into(),
|
||||
}))
|
||||
.await;
|
||||
break;
|
||||
}
|
||||
EventMsg::TaskComplete(TaskCompleteEvent { last_agent_message }) => {
|
||||
let text = match last_agent_message {
|
||||
Some(msg) => msg.clone(),
|
||||
None => "".to_string(),
|
||||
EventMsg::TaskComplete(TaskCompleteEvent {
|
||||
last_agent_message: _,
|
||||
}) => {
|
||||
let result = if let Some(msg) = last_agent_message {
|
||||
CallToolResult {
|
||||
content: vec![CallToolResultContent::TextContent(TextContent {
|
||||
r#type: "text".to_string(),
|
||||
text: msg,
|
||||
annotations: None,
|
||||
})],
|
||||
is_error: None,
|
||||
}
|
||||
} else {
|
||||
CallToolResult {
|
||||
content: vec![CallToolResultContent::TextContent(TextContent {
|
||||
r#type: "text".to_string(),
|
||||
text: String::new(),
|
||||
annotations: None,
|
||||
})],
|
||||
is_error: None,
|
||||
}
|
||||
};
|
||||
let result = CallToolResult {
|
||||
content: vec![ContentBlock::TextContent(TextContent {
|
||||
r#type: "text".to_string(),
|
||||
text,
|
||||
annotations: None,
|
||||
})],
|
||||
is_error: None,
|
||||
structured_content: None,
|
||||
};
|
||||
outgoing.send_response(id.clone(), result.into()).await;
|
||||
let _ = outgoing
|
||||
.send(JSONRPCMessage::Response(JSONRPCResponse {
|
||||
jsonrpc: JSONRPC_VERSION.into(),
|
||||
id: id.clone(),
|
||||
result: result.into(),
|
||||
}))
|
||||
.await;
|
||||
break;
|
||||
}
|
||||
EventMsg::SessionConfigured(_) => {
|
||||
tracing::error!("unexpected SessionConfigured event");
|
||||
}
|
||||
EventMsg::AgentMessageDelta(_) => {
|
||||
// TODO: think how we want to support this in the MCP
|
||||
}
|
||||
EventMsg::AgentReasoningDelta(_) => {
|
||||
// TODO: think how we want to support this in the MCP
|
||||
}
|
||||
EventMsg::AgentMessage(AgentMessageEvent { .. }) => {
|
||||
// TODO: think how we want to support this in the MCP
|
||||
}
|
||||
EventMsg::Error(_)
|
||||
| EventMsg::TaskStarted
|
||||
| EventMsg::TokenCount(_)
|
||||
| EventMsg::AgentReasoning(_)
|
||||
| EventMsg::AgentMessageDelta(_)
|
||||
| EventMsg::AgentReasoningDelta(_)
|
||||
| EventMsg::McpToolCallBegin(_)
|
||||
| EventMsg::McpToolCallEnd(_)
|
||||
| EventMsg::ExecCommandBegin(_)
|
||||
@@ -198,58 +196,22 @@ pub async fn run_codex_tool_session(
|
||||
}
|
||||
Err(e) => {
|
||||
let result = CallToolResult {
|
||||
content: vec![ContentBlock::TextContent(TextContent {
|
||||
content: vec![CallToolResultContent::TextContent(TextContent {
|
||||
r#type: "text".to_string(),
|
||||
text: format!("Codex runtime error: {e}"),
|
||||
annotations: None,
|
||||
})],
|
||||
is_error: Some(true),
|
||||
// TODO(mbolin): Could present the error in a more
|
||||
// structured way.
|
||||
structured_content: None,
|
||||
};
|
||||
outgoing.send_response(id.clone(), result.into()).await;
|
||||
let _ = outgoing
|
||||
.send(JSONRPCMessage::Response(JSONRPCResponse {
|
||||
jsonrpc: JSONRPC_VERSION.into(),
|
||||
id: id.clone(),
|
||||
result: result.into(),
|
||||
}))
|
||||
.await;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn on_exec_approval_response(
|
||||
event_id: String,
|
||||
receiver: tokio::sync::oneshot::Receiver<mcp_types::Result>,
|
||||
codex: Arc<Codex>,
|
||||
) {
|
||||
let response = receiver.await;
|
||||
let value = match response {
|
||||
Ok(value) => value,
|
||||
Err(err) => {
|
||||
error!("request failed: {err:?}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// Try to deserialize `value` and then make the appropriate call to `codex`.
|
||||
let response = match serde_json::from_value::<ExecApprovalResponse>(value) {
|
||||
Ok(response) => response,
|
||||
Err(err) => {
|
||||
error!("failed to deserialize ExecApprovalResponse: {err}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
if let Err(err) = codex
|
||||
.submit(Op::ExecApproval {
|
||||
id: event_id,
|
||||
decision: response.decision,
|
||||
})
|
||||
.await
|
||||
{
|
||||
error!("failed to submit ExecApproval: {err}");
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct ExecApprovalResponse {
|
||||
pub decision: ReviewDecision,
|
||||
}
|
||||
|
||||
@@ -18,11 +18,8 @@ mod codex_tool_config;
|
||||
mod codex_tool_runner;
|
||||
mod json_to_toml;
|
||||
mod message_processor;
|
||||
mod outgoing_message;
|
||||
|
||||
use crate::message_processor::MessageProcessor;
|
||||
use crate::outgoing_message::OutgoingMessage;
|
||||
use crate::outgoing_message::OutgoingMessageSender;
|
||||
|
||||
/// Size of the bounded channels used to communicate between tasks. The value
|
||||
/// is a balance between throughput and memory usage – 128 messages should be
|
||||
@@ -38,7 +35,7 @@ pub async fn run_main(codex_linux_sandbox_exe: Option<PathBuf>) -> IoResult<()>
|
||||
|
||||
// Set up channels.
|
||||
let (incoming_tx, mut incoming_rx) = mpsc::channel::<JSONRPCMessage>(CHANNEL_CAPACITY);
|
||||
let (outgoing_tx, mut outgoing_rx) = mpsc::channel::<OutgoingMessage>(CHANNEL_CAPACITY);
|
||||
let (outgoing_tx, mut outgoing_rx) = mpsc::channel::<JSONRPCMessage>(CHANNEL_CAPACITY);
|
||||
|
||||
// Task: read from stdin, push to `incoming_tx`.
|
||||
let stdin_reader_handle = tokio::spawn({
|
||||
@@ -66,15 +63,16 @@ pub async fn run_main(codex_linux_sandbox_exe: Option<PathBuf>) -> IoResult<()>
|
||||
|
||||
// Task: process incoming messages.
|
||||
let processor_handle = tokio::spawn({
|
||||
let outgoing_message_sender = OutgoingMessageSender::new(outgoing_tx);
|
||||
let mut processor = MessageProcessor::new(outgoing_message_sender, codex_linux_sandbox_exe);
|
||||
let mut processor = MessageProcessor::new(outgoing_tx.clone(), codex_linux_sandbox_exe);
|
||||
async move {
|
||||
while let Some(msg) = incoming_rx.recv().await {
|
||||
match msg {
|
||||
JSONRPCMessage::Request(r) => processor.process_request(r).await,
|
||||
JSONRPCMessage::Response(r) => processor.process_response(r).await,
|
||||
JSONRPCMessage::Request(r) => processor.process_request(r),
|
||||
JSONRPCMessage::Response(r) => processor.process_response(r),
|
||||
JSONRPCMessage::Notification(n) => processor.process_notification(n),
|
||||
JSONRPCMessage::BatchRequest(b) => processor.process_batch_request(b),
|
||||
JSONRPCMessage::Error(e) => processor.process_error(e),
|
||||
JSONRPCMessage::BatchResponse(b) => processor.process_batch_response(b),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -85,8 +83,7 @@ pub async fn run_main(codex_linux_sandbox_exe: Option<PathBuf>) -> IoResult<()>
|
||||
// Task: write outgoing messages to stdout.
|
||||
let stdout_writer_handle = tokio::spawn(async move {
|
||||
let mut stdout = io::stdout();
|
||||
while let Some(outgoing_message) = outgoing_rx.recv().await {
|
||||
let msg: JSONRPCMessage = outgoing_message.into();
|
||||
while let Some(msg) = outgoing_rx.recv().await {
|
||||
match serde_json::to_string(&msg) {
|
||||
Ok(json) => {
|
||||
if let Err(e) = stdout.write_all(json.as_bytes()).await {
|
||||
|
||||
@@ -1,17 +1,19 @@
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::codex_tool_config::CodexToolCallParam;
|
||||
use crate::codex_tool_config::create_tool_for_codex_tool_call_param;
|
||||
use crate::outgoing_message::OutgoingMessageSender;
|
||||
|
||||
use codex_core::config::Config as CodexConfig;
|
||||
use mcp_types::CallToolRequestParams;
|
||||
use mcp_types::CallToolResult;
|
||||
use mcp_types::CallToolResultContent;
|
||||
use mcp_types::ClientRequest;
|
||||
use mcp_types::ContentBlock;
|
||||
use mcp_types::JSONRPC_VERSION;
|
||||
use mcp_types::JSONRPCBatchRequest;
|
||||
use mcp_types::JSONRPCBatchResponse;
|
||||
use mcp_types::JSONRPCError;
|
||||
use mcp_types::JSONRPCErrorError;
|
||||
use mcp_types::JSONRPCMessage;
|
||||
use mcp_types::JSONRPCNotification;
|
||||
use mcp_types::JSONRPCRequest;
|
||||
use mcp_types::JSONRPCResponse;
|
||||
@@ -22,10 +24,11 @@ use mcp_types::ServerCapabilitiesTools;
|
||||
use mcp_types::ServerNotification;
|
||||
use mcp_types::TextContent;
|
||||
use serde_json::json;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::task;
|
||||
|
||||
pub(crate) struct MessageProcessor {
|
||||
outgoing: Arc<OutgoingMessageSender>,
|
||||
outgoing: mpsc::Sender<JSONRPCMessage>,
|
||||
initialized: bool,
|
||||
codex_linux_sandbox_exe: Option<PathBuf>,
|
||||
}
|
||||
@@ -34,17 +37,17 @@ impl MessageProcessor {
|
||||
/// Create a new `MessageProcessor`, retaining a handle to the outgoing
|
||||
/// `Sender` so handlers can enqueue messages to be written to stdout.
|
||||
pub(crate) fn new(
|
||||
outgoing: OutgoingMessageSender,
|
||||
outgoing: mpsc::Sender<JSONRPCMessage>,
|
||||
codex_linux_sandbox_exe: Option<PathBuf>,
|
||||
) -> Self {
|
||||
Self {
|
||||
outgoing: Arc::new(outgoing),
|
||||
outgoing,
|
||||
initialized: false,
|
||||
codex_linux_sandbox_exe,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn process_request(&mut self, request: JSONRPCRequest) {
|
||||
pub(crate) fn process_request(&mut self, request: JSONRPCRequest) {
|
||||
// Hold on to the ID so we can respond.
|
||||
let request_id = request.id.clone();
|
||||
|
||||
@@ -59,10 +62,10 @@ impl MessageProcessor {
|
||||
// Dispatch to a dedicated handler for each request type.
|
||||
match client_request {
|
||||
ClientRequest::InitializeRequest(params) => {
|
||||
self.handle_initialize(request_id, params).await;
|
||||
self.handle_initialize(request_id, params);
|
||||
}
|
||||
ClientRequest::PingRequest(params) => {
|
||||
self.handle_ping(request_id, params).await;
|
||||
self.handle_ping(request_id, params);
|
||||
}
|
||||
ClientRequest::ListResourcesRequest(params) => {
|
||||
self.handle_list_resources(params);
|
||||
@@ -86,10 +89,10 @@ impl MessageProcessor {
|
||||
self.handle_get_prompt(params);
|
||||
}
|
||||
ClientRequest::ListToolsRequest(params) => {
|
||||
self.handle_list_tools(request_id, params).await;
|
||||
self.handle_list_tools(request_id, params);
|
||||
}
|
||||
ClientRequest::CallToolRequest(params) => {
|
||||
self.handle_call_tool(request_id, params).await;
|
||||
self.handle_call_tool(request_id, params);
|
||||
}
|
||||
ClientRequest::SetLevelRequest(params) => {
|
||||
self.handle_set_level(params);
|
||||
@@ -101,10 +104,8 @@ impl MessageProcessor {
|
||||
}
|
||||
|
||||
/// Handle a standalone JSON-RPC response originating from the peer.
|
||||
pub(crate) async fn process_response(&mut self, response: JSONRPCResponse) {
|
||||
pub(crate) fn process_response(&mut self, response: JSONRPCResponse) {
|
||||
tracing::info!("<- response: {:?}", response);
|
||||
let JSONRPCResponse { id, result, .. } = response;
|
||||
self.outgoing.notify_client_response(id, result).await
|
||||
}
|
||||
|
||||
/// Handle a fire-and-forget JSON-RPC notification.
|
||||
@@ -144,12 +145,42 @@ impl MessageProcessor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle a batch of requests and/or notifications.
|
||||
pub(crate) fn process_batch_request(&mut self, batch: JSONRPCBatchRequest) {
|
||||
tracing::info!("<- batch request containing {} item(s)", batch.len());
|
||||
for item in batch {
|
||||
match item {
|
||||
mcp_types::JSONRPCBatchRequestItem::JSONRPCRequest(req) => {
|
||||
self.process_request(req);
|
||||
}
|
||||
mcp_types::JSONRPCBatchRequestItem::JSONRPCNotification(note) => {
|
||||
self.process_notification(note);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle an error object received from the peer.
|
||||
pub(crate) fn process_error(&mut self, err: JSONRPCError) {
|
||||
tracing::error!("<- error: {:?}", err);
|
||||
}
|
||||
|
||||
async fn handle_initialize(
|
||||
/// Handle a batch of responses/errors.
|
||||
pub(crate) fn process_batch_response(&mut self, batch: JSONRPCBatchResponse) {
|
||||
tracing::info!("<- batch response containing {} item(s)", batch.len());
|
||||
for item in batch {
|
||||
match item {
|
||||
mcp_types::JSONRPCBatchResponseItem::JSONRPCResponse(resp) => {
|
||||
self.process_response(resp);
|
||||
}
|
||||
mcp_types::JSONRPCBatchResponseItem::JSONRPCError(err) => {
|
||||
self.process_error(err);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_initialize(
|
||||
&mut self,
|
||||
id: RequestId,
|
||||
params: <mcp_types::InitializeRequest as ModelContextProtocolRequest>::Params,
|
||||
@@ -158,12 +189,19 @@ impl MessageProcessor {
|
||||
|
||||
if self.initialized {
|
||||
// Already initialised: send JSON-RPC error response.
|
||||
let error = JSONRPCErrorError {
|
||||
code: -32600, // Invalid Request
|
||||
message: "initialize called more than once".to_string(),
|
||||
data: None,
|
||||
};
|
||||
self.outgoing.send_error(id, error).await;
|
||||
let error_msg = JSONRPCMessage::Error(JSONRPCError {
|
||||
jsonrpc: JSONRPC_VERSION.into(),
|
||||
id,
|
||||
error: JSONRPCErrorError {
|
||||
code: -32600, // Invalid Request
|
||||
message: "initialize called more than once".to_string(),
|
||||
data: None,
|
||||
},
|
||||
});
|
||||
|
||||
if let Err(e) = self.outgoing.try_send(error_msg) {
|
||||
tracing::error!("Failed to send initialization error: {e}");
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -186,33 +224,37 @@ impl MessageProcessor {
|
||||
server_info: mcp_types::Implementation {
|
||||
name: "codex-mcp-server".to_string(),
|
||||
version: mcp_types::MCP_SCHEMA_VERSION.to_string(),
|
||||
title: Some("Codex".to_string()),
|
||||
},
|
||||
};
|
||||
|
||||
self.send_response::<mcp_types::InitializeRequest>(id, result)
|
||||
.await;
|
||||
self.send_response::<mcp_types::InitializeRequest>(id, result);
|
||||
}
|
||||
|
||||
async fn send_response<T>(&self, id: RequestId, result: T::Result)
|
||||
fn send_response<T>(&self, id: RequestId, result: T::Result)
|
||||
where
|
||||
T: ModelContextProtocolRequest,
|
||||
{
|
||||
// result has `Serialized` instance so should never fail
|
||||
#[expect(clippy::unwrap_used)]
|
||||
let result = serde_json::to_value(result).unwrap();
|
||||
self.outgoing.send_response(id, result).await;
|
||||
let response = JSONRPCMessage::Response(JSONRPCResponse {
|
||||
jsonrpc: JSONRPC_VERSION.into(),
|
||||
id,
|
||||
result: serde_json::to_value(result).unwrap(),
|
||||
});
|
||||
|
||||
if let Err(e) = self.outgoing.try_send(response) {
|
||||
tracing::error!("Failed to send response: {e}");
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_ping(
|
||||
fn handle_ping(
|
||||
&self,
|
||||
id: RequestId,
|
||||
params: <mcp_types::PingRequest as mcp_types::ModelContextProtocolRequest>::Params,
|
||||
) {
|
||||
tracing::info!("ping -> params: {:?}", params);
|
||||
let result = json!({});
|
||||
self.send_response::<mcp_types::PingRequest>(id, result)
|
||||
.await;
|
||||
self.send_response::<mcp_types::PingRequest>(id, result);
|
||||
}
|
||||
|
||||
fn handle_list_resources(
|
||||
@@ -265,7 +307,7 @@ impl MessageProcessor {
|
||||
tracing::info!("prompts/get -> params: {:?}", params);
|
||||
}
|
||||
|
||||
async fn handle_list_tools(
|
||||
fn handle_list_tools(
|
||||
&self,
|
||||
id: RequestId,
|
||||
params: <mcp_types::ListToolsRequest as mcp_types::ModelContextProtocolRequest>::Params,
|
||||
@@ -276,11 +318,10 @@ impl MessageProcessor {
|
||||
next_cursor: None,
|
||||
};
|
||||
|
||||
self.send_response::<mcp_types::ListToolsRequest>(id, result)
|
||||
.await;
|
||||
self.send_response::<mcp_types::ListToolsRequest>(id, result);
|
||||
}
|
||||
|
||||
async fn handle_call_tool(
|
||||
fn handle_call_tool(
|
||||
&self,
|
||||
id: RequestId,
|
||||
params: <mcp_types::CallToolRequest as mcp_types::ModelContextProtocolRequest>::Params,
|
||||
@@ -292,16 +333,14 @@ impl MessageProcessor {
|
||||
if name != "codex" {
|
||||
// Tool not found – return error result so the LLM can react.
|
||||
let result = CallToolResult {
|
||||
content: vec![ContentBlock::TextContent(TextContent {
|
||||
content: vec![CallToolResultContent::TextContent(TextContent {
|
||||
r#type: "text".to_string(),
|
||||
text: format!("Unknown tool '{name}'"),
|
||||
annotations: None,
|
||||
})],
|
||||
is_error: Some(true),
|
||||
structured_content: None,
|
||||
};
|
||||
self.send_response::<mcp_types::CallToolRequest>(id, result)
|
||||
.await;
|
||||
self.send_response::<mcp_types::CallToolRequest>(id, result);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -311,7 +350,7 @@ impl MessageProcessor {
|
||||
Ok(cfg) => cfg,
|
||||
Err(e) => {
|
||||
let result = CallToolResult {
|
||||
content: vec![ContentBlock::TextContent(TextContent {
|
||||
content: vec![CallToolResultContent::TextContent(TextContent {
|
||||
r#type: "text".to_owned(),
|
||||
text: format!(
|
||||
"Failed to load Codex configuration from overrides: {e}"
|
||||
@@ -319,31 +358,27 @@ impl MessageProcessor {
|
||||
annotations: None,
|
||||
})],
|
||||
is_error: Some(true),
|
||||
structured_content: None,
|
||||
};
|
||||
self.send_response::<mcp_types::CallToolRequest>(id, result)
|
||||
.await;
|
||||
self.send_response::<mcp_types::CallToolRequest>(id, result);
|
||||
return;
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
let result = CallToolResult {
|
||||
content: vec![ContentBlock::TextContent(TextContent {
|
||||
content: vec![CallToolResultContent::TextContent(TextContent {
|
||||
r#type: "text".to_owned(),
|
||||
text: format!("Failed to parse configuration for Codex tool: {e}"),
|
||||
annotations: None,
|
||||
})],
|
||||
is_error: Some(true),
|
||||
structured_content: None,
|
||||
};
|
||||
self.send_response::<mcp_types::CallToolRequest>(id, result)
|
||||
.await;
|
||||
self.send_response::<mcp_types::CallToolRequest>(id, result);
|
||||
return;
|
||||
}
|
||||
},
|
||||
None => {
|
||||
let result = CallToolResult {
|
||||
content: vec![ContentBlock::TextContent(TextContent {
|
||||
content: vec![CallToolResultContent::TextContent(TextContent {
|
||||
r#type: "text".to_string(),
|
||||
text:
|
||||
"Missing arguments for codex tool-call; the `prompt` field is required."
|
||||
@@ -351,10 +386,8 @@ impl MessageProcessor {
|
||||
annotations: None,
|
||||
})],
|
||||
is_error: Some(true),
|
||||
structured_content: None,
|
||||
};
|
||||
self.send_response::<mcp_types::CallToolRequest>(id, result)
|
||||
.await;
|
||||
self.send_response::<mcp_types::CallToolRequest>(id, result);
|
||||
return;
|
||||
}
|
||||
};
|
||||
@@ -365,7 +398,7 @@ impl MessageProcessor {
|
||||
// Spawn an async task to handle the Codex session so that we do not
|
||||
// block the synchronous message-processing loop.
|
||||
task::spawn(async move {
|
||||
// Run the Codex session and stream events Fck to the client.
|
||||
// Run the Codex session and stream events back to the client.
|
||||
crate::codex_tool_runner::run_codex_tool_session(id, initial_prompt, config, outgoing)
|
||||
.await;
|
||||
});
|
||||
|
||||
@@ -1,165 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::atomic::AtomicI64;
|
||||
use std::sync::atomic::Ordering;
|
||||
|
||||
use codex_core::protocol::Event;
|
||||
use mcp_types::JSONRPC_VERSION;
|
||||
use mcp_types::JSONRPCError;
|
||||
use mcp_types::JSONRPCErrorError;
|
||||
use mcp_types::JSONRPCMessage;
|
||||
use mcp_types::JSONRPCNotification;
|
||||
use mcp_types::JSONRPCRequest;
|
||||
use mcp_types::JSONRPCResponse;
|
||||
use mcp_types::RequestId;
|
||||
use mcp_types::Result;
|
||||
use serde::Serialize;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::oneshot;
|
||||
use tracing::warn;
|
||||
|
||||
pub(crate) struct OutgoingMessageSender {
|
||||
next_request_id: AtomicI64,
|
||||
sender: mpsc::Sender<OutgoingMessage>,
|
||||
request_id_to_callback: Mutex<HashMap<RequestId, oneshot::Sender<Result>>>,
|
||||
}
|
||||
|
||||
impl OutgoingMessageSender {
|
||||
pub(crate) fn new(sender: mpsc::Sender<OutgoingMessage>) -> Self {
|
||||
Self {
|
||||
next_request_id: AtomicI64::new(0),
|
||||
sender,
|
||||
request_id_to_callback: Mutex::new(HashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn send_request(
|
||||
&self,
|
||||
method: &str,
|
||||
params: Option<serde_json::Value>,
|
||||
) -> oneshot::Receiver<Result> {
|
||||
let id = RequestId::Integer(self.next_request_id.fetch_add(1, Ordering::Relaxed));
|
||||
let outgoing_message_id = id.clone();
|
||||
let (tx_approve, rx_approve) = oneshot::channel();
|
||||
{
|
||||
let mut request_id_to_callback = self.request_id_to_callback.lock().await;
|
||||
request_id_to_callback.insert(id, tx_approve);
|
||||
}
|
||||
|
||||
let outgoing_message = OutgoingMessage::Request(OutgoingRequest {
|
||||
id: outgoing_message_id,
|
||||
method: method.to_string(),
|
||||
params,
|
||||
});
|
||||
let _ = self.sender.send(outgoing_message).await;
|
||||
rx_approve
|
||||
}
|
||||
|
||||
pub(crate) async fn notify_client_response(&self, id: RequestId, result: Result) {
|
||||
let entry = {
|
||||
let mut request_id_to_callback = self.request_id_to_callback.lock().await;
|
||||
request_id_to_callback.remove_entry(&id)
|
||||
};
|
||||
|
||||
match entry {
|
||||
Some((id, sender)) => {
|
||||
if let Err(err) = sender.send(result) {
|
||||
warn!("could not notify callback for {id:?} due to: {err:?}");
|
||||
}
|
||||
}
|
||||
None => {
|
||||
warn!("could not find callback for {id:?}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn send_response(&self, id: RequestId, result: Result) {
|
||||
let outgoing_message = OutgoingMessage::Response(OutgoingResponse { id, result });
|
||||
let _ = self.sender.send(outgoing_message).await;
|
||||
}
|
||||
|
||||
pub(crate) async fn send_event_as_notification(&self, event: &Event) {
|
||||
#[expect(clippy::expect_used)]
|
||||
let params = Some(serde_json::to_value(event).expect("Event must serialize"));
|
||||
let outgoing_message = OutgoingMessage::Notification(OutgoingNotification {
|
||||
method: "codex/event".to_string(),
|
||||
params,
|
||||
});
|
||||
let _ = self.sender.send(outgoing_message).await;
|
||||
}
|
||||
|
||||
pub(crate) async fn send_error(&self, id: RequestId, error: JSONRPCErrorError) {
|
||||
let outgoing_message = OutgoingMessage::Error(OutgoingError { id, error });
|
||||
let _ = self.sender.send(outgoing_message).await;
|
||||
}
|
||||
}
|
||||
|
||||
/// Outgoing message from the server to the client.
|
||||
pub(crate) enum OutgoingMessage {
|
||||
Request(OutgoingRequest),
|
||||
Notification(OutgoingNotification),
|
||||
Response(OutgoingResponse),
|
||||
Error(OutgoingError),
|
||||
}
|
||||
|
||||
impl From<OutgoingMessage> for JSONRPCMessage {
|
||||
fn from(val: OutgoingMessage) -> Self {
|
||||
use OutgoingMessage::*;
|
||||
match val {
|
||||
Request(OutgoingRequest { id, method, params }) => {
|
||||
JSONRPCMessage::Request(JSONRPCRequest {
|
||||
jsonrpc: JSONRPC_VERSION.into(),
|
||||
id,
|
||||
method,
|
||||
params,
|
||||
})
|
||||
}
|
||||
Notification(OutgoingNotification { method, params }) => {
|
||||
JSONRPCMessage::Notification(JSONRPCNotification {
|
||||
jsonrpc: JSONRPC_VERSION.into(),
|
||||
method,
|
||||
params,
|
||||
})
|
||||
}
|
||||
Response(OutgoingResponse { id, result }) => {
|
||||
JSONRPCMessage::Response(JSONRPCResponse {
|
||||
jsonrpc: JSONRPC_VERSION.into(),
|
||||
id,
|
||||
result,
|
||||
})
|
||||
}
|
||||
Error(OutgoingError { id, error }) => JSONRPCMessage::Error(JSONRPCError {
|
||||
jsonrpc: JSONRPC_VERSION.into(),
|
||||
id,
|
||||
error,
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize)]
|
||||
pub(crate) struct OutgoingRequest {
|
||||
pub id: RequestId,
|
||||
pub method: String,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub params: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize)]
|
||||
pub(crate) struct OutgoingNotification {
|
||||
pub method: String,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub params: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize)]
|
||||
pub(crate) struct OutgoingResponse {
|
||||
pub id: RequestId,
|
||||
pub result: Result,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize)]
|
||||
pub(crate) struct OutgoingError {
|
||||
pub error: JSONRPCErrorError,
|
||||
pub id: RequestId,
|
||||
}
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
Types for Model Context Protocol. Inspired by https://crates.io/crates/lsp-types.
|
||||
|
||||
As documented on https://modelcontextprotocol.io/specification/2025-06-18/basic:
|
||||
As documented on https://modelcontextprotocol.io/specification/2025-03-26/basic:
|
||||
|
||||
- TypeScript schema is the source of truth: https://github.com/modelcontextprotocol/modelcontextprotocol/blob/main/schema/2025-06-18/schema.ts
|
||||
- JSON schema is amenable to automated tooling: https://github.com/modelcontextprotocol/modelcontextprotocol/blob/main/schema/2025-06-18/schema.json
|
||||
- TypeScript schema is the source of truth: https://github.com/modelcontextprotocol/modelcontextprotocol/blob/main/schema/2025-03-26/schema.ts
|
||||
- JSON schema is amenable to automated tooling: https://github.com/modelcontextprotocol/modelcontextprotocol/blob/main/schema/2025-03-26/schema.json
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
# flake8: noqa: E501
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import subprocess
|
||||
import sys
|
||||
@@ -14,13 +13,10 @@ from pathlib import Path
|
||||
# Helper first so it is defined when other functions call it.
|
||||
from typing import Any, Literal
|
||||
|
||||
SCHEMA_VERSION = "2025-06-18"
|
||||
SCHEMA_VERSION = "2025-03-26"
|
||||
JSONRPC_VERSION = "2.0"
|
||||
|
||||
STANDARD_DERIVE = "#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]\n"
|
||||
STANDARD_HASHABLE_DERIVE = (
|
||||
"#[derive(Debug, Clone, PartialEq, Deserialize, Serialize, Hash, Eq)]\n"
|
||||
)
|
||||
|
||||
# Will be populated with the schema's `definitions` map in `main()` so that
|
||||
# helper functions (for example `define_any_of`) can perform look-ups while
|
||||
@@ -30,27 +26,19 @@ DEFINITIONS: dict[str, Any] = {}
|
||||
CLIENT_REQUEST_TYPE_NAMES: list[str] = []
|
||||
# Concrete *Notification types that make up the ServerNotification enum.
|
||||
SERVER_NOTIFICATION_TYPE_NAMES: list[str] = []
|
||||
# Enum types that will need a `allow(clippy::large_enum_variant)` annotation in
|
||||
# order to compile without warnings.
|
||||
LARGE_ENUMS = {"ServerResult"}
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Embed, cluster and analyse text prompts via the OpenAI API.",
|
||||
)
|
||||
|
||||
default_schema_file = (
|
||||
Path(__file__).resolve().parent / "schema" / SCHEMA_VERSION / "schema.json"
|
||||
)
|
||||
parser.add_argument(
|
||||
"schema_file",
|
||||
nargs="?",
|
||||
default=default_schema_file,
|
||||
help="schema.json file to process",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
schema_file = args.schema_file
|
||||
num_args = len(sys.argv)
|
||||
if num_args == 1:
|
||||
schema_file = (
|
||||
Path(__file__).resolve().parent / "schema" / SCHEMA_VERSION / "schema.json"
|
||||
)
|
||||
elif num_args == 2:
|
||||
schema_file = Path(sys.argv[1])
|
||||
else:
|
||||
print("Usage: python3 codegen.py <schema.json>")
|
||||
return 1
|
||||
|
||||
lib_rs = Path(__file__).resolve().parent / "src/lib.rs"
|
||||
|
||||
@@ -209,8 +197,6 @@ def add_definition(name: str, definition: dict[str, Any], out: list[str]) -> Non
|
||||
if name.endswith("Result"):
|
||||
out.extend(f"impl From<{name}> for serde_json::Value {{\n")
|
||||
out.append(f" fn from(value: {name}) -> Self {{\n")
|
||||
out.append(" // Leave this as it should never fail\n")
|
||||
out.append(" #[expect(clippy::unwrap_used)]\n")
|
||||
out.append(" serde_json::to_value(value).unwrap()\n")
|
||||
out.append(" }\n")
|
||||
out.append("}\n\n")
|
||||
@@ -225,7 +211,20 @@ def add_definition(name: str, definition: dict[str, Any], out: list[str]) -> Non
|
||||
any_of = definition.get("anyOf", [])
|
||||
if any_of:
|
||||
assert isinstance(any_of, list)
|
||||
out.extend(define_any_of(name, any_of, description))
|
||||
if name == "JSONRPCMessage":
|
||||
# Special case for JSONRPCMessage because its definition in the
|
||||
# JSON schema does not quite match how we think about this type
|
||||
# definition in Rust.
|
||||
deep_copied_any_of = json.loads(json.dumps(any_of))
|
||||
deep_copied_any_of[2] = {
|
||||
"$ref": "#/definitions/JSONRPCBatchRequest",
|
||||
}
|
||||
deep_copied_any_of[5] = {
|
||||
"$ref": "#/definitions/JSONRPCBatchResponse",
|
||||
}
|
||||
out.extend(define_any_of(name, deep_copied_any_of, description))
|
||||
else:
|
||||
out.extend(define_any_of(name, any_of, description))
|
||||
return
|
||||
|
||||
type_prop = definition.get("type", None)
|
||||
@@ -394,7 +393,7 @@ def define_string_enum(
|
||||
|
||||
|
||||
def define_untagged_enum(name: str, type_list: list[str], out: list[str]) -> None:
|
||||
out.append(STANDARD_HASHABLE_DERIVE)
|
||||
out.append(STANDARD_DERIVE)
|
||||
out.append("#[serde(untagged)]\n")
|
||||
out.append(f"pub enum {name} {{\n")
|
||||
for simple_type in type_list:
|
||||
@@ -440,8 +439,6 @@ def define_any_of(
|
||||
if serde := get_serde_annotation_for_anyof_type(name):
|
||||
out.append(serde + "\n")
|
||||
|
||||
if name in LARGE_ENUMS:
|
||||
out.append("#[allow(clippy::large_enum_variant)]\n")
|
||||
out.append(f"pub enum {name} {{\n")
|
||||
|
||||
if name == "ClientRequest":
|
||||
@@ -599,8 +596,6 @@ def rust_prop_name(name: str, is_optional: bool) -> RustProp:
|
||||
prop_name = "r#type"
|
||||
elif name == "ref":
|
||||
prop_name = "r#ref"
|
||||
elif name == "enum":
|
||||
prop_name = "r#enum"
|
||||
elif snake_case := to_snake_case(name):
|
||||
prop_name = snake_case
|
||||
is_rename = True
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -10,7 +10,7 @@ use serde::Serialize;
|
||||
use serde::de::DeserializeOwned;
|
||||
use std::convert::TryFrom;
|
||||
|
||||
pub const MCP_SCHEMA_VERSION: &str = "2025-06-18";
|
||||
pub const MCP_SCHEMA_VERSION: &str = "2025-03-26";
|
||||
pub const JSONRPC_VERSION: &str = "2.0";
|
||||
|
||||
/// Paired request/response types for the Model Context Protocol (MCP).
|
||||
@@ -35,12 +35,6 @@ fn default_jsonrpc() -> String {
|
||||
pub struct Annotations {
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub audience: Option<Vec<Role>>,
|
||||
#[serde(
|
||||
rename = "lastModified",
|
||||
default,
|
||||
skip_serializing_if = "Option::is_none"
|
||||
)]
|
||||
pub last_modified: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub priority: Option<f64>,
|
||||
}
|
||||
@@ -56,14 +50,6 @@ pub struct AudioContent {
|
||||
pub r#type: String, // &'static str = "audio"
|
||||
}
|
||||
|
||||
/// Base interface for metadata with name (identifier) and title (display name) properties.
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct BaseMetadata {
|
||||
pub name: String,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub title: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct BlobResourceContents {
|
||||
pub blob: String,
|
||||
@@ -72,17 +58,6 @@ pub struct BlobResourceContents {
|
||||
pub uri: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct BooleanSchema {
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub default: Option<bool>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub description: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub title: Option<String>,
|
||||
pub r#type: String, // &'static str = "boolean"
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub enum CallToolRequest {}
|
||||
|
||||
@@ -100,17 +75,29 @@ pub struct CallToolRequestParams {
|
||||
}
|
||||
|
||||
/// The server's response to a tool call.
|
||||
///
|
||||
/// Any errors that originate from the tool SHOULD be reported inside the result
|
||||
/// object, with `isError` set to true, _not_ as an MCP protocol-level error
|
||||
/// response. Otherwise, the LLM would not be able to see that an error occurred
|
||||
/// and self-correct.
|
||||
///
|
||||
/// However, any errors in _finding_ the tool, an error indicating that the
|
||||
/// server does not support tool calls, or any other exceptional conditions,
|
||||
/// should be reported as an MCP error response.
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct CallToolResult {
|
||||
pub content: Vec<ContentBlock>,
|
||||
pub content: Vec<CallToolResultContent>,
|
||||
#[serde(rename = "isError", default, skip_serializing_if = "Option::is_none")]
|
||||
pub is_error: Option<bool>,
|
||||
#[serde(
|
||||
rename = "structuredContent",
|
||||
default,
|
||||
skip_serializing_if = "Option::is_none"
|
||||
)]
|
||||
pub structured_content: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum CallToolResultContent {
|
||||
TextContent(TextContent),
|
||||
ImageContent(ImageContent),
|
||||
AudioContent(AudioContent),
|
||||
EmbeddedResource(EmbeddedResource),
|
||||
}
|
||||
|
||||
impl From<CallToolResult> for serde_json::Value {
|
||||
@@ -140,8 +127,6 @@ pub struct CancelledNotificationParams {
|
||||
/// Capabilities a client may support. Known capabilities are defined here, in this schema, but this is not a closed set: any client can define its own, additional capabilities.
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct ClientCapabilities {
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub elicitation: Option<serde_json::Value>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub experimental: Option<serde_json::Value>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
@@ -209,7 +194,6 @@ pub enum ClientResult {
|
||||
Result(Result),
|
||||
CreateMessageResult(CreateMessageResult),
|
||||
ListRootsResult(ListRootsResult),
|
||||
ElicitResult(ElicitResult),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
@@ -224,18 +208,9 @@ impl ModelContextProtocolRequest for CompleteRequest {
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct CompleteRequestParams {
|
||||
pub argument: CompleteRequestParamsArgument,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub context: Option<CompleteRequestParamsContext>,
|
||||
pub r#ref: CompleteRequestParamsRef,
|
||||
}
|
||||
|
||||
/// Additional, optional context for completions
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct CompleteRequestParamsContext {
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub arguments: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
/// The argument's information
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct CompleteRequestParamsArgument {
|
||||
@@ -247,7 +222,7 @@ pub struct CompleteRequestParamsArgument {
|
||||
#[serde(untagged)]
|
||||
pub enum CompleteRequestParamsRef {
|
||||
PromptReference(PromptReference),
|
||||
ResourceTemplateReference(ResourceTemplateReference),
|
||||
ResourceReference(ResourceReference),
|
||||
}
|
||||
|
||||
/// The server's response to a completion/complete request
|
||||
@@ -273,16 +248,6 @@ impl From<CompleteResult> for serde_json::Value {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum ContentBlock {
|
||||
TextContent(TextContent),
|
||||
ImageContent(ImageContent),
|
||||
AudioContent(AudioContent),
|
||||
ResourceLink(ResourceLink),
|
||||
EmbeddedResource(EmbeddedResource),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub enum CreateMessageRequest {}
|
||||
|
||||
@@ -360,48 +325,6 @@ impl From<CreateMessageResult> for serde_json::Value {
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct Cursor(String);
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub enum ElicitRequest {}
|
||||
|
||||
impl ModelContextProtocolRequest for ElicitRequest {
|
||||
const METHOD: &'static str = "elicitation/create";
|
||||
type Params = ElicitRequestParams;
|
||||
type Result = ElicitResult;
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct ElicitRequestParams {
|
||||
pub message: String,
|
||||
#[serde(rename = "requestedSchema")]
|
||||
pub requested_schema: ElicitRequestParamsRequestedSchema,
|
||||
}
|
||||
|
||||
/// A restricted subset of JSON Schema.
|
||||
/// Only top-level properties are allowed, without nesting.
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct ElicitRequestParamsRequestedSchema {
|
||||
pub properties: serde_json::Value,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub required: Option<Vec<String>>,
|
||||
pub r#type: String, // &'static str = "object"
|
||||
}
|
||||
|
||||
/// The client's response to an elicitation request.
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct ElicitResult {
|
||||
pub action: String,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub content: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
impl From<ElicitResult> for serde_json::Value {
|
||||
fn from(value: ElicitResult) -> Self {
|
||||
// Leave this as it should never fail
|
||||
#[expect(clippy::unwrap_used)]
|
||||
serde_json::to_value(value).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
/// The contents of a resource, embedded into a prompt or tool call result.
|
||||
///
|
||||
/// It is up to the client how best to render embedded resources for the benefit
|
||||
@@ -423,18 +346,6 @@ pub enum EmbeddedResourceResource {
|
||||
|
||||
pub type EmptyResult = Result;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct EnumSchema {
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub description: Option<String>,
|
||||
pub r#enum: Vec<String>,
|
||||
#[serde(rename = "enumNames", default, skip_serializing_if = "Option::is_none")]
|
||||
pub enum_names: Option<Vec<String>>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub title: Option<String>,
|
||||
pub r#type: String, // &'static str = "string"
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub enum GetPromptRequest {}
|
||||
|
||||
@@ -478,12 +389,10 @@ pub struct ImageContent {
|
||||
pub r#type: String, // &'static str = "image"
|
||||
}
|
||||
|
||||
/// Describes the name and version of an MCP implementation, with an optional title for UI representation.
|
||||
/// Describes the name and version of an MCP implementation.
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct Implementation {
|
||||
pub name: String,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub title: Option<String>,
|
||||
pub version: String,
|
||||
}
|
||||
|
||||
@@ -533,6 +442,24 @@ impl ModelContextProtocolNotification for InitializedNotification {
|
||||
type Params = Option<serde_json::Value>;
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum JSONRPCBatchRequestItem {
|
||||
JSONRPCRequest(JSONRPCRequest),
|
||||
JSONRPCNotification(JSONRPCNotification),
|
||||
}
|
||||
|
||||
pub type JSONRPCBatchRequest = Vec<JSONRPCBatchRequestItem>;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum JSONRPCBatchResponseItem {
|
||||
JSONRPCResponse(JSONRPCResponse),
|
||||
JSONRPCError(JSONRPCError),
|
||||
}
|
||||
|
||||
pub type JSONRPCBatchResponse = Vec<JSONRPCBatchResponseItem>;
|
||||
|
||||
/// A response to a request that indicates an error occurred.
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct JSONRPCError {
|
||||
@@ -556,8 +483,10 @@ pub struct JSONRPCErrorError {
|
||||
pub enum JSONRPCMessage {
|
||||
Request(JSONRPCRequest),
|
||||
Notification(JSONRPCNotification),
|
||||
BatchRequest(JSONRPCBatchRequest),
|
||||
Response(JSONRPCResponse),
|
||||
Error(JSONRPCError),
|
||||
BatchResponse(JSONRPCBatchResponse),
|
||||
}
|
||||
|
||||
/// A notification which does not expect a response.
|
||||
@@ -848,19 +777,6 @@ pub struct Notification {
|
||||
pub params: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct NumberSchema {
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub description: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub maximum: Option<i64>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub minimum: Option<i64>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub title: Option<String>,
|
||||
pub r#type: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct PaginatedRequest {
|
||||
pub method: String,
|
||||
@@ -901,17 +817,6 @@ impl ModelContextProtocolRequest for PingRequest {
|
||||
type Result = Result;
|
||||
}
|
||||
|
||||
/// Restricted schema definitions that only allow primitive types
|
||||
/// without nested objects or arrays.
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum PrimitiveSchemaDefinition {
|
||||
StringSchema(StringSchema),
|
||||
NumberSchema(NumberSchema),
|
||||
BooleanSchema(BooleanSchema),
|
||||
EnumSchema(EnumSchema),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub enum ProgressNotification {}
|
||||
|
||||
@@ -931,7 +836,7 @@ pub struct ProgressNotificationParams {
|
||||
pub total: Option<f64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize, Hash, Eq)]
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum ProgressToken {
|
||||
String(String),
|
||||
@@ -946,8 +851,6 @@ pub struct Prompt {
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub description: Option<String>,
|
||||
pub name: String,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub title: Option<String>,
|
||||
}
|
||||
|
||||
/// Describes an argument that a prompt can accept.
|
||||
@@ -958,8 +861,6 @@ pub struct PromptArgument {
|
||||
pub name: String,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub required: Option<bool>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub title: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
@@ -976,16 +877,23 @@ impl ModelContextProtocolNotification for PromptListChangedNotification {
|
||||
/// resources from the MCP server.
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct PromptMessage {
|
||||
pub content: ContentBlock,
|
||||
pub content: PromptMessageContent,
|
||||
pub role: Role,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum PromptMessageContent {
|
||||
TextContent(TextContent),
|
||||
ImageContent(ImageContent),
|
||||
AudioContent(AudioContent),
|
||||
EmbeddedResource(EmbeddedResource),
|
||||
}
|
||||
|
||||
/// Identifies a prompt.
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct PromptReference {
|
||||
pub name: String,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub title: Option<String>,
|
||||
pub r#type: String, // &'static str = "ref/prompt"
|
||||
}
|
||||
|
||||
@@ -1031,7 +939,7 @@ pub struct Request {
|
||||
pub params: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize, Hash, Eq)]
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum RequestId {
|
||||
String(String),
|
||||
@@ -1050,8 +958,6 @@ pub struct Resource {
|
||||
pub name: String,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub size: Option<i64>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub title: Option<String>,
|
||||
pub uri: String,
|
||||
}
|
||||
|
||||
@@ -1063,26 +969,6 @@ pub struct ResourceContents {
|
||||
pub uri: String,
|
||||
}
|
||||
|
||||
/// A resource that the server is capable of reading, included in a prompt or tool call result.
|
||||
///
|
||||
/// Note: resource links returned by tools are not guaranteed to appear in the results of `resources/list` requests.
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct ResourceLink {
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub annotations: Option<Annotations>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub description: Option<String>,
|
||||
#[serde(rename = "mimeType", default, skip_serializing_if = "Option::is_none")]
|
||||
pub mime_type: Option<String>,
|
||||
pub name: String,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub size: Option<i64>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub title: Option<String>,
|
||||
pub r#type: String, // &'static str = "resource_link"
|
||||
pub uri: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub enum ResourceListChangedNotification {}
|
||||
|
||||
@@ -1091,6 +977,13 @@ impl ModelContextProtocolNotification for ResourceListChangedNotification {
|
||||
type Params = Option<serde_json::Value>;
|
||||
}
|
||||
|
||||
/// A reference to a resource or resource template definition.
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct ResourceReference {
|
||||
pub r#type: String, // &'static str = "ref/resource"
|
||||
pub uri: String,
|
||||
}
|
||||
|
||||
/// A template description for resources available on the server.
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct ResourceTemplate {
|
||||
@@ -1101,19 +994,10 @@ pub struct ResourceTemplate {
|
||||
#[serde(rename = "mimeType", default, skip_serializing_if = "Option::is_none")]
|
||||
pub mime_type: Option<String>,
|
||||
pub name: String,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub title: Option<String>,
|
||||
#[serde(rename = "uriTemplate")]
|
||||
pub uri_template: String,
|
||||
}
|
||||
|
||||
/// A reference to a resource or resource template definition.
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct ResourceTemplateReference {
|
||||
pub r#type: String, // &'static str = "ref/resource"
|
||||
pub uri: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub enum ResourceUpdatedNotification {}
|
||||
|
||||
@@ -1256,7 +1140,6 @@ pub enum ServerRequest {
|
||||
PingRequest(PingRequest),
|
||||
CreateMessageRequest(CreateMessageRequest),
|
||||
ListRootsRequest(ListRootsRequest),
|
||||
ElicitRequest(ElicitRequest),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
@@ -1289,21 +1172,6 @@ pub struct SetLevelRequestParams {
|
||||
pub level: LoggingLevel,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct StringSchema {
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub description: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub format: Option<String>,
|
||||
#[serde(rename = "maxLength", default, skip_serializing_if = "Option::is_none")]
|
||||
pub max_length: Option<i64>,
|
||||
#[serde(rename = "minLength", default, skip_serializing_if = "Option::is_none")]
|
||||
pub min_length: Option<i64>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub title: Option<String>,
|
||||
pub r#type: String, // &'static str = "string"
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub enum SubscribeRequest {}
|
||||
|
||||
@@ -1345,25 +1213,6 @@ pub struct Tool {
|
||||
#[serde(rename = "inputSchema")]
|
||||
pub input_schema: ToolInputSchema,
|
||||
pub name: String,
|
||||
#[serde(
|
||||
rename = "outputSchema",
|
||||
default,
|
||||
skip_serializing_if = "Option::is_none"
|
||||
)]
|
||||
pub output_schema: Option<ToolOutputSchema>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub title: Option<String>,
|
||||
}
|
||||
|
||||
/// An optional JSON Schema object defining the structure of the tool's output returned in
|
||||
/// the structuredContent field of a CallToolResult.
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct ToolOutputSchema {
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub properties: Option<serde_json::Value>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub required: Option<Vec<String>>,
|
||||
pub r#type: String, // &'static str = "object"
|
||||
}
|
||||
|
||||
/// A JSON Schema object defining the expected parameters for the tool.
|
||||
|
||||
@@ -17,8 +17,8 @@ fn deserialize_initialize_request() {
|
||||
"method": "initialize",
|
||||
"params": {
|
||||
"capabilities": {},
|
||||
"clientInfo": { "name": "acme-client", "title": "Acme", "version": "1.2.3" },
|
||||
"protocolVersion": "2025-06-18"
|
||||
"clientInfo": { "name": "acme-client", "version": "1.2.3" },
|
||||
"protocolVersion": "2025-03-26"
|
||||
}
|
||||
}"#;
|
||||
|
||||
@@ -37,8 +37,8 @@ fn deserialize_initialize_request() {
|
||||
method: "initialize".into(),
|
||||
params: Some(json!({
|
||||
"capabilities": {},
|
||||
"clientInfo": { "name": "acme-client", "title": "Acme", "version": "1.2.3" },
|
||||
"protocolVersion": "2025-06-18"
|
||||
"clientInfo": { "name": "acme-client", "version": "1.2.3" },
|
||||
"protocolVersion": "2025-03-26"
|
||||
})),
|
||||
};
|
||||
|
||||
@@ -57,14 +57,12 @@ fn deserialize_initialize_request() {
|
||||
experimental: None,
|
||||
roots: None,
|
||||
sampling: None,
|
||||
elicitation: None,
|
||||
},
|
||||
client_info: Implementation {
|
||||
name: "acme-client".into(),
|
||||
title: Some("Acme".to_string()),
|
||||
version: "1.2.3".into(),
|
||||
},
|
||||
protocol_version: "2025-06-18".into(),
|
||||
protocol_version: "2025-03-26".into(),
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
[toolchain]
|
||||
channel = "1.88.0"
|
||||
components = [ "clippy", "rustfmt", "rust-src"]
|
||||
@@ -18,16 +18,8 @@ use crossterm::event::KeyEvent;
|
||||
use crossterm::event::MouseEvent;
|
||||
use crossterm::event::MouseEventKind;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::atomic::Ordering;
|
||||
use std::sync::mpsc::Receiver;
|
||||
use std::sync::mpsc::channel;
|
||||
use std::thread;
|
||||
use std::time::Duration;
|
||||
|
||||
/// Time window for debouncing redraw requests.
|
||||
const REDRAW_DEBOUNCE: Duration = Duration::from_millis(10);
|
||||
|
||||
/// Top-level application state: which full-screen view is currently active.
|
||||
#[allow(clippy::large_enum_variant)]
|
||||
@@ -54,9 +46,6 @@ pub(crate) struct App<'a> {
|
||||
|
||||
file_search: FileSearchManager,
|
||||
|
||||
/// True when a redraw has been scheduled but not yet executed.
|
||||
pending_redraw: Arc<AtomicBool>,
|
||||
|
||||
/// Stored parameters needed to instantiate the ChatWidget later, e.g.,
|
||||
/// after dismissing the Git-repo warning.
|
||||
chat_args: Option<ChatWidgetArgs>,
|
||||
@@ -71,7 +60,7 @@ struct ChatWidgetArgs {
|
||||
initial_images: Vec<PathBuf>,
|
||||
}
|
||||
|
||||
impl App<'_> {
|
||||
impl<'a> App<'a> {
|
||||
pub(crate) fn new(
|
||||
config: Config,
|
||||
initial_prompt: Option<String>,
|
||||
@@ -81,7 +70,6 @@ impl App<'_> {
|
||||
) -> Self {
|
||||
let (app_event_tx, app_event_rx) = channel();
|
||||
let app_event_tx = AppEventSender::new(app_event_tx);
|
||||
let pending_redraw = Arc::new(AtomicBool::new(false));
|
||||
let scroll_event_helper = ScrollEventHelper::new(app_event_tx.clone());
|
||||
|
||||
// Spawn a dedicated thread for reading the crossterm event loop and
|
||||
@@ -95,7 +83,7 @@ impl App<'_> {
|
||||
app_event_tx.send(AppEvent::KeyEvent(key_event));
|
||||
}
|
||||
crossterm::event::Event::Resize(_, _) => {
|
||||
app_event_tx.send(AppEvent::RequestRedraw);
|
||||
app_event_tx.send(AppEvent::Redraw);
|
||||
}
|
||||
crossterm::event::Event::Mouse(MouseEvent {
|
||||
kind: MouseEventKind::ScrollUp,
|
||||
@@ -164,7 +152,6 @@ impl App<'_> {
|
||||
app_state,
|
||||
config,
|
||||
file_search,
|
||||
pending_redraw,
|
||||
chat_args,
|
||||
}
|
||||
}
|
||||
@@ -175,28 +162,6 @@ impl App<'_> {
|
||||
self.app_event_tx.clone()
|
||||
}
|
||||
|
||||
/// Schedule a redraw if one is not already pending.
|
||||
#[allow(clippy::unwrap_used)]
|
||||
fn schedule_redraw(&self) {
|
||||
// Attempt to set the flag to `true`. If it was already `true`, another
|
||||
// redraw is already pending so we can return early.
|
||||
if self
|
||||
.pending_redraw
|
||||
.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
|
||||
.is_err()
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
let tx = self.app_event_tx.clone();
|
||||
let pending_redraw = self.pending_redraw.clone();
|
||||
thread::spawn(move || {
|
||||
thread::sleep(REDRAW_DEBOUNCE);
|
||||
tx.send(AppEvent::Redraw);
|
||||
pending_redraw.store(false, Ordering::SeqCst);
|
||||
});
|
||||
}
|
||||
|
||||
pub(crate) fn run(
|
||||
&mut self,
|
||||
terminal: &mut tui::Tui,
|
||||
@@ -204,13 +169,10 @@ impl App<'_> {
|
||||
) -> Result<()> {
|
||||
// Insert an event to trigger the first render.
|
||||
let app_event_tx = self.app_event_tx.clone();
|
||||
app_event_tx.send(AppEvent::RequestRedraw);
|
||||
app_event_tx.send(AppEvent::Redraw);
|
||||
|
||||
while let Ok(event) = self.app_event_rx.recv() {
|
||||
match event {
|
||||
AppEvent::RequestRedraw => {
|
||||
self.schedule_redraw();
|
||||
}
|
||||
AppEvent::Redraw => {
|
||||
self.draw_next_frame(terminal)?;
|
||||
}
|
||||
@@ -237,21 +199,7 @@ impl App<'_> {
|
||||
modifiers: crossterm::event::KeyModifiers::CONTROL,
|
||||
..
|
||||
} => {
|
||||
match &mut self.app_state {
|
||||
AppState::Chat { widget } => {
|
||||
if widget.composer_is_empty() {
|
||||
self.app_event_tx.send(AppEvent::ExitRequest);
|
||||
} else {
|
||||
// Treat Ctrl+D as a normal key event when the composer
|
||||
// is not empty so that it doesn't quit the application
|
||||
// prematurely.
|
||||
self.dispatch_key_event(key_event);
|
||||
}
|
||||
}
|
||||
AppState::Login { .. } | AppState::GitWarning { .. } => {
|
||||
self.app_event_tx.send(AppEvent::ExitRequest);
|
||||
}
|
||||
}
|
||||
self.app_event_tx.send(AppEvent::ExitRequest);
|
||||
}
|
||||
_ => {
|
||||
self.dispatch_key_event(key_event);
|
||||
@@ -287,7 +235,7 @@ impl App<'_> {
|
||||
Vec::new(),
|
||||
));
|
||||
self.app_state = AppState::Chat { widget: new_widget };
|
||||
self.app_event_tx.send(AppEvent::RequestRedraw);
|
||||
self.app_event_tx.send(AppEvent::Redraw);
|
||||
}
|
||||
SlashCommand::ToggleMouseMode => {
|
||||
if let Err(e) = mouse_capture.toggle() {
|
||||
@@ -335,8 +283,6 @@ impl App<'_> {
|
||||
}
|
||||
|
||||
fn draw_next_frame(&mut self, terminal: &mut tui::Tui) -> Result<()> {
|
||||
// TODO: add a throttle to avoid redrawing too often
|
||||
|
||||
match &mut self.app_state {
|
||||
AppState::Chat { widget } => {
|
||||
terminal.draw(|frame| frame.render_widget_ref(&**widget, frame.area()))?;
|
||||
@@ -374,7 +320,7 @@ impl App<'_> {
|
||||
args.initial_images,
|
||||
));
|
||||
self.app_state = AppState::Chat { widget };
|
||||
self.app_event_tx.send(AppEvent::RequestRedraw);
|
||||
self.app_event_tx.send(AppEvent::Redraw);
|
||||
}
|
||||
GitWarningOutcome::Quit => {
|
||||
self.app_event_tx.send(AppEvent::ExitRequest);
|
||||
|
||||
@@ -8,10 +8,6 @@ use crate::slash_command::SlashCommand;
|
||||
pub(crate) enum AppEvent {
|
||||
CodexEvent(Event),
|
||||
|
||||
/// Request a redraw which will be debounced by the [`App`].
|
||||
RequestRedraw,
|
||||
|
||||
/// Actually draw the next frame.
|
||||
Redraw,
|
||||
|
||||
KeyEvent(KeyEvent),
|
||||
|
||||
@@ -76,11 +76,6 @@ impl ChatComposer<'_> {
|
||||
this
|
||||
}
|
||||
|
||||
/// Returns true if the composer currently contains no user input.
|
||||
pub(crate) fn is_empty(&self) -> bool {
|
||||
self.textarea.is_empty()
|
||||
}
|
||||
|
||||
/// Update the cached *context-left* percentage and refresh the placeholder
|
||||
/// text. The UI relies on the placeholder to convey the remaining
|
||||
/// context when the composer is empty.
|
||||
|
||||
@@ -72,7 +72,8 @@ impl ChatComposerHistory {
|
||||
return false;
|
||||
}
|
||||
|
||||
if textarea.is_empty() {
|
||||
let lines = textarea.lines();
|
||||
if lines.len() == 1 && lines[0].is_empty() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -84,7 +85,6 @@ impl ChatComposerHistory {
|
||||
return false;
|
||||
}
|
||||
|
||||
let lines = textarea.lines();
|
||||
matches!(&self.last_history_text, Some(prev) if prev == &lines.join("\n"))
|
||||
}
|
||||
|
||||
|
||||
@@ -162,10 +162,6 @@ impl BottomPane<'_> {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn composer_is_empty(&self) -> bool {
|
||||
self.composer.is_empty()
|
||||
}
|
||||
|
||||
pub(crate) fn is_task_running(&self) -> bool {
|
||||
self.is_task_running
|
||||
}
|
||||
@@ -212,7 +208,7 @@ impl BottomPane<'_> {
|
||||
}
|
||||
|
||||
pub(crate) fn request_redraw(&self) {
|
||||
self.app_event_tx.send(AppEvent::RequestRedraw)
|
||||
self.app_event_tx.send(AppEvent::Redraw)
|
||||
}
|
||||
|
||||
/// Returns true when a popup inside the composer is visible.
|
||||
|
||||
@@ -24,7 +24,7 @@ impl StatusIndicatorView {
|
||||
}
|
||||
}
|
||||
|
||||
impl BottomPaneView<'_> for StatusIndicatorView {
|
||||
impl<'a> BottomPaneView<'a> for StatusIndicatorView {
|
||||
fn update_status_text(&mut self, text: String) -> ConditionalUpdate {
|
||||
self.update_text(text);
|
||||
ConditionalUpdate::NeedsRedraw
|
||||
|
||||
@@ -3,9 +3,7 @@ use std::sync::Arc;
|
||||
|
||||
use codex_core::codex_wrapper::init_codex;
|
||||
use codex_core::config::Config;
|
||||
use codex_core::protocol::AgentMessageDeltaEvent;
|
||||
use codex_core::protocol::AgentMessageEvent;
|
||||
use codex_core::protocol::AgentReasoningDeltaEvent;
|
||||
use codex_core::protocol::AgentReasoningEvent;
|
||||
use codex_core::protocol::ApplyPatchApprovalRequestEvent;
|
||||
use codex_core::protocol::ErrorEvent;
|
||||
@@ -42,6 +40,15 @@ use crate::history_cell::PatchEventType;
|
||||
use crate::user_approval_widget::ApprovalRequest;
|
||||
use codex_file_search::FileMatch;
|
||||
|
||||
/// Bookkeeping for a live streaming cell. We track the `sub_id` to know when
|
||||
/// a new turn has started (and thus when to start a new cell) and accumulate
|
||||
/// the full text so we can re-render markdown cleanly when the turn ends.
|
||||
#[derive(Default)]
|
||||
struct StreamingBuf {
|
||||
sub_id: Option<String>,
|
||||
text: String,
|
||||
}
|
||||
|
||||
pub(crate) struct ChatWidget<'a> {
|
||||
app_event_tx: AppEventSender,
|
||||
codex_op_tx: UnboundedSender<Op>,
|
||||
@@ -51,8 +58,10 @@ pub(crate) struct ChatWidget<'a> {
|
||||
config: Config,
|
||||
initial_user_message: Option<UserMessage>,
|
||||
token_usage: TokenUsage,
|
||||
reasoning_buffer: String,
|
||||
answer_buffer: String,
|
||||
/// Accumulates assistant streaming text for the *current* turn.
|
||||
streaming_agent: StreamingBuf,
|
||||
/// Accumulates reasoning streaming text for the *current* turn.
|
||||
streaming_reasoning: StreamingBuf,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Eq, PartialEq)]
|
||||
@@ -139,8 +148,8 @@ impl ChatWidget<'_> {
|
||||
initial_images,
|
||||
),
|
||||
token_usage: TokenUsage::default(),
|
||||
reasoning_buffer: String::new(),
|
||||
answer_buffer: String::new(),
|
||||
streaming_agent: StreamingBuf::default(),
|
||||
streaming_reasoning: StreamingBuf::default(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -226,6 +235,8 @@ impl ChatWidget<'_> {
|
||||
|
||||
pub(crate) fn handle_codex_event(&mut self, event: Event) {
|
||||
let Event { id, msg } = event;
|
||||
// We need a copy of `id` for streaming bookkeeping because it is moved into some match arms.
|
||||
let event_id = id.clone();
|
||||
match msg {
|
||||
EventMsg::SessionConfigured(event) => {
|
||||
// Record session information at the top of the conversation.
|
||||
@@ -246,51 +257,113 @@ impl ChatWidget<'_> {
|
||||
self.request_redraw();
|
||||
}
|
||||
EventMsg::AgentMessage(AgentMessageEvent { message }) => {
|
||||
// if the answer buffer is empty, this means we haven't received any
|
||||
// delta. Thus, we need to print the message as a new answer.
|
||||
if self.answer_buffer.is_empty() {
|
||||
self.conversation_history
|
||||
.add_agent_message(&self.config, message);
|
||||
if self.config.streaming_enabled {
|
||||
// Final full assistant message. If we have an in-flight streaming cell for this id, replace it.
|
||||
let same_turn = self
|
||||
.streaming_agent
|
||||
.sub_id
|
||||
.as_ref()
|
||||
.map(|s| s == &event_id)
|
||||
.unwrap_or(false);
|
||||
if same_turn {
|
||||
self.conversation_history
|
||||
.replace_last_agent_message(&self.config, message.clone());
|
||||
self.streaming_agent.sub_id = None;
|
||||
self.streaming_agent.text.clear();
|
||||
} else {
|
||||
// Streaming enabled but we never saw deltas – just render normally.
|
||||
self.finalize_streams_if_new_turn(&event_id);
|
||||
self.conversation_history
|
||||
.add_agent_message(&self.config, message.clone());
|
||||
}
|
||||
} else {
|
||||
// Streaming disabled -> always render final message, ignore any deltas.
|
||||
self.conversation_history
|
||||
.replace_prev_agent_message(&self.config, message);
|
||||
.add_agent_message(&self.config, message.clone());
|
||||
}
|
||||
self.answer_buffer.clear();
|
||||
self.request_redraw();
|
||||
}
|
||||
EventMsg::AgentMessageDelta(AgentMessageDeltaEvent { delta }) => {
|
||||
if self.answer_buffer.is_empty() {
|
||||
self.conversation_history
|
||||
.add_agent_message(&self.config, "".to_string());
|
||||
EventMsg::AgentMessageDelta(AgentMessageEvent { message }) => {
|
||||
// Streaming Assistant text.
|
||||
if !self.config.streaming_enabled {
|
||||
// Ignore when streaming disabled.
|
||||
return;
|
||||
}
|
||||
self.answer_buffer.push_str(&delta.clone());
|
||||
self.conversation_history
|
||||
.replace_prev_agent_message(&self.config, self.answer_buffer.clone());
|
||||
self.request_redraw();
|
||||
}
|
||||
EventMsg::AgentReasoningDelta(AgentReasoningDeltaEvent { delta }) => {
|
||||
if self.reasoning_buffer.is_empty() {
|
||||
// Start a new cell if this delta belongs to a new turn.
|
||||
let is_new_stream = self
|
||||
.streaming_agent
|
||||
.sub_id
|
||||
.as_ref()
|
||||
.map(|s| s != &event_id)
|
||||
.unwrap_or(true);
|
||||
if is_new_stream {
|
||||
// Finalise any in-flight stream from the prior turn.
|
||||
self.finalize_streams_if_new_turn(&event_id);
|
||||
// Start a header-only streaming cell so we don't parse partial markdown.
|
||||
self.conversation_history
|
||||
.add_agent_reasoning(&self.config, "".to_string());
|
||||
.add_agent_message(&self.config, String::new());
|
||||
self.streaming_agent.sub_id = Some(event_id.clone());
|
||||
self.streaming_agent.text.clear();
|
||||
}
|
||||
self.reasoning_buffer.push_str(&delta.clone());
|
||||
// Accumulate full text; incremental markdown re-render happens in ConversationHistoryWidget.
|
||||
self.streaming_agent.text.push_str(&message);
|
||||
self.conversation_history
|
||||
.replace_prev_agent_reasoning(&self.config, self.reasoning_buffer.clone());
|
||||
.append_agent_message_delta(&self.config, message);
|
||||
self.request_redraw();
|
||||
}
|
||||
EventMsg::AgentReasoning(AgentReasoningEvent { text }) => {
|
||||
// if the reasoning buffer is empty, this means we haven't received any
|
||||
// delta. Thus, we need to print the message as a new reasoning.
|
||||
if self.reasoning_buffer.is_empty() {
|
||||
self.conversation_history
|
||||
.add_agent_reasoning(&self.config, "".to_string());
|
||||
} else {
|
||||
// else, we rerender one last time.
|
||||
self.conversation_history
|
||||
.replace_prev_agent_reasoning(&self.config, text);
|
||||
if !self.config.hide_agent_reasoning {
|
||||
if self.config.streaming_enabled {
|
||||
// Final full reasoning summary. Replace streaming cell if same turn.
|
||||
let same_turn = self
|
||||
.streaming_reasoning
|
||||
.sub_id
|
||||
.as_ref()
|
||||
.map(|s| s == &event_id)
|
||||
.unwrap_or(false);
|
||||
if same_turn {
|
||||
self.conversation_history
|
||||
.replace_last_agent_reasoning(&self.config, text.clone());
|
||||
self.streaming_reasoning.sub_id = None;
|
||||
self.streaming_reasoning.text.clear();
|
||||
} else {
|
||||
self.finalize_streams_if_new_turn(&event_id);
|
||||
self.conversation_history
|
||||
.add_agent_reasoning(&self.config, text.clone());
|
||||
}
|
||||
} else {
|
||||
self.conversation_history
|
||||
.add_agent_reasoning(&self.config, text.clone());
|
||||
}
|
||||
self.request_redraw();
|
||||
}
|
||||
}
|
||||
EventMsg::AgentReasoningDelta(AgentReasoningEvent { text }) => {
|
||||
if !self.config.hide_agent_reasoning {
|
||||
if !self.config.streaming_enabled {
|
||||
// Ignore when streaming disabled.
|
||||
return;
|
||||
}
|
||||
let is_new_stream = self
|
||||
.streaming_reasoning
|
||||
.sub_id
|
||||
.as_ref()
|
||||
.map(|s| s != &event_id)
|
||||
.unwrap_or(true);
|
||||
if is_new_stream {
|
||||
self.finalize_streams_if_new_turn(&event_id);
|
||||
// Start header-only streaming cell.
|
||||
self.conversation_history
|
||||
.add_agent_reasoning(&self.config, String::new());
|
||||
self.streaming_reasoning.sub_id = Some(event_id.clone());
|
||||
self.streaming_reasoning.text.clear();
|
||||
}
|
||||
// Accumulate full text; incremental markdown re-render happens in ConversationHistoryWidget.
|
||||
self.streaming_reasoning.text.push_str(&text);
|
||||
self.conversation_history
|
||||
.append_agent_reasoning_delta(&self.config, text);
|
||||
self.request_redraw();
|
||||
}
|
||||
self.reasoning_buffer.clear();
|
||||
self.request_redraw();
|
||||
}
|
||||
EventMsg::TaskStarted => {
|
||||
self.bottom_pane.clear_ctrl_c_quit_hint();
|
||||
@@ -300,6 +373,8 @@ impl ChatWidget<'_> {
|
||||
EventMsg::TaskComplete(TaskCompleteEvent {
|
||||
last_agent_message: _,
|
||||
}) => {
|
||||
// Turn has ended – ensure no lingering streaming cells remain un-finalised.
|
||||
self.finalize_streams();
|
||||
self.bottom_pane.set_task_running(false);
|
||||
self.request_redraw();
|
||||
}
|
||||
@@ -431,7 +506,7 @@ impl ChatWidget<'_> {
|
||||
}
|
||||
|
||||
fn request_redraw(&mut self) {
|
||||
self.app_event_tx.send(AppEvent::RequestRedraw);
|
||||
self.app_event_tx.send(AppEvent::Redraw);
|
||||
}
|
||||
|
||||
pub(crate) fn add_diff_output(&mut self, diff_output: String) {
|
||||
@@ -464,8 +539,6 @@ impl ChatWidget<'_> {
|
||||
if self.bottom_pane.is_task_running() {
|
||||
self.bottom_pane.clear_ctrl_c_quit_hint();
|
||||
self.submit_op(Op::Interrupt);
|
||||
self.answer_buffer.clear();
|
||||
self.reasoning_buffer.clear();
|
||||
false
|
||||
} else if self.bottom_pane.ctrl_c_quit_hint_visible() {
|
||||
true
|
||||
@@ -475,16 +548,48 @@ impl ChatWidget<'_> {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn composer_is_empty(&self) -> bool {
|
||||
self.bottom_pane.composer_is_empty()
|
||||
}
|
||||
|
||||
/// Forward an `Op` directly to codex.
|
||||
pub(crate) fn submit_op(&self, op: Op) {
|
||||
if let Err(e) = self.codex_op_tx.send(op) {
|
||||
tracing::error!("failed to submit op: {e}");
|
||||
}
|
||||
}
|
||||
|
||||
/// Finalise (render) streaming buffers when we detect a new turn id.
|
||||
fn finalize_streams_if_new_turn(&mut self, new_id: &str) {
|
||||
// If the incoming id differs from the current stream id(s) we must flush.
|
||||
let agent_changed = self
|
||||
.streaming_agent
|
||||
.sub_id
|
||||
.as_ref()
|
||||
.map(|s| s != new_id)
|
||||
.unwrap_or(false);
|
||||
let reasoning_changed = self
|
||||
.streaming_reasoning
|
||||
.sub_id
|
||||
.as_ref()
|
||||
.map(|s| s != new_id)
|
||||
.unwrap_or(false);
|
||||
if agent_changed || reasoning_changed {
|
||||
self.finalize_streams();
|
||||
}
|
||||
}
|
||||
|
||||
/// Re-render any in-flight streaming cells with full markdown and clear buffers.
|
||||
fn finalize_streams(&mut self) {
|
||||
let had_agent = self.streaming_agent.sub_id.take().is_some();
|
||||
if had_agent {
|
||||
let text = std::mem::take(&mut self.streaming_agent.text);
|
||||
self.conversation_history
|
||||
.replace_last_agent_message(&self.config, text);
|
||||
}
|
||||
let had_reasoning = self.streaming_reasoning.sub_id.take().is_some();
|
||||
if had_reasoning {
|
||||
let text = std::mem::take(&mut self.streaming_reasoning.text);
|
||||
self.conversation_history
|
||||
.replace_last_agent_reasoning(&self.config, text);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl WidgetRef for &ChatWidget<'_> {
|
||||
|
||||
@@ -9,9 +9,9 @@ use crossterm::event::KeyCode;
|
||||
use crossterm::event::KeyEvent;
|
||||
use ratatui::prelude::*;
|
||||
use ratatui::style::Style;
|
||||
use ratatui::text::Span;
|
||||
use ratatui::widgets::*;
|
||||
use serde_json::Value as JsonValue;
|
||||
use std::cell::Cell as StdCell;
|
||||
use std::cell::Cell;
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
@@ -26,24 +26,30 @@ pub struct ConversationHistoryWidget {
|
||||
entries: Vec<Entry>,
|
||||
/// The width (in terminal cells/columns) that [`Entry::line_count`] was
|
||||
/// computed for. When the available width changes we recompute counts.
|
||||
cached_width: StdCell<u16>,
|
||||
cached_width: Cell<u16>,
|
||||
scroll_position: usize,
|
||||
/// Number of lines the last time render_ref() was called
|
||||
num_rendered_lines: StdCell<usize>,
|
||||
num_rendered_lines: Cell<usize>,
|
||||
/// The height of the viewport last time render_ref() was called
|
||||
last_viewport_height: StdCell<usize>,
|
||||
last_viewport_height: Cell<usize>,
|
||||
has_input_focus: bool,
|
||||
/// Scratch buffer used while incrementally streaming an agent message so we can re-render markdown at newline boundaries.
|
||||
streaming_agent_message_buf: String,
|
||||
/// Scratch buffer used while incrementally streaming agent reasoning so we can re-render markdown at newline boundaries.
|
||||
streaming_agent_reasoning_buf: String,
|
||||
}
|
||||
|
||||
impl ConversationHistoryWidget {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
entries: Vec::new(),
|
||||
cached_width: StdCell::new(0),
|
||||
cached_width: Cell::new(0),
|
||||
scroll_position: usize::MAX,
|
||||
num_rendered_lines: StdCell::new(0),
|
||||
last_viewport_height: StdCell::new(0),
|
||||
num_rendered_lines: Cell::new(0),
|
||||
last_viewport_height: Cell::new(0),
|
||||
has_input_focus: false,
|
||||
streaming_agent_message_buf: String::new(),
|
||||
streaming_agent_reasoning_buf: String::new(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -84,38 +90,26 @@ impl ConversationHistoryWidget {
|
||||
}
|
||||
|
||||
fn scroll_up(&mut self, num_lines: u32) {
|
||||
// If a user is scrolling up from the "stick to bottom" mode, we need to
|
||||
// map this to a specific scroll position so we can calculate the delta.
|
||||
// This requires us to care about how tall the screen is.
|
||||
// Convert sticky-to-bottom sentinel into a concrete offset anchored at the bottom.
|
||||
if self.scroll_position == usize::MAX {
|
||||
self.scroll_position = self
|
||||
.num_rendered_lines
|
||||
.get()
|
||||
.saturating_sub(self.last_viewport_height.get());
|
||||
self.scroll_position = sticky_offset(
|
||||
self.num_rendered_lines.get(),
|
||||
self.last_viewport_height.get(),
|
||||
);
|
||||
}
|
||||
|
||||
self.scroll_position = self.scroll_position.saturating_sub(num_lines as usize);
|
||||
}
|
||||
|
||||
fn scroll_down(&mut self, num_lines: u32) {
|
||||
// If we're already pinned to the bottom there's nothing to do.
|
||||
// Nothing to do if we're already pinned to the bottom.
|
||||
if self.scroll_position == usize::MAX {
|
||||
return;
|
||||
}
|
||||
|
||||
let viewport_height = self.last_viewport_height.get().max(1);
|
||||
let num_rendered_lines = self.num_rendered_lines.get();
|
||||
|
||||
// Compute the maximum explicit scroll offset that still shows a full
|
||||
// viewport. This mirrors the calculation in `scroll_page_down()` and
|
||||
// in the render path.
|
||||
let max_scroll = num_rendered_lines.saturating_sub(viewport_height);
|
||||
|
||||
let max_scroll = sticky_offset(self.num_rendered_lines.get(), viewport_height);
|
||||
let new_pos = self.scroll_position.saturating_add(num_lines as usize);
|
||||
|
||||
if new_pos >= max_scroll {
|
||||
// Reached (or passed) the bottom – switch to stick‑to‑bottom mode
|
||||
// so that additional output keeps the view pinned automatically.
|
||||
// Switch to sticky-bottom mode so subsequent output pins view.
|
||||
self.scroll_position = usize::MAX;
|
||||
} else {
|
||||
self.scroll_position = new_pos;
|
||||
@@ -125,44 +119,21 @@ impl ConversationHistoryWidget {
|
||||
/// Scroll up by one full viewport height (Page Up).
|
||||
fn scroll_page_up(&mut self) {
|
||||
let viewport_height = self.last_viewport_height.get().max(1);
|
||||
|
||||
// If we are currently in the "stick to bottom" mode, first convert the
|
||||
// implicit scroll position (`usize::MAX`) into an explicit offset that
|
||||
// represents the very bottom of the scroll region. This mirrors the
|
||||
// logic from `scroll_up()`.
|
||||
if self.scroll_position == usize::MAX {
|
||||
self.scroll_position = self
|
||||
.num_rendered_lines
|
||||
.get()
|
||||
.saturating_sub(viewport_height);
|
||||
self.scroll_position = sticky_offset(self.num_rendered_lines.get(), viewport_height);
|
||||
}
|
||||
|
||||
// Move up by a full page.
|
||||
self.scroll_position = self.scroll_position.saturating_sub(viewport_height);
|
||||
}
|
||||
|
||||
/// Scroll down by one full viewport height (Page Down).
|
||||
fn scroll_page_down(&mut self) {
|
||||
// Nothing to do if we're already stuck to the bottom.
|
||||
if self.scroll_position == usize::MAX {
|
||||
return;
|
||||
}
|
||||
|
||||
let viewport_height = self.last_viewport_height.get().max(1);
|
||||
let num_lines = self.num_rendered_lines.get();
|
||||
|
||||
// Calculate the maximum explicit scroll offset that is still within
|
||||
// range. This matches the logic in `scroll_down()` and the render
|
||||
// method.
|
||||
let max_scroll = num_lines.saturating_sub(viewport_height);
|
||||
|
||||
// Attempt to move down by a full page.
|
||||
let max_scroll = sticky_offset(self.num_rendered_lines.get(), viewport_height);
|
||||
let new_pos = self.scroll_position.saturating_add(viewport_height);
|
||||
|
||||
if new_pos >= max_scroll {
|
||||
// We have reached (or passed) the bottom – switch back to
|
||||
// automatic stick‑to‑bottom mode so that subsequent output keeps
|
||||
// the viewport pinned.
|
||||
self.scroll_position = usize::MAX;
|
||||
} else {
|
||||
self.scroll_position = new_pos;
|
||||
@@ -195,19 +166,105 @@ impl ConversationHistoryWidget {
|
||||
}
|
||||
|
||||
pub fn add_agent_message(&mut self, config: &Config, message: String) {
|
||||
// Reset streaming scratch because we are starting a fresh agent message.
|
||||
self.streaming_agent_message_buf.clear();
|
||||
self.streaming_agent_message_buf.push_str(&message);
|
||||
self.add_to_history(HistoryCell::new_agent_message(config, message));
|
||||
}
|
||||
|
||||
pub fn add_agent_reasoning(&mut self, config: &Config, text: String) {
|
||||
self.streaming_agent_reasoning_buf.clear();
|
||||
self.streaming_agent_reasoning_buf.push_str(&text);
|
||||
self.add_to_history(HistoryCell::new_agent_reasoning(config, text));
|
||||
}
|
||||
|
||||
pub fn replace_prev_agent_reasoning(&mut self, config: &Config, text: String) {
|
||||
self.replace_last_agent_reasoning(config, text);
|
||||
/// Append incremental assistant text.
|
||||
///
|
||||
/// Previous heuristic: fast‑append chunks until we saw a newline, then re‑render.
|
||||
/// This caused visible "one‑word" lines (e.g., "The" -> "The user") when models
|
||||
/// streamed small token fragments and also delayed Markdown styling (headings, code fences)
|
||||
/// until the first newline arrived. To improve perceived quality we now *always* re‑render
|
||||
/// the accumulated markdown buffer on every incoming delta chunk. We still apply the
|
||||
/// soft‑break collapsing heuristic (outside fenced code blocks) so interim layout more closely
|
||||
/// matches the final message and reduces layout thrash.
|
||||
pub fn append_agent_message_delta(&mut self, _config: &Config, text: String) {
|
||||
if text.is_empty() {
|
||||
return;
|
||||
}
|
||||
// Accumulate full buffer.
|
||||
self.streaming_agent_message_buf.push_str(&text);
|
||||
|
||||
let collapsed = collapse_single_newlines_for_streaming(&self.streaming_agent_message_buf);
|
||||
if let Some(idx) = last_agent_message_idx(&self.entries) {
|
||||
let width = self.cached_width.get();
|
||||
let entry = &mut self.entries[idx];
|
||||
entry.cell = HistoryCell::new_agent_message(_config, collapsed);
|
||||
// Drop trailing blank so we can continue streaming additional tokens cleanly.
|
||||
if let HistoryCell::AgentMessage { view } = &mut entry.cell {
|
||||
drop_trailing_blank_line(&mut view.lines);
|
||||
}
|
||||
if width > 0 {
|
||||
update_entry_height(entry, width);
|
||||
}
|
||||
} else {
|
||||
// No existing cell? Start a new one.
|
||||
self.add_agent_message(_config, self.streaming_agent_message_buf.clone());
|
||||
}
|
||||
}
|
||||
|
||||
pub fn replace_prev_agent_message(&mut self, config: &Config, text: String) {
|
||||
self.replace_last_agent_message(config, text);
|
||||
/// Append incremental reasoning text (mirrors `append_agent_message_delta`).
|
||||
pub fn append_agent_reasoning_delta(&mut self, _config: &Config, text: String) {
|
||||
if text.is_empty() {
|
||||
return;
|
||||
}
|
||||
self.streaming_agent_reasoning_buf.push_str(&text);
|
||||
|
||||
let collapsed = collapse_single_newlines_for_streaming(&self.streaming_agent_reasoning_buf);
|
||||
if let Some(idx) = last_agent_reasoning_idx(&self.entries) {
|
||||
let width = self.cached_width.get();
|
||||
let entry = &mut self.entries[idx];
|
||||
entry.cell = HistoryCell::new_agent_reasoning(_config, collapsed);
|
||||
if let HistoryCell::AgentReasoning { view } = &mut entry.cell {
|
||||
drop_trailing_blank_line(&mut view.lines);
|
||||
}
|
||||
if width > 0 {
|
||||
update_entry_height(entry, width);
|
||||
}
|
||||
} else {
|
||||
self.add_agent_reasoning(_config, self.streaming_agent_reasoning_buf.clone());
|
||||
}
|
||||
}
|
||||
|
||||
/// Replace the most recent AgentMessage cell with the fully accumulated `text`.
|
||||
/// This should be called once the turn is complete so we can render proper markdown.
|
||||
pub fn replace_last_agent_message(&mut self, config: &Config, text: String) {
|
||||
self.streaming_agent_message_buf.clear();
|
||||
if let Some(idx) = last_agent_message_idx(&self.entries) {
|
||||
let width = self.cached_width.get();
|
||||
let entry = &mut self.entries[idx];
|
||||
entry.cell = HistoryCell::new_agent_message(config, text);
|
||||
if width > 0 {
|
||||
update_entry_height(entry, width);
|
||||
}
|
||||
} else {
|
||||
// No existing AgentMessage (shouldn't happen) – append new.
|
||||
self.add_agent_message(config, text);
|
||||
}
|
||||
}
|
||||
|
||||
/// Replace the most recent AgentReasoning cell with the fully accumulated `text`.
|
||||
pub fn replace_last_agent_reasoning(&mut self, config: &Config, text: String) {
|
||||
self.streaming_agent_reasoning_buf.clear();
|
||||
if let Some(idx) = last_agent_reasoning_idx(&self.entries) {
|
||||
let width = self.cached_width.get();
|
||||
let entry = &mut self.entries[idx];
|
||||
entry.cell = HistoryCell::new_agent_reasoning(config, text);
|
||||
if width > 0 {
|
||||
update_entry_height(entry, width);
|
||||
}
|
||||
} else {
|
||||
self.add_agent_reasoning(config, text);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add_background_event(&mut self, message: String) {
|
||||
@@ -257,42 +314,6 @@ impl ConversationHistoryWidget {
|
||||
});
|
||||
}
|
||||
|
||||
pub fn replace_last_agent_reasoning(&mut self, config: &Config, text: String) {
|
||||
if let Some(idx) = self
|
||||
.entries
|
||||
.iter()
|
||||
.rposition(|entry| matches!(entry.cell, HistoryCell::AgentReasoning { .. }))
|
||||
{
|
||||
let width = self.cached_width.get();
|
||||
let entry = &mut self.entries[idx];
|
||||
entry.cell = HistoryCell::new_agent_reasoning(config, text);
|
||||
let height = if width > 0 {
|
||||
entry.cell.height(width)
|
||||
} else {
|
||||
0
|
||||
};
|
||||
entry.line_count.set(height);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn replace_last_agent_message(&mut self, config: &Config, text: String) {
|
||||
if let Some(idx) = self
|
||||
.entries
|
||||
.iter()
|
||||
.rposition(|entry| matches!(entry.cell, HistoryCell::AgentMessage { .. }))
|
||||
{
|
||||
let width = self.cached_width.get();
|
||||
let entry = &mut self.entries[idx];
|
||||
entry.cell = HistoryCell::new_agent_message(config, text);
|
||||
let height = if width > 0 {
|
||||
entry.cell.height(width)
|
||||
} else {
|
||||
0
|
||||
};
|
||||
entry.line_count.set(height);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn record_completed_exec_command(
|
||||
&mut self,
|
||||
call_id: String,
|
||||
@@ -323,7 +344,7 @@ impl ConversationHistoryWidget {
|
||||
|
||||
// Update cached line count.
|
||||
if width > 0 {
|
||||
entry.line_count.set(cell.height(width));
|
||||
update_entry_height(entry, width);
|
||||
}
|
||||
break;
|
||||
}
|
||||
@@ -357,7 +378,7 @@ impl ConversationHistoryWidget {
|
||||
entry.cell = completed;
|
||||
|
||||
if width > 0 {
|
||||
entry.line_count.set(entry.cell.height(width));
|
||||
update_entry_height(entry, width);
|
||||
}
|
||||
|
||||
break;
|
||||
@@ -380,7 +401,7 @@ impl WidgetRef for ConversationHistoryWidget {
|
||||
|
||||
let block = Block::default()
|
||||
.title(title)
|
||||
.borders(Borders::NONE)
|
||||
.borders(Borders::ALL)
|
||||
.border_type(BorderType::Rounded)
|
||||
.border_style(border_style);
|
||||
|
||||
@@ -391,9 +412,9 @@ impl WidgetRef for ConversationHistoryWidget {
|
||||
|
||||
// Cache (and if necessary recalculate) the wrapped line counts for every
|
||||
// [`HistoryCell`] so that our scrolling math accounts for text
|
||||
// wrapping. The full inner width is now used because the scrollbar has
|
||||
// been disabled.
|
||||
let effective_width = inner.width;
|
||||
// wrapping. We always reserve one column on the right-hand side for the
|
||||
// scrollbar so that the content never renders "under" the scrollbar.
|
||||
let effective_width = inner.width.saturating_sub(1);
|
||||
|
||||
if effective_width == 0 {
|
||||
return; // Nothing to draw – avoid division by zero.
|
||||
@@ -414,14 +435,12 @@ impl WidgetRef for ConversationHistoryWidget {
|
||||
self.entries.iter().map(|e| e.line_count.get()).sum()
|
||||
};
|
||||
|
||||
// Determine the scroll position. Note the existing value of
|
||||
// `self.scroll_position` could exceed the maximum scroll offset if the
|
||||
// user made the window wider since the last render.
|
||||
let max_scroll = num_lines.saturating_sub(viewport_height);
|
||||
// Determine the scroll position (respect sticky-to-bottom sentinel and clamp).
|
||||
let max_scroll = sticky_offset(num_lines, viewport_height);
|
||||
let scroll_pos = if self.scroll_position == usize::MAX {
|
||||
max_scroll
|
||||
} else {
|
||||
self.scroll_position.min(max_scroll)
|
||||
clamp_scroll_pos(self.scroll_position, max_scroll)
|
||||
};
|
||||
|
||||
// ------------------------------------------------------------------
|
||||
@@ -486,7 +505,48 @@ impl WidgetRef for ConversationHistoryWidget {
|
||||
}
|
||||
}
|
||||
|
||||
// Scrollbar intentionally disabled: scrolling still functions via key / mouse events.
|
||||
// Always render a scrollbar *track* so the reserved column is filled.
|
||||
let overflow = num_lines.saturating_sub(viewport_height);
|
||||
|
||||
let mut scroll_state = ScrollbarState::default()
|
||||
// The Scrollbar widget expects the *content* height minus the
|
||||
// viewport height. When there is no overflow we still provide 0
|
||||
// so that the widget renders only the track without a thumb.
|
||||
.content_length(overflow)
|
||||
.position(scroll_pos);
|
||||
|
||||
{
|
||||
// Choose a thumb color that stands out only when this pane has focus so that the
|
||||
// user's attention is naturally drawn to the active viewport. When unfocused we show
|
||||
// a low-contrast thumb so the scrollbar fades into the background without becoming
|
||||
// invisible.
|
||||
let thumb_style = if self.has_input_focus {
|
||||
Style::reset().fg(Color::LightYellow)
|
||||
} else {
|
||||
Style::reset().fg(Color::Gray)
|
||||
};
|
||||
|
||||
// By default the Scrollbar widget inherits any style that was
|
||||
// present in the underlying buffer cells. That means if a colored
|
||||
// line happens to be underneath the scrollbar, the track (and
|
||||
// potentially the thumb) adopt that color. Explicitly setting the
|
||||
// track/thumb styles ensures we always draw the scrollbar with a
|
||||
// consistent palette regardless of what content is behind it.
|
||||
StatefulWidget::render(
|
||||
Scrollbar::new(ScrollbarOrientation::VerticalRight)
|
||||
.begin_symbol(Some("↑"))
|
||||
.end_symbol(Some("↓"))
|
||||
.begin_style(Style::reset().fg(Color::DarkGray))
|
||||
.end_style(Style::reset().fg(Color::DarkGray))
|
||||
.thumb_symbol("█")
|
||||
.thumb_style(thumb_style)
|
||||
.track_symbol(Some("│"))
|
||||
.track_style(Style::reset().fg(Color::DarkGray)),
|
||||
inner,
|
||||
buf,
|
||||
&mut scroll_state,
|
||||
);
|
||||
}
|
||||
|
||||
// Update auxiliary stats that the scroll handlers rely on.
|
||||
self.num_rendered_lines.set(num_lines);
|
||||
@@ -500,3 +560,118 @@ impl WidgetRef for ConversationHistoryWidget {
|
||||
pub(crate) const fn wrap_cfg() -> ratatui::widgets::Wrap {
|
||||
ratatui::widgets::Wrap { trim: false }
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Scrolling helpers (private)
|
||||
// ---------------------------------------------------------------------------
|
||||
#[inline]
|
||||
fn sticky_offset(num_lines: usize, viewport_height: usize) -> usize {
|
||||
num_lines.saturating_sub(viewport_height.max(1))
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn clamp_scroll_pos(pos: usize, max_scroll: usize) -> usize {
|
||||
pos.min(max_scroll)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Streaming helpers (private)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Locate the most recent `HistoryCell::AgentMessage` entry.
|
||||
fn last_agent_message_idx(entries: &[Entry]) -> Option<usize> {
|
||||
entries
|
||||
.iter()
|
||||
.rposition(|e| matches!(e.cell, HistoryCell::AgentMessage { .. }))
|
||||
}
|
||||
|
||||
/// Locate the most recent `HistoryCell::AgentReasoning` entry.
|
||||
fn last_agent_reasoning_idx(entries: &[Entry]) -> Option<usize> {
|
||||
entries
|
||||
.iter()
|
||||
.rposition(|e| matches!(e.cell, HistoryCell::AgentReasoning { .. }))
|
||||
}
|
||||
|
||||
/// True if the line is an empty spacer (single empty span).
|
||||
fn is_blank_line(line: &Line<'_>) -> bool {
|
||||
line.spans.len() == 1 && line.spans[0].content.is_empty()
|
||||
}
|
||||
|
||||
/// Ensure that the vector has *at least* one body line after the header.
|
||||
/// A freshly-created AgentMessage/Reasoning cell always has a header + blank line,
|
||||
/// but streaming cells may be created empty; this makes sure we have a target line.
|
||||
#[allow(dead_code)]
|
||||
fn ensure_body_line(lines: &mut Vec<Line<'static>>) {
|
||||
if lines.len() < 2 {
|
||||
lines.push(Line::from(""));
|
||||
}
|
||||
}
|
||||
|
||||
/// Trim a single trailing blank spacer line (but preserve intentional paragraph breaks).
|
||||
fn drop_trailing_blank_line(lines: &mut Vec<Line<'static>>) {
|
||||
if let Some(last) = lines.last() {
|
||||
if is_blank_line(last) {
|
||||
lines.pop();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Append streaming text, honouring embedded newlines.
|
||||
#[allow(dead_code)]
|
||||
fn append_streaming_text_chunks(lines: &mut Vec<Line<'static>>, text: &str) {
|
||||
// NOTE: This helper is now a fallback path only (we eagerly re-render accumulated markdown).
|
||||
// Still, keep behaviour sane: drop trailing spacer, ensure a writable body line, then append.
|
||||
drop_trailing_blank_line(lines);
|
||||
ensure_body_line(lines);
|
||||
if let Some(last_line) = lines.last_mut() {
|
||||
last_line.spans.push(Span::raw(text.to_string()));
|
||||
} else {
|
||||
lines.push(Line::from(text.to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
/// Re-measure a mutated entry at `width` columns and update its cached height.
|
||||
fn update_entry_height(entry: &Entry, width: u16) {
|
||||
entry.line_count.set(entry.cell.height(width));
|
||||
}
|
||||
|
||||
/// Collapse *single* newlines in a streaming buffer into spaces so that interim streaming
|
||||
/// renders more closely match final Markdown layout — *except* when we detect fenced code blocks.
|
||||
/// If the accumulated text contains a Markdown code fence (``` or ~~~), we preserve **all**
|
||||
/// newlines verbatim so multi-line code renders correctly while streaming.
|
||||
fn collapse_single_newlines_for_streaming(src: &str) -> String {
|
||||
// Quick fence detection. If we see a code fence marker anywhere in the accumulated text,
|
||||
// skip collapsing entirely so we do not mangle code formatting.
|
||||
if src.contains("```") || src.contains("~~~") {
|
||||
return src.to_string();
|
||||
}
|
||||
|
||||
let mut out = String::with_capacity(src.len());
|
||||
let mut pending_newlines = 0usize;
|
||||
for ch in src.chars() {
|
||||
if ch == '\n' {
|
||||
pending_newlines += 1;
|
||||
continue;
|
||||
}
|
||||
if pending_newlines == 1 {
|
||||
// soft break -> space
|
||||
out.push(' ');
|
||||
} else if pending_newlines > 1 {
|
||||
// preserve paragraph breaks exactly
|
||||
for _ in 0..pending_newlines {
|
||||
out.push('\n');
|
||||
}
|
||||
}
|
||||
pending_newlines = 0;
|
||||
out.push(ch);
|
||||
}
|
||||
// flush tail
|
||||
if pending_newlines == 1 {
|
||||
out.push(' ');
|
||||
} else if pending_newlines > 1 {
|
||||
for _ in 0..pending_newlines {
|
||||
out.push('\n');
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
@@ -17,7 +17,6 @@ use image::GenericImageView;
|
||||
use image::ImageReader;
|
||||
use lazy_static::lazy_static;
|
||||
use mcp_types::EmbeddedResourceResource;
|
||||
use mcp_types::ResourceLink;
|
||||
use ratatui::prelude::*;
|
||||
use ratatui::style::Color;
|
||||
use ratatui::style::Modifier;
|
||||
@@ -332,7 +331,8 @@ impl HistoryCell {
|
||||
) -> Option<Self> {
|
||||
match result {
|
||||
Ok(mcp_types::CallToolResult { content, .. }) => {
|
||||
if let Some(mcp_types::ContentBlock::ImageContent(image)) = content.first() {
|
||||
if let Some(mcp_types::CallToolResultContent::ImageContent(image)) = content.first()
|
||||
{
|
||||
let raw_data =
|
||||
match base64::engine::general_purpose::STANDARD.decode(&image.data) {
|
||||
Ok(data) => data,
|
||||
@@ -405,21 +405,21 @@ impl HistoryCell {
|
||||
|
||||
for tool_call_result in content {
|
||||
let line_text = match tool_call_result {
|
||||
mcp_types::ContentBlock::TextContent(text) => {
|
||||
mcp_types::CallToolResultContent::TextContent(text) => {
|
||||
format_and_truncate_tool_result(
|
||||
&text.text,
|
||||
TOOL_CALL_MAX_LINES,
|
||||
num_cols as usize,
|
||||
)
|
||||
}
|
||||
mcp_types::ContentBlock::ImageContent(_) => {
|
||||
mcp_types::CallToolResultContent::ImageContent(_) => {
|
||||
// TODO show images even if they're not the first result, will require a refactor of `CompletedMcpToolCall`
|
||||
"<image content>".to_string()
|
||||
}
|
||||
mcp_types::ContentBlock::AudioContent(_) => {
|
||||
mcp_types::CallToolResultContent::AudioContent(_) => {
|
||||
"<audio content>".to_string()
|
||||
}
|
||||
mcp_types::ContentBlock::EmbeddedResource(resource) => {
|
||||
mcp_types::CallToolResultContent::EmbeddedResource(resource) => {
|
||||
let uri = match resource.resource {
|
||||
EmbeddedResourceResource::TextResourceContents(text) => {
|
||||
text.uri
|
||||
@@ -430,9 +430,6 @@ impl HistoryCell {
|
||||
};
|
||||
format!("embedded resource: {uri}")
|
||||
}
|
||||
mcp_types::ContentBlock::ResourceLink(ResourceLink { uri, .. }) => {
|
||||
format!("link: {uri}")
|
||||
}
|
||||
};
|
||||
lines.push(Line::styled(line_text, Style::default().fg(Color::Gray)));
|
||||
}
|
||||
|
||||
@@ -65,7 +65,7 @@ impl StatusIndicatorWidget {
|
||||
std::thread::sleep(Duration::from_millis(200));
|
||||
counter = counter.wrapping_add(1);
|
||||
frame_idx_clone.store(counter, Ordering::Relaxed);
|
||||
app_event_tx_clone.send(AppEvent::RequestRedraw);
|
||||
app_event_tx_clone.send(AppEvent::Redraw);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user