mirror of
https://github.com/openai/codex.git
synced 2026-02-01 22:47:52 +00:00
Compare commits
24 Commits
joshka/tui
...
jif/client
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
001ed59f5c | ||
|
|
166ca2fce7 | ||
|
|
f5918d7e1b | ||
|
|
34b57eff3f | ||
|
|
c3480d94a1 | ||
|
|
3d58659451 | ||
|
|
fe95c24442 | ||
|
|
769e9cc92c | ||
|
|
b7dcc8ef5c | ||
|
|
0c609d441b | ||
|
|
f6494aa85c | ||
|
|
5fc0c39386 | ||
|
|
e13d10531a | ||
|
|
ba2873074e | ||
|
|
6239decccc | ||
|
|
7a5786f49f | ||
|
|
5b43146ba5 | ||
|
|
690de0d4c6 | ||
|
|
2e2e9627de | ||
|
|
6a767b7230 | ||
|
|
9c267f0204 | ||
|
|
10c880d886 | ||
|
|
dabf219a45 | ||
|
|
1bac24f827 |
46
codex-rs/Cargo.lock
generated
46
codex-rs/Cargo.lock
generated
@@ -830,6 +830,29 @@ dependencies = [
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "codex-api-client"
|
||||
version = "0.0.0"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"bytes",
|
||||
"codex-app-server-protocol",
|
||||
"codex-otel",
|
||||
"codex-protocol",
|
||||
"codex-provider-config",
|
||||
"futures",
|
||||
"maplit",
|
||||
"pretty_assertions",
|
||||
"regex-lite",
|
||||
"reqwest",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"thiserror 2.0.17",
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "codex-app-server"
|
||||
version = "0.0.0"
|
||||
@@ -1065,6 +1088,7 @@ dependencies = [
|
||||
"base64",
|
||||
"bytes",
|
||||
"chrono",
|
||||
"codex-api-client",
|
||||
"codex-app-server-protocol",
|
||||
"codex-apply-patch",
|
||||
"codex-arg0",
|
||||
@@ -1074,6 +1098,7 @@ dependencies = [
|
||||
"codex-keyring-store",
|
||||
"codex-otel",
|
||||
"codex-protocol",
|
||||
"codex-provider-config",
|
||||
"codex-rmcp-client",
|
||||
"codex-utils-pty",
|
||||
"codex-utils-readiness",
|
||||
@@ -1368,6 +1393,17 @@ dependencies = [
|
||||
"uuid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "codex-provider-config"
|
||||
version = "0.0.0"
|
||||
dependencies = [
|
||||
"codex-app-server-protocol",
|
||||
"reqwest",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"thiserror 2.0.17",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "codex-responses-api-proxy"
|
||||
version = "0.0.0"
|
||||
@@ -4450,7 +4486,7 @@ checksum = "3af6b589e163c5a788fab00ce0c0366f6efbb9959c2f9874b224936af7fce7e1"
|
||||
dependencies = [
|
||||
"base64",
|
||||
"indexmap 2.12.0",
|
||||
"quick-xml",
|
||||
"quick-xml 0.38.0",
|
||||
"serde",
|
||||
"time",
|
||||
]
|
||||
@@ -7093,7 +7129,7 @@ version = "0.31.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c66a47e840dc20793f2264eb4b3e4ecb4b75d91c0dd4af04b456128e0bdd449d"
|
||||
dependencies = [
|
||||
"bitflags 2.9.1",
|
||||
"bitflags 2.10.0",
|
||||
"rustix 1.0.8",
|
||||
"wayland-backend",
|
||||
"wayland-scanner",
|
||||
@@ -7105,7 +7141,7 @@ version = "0.32.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "efa790ed75fbfd71283bd2521a1cfdc022aabcc28bdcff00851f9e4ae88d9901"
|
||||
dependencies = [
|
||||
"bitflags 2.9.1",
|
||||
"bitflags 2.10.0",
|
||||
"wayland-backend",
|
||||
"wayland-client",
|
||||
"wayland-scanner",
|
||||
@@ -7117,7 +7153,7 @@ version = "0.3.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "efd94963ed43cf9938a090ca4f7da58eb55325ec8200c3848963e98dc25b78ec"
|
||||
dependencies = [
|
||||
"bitflags 2.9.1",
|
||||
"bitflags 2.10.0",
|
||||
"wayland-backend",
|
||||
"wayland-client",
|
||||
"wayland-protocols",
|
||||
@@ -7726,7 +7762,7 @@ dependencies = [
|
||||
"os_pipe",
|
||||
"rustix 0.38.44",
|
||||
"tempfile",
|
||||
"thiserror 2.0.16",
|
||||
"thiserror 2.0.17",
|
||||
"tree_magic_mini",
|
||||
"wayland-backend",
|
||||
"wayland-client",
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
[workspace]
|
||||
members = [
|
||||
"api-client",
|
||||
"backend-client",
|
||||
"ansi-escape",
|
||||
"async-utils",
|
||||
@@ -25,6 +26,7 @@ members = [
|
||||
"ollama",
|
||||
"process-hardening",
|
||||
"protocol",
|
||||
"provider-config",
|
||||
"rmcp-client",
|
||||
"responses-api-proxy",
|
||||
"stdio-to-uds",
|
||||
@@ -53,6 +55,7 @@ edition = "2024"
|
||||
# Internal
|
||||
app_test_support = { path = "app-server/tests/common" }
|
||||
codex-ansi-escape = { path = "ansi-escape" }
|
||||
codex-api-client = { path = "api-client" }
|
||||
codex-app-server = { path = "app-server" }
|
||||
codex-app-server-protocol = { path = "app-server-protocol" }
|
||||
codex-apply-patch = { path = "apply-patch" }
|
||||
@@ -74,6 +77,7 @@ codex-ollama = { path = "ollama" }
|
||||
codex-otel = { path = "otel" }
|
||||
codex-process-hardening = { path = "process-hardening" }
|
||||
codex-protocol = { path = "protocol" }
|
||||
codex-provider-config = { path = "provider-config" }
|
||||
codex-responses-api-proxy = { path = "responses-api-proxy" }
|
||||
codex-rmcp-client = { path = "rmcp-client" }
|
||||
codex-stdio-to-uds = { path = "stdio-to-uds" }
|
||||
|
||||
29
codex-rs/api-client/Cargo.toml
Normal file
29
codex-rs/api-client/Cargo.toml
Normal file
@@ -0,0 +1,29 @@
|
||||
[package]
|
||||
name = "codex-api-client"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
[dependencies]
|
||||
async-trait = { workspace = true }
|
||||
bytes = { workspace = true }
|
||||
codex-app-server-protocol = { workspace = true }
|
||||
codex-otel = { workspace = true }
|
||||
codex-protocol = { path = "../protocol" }
|
||||
codex-provider-config = { path = "../provider-config" }
|
||||
futures = { workspace = true, default-features = false, features = ["std"] }
|
||||
maplit = { workspace = true }
|
||||
regex-lite = { workspace = true }
|
||||
reqwest = { workspace = true, features = ["json", "stream"] }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
serde_json = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
tokio = { workspace = true, features = ["sync", "time", "rt", "rt-multi-thread", "macros", "io-util"] }
|
||||
tokio-util = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
maplit = { workspace = true }
|
||||
pretty_assertions = { workspace = true }
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
18
codex-rs/api-client/src/auth.rs
Normal file
18
codex-rs/api-client/src/auth.rs
Normal file
@@ -0,0 +1,18 @@
|
||||
use crate::error::Result;
|
||||
use async_trait::async_trait;
|
||||
use codex_app_server_protocol::AuthMode;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct AuthContext {
|
||||
pub mode: AuthMode,
|
||||
pub bearer_token: Option<String>,
|
||||
pub account_id: Option<String>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait AuthProvider: Send + Sync {
|
||||
async fn auth_context(&self) -> Option<AuthContext>;
|
||||
async fn refresh_token(&self) -> Result<Option<String>>;
|
||||
}
|
||||
75
codex-rs/api-client/src/chat.rs
Normal file
75
codex-rs/api-client/src/chat.rs
Normal file
@@ -0,0 +1,75 @@
|
||||
use crate::error::Error;
|
||||
use crate::error::Result;
|
||||
use crate::stream::WireResponseStream;
|
||||
use codex_otel::otel_event_manager::OtelEventManager;
|
||||
use codex_provider_config::ModelProviderInfo;
|
||||
use futures::TryStreamExt;
|
||||
|
||||
#[derive(Clone)]
|
||||
/// Configuration for the Chat Completions client (OpenAI-compatible `/v1/chat/completions`).
|
||||
///
|
||||
/// - `http_client`: Reqwest client used for HTTP requests.
|
||||
/// - `provider`: Provider configuration (base URL, headers, retries, etc.).
|
||||
/// - `otel_event_manager`: Telemetry event manager for request/stream instrumentation.
|
||||
pub struct ChatCompletionsApiClientConfig {
|
||||
pub http_client: reqwest::Client,
|
||||
pub provider: ModelProviderInfo,
|
||||
pub otel_event_manager: OtelEventManager,
|
||||
pub extra_headers: Vec<(String, String)>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ChatCompletionsApiClient {
|
||||
config: ChatCompletionsApiClientConfig,
|
||||
}
|
||||
|
||||
impl ChatCompletionsApiClient {
|
||||
pub fn new(config: ChatCompletionsApiClientConfig) -> Result<Self> {
|
||||
Ok(Self { config })
|
||||
}
|
||||
|
||||
pub async fn stream_payload_wire(
|
||||
&self,
|
||||
payload_json: &serde_json::Value,
|
||||
) -> Result<WireResponseStream> {
|
||||
if self.config.provider.wire_api != codex_provider_config::WireApi::Chat {
|
||||
return Err(crate::error::Error::UnsupportedOperation(
|
||||
"ChatCompletionsApiClient requires a Chat provider".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let extra_headers = crate::client::http::header_pairs(&self.config.extra_headers);
|
||||
let mut req_builder = crate::client::http::build_request(
|
||||
&self.config.http_client,
|
||||
&self.config.provider,
|
||||
&None,
|
||||
&extra_headers,
|
||||
)
|
||||
.await?;
|
||||
|
||||
req_builder = req_builder
|
||||
.header(reqwest::header::ACCEPT, "text/event-stream")
|
||||
.json(payload_json);
|
||||
|
||||
let res = self
|
||||
.config
|
||||
.otel_event_manager
|
||||
.log_request(0, || req_builder.send())
|
||||
.await?;
|
||||
|
||||
let stream = res
|
||||
.bytes_stream()
|
||||
.map_err(|err| Error::ResponseStreamFailed {
|
||||
source: err,
|
||||
request_id: None,
|
||||
});
|
||||
let (_, rx_event) = crate::client::sse::spawn_wire_stream(
|
||||
stream,
|
||||
&self.config.provider,
|
||||
self.config.otel_event_manager.clone(),
|
||||
crate::decode_wire::chat::WireChatSseDecoder::new(),
|
||||
);
|
||||
|
||||
Ok(rx_event)
|
||||
}
|
||||
}
|
||||
41
codex-rs/api-client/src/client/fixtures.rs
Normal file
41
codex-rs/api-client/src/client/fixtures.rs
Normal file
@@ -0,0 +1,41 @@
|
||||
use std::path::Path;
|
||||
|
||||
use codex_otel::otel_event_manager::OtelEventManager;
|
||||
use futures::TryStreamExt;
|
||||
use tokio_util::io::ReaderStream;
|
||||
|
||||
use crate::error::Error;
|
||||
use crate::error::Result;
|
||||
use codex_provider_config::ModelProviderInfo;
|
||||
|
||||
pub async fn stream_from_fixture_wire(
|
||||
path: impl AsRef<Path>,
|
||||
provider: ModelProviderInfo,
|
||||
otel_event_manager: OtelEventManager,
|
||||
) -> Result<crate::stream::WireResponseStream> {
|
||||
let display_path = path.as_ref().display().to_string();
|
||||
let content = std::fs::read_to_string(path.as_ref()).map_err(|err| {
|
||||
Error::Other(format!(
|
||||
"failed to read fixture text from {display_path}: {err}"
|
||||
))
|
||||
})?;
|
||||
let content = content
|
||||
.lines()
|
||||
.map(|line| {
|
||||
let mut line_with_spacing = line.to_string();
|
||||
line_with_spacing.push('\n');
|
||||
line_with_spacing.push('\n');
|
||||
line_with_spacing
|
||||
})
|
||||
.collect::<String>();
|
||||
|
||||
let rdr = std::io::Cursor::new(content);
|
||||
let stream = ReaderStream::new(rdr).map_err(|err| Error::Other(err.to_string()));
|
||||
let (_, rx_event) = crate::client::sse::spawn_wire_stream(
|
||||
stream,
|
||||
&provider,
|
||||
otel_event_manager,
|
||||
crate::decode_wire::responses::WireResponsesSseDecoder,
|
||||
);
|
||||
Ok(rx_event)
|
||||
}
|
||||
56
codex-rs/api-client/src/client/http.rs
Normal file
56
codex-rs/api-client/src/client/http.rs
Normal file
@@ -0,0 +1,56 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::auth::AuthContext;
|
||||
use crate::auth::AuthProvider;
|
||||
use crate::error::Result;
|
||||
use codex_provider_config::ModelProviderInfo;
|
||||
|
||||
/// Build a request builder with provider/auth/session headers applied.
|
||||
pub async fn build_request(
|
||||
http_client: &reqwest::Client,
|
||||
provider: &ModelProviderInfo,
|
||||
auth: &Option<AuthContext>,
|
||||
extra_headers: &[(&str, String)],
|
||||
) -> Result<reqwest::RequestBuilder> {
|
||||
let mut builder = provider
|
||||
.create_request_builder(
|
||||
http_client,
|
||||
&auth.as_ref().map(|a| codex_provider_config::AuthContext {
|
||||
mode: a.mode,
|
||||
bearer_token: a.bearer_token.clone(),
|
||||
account_id: a.account_id.clone(),
|
||||
}),
|
||||
)
|
||||
.await
|
||||
.map_err(|e| crate::error::Error::MissingEnvVar {
|
||||
var: match e {
|
||||
codex_provider_config::Error::MissingEnvVar { ref var, .. } => var.clone(),
|
||||
},
|
||||
instructions: match e {
|
||||
codex_provider_config::Error::MissingEnvVar {
|
||||
ref instructions, ..
|
||||
} => instructions.clone(),
|
||||
},
|
||||
})?;
|
||||
for (name, value) in extra_headers {
|
||||
builder = builder.header(*name, value);
|
||||
}
|
||||
Ok(builder)
|
||||
}
|
||||
|
||||
/// Resolve auth context from an optional provider.
|
||||
pub async fn resolve_auth(auth_provider: &Option<Arc<dyn AuthProvider>>) -> Option<AuthContext> {
|
||||
if let Some(p) = auth_provider {
|
||||
p.auth_context().await
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert owned header pairs into borrowed key/value tuples for reqwest.
|
||||
pub fn header_pairs(headers: &[(String, String)]) -> Vec<(&str, String)> {
|
||||
headers
|
||||
.iter()
|
||||
.map(|(k, v)| (k.as_str(), v.clone()))
|
||||
.collect()
|
||||
}
|
||||
24
codex-rs/api-client/src/client/mod.rs
Normal file
24
codex-rs/api-client/src/client/mod.rs
Normal file
@@ -0,0 +1,24 @@
|
||||
use async_trait::async_trait;
|
||||
use codex_otel::otel_event_manager::OtelEventManager;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use crate::error::Result;
|
||||
use crate::stream::WireEvent;
|
||||
|
||||
pub mod fixtures;
|
||||
pub mod http;
|
||||
pub mod rate_limits;
|
||||
pub mod sse;
|
||||
|
||||
// Legacy ResponseEvent-based decoder removed
|
||||
|
||||
/// Decodes framed SSE JSON into WireEvent(s).
|
||||
#[async_trait]
|
||||
pub trait WireResponseDecoder {
|
||||
async fn on_frame(
|
||||
&mut self,
|
||||
json: &str,
|
||||
tx: &mpsc::Sender<Result<WireEvent>>,
|
||||
otel: &OtelEventManager,
|
||||
) -> Result<()>;
|
||||
}
|
||||
64
codex-rs/api-client/src/client/rate_limits.rs
Normal file
64
codex-rs/api-client/src/client/rate_limits.rs
Normal file
@@ -0,0 +1,64 @@
|
||||
use codex_protocol::protocol::RateLimitSnapshot;
|
||||
use codex_protocol::protocol::RateLimitWindow;
|
||||
use reqwest::header::HeaderMap;
|
||||
|
||||
pub fn parse_rate_limit_snapshot(headers: &HeaderMap) -> Option<RateLimitSnapshot> {
|
||||
let primary = parse_rate_limit_window(
|
||||
headers,
|
||||
"x-codex-primary-used-percent",
|
||||
"x-codex-primary-window-minutes",
|
||||
"x-codex-primary-reset-at",
|
||||
);
|
||||
|
||||
let secondary = parse_rate_limit_window(
|
||||
headers,
|
||||
"x-codex-secondary-used-percent",
|
||||
"x-codex-secondary-window-minutes",
|
||||
"x-codex-secondary-reset-at",
|
||||
);
|
||||
|
||||
if primary.is_none() && secondary.is_none() {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(RateLimitSnapshot { primary, secondary })
|
||||
}
|
||||
|
||||
fn parse_rate_limit_window(
|
||||
headers: &HeaderMap,
|
||||
used_percent_header: &str,
|
||||
window_minutes_header: &str,
|
||||
resets_at_header: &str,
|
||||
) -> Option<RateLimitWindow> {
|
||||
let used_percent: Option<f64> = parse_header_f64(headers, used_percent_header);
|
||||
|
||||
used_percent.and_then(|used_percent| {
|
||||
let window_minutes = parse_header_i64(headers, window_minutes_header);
|
||||
let resets_at = parse_header_i64(headers, resets_at_header);
|
||||
|
||||
let has_data = used_percent != 0.0
|
||||
|| window_minutes.is_some_and(|minutes| minutes != 0)
|
||||
|| resets_at.is_some();
|
||||
|
||||
has_data.then_some(RateLimitWindow {
|
||||
used_percent,
|
||||
window_minutes,
|
||||
resets_at,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn parse_header_f64(headers: &HeaderMap, name: &str) -> Option<f64> {
|
||||
parse_header_str(headers, name)?
|
||||
.parse::<f64>()
|
||||
.ok()
|
||||
.filter(|v| v.is_finite())
|
||||
}
|
||||
|
||||
fn parse_header_i64(headers: &HeaderMap, name: &str) -> Option<i64> {
|
||||
parse_header_str(headers, name)?.parse::<i64>().ok()
|
||||
}
|
||||
|
||||
fn parse_header_str<'a>(headers: &'a HeaderMap, name: &str) -> Option<&'a str> {
|
||||
headers.get(name)?.to_str().ok()
|
||||
}
|
||||
430
codex-rs/api-client/src/client/sse.rs
Normal file
430
codex-rs/api-client/src/client/sse.rs
Normal file
@@ -0,0 +1,430 @@
|
||||
use std::time::Duration;
|
||||
use std::time::Instant;
|
||||
|
||||
use bytes::Bytes;
|
||||
use codex_otel::otel_event_manager::OtelEventManager;
|
||||
use codex_provider_config::ModelProviderInfo;
|
||||
use futures::Stream;
|
||||
use futures::StreamExt;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::time::timeout;
|
||||
|
||||
use crate::error::Error;
|
||||
use crate::error::Result;
|
||||
// Legacy ResponseEvent-based SSE framer removed
|
||||
use crate::stream::WireEvent;
|
||||
|
||||
// Legacy ResponseEvent-based SSE framer removed
|
||||
|
||||
struct SseProcessor<S, D> {
|
||||
stream: S,
|
||||
decoder: D,
|
||||
tx_event: mpsc::Sender<Result<WireEvent>>,
|
||||
otel_event_manager: OtelEventManager,
|
||||
buffer: String,
|
||||
max_idle_duration: Duration,
|
||||
}
|
||||
|
||||
impl<S, D> SseProcessor<S, D>
|
||||
where
|
||||
S: Stream<Item = Result<Bytes>> + Send + 'static + Unpin,
|
||||
D: crate::client::WireResponseDecoder + Send,
|
||||
{
|
||||
async fn run(mut self) {
|
||||
loop {
|
||||
let start = Instant::now();
|
||||
let result = timeout(self.max_idle_duration, self.stream.next()).await;
|
||||
let duration = start.elapsed();
|
||||
match result {
|
||||
Err(_) => {
|
||||
self.send_error(
|
||||
None,
|
||||
duration,
|
||||
"idle timeout waiting for SSE",
|
||||
Error::Stream(
|
||||
"stream idle timeout fired before Completed event".to_string(),
|
||||
None,
|
||||
),
|
||||
)
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
Ok(Some(Err(err))) => {
|
||||
let message = format!("{err}");
|
||||
self.send_error(None, duration, &message, err).await;
|
||||
return;
|
||||
}
|
||||
Ok(Some(Ok(chunk))) => {
|
||||
if !self.process_chunk(chunk, duration).await {
|
||||
return;
|
||||
}
|
||||
}
|
||||
Ok(None) => {
|
||||
if !self.drain_buffer(duration).await {
|
||||
return;
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn process_chunk(&mut self, chunk: Bytes, duration: Duration) -> bool {
|
||||
let chunk_str = match std::str::from_utf8(&chunk) {
|
||||
Ok(s) => s,
|
||||
Err(err) => {
|
||||
self.send_error(
|
||||
None,
|
||||
duration,
|
||||
&format!("UTF8 error: {err}"),
|
||||
Error::Other(format!("Invalid UTF-8 in SSE chunk: {err}")),
|
||||
)
|
||||
.await;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
.replace("\r\n", "\n")
|
||||
.replace('\r', "\n");
|
||||
|
||||
self.buffer.push_str(&chunk_str);
|
||||
while let Some(frame) = next_frame(&mut self.buffer) {
|
||||
if !self.handle_frame(frame, duration).await {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
async fn drain_buffer(&mut self, duration: Duration) -> bool {
|
||||
while let Some(frame) = next_frame(&mut self.buffer) {
|
||||
if !self.handle_frame(frame, duration).await {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if self.buffer.is_empty() {
|
||||
return true;
|
||||
}
|
||||
|
||||
let remainder = std::mem::take(&mut self.buffer);
|
||||
self.handle_frame(remainder, duration).await
|
||||
}
|
||||
|
||||
async fn handle_frame(&mut self, frame: String, duration: Duration) -> bool {
|
||||
if let Some(frame) = parse_sse_frame(&frame) {
|
||||
if frame.data.trim() == "[DONE]" {
|
||||
self.otel_event_manager.sse_event_kind(&frame.event);
|
||||
return true;
|
||||
}
|
||||
|
||||
match self
|
||||
.decoder
|
||||
.on_frame(&frame.data, &self.tx_event, &self.otel_event_manager)
|
||||
.await
|
||||
{
|
||||
Ok(_) => {
|
||||
self.otel_event_manager.sse_event_kind(&frame.event);
|
||||
}
|
||||
Err(e) => {
|
||||
let reason = format!("{e}");
|
||||
self.send_error(Some(frame.event.clone()), duration, &reason, e)
|
||||
.await;
|
||||
return false;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
async fn send_error(
|
||||
&mut self,
|
||||
event: Option<String>,
|
||||
duration: Duration,
|
||||
log_reason: impl std::fmt::Display,
|
||||
error: Error,
|
||||
) {
|
||||
self.otel_event_manager
|
||||
.sse_event_failed(event.as_ref(), duration, &log_reason);
|
||||
let _ = self.tx_event.send(Err(error)).await;
|
||||
}
|
||||
}
|
||||
|
||||
/// Spawn an SSE processing task and return a sender/stream pair for wire events.
|
||||
pub fn spawn_wire_stream<S, D>(
|
||||
stream: S,
|
||||
provider: &ModelProviderInfo,
|
||||
otel_event_manager: OtelEventManager,
|
||||
decoder: D,
|
||||
) -> (
|
||||
mpsc::Sender<Result<WireEvent>>,
|
||||
crate::stream::WireResponseStream,
|
||||
)
|
||||
where
|
||||
S: Stream<Item = Result<Bytes>> + Send + 'static + Unpin,
|
||||
D: crate::client::WireResponseDecoder + Send + 'static,
|
||||
{
|
||||
let (tx_event, rx_event) = mpsc::channel::<Result<WireEvent>>(1600);
|
||||
let idle_timeout = provider.stream_idle_timeout();
|
||||
let otel = otel_event_manager;
|
||||
let tx_for_task = tx_event.clone();
|
||||
|
||||
tokio::spawn(process_sse_wire(
|
||||
stream,
|
||||
tx_for_task,
|
||||
idle_timeout,
|
||||
otel,
|
||||
decoder,
|
||||
));
|
||||
|
||||
(
|
||||
tx_event,
|
||||
crate::stream::EventStream::from_receiver(rx_event),
|
||||
)
|
||||
}
|
||||
|
||||
/// Generic SSE framer for wire events: Byte stream -> framed JSON -> WireResponseDecoder.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn process_sse_wire<S, D>(
|
||||
stream: S,
|
||||
tx_event: mpsc::Sender<Result<WireEvent>>,
|
||||
max_idle_duration: Duration,
|
||||
otel_event_manager: OtelEventManager,
|
||||
decoder: D,
|
||||
) where
|
||||
S: Stream<Item = Result<Bytes>> + Send + 'static + Unpin,
|
||||
D: crate::client::WireResponseDecoder + Send,
|
||||
{
|
||||
SseProcessor {
|
||||
stream,
|
||||
decoder,
|
||||
tx_event,
|
||||
otel_event_manager,
|
||||
buffer: String::new(),
|
||||
max_idle_duration,
|
||||
}
|
||||
.run()
|
||||
.await;
|
||||
}
|
||||
|
||||
fn next_frame(buffer: &mut String) -> Option<String> {
|
||||
loop {
|
||||
let idx = buffer.find("\n\n")?;
|
||||
|
||||
let frame = buffer[..idx].to_string();
|
||||
buffer.drain(..idx + 2);
|
||||
|
||||
if frame.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
return Some(frame);
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_sse_frame(frame: &str) -> Option<SseFrame> {
|
||||
let mut data = String::new();
|
||||
let mut event: Option<String> = None;
|
||||
let mut saw_data_line = false;
|
||||
|
||||
for raw_line in frame.split('\n') {
|
||||
let line = raw_line.strip_suffix('\r').unwrap_or(raw_line);
|
||||
if line.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(rest) = line.strip_prefix("event:") {
|
||||
let trimmed = rest.trim_start();
|
||||
if !trimmed.is_empty() {
|
||||
event = Some(trimmed.to_string());
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(rest) = line.strip_prefix("data:") {
|
||||
let content = rest.strip_prefix(' ').unwrap_or(rest);
|
||||
if saw_data_line {
|
||||
data.push('\n');
|
||||
}
|
||||
data.push_str(content);
|
||||
saw_data_line = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
if saw_data_line {
|
||||
data.push('\n');
|
||||
data.push_str(line.trim_start());
|
||||
}
|
||||
}
|
||||
|
||||
if data.is_empty() && event.is_none() && !saw_data_line {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(SseFrame {
|
||||
event: event.unwrap_or_else(|| "message".to_string()),
|
||||
data,
|
||||
})
|
||||
}
|
||||
|
||||
struct SseFrame {
|
||||
event: String,
|
||||
data: String,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use codex_protocol::ConversationId;
|
||||
use futures::stream;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
use std::fmt::Write as _;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
#[tokio::test]
|
||||
async fn apply_patch_body_handles_coalesced_and_split_chunks() {
|
||||
let events = apply_patch_events();
|
||||
let chunk_variants = vec![
|
||||
vec![sse(events.clone())],
|
||||
vec![sse(events[..2].to_vec()), sse(events[2..].to_vec())],
|
||||
];
|
||||
|
||||
for chunks in chunk_variants {
|
||||
let events = collect_events(chunks).await;
|
||||
assert_eq!(
|
||||
events,
|
||||
vec![
|
||||
"created",
|
||||
"response.output_item.done",
|
||||
"response.output_item.added",
|
||||
"response.completed"
|
||||
]
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn multiple_events_in_single_chunk_emit_done() {
|
||||
let chunk = sse(vec![
|
||||
event_output_item_done("call-inline"),
|
||||
event_completed("resp-inline"),
|
||||
]);
|
||||
let events = collect_events(vec![chunk]).await;
|
||||
assert_eq!(
|
||||
events,
|
||||
vec!["response.output_item.done", "response.completed",]
|
||||
);
|
||||
}
|
||||
|
||||
async fn collect_events(chunks: Vec<String>) -> Vec<String> {
|
||||
let (tx_event, mut rx_event) = mpsc::channel::<Result<WireEvent>>(16);
|
||||
let stream = stream::iter(chunks.into_iter().map(|chunk| Ok(Bytes::from(chunk))));
|
||||
let otel_event_manager = OtelEventManager::new(
|
||||
ConversationId::new(),
|
||||
"test-model",
|
||||
"test-slug",
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
false,
|
||||
"terminal".to_string(),
|
||||
);
|
||||
|
||||
let handle = tokio::spawn(process_sse_wire(
|
||||
stream,
|
||||
tx_event,
|
||||
Duration::from_secs(5),
|
||||
otel_event_manager,
|
||||
crate::decode_wire::responses::WireResponsesSseDecoder,
|
||||
));
|
||||
|
||||
let mut out = Vec::new();
|
||||
while let Some(event) = rx_event.recv().await {
|
||||
let event = event.expect("event decoding should succeed");
|
||||
out.push(event_name(&event));
|
||||
}
|
||||
handle
|
||||
.await
|
||||
.expect("SSE framing task should complete without panicking");
|
||||
out
|
||||
}
|
||||
|
||||
fn event_name(event: &WireEvent) -> String {
|
||||
match event {
|
||||
WireEvent::Created => "created",
|
||||
WireEvent::OutputItemDone(_) => "response.output_item.done",
|
||||
WireEvent::OutputItemAdded(_) => "response.output_item.added",
|
||||
WireEvent::Completed { .. } => "response.completed",
|
||||
WireEvent::OutputTextDelta(_) => "response.output_text.delta",
|
||||
WireEvent::ReasoningSummaryDelta(_) => "response.reasoning_summary_text.delta",
|
||||
WireEvent::ReasoningContentDelta(_) => "response.reasoning_text.delta",
|
||||
WireEvent::ReasoningSummaryPartAdded => "response.reasoning_summary_part.added",
|
||||
WireEvent::RateLimits(_) => "response.rate_limits",
|
||||
}
|
||||
.to_string()
|
||||
}
|
||||
|
||||
fn apply_patch_events() -> Vec<serde_json::Value> {
|
||||
vec![
|
||||
json!({
|
||||
"type": "response.created",
|
||||
"response": { "id": "resp-apply-patch" }
|
||||
}),
|
||||
event_output_item_done("apply-patch-call"),
|
||||
json!({
|
||||
"type": "response.output_item.added",
|
||||
"item": {
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [{"type": "output_text", "text": "ok"}]
|
||||
}
|
||||
}),
|
||||
event_completed("resp-apply-patch"),
|
||||
]
|
||||
}
|
||||
|
||||
fn event_output_item_done(call_id: &str) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "response.output_item.done",
|
||||
"item": {
|
||||
"type": "function_call",
|
||||
"name": "apply_patch",
|
||||
"arguments": "{\"input\":\"*** Begin Patch\\n*** End Patch\"}",
|
||||
"call_id": call_id
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn event_completed(id: &str) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "response.completed",
|
||||
"response": {
|
||||
"id": id,
|
||||
"usage": {
|
||||
"input_tokens": 0,
|
||||
"input_tokens_details": null,
|
||||
"output_tokens": 0,
|
||||
"output_tokens_details": null,
|
||||
"reasoning_output_tokens": 0,
|
||||
"total_tokens": 0
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn sse(events: Vec<serde_json::Value>) -> String {
|
||||
let mut out = String::new();
|
||||
for ev in events {
|
||||
let kind = ev.get("type").and_then(|v| v.as_str()).unwrap_or_default();
|
||||
writeln!(&mut out, "event: {kind}").unwrap();
|
||||
if !ev.as_object().map(|o| o.len() == 1).unwrap_or(false) {
|
||||
write!(&mut out, "data: {ev}\n\n").unwrap();
|
||||
} else {
|
||||
out.push('\n');
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
}
|
||||
293
codex-rs/api-client/src/decode_wire/chat.rs
Normal file
293
codex-rs/api-client/src/decode_wire/chat.rs
Normal file
@@ -0,0 +1,293 @@
|
||||
use async_trait::async_trait;
|
||||
use codex_otel::otel_event_manager::OtelEventManager;
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::ReasoningItemContent;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use serde::Deserialize;
|
||||
use tokio::sync::mpsc;
|
||||
use tracing::debug;
|
||||
|
||||
use crate::client::WireResponseDecoder;
|
||||
use crate::error::Error;
|
||||
use crate::error::Result;
|
||||
use crate::stream::WireEvent;
|
||||
|
||||
async fn send_wire_event(tx: &mpsc::Sender<crate::error::Result<WireEvent>>, event: WireEvent) {
|
||||
let _ = tx.send(Ok(event)).await;
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct FunctionCallState {
|
||||
active: bool,
|
||||
call_id: Option<String>,
|
||||
name: Option<String>,
|
||||
arguments: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Deserialize)]
|
||||
struct ChatChunk {
|
||||
#[serde(default)]
|
||||
choices: Vec<ChatChoice>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Deserialize)]
|
||||
struct ChatChoice {
|
||||
#[serde(default)]
|
||||
delta: Option<ChatDelta>,
|
||||
#[serde(default)]
|
||||
finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Deserialize)]
|
||||
struct ChatDelta {
|
||||
#[serde(default)]
|
||||
content: Vec<DeltaText>,
|
||||
#[serde(default)]
|
||||
reasoning_content: Vec<DeltaText>,
|
||||
#[serde(default)]
|
||||
tool_calls: Vec<ChatToolCall>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Deserialize)]
|
||||
struct DeltaText {
|
||||
#[serde(default)]
|
||||
text: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Deserialize)]
|
||||
struct ChatToolCall {
|
||||
#[serde(default)]
|
||||
id: Option<String>,
|
||||
#[serde(default)]
|
||||
function: Option<ChatFunction>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Deserialize)]
|
||||
struct ChatFunction {
|
||||
#[serde(default)]
|
||||
name: String,
|
||||
#[serde(default)]
|
||||
arguments: String,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct WireChatSseDecoder {
|
||||
fn_call_state: FunctionCallState,
|
||||
created_emitted: bool,
|
||||
assistant_started: bool,
|
||||
assistant_text: String,
|
||||
reasoning_started: bool,
|
||||
reasoning_text: String,
|
||||
}
|
||||
|
||||
impl WireChatSseDecoder {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
async fn emit_created_once(&mut self, tx: &mpsc::Sender<crate::error::Result<WireEvent>>) {
|
||||
if self.created_emitted {
|
||||
return;
|
||||
}
|
||||
send_wire_event(tx, WireEvent::Created).await;
|
||||
self.created_emitted = true;
|
||||
}
|
||||
|
||||
async fn handle_content_delta(
|
||||
&mut self,
|
||||
delta: &ChatDelta,
|
||||
tx: &mpsc::Sender<crate::error::Result<WireEvent>>,
|
||||
) {
|
||||
for piece in &delta.content {
|
||||
if !piece.text.is_empty() {
|
||||
self.push_assistant_text(&piece.text, tx).await;
|
||||
}
|
||||
}
|
||||
|
||||
for entry in &delta.reasoning_content {
|
||||
if !entry.text.is_empty() {
|
||||
self.push_reasoning_text(&entry.text, tx).await;
|
||||
}
|
||||
}
|
||||
|
||||
self.record_tool_calls(&delta.tool_calls);
|
||||
}
|
||||
|
||||
async fn push_assistant_text(
|
||||
&mut self,
|
||||
text: &str,
|
||||
tx: &mpsc::Sender<crate::error::Result<WireEvent>>,
|
||||
) {
|
||||
self.start_assistant(tx).await;
|
||||
self.assistant_text.push_str(text);
|
||||
send_wire_event(tx, WireEvent::OutputTextDelta(text.to_string())).await;
|
||||
}
|
||||
|
||||
async fn push_reasoning_text(
|
||||
&mut self,
|
||||
text: &str,
|
||||
tx: &mpsc::Sender<crate::error::Result<WireEvent>>,
|
||||
) {
|
||||
self.start_reasoning(tx).await;
|
||||
self.reasoning_text.push_str(text);
|
||||
send_wire_event(tx, WireEvent::ReasoningContentDelta(text.to_string())).await;
|
||||
}
|
||||
|
||||
async fn start_assistant(&mut self, tx: &mpsc::Sender<crate::error::Result<WireEvent>>) {
|
||||
if self.assistant_started {
|
||||
return;
|
||||
}
|
||||
self.assistant_started = true;
|
||||
let message = ResponseItem::Message {
|
||||
id: None,
|
||||
role: "assistant".to_string(),
|
||||
content: vec![ContentItem::OutputText {
|
||||
text: String::new(),
|
||||
}],
|
||||
};
|
||||
send_wire_event(tx, WireEvent::OutputItemAdded(message)).await;
|
||||
}
|
||||
|
||||
async fn start_reasoning(&mut self, tx: &mpsc::Sender<crate::error::Result<WireEvent>>) {
|
||||
if self.reasoning_started {
|
||||
return;
|
||||
}
|
||||
self.reasoning_started = true;
|
||||
let reasoning_item = ResponseItem::Reasoning {
|
||||
id: String::new(),
|
||||
summary: vec![],
|
||||
content: None,
|
||||
encrypted_content: None,
|
||||
};
|
||||
send_wire_event(tx, WireEvent::OutputItemAdded(reasoning_item)).await;
|
||||
}
|
||||
|
||||
fn record_tool_calls(&mut self, tool_calls: &[ChatToolCall]) {
|
||||
for call in tool_calls {
|
||||
if let Some(id_val) = &call.id {
|
||||
self.fn_call_state.call_id = Some(id_val.clone());
|
||||
}
|
||||
if let Some(function) = &call.function {
|
||||
if !function.name.is_empty() {
|
||||
self.fn_call_state.name = Some(function.name.clone());
|
||||
self.fn_call_state.active = true;
|
||||
}
|
||||
if !function.arguments.is_empty() {
|
||||
self.fn_call_state.arguments.push_str(&function.arguments);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn finish_function_call(&mut self) -> Option<ResponseItem> {
|
||||
if !self.fn_call_state.active {
|
||||
return None;
|
||||
}
|
||||
let function_name = self.fn_call_state.name.take().unwrap_or_default();
|
||||
let call_id = self.fn_call_state.call_id.take().unwrap_or_default();
|
||||
let arguments = std::mem::take(&mut self.fn_call_state.arguments);
|
||||
self.fn_call_state = FunctionCallState::default();
|
||||
|
||||
Some(ResponseItem::FunctionCall {
|
||||
id: Some(call_id.clone()),
|
||||
name: function_name,
|
||||
arguments,
|
||||
call_id,
|
||||
})
|
||||
}
|
||||
|
||||
fn finish_reasoning(&mut self) -> Option<ResponseItem> {
|
||||
if !self.reasoning_started {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut content = Vec::new();
|
||||
let text = std::mem::take(&mut self.reasoning_text);
|
||||
if !text.is_empty() {
|
||||
content.push(ReasoningItemContent::ReasoningText { text });
|
||||
}
|
||||
self.reasoning_started = false;
|
||||
|
||||
Some(ResponseItem::Reasoning {
|
||||
id: String::new(),
|
||||
summary: vec![],
|
||||
content: Some(content),
|
||||
encrypted_content: None,
|
||||
})
|
||||
}
|
||||
|
||||
fn finish_assistant(&mut self) -> Option<ResponseItem> {
|
||||
if !self.assistant_started {
|
||||
return None;
|
||||
}
|
||||
let text = std::mem::take(&mut self.assistant_text);
|
||||
self.assistant_started = false;
|
||||
|
||||
Some(ResponseItem::Message {
|
||||
id: None,
|
||||
role: "assistant".to_string(),
|
||||
content: vec![ContentItem::OutputText { text }],
|
||||
})
|
||||
}
|
||||
|
||||
fn reset_reasoning_and_assistant(&mut self) {
|
||||
self.assistant_started = false;
|
||||
self.assistant_text.clear();
|
||||
self.reasoning_started = false;
|
||||
self.reasoning_text.clear();
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl WireResponseDecoder for WireChatSseDecoder {
|
||||
async fn on_frame(
|
||||
&mut self,
|
||||
json: &str,
|
||||
tx: &mpsc::Sender<crate::error::Result<WireEvent>>,
|
||||
_otel: &OtelEventManager,
|
||||
) -> Result<()> {
|
||||
let chunk = serde_json::from_str::<ChatChunk>(json).map_err(|err| {
|
||||
debug!("failed to parse Chat SSE JSON: {}", json);
|
||||
Error::Other(format!("failed to parse Chat SSE JSON: {err}"))
|
||||
})?;
|
||||
|
||||
for choice in chunk.choices {
|
||||
self.emit_created_once(tx).await;
|
||||
|
||||
if let Some(delta) = &choice.delta {
|
||||
self.handle_content_delta(delta, tx).await;
|
||||
}
|
||||
|
||||
match choice.finish_reason.as_deref() {
|
||||
Some("tool_calls") => {
|
||||
if let Some(item) = self.finish_function_call() {
|
||||
send_wire_event(tx, WireEvent::OutputItemDone(item)).await;
|
||||
}
|
||||
}
|
||||
Some("stop") | Some("length") => {
|
||||
if let Some(reasoning_item) = self.finish_reasoning() {
|
||||
send_wire_event(tx, WireEvent::OutputItemDone(reasoning_item)).await;
|
||||
}
|
||||
|
||||
if let Some(message) = self.finish_assistant() {
|
||||
send_wire_event(tx, WireEvent::OutputItemDone(message)).await;
|
||||
}
|
||||
|
||||
send_wire_event(
|
||||
tx,
|
||||
WireEvent::Completed {
|
||||
response_id: String::new(),
|
||||
token_usage: None,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
self.reset_reasoning_and_assistant();
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
2
codex-rs/api-client/src/decode_wire/mod.rs
Normal file
2
codex-rs/api-client/src/decode_wire/mod.rs
Normal file
@@ -0,0 +1,2 @@
|
||||
pub mod chat;
|
||||
pub mod responses;
|
||||
187
codex-rs/api-client/src/decode_wire/responses.rs
Normal file
187
codex-rs/api-client/src/decode_wire/responses.rs
Normal file
@@ -0,0 +1,187 @@
|
||||
use async_trait::async_trait;
|
||||
use codex_otel::otel_event_manager::OtelEventManager;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::protocol::TokenUsage;
|
||||
use serde::Deserialize;
|
||||
use serde_json::Value;
|
||||
use tokio::sync::mpsc;
|
||||
use tracing::debug;
|
||||
|
||||
use crate::client::WireResponseDecoder;
|
||||
use crate::error::Error;
|
||||
use crate::error::Result;
|
||||
use crate::stream::WireEvent;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct StreamEvent {
|
||||
#[serde(rename = "type")]
|
||||
event_type: String,
|
||||
#[serde(default)]
|
||||
response: Option<Value>,
|
||||
#[serde(default)]
|
||||
item: Option<Value>,
|
||||
#[serde(default)]
|
||||
error: Option<Value>,
|
||||
#[serde(default)]
|
||||
delta: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Default, Deserialize)]
|
||||
struct WireUsage {
|
||||
#[serde(default)]
|
||||
input_tokens: i64,
|
||||
#[serde(default)]
|
||||
cached_input_tokens: Option<i64>,
|
||||
#[serde(default)]
|
||||
output_tokens: i64,
|
||||
#[serde(default)]
|
||||
reasoning_output_tokens: Option<i64>,
|
||||
#[serde(default)]
|
||||
total_tokens: i64,
|
||||
#[serde(default)]
|
||||
input_tokens_details: Option<WireInputTokensDetails>,
|
||||
#[serde(default)]
|
||||
output_tokens_details: Option<WireOutputTokensDetails>,
|
||||
}
|
||||
|
||||
#[derive(Default, Deserialize)]
|
||||
struct WireInputTokensDetails {
|
||||
#[serde(default)]
|
||||
cached_tokens: Option<i64>,
|
||||
}
|
||||
|
||||
#[derive(Default, Deserialize)]
|
||||
struct WireOutputTokensDetails {
|
||||
#[serde(default)]
|
||||
reasoning_tokens: Option<i64>,
|
||||
}
|
||||
|
||||
pub struct WireResponsesSseDecoder;
|
||||
|
||||
#[async_trait]
|
||||
impl WireResponseDecoder for WireResponsesSseDecoder {
|
||||
async fn on_frame(
|
||||
&mut self,
|
||||
json: &str,
|
||||
tx: &mpsc::Sender<Result<WireEvent>>,
|
||||
otel: &OtelEventManager,
|
||||
) -> Result<()> {
|
||||
let event = serde_json::from_str::<StreamEvent>(json).map_err(|err| {
|
||||
debug!("failed to parse Responses SSE JSON: {}", json);
|
||||
Error::Other(format!("failed to parse Responses SSE JSON: {err}"))
|
||||
})?;
|
||||
|
||||
match event.event_type.as_str() {
|
||||
"response.created" => {
|
||||
let _ = tx.send(Ok(WireEvent::Created)).await;
|
||||
}
|
||||
"response.output_text.delta" => {
|
||||
if let Some(delta) = event.delta.or_else(|| {
|
||||
event.item.and_then(|v| {
|
||||
v.get("delta")
|
||||
.and_then(|d| d.as_str().map(std::string::ToString::to_string))
|
||||
})
|
||||
}) {
|
||||
let _ = tx.send(Ok(WireEvent::OutputTextDelta(delta))).await;
|
||||
}
|
||||
}
|
||||
"response.reasoning_text.delta" => {
|
||||
if let Some(delta) = event.delta {
|
||||
let _ = tx.send(Ok(WireEvent::ReasoningContentDelta(delta))).await;
|
||||
}
|
||||
}
|
||||
"response.reasoning_summary_text.delta" => {
|
||||
if let Some(delta) = event.delta {
|
||||
let _ = tx.send(Ok(WireEvent::ReasoningSummaryDelta(delta))).await;
|
||||
}
|
||||
}
|
||||
"response.output_item.done" => {
|
||||
if let Some(item_val) = event.item {
|
||||
let item = parse_response_item(item_val);
|
||||
let _ = tx.send(Ok(WireEvent::OutputItemDone(item))).await;
|
||||
}
|
||||
}
|
||||
"response.output_item.added" => {
|
||||
if let Some(item_val) = event.item {
|
||||
let item = parse_response_item(item_val);
|
||||
let _ = tx.send(Ok(WireEvent::OutputItemAdded(item))).await;
|
||||
}
|
||||
}
|
||||
"response.reasoning_summary_part.added" => {
|
||||
let _ = tx.send(Ok(WireEvent::ReasoningSummaryPartAdded)).await;
|
||||
}
|
||||
"response.completed" => {
|
||||
if let Some(resp) = event.response {
|
||||
let response_id = resp
|
||||
.get("id")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or_default()
|
||||
.to_string();
|
||||
let usage = parse_usage(&resp);
|
||||
if let Some(u) = &usage {
|
||||
otel.sse_event_completed(
|
||||
u.input_tokens,
|
||||
u.output_tokens,
|
||||
Some(u.cached_input_tokens),
|
||||
Some(u.reasoning_output_tokens),
|
||||
u.total_tokens,
|
||||
);
|
||||
} else {
|
||||
otel.see_event_completed_failed(&"missing token usage".to_string());
|
||||
}
|
||||
let _ = tx
|
||||
.send(Ok(WireEvent::Completed {
|
||||
response_id,
|
||||
token_usage: usage,
|
||||
}))
|
||||
.await;
|
||||
}
|
||||
}
|
||||
"response.error" | "response.failed" => {
|
||||
let message = event
|
||||
.error
|
||||
.as_ref()
|
||||
.and_then(|v| v.get("message"))
|
||||
.and_then(|v| v.as_str())
|
||||
.map(std::string::ToString::to_string)
|
||||
.unwrap_or_else(|| "unknown error".to_string());
|
||||
let _ = tx.send(Err(Error::Stream(message, None))).await;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_usage(resp: &Value) -> Option<TokenUsage> {
|
||||
let usage: WireUsage = serde_json::from_value(resp.get("usage")?.clone()).ok()?;
|
||||
let cached_input_tokens = usage
|
||||
.cached_input_tokens
|
||||
.or_else(|| {
|
||||
usage
|
||||
.input_tokens_details
|
||||
.and_then(|details| details.cached_tokens)
|
||||
})
|
||||
.unwrap_or(0);
|
||||
let reasoning_output_tokens = usage
|
||||
.reasoning_output_tokens
|
||||
.or_else(|| {
|
||||
usage
|
||||
.output_tokens_details
|
||||
.and_then(|details| details.reasoning_tokens)
|
||||
})
|
||||
.unwrap_or(0);
|
||||
|
||||
Some(TokenUsage {
|
||||
input_tokens: usage.input_tokens,
|
||||
cached_input_tokens,
|
||||
output_tokens: usage.output_tokens,
|
||||
reasoning_output_tokens,
|
||||
total_tokens: usage.total_tokens,
|
||||
})
|
||||
}
|
||||
|
||||
fn parse_response_item(value: Value) -> ResponseItem {
|
||||
serde_json::from_value(value).unwrap_or(ResponseItem::Other)
|
||||
}
|
||||
45
codex-rs/api-client/src/error.rs
Normal file
45
codex-rs/api-client/src/error.rs
Normal file
@@ -0,0 +1,45 @@
|
||||
use codex_protocol::protocol::RateLimitSnapshot;
|
||||
use reqwest::StatusCode;
|
||||
use thiserror::Error;
|
||||
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum Error {
|
||||
#[error("{0}")]
|
||||
UnsupportedOperation(String),
|
||||
#[error(transparent)]
|
||||
Http(#[from] reqwest::Error),
|
||||
#[error("response stream failed: {source}")]
|
||||
ResponseStreamFailed {
|
||||
#[source]
|
||||
source: reqwest::Error,
|
||||
request_id: Option<String>,
|
||||
},
|
||||
#[error("stream error: {0}")]
|
||||
Stream(String, Option<std::time::Duration>),
|
||||
#[error("usage limit reached")]
|
||||
UsageLimitReached {
|
||||
plan_type: Option<String>,
|
||||
resets_at: Option<i64>,
|
||||
rate_limits: Option<RateLimitSnapshot>,
|
||||
},
|
||||
#[error("unexpected status {status}: {body}")]
|
||||
UnexpectedStatus { status: StatusCode, body: String },
|
||||
#[error("retry limit reached {status:?} request_id={request_id:?}")]
|
||||
RetryLimit {
|
||||
status: Option<StatusCode>,
|
||||
request_id: Option<String>,
|
||||
},
|
||||
#[error("missing env var {var}: {instructions:?}")]
|
||||
MissingEnvVar {
|
||||
var: String,
|
||||
instructions: Option<String>,
|
||||
},
|
||||
#[error("auth error: {0}")]
|
||||
Auth(String),
|
||||
#[error(transparent)]
|
||||
Json(#[from] serde_json::Error),
|
||||
#[error("{0}")]
|
||||
Other(String),
|
||||
}
|
||||
32
codex-rs/api-client/src/lib.rs
Normal file
32
codex-rs/api-client/src/lib.rs
Normal file
@@ -0,0 +1,32 @@
|
||||
pub mod auth;
|
||||
pub mod chat;
|
||||
mod client;
|
||||
// Legacy payload decoding has been removed; wire decoding lives in decode_wire
|
||||
mod decode_wire;
|
||||
pub mod error;
|
||||
// payload building lives in codex-core now
|
||||
pub mod responses;
|
||||
pub mod routed_client;
|
||||
pub mod stream;
|
||||
|
||||
pub use crate::auth::AuthContext;
|
||||
pub use crate::auth::AuthProvider;
|
||||
pub use crate::chat::ChatCompletionsApiClient;
|
||||
pub use crate::chat::ChatCompletionsApiClientConfig;
|
||||
pub use crate::error::Error;
|
||||
pub use crate::error::Result;
|
||||
pub use crate::responses::ResponsesApiClient;
|
||||
pub use crate::responses::ResponsesApiClientConfig;
|
||||
pub use crate::routed_client::RoutedApiClient;
|
||||
pub use crate::routed_client::RoutedApiClientConfig;
|
||||
pub use crate::stream::EventStream;
|
||||
pub use crate::stream::Reasoning;
|
||||
pub use crate::stream::ResponseEvent;
|
||||
pub use crate::stream::ResponseStream;
|
||||
pub use crate::stream::TextControls;
|
||||
pub use crate::stream::TextFormat;
|
||||
pub use crate::stream::TextFormatType;
|
||||
pub use crate::stream::WireEvent;
|
||||
pub use crate::stream::WireResponseStream;
|
||||
pub use codex_provider_config::ModelProviderInfo;
|
||||
pub use codex_provider_config::WireApi;
|
||||
137
codex-rs/api-client/src/responses.rs
Normal file
137
codex-rs/api-client/src/responses.rs
Normal file
@@ -0,0 +1,137 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use codex_app_server_protocol::AuthMode;
|
||||
use codex_otel::otel_event_manager::OtelEventManager;
|
||||
use codex_protocol::ConversationId;
|
||||
use futures::TryStreamExt;
|
||||
use serde_json::Value;
|
||||
use tracing::debug;
|
||||
use tracing::trace;
|
||||
|
||||
use crate::auth::AuthProvider;
|
||||
use crate::error::Error;
|
||||
use crate::error::Result;
|
||||
use crate::stream::WireResponseStream;
|
||||
use codex_provider_config::ModelProviderInfo;
|
||||
|
||||
#[derive(Clone)]
|
||||
/// Configuration for the OpenAI Responses API client (`/v1/responses`).
|
||||
///
|
||||
/// - `http_client`: Reqwest client used for HTTP requests.
|
||||
/// - `provider`: Provider configuration (base URL, headers, retries, etc.).
|
||||
/// - `conversation_id`: Used to set conversation/session headers and cache keys.
|
||||
/// - `auth_provider`: Optional provider of auth context (e.g., ChatGPT login token).
|
||||
/// - `otel_event_manager`: Telemetry event manager for request/stream instrumentation.
|
||||
pub struct ResponsesApiClientConfig {
|
||||
pub http_client: reqwest::Client,
|
||||
pub provider: ModelProviderInfo,
|
||||
pub conversation_id: ConversationId,
|
||||
pub auth_provider: Option<Arc<dyn AuthProvider>>,
|
||||
pub otel_event_manager: OtelEventManager,
|
||||
pub extra_headers: Vec<(String, String)>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ResponsesApiClient {
|
||||
config: ResponsesApiClientConfig,
|
||||
}
|
||||
|
||||
impl ResponsesApiClient {
|
||||
pub fn new(config: ResponsesApiClientConfig) -> Result<Self> {
|
||||
Ok(Self { config })
|
||||
}
|
||||
}
|
||||
|
||||
impl ResponsesApiClient {
|
||||
pub async fn stream_payload_wire(&self, payload_json: &Value) -> Result<WireResponseStream> {
|
||||
if self.config.provider.wire_api != codex_provider_config::WireApi::Responses {
|
||||
return Err(Error::UnsupportedOperation(
|
||||
"ResponsesApiClient requires a Responses provider".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let auth = crate::client::http::resolve_auth(&self.config.auth_provider).await;
|
||||
|
||||
trace!(
|
||||
"POST to {}: {:?}",
|
||||
self.config.provider.get_full_url(
|
||||
auth.as_ref()
|
||||
.map(|a| codex_provider_config::AuthContext {
|
||||
mode: a.mode,
|
||||
bearer_token: a.bearer_token.clone(),
|
||||
account_id: a.account_id.clone(),
|
||||
})
|
||||
.as_ref()
|
||||
),
|
||||
serde_json::to_string(payload_json)
|
||||
.unwrap_or_else(|_| "<unable to serialize payload>".to_string())
|
||||
);
|
||||
|
||||
let mut owned_headers: Vec<(String, String)> = vec![
|
||||
(
|
||||
"conversation_id".to_string(),
|
||||
self.config.conversation_id.to_string(),
|
||||
),
|
||||
(
|
||||
"session_id".to_string(),
|
||||
self.config.conversation_id.to_string(),
|
||||
),
|
||||
];
|
||||
owned_headers.extend(self.config.extra_headers.iter().cloned());
|
||||
let extra_headers = crate::client::http::header_pairs(&owned_headers);
|
||||
let mut req_builder = crate::client::http::build_request(
|
||||
&self.config.http_client,
|
||||
&self.config.provider,
|
||||
&auth,
|
||||
&extra_headers,
|
||||
)
|
||||
.await?;
|
||||
|
||||
req_builder = req_builder
|
||||
.header(reqwest::header::ACCEPT, "text/event-stream")
|
||||
.json(payload_json);
|
||||
|
||||
if let Some(auth_ctx) = auth.as_ref()
|
||||
&& auth_ctx.mode == AuthMode::ChatGPT
|
||||
&& let Some(account_id) = auth_ctx.account_id.clone()
|
||||
{
|
||||
req_builder = req_builder.header("chatgpt-account-id", account_id);
|
||||
}
|
||||
|
||||
let res = self
|
||||
.config
|
||||
.otel_event_manager
|
||||
.log_request(0, || req_builder.send())
|
||||
.await
|
||||
.map_err(|source| Error::ResponseStreamFailed {
|
||||
source,
|
||||
request_id: None,
|
||||
})?;
|
||||
|
||||
let snapshot = crate::client::rate_limits::parse_rate_limit_snapshot(res.headers());
|
||||
|
||||
let stream = res
|
||||
.bytes_stream()
|
||||
.map_err(|err| Error::ResponseStreamFailed {
|
||||
source: err,
|
||||
request_id: None,
|
||||
});
|
||||
|
||||
let (tx_event, rx_event) = crate::client::sse::spawn_wire_stream(
|
||||
stream,
|
||||
&self.config.provider,
|
||||
self.config.otel_event_manager.clone(),
|
||||
crate::decode_wire::responses::WireResponsesSseDecoder,
|
||||
);
|
||||
if let Some(snapshot) = snapshot
|
||||
&& tx_event
|
||||
.send(Ok(crate::stream::WireEvent::RateLimits(snapshot)))
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
debug!("receiver dropped rate limit snapshot event");
|
||||
}
|
||||
|
||||
Ok(rx_event)
|
||||
}
|
||||
}
|
||||
84
codex-rs/api-client/src/routed_client.rs
Normal file
84
codex-rs/api-client/src/routed_client.rs
Normal file
@@ -0,0 +1,84 @@
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
use codex_otel::otel_event_manager::OtelEventManager;
|
||||
use codex_protocol::ConversationId;
|
||||
|
||||
use crate::ChatCompletionsApiClient;
|
||||
use crate::ChatCompletionsApiClientConfig;
|
||||
use crate::ResponsesApiClient;
|
||||
use crate::ResponsesApiClientConfig;
|
||||
use crate::Result;
|
||||
use crate::WireApi;
|
||||
use crate::WireResponseStream;
|
||||
use crate::auth::AuthProvider;
|
||||
use codex_provider_config::ModelProviderInfo;
|
||||
|
||||
/// Dispatches to the appropriate API client implementation based on the provider wire API.
|
||||
#[derive(Clone)]
|
||||
pub struct RoutedApiClientConfig {
|
||||
pub http_client: reqwest::Client,
|
||||
pub provider: ModelProviderInfo,
|
||||
pub conversation_id: ConversationId,
|
||||
pub auth_provider: Option<Arc<dyn AuthProvider>>,
|
||||
pub otel_event_manager: OtelEventManager,
|
||||
pub responses_fixture_path: Option<PathBuf>,
|
||||
pub extra_headers: Vec<(String, String)>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct RoutedApiClient {
|
||||
config: RoutedApiClientConfig,
|
||||
}
|
||||
|
||||
impl RoutedApiClient {
|
||||
pub fn new(config: RoutedApiClientConfig) -> Self {
|
||||
Self { config }
|
||||
}
|
||||
|
||||
pub async fn stream_payload_wire(
|
||||
&self,
|
||||
payload_json: &serde_json::Value,
|
||||
) -> Result<WireResponseStream> {
|
||||
match self.config.provider.wire_api {
|
||||
WireApi::Responses => {
|
||||
if let Some(path) = &self.config.responses_fixture_path {
|
||||
return crate::client::fixtures::stream_from_fixture_wire(
|
||||
path,
|
||||
self.config.provider.clone(),
|
||||
self.config.otel_event_manager.clone(),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
let client = ResponsesApiClient::new(self.responses_config())?;
|
||||
client.stream_payload_wire(payload_json).await
|
||||
}
|
||||
WireApi::Chat => {
|
||||
let client = ChatCompletionsApiClient::new(self.chat_config())?;
|
||||
client.stream_payload_wire(payload_json).await
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RoutedApiClient {
|
||||
fn responses_config(&self) -> ResponsesApiClientConfig {
|
||||
ResponsesApiClientConfig {
|
||||
http_client: self.config.http_client.clone(),
|
||||
provider: self.config.provider.clone(),
|
||||
conversation_id: self.config.conversation_id,
|
||||
auth_provider: self.config.auth_provider.clone(),
|
||||
otel_event_manager: self.config.otel_event_manager.clone(),
|
||||
extra_headers: self.config.extra_headers.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn chat_config(&self) -> ChatCompletionsApiClientConfig {
|
||||
ChatCompletionsApiClientConfig {
|
||||
http_client: self.config.http_client.clone(),
|
||||
provider: self.config.provider.clone(),
|
||||
otel_event_manager: self.config.otel_event_manager.clone(),
|
||||
extra_headers: self.config.extra_headers.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
86
codex-rs/api-client/src/stream.rs
Normal file
86
codex-rs/api-client/src/stream.rs
Normal file
@@ -0,0 +1,86 @@
|
||||
use std::pin::Pin;
|
||||
use std::task::Context;
|
||||
use std::task::Poll;
|
||||
|
||||
use codex_protocol::config_types::ReasoningEffort as ReasoningEffortConfig;
|
||||
use codex_protocol::config_types::ReasoningSummary as ReasoningSummaryConfig;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::protocol::RateLimitSnapshot;
|
||||
use codex_protocol::protocol::TokenUsage;
|
||||
use futures::Stream;
|
||||
use serde::Serialize;
|
||||
use serde_json::Value;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use crate::error::Result;
|
||||
|
||||
#[derive(Debug, Serialize, Clone)]
|
||||
pub struct Reasoning {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub effort: Option<ReasoningEffortConfig>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub summary: Option<ReasoningSummaryConfig>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Default, Clone)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum TextFormatType {
|
||||
#[default]
|
||||
JsonSchema,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Default, Clone)]
|
||||
pub struct TextFormat {
|
||||
pub r#type: TextFormatType,
|
||||
pub strict: bool,
|
||||
pub schema: Value,
|
||||
pub name: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Default, Clone)]
|
||||
pub struct TextControls {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub verbosity: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub format: Option<TextFormat>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum ResponseEvent {
|
||||
Created,
|
||||
OutputItemDone(ResponseItem),
|
||||
OutputItemAdded(ResponseItem),
|
||||
Completed {
|
||||
response_id: String,
|
||||
token_usage: Option<TokenUsage>,
|
||||
},
|
||||
OutputTextDelta(String),
|
||||
ReasoningSummaryDelta(String),
|
||||
ReasoningContentDelta(String),
|
||||
ReasoningSummaryPartAdded,
|
||||
RateLimits(RateLimitSnapshot),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct EventStream<T> {
|
||||
pub(crate) rx_event: mpsc::Receiver<T>,
|
||||
}
|
||||
|
||||
impl<T> EventStream<T> {
|
||||
pub fn from_receiver(rx_event: mpsc::Receiver<T>) -> Self {
|
||||
Self { rx_event }
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Stream for EventStream<T> {
|
||||
type Item = T;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
self.rx_event.poll_recv(cx)
|
||||
}
|
||||
}
|
||||
|
||||
pub type ResponseStream = EventStream<Result<ResponseEvent>>;
|
||||
|
||||
pub type WireEvent = ResponseEvent;
|
||||
pub type WireResponseStream = ResponseStream;
|
||||
@@ -22,6 +22,8 @@ chrono = { workspace = true, features = ["serde"] }
|
||||
codex-app-server-protocol = { workspace = true }
|
||||
codex-apply-patch = { workspace = true }
|
||||
codex-async-utils = { workspace = true }
|
||||
codex-api-client = { workspace = true }
|
||||
codex-provider-config = { workspace = true }
|
||||
codex-file-search = { workspace = true }
|
||||
codex-git = { workspace = true }
|
||||
codex-keyring-store = { workspace = true }
|
||||
|
||||
167
codex-rs/core/src/aggregate.rs
Normal file
167
codex-rs/core/src/aggregate.rs
Normal file
@@ -0,0 +1,167 @@
|
||||
use std::collections::VecDeque;
|
||||
use std::pin::Pin;
|
||||
use std::task::Context;
|
||||
use std::task::Poll;
|
||||
|
||||
use crate::ContentItem;
|
||||
use crate::ResponseItem;
|
||||
use futures::Stream;
|
||||
|
||||
use crate::ResponseEvent;
|
||||
use crate::error::Result;
|
||||
|
||||
pub trait AggregateStreamExt: Stream<Item = Result<ResponseEvent>> + Sized {
|
||||
fn aggregate(self) -> AggregatedChatStream<Self>
|
||||
where
|
||||
Self: Unpin,
|
||||
{
|
||||
AggregatedChatStream::new(self, AggregateMode::AggregatedOnly)
|
||||
}
|
||||
|
||||
fn streaming_mode(self) -> AggregatedChatStream<Self>
|
||||
where
|
||||
Self: Unpin,
|
||||
{
|
||||
AggregatedChatStream::new(self, AggregateMode::Streaming)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> AggregateStreamExt for S where S: Stream<Item = Result<ResponseEvent>> + Sized + Unpin {}
|
||||
|
||||
enum AggregateMode {
|
||||
AggregatedOnly,
|
||||
Streaming,
|
||||
}
|
||||
|
||||
pub struct AggregatedChatStream<S> {
|
||||
inner: S,
|
||||
cumulative: String,
|
||||
cumulative_reasoning: String,
|
||||
pending: VecDeque<ResponseEvent>,
|
||||
mode: AggregateMode,
|
||||
}
|
||||
|
||||
impl<S> AggregatedChatStream<S>
|
||||
where
|
||||
S: Stream<Item = Result<ResponseEvent>> + Unpin,
|
||||
{
|
||||
fn new(inner: S, mode: AggregateMode) -> Self {
|
||||
Self {
|
||||
inner,
|
||||
cumulative: String::new(),
|
||||
cumulative_reasoning: String::new(),
|
||||
pending: VecDeque::new(),
|
||||
mode,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> Stream for AggregatedChatStream<S>
|
||||
where
|
||||
S: Stream<Item = Result<ResponseEvent>> + Unpin,
|
||||
{
|
||||
type Item = Result<ResponseEvent>;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
if let Some(ev) = self.pending.pop_front() {
|
||||
return Poll::Ready(Some(Ok(ev)));
|
||||
}
|
||||
|
||||
loop {
|
||||
match Pin::new(&mut self.inner).poll_next(cx) {
|
||||
Poll::Pending => return Poll::Pending,
|
||||
Poll::Ready(None) => return Poll::Ready(None),
|
||||
Poll::Ready(Some(Err(err))) => {
|
||||
return Poll::Ready(Some(Err(err)));
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item)))) => {
|
||||
let is_assistant_message = matches!(
|
||||
&item,
|
||||
ResponseItem::Message { role, .. } if role == "assistant"
|
||||
);
|
||||
|
||||
if is_assistant_message {
|
||||
if let ResponseItem::Message { role, content, .. } = item {
|
||||
let mut text = String::new();
|
||||
for c in content {
|
||||
match c {
|
||||
ContentItem::InputText { text: t }
|
||||
| ContentItem::OutputText { text: t } => text.push_str(&t),
|
||||
ContentItem::InputImage { image_url } => {
|
||||
text.push_str(&image_url)
|
||||
}
|
||||
}
|
||||
}
|
||||
self.cumulative.push_str(&text);
|
||||
if matches!(self.mode, AggregateMode::Streaming) {
|
||||
let output_item =
|
||||
ResponseEvent::OutputItemDone(ResponseItem::Message {
|
||||
id: None,
|
||||
role,
|
||||
content: vec![ContentItem::OutputText {
|
||||
text: self.cumulative.clone(),
|
||||
}],
|
||||
});
|
||||
self.pending.push_back(output_item);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item))));
|
||||
}
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::OutputItemAdded(item)))) => {
|
||||
if !matches!(
|
||||
&item,
|
||||
ResponseItem::Message { role, .. } if role == "assistant"
|
||||
) {
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::OutputItemAdded(item))));
|
||||
}
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::ReasoningContentDelta(delta)))) => {
|
||||
self.cumulative_reasoning.push_str(&delta);
|
||||
if matches!(self.mode, AggregateMode::Streaming) {
|
||||
let ev =
|
||||
ResponseEvent::ReasoningContentDelta(self.cumulative_reasoning.clone());
|
||||
self.pending.push_back(ev);
|
||||
}
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::ReasoningSummaryDelta(delta)))) => {
|
||||
if matches!(self.mode, AggregateMode::Streaming) {
|
||||
let ev = ResponseEvent::ReasoningSummaryDelta(delta);
|
||||
self.pending.push_back(ev);
|
||||
}
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::Completed {
|
||||
response_id,
|
||||
token_usage,
|
||||
}))) => {
|
||||
let assistant_event = ResponseEvent::OutputItemDone(ResponseItem::Message {
|
||||
id: None,
|
||||
role: "assistant".to_string(),
|
||||
content: vec![ContentItem::OutputText {
|
||||
text: self.cumulative.clone(),
|
||||
}],
|
||||
});
|
||||
let completion_event = ResponseEvent::Completed {
|
||||
response_id,
|
||||
token_usage,
|
||||
};
|
||||
|
||||
if matches!(self.mode, AggregateMode::Streaming) {
|
||||
self.pending.push_back(assistant_event);
|
||||
self.pending.push_back(completion_event);
|
||||
} else {
|
||||
return Poll::Ready(Some(Ok(assistant_event)));
|
||||
}
|
||||
}
|
||||
Poll::Ready(Some(Ok(ev))) => {
|
||||
return Poll::Ready(Some(Ok(ev)));
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(ev) = self.pending.pop_front() {
|
||||
return Poll::Ready(Some(Ok(ev)));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -219,12 +219,10 @@ impl CodexAuth {
|
||||
InternalKnownPlan::Edu => AccountPlanType::Edu,
|
||||
};
|
||||
|
||||
self.get_current_token_data()
|
||||
.and_then(|t| t.id_token.chatgpt_plan_type)
|
||||
.map(|pt| match pt {
|
||||
InternalPlanType::Known(k) => map_known(&k),
|
||||
InternalPlanType::Unknown(_) => AccountPlanType::Unknown,
|
||||
})
|
||||
self.get_plan_type().map(|pt| match pt {
|
||||
InternalPlanType::Known(k) => map_known(&k),
|
||||
InternalPlanType::Unknown(_) => AccountPlanType::Unknown,
|
||||
})
|
||||
}
|
||||
|
||||
/// Raw internal plan value from the ID token.
|
||||
|
||||
@@ -1,967 +0,0 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::ModelProviderInfo;
|
||||
use crate::client_common::Prompt;
|
||||
use crate::client_common::ResponseEvent;
|
||||
use crate::client_common::ResponseStream;
|
||||
use crate::default_client::CodexHttpClient;
|
||||
use crate::error::CodexErr;
|
||||
use crate::error::ConnectionFailedError;
|
||||
use crate::error::ResponseStreamFailed;
|
||||
use crate::error::Result;
|
||||
use crate::error::RetryLimitReachedError;
|
||||
use crate::error::UnexpectedResponseError;
|
||||
use crate::model_family::ModelFamily;
|
||||
use crate::tools::spec::create_tools_json_for_chat_completions_api;
|
||||
use crate::util::backoff;
|
||||
use bytes::Bytes;
|
||||
use codex_otel::otel_event_manager::OtelEventManager;
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::FunctionCallOutputContentItem;
|
||||
use codex_protocol::models::ReasoningItemContent;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::protocol::SessionSource;
|
||||
use codex_protocol::protocol::SubAgentSource;
|
||||
use eventsource_stream::Eventsource;
|
||||
use futures::Stream;
|
||||
use futures::StreamExt;
|
||||
use futures::TryStreamExt;
|
||||
use reqwest::StatusCode;
|
||||
use serde_json::json;
|
||||
use std::pin::Pin;
|
||||
use std::task::Context;
|
||||
use std::task::Poll;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::time::timeout;
|
||||
use tracing::debug;
|
||||
use tracing::trace;
|
||||
|
||||
/// Implementation for the classic Chat Completions API.
|
||||
pub(crate) async fn stream_chat_completions(
|
||||
prompt: &Prompt,
|
||||
model_family: &ModelFamily,
|
||||
client: &CodexHttpClient,
|
||||
provider: &ModelProviderInfo,
|
||||
otel_event_manager: &OtelEventManager,
|
||||
session_source: &SessionSource,
|
||||
) -> Result<ResponseStream> {
|
||||
if prompt.output_schema.is_some() {
|
||||
return Err(CodexErr::UnsupportedOperation(
|
||||
"output_schema is not supported for Chat Completions API".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// Build messages array
|
||||
let mut messages = Vec::<serde_json::Value>::new();
|
||||
|
||||
let full_instructions = prompt.get_full_instructions(model_family);
|
||||
messages.push(json!({"role": "system", "content": full_instructions}));
|
||||
|
||||
let input = prompt.get_formatted_input();
|
||||
|
||||
// Pre-scan: map Reasoning blocks to the adjacent assistant anchor after the last user.
|
||||
// - If the last emitted message is a user message, drop all reasoning.
|
||||
// - Otherwise, for each Reasoning item after the last user message, attach it
|
||||
// to the immediate previous assistant message (stop turns) or the immediate
|
||||
// next assistant anchor (tool-call turns: function/local shell call, or assistant message).
|
||||
let mut reasoning_by_anchor_index: std::collections::HashMap<usize, String> =
|
||||
std::collections::HashMap::new();
|
||||
|
||||
// Determine the last role that would be emitted to Chat Completions.
|
||||
let mut last_emitted_role: Option<&str> = None;
|
||||
for item in &input {
|
||||
match item {
|
||||
ResponseItem::Message { role, .. } => last_emitted_role = Some(role.as_str()),
|
||||
ResponseItem::FunctionCall { .. } | ResponseItem::LocalShellCall { .. } => {
|
||||
last_emitted_role = Some("assistant")
|
||||
}
|
||||
ResponseItem::FunctionCallOutput { .. } => last_emitted_role = Some("tool"),
|
||||
ResponseItem::Reasoning { .. } | ResponseItem::Other => {}
|
||||
ResponseItem::CustomToolCall { .. } => {}
|
||||
ResponseItem::CustomToolCallOutput { .. } => {}
|
||||
ResponseItem::WebSearchCall { .. } => {}
|
||||
ResponseItem::GhostSnapshot { .. } => {}
|
||||
}
|
||||
}
|
||||
|
||||
// Find the last user message index in the input.
|
||||
let mut last_user_index: Option<usize> = None;
|
||||
for (idx, item) in input.iter().enumerate() {
|
||||
if let ResponseItem::Message { role, .. } = item
|
||||
&& role == "user"
|
||||
{
|
||||
last_user_index = Some(idx);
|
||||
}
|
||||
}
|
||||
|
||||
// Attach reasoning only if the conversation does not end with a user message.
|
||||
if !matches!(last_emitted_role, Some("user")) {
|
||||
for (idx, item) in input.iter().enumerate() {
|
||||
// Only consider reasoning that appears after the last user message.
|
||||
if let Some(u_idx) = last_user_index
|
||||
&& idx <= u_idx
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
if let ResponseItem::Reasoning {
|
||||
content: Some(items),
|
||||
..
|
||||
} = item
|
||||
{
|
||||
let mut text = String::new();
|
||||
for entry in items {
|
||||
match entry {
|
||||
ReasoningItemContent::ReasoningText { text: segment }
|
||||
| ReasoningItemContent::Text { text: segment } => text.push_str(segment),
|
||||
}
|
||||
}
|
||||
if text.trim().is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Prefer immediate previous assistant message (stop turns)
|
||||
let mut attached = false;
|
||||
if idx > 0
|
||||
&& let ResponseItem::Message { role, .. } = &input[idx - 1]
|
||||
&& role == "assistant"
|
||||
{
|
||||
reasoning_by_anchor_index
|
||||
.entry(idx - 1)
|
||||
.and_modify(|v| v.push_str(&text))
|
||||
.or_insert(text.clone());
|
||||
attached = true;
|
||||
}
|
||||
|
||||
// Otherwise, attach to immediate next assistant anchor (tool-calls or assistant message)
|
||||
if !attached && idx + 1 < input.len() {
|
||||
match &input[idx + 1] {
|
||||
ResponseItem::FunctionCall { .. } | ResponseItem::LocalShellCall { .. } => {
|
||||
reasoning_by_anchor_index
|
||||
.entry(idx + 1)
|
||||
.and_modify(|v| v.push_str(&text))
|
||||
.or_insert(text.clone());
|
||||
}
|
||||
ResponseItem::Message { role, .. } if role == "assistant" => {
|
||||
reasoning_by_anchor_index
|
||||
.entry(idx + 1)
|
||||
.and_modify(|v| v.push_str(&text))
|
||||
.or_insert(text.clone());
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Track last assistant text we emitted to avoid duplicate assistant messages
|
||||
// in the outbound Chat Completions payload (can happen if a final
|
||||
// aggregated assistant message was recorded alongside an earlier partial).
|
||||
let mut last_assistant_text: Option<String> = None;
|
||||
|
||||
for (idx, item) in input.iter().enumerate() {
|
||||
match item {
|
||||
ResponseItem::Message { role, content, .. } => {
|
||||
// Build content either as a plain string (typical for assistant text)
|
||||
// or as an array of content items when images are present (user/tool multimodal).
|
||||
let mut text = String::new();
|
||||
let mut items: Vec<serde_json::Value> = Vec::new();
|
||||
let mut saw_image = false;
|
||||
|
||||
for c in content {
|
||||
match c {
|
||||
ContentItem::InputText { text: t }
|
||||
| ContentItem::OutputText { text: t } => {
|
||||
text.push_str(t);
|
||||
items.push(json!({"type":"text","text": t}));
|
||||
}
|
||||
ContentItem::InputImage { image_url } => {
|
||||
saw_image = true;
|
||||
items.push(json!({"type":"image_url","image_url": {"url": image_url}}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Skip exact-duplicate assistant messages.
|
||||
if role == "assistant" {
|
||||
if let Some(prev) = &last_assistant_text
|
||||
&& prev == &text
|
||||
{
|
||||
continue;
|
||||
}
|
||||
last_assistant_text = Some(text.clone());
|
||||
}
|
||||
|
||||
// For assistant messages, always send a plain string for compatibility.
|
||||
// For user messages, if an image is present, send an array of content items.
|
||||
let content_value = if role == "assistant" {
|
||||
json!(text)
|
||||
} else if saw_image {
|
||||
json!(items)
|
||||
} else {
|
||||
json!(text)
|
||||
};
|
||||
|
||||
let mut msg = json!({"role": role, "content": content_value});
|
||||
if role == "assistant"
|
||||
&& let Some(reasoning) = reasoning_by_anchor_index.get(&idx)
|
||||
&& let Some(obj) = msg.as_object_mut()
|
||||
{
|
||||
obj.insert("reasoning".to_string(), json!(reasoning));
|
||||
}
|
||||
messages.push(msg);
|
||||
}
|
||||
ResponseItem::FunctionCall {
|
||||
name,
|
||||
arguments,
|
||||
call_id,
|
||||
..
|
||||
} => {
|
||||
let mut msg = json!({
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"tool_calls": [{
|
||||
"id": call_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": name,
|
||||
"arguments": arguments,
|
||||
}
|
||||
}]
|
||||
});
|
||||
if let Some(reasoning) = reasoning_by_anchor_index.get(&idx)
|
||||
&& let Some(obj) = msg.as_object_mut()
|
||||
{
|
||||
obj.insert("reasoning".to_string(), json!(reasoning));
|
||||
}
|
||||
messages.push(msg);
|
||||
}
|
||||
ResponseItem::LocalShellCall {
|
||||
id,
|
||||
call_id: _,
|
||||
status,
|
||||
action,
|
||||
} => {
|
||||
// Confirm with API team.
|
||||
let mut msg = json!({
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"tool_calls": [{
|
||||
"id": id.clone().unwrap_or_else(|| "".to_string()),
|
||||
"type": "local_shell_call",
|
||||
"status": status,
|
||||
"action": action,
|
||||
}]
|
||||
});
|
||||
if let Some(reasoning) = reasoning_by_anchor_index.get(&idx)
|
||||
&& let Some(obj) = msg.as_object_mut()
|
||||
{
|
||||
obj.insert("reasoning".to_string(), json!(reasoning));
|
||||
}
|
||||
messages.push(msg);
|
||||
}
|
||||
ResponseItem::FunctionCallOutput { call_id, output } => {
|
||||
// Prefer structured content items when available (e.g., images)
|
||||
// otherwise fall back to the legacy plain-string content.
|
||||
let content_value = if let Some(items) = &output.content_items {
|
||||
let mapped: Vec<serde_json::Value> = items
|
||||
.iter()
|
||||
.map(|it| match it {
|
||||
FunctionCallOutputContentItem::InputText { text } => {
|
||||
json!({"type":"text","text": text})
|
||||
}
|
||||
FunctionCallOutputContentItem::InputImage { image_url } => {
|
||||
json!({"type":"image_url","image_url": {"url": image_url}})
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
json!(mapped)
|
||||
} else {
|
||||
json!(output.content)
|
||||
};
|
||||
|
||||
messages.push(json!({
|
||||
"role": "tool",
|
||||
"tool_call_id": call_id,
|
||||
"content": content_value,
|
||||
}));
|
||||
}
|
||||
ResponseItem::CustomToolCall {
|
||||
id,
|
||||
call_id: _,
|
||||
name,
|
||||
input,
|
||||
status: _,
|
||||
} => {
|
||||
messages.push(json!({
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"tool_calls": [{
|
||||
"id": id,
|
||||
"type": "custom",
|
||||
"custom": {
|
||||
"name": name,
|
||||
"input": input,
|
||||
}
|
||||
}]
|
||||
}));
|
||||
}
|
||||
ResponseItem::CustomToolCallOutput { call_id, output } => {
|
||||
messages.push(json!({
|
||||
"role": "tool",
|
||||
"tool_call_id": call_id,
|
||||
"content": output,
|
||||
}));
|
||||
}
|
||||
ResponseItem::GhostSnapshot { .. } => {
|
||||
// Ghost snapshots annotate history but are not sent to the model.
|
||||
continue;
|
||||
}
|
||||
ResponseItem::Reasoning { .. }
|
||||
| ResponseItem::WebSearchCall { .. }
|
||||
| ResponseItem::Other => {
|
||||
// Omit these items from the conversation history.
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let tools_json = create_tools_json_for_chat_completions_api(&prompt.tools)?;
|
||||
let payload = json!({
|
||||
"model": model_family.slug,
|
||||
"messages": messages,
|
||||
"stream": true,
|
||||
"tools": tools_json,
|
||||
});
|
||||
|
||||
debug!(
|
||||
"POST to {}: {}",
|
||||
provider.get_full_url(&None),
|
||||
serde_json::to_string_pretty(&payload).unwrap_or_default()
|
||||
);
|
||||
|
||||
let mut attempt = 0;
|
||||
let max_retries = provider.request_max_retries();
|
||||
loop {
|
||||
attempt += 1;
|
||||
|
||||
let mut req_builder = provider.create_request_builder(client, &None).await?;
|
||||
|
||||
// Include subagent header only for subagent sessions.
|
||||
if let SessionSource::SubAgent(sub) = session_source.clone() {
|
||||
let subagent = if let SubAgentSource::Other(label) = sub {
|
||||
label
|
||||
} else {
|
||||
serde_json::to_value(&sub)
|
||||
.ok()
|
||||
.and_then(|v| v.as_str().map(std::string::ToString::to_string))
|
||||
.unwrap_or_else(|| "other".to_string())
|
||||
};
|
||||
req_builder = req_builder.header("x-openai-subagent", subagent);
|
||||
}
|
||||
|
||||
let res = otel_event_manager
|
||||
.log_request(attempt, || {
|
||||
req_builder
|
||||
.header(reqwest::header::ACCEPT, "text/event-stream")
|
||||
.json(&payload)
|
||||
.send()
|
||||
})
|
||||
.await;
|
||||
|
||||
match res {
|
||||
Ok(resp) if resp.status().is_success() => {
|
||||
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent>>(1600);
|
||||
let stream = resp.bytes_stream().map_err(|e| {
|
||||
CodexErr::ResponseStreamFailed(ResponseStreamFailed {
|
||||
source: e,
|
||||
request_id: None,
|
||||
})
|
||||
});
|
||||
tokio::spawn(process_chat_sse(
|
||||
stream,
|
||||
tx_event,
|
||||
provider.stream_idle_timeout(),
|
||||
otel_event_manager.clone(),
|
||||
));
|
||||
return Ok(ResponseStream { rx_event });
|
||||
}
|
||||
Ok(res) => {
|
||||
let status = res.status();
|
||||
if !(status == StatusCode::TOO_MANY_REQUESTS || status.is_server_error()) {
|
||||
let body = (res.text().await).unwrap_or_default();
|
||||
return Err(CodexErr::UnexpectedStatus(UnexpectedResponseError {
|
||||
status,
|
||||
body,
|
||||
request_id: None,
|
||||
}));
|
||||
}
|
||||
|
||||
if attempt > max_retries {
|
||||
return Err(CodexErr::RetryLimit(RetryLimitReachedError {
|
||||
status,
|
||||
request_id: None,
|
||||
}));
|
||||
}
|
||||
|
||||
let retry_after_secs = res
|
||||
.headers()
|
||||
.get(reqwest::header::RETRY_AFTER)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|s| s.parse::<u64>().ok());
|
||||
|
||||
let delay = retry_after_secs
|
||||
.map(|s| Duration::from_millis(s * 1_000))
|
||||
.unwrap_or_else(|| backoff(attempt));
|
||||
tokio::time::sleep(delay).await;
|
||||
}
|
||||
Err(e) => {
|
||||
if attempt > max_retries {
|
||||
return Err(CodexErr::ConnectionFailed(ConnectionFailedError {
|
||||
source: e,
|
||||
}));
|
||||
}
|
||||
let delay = backoff(attempt);
|
||||
tokio::time::sleep(delay).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn append_assistant_text(
|
||||
tx_event: &mpsc::Sender<Result<ResponseEvent>>,
|
||||
assistant_item: &mut Option<ResponseItem>,
|
||||
text: String,
|
||||
) {
|
||||
if assistant_item.is_none() {
|
||||
let item = ResponseItem::Message {
|
||||
id: None,
|
||||
role: "assistant".to_string(),
|
||||
content: vec![],
|
||||
};
|
||||
*assistant_item = Some(item.clone());
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::OutputItemAdded(item)))
|
||||
.await;
|
||||
}
|
||||
|
||||
if let Some(ResponseItem::Message { content, .. }) = assistant_item {
|
||||
content.push(ContentItem::OutputText { text: text.clone() });
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::OutputTextDelta(text.clone())))
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn append_reasoning_text(
|
||||
tx_event: &mpsc::Sender<Result<ResponseEvent>>,
|
||||
reasoning_item: &mut Option<ResponseItem>,
|
||||
text: String,
|
||||
) {
|
||||
if reasoning_item.is_none() {
|
||||
let item = ResponseItem::Reasoning {
|
||||
id: String::new(),
|
||||
summary: Vec::new(),
|
||||
content: Some(vec![]),
|
||||
encrypted_content: None,
|
||||
};
|
||||
*reasoning_item = Some(item.clone());
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::OutputItemAdded(item)))
|
||||
.await;
|
||||
}
|
||||
|
||||
if let Some(ResponseItem::Reasoning {
|
||||
content: Some(content),
|
||||
..
|
||||
}) = reasoning_item
|
||||
{
|
||||
content.push(ReasoningItemContent::ReasoningText { text: text.clone() });
|
||||
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::ReasoningContentDelta(text.clone())))
|
||||
.await;
|
||||
}
|
||||
}
|
||||
/// Lightweight SSE processor for the Chat Completions streaming format. The
|
||||
/// output is mapped onto Codex's internal [`ResponseEvent`] so that the rest
|
||||
/// of the pipeline can stay agnostic of the underlying wire format.
|
||||
async fn process_chat_sse<S>(
|
||||
stream: S,
|
||||
tx_event: mpsc::Sender<Result<ResponseEvent>>,
|
||||
idle_timeout: Duration,
|
||||
otel_event_manager: OtelEventManager,
|
||||
) where
|
||||
S: Stream<Item = Result<Bytes>> + Unpin,
|
||||
{
|
||||
let mut stream = stream.eventsource();
|
||||
|
||||
// State to accumulate a function call across streaming chunks.
|
||||
// OpenAI may split the `arguments` string over multiple `delta` events
|
||||
// until the chunk whose `finish_reason` is `tool_calls` is emitted. We
|
||||
// keep collecting the pieces here and forward a single
|
||||
// `ResponseItem::FunctionCall` once the call is complete.
|
||||
#[derive(Default)]
|
||||
struct FunctionCallState {
|
||||
name: Option<String>,
|
||||
arguments: String,
|
||||
call_id: Option<String>,
|
||||
active: bool,
|
||||
}
|
||||
|
||||
let mut fn_call_state = FunctionCallState::default();
|
||||
let mut assistant_item: Option<ResponseItem> = None;
|
||||
let mut reasoning_item: Option<ResponseItem> = None;
|
||||
|
||||
loop {
|
||||
let start = std::time::Instant::now();
|
||||
let response = timeout(idle_timeout, stream.next()).await;
|
||||
let duration = start.elapsed();
|
||||
otel_event_manager.log_sse_event(&response, duration);
|
||||
|
||||
let sse = match response {
|
||||
Ok(Some(Ok(ev))) => ev,
|
||||
Ok(Some(Err(e))) => {
|
||||
let _ = tx_event
|
||||
.send(Err(CodexErr::Stream(e.to_string(), None)))
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
Ok(None) => {
|
||||
// Stream closed gracefully – emit Completed with dummy id.
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::Completed {
|
||||
response_id: String::new(),
|
||||
token_usage: None,
|
||||
}))
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
Err(_) => {
|
||||
let _ = tx_event
|
||||
.send(Err(CodexErr::Stream(
|
||||
"idle timeout waiting for SSE".into(),
|
||||
None,
|
||||
)))
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// OpenAI Chat streaming sends a literal string "[DONE]" when finished.
|
||||
if sse.data.trim() == "[DONE]" {
|
||||
// Emit any finalized items before closing so downstream consumers receive
|
||||
// terminal events for both assistant content and raw reasoning.
|
||||
if let Some(item) = assistant_item {
|
||||
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
|
||||
}
|
||||
|
||||
if let Some(item) = reasoning_item {
|
||||
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
|
||||
}
|
||||
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::Completed {
|
||||
response_id: String::new(),
|
||||
token_usage: None,
|
||||
}))
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
|
||||
// Parse JSON chunk
|
||||
let chunk: serde_json::Value = match serde_json::from_str(&sse.data) {
|
||||
Ok(v) => v,
|
||||
Err(_) => continue,
|
||||
};
|
||||
trace!("chat_completions received SSE chunk: {chunk:?}");
|
||||
|
||||
let choice_opt = chunk.get("choices").and_then(|c| c.get(0));
|
||||
|
||||
if let Some(choice) = choice_opt {
|
||||
// Handle assistant content tokens as streaming deltas.
|
||||
if let Some(content) = choice
|
||||
.get("delta")
|
||||
.and_then(|d| d.get("content"))
|
||||
.and_then(|c| c.as_str())
|
||||
&& !content.is_empty()
|
||||
{
|
||||
append_assistant_text(&tx_event, &mut assistant_item, content.to_string()).await;
|
||||
}
|
||||
|
||||
// Forward any reasoning/thinking deltas if present.
|
||||
// Some providers stream `reasoning` as a plain string while others
|
||||
// nest the text under an object (e.g. `{ "reasoning": { "text": "…" } }`).
|
||||
if let Some(reasoning_val) = choice.get("delta").and_then(|d| d.get("reasoning")) {
|
||||
let mut maybe_text = reasoning_val
|
||||
.as_str()
|
||||
.map(str::to_string)
|
||||
.filter(|s| !s.is_empty());
|
||||
|
||||
if maybe_text.is_none() && reasoning_val.is_object() {
|
||||
if let Some(s) = reasoning_val
|
||||
.get("text")
|
||||
.and_then(|t| t.as_str())
|
||||
.filter(|s| !s.is_empty())
|
||||
{
|
||||
maybe_text = Some(s.to_string());
|
||||
} else if let Some(s) = reasoning_val
|
||||
.get("content")
|
||||
.and_then(|t| t.as_str())
|
||||
.filter(|s| !s.is_empty())
|
||||
{
|
||||
maybe_text = Some(s.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(reasoning) = maybe_text {
|
||||
// Accumulate so we can emit a terminal Reasoning item at the end.
|
||||
append_reasoning_text(&tx_event, &mut reasoning_item, reasoning).await;
|
||||
}
|
||||
}
|
||||
|
||||
// Some providers only include reasoning on the final message object.
|
||||
if let Some(message_reasoning) = choice.get("message").and_then(|m| m.get("reasoning"))
|
||||
{
|
||||
// Accept either a plain string or an object with { text | content }
|
||||
if let Some(s) = message_reasoning.as_str() {
|
||||
if !s.is_empty() {
|
||||
append_reasoning_text(&tx_event, &mut reasoning_item, s.to_string()).await;
|
||||
}
|
||||
} else if let Some(obj) = message_reasoning.as_object()
|
||||
&& let Some(s) = obj
|
||||
.get("text")
|
||||
.and_then(|v| v.as_str())
|
||||
.or_else(|| obj.get("content").and_then(|v| v.as_str()))
|
||||
&& !s.is_empty()
|
||||
{
|
||||
append_reasoning_text(&tx_event, &mut reasoning_item, s.to_string()).await;
|
||||
}
|
||||
}
|
||||
|
||||
// Handle streaming function / tool calls.
|
||||
if let Some(tool_calls) = choice
|
||||
.get("delta")
|
||||
.and_then(|d| d.get("tool_calls"))
|
||||
.and_then(|tc| tc.as_array())
|
||||
&& let Some(tool_call) = tool_calls.first()
|
||||
{
|
||||
// Mark that we have an active function call in progress.
|
||||
fn_call_state.active = true;
|
||||
|
||||
// Extract call_id if present.
|
||||
if let Some(id) = tool_call.get("id").and_then(|v| v.as_str()) {
|
||||
fn_call_state.call_id.get_or_insert_with(|| id.to_string());
|
||||
}
|
||||
|
||||
// Extract function details if present.
|
||||
if let Some(function) = tool_call.get("function") {
|
||||
if let Some(name) = function.get("name").and_then(|n| n.as_str()) {
|
||||
fn_call_state.name.get_or_insert_with(|| name.to_string());
|
||||
}
|
||||
|
||||
if let Some(args_fragment) = function.get("arguments").and_then(|a| a.as_str())
|
||||
{
|
||||
fn_call_state.arguments.push_str(args_fragment);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Emit end-of-turn when finish_reason signals completion.
|
||||
if let Some(finish_reason) = choice.get("finish_reason").and_then(|v| v.as_str()) {
|
||||
match finish_reason {
|
||||
"tool_calls" if fn_call_state.active => {
|
||||
// First, flush the terminal raw reasoning so UIs can finalize
|
||||
// the reasoning stream before any exec/tool events begin.
|
||||
if let Some(item) = reasoning_item.take() {
|
||||
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
|
||||
}
|
||||
|
||||
// Then emit the FunctionCall response item.
|
||||
let item = ResponseItem::FunctionCall {
|
||||
id: None,
|
||||
name: fn_call_state.name.clone().unwrap_or_else(|| "".to_string()),
|
||||
arguments: fn_call_state.arguments.clone(),
|
||||
call_id: fn_call_state.call_id.clone().unwrap_or_else(String::new),
|
||||
};
|
||||
|
||||
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
|
||||
}
|
||||
"stop" => {
|
||||
// Regular turn without tool-call. Emit the final assistant message
|
||||
// as a single OutputItemDone so non-delta consumers see the result.
|
||||
if let Some(item) = assistant_item.take() {
|
||||
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
|
||||
}
|
||||
// Also emit a terminal Reasoning item so UIs can finalize raw reasoning.
|
||||
if let Some(item) = reasoning_item.take() {
|
||||
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
// Emit Completed regardless of reason so the agent can advance.
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::Completed {
|
||||
response_id: String::new(),
|
||||
token_usage: None,
|
||||
}))
|
||||
.await;
|
||||
|
||||
// Prepare for potential next turn (should not happen in same stream).
|
||||
// fn_call_state = FunctionCallState::default();
|
||||
|
||||
return; // End processing for this SSE stream.
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Optional client-side aggregation helper
|
||||
///
|
||||
/// Stream adapter that merges the incremental `OutputItemDone` chunks coming from
|
||||
/// [`process_chat_sse`] into a *running* assistant message, **suppressing the
|
||||
/// per-token deltas**. The stream stays silent while the model is thinking
|
||||
/// and only emits two events per turn:
|
||||
///
|
||||
/// 1. `ResponseEvent::OutputItemDone` with the *complete* assistant message
|
||||
/// (fully concatenated).
|
||||
/// 2. The original `ResponseEvent::Completed` right after it.
|
||||
///
|
||||
/// This mirrors the behaviour the TypeScript CLI exposes to its higher layers.
|
||||
///
|
||||
/// The adapter is intentionally *lossless*: callers who do **not** opt in via
|
||||
/// [`AggregateStreamExt::aggregate()`] keep receiving the original unmodified
|
||||
/// events.
|
||||
#[derive(Copy, Clone, Eq, PartialEq)]
|
||||
enum AggregateMode {
|
||||
AggregatedOnly,
|
||||
Streaming,
|
||||
}
|
||||
pub(crate) struct AggregatedChatStream<S> {
|
||||
inner: S,
|
||||
cumulative: String,
|
||||
cumulative_reasoning: String,
|
||||
pending: std::collections::VecDeque<ResponseEvent>,
|
||||
mode: AggregateMode,
|
||||
}
|
||||
|
||||
impl<S> Stream for AggregatedChatStream<S>
|
||||
where
|
||||
S: Stream<Item = Result<ResponseEvent>> + Unpin,
|
||||
{
|
||||
type Item = Result<ResponseEvent>;
|
||||
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
let this = self.get_mut();
|
||||
|
||||
// First, flush any buffered events from the previous call.
|
||||
if let Some(ev) = this.pending.pop_front() {
|
||||
return Poll::Ready(Some(Ok(ev)));
|
||||
}
|
||||
|
||||
loop {
|
||||
match Pin::new(&mut this.inner).poll_next(cx) {
|
||||
Poll::Pending => return Poll::Pending,
|
||||
Poll::Ready(None) => return Poll::Ready(None),
|
||||
Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
|
||||
Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item)))) => {
|
||||
// If this is an incremental assistant message chunk, accumulate but
|
||||
// do NOT emit yet. Forward any other item (e.g. FunctionCall) right
|
||||
// away so downstream consumers see it.
|
||||
|
||||
let is_assistant_message = matches!(
|
||||
&item,
|
||||
codex_protocol::models::ResponseItem::Message { role, .. } if role == "assistant"
|
||||
);
|
||||
|
||||
if is_assistant_message {
|
||||
match this.mode {
|
||||
AggregateMode::AggregatedOnly => {
|
||||
// Only use the final assistant message if we have not
|
||||
// seen any deltas; otherwise, deltas already built the
|
||||
// cumulative text and this would duplicate it.
|
||||
if this.cumulative.is_empty()
|
||||
&& let codex_protocol::models::ResponseItem::Message {
|
||||
content,
|
||||
..
|
||||
} = &item
|
||||
&& let Some(text) = content.iter().find_map(|c| match c {
|
||||
codex_protocol::models::ContentItem::OutputText {
|
||||
text,
|
||||
} => Some(text),
|
||||
_ => None,
|
||||
})
|
||||
{
|
||||
this.cumulative.push_str(text);
|
||||
}
|
||||
// Swallow assistant message here; emit on Completed.
|
||||
continue;
|
||||
}
|
||||
AggregateMode::Streaming => {
|
||||
// In streaming mode, if we have not seen any deltas, forward
|
||||
// the final assistant message directly. If deltas were seen,
|
||||
// suppress the final message to avoid duplication.
|
||||
if this.cumulative.is_empty() {
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(
|
||||
item,
|
||||
))));
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Not an assistant message – forward immediately.
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item))));
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::RateLimits(snapshot)))) => {
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::RateLimits(snapshot))));
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::Completed {
|
||||
response_id,
|
||||
token_usage,
|
||||
}))) => {
|
||||
// Build any aggregated items in the correct order: Reasoning first, then Message.
|
||||
let mut emitted_any = false;
|
||||
|
||||
if !this.cumulative_reasoning.is_empty()
|
||||
&& matches!(this.mode, AggregateMode::AggregatedOnly)
|
||||
{
|
||||
let aggregated_reasoning =
|
||||
codex_protocol::models::ResponseItem::Reasoning {
|
||||
id: String::new(),
|
||||
summary: Vec::new(),
|
||||
content: Some(vec![
|
||||
codex_protocol::models::ReasoningItemContent::ReasoningText {
|
||||
text: std::mem::take(&mut this.cumulative_reasoning),
|
||||
},
|
||||
]),
|
||||
encrypted_content: None,
|
||||
};
|
||||
this.pending
|
||||
.push_back(ResponseEvent::OutputItemDone(aggregated_reasoning));
|
||||
emitted_any = true;
|
||||
}
|
||||
|
||||
// Always emit the final aggregated assistant message when any
|
||||
// content deltas have been observed. In AggregatedOnly mode this
|
||||
// is the sole assistant output; in Streaming mode this finalizes
|
||||
// the streamed deltas into a terminal OutputItemDone so callers
|
||||
// can persist/render the message once per turn.
|
||||
if !this.cumulative.is_empty() {
|
||||
let aggregated_message = codex_protocol::models::ResponseItem::Message {
|
||||
id: None,
|
||||
role: "assistant".to_string(),
|
||||
content: vec![codex_protocol::models::ContentItem::OutputText {
|
||||
text: std::mem::take(&mut this.cumulative),
|
||||
}],
|
||||
};
|
||||
this.pending
|
||||
.push_back(ResponseEvent::OutputItemDone(aggregated_message));
|
||||
emitted_any = true;
|
||||
}
|
||||
|
||||
// Always emit Completed last when anything was aggregated.
|
||||
if emitted_any {
|
||||
this.pending.push_back(ResponseEvent::Completed {
|
||||
response_id: response_id.clone(),
|
||||
token_usage: token_usage.clone(),
|
||||
});
|
||||
// Return the first pending event now.
|
||||
if let Some(ev) = this.pending.pop_front() {
|
||||
return Poll::Ready(Some(Ok(ev)));
|
||||
}
|
||||
}
|
||||
|
||||
// Nothing aggregated – forward Completed directly.
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::Completed {
|
||||
response_id,
|
||||
token_usage,
|
||||
})));
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::Created))) => {
|
||||
// These events are exclusive to the Responses API and
|
||||
// will never appear in a Chat Completions stream.
|
||||
continue;
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::OutputTextDelta(delta)))) => {
|
||||
// Always accumulate deltas so we can emit a final OutputItemDone at Completed.
|
||||
this.cumulative.push_str(&delta);
|
||||
if matches!(this.mode, AggregateMode::Streaming) {
|
||||
// In streaming mode, also forward the delta immediately.
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::OutputTextDelta(delta))));
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::ReasoningContentDelta(delta)))) => {
|
||||
// Always accumulate reasoning deltas so we can emit a final Reasoning item at Completed.
|
||||
this.cumulative_reasoning.push_str(&delta);
|
||||
if matches!(this.mode, AggregateMode::Streaming) {
|
||||
// In streaming mode, also forward the delta immediately.
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::ReasoningContentDelta(delta))));
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::ReasoningSummaryDelta(_)))) => {
|
||||
continue;
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::ReasoningSummaryPartAdded))) => {
|
||||
continue;
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::OutputItemAdded(item)))) => {
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::OutputItemAdded(item))));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Extension trait that activates aggregation on any stream of [`ResponseEvent`].
|
||||
pub(crate) trait AggregateStreamExt: Stream<Item = Result<ResponseEvent>> + Sized {
|
||||
/// Returns a new stream that emits **only** the final assistant message
|
||||
/// per turn instead of every incremental delta. The produced
|
||||
/// `ResponseEvent` sequence for a typical text turn looks like:
|
||||
///
|
||||
/// ```ignore
|
||||
/// OutputItemDone(<full message>)
|
||||
/// Completed
|
||||
/// ```
|
||||
///
|
||||
/// No other `OutputItemDone` events will be seen by the caller.
|
||||
///
|
||||
/// Usage:
|
||||
///
|
||||
/// ```ignore
|
||||
/// let agg_stream = client.stream(&prompt).await?.aggregate();
|
||||
/// while let Some(event) = agg_stream.next().await {
|
||||
/// // event now contains cumulative text
|
||||
/// }
|
||||
/// ```
|
||||
fn aggregate(self) -> AggregatedChatStream<Self> {
|
||||
AggregatedChatStream::new(self, AggregateMode::AggregatedOnly)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> AggregateStreamExt for T where T: Stream<Item = Result<ResponseEvent>> + Sized {}
|
||||
|
||||
impl<S> AggregatedChatStream<S> {
|
||||
fn new(inner: S, mode: AggregateMode) -> Self {
|
||||
AggregatedChatStream {
|
||||
inner,
|
||||
cumulative: String::new(),
|
||||
cumulative_reasoning: String::new(),
|
||||
pending: std::collections::VecDeque::new(),
|
||||
mode,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn streaming_mode(inner: S) -> Self {
|
||||
Self::new(inner, AggregateMode::Streaming)
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,24 +1,21 @@
|
||||
use crate::client_common::tools::ToolSpec;
|
||||
use crate::error::Result;
|
||||
use crate::model_family::ModelFamily;
|
||||
use crate::protocol::RateLimitSnapshot;
|
||||
use crate::protocol::TokenUsage;
|
||||
use codex_api_client::Reasoning;
|
||||
pub use codex_api_client::ResponseEvent;
|
||||
use codex_api_client::TextControls;
|
||||
use codex_api_client::TextFormat;
|
||||
use codex_api_client::TextFormatType;
|
||||
use codex_apply_patch::APPLY_PATCH_TOOL_INSTRUCTIONS;
|
||||
use codex_protocol::config_types::ReasoningEffort as ReasoningEffortConfig;
|
||||
use codex_protocol::config_types::ReasoningSummary as ReasoningSummaryConfig;
|
||||
use codex_protocol::config_types::Verbosity as VerbosityConfig;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use futures::Stream;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use serde_json::Value;
|
||||
use std::borrow::Cow;
|
||||
use std::collections::HashSet;
|
||||
use std::ops::Deref;
|
||||
use std::pin::Pin;
|
||||
use std::task::Context;
|
||||
use std::task::Poll;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
/// Review thread system prompt. Edit `core/src/review_prompt.md` to customize.
|
||||
pub const REVIEW_PROMPT: &str = include_str!("../review_prompt.md");
|
||||
@@ -193,95 +190,7 @@ fn strip_total_output_header(output: &str) -> Option<&str> {
|
||||
Some(remainder)
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum ResponseEvent {
|
||||
Created,
|
||||
OutputItemDone(ResponseItem),
|
||||
OutputItemAdded(ResponseItem),
|
||||
Completed {
|
||||
response_id: String,
|
||||
token_usage: Option<TokenUsage>,
|
||||
},
|
||||
OutputTextDelta(String),
|
||||
ReasoningSummaryDelta(String),
|
||||
ReasoningContentDelta(String),
|
||||
ReasoningSummaryPartAdded,
|
||||
RateLimits(RateLimitSnapshot),
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub(crate) struct Reasoning {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub(crate) effort: Option<ReasoningEffortConfig>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub(crate) summary: Option<ReasoningSummaryConfig>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Default, Clone)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub(crate) enum TextFormatType {
|
||||
#[default]
|
||||
JsonSchema,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Default, Clone)]
|
||||
pub(crate) struct TextFormat {
|
||||
pub(crate) r#type: TextFormatType,
|
||||
pub(crate) strict: bool,
|
||||
pub(crate) schema: Value,
|
||||
pub(crate) name: String,
|
||||
}
|
||||
|
||||
/// Controls under the `text` field in the Responses API for GPT-5.
|
||||
#[derive(Debug, Serialize, Default, Clone)]
|
||||
pub(crate) struct TextControls {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub(crate) verbosity: Option<OpenAiVerbosity>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub(crate) format: Option<TextFormat>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Default, Clone)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub(crate) enum OpenAiVerbosity {
|
||||
Low,
|
||||
#[default]
|
||||
Medium,
|
||||
High,
|
||||
}
|
||||
|
||||
impl From<VerbosityConfig> for OpenAiVerbosity {
|
||||
fn from(v: VerbosityConfig) -> Self {
|
||||
match v {
|
||||
VerbosityConfig::Low => OpenAiVerbosity::Low,
|
||||
VerbosityConfig::Medium => OpenAiVerbosity::Medium,
|
||||
VerbosityConfig::High => OpenAiVerbosity::High,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Request object that is serialized as JSON and POST'ed when using the
|
||||
/// Responses API.
|
||||
#[derive(Debug, Serialize)]
|
||||
pub(crate) struct ResponsesApiRequest<'a> {
|
||||
pub(crate) model: &'a str,
|
||||
pub(crate) instructions: &'a str,
|
||||
// TODO(mbolin): ResponseItem::Other should not be serialized. Currently,
|
||||
// we code defensively to avoid this case, but perhaps we should use a
|
||||
// separate enum for serialization.
|
||||
pub(crate) input: &'a Vec<ResponseItem>,
|
||||
pub(crate) tools: &'a [serde_json::Value],
|
||||
pub(crate) tool_choice: &'static str,
|
||||
pub(crate) parallel_tool_calls: bool,
|
||||
pub(crate) reasoning: Option<Reasoning>,
|
||||
pub(crate) store: bool,
|
||||
pub(crate) stream: bool,
|
||||
pub(crate) include: Vec<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub(crate) prompt_cache_key: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub(crate) text: Option<TextControls>,
|
||||
}
|
||||
pub type ResponseStream = codex_api_client::EventStream<Result<ResponseEvent>>;
|
||||
|
||||
pub(crate) mod tools {
|
||||
use crate::tools::spec::JsonSchema;
|
||||
@@ -366,7 +275,11 @@ pub(crate) fn create_text_param_for_request(
|
||||
}
|
||||
|
||||
Some(TextControls {
|
||||
verbosity: verbosity.map(std::convert::Into::into),
|
||||
verbosity: verbosity.map(|v| match v {
|
||||
VerbosityConfig::Low => "low".to_string(),
|
||||
VerbosityConfig::Medium => "medium".to_string(),
|
||||
VerbosityConfig::High => "high".to_string(),
|
||||
}),
|
||||
format: output_schema.as_ref().map(|schema| TextFormat {
|
||||
r#type: TextFormatType::JsonSchema,
|
||||
strict: true,
|
||||
@@ -376,18 +289,6 @@ pub(crate) fn create_text_param_for_request(
|
||||
})
|
||||
}
|
||||
|
||||
pub struct ResponseStream {
|
||||
pub(crate) rx_event: mpsc::Receiver<Result<ResponseEvent>>,
|
||||
}
|
||||
|
||||
impl Stream for ResponseStream {
|
||||
type Item = Result<ResponseEvent>;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
self.rx_event.poll_recv(cx)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::model_family::find_family_for_model;
|
||||
@@ -453,39 +354,14 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn serializes_text_verbosity_when_set() {
|
||||
let input: Vec<ResponseItem> = vec![];
|
||||
let tools: Vec<serde_json::Value> = vec![];
|
||||
let req = ResponsesApiRequest {
|
||||
model: "gpt-5",
|
||||
instructions: "i",
|
||||
input: &input,
|
||||
tools: &tools,
|
||||
tool_choice: "auto",
|
||||
parallel_tool_calls: true,
|
||||
reasoning: None,
|
||||
store: false,
|
||||
stream: true,
|
||||
include: vec![],
|
||||
prompt_cache_key: None,
|
||||
text: Some(TextControls {
|
||||
verbosity: Some(OpenAiVerbosity::Low),
|
||||
format: None,
|
||||
}),
|
||||
};
|
||||
|
||||
let v = serde_json::to_value(&req).expect("json");
|
||||
assert_eq!(
|
||||
v.get("text")
|
||||
.and_then(|t| t.get("verbosity"))
|
||||
.and_then(|s| s.as_str()),
|
||||
Some("low")
|
||||
);
|
||||
let controls =
|
||||
create_text_param_for_request(Some(VerbosityConfig::Low), &None).expect("controls");
|
||||
assert_eq!(controls.verbosity.as_deref(), Some("low"));
|
||||
assert!(controls.format.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serializes_text_schema_with_strict_format() {
|
||||
let input: Vec<ResponseItem> = vec![];
|
||||
let tools: Vec<serde_json::Value> = vec![];
|
||||
let schema = serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -493,61 +369,17 @@ mod tests {
|
||||
},
|
||||
"required": ["answer"],
|
||||
});
|
||||
let text_controls =
|
||||
let controls =
|
||||
create_text_param_for_request(None, &Some(schema.clone())).expect("text controls");
|
||||
|
||||
let req = ResponsesApiRequest {
|
||||
model: "gpt-5",
|
||||
instructions: "i",
|
||||
input: &input,
|
||||
tools: &tools,
|
||||
tool_choice: "auto",
|
||||
parallel_tool_calls: true,
|
||||
reasoning: None,
|
||||
store: false,
|
||||
stream: true,
|
||||
include: vec![],
|
||||
prompt_cache_key: None,
|
||||
text: Some(text_controls),
|
||||
};
|
||||
|
||||
let v = serde_json::to_value(&req).expect("json");
|
||||
let text = v.get("text").expect("text field");
|
||||
assert!(text.get("verbosity").is_none());
|
||||
let format = text.get("format").expect("format field");
|
||||
|
||||
assert_eq!(
|
||||
format.get("name"),
|
||||
Some(&serde_json::Value::String("codex_output_schema".into()))
|
||||
);
|
||||
assert_eq!(
|
||||
format.get("type"),
|
||||
Some(&serde_json::Value::String("json_schema".into()))
|
||||
);
|
||||
assert_eq!(format.get("strict"), Some(&serde_json::Value::Bool(true)));
|
||||
assert_eq!(format.get("schema"), Some(&schema));
|
||||
assert!(controls.verbosity.is_none());
|
||||
let format = controls.format.expect("format");
|
||||
assert_eq!(format.name, "codex_output_schema");
|
||||
assert!(format.strict);
|
||||
assert_eq!(format.schema, schema);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn omits_text_when_not_set() {
|
||||
let input: Vec<ResponseItem> = vec![];
|
||||
let tools: Vec<serde_json::Value> = vec![];
|
||||
let req = ResponsesApiRequest {
|
||||
model: "gpt-5",
|
||||
instructions: "i",
|
||||
input: &input,
|
||||
tools: &tools,
|
||||
tool_choice: "auto",
|
||||
parallel_tool_calls: true,
|
||||
reasoning: None,
|
||||
store: false,
|
||||
stream: true,
|
||||
include: vec![],
|
||||
prompt_cache_key: None,
|
||||
text: None,
|
||||
};
|
||||
|
||||
let v = serde_json::to_value(&req).expect("json");
|
||||
assert!(v.get("text").is_none());
|
||||
assert!(create_text_param_for_request(None, &None).is_none());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicU64;
|
||||
|
||||
use crate::AuthManager;
|
||||
use crate::ModelClient;
|
||||
use crate::client_common::REVIEW_PROMPT;
|
||||
use crate::compact;
|
||||
use crate::features::Feature;
|
||||
@@ -53,7 +54,6 @@ use tracing::info;
|
||||
use tracing::warn;
|
||||
|
||||
use crate::ModelProviderInfo;
|
||||
use crate::client::ModelClient;
|
||||
use crate::client_common::Prompt;
|
||||
use crate::client_common::ResponseEvent;
|
||||
use crate::config::Config;
|
||||
@@ -293,6 +293,8 @@ impl TurnContext {
|
||||
}
|
||||
}
|
||||
|
||||
// Model-specific helpers live on ModelClient; TurnContext remains lean.
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct SessionConfiguration {
|
||||
@@ -402,6 +404,11 @@ impl Session {
|
||||
session_configuration.model.as_str(),
|
||||
);
|
||||
|
||||
let tools_config = ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &model_family,
|
||||
features: &config.features,
|
||||
});
|
||||
|
||||
let client = ModelClient::new(
|
||||
Arc::new(per_turn_config),
|
||||
auth_manager,
|
||||
@@ -413,11 +420,6 @@ impl Session {
|
||||
session_configuration.session_source.clone(),
|
||||
);
|
||||
|
||||
let tools_config = ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &model_family,
|
||||
features: &config.features,
|
||||
});
|
||||
|
||||
TurnContext {
|
||||
sub_id,
|
||||
client,
|
||||
@@ -1677,6 +1679,7 @@ async fn spawn_review_thread(
|
||||
);
|
||||
|
||||
let per_turn_config = Arc::new(per_turn_config);
|
||||
|
||||
let client = ModelClient::new(
|
||||
per_turn_config.clone(),
|
||||
auth_manager,
|
||||
@@ -1939,7 +1942,7 @@ async fn run_turn(
|
||||
retries += 1;
|
||||
let delay = match e {
|
||||
CodexErr::Stream(_, Some(delay)) => delay,
|
||||
_ => backoff(retries),
|
||||
_ => backoff(retries.max(0) as u64),
|
||||
};
|
||||
warn!(
|
||||
"stream disconnected - retrying turn ({retries}/{max_retries} in {delay:?})...",
|
||||
@@ -1998,10 +2001,7 @@ async fn try_run_turn(
|
||||
});
|
||||
|
||||
sess.persist_rollout_items(&[rollout_item]).await;
|
||||
let mut stream = turn_context
|
||||
.client
|
||||
.clone()
|
||||
.stream(prompt)
|
||||
let mut stream = crate::client::stream_for_turn(&turn_context, prompt)
|
||||
.or_cancel(&cancellation_token)
|
||||
.await??;
|
||||
|
||||
@@ -2013,6 +2013,9 @@ async fn try_run_turn(
|
||||
);
|
||||
let mut output: FuturesOrdered<BoxFuture<CodexResult<ProcessedResponseItem>>> =
|
||||
FuturesOrdered::new();
|
||||
// Track whether any tool calls have been scheduled so we can salvage
|
||||
// their outputs if the stream disconnects before `response.completed`.
|
||||
let mut saw_tool_call = false;
|
||||
|
||||
let mut active_item: Option<TurnItem> = None;
|
||||
|
||||
@@ -2031,8 +2034,27 @@ async fn try_run_turn(
|
||||
};
|
||||
|
||||
let event = match event {
|
||||
Some(res) => res?,
|
||||
Some(res) => match res {
|
||||
Ok(ev) => ev,
|
||||
Err(e) => {
|
||||
if saw_tool_call {
|
||||
let processed_items = output.try_collect().await?;
|
||||
return Ok(TurnRunResult {
|
||||
processed_items,
|
||||
total_token_usage: None,
|
||||
});
|
||||
}
|
||||
return Err(e);
|
||||
}
|
||||
},
|
||||
None => {
|
||||
if saw_tool_call {
|
||||
let processed_items = output.try_collect().await?;
|
||||
return Ok(TurnRunResult {
|
||||
processed_items,
|
||||
total_token_usage: None,
|
||||
});
|
||||
}
|
||||
return Err(CodexErr::Stream(
|
||||
"stream closed before response.completed".into(),
|
||||
None,
|
||||
@@ -2045,7 +2067,11 @@ async fn try_run_turn(
|
||||
};
|
||||
|
||||
match event {
|
||||
ResponseEvent::Created => {}
|
||||
ResponseEvent::Created => {
|
||||
// Emit an initial TokenCount so UIs (and rollouts) have a
|
||||
// marker even when providers omit rate-limit headers.
|
||||
sess.send_token_count_event(&turn_context).await;
|
||||
}
|
||||
ResponseEvent::OutputItemDone(item) => {
|
||||
let previously_active_item = active_item.take();
|
||||
match ToolRouter::build_tool_call(sess.as_ref(), item.clone()) {
|
||||
@@ -2056,6 +2082,7 @@ async fn try_run_turn(
|
||||
let response =
|
||||
tool_runtime.handle_tool_call(call, cancellation_token.child_token());
|
||||
|
||||
saw_tool_call = true;
|
||||
output.push_back(
|
||||
async move {
|
||||
Ok(ProcessedResponseItem {
|
||||
|
||||
@@ -120,7 +120,7 @@ async fn run_compact_task_inner(
|
||||
Err(e) => {
|
||||
if retries < max_retries {
|
||||
retries += 1;
|
||||
let delay = backoff(retries);
|
||||
let delay = backoff(retries.max(0) as u64);
|
||||
sess.notify_stream_error(
|
||||
turn_context.as_ref(),
|
||||
format!("Reconnecting... {retries}/{max_retries}"),
|
||||
@@ -266,7 +266,7 @@ async fn drain_to_completed(
|
||||
turn_context: &TurnContext,
|
||||
prompt: &Prompt,
|
||||
) -> CodexResult<()> {
|
||||
let mut stream = turn_context.client.clone().stream(prompt).await?;
|
||||
let mut stream = crate::client::stream_for_turn(turn_context, prompt).await?;
|
||||
loop {
|
||||
let maybe_event = stream.next().await;
|
||||
let Some(event) = maybe_event else {
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
use crate::ModelProviderInfo;
|
||||
use crate::auth::AuthCredentialsStoreMode;
|
||||
use crate::built_in_model_providers;
|
||||
use crate::config::types::DEFAULT_OTEL_ENVIRONMENT;
|
||||
use crate::config::types::History;
|
||||
use crate::config::types::McpServerConfig;
|
||||
@@ -25,8 +27,6 @@ use crate::git_info::resolve_root_git_project_for_trust;
|
||||
use crate::model_family::ModelFamily;
|
||||
use crate::model_family::derive_default_model_family;
|
||||
use crate::model_family::find_family_for_model;
|
||||
use crate::model_provider_info::ModelProviderInfo;
|
||||
use crate::model_provider_info::built_in_model_providers;
|
||||
use crate::openai_model_info::get_model_info;
|
||||
use crate::project_doc::DEFAULT_PROJECT_DOC_FILENAME;
|
||||
use crate::project_doc::LOCAL_PROJECT_DOC_FILENAME;
|
||||
|
||||
@@ -41,6 +41,14 @@ impl CodexHttpClient {
|
||||
Self { inner }
|
||||
}
|
||||
|
||||
pub fn inner(&self) -> &reqwest::Client {
|
||||
&self.inner
|
||||
}
|
||||
|
||||
pub fn clone_inner(&self) -> reqwest::Client {
|
||||
self.inner.clone()
|
||||
}
|
||||
|
||||
pub fn get<U>(&self, url: U) -> CodexRequestBuilder
|
||||
where
|
||||
U: IntoUrl,
|
||||
|
||||
@@ -8,7 +8,6 @@
|
||||
mod apply_patch;
|
||||
pub mod auth;
|
||||
pub mod bash;
|
||||
mod chat_completions;
|
||||
mod client;
|
||||
mod client_common;
|
||||
pub mod codex;
|
||||
@@ -32,7 +31,6 @@ pub mod mcp;
|
||||
mod mcp_connection_manager;
|
||||
mod mcp_tool_call;
|
||||
mod message_history;
|
||||
mod model_provider_info;
|
||||
pub mod parse_command;
|
||||
mod response_processing;
|
||||
pub mod sandboxing;
|
||||
@@ -40,11 +38,13 @@ pub mod token_data;
|
||||
mod truncate;
|
||||
mod unified_exec;
|
||||
mod user_instructions;
|
||||
pub use model_provider_info::BUILT_IN_OSS_MODEL_PROVIDER_ID;
|
||||
pub use model_provider_info::ModelProviderInfo;
|
||||
pub use model_provider_info::WireApi;
|
||||
pub use model_provider_info::built_in_model_providers;
|
||||
pub use model_provider_info::create_oss_provider_with_base_url;
|
||||
mod wire_payload;
|
||||
pub use codex_provider_config::BUILT_IN_OSS_MODEL_PROVIDER_ID;
|
||||
pub use codex_provider_config::ModelProviderInfo;
|
||||
pub use codex_provider_config::WireApi;
|
||||
pub use codex_provider_config::built_in_model_providers;
|
||||
pub use codex_provider_config::create_oss_provider_with_base_url;
|
||||
mod aggregate;
|
||||
mod conversation_manager;
|
||||
mod event_mapping;
|
||||
pub mod review_format;
|
||||
@@ -95,6 +95,7 @@ pub use codex_protocol::protocol;
|
||||
// as those in the protocol crate when constructing protocol messages.
|
||||
pub use codex_protocol::config_types as protocol_config_types;
|
||||
|
||||
pub use aggregate::AggregateStreamExt;
|
||||
pub use client::ModelClient;
|
||||
pub use client_common::Prompt;
|
||||
pub use client_common::REVIEW_PROMPT;
|
||||
|
||||
@@ -1,532 +0,0 @@
|
||||
//! Registry of model providers supported by Codex.
|
||||
//!
|
||||
//! Providers can be defined in two places:
|
||||
//! 1. Built-in defaults compiled into the binary so Codex works out-of-the-box.
|
||||
//! 2. User-defined entries inside `~/.codex/config.toml` under the `model_providers`
|
||||
//! key. These override or extend the defaults at runtime.
|
||||
|
||||
use crate::CodexAuth;
|
||||
use crate::default_client::CodexHttpClient;
|
||||
use crate::default_client::CodexRequestBuilder;
|
||||
use codex_app_server_protocol::AuthMode;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use std::collections::HashMap;
|
||||
use std::env::VarError;
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::error::EnvVarError;
|
||||
const DEFAULT_STREAM_IDLE_TIMEOUT_MS: u64 = 300_000;
|
||||
const DEFAULT_STREAM_MAX_RETRIES: u64 = 5;
|
||||
const DEFAULT_REQUEST_MAX_RETRIES: u64 = 4;
|
||||
/// Hard cap for user-configured `stream_max_retries`.
|
||||
const MAX_STREAM_MAX_RETRIES: u64 = 100;
|
||||
/// Hard cap for user-configured `request_max_retries`.
|
||||
const MAX_REQUEST_MAX_RETRIES: u64 = 100;
|
||||
|
||||
/// Wire protocol that the provider speaks. Most third-party services only
|
||||
/// implement the classic OpenAI Chat Completions JSON schema, whereas OpenAI
|
||||
/// itself (and a handful of others) additionally expose the more modern
|
||||
/// *Responses* API. The two protocols use different request/response shapes
|
||||
/// and *cannot* be auto-detected at runtime, therefore each provider entry
|
||||
/// must declare which one it expects.
|
||||
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum WireApi {
|
||||
/// The Responses API exposed by OpenAI at `/v1/responses`.
|
||||
Responses,
|
||||
|
||||
/// Regular Chat Completions compatible with `/v1/chat/completions`.
|
||||
#[default]
|
||||
Chat,
|
||||
}
|
||||
|
||||
/// Serializable representation of a provider definition.
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
|
||||
pub struct ModelProviderInfo {
|
||||
/// Friendly display name.
|
||||
pub name: String,
|
||||
/// Base URL for the provider's OpenAI-compatible API.
|
||||
pub base_url: Option<String>,
|
||||
/// Environment variable that stores the user's API key for this provider.
|
||||
pub env_key: Option<String>,
|
||||
|
||||
/// Optional instructions to help the user get a valid value for the
|
||||
/// variable and set it.
|
||||
pub env_key_instructions: Option<String>,
|
||||
|
||||
/// Value to use with `Authorization: Bearer <token>` header. Use of this
|
||||
/// config is discouraged in favor of `env_key` for security reasons, but
|
||||
/// this may be necessary when using this programmatically.
|
||||
pub experimental_bearer_token: Option<String>,
|
||||
|
||||
/// Which wire protocol this provider expects.
|
||||
#[serde(default)]
|
||||
pub wire_api: WireApi,
|
||||
|
||||
/// Optional query parameters to append to the base URL.
|
||||
pub query_params: Option<HashMap<String, String>>,
|
||||
|
||||
/// Additional HTTP headers to include in requests to this provider where
|
||||
/// the (key, value) pairs are the header name and value.
|
||||
pub http_headers: Option<HashMap<String, String>>,
|
||||
|
||||
/// Optional HTTP headers to include in requests to this provider where the
|
||||
/// (key, value) pairs are the header name and _environment variable_ whose
|
||||
/// value should be used. If the environment variable is not set, or the
|
||||
/// value is empty, the header will not be included in the request.
|
||||
pub env_http_headers: Option<HashMap<String, String>>,
|
||||
|
||||
/// Maximum number of times to retry a failed HTTP request to this provider.
|
||||
pub request_max_retries: Option<u64>,
|
||||
|
||||
/// Number of times to retry reconnecting a dropped streaming response before failing.
|
||||
pub stream_max_retries: Option<u64>,
|
||||
|
||||
/// Idle timeout (in milliseconds) to wait for activity on a streaming response before treating
|
||||
/// the connection as lost.
|
||||
pub stream_idle_timeout_ms: Option<u64>,
|
||||
|
||||
/// Does this provider require an OpenAI API Key or ChatGPT login token? If true,
|
||||
/// user is presented with login screen on first run, and login preference and token/key
|
||||
/// are stored in auth.json. If false (which is the default), login screen is skipped,
|
||||
/// and API key (if needed) comes from the "env_key" environment variable.
|
||||
#[serde(default)]
|
||||
pub requires_openai_auth: bool,
|
||||
}
|
||||
|
||||
impl ModelProviderInfo {
|
||||
/// Construct a `POST` RequestBuilder for the given URL using the provided
|
||||
/// [`CodexHttpClient`] applying:
|
||||
/// • provider-specific headers (static + env based)
|
||||
/// • Bearer auth header when an API key is available.
|
||||
/// • Auth token for OAuth.
|
||||
///
|
||||
/// If the provider declares an `env_key` but the variable is missing/empty, returns an [`Err`] identical to the
|
||||
/// one produced by [`ModelProviderInfo::api_key`].
|
||||
pub async fn create_request_builder<'a>(
|
||||
&'a self,
|
||||
client: &'a CodexHttpClient,
|
||||
auth: &Option<CodexAuth>,
|
||||
) -> crate::error::Result<CodexRequestBuilder> {
|
||||
let effective_auth = if let Some(secret_key) = &self.experimental_bearer_token {
|
||||
Some(CodexAuth::from_api_key(secret_key))
|
||||
} else {
|
||||
match self.api_key() {
|
||||
Ok(Some(key)) => Some(CodexAuth::from_api_key(&key)),
|
||||
Ok(None) => auth.clone(),
|
||||
Err(err) => {
|
||||
if auth.is_some() {
|
||||
auth.clone()
|
||||
} else {
|
||||
return Err(err);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let url = self.get_full_url(&effective_auth);
|
||||
|
||||
let mut builder = client.post(url);
|
||||
|
||||
if let Some(auth) = effective_auth.as_ref() {
|
||||
builder = builder.bearer_auth(auth.get_token().await?);
|
||||
}
|
||||
|
||||
Ok(self.apply_http_headers(builder))
|
||||
}
|
||||
|
||||
fn get_query_string(&self) -> String {
|
||||
self.query_params
|
||||
.as_ref()
|
||||
.map_or_else(String::new, |params| {
|
||||
let full_params = params
|
||||
.iter()
|
||||
.map(|(k, v)| format!("{k}={v}"))
|
||||
.collect::<Vec<_>>()
|
||||
.join("&");
|
||||
format!("?{full_params}")
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn get_full_url(&self, auth: &Option<CodexAuth>) -> String {
|
||||
let default_base_url = if matches!(
|
||||
auth,
|
||||
Some(CodexAuth {
|
||||
mode: AuthMode::ChatGPT,
|
||||
..
|
||||
})
|
||||
) {
|
||||
"https://chatgpt.com/backend-api/codex"
|
||||
} else {
|
||||
"https://api.openai.com/v1"
|
||||
};
|
||||
let query_string = self.get_query_string();
|
||||
let base_url = self
|
||||
.base_url
|
||||
.clone()
|
||||
.unwrap_or(default_base_url.to_string());
|
||||
|
||||
match self.wire_api {
|
||||
WireApi::Responses => format!("{base_url}/responses{query_string}"),
|
||||
WireApi::Chat => format!("{base_url}/chat/completions{query_string}"),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn is_azure_responses_endpoint(&self) -> bool {
|
||||
if self.wire_api != WireApi::Responses {
|
||||
return false;
|
||||
}
|
||||
|
||||
if self.name.eq_ignore_ascii_case("azure") {
|
||||
return true;
|
||||
}
|
||||
|
||||
self.base_url
|
||||
.as_ref()
|
||||
.map(|base| matches_azure_responses_base_url(base))
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
/// Apply provider-specific HTTP headers (both static and environment-based)
|
||||
/// onto an existing [`CodexRequestBuilder`] and return the updated
|
||||
/// builder.
|
||||
fn apply_http_headers(&self, mut builder: CodexRequestBuilder) -> CodexRequestBuilder {
|
||||
if let Some(extra) = &self.http_headers {
|
||||
for (k, v) in extra {
|
||||
builder = builder.header(k, v);
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(env_headers) = &self.env_http_headers {
|
||||
for (header, env_var) in env_headers {
|
||||
if let Ok(val) = std::env::var(env_var)
|
||||
&& !val.trim().is_empty()
|
||||
{
|
||||
builder = builder.header(header, val);
|
||||
}
|
||||
}
|
||||
}
|
||||
builder
|
||||
}
|
||||
|
||||
/// If `env_key` is Some, returns the API key for this provider if present
|
||||
/// (and non-empty) in the environment. If `env_key` is required but
|
||||
/// cannot be found, returns an error.
|
||||
pub fn api_key(&self) -> crate::error::Result<Option<String>> {
|
||||
match &self.env_key {
|
||||
Some(env_key) => {
|
||||
let env_value = std::env::var(env_key);
|
||||
env_value
|
||||
.and_then(|v| {
|
||||
if v.trim().is_empty() {
|
||||
Err(VarError::NotPresent)
|
||||
} else {
|
||||
Ok(Some(v))
|
||||
}
|
||||
})
|
||||
.map_err(|_| {
|
||||
crate::error::CodexErr::EnvVar(EnvVarError {
|
||||
var: env_key.clone(),
|
||||
instructions: self.env_key_instructions.clone(),
|
||||
})
|
||||
})
|
||||
}
|
||||
None => Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
/// Effective maximum number of request retries for this provider.
|
||||
pub fn request_max_retries(&self) -> u64 {
|
||||
self.request_max_retries
|
||||
.unwrap_or(DEFAULT_REQUEST_MAX_RETRIES)
|
||||
.min(MAX_REQUEST_MAX_RETRIES)
|
||||
}
|
||||
|
||||
/// Effective maximum number of stream reconnection attempts for this provider.
|
||||
pub fn stream_max_retries(&self) -> u64 {
|
||||
self.stream_max_retries
|
||||
.unwrap_or(DEFAULT_STREAM_MAX_RETRIES)
|
||||
.min(MAX_STREAM_MAX_RETRIES)
|
||||
}
|
||||
|
||||
/// Effective idle timeout for streaming responses.
|
||||
pub fn stream_idle_timeout(&self) -> Duration {
|
||||
self.stream_idle_timeout_ms
|
||||
.map(Duration::from_millis)
|
||||
.unwrap_or(Duration::from_millis(DEFAULT_STREAM_IDLE_TIMEOUT_MS))
|
||||
}
|
||||
}
|
||||
|
||||
const DEFAULT_OLLAMA_PORT: u32 = 11434;
|
||||
|
||||
pub const BUILT_IN_OSS_MODEL_PROVIDER_ID: &str = "oss";
|
||||
|
||||
/// Built-in default provider list.
|
||||
pub fn built_in_model_providers() -> HashMap<String, ModelProviderInfo> {
|
||||
use ModelProviderInfo as P;
|
||||
|
||||
// We do not want to be in the business of adjucating which third-party
|
||||
// providers are bundled with Codex CLI, so we only include the OpenAI and
|
||||
// open source ("oss") providers by default. Users are encouraged to add to
|
||||
// `model_providers` in config.toml to add their own providers.
|
||||
[
|
||||
(
|
||||
"openai",
|
||||
P {
|
||||
name: "OpenAI".into(),
|
||||
// Allow users to override the default OpenAI endpoint by
|
||||
// exporting `OPENAI_BASE_URL`. This is useful when pointing
|
||||
// Codex at a proxy, mock server, or Azure-style deployment
|
||||
// without requiring a full TOML override for the built-in
|
||||
// OpenAI provider.
|
||||
base_url: std::env::var("OPENAI_BASE_URL")
|
||||
.ok()
|
||||
.filter(|v| !v.trim().is_empty()),
|
||||
env_key: None,
|
||||
env_key_instructions: None,
|
||||
experimental_bearer_token: None,
|
||||
wire_api: WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: Some(
|
||||
[("version".to_string(), env!("CARGO_PKG_VERSION").to_string())]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
),
|
||||
env_http_headers: Some(
|
||||
[
|
||||
(
|
||||
"OpenAI-Organization".to_string(),
|
||||
"OPENAI_ORGANIZATION".to_string(),
|
||||
),
|
||||
("OpenAI-Project".to_string(), "OPENAI_PROJECT".to_string()),
|
||||
]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
),
|
||||
// Use global defaults for retry/timeout unless overridden in config.toml.
|
||||
request_max_retries: None,
|
||||
stream_max_retries: None,
|
||||
stream_idle_timeout_ms: None,
|
||||
requires_openai_auth: true,
|
||||
},
|
||||
),
|
||||
(BUILT_IN_OSS_MODEL_PROVIDER_ID, create_oss_provider()),
|
||||
]
|
||||
.into_iter()
|
||||
.map(|(k, v)| (k.to_string(), v))
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn create_oss_provider() -> ModelProviderInfo {
|
||||
// These CODEX_OSS_ environment variables are experimental: we may
|
||||
// switch to reading values from config.toml instead.
|
||||
let codex_oss_base_url = match std::env::var("CODEX_OSS_BASE_URL")
|
||||
.ok()
|
||||
.filter(|v| !v.trim().is_empty())
|
||||
{
|
||||
Some(url) => url,
|
||||
None => format!(
|
||||
"http://localhost:{port}/v1",
|
||||
port = std::env::var("CODEX_OSS_PORT")
|
||||
.ok()
|
||||
.filter(|v| !v.trim().is_empty())
|
||||
.and_then(|v| v.parse::<u32>().ok())
|
||||
.unwrap_or(DEFAULT_OLLAMA_PORT)
|
||||
),
|
||||
};
|
||||
|
||||
create_oss_provider_with_base_url(&codex_oss_base_url)
|
||||
}
|
||||
|
||||
pub fn create_oss_provider_with_base_url(base_url: &str) -> ModelProviderInfo {
|
||||
ModelProviderInfo {
|
||||
name: "gpt-oss".into(),
|
||||
base_url: Some(base_url.into()),
|
||||
env_key: None,
|
||||
env_key_instructions: None,
|
||||
experimental_bearer_token: None,
|
||||
wire_api: WireApi::Chat,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
request_max_retries: None,
|
||||
stream_max_retries: None,
|
||||
stream_idle_timeout_ms: None,
|
||||
requires_openai_auth: false,
|
||||
}
|
||||
}
|
||||
|
||||
fn matches_azure_responses_base_url(base_url: &str) -> bool {
|
||||
let base = base_url.to_ascii_lowercase();
|
||||
const AZURE_MARKERS: [&str; 5] = [
|
||||
"openai.azure.",
|
||||
"cognitiveservices.azure.",
|
||||
"aoai.azure.",
|
||||
"azure-api.",
|
||||
"azurefd.",
|
||||
];
|
||||
AZURE_MARKERS.iter().any(|marker| base.contains(marker))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
#[test]
|
||||
fn test_deserialize_ollama_model_provider_toml() {
|
||||
let azure_provider_toml = r#"
|
||||
name = "Ollama"
|
||||
base_url = "http://localhost:11434/v1"
|
||||
"#;
|
||||
let expected_provider = ModelProviderInfo {
|
||||
name: "Ollama".into(),
|
||||
base_url: Some("http://localhost:11434/v1".into()),
|
||||
env_key: None,
|
||||
env_key_instructions: None,
|
||||
experimental_bearer_token: None,
|
||||
wire_api: WireApi::Chat,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
request_max_retries: None,
|
||||
stream_max_retries: None,
|
||||
stream_idle_timeout_ms: None,
|
||||
requires_openai_auth: false,
|
||||
};
|
||||
|
||||
let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap();
|
||||
assert_eq!(expected_provider, provider);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deserialize_azure_model_provider_toml() {
|
||||
let azure_provider_toml = r#"
|
||||
name = "Azure"
|
||||
base_url = "https://xxxxx.openai.azure.com/openai"
|
||||
env_key = "AZURE_OPENAI_API_KEY"
|
||||
query_params = { api-version = "2025-04-01-preview" }
|
||||
"#;
|
||||
let expected_provider = ModelProviderInfo {
|
||||
name: "Azure".into(),
|
||||
base_url: Some("https://xxxxx.openai.azure.com/openai".into()),
|
||||
env_key: Some("AZURE_OPENAI_API_KEY".into()),
|
||||
env_key_instructions: None,
|
||||
experimental_bearer_token: None,
|
||||
wire_api: WireApi::Chat,
|
||||
query_params: Some(maplit::hashmap! {
|
||||
"api-version".to_string() => "2025-04-01-preview".to_string(),
|
||||
}),
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
request_max_retries: None,
|
||||
stream_max_retries: None,
|
||||
stream_idle_timeout_ms: None,
|
||||
requires_openai_auth: false,
|
||||
};
|
||||
|
||||
let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap();
|
||||
assert_eq!(expected_provider, provider);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deserialize_example_model_provider_toml() {
|
||||
let azure_provider_toml = r#"
|
||||
name = "Example"
|
||||
base_url = "https://example.com"
|
||||
env_key = "API_KEY"
|
||||
http_headers = { "X-Example-Header" = "example-value" }
|
||||
env_http_headers = { "X-Example-Env-Header" = "EXAMPLE_ENV_VAR" }
|
||||
"#;
|
||||
let expected_provider = ModelProviderInfo {
|
||||
name: "Example".into(),
|
||||
base_url: Some("https://example.com".into()),
|
||||
env_key: Some("API_KEY".into()),
|
||||
env_key_instructions: None,
|
||||
experimental_bearer_token: None,
|
||||
wire_api: WireApi::Chat,
|
||||
query_params: None,
|
||||
http_headers: Some(maplit::hashmap! {
|
||||
"X-Example-Header".to_string() => "example-value".to_string(),
|
||||
}),
|
||||
env_http_headers: Some(maplit::hashmap! {
|
||||
"X-Example-Env-Header".to_string() => "EXAMPLE_ENV_VAR".to_string(),
|
||||
}),
|
||||
request_max_retries: None,
|
||||
stream_max_retries: None,
|
||||
stream_idle_timeout_ms: None,
|
||||
requires_openai_auth: false,
|
||||
};
|
||||
|
||||
let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap();
|
||||
assert_eq!(expected_provider, provider);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detects_azure_responses_base_urls() {
|
||||
fn provider_for(base_url: &str) -> ModelProviderInfo {
|
||||
ModelProviderInfo {
|
||||
name: "test".into(),
|
||||
base_url: Some(base_url.into()),
|
||||
env_key: None,
|
||||
env_key_instructions: None,
|
||||
experimental_bearer_token: None,
|
||||
wire_api: WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
request_max_retries: None,
|
||||
stream_max_retries: None,
|
||||
stream_idle_timeout_ms: None,
|
||||
requires_openai_auth: false,
|
||||
}
|
||||
}
|
||||
|
||||
let positive_cases = [
|
||||
"https://foo.openai.azure.com/openai",
|
||||
"https://foo.openai.azure.us/openai/deployments/bar",
|
||||
"https://foo.cognitiveservices.azure.cn/openai",
|
||||
"https://foo.aoai.azure.com/openai",
|
||||
"https://foo.openai.azure-api.net/openai",
|
||||
"https://foo.z01.azurefd.net/",
|
||||
];
|
||||
for base_url in positive_cases {
|
||||
let provider = provider_for(base_url);
|
||||
assert!(
|
||||
provider.is_azure_responses_endpoint(),
|
||||
"expected {base_url} to be detected as Azure"
|
||||
);
|
||||
}
|
||||
|
||||
let named_provider = ModelProviderInfo {
|
||||
name: "Azure".into(),
|
||||
base_url: Some("https://example.com".into()),
|
||||
env_key: None,
|
||||
env_key_instructions: None,
|
||||
experimental_bearer_token: None,
|
||||
wire_api: WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
request_max_retries: None,
|
||||
stream_max_retries: None,
|
||||
stream_idle_timeout_ms: None,
|
||||
requires_openai_auth: false,
|
||||
};
|
||||
assert!(named_provider.is_azure_responses_endpoint());
|
||||
|
||||
let negative_cases = [
|
||||
"https://api.openai.com/v1",
|
||||
"https://example.com/openai",
|
||||
"https://myproxy.azurewebsites.net/openai",
|
||||
];
|
||||
for base_url in negative_cases {
|
||||
let provider = provider_for(base_url);
|
||||
assert!(
|
||||
!provider.is_azure_responses_endpoint(),
|
||||
"expected {base_url} not to be detected as Azure"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
342
codex-rs/core/src/wire_payload.rs
Normal file
342
codex-rs/core/src/wire_payload.rs
Normal file
@@ -0,0 +1,342 @@
|
||||
use codex_protocol::ConversationId;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use serde_json::Value;
|
||||
use serde_json::json;
|
||||
|
||||
use crate::client_common::Prompt;
|
||||
use crate::tools::spec::create_tools_json_for_responses_api;
|
||||
|
||||
pub fn build_responses_payload(
|
||||
prompt: &Prompt,
|
||||
model: &str,
|
||||
conversation_id: ConversationId,
|
||||
azure_workaround: bool,
|
||||
reasoning: Option<codex_api_client::Reasoning>,
|
||||
text_controls: Option<codex_api_client::TextControls>,
|
||||
instructions: String,
|
||||
) -> Value {
|
||||
let tools =
|
||||
create_tools_json_for_responses_api(&prompt.tools).unwrap_or_else(|_| Vec::<Value>::new());
|
||||
|
||||
let mut payload = json!({
|
||||
"model": model,
|
||||
"instructions": instructions,
|
||||
"input": prompt.get_formatted_input(),
|
||||
"tools": tools,
|
||||
"tool_choice": "auto",
|
||||
"parallel_tool_calls": prompt.parallel_tool_calls,
|
||||
"store": azure_workaround,
|
||||
"stream": true,
|
||||
"prompt_cache_key": conversation_id.to_string(),
|
||||
});
|
||||
|
||||
if let Some(reasoning) = reasoning
|
||||
&& let Some(obj) = payload.as_object_mut()
|
||||
{
|
||||
obj.insert(
|
||||
"reasoning".to_string(),
|
||||
serde_json::to_value(reasoning).unwrap_or(Value::Null),
|
||||
);
|
||||
}
|
||||
|
||||
if let Some(text) = text_controls
|
||||
&& let Some(obj) = payload.as_object_mut()
|
||||
{
|
||||
obj.insert(
|
||||
"text".to_string(),
|
||||
serde_json::to_value(text).unwrap_or(Value::Null),
|
||||
);
|
||||
}
|
||||
|
||||
let include = if prompt
|
||||
.get_formatted_input()
|
||||
.iter()
|
||||
.any(|it| matches!(it, ResponseItem::Reasoning { .. }))
|
||||
{
|
||||
vec!["reasoning.encrypted_content".to_string()]
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
if let Some(obj) = payload.as_object_mut() {
|
||||
obj.insert(
|
||||
"include".to_string(),
|
||||
Value::Array(include.into_iter().map(Value::String).collect()),
|
||||
);
|
||||
}
|
||||
|
||||
// Azure Responses requires ids attached to input items
|
||||
if azure_workaround
|
||||
&& let Some(input_value) = payload.get_mut("input")
|
||||
&& let Some(array) = input_value.as_array_mut()
|
||||
{
|
||||
attach_item_ids_array(array, &prompt.get_formatted_input());
|
||||
}
|
||||
|
||||
payload
|
||||
}
|
||||
|
||||
fn attach_item_ids_array(json_array: &mut [Value], prompt_input: &[ResponseItem]) {
|
||||
for (json_item, item) in json_array.iter_mut().zip(prompt_input.iter()) {
|
||||
let Some(obj) = json_item.as_object_mut() else {
|
||||
continue;
|
||||
};
|
||||
let mut set_id_if_absent = |id: &str| match obj.get("id") {
|
||||
Some(Value::String(s)) if !s.is_empty() => {}
|
||||
Some(Value::Null) | None => {
|
||||
obj.insert("id".to_string(), Value::String(id.to_string()));
|
||||
}
|
||||
_ => {}
|
||||
};
|
||||
match item {
|
||||
ResponseItem::Reasoning { id, .. } => set_id_if_absent(id),
|
||||
ResponseItem::Message { id, .. } => {
|
||||
if let Some(id) = id.as_ref() {
|
||||
set_id_if_absent(id);
|
||||
}
|
||||
}
|
||||
ResponseItem::WebSearchCall { id, .. }
|
||||
| ResponseItem::FunctionCall { id, .. }
|
||||
| ResponseItem::LocalShellCall { id, .. }
|
||||
| ResponseItem::CustomToolCall { id, .. } => {
|
||||
if let Some(id) = id.as_ref() {
|
||||
set_id_if_absent(id);
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn build_chat_payload(prompt: &Prompt, model: &str, instructions: String) -> Value {
|
||||
use crate::tools::spec::create_tools_json_for_chat_completions_api;
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::FunctionCallOutputContentItem;
|
||||
use codex_protocol::models::ReasoningItemContent;
|
||||
use std::collections::HashMap;
|
||||
|
||||
let mut messages = Vec::<Value>::new();
|
||||
messages.push(json!({ "role": "system", "content": instructions }));
|
||||
|
||||
let mut reasoning_by_anchor_index: HashMap<usize, String> = HashMap::new();
|
||||
|
||||
let mut last_emitted_role: Option<&str> = None;
|
||||
for item in &prompt.input {
|
||||
match item {
|
||||
ResponseItem::Message { role, .. } => last_emitted_role = Some(role.as_str()),
|
||||
ResponseItem::FunctionCall { .. } | ResponseItem::LocalShellCall { .. } => {
|
||||
last_emitted_role = Some("assistant");
|
||||
}
|
||||
ResponseItem::FunctionCallOutput { .. } => last_emitted_role = Some("tool"),
|
||||
ResponseItem::Reasoning { .. }
|
||||
| ResponseItem::Other
|
||||
| ResponseItem::CustomToolCall { .. }
|
||||
| ResponseItem::CustomToolCallOutput { .. }
|
||||
| ResponseItem::WebSearchCall { .. }
|
||||
| ResponseItem::GhostSnapshot { .. } => {}
|
||||
}
|
||||
}
|
||||
|
||||
let mut last_user_index: Option<usize> = None;
|
||||
for (idx, item) in prompt.input.iter().enumerate() {
|
||||
if let ResponseItem::Message { role, .. } = item
|
||||
&& role == "user"
|
||||
{
|
||||
last_user_index = Some(idx);
|
||||
}
|
||||
}
|
||||
|
||||
if !matches!(last_emitted_role, Some("user")) {
|
||||
for (idx, item) in prompt.input.iter().enumerate() {
|
||||
if let Some(u_idx) = last_user_index
|
||||
&& idx <= u_idx
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
if let ResponseItem::Reasoning {
|
||||
content: Some(items),
|
||||
..
|
||||
} = item
|
||||
{
|
||||
let mut text = String::new();
|
||||
for entry in items {
|
||||
match entry {
|
||||
ReasoningItemContent::ReasoningText { text: segment }
|
||||
| ReasoningItemContent::Text { text: segment } => {
|
||||
text.push_str(segment);
|
||||
}
|
||||
}
|
||||
}
|
||||
if text.trim().is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut attached = false;
|
||||
if idx > 0
|
||||
&& let ResponseItem::Message { role, .. } = &prompt.input[idx - 1]
|
||||
&& role == "assistant"
|
||||
{
|
||||
reasoning_by_anchor_index
|
||||
.entry(idx - 1)
|
||||
.and_modify(|val| val.push_str(&text))
|
||||
.or_insert(text.clone());
|
||||
attached = true;
|
||||
}
|
||||
|
||||
if !attached && idx + 1 < prompt.input.len() {
|
||||
match &prompt.input[idx + 1] {
|
||||
ResponseItem::FunctionCall { .. } | ResponseItem::LocalShellCall { .. } => {
|
||||
reasoning_by_anchor_index
|
||||
.entry(idx + 1)
|
||||
.and_modify(|val| val.push_str(&text))
|
||||
.or_insert(text.clone());
|
||||
}
|
||||
ResponseItem::Message { role, .. } if role == "assistant" => {
|
||||
reasoning_by_anchor_index
|
||||
.entry(idx + 1)
|
||||
.and_modify(|val| val.push_str(&text))
|
||||
.or_insert(text.clone());
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut last_assistant_text: Option<String> = None;
|
||||
for (idx, item) in prompt.input.iter().enumerate() {
|
||||
match item {
|
||||
ResponseItem::Message { role, content, .. } => {
|
||||
let mut text = String::new();
|
||||
let mut items: Vec<Value> = Vec::new();
|
||||
let mut saw_image = false;
|
||||
|
||||
for c in content {
|
||||
match c {
|
||||
ContentItem::InputText { text: t }
|
||||
| ContentItem::OutputText { text: t } => {
|
||||
text.push_str(t);
|
||||
items.push(json!({"type":"text","text": t}));
|
||||
}
|
||||
ContentItem::InputImage { image_url } => {
|
||||
saw_image = true;
|
||||
items.push(json!({"type":"image_url","image_url": {"url": image_url}}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if role == "assistant" {
|
||||
if let Some(prev) = &last_assistant_text
|
||||
&& prev == &text
|
||||
{
|
||||
continue;
|
||||
}
|
||||
last_assistant_text = Some(text.clone());
|
||||
}
|
||||
|
||||
let content_value = if role == "assistant" {
|
||||
json!(text)
|
||||
} else if saw_image {
|
||||
json!(items)
|
||||
} else {
|
||||
json!(text)
|
||||
};
|
||||
|
||||
let mut message = json!({ "role": role, "content": content_value });
|
||||
if let Some(reasoning) = reasoning_by_anchor_index.get(&idx)
|
||||
&& let Some(obj) = message.as_object_mut()
|
||||
{
|
||||
obj.insert("reasoning".to_string(), json!({"text": reasoning}));
|
||||
}
|
||||
messages.push(message);
|
||||
}
|
||||
ResponseItem::FunctionCall {
|
||||
name,
|
||||
arguments,
|
||||
call_id,
|
||||
..
|
||||
} => {
|
||||
messages.push(json!({
|
||||
"role": "assistant",
|
||||
"tool_calls": [{
|
||||
"id": call_id,
|
||||
"type": "function",
|
||||
"function": { "name": name, "arguments": arguments },
|
||||
}],
|
||||
}));
|
||||
}
|
||||
ResponseItem::FunctionCallOutput { call_id, output } => {
|
||||
let content_value = if let Some(items) = &output.content_items {
|
||||
let mapped: Vec<Value> = items
|
||||
.iter()
|
||||
.map(|item| match item {
|
||||
FunctionCallOutputContentItem::InputText { text } => {
|
||||
json!({"type":"text","text": text})
|
||||
}
|
||||
FunctionCallOutputContentItem::InputImage { image_url } => {
|
||||
json!({"type":"image_url","image_url": {"url": image_url}})
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
json!(mapped)
|
||||
} else {
|
||||
json!(output.content)
|
||||
};
|
||||
messages.push(
|
||||
json!({ "role": "tool", "tool_call_id": call_id, "content": content_value }),
|
||||
);
|
||||
}
|
||||
ResponseItem::LocalShellCall {
|
||||
id,
|
||||
call_id,
|
||||
action,
|
||||
..
|
||||
} => {
|
||||
let tool_id = call_id
|
||||
.clone()
|
||||
.filter(|value| !value.is_empty())
|
||||
.or_else(|| id.clone())
|
||||
.unwrap_or_default();
|
||||
messages.push(json!({
|
||||
"role": "assistant",
|
||||
"tool_calls": [{
|
||||
"id": tool_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "shell",
|
||||
"arguments": serde_json::to_string(action).unwrap_or_default(),
|
||||
},
|
||||
}],
|
||||
}));
|
||||
}
|
||||
ResponseItem::CustomToolCall {
|
||||
call_id,
|
||||
name,
|
||||
input,
|
||||
..
|
||||
} => {
|
||||
messages.push(json!({
|
||||
"role": "assistant",
|
||||
"tool_calls": [{
|
||||
"id": call_id.clone(),
|
||||
"type": "function",
|
||||
"function": { "name": name, "arguments": input },
|
||||
}],
|
||||
}));
|
||||
}
|
||||
ResponseItem::CustomToolCallOutput { call_id, output } => {
|
||||
messages
|
||||
.push(json!({ "role": "tool", "tool_call_id": call_id, "content": output }));
|
||||
}
|
||||
ResponseItem::WebSearchCall { .. }
|
||||
| ResponseItem::Reasoning { .. }
|
||||
| ResponseItem::Other
|
||||
| ResponseItem::GhostSnapshot { .. } => {}
|
||||
}
|
||||
}
|
||||
|
||||
let tools_json = create_tools_json_for_chat_completions_api(&prompt.tools)
|
||||
.unwrap_or_else(|_| Vec::<Value>::new());
|
||||
json!({ "model": model, "messages": messages, "stream": true, "tools": tools_json })
|
||||
}
|
||||
@@ -220,6 +220,10 @@ impl OtelEventManager {
|
||||
);
|
||||
}
|
||||
|
||||
pub fn sse_event_kind(&self, kind: &str) {
|
||||
self.sse_event(kind, Duration::from_millis(0));
|
||||
}
|
||||
|
||||
pub fn sse_event_failed<T>(&self, kind: Option<&String>, duration: Duration, error: &T)
|
||||
where
|
||||
T: Display,
|
||||
|
||||
18
codex-rs/provider-config/Cargo.toml
Normal file
18
codex-rs/provider-config/Cargo.toml
Normal file
@@ -0,0 +1,18 @@
|
||||
[package]
|
||||
name = "codex-provider-config"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
[lib]
|
||||
name = "codex_provider_config"
|
||||
path = "src/lib.rs"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
codex-app-server-protocol = { workspace = true }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
serde_json = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
reqwest = { workspace = true }
|
||||
337
codex-rs/provider-config/src/lib.rs
Normal file
337
codex-rs/provider-config/src/lib.rs
Normal file
@@ -0,0 +1,337 @@
|
||||
//! Provider configuration shared across Codex layers.
|
||||
//!
|
||||
//! This crate defines the provider-agnostic configuration and wire API
|
||||
//! selection that higher layers (core/app/client) can use. It intentionally
|
||||
//! avoids Codex-domain concepts like prompts, token counting, or event types.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::env::VarError;
|
||||
use std::time::Duration;
|
||||
|
||||
use codex_app_server_protocol::AuthMode;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum Error {
|
||||
#[error("missing environment variable {var}")]
|
||||
MissingEnvVar {
|
||||
var: String,
|
||||
instructions: Option<String>,
|
||||
},
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
const DEFAULT_STREAM_IDLE_TIMEOUT_MS: i64 = 300_000;
|
||||
const DEFAULT_STREAM_MAX_RETRIES: i64 = 5;
|
||||
const DEFAULT_REQUEST_MAX_RETRIES: i64 = 4;
|
||||
/// Hard cap for user-configured `stream_max_retries`.
|
||||
const MAX_STREAM_MAX_RETRIES: i64 = 100;
|
||||
/// Hard cap for user-configured `request_max_retries`.
|
||||
const MAX_REQUEST_MAX_RETRIES: i64 = 100;
|
||||
const DEFAULT_OLLAMA_PORT: i32 = 11434;
|
||||
|
||||
/// Wire protocol that the provider speaks.
|
||||
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum WireApi {
|
||||
/// The Responses API exposed by OpenAI at `/v1/responses`.
|
||||
Responses,
|
||||
/// Regular Chat Completions compatible with `/v1/chat/completions`.
|
||||
#[default]
|
||||
Chat,
|
||||
}
|
||||
|
||||
/// Serializable representation of a provider definition.
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
|
||||
pub struct ModelProviderInfo {
|
||||
/// Friendly display name.
|
||||
pub name: String,
|
||||
/// Base URL for the provider's OpenAI-compatible API.
|
||||
pub base_url: Option<String>,
|
||||
/// Environment variable that stores the user's API key for this provider.
|
||||
pub env_key: Option<String>,
|
||||
/// Optional instructions to help the user set the environment variable.
|
||||
pub env_key_instructions: Option<String>,
|
||||
/// Value to use with `Authorization: Bearer <token>` header. Prefer `env_key` when possible.
|
||||
pub experimental_bearer_token: Option<String>,
|
||||
/// Which wire protocol this provider expects.
|
||||
#[serde(default)]
|
||||
pub wire_api: WireApi,
|
||||
/// Optional query parameters to append to the base URL.
|
||||
pub query_params: Option<HashMap<String, String>>,
|
||||
/// Additional static HTTP headers to include in requests.
|
||||
pub http_headers: Option<HashMap<String, String>>,
|
||||
/// Optional HTTP headers whose values come from environment variables.
|
||||
pub env_http_headers: Option<HashMap<String, String>>,
|
||||
/// Maximum number of times to retry a failed HTTP request.
|
||||
pub request_max_retries: Option<i64>,
|
||||
/// Number of times to retry reconnecting a dropped streaming response before failing.
|
||||
pub stream_max_retries: Option<i64>,
|
||||
/// Idle timeout (in milliseconds) to wait for activity on a streaming response.
|
||||
pub stream_idle_timeout_ms: Option<i64>,
|
||||
/// If true, user is prompted for OpenAI login; otherwise uses `env_key`.
|
||||
#[serde(default)]
|
||||
pub requires_openai_auth: bool,
|
||||
}
|
||||
|
||||
impl ModelProviderInfo {
|
||||
/// Construct a `POST` request URL for the configured wire API.
|
||||
pub fn get_full_url(&self, auth: Option<&AuthContext>) -> String {
|
||||
let default_base_url = if matches!(
|
||||
auth,
|
||||
Some(AuthContext {
|
||||
mode: AuthMode::ChatGPT,
|
||||
..
|
||||
})
|
||||
) {
|
||||
"https://chatgpt.com/backend-api/codex"
|
||||
} else {
|
||||
"https://api.openai.com/v1"
|
||||
};
|
||||
let query_string = self.get_query_string();
|
||||
let base_url = self
|
||||
.base_url
|
||||
.clone()
|
||||
.unwrap_or_else(|| default_base_url.to_string());
|
||||
|
||||
match self.wire_api {
|
||||
WireApi::Responses => format!("{base_url}/responses{query_string}"),
|
||||
WireApi::Chat => format!("{base_url}/chat/completions{query_string}"),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_query_string(&self) -> String {
|
||||
self.query_params
|
||||
.as_ref()
|
||||
.map_or_else(String::new, |params| {
|
||||
let full_params = params
|
||||
.iter()
|
||||
.map(|(k, v)| format!("{k}={v}"))
|
||||
.collect::<Vec<_>>()
|
||||
.join("&");
|
||||
format!("?{full_params}")
|
||||
})
|
||||
}
|
||||
|
||||
pub fn is_azure_responses_endpoint(&self) -> bool {
|
||||
if self.wire_api != WireApi::Responses {
|
||||
return false;
|
||||
}
|
||||
if self.name.eq_ignore_ascii_case("azure") {
|
||||
return true;
|
||||
}
|
||||
self.base_url
|
||||
.as_ref()
|
||||
.map(|base| matches_azure_responses_base_url(base))
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
/// Apply static and env-derived headers to the provided builder.
|
||||
pub fn apply_http_headers(
|
||||
&self,
|
||||
mut builder: reqwest::RequestBuilder,
|
||||
) -> reqwest::RequestBuilder {
|
||||
if let Some(extra) = &self.http_headers {
|
||||
for (k, v) in extra {
|
||||
builder = builder.header(k, v);
|
||||
}
|
||||
}
|
||||
if let Some(env_headers) = &self.env_http_headers {
|
||||
for (header, env_var) in env_headers {
|
||||
if let Ok(val) = std::env::var(env_var)
|
||||
&& !val.trim().is_empty()
|
||||
{
|
||||
builder = builder.header(header, val);
|
||||
}
|
||||
}
|
||||
}
|
||||
builder
|
||||
}
|
||||
|
||||
pub fn api_key(&self) -> Result<Option<String>> {
|
||||
Ok(match self.env_key.as_ref() {
|
||||
Some(env_key) => match std::env::var(env_key) {
|
||||
Ok(value) if !value.trim().is_empty() => Some(value),
|
||||
Ok(_missing) => None,
|
||||
Err(VarError::NotPresent) => {
|
||||
let instructions = self.env_key_instructions.clone();
|
||||
return Err(Error::MissingEnvVar {
|
||||
var: env_key.to_string(),
|
||||
instructions,
|
||||
});
|
||||
}
|
||||
Err(VarError::NotUnicode(_)) => {
|
||||
return Err(Error::MissingEnvVar {
|
||||
var: env_key.to_string(),
|
||||
instructions: None,
|
||||
});
|
||||
}
|
||||
},
|
||||
None => None,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn stream_max_retries(&self) -> i64 {
|
||||
let value = self
|
||||
.stream_max_retries
|
||||
.unwrap_or(DEFAULT_STREAM_MAX_RETRIES)
|
||||
.min(MAX_STREAM_MAX_RETRIES);
|
||||
value.max(0)
|
||||
}
|
||||
|
||||
pub fn request_max_retries(&self) -> i64 {
|
||||
let value = self
|
||||
.request_max_retries
|
||||
.unwrap_or(DEFAULT_REQUEST_MAX_RETRIES)
|
||||
.min(MAX_REQUEST_MAX_RETRIES);
|
||||
value.max(0)
|
||||
}
|
||||
|
||||
pub fn stream_idle_timeout(&self) -> Duration {
|
||||
let ms = self
|
||||
.stream_idle_timeout_ms
|
||||
.unwrap_or(DEFAULT_STREAM_IDLE_TIMEOUT_MS);
|
||||
let clamped = if ms < 0 { 0 } else { ms as u64 };
|
||||
Duration::from_millis(clamped)
|
||||
}
|
||||
}
|
||||
|
||||
fn matches_azure_responses_base_url(base: &str) -> bool {
|
||||
base.starts_with("https://") && base.ends_with(".openai.azure.com/openai/responses")
|
||||
}
|
||||
|
||||
pub const BUILT_IN_OSS_MODEL_PROVIDER_ID: &str = "openai/compatible";
|
||||
pub const OPENAI_MODEL_PROVIDER_ID: &str = "openai";
|
||||
pub const ANTHROPIC_MODEL_PROVIDER_ID: &str = "anthropic";
|
||||
|
||||
/// Convenience helper to construct a default `openai/compatible` provider pointing at localhost.
|
||||
pub fn create_oss_provider_with_base_url(url: &str) -> ModelProviderInfo {
|
||||
ModelProviderInfo {
|
||||
name: "openai/compatible".to_string(),
|
||||
base_url: Some(url.to_string()),
|
||||
env_key: None,
|
||||
env_key_instructions: None,
|
||||
experimental_bearer_token: None,
|
||||
wire_api: WireApi::Chat,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
request_max_retries: None,
|
||||
stream_max_retries: None,
|
||||
stream_idle_timeout_ms: None,
|
||||
requires_openai_auth: false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_oss_provider() -> ModelProviderInfo {
|
||||
create_oss_provider_with_base_url(&format!("http://localhost:{DEFAULT_OLLAMA_PORT}"))
|
||||
}
|
||||
|
||||
pub fn built_in_model_providers() -> HashMap<String, ModelProviderInfo> {
|
||||
let mut map = HashMap::new();
|
||||
|
||||
map.insert(
|
||||
OPENAI_MODEL_PROVIDER_ID.to_string(),
|
||||
ModelProviderInfo {
|
||||
name: "OpenAI".to_string(),
|
||||
base_url: None,
|
||||
env_key: Some("OPENAI_API_KEY".to_string()),
|
||||
env_key_instructions: Some(
|
||||
"Log in to OpenAI and create a new API key at https://platform.openai.com/api-keys. Then paste it here.".to_string(),
|
||||
),
|
||||
experimental_bearer_token: None,
|
||||
wire_api: WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
request_max_retries: None,
|
||||
stream_max_retries: None,
|
||||
stream_idle_timeout_ms: None,
|
||||
requires_openai_auth: true,
|
||||
},
|
||||
);
|
||||
|
||||
map.insert(
|
||||
ANTHROPIC_MODEL_PROVIDER_ID.to_string(),
|
||||
ModelProviderInfo {
|
||||
name: "Anthropic".to_string(),
|
||||
base_url: Some("https://api.anthropic.com/v1".to_string()),
|
||||
env_key: Some("ANTHROPIC_API_KEY".to_string()),
|
||||
env_key_instructions: None,
|
||||
experimental_bearer_token: None,
|
||||
wire_api: WireApi::Chat,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
request_max_retries: None,
|
||||
stream_max_retries: None,
|
||||
stream_idle_timeout_ms: None,
|
||||
requires_openai_auth: false,
|
||||
},
|
||||
);
|
||||
|
||||
map.insert(
|
||||
BUILT_IN_OSS_MODEL_PROVIDER_ID.to_string(),
|
||||
create_oss_provider_with_base_url("http://localhost:11434"),
|
||||
);
|
||||
|
||||
map
|
||||
}
|
||||
|
||||
/// Minimal auth context used only for computing URLs and headers.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AuthContext {
|
||||
pub mode: AuthMode,
|
||||
pub bearer_token: Option<String>,
|
||||
pub account_id: Option<String>,
|
||||
}
|
||||
|
||||
impl ModelProviderInfo {
|
||||
/// Convenience to create a request builder with provider and auth headers.
|
||||
pub async fn create_request_builder(
|
||||
&self,
|
||||
client: &reqwest::Client,
|
||||
auth: &Option<AuthContext>,
|
||||
) -> Result<reqwest::RequestBuilder> {
|
||||
let effective_auth = if let Some(secret_key) = &self.experimental_bearer_token {
|
||||
Some(AuthContext {
|
||||
mode: AuthMode::ApiKey,
|
||||
bearer_token: Some(secret_key.clone()),
|
||||
account_id: None,
|
||||
})
|
||||
} else {
|
||||
match self.api_key() {
|
||||
Ok(Some(key)) => Some(AuthContext {
|
||||
mode: AuthMode::ApiKey,
|
||||
bearer_token: Some(key),
|
||||
account_id: None,
|
||||
}),
|
||||
Ok(None) => auth.clone(),
|
||||
Err(err) => {
|
||||
if auth.is_some() {
|
||||
auth.clone()
|
||||
} else {
|
||||
return Err(err);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let mut builder = client.post(self.get_full_url(effective_auth.as_ref()));
|
||||
builder = self.apply_http_headers(builder);
|
||||
|
||||
if let Some(context) = effective_auth.as_ref() {
|
||||
if let Some(token) = context.bearer_token.as_ref() {
|
||||
builder = builder.bearer_auth(token);
|
||||
}
|
||||
if let Some(account) = context.account_id.as_ref() {
|
||||
builder = builder.header("OpenAI-Beta", "codex-2");
|
||||
builder = builder.header("OpenAI-Organization", account);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(builder)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user