This commit is contained in:
jif-oai
2025-11-10 12:01:02 +00:00
parent 557ac63094
commit 1bac24f827
22 changed files with 2476 additions and 3058 deletions

24
codex-rs/Cargo.lock generated
View File

@@ -829,6 +829,29 @@ dependencies = [
"tracing",
]
[[package]]
name = "codex-api-client"
version = "0.0.0"
dependencies = [
"async-trait",
"bytes",
"codex-app-server-protocol",
"codex-otel",
"codex-protocol",
"eventsource-stream",
"futures",
"maplit",
"regex-lite",
"reqwest",
"serde",
"serde_json",
"thiserror 2.0.17",
"tokio",
"tokio-util",
"toml",
"tracing",
]
[[package]]
name = "codex-app-server"
version = "0.0.0"
@@ -1062,6 +1085,7 @@ dependencies = [
"base64",
"bytes",
"chrono",
"codex-api-client",
"codex-app-server-protocol",
"codex-apply-patch",
"codex-async-utils",

View File

@@ -1,5 +1,6 @@
[workspace]
members = [
"api-client",
"backend-client",
"ansi-escape",
"async-utils",
@@ -54,6 +55,7 @@ edition = "2024"
# Internal
app_test_support = { path = "app-server/tests/common" }
codex-ansi-escape = { path = "ansi-escape" }
codex-api-client = { path = "api-client" }
codex-app-server = { path = "app-server" }
codex-app-server-protocol = { path = "app-server-protocol" }
codex-apply-patch = { path = "apply-patch" }

View File

@@ -0,0 +1,29 @@
[package]
name = "codex-api-client"
version.workspace = true
edition.workspace = true
[dependencies]
async-trait = { workspace = true }
bytes = { workspace = true }
codex-app-server-protocol = { workspace = true }
codex-otel = { workspace = true }
codex-protocol = { path = "../protocol" }
eventsource-stream = { workspace = true }
futures = { workspace = true, default-features = false, features = ["std"] }
maplit = "1.0.2"
regex-lite = { workspace = true }
reqwest = { workspace = true, features = ["json", "stream"] }
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }
thiserror = { workspace = true }
tokio = { workspace = true, features = ["sync", "time", "rt", "rt-multi-thread", "macros", "io-util"] }
tokio-util = { workspace = true }
tracing = { workspace = true }
[dev-dependencies]
maplit = "1.0.2"
toml = { workspace = true }
[lints]
workspace = true

View File

@@ -0,0 +1,13 @@
use async_trait::async_trait;
use crate::error::Result;
use crate::prompt::Prompt;
use crate::stream::ResponseStream;
#[async_trait]
pub trait ApiClient: Sized {
type Config;
async fn new(config: Self::Config) -> Result<Self>;
async fn stream(&self, prompt: Prompt) -> Result<ResponseStream>;
}

View File

@@ -0,0 +1,17 @@
use async_trait::async_trait;
use codex_app_server_protocol::AuthMode;
use serde::Deserialize;
use serde::Serialize;
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct AuthContext {
pub mode: AuthMode,
pub bearer_token: Option<String>,
pub account_id: Option<String>,
}
#[async_trait]
pub trait AuthProvider: Send + Sync {
async fn auth_context(&self) -> Option<AuthContext>;
async fn refresh_token(&self) -> std::result::Result<Option<String>, String>;
}

View File

@@ -0,0 +1,629 @@
use std::time::Duration;
use async_trait::async_trait;
use bytes::Bytes;
use codex_otel::otel_event_manager::OtelEventManager;
use codex_protocol::models::ContentItem;
use codex_protocol::models::FunctionCallOutputContentItem;
use codex_protocol::models::ReasoningItemContent;
use codex_protocol::models::ResponseItem;
use codex_protocol::protocol::SessionSource;
use eventsource_stream::Eventsource;
use futures::Stream;
use futures::StreamExt;
use futures::TryStreamExt;
use serde_json::Value;
use serde_json::json;
use tokio::sync::mpsc;
use tokio::time::timeout;
use tracing::debug;
use tracing::trace;
use crate::aggregate::ChatAggregationMode;
use crate::api::ApiClient;
use crate::common::apply_subagent_header;
use crate::common::backoff;
use crate::error::Error;
use crate::model_provider::ModelProviderInfo;
use crate::prompt::Prompt;
use crate::stream::ResponseEvent;
use crate::stream::ResponseStream;
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Clone)]
pub struct ChatCompletionsApiClientConfig {
pub http_client: reqwest::Client,
pub provider: ModelProviderInfo,
pub model: String,
pub otel_event_manager: OtelEventManager,
pub session_source: SessionSource,
pub aggregation_mode: ChatAggregationMode,
}
#[derive(Clone)]
pub struct ChatCompletionsApiClient {
config: ChatCompletionsApiClientConfig,
}
#[async_trait]
impl ApiClient for ChatCompletionsApiClient {
type Config = ChatCompletionsApiClientConfig;
async fn new(config: Self::Config) -> Result<Self> {
Ok(Self { config })
}
async fn stream(&self, prompt: Prompt) -> Result<ResponseStream> {
Self::validate_prompt(&prompt)?;
let payload = self.build_payload(&prompt)?;
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent>>(1600);
let mut attempt: i64 = 0;
let max_retries = self.config.provider.request_max_retries();
loop {
attempt += 1;
let req_builder = self
.config
.provider
.create_request_builder(&self.config.http_client, &None)
.await
.map(|builder| apply_subagent_header(builder, Some(&self.config.session_source)))?;
let res = self
.config
.otel_event_manager
.log_request(attempt as u64, || {
req_builder
.header(reqwest::header::ACCEPT, "text/event-stream")
.json(&payload)
.send()
})
.await;
match res {
Ok(resp) if resp.status().is_success() => {
let stream = resp
.bytes_stream()
.map_err(|err| Error::ResponseStreamFailed {
source: err,
request_id: None,
});
let idle_timeout = self.config.provider.stream_idle_timeout();
let otel = self.config.otel_event_manager.clone();
let mode = self.config.aggregation_mode;
tokio::spawn(process_chat_sse(
stream,
tx_event.clone(),
idle_timeout,
otel,
mode,
));
return Ok(ResponseStream { rx_event });
}
Ok(resp) => {
if attempt >= max_retries {
let status = resp.status();
let body = resp
.text()
.await
.unwrap_or_else(|_| "<failed to read response>".to_string());
return Err(Error::UnexpectedStatus { status, body });
}
let retry_after = resp
.headers()
.get(reqwest::header::RETRY_AFTER)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<i64>().ok())
.map(|secs| Duration::from_secs(if secs < 0 { 0 } else { secs as u64 }));
tokio::time::sleep(retry_after.unwrap_or_else(|| backoff(attempt))).await;
}
Err(error) => {
if attempt >= max_retries {
return Err(Error::Http(error));
}
tokio::time::sleep(backoff(attempt)).await;
}
}
}
}
}
impl ChatCompletionsApiClient {
fn validate_prompt(prompt: &Prompt) -> Result<()> {
if prompt.output_schema.is_some() {
return Err(Error::UnsupportedOperation(
"output_schema is not supported for Chat Completions API".to_string(),
));
}
Ok(())
}
fn build_payload(&self, prompt: &Prompt) -> Result<serde_json::Value> {
let mut messages = Vec::<serde_json::Value>::new();
messages.push(json!({ "role": "system", "content": prompt.instructions }));
let mut reasoning_by_anchor_index: std::collections::HashMap<usize, String> =
std::collections::HashMap::new();
let mut last_emitted_role: Option<&str> = None;
for item in &prompt.input {
match item {
ResponseItem::Message { role, .. } => last_emitted_role = Some(role.as_str()),
ResponseItem::FunctionCall { .. } | ResponseItem::LocalShellCall { .. } => {
last_emitted_role = Some("assistant");
}
ResponseItem::FunctionCallOutput { .. } => last_emitted_role = Some("tool"),
ResponseItem::Reasoning { .. }
| ResponseItem::Other
| ResponseItem::CustomToolCall { .. }
| ResponseItem::CustomToolCallOutput { .. }
| ResponseItem::WebSearchCall { .. }
| ResponseItem::GhostSnapshot { .. } => {}
}
}
let mut last_user_index: Option<usize> = None;
for (idx, item) in prompt.input.iter().enumerate() {
if let ResponseItem::Message { role, .. } = item
&& role == "user"
{
last_user_index = Some(idx);
}
}
if !matches!(last_emitted_role, Some("user")) {
for (idx, item) in prompt.input.iter().enumerate() {
if let Some(u_idx) = last_user_index
&& idx <= u_idx
{
continue;
}
if let ResponseItem::Reasoning {
content: Some(items),
..
} = item
{
let mut text = String::new();
for entry in items {
match entry {
ReasoningItemContent::ReasoningText { text: segment }
| ReasoningItemContent::Text { text: segment } => {
text.push_str(segment);
}
}
}
if text.trim().is_empty() {
continue;
}
let mut attached = false;
if idx > 0
&& let ResponseItem::Message { role, .. } = &prompt.input[idx - 1]
&& role == "assistant"
{
reasoning_by_anchor_index
.entry(idx - 1)
.and_modify(|val| val.push_str(&text))
.or_insert(text.clone());
attached = true;
}
if !attached && idx + 1 < prompt.input.len() {
match &prompt.input[idx + 1] {
ResponseItem::FunctionCall { .. }
| ResponseItem::LocalShellCall { .. } => {
reasoning_by_anchor_index
.entry(idx + 1)
.and_modify(|val| val.push_str(&text))
.or_insert(text.clone());
}
ResponseItem::Message { role, .. } if role == "assistant" => {
reasoning_by_anchor_index
.entry(idx + 1)
.and_modify(|val| val.push_str(&text))
.or_insert(text.clone());
}
_ => {}
}
}
}
}
}
let mut last_assistant_text: Option<String> = None;
for (idx, item) in prompt.input.iter().enumerate() {
match item {
ResponseItem::Message { role, content, .. } => {
let mut text = String::new();
let mut items: Vec<serde_json::Value> = Vec::new();
let mut saw_image = false;
for c in content {
match c {
ContentItem::InputText { text: t }
| ContentItem::OutputText { text: t } => {
text.push_str(t);
items.push(json!({"type":"text","text": t}));
}
ContentItem::InputImage { image_url } => {
saw_image = true;
items.push(
json!({"type":"image_url","image_url": {"url": image_url}}),
);
}
}
}
if role == "assistant" {
if let Some(prev) = &last_assistant_text
&& prev == &text
{
continue;
}
last_assistant_text = Some(text.clone());
}
let content_value = if role == "assistant" {
json!(text)
} else if saw_image {
json!(items)
} else {
json!(text)
};
let mut message = json!({
"role": role,
"content": content_value,
});
if let Some(reasoning) = reasoning_by_anchor_index.get(&idx)
&& let Some(obj) = message.as_object_mut()
{
obj.insert("reasoning".to_string(), json!({"text": reasoning}));
}
messages.push(message);
}
ResponseItem::FunctionCall {
name,
arguments,
call_id,
..
} => {
messages.push(json!({
"role": "assistant",
"tool_calls": [{
"id": call_id,
"type": "function",
"function": {
"name": name,
"arguments": arguments,
},
}],
}));
}
ResponseItem::FunctionCallOutput { call_id, output } => {
let content_value = if let Some(items) = &output.content_items {
let mapped: Vec<serde_json::Value> = items
.iter()
.map(|item| match item {
FunctionCallOutputContentItem::InputText { text } => {
json!({"type":"text","text": text})
}
FunctionCallOutputContentItem::InputImage { image_url } => {
json!({"type":"image_url","image_url": {"url": image_url}})
}
})
.collect();
json!(mapped)
} else {
json!(output.content)
};
messages.push(json!({
"role": "tool",
"tool_call_id": call_id,
"content": content_value,
}));
}
ResponseItem::LocalShellCall {
id,
call_id,
action,
..
} => {
let tool_id = call_id
.clone()
.filter(|value| !value.is_empty())
.or_else(|| id.clone())
.unwrap_or_default();
messages.push(json!({
"role": "assistant",
"tool_calls": [{
"id": tool_id,
"type": "function",
"function": {
"name": "shell",
"arguments": serde_json::to_string(action).unwrap_or_default(),
},
}],
}));
}
ResponseItem::CustomToolCall {
call_id,
name,
input,
..
} => {
messages.push(json!({
"role": "assistant",
"tool_calls": [{
"id": call_id.clone(),
"type": "function",
"function": {
"name": name,
"arguments": input,
},
}],
}));
}
ResponseItem::CustomToolCallOutput { call_id, output } => {
messages.push(json!({
"role": "tool",
"tool_call_id": call_id,
"content": output,
}));
}
ResponseItem::WebSearchCall { .. }
| ResponseItem::Reasoning { .. }
| ResponseItem::Other
| ResponseItem::GhostSnapshot { .. } => {}
}
}
let tools_json = create_tools_json_for_chat_completions_api(&prompt.tools)?;
let payload = json!({
"model": self.config.model,
"messages": messages,
"stream": true,
"tools": tools_json,
});
trace!("chat completions payload: {}", payload);
Ok(payload)
}
}
/// Lightweight SSE processor for Chat Completions streaming, mapped to ResponseEvent.
async fn process_chat_sse<S>(
stream: S,
tx_event: mpsc::Sender<Result<ResponseEvent>>,
idle_timeout: Duration,
_otel_event_manager: OtelEventManager,
aggregation_mode: ChatAggregationMode,
) where
S: Stream<Item = Result<Bytes>> + Unpin,
{
let mut stream = stream.eventsource();
#[derive(Default)]
struct FunctionCallState {
name: Option<String>,
arguments: String,
call_id: Option<String>,
active: bool,
}
let mut fn_call_state = FunctionCallState::default();
let mut assistant_item: Option<ResponseItem> = None;
let mut reasoning_item: Option<ResponseItem> = None;
loop {
let response = timeout(idle_timeout, stream.next()).await;
let sse = match response {
Ok(Some(Ok(ev))) => ev,
Ok(Some(Err(err))) => {
let _ = tx_event
.send(Err(Error::Stream(err.to_string(), None)))
.await;
return;
}
Ok(None) => {
let _ = tx_event
.send(Ok(ResponseEvent::Completed {
response_id: String::new(),
token_usage: None,
}))
.await;
return;
}
Err(_) => {
let _ = tx_event
.send(Err(Error::Stream(
"idle timeout waiting for SSE".into(),
None,
)))
.await;
return;
}
};
if sse.data.trim() == "[DONE]" {
if let Some(item) = assistant_item {
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
}
if let Some(item) = reasoning_item {
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
}
let _ = tx_event
.send(Ok(ResponseEvent::Completed {
response_id: String::new(),
token_usage: None,
}))
.await;
return;
}
let Ok(parsed_chunk) = serde_json::from_str::<serde_json::Value>(&sse.data) else {
debug!("failed to parse SSE data into JSON: {}", sse.data);
continue;
};
let choices = parsed_chunk
.get("choices")
.and_then(|choices| choices.as_array())
.cloned()
.unwrap_or_default();
for choice in choices {
if let Some(delta) = choice.get("delta") {
if let Some(content) = delta.get("content").and_then(|c| c.as_array()) {
for piece in content {
if let Some(text) = piece.get("text").and_then(|t| t.as_str()) {
append_assistant_text(&tx_event, &mut assistant_item, text.to_string())
.await;
if matches!(aggregation_mode, ChatAggregationMode::Streaming) {
let _ = tx_event
.send(Ok(ResponseEvent::OutputTextDelta(text.to_string())))
.await;
}
}
}
}
if let Some(tool_calls) = delta.get("tool_calls").and_then(|c| c.as_array()) {
for call in tool_calls {
if let Some(id_val) = call.get("id").and_then(|id| id.as_str()) {
fn_call_state.call_id = Some(id_val.to_string());
}
if let Some(function) = call.get("function") {
if let Some(name) = function.get("name").and_then(|n| n.as_str()) {
fn_call_state.name = Some(name.to_string());
fn_call_state.active = true;
}
if let Some(args) = function.get("arguments").and_then(|a| a.as_str()) {
fn_call_state.arguments.push_str(args);
}
}
}
}
if let Some(reasoning) = delta.get("reasoning_content").and_then(|c| c.as_array()) {
for entry in reasoning {
if let Some(text) = entry.get("text").and_then(|t| t.as_str()) {
append_reasoning_text(&tx_event, &mut reasoning_item, text.to_string())
.await;
}
}
}
}
if let Some(finish_reason) = choice.get("finish_reason").and_then(|f| f.as_str())
&& finish_reason == "tool_calls"
&& fn_call_state.active
{
let function_name = fn_call_state.name.take().unwrap_or_default();
let call_id = fn_call_state.call_id.take().unwrap_or_default();
let arguments = fn_call_state.arguments.clone();
fn_call_state = FunctionCallState::default();
let item = ResponseItem::FunctionCall {
id: Some(call_id.clone()),
call_id,
name: function_name,
arguments,
};
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
}
}
}
}
async fn append_assistant_text(
tx_event: &mpsc::Sender<Result<ResponseEvent>>,
assistant_item: &mut Option<ResponseItem>,
text: String,
) {
if assistant_item.is_none() {
let item = ResponseItem::Message {
id: None,
role: "assistant".to_string(),
content: vec![],
};
*assistant_item = Some(item.clone());
let _ = tx_event
.send(Ok(ResponseEvent::OutputItemAdded(item)))
.await;
}
if let Some(ResponseItem::Message { content, .. }) = assistant_item {
content.push(ContentItem::OutputText { text });
}
}
async fn append_reasoning_text(
tx_event: &mpsc::Sender<Result<ResponseEvent>>,
reasoning_item: &mut Option<ResponseItem>,
text: String,
) {
if reasoning_item.is_none() {
let item = ResponseItem::Reasoning {
id: String::new(),
summary: Vec::new(),
content: Some(vec![]),
encrypted_content: None,
};
*reasoning_item = Some(item.clone());
let _ = tx_event
.send(Ok(ResponseEvent::OutputItemAdded(item)))
.await;
}
if let Some(ResponseItem::Reasoning {
content: Some(content),
..
}) = reasoning_item
{
content.push(ReasoningItemContent::ReasoningText { text });
}
}
fn create_tools_json_for_chat_completions_api(
tools: &[serde_json::Value],
) -> Result<Vec<serde_json::Value>> {
let tools_json = tools
.iter()
.filter_map(|tool| {
if tool.get("type") != Some(&serde_json::Value::String("function".to_string())) {
return None;
}
let function_value = if let Some(function) = tool.get("function") {
function.clone()
} else if let Some(map) = tool.as_object() {
let mut function = map.clone();
function.remove("type");
Value::Object(function)
} else {
return None;
};
Some(json!({
"type": "function",
"function": function_value,
}))
})
.collect::<Vec<serde_json::Value>>();
Ok(tools_json)
}
// aggregation types and adapters moved to crate::aggregate

View File

@@ -0,0 +1,38 @@
use reqwest::StatusCode;
use thiserror::Error;
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Error, Debug)]
pub enum Error {
#[error("{0}")]
UnsupportedOperation(String),
#[error(transparent)]
Http(#[from] reqwest::Error),
#[error("response stream failed: {source}")]
ResponseStreamFailed {
#[source]
source: reqwest::Error,
request_id: Option<String>,
},
#[error("stream error: {0}")]
Stream(String, Option<std::time::Duration>),
#[error("unexpected status {status}: {body}")]
UnexpectedStatus { status: StatusCode, body: String },
#[error("retry limit reached {status:?} request_id={request_id:?}")]
RetryLimit {
status: Option<StatusCode>,
request_id: Option<String>,
},
#[error("missing env var {var}: {instructions:?}")]
MissingEnvVar {
var: String,
instructions: Option<String>,
},
#[error("auth error: {0}")]
Auth(String),
#[error(transparent)]
Json(#[from] serde_json::Error),
#[error("{0}")]
Other(String),
}

View File

@@ -0,0 +1,37 @@
pub mod aggregate;
pub mod api;
pub mod auth;
pub mod chat;
mod common;
pub mod error;
pub mod model_provider;
pub mod prompt;
pub mod responses;
pub mod stream;
pub use crate::aggregate::AggregateStreamExt;
pub use crate::aggregate::ChatAggregationMode;
pub use crate::api::ApiClient;
pub use crate::auth::AuthContext;
pub use crate::auth::AuthProvider;
pub use crate::chat::ChatCompletionsApiClient;
pub use crate::chat::ChatCompletionsApiClientConfig;
pub use crate::error::Error;
pub use crate::error::Result;
pub use crate::model_provider::BUILT_IN_OSS_MODEL_PROVIDER_ID;
pub use crate::model_provider::ModelProviderInfo;
pub use crate::model_provider::WireApi;
pub use crate::model_provider::built_in_model_providers;
pub use crate::model_provider::create_oss_provider;
pub use crate::model_provider::create_oss_provider_with_base_url;
pub use crate::prompt::Prompt;
pub use crate::responses::ResponsesApiClient;
pub use crate::responses::ResponsesApiClientConfig;
pub use crate::responses::stream_from_fixture;
pub use crate::stream::EventStream;
pub use crate::stream::Reasoning;
pub use crate::stream::ResponseEvent;
pub use crate::stream::ResponseStream;
pub use crate::stream::TextControls;
pub use crate::stream::TextFormat;
pub use crate::stream::TextFormatType;

View File

@@ -0,0 +1,343 @@
//! Registry of model providers supported by Codex.
//!
//! Providers can be defined in two places:
//! 1. Built-in defaults compiled into the binary so Codex works out-of-the-box.
//! 2. User-defined entries inside `~/.codex/config.toml` under the `model_providers`
//! key. These override or extend the defaults at runtime.
use std::collections::HashMap;
use std::env::VarError;
use std::time::Duration;
use codex_app_server_protocol::AuthMode;
use serde::Deserialize;
use serde::Serialize;
use crate::auth::AuthContext;
use crate::error::Error;
use crate::error::Result;
const DEFAULT_STREAM_IDLE_TIMEOUT_MS: i64 = 300_000;
const DEFAULT_STREAM_MAX_RETRIES: i64 = 5;
const DEFAULT_REQUEST_MAX_RETRIES: i64 = 4;
/// Hard cap for user-configured `stream_max_retries`.
const MAX_STREAM_MAX_RETRIES: i64 = 100;
/// Hard cap for user-configured `request_max_retries`.
const MAX_REQUEST_MAX_RETRIES: i64 = 100;
const DEFAULT_OLLAMA_PORT: i32 = 11434;
/// Wire protocol that the provider speaks. Most third-party services only
/// implement the classic OpenAI Chat Completions JSON schema, whereas OpenAI
/// itself (and a handful of others) additionally expose the more modern
/// Responses API. The two protocols use different request/response shapes
/// and cannot be auto-detected at runtime, therefore each provider entry
/// must declare which one it expects.
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum WireApi {
/// The Responses API exposed by OpenAI at `/v1/responses`.
Responses,
/// Regular Chat Completions compatible with `/v1/chat/completions`.
#[default]
Chat,
}
/// Serializable representation of a provider definition.
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
pub struct ModelProviderInfo {
/// Friendly display name.
pub name: String,
/// Base URL for the provider's OpenAI-compatible API.
pub base_url: Option<String>,
/// Environment variable that stores the user's API key for this provider.
pub env_key: Option<String>,
/// Optional instructions to help the user get a valid value for the
/// variable and set it.
pub env_key_instructions: Option<String>,
/// Value to use with `Authorization: Bearer <token>` header. Use of this
/// config is discouraged in favor of `env_key` for security reasons, but
/// this may be necessary when using this programmatically.
pub experimental_bearer_token: Option<String>,
/// Which wire protocol this provider expects.
#[serde(default)]
pub wire_api: WireApi,
/// Optional query parameters to append to the base URL.
pub query_params: Option<HashMap<String, String>>,
/// Additional HTTP headers to include in requests to this provider where
/// the (key, value) pairs are the header name and value.
pub http_headers: Option<HashMap<String, String>>,
/// Optional HTTP headers to include in requests to this provider where the
/// (key, value) pairs are the header name and environment variable whose
/// value should be used. If the environment variable is not set, or the
/// value is empty, the header will not be included in the request.
pub env_http_headers: Option<HashMap<String, String>>,
/// Maximum number of times to retry a failed HTTP request to this provider.
pub request_max_retries: Option<i64>,
/// Number of times to retry reconnecting a dropped streaming response before failing.
pub stream_max_retries: Option<i64>,
/// Idle timeout (in milliseconds) to wait for activity on a streaming response before treating
/// the connection as lost.
pub stream_idle_timeout_ms: Option<i64>,
/// Does this provider require an OpenAI API Key or ChatGPT login token? If true,
/// the user is presented with a login screen on first run, and login preference and token/key
/// are stored in auth.json. If false (which is the default), the login screen is skipped,
/// and the API key (if needed) comes from the `env_key` environment variable.
#[serde(default)]
pub requires_openai_auth: bool,
}
impl ModelProviderInfo {
/// Construct a `POST` request builder for the given URL using the provided
/// [`reqwest::Client`] applying:
/// - provider-specific headers (static and environment based)
/// - Bearer auth header when an API key is available
/// - Auth token for OAuth
///
/// If the provider declares an `env_key` but the variable is missing or empty, this returns an
/// error identical to the one produced by [`ModelProviderInfo::api_key`].
pub async fn create_request_builder(
&self,
client: &reqwest::Client,
auth: &Option<AuthContext>,
) -> Result<reqwest::RequestBuilder> {
let effective_auth = if let Some(secret_key) = &self.experimental_bearer_token {
Some(AuthContext {
mode: AuthMode::ApiKey,
bearer_token: Some(secret_key.clone()),
account_id: None,
})
} else {
match self.api_key()? {
Some(key) => Some(AuthContext {
mode: AuthMode::ApiKey,
bearer_token: Some(key),
account_id: None,
}),
None => auth.clone(),
}
};
let url = self.get_full_url(effective_auth.as_ref());
let mut builder = client.post(url);
if let Some(context) = effective_auth.as_ref()
&& let Some(token) = context.bearer_token.as_ref()
{
builder = builder.bearer_auth(token);
}
Ok(self.apply_http_headers(builder))
}
fn get_query_string(&self) -> String {
self.query_params
.as_ref()
.map_or_else(String::new, |params| {
let full_params = params
.iter()
.map(|(k, v)| format!("{k}={v}"))
.collect::<Vec<_>>()
.join("&");
format!("?{full_params}")
})
}
pub fn get_full_url(&self, auth: Option<&AuthContext>) -> String {
let default_base_url = if matches!(
auth,
Some(AuthContext {
mode: AuthMode::ChatGPT,
..
})
) {
"https://chatgpt.com/backend-api/codex"
} else {
"https://api.openai.com/v1"
};
let query_string = self.get_query_string();
let base_url = self
.base_url
.clone()
.unwrap_or_else(|| default_base_url.to_string());
match self.wire_api {
WireApi::Responses => format!("{base_url}/responses{query_string}"),
WireApi::Chat => format!("{base_url}/chat/completions{query_string}"),
}
}
pub fn is_azure_responses_endpoint(&self) -> bool {
if self.wire_api != WireApi::Responses {
return false;
}
if self.name.eq_ignore_ascii_case("azure") {
return true;
}
self.base_url
.as_ref()
.map(|base| matches_azure_responses_base_url(base))
.unwrap_or(false)
}
/// Apply provider-specific HTTP headers (both static and environment-based) onto an existing
/// [`reqwest::RequestBuilder`] and return the updated builder.
fn apply_http_headers(&self, mut builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
if let Some(extra) = &self.http_headers {
for (k, v) in extra {
builder = builder.header(k, v);
}
}
if let Some(env_headers) = &self.env_http_headers {
for (header, env_var) in env_headers {
if let Ok(val) = std::env::var(env_var)
&& !val.trim().is_empty()
{
builder = builder.header(header, val);
}
}
}
builder
}
pub fn api_key(&self) -> Result<Option<String>> {
Ok(match self.env_key.as_ref() {
Some(env_key) => match std::env::var(env_key) {
Ok(value) if !value.trim().is_empty() => Some(value),
Ok(_missing) => None,
Err(VarError::NotPresent) => {
let instructions = self.env_key_instructions.clone();
return Err(Error::MissingEnvVar {
var: env_key.to_string(),
instructions,
});
}
Err(VarError::NotUnicode(_)) => {
return Err(Error::MissingEnvVar {
var: env_key.to_string(),
instructions: None,
});
}
},
None => None,
})
}
pub fn stream_max_retries(&self) -> i64 {
let value = self
.stream_max_retries
.unwrap_or(DEFAULT_STREAM_MAX_RETRIES)
.min(MAX_STREAM_MAX_RETRIES);
value.max(0)
}
pub fn request_max_retries(&self) -> i64 {
let value = self
.request_max_retries
.unwrap_or(DEFAULT_REQUEST_MAX_RETRIES)
.min(MAX_REQUEST_MAX_RETRIES);
value.max(0)
}
pub fn stream_idle_timeout(&self) -> Duration {
let ms = self
.stream_idle_timeout_ms
.unwrap_or(DEFAULT_STREAM_IDLE_TIMEOUT_MS);
let clamped = if ms < 0 { 0 } else { ms as u64 };
Duration::from_millis(clamped)
}
}
fn matches_azure_responses_base_url(base: &str) -> bool {
base.starts_with("https://") && base.ends_with(".openai.azure.com/openai/responses")
}
pub const BUILT_IN_OSS_MODEL_PROVIDER_ID: &str = "openai/compatible";
pub const OPENAI_MODEL_PROVIDER_ID: &str = "openai";
pub const ANTHROPIC_MODEL_PROVIDER_ID: &str = "anthropic";
/// Returns the baked-in list of providers. These can be overridden by a `[model_providers]`
/// entry inside `~/.codex/config.toml`.
pub fn built_in_model_providers() -> HashMap<String, ModelProviderInfo> {
let mut providers = HashMap::new();
providers.insert(
OPENAI_MODEL_PROVIDER_ID.to_string(),
ModelProviderInfo {
name: "OpenAI".to_string(),
base_url: None,
env_key: Some("OPENAI_API_KEY".to_string()),
env_key_instructions: Some("Log in to OpenAI and create a new API key at https://platform.openai.com/api-keys. Then paste it here.".to_string()),
experimental_bearer_token: None,
wire_api: WireApi::Responses,
query_params: None,
http_headers: None,
env_http_headers: None,
request_max_retries: None,
stream_max_retries: None,
stream_idle_timeout_ms: None,
requires_openai_auth: true,
},
);
providers.insert(
ANTHROPIC_MODEL_PROVIDER_ID.to_string(),
ModelProviderInfo {
name: "Anthropic".to_string(),
base_url: Some("https://api.anthropic.com/v1/messages".to_string()),
env_key: Some("ANTHROPIC_API_KEY".to_string()),
env_key_instructions: Some("Create a new API key at https://console.anthropic.com/settings/keys and paste it here.".to_string()),
experimental_bearer_token: None,
wire_api: WireApi::Chat,
query_params: None,
http_headers: Some(
maplit::hashmap! {
"anthropic-version".to_string() => "2023-06-01".to_string(),
}
),
env_http_headers: None,
request_max_retries: None,
stream_max_retries: None,
stream_idle_timeout_ms: None,
requires_openai_auth: false,
},
);
providers.insert(
BUILT_IN_OSS_MODEL_PROVIDER_ID.to_string(),
create_oss_provider_with_base_url("http://localhost:11434"),
);
providers
}
pub fn create_oss_provider_with_base_url(url: &str) -> ModelProviderInfo {
let http_headers = maplit::hashmap! {
"x-oss-provider".to_string() => "ollama".to_string(),
};
ModelProviderInfo {
name: "Self-hosted OpenAI-compatible (OSS)".to_string(),
base_url: Some(url.to_string()),
env_key: Some("CODEX_OSS_PROVIDER_API_KEY".to_string()),
env_key_instructions: Some(
"Set CODEx_OSS_PROVIDER_API_KEY to authenticate with this provider.".to_string(),
),
experimental_bearer_token: None,
wire_api: WireApi::Chat,
query_params: None,
http_headers: Some(http_headers),
env_http_headers: None,
request_max_retries: None,
stream_max_retries: None,
stream_idle_timeout_ms: None,
requires_openai_auth: false,
}
}
/// Convenience helper to construct a default `openai/compatible` provider pointing at localhost.
pub fn create_oss_provider() -> ModelProviderInfo {
create_oss_provider_with_base_url(&format!("http://localhost:{DEFAULT_OLLAMA_PORT}"))
}

View File

@@ -0,0 +1,46 @@
use codex_protocol::models::ResponseItem;
use codex_protocol::protocol::SessionSource;
use serde_json::Value;
use crate::Reasoning;
use crate::TextControls;
#[derive(Debug, Clone, Default)]
pub struct Prompt {
pub instructions: String,
pub input: Vec<ResponseItem>,
pub tools: Vec<Value>,
pub parallel_tool_calls: bool,
pub output_schema: Option<Value>,
pub reasoning: Option<Reasoning>,
pub text_controls: Option<TextControls>,
pub prompt_cache_key: Option<String>,
pub session_source: Option<SessionSource>,
}
impl Prompt {
#[allow(clippy::too_many_arguments)]
pub fn new(
instructions: String,
input: Vec<ResponseItem>,
tools: Vec<Value>,
parallel_tool_calls: bool,
output_schema: Option<Value>,
reasoning: Option<Reasoning>,
text_controls: Option<TextControls>,
prompt_cache_key: Option<String>,
session_source: Option<SessionSource>,
) -> Self {
Self {
instructions,
input,
tools,
parallel_tool_calls,
output_schema,
reasoning,
text_controls,
prompt_cache_key,
session_source,
}
}
}

View File

@@ -0,0 +1,819 @@
use std::io::BufRead;
use std::path::Path;
use std::sync::Arc;
use std::sync::OnceLock;
use std::time::Duration;
use async_trait::async_trait;
use bytes::Bytes;
use codex_app_server_protocol::AuthMode;
use codex_otel::otel_event_manager::OtelEventManager;
use codex_protocol::ConversationId;
use codex_protocol::models::ResponseItem;
use codex_protocol::protocol::RateLimitSnapshot;
use codex_protocol::protocol::TokenUsage;
use futures::Stream;
use futures::StreamExt;
use futures::TryStreamExt;
use regex_lite::Regex;
use reqwest::StatusCode;
use reqwest::header::HeaderMap;
use serde::Deserialize;
use serde_json::Value;
use serde_json::json;
use tokio::sync::mpsc;
use tokio::time::timeout;
use tokio_util::io::ReaderStream;
use tracing::debug;
use tracing::trace;
use crate::api::ApiClient;
use crate::auth::AuthProvider;
use crate::common::apply_subagent_header;
use crate::common::backoff;
use crate::error::Error;
use crate::model_provider::ModelProviderInfo;
use crate::prompt::Prompt;
use crate::stream::ResponseEvent;
use crate::stream::ResponseStream;
type Result<T> = std::result::Result<T, Error>;
#[derive(Clone)]
pub struct ResponsesApiClientConfig {
pub http_client: reqwest::Client,
pub provider: ModelProviderInfo,
pub model: String,
pub conversation_id: ConversationId,
pub auth_provider: Option<Arc<dyn AuthProvider>>,
pub otel_event_manager: OtelEventManager,
}
#[derive(Clone)]
pub struct ResponsesApiClient {
config: ResponsesApiClientConfig,
}
#[async_trait]
impl ApiClient for ResponsesApiClient {
type Config = ResponsesApiClientConfig;
async fn new(config: Self::Config) -> Result<Self> {
Ok(Self { config })
}
async fn stream(&self, prompt: Prompt) -> Result<ResponseStream> {
if self.config.provider.wire_api != crate::model_provider::WireApi::Responses {
return Err(Error::UnsupportedOperation(
"ResponsesApiClient requires a Responses provider".to_string(),
));
}
let mut payload_json = self.build_payload(&prompt)?;
if self.config.provider.is_azure_responses_endpoint()
&& let Some(input_value) = payload_json.get_mut("input")
&& let Some(array) = input_value.as_array_mut()
{
attach_item_ids_array(array, &prompt.input);
}
let max_attempts = self.config.provider.request_max_retries();
for attempt in 0..=max_attempts {
match self
.attempt_stream_responses(attempt, &prompt, &payload_json)
.await
{
Ok(stream) => return Ok(stream),
Err(StreamAttemptError::Fatal(err)) => return Err(err),
Err(retryable) => {
if attempt == max_attempts {
return Err(retryable.into_error());
}
tokio::time::sleep(retryable.delay(attempt)).await;
}
}
}
unreachable!("attempt_stream_responses should always return");
}
}
impl ResponsesApiClient {
fn build_payload(&self, prompt: &Prompt) -> Result<Value> {
let azure_workaround = self.config.provider.is_azure_responses_endpoint();
let mut payload = json!({
"model": self.config.model,
"instructions": prompt.instructions,
"input": prompt.input,
"tools": prompt.tools,
"tool_choice": "auto",
"parallel_tool_calls": prompt.parallel_tool_calls,
"store": azure_workaround,
"stream": true,
"prompt_cache_key": prompt
.prompt_cache_key
.clone()
.unwrap_or_else(|| self.config.conversation_id.to_string()),
});
if let Some(reasoning) = prompt.reasoning.as_ref()
&& let Some(obj) = payload.as_object_mut()
{
obj.insert("reasoning".to_string(), serde_json::to_value(reasoning)?);
}
if let Some(text) = prompt.text_controls.as_ref()
&& let Some(obj) = payload.as_object_mut()
{
obj.insert("text".to_string(), serde_json::to_value(text)?);
}
let include = if prompt.reasoning.is_some() {
vec!["reasoning.encrypted_content".to_string()]
} else {
Vec::new()
};
if let Some(obj) = payload.as_object_mut() {
obj.insert(
"include".to_string(),
Value::Array(include.into_iter().map(Value::String).collect()),
);
}
Ok(payload)
}
async fn attempt_stream_responses(
&self,
attempt: i64,
prompt: &Prompt,
payload_json: &Value,
) -> std::result::Result<ResponseStream, StreamAttemptError> {
let auth = if let Some(provider) = &self.config.auth_provider {
provider.auth_context().await
} else {
None
};
trace!(
"POST to {}: {:?}",
self.config.provider.get_full_url(auth.as_ref()),
serde_json::to_string(payload_json)
.unwrap_or_else(|_| "<unable to serialize payload>".to_string())
);
let mut req_builder = self
.config
.provider
.create_request_builder(&self.config.http_client, &auth)
.await
.map_err(StreamAttemptError::Fatal)?;
req_builder = apply_subagent_header(req_builder, prompt.session_source.as_ref());
req_builder = req_builder
.header("conversation_id", self.config.conversation_id.to_string())
.header("session_id", self.config.conversation_id.to_string())
.header(reqwest::header::ACCEPT, "text/event-stream")
.json(payload_json);
if let Some(auth_ctx) = auth.as_ref()
&& auth_ctx.mode == AuthMode::ChatGPT
&& let Some(account_id) = auth_ctx.account_id.clone()
{
req_builder = req_builder.header("chatgpt-account-id", account_id);
}
let res = self
.config
.otel_event_manager
.log_request(attempt as u64, || req_builder.send())
.await;
let mut request_id = None;
if let Ok(resp) = &res {
request_id = resp
.headers()
.get("cf-ray")
.and_then(|v| v.to_str().ok())
.map(std::string::ToString::to_string);
}
match res {
Ok(resp) if resp.status().is_success() => {
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent>>(1600);
if let Some(snapshot) = parse_rate_limit_snapshot(resp.headers())
&& tx_event
.send(Ok(ResponseEvent::RateLimits(snapshot)))
.await
.is_err()
{
debug!("receiver dropped rate limit snapshot event");
}
let stream = resp
.bytes_stream()
.map_err(move |err| Error::ResponseStreamFailed {
source: err,
request_id: request_id.clone(),
});
let idle_timeout = self.config.provider.stream_idle_timeout();
let otel = self.config.otel_event_manager.clone();
tokio::spawn(process_sse(stream, tx_event, idle_timeout, otel));
Ok(ResponseStream { rx_event })
}
Ok(resp) => Err(handle_error_response(resp, request_id, &self.config).await),
Err(err) => Err(StreamAttemptError::RetryableTransportError(Error::Http(
err,
))),
}
}
}
async fn handle_error_response(
resp: reqwest::Response,
request_id: Option<String>,
_config: &ResponsesApiClientConfig,
) -> StreamAttemptError {
let status = resp.status();
let retry_after_secs = resp
.headers()
.get(reqwest::header::RETRY_AFTER)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<i64>().ok());
let retry_after = retry_after_secs.map(|secs| {
let clamped = if secs < 0 { 0 } else { secs as u64 };
Duration::from_secs(clamped)
});
if !(status == StatusCode::TOO_MANY_REQUESTS
|| status == StatusCode::UNAUTHORIZED
|| status.is_server_error())
{
let body = resp.text().await.unwrap_or_default();
return StreamAttemptError::Fatal(Error::UnexpectedStatus { status, body });
}
if status == StatusCode::TOO_MANY_REQUESTS {
let body = resp.json::<ErrorResponse>().await.ok();
if let Some(ErrorResponse { error }) = body {
if error.r#type.as_deref() == Some("usage_limit_reached") {
return StreamAttemptError::Fatal(Error::Stream(
"usage limit reached".to_string(),
None,
));
} else if error.r#type.as_deref() == Some("usage_not_included") {
return StreamAttemptError::Fatal(Error::Stream(
"usage not included".to_string(),
None,
));
} else if is_quota_exceeded_error(&error) {
return StreamAttemptError::Fatal(Error::Stream(
"quota exceeded".to_string(),
None,
));
}
}
}
StreamAttemptError::RetryableHttpError {
status,
retry_after,
request_id,
}
}
#[allow(clippy::too_many_arguments)]
async fn process_sse<S>(
stream: S,
tx_event: mpsc::Sender<Result<ResponseEvent>>,
max_idle_duration: Duration,
otel_event_manager: OtelEventManager,
) where
S: Stream<Item = Result<Bytes>> + Send + 'static + Unpin,
{
let mut stream = stream;
let mut response_completed: Option<ResponseCompleted> = None;
let mut response_error: Option<Error> = None;
loop {
let result = timeout(max_idle_duration, stream.next()).await;
match result {
Err(_) => {
if let Some(completed) = response_completed.take() {
let _ = emit_response_completed(
tx_event.clone(),
completed,
response_error.take(),
&otel_event_manager,
)
.await;
return;
}
let _ = tx_event
.send(Err(Error::Stream(
"stream idle timeout fired before Completed event".to_string(),
None,
)))
.await;
return;
}
Ok(Some(Err(err))) => {
let _ = tx_event.send(Err(err)).await;
return;
}
Ok(Some(Ok(chunk))) => {
if let Err(err) = process_sse_chunk(chunk, &tx_event).await {
let _ = tx_event.send(Err(err)).await;
return;
}
}
Ok(None) => {
if let Some(completed) = response_completed.take() {
let _ = emit_response_completed(
tx_event.clone(),
completed,
response_error.take(),
&otel_event_manager,
)
.await;
}
return;
}
}
}
}
async fn emit_response_completed(
tx_event: mpsc::Sender<Result<ResponseEvent>>,
completed: ResponseCompleted,
response_error: Option<Error>,
_otel_event_manager: &OtelEventManager,
) -> Result<()> {
if let Some(err) = response_error {
tx_event.send(Err(err)).await.ok();
return Ok(());
}
let event = ResponseEvent::Completed {
response_id: completed.id,
token_usage: completed.usage,
};
tx_event.send(Ok(event)).await.ok();
Ok(())
}
fn parse_rate_limit_snapshot(_headers: &HeaderMap) -> Option<RateLimitSnapshot> {
None
}
async fn process_sse_chunk(
chunk: Bytes,
tx_event: &mpsc::Sender<Result<ResponseEvent>>,
) -> Result<()> {
let chunk_str = std::str::from_utf8(&chunk)
.map_err(|err| Error::Other(format!("Invalid UTF-8 in SSE chunk: {err}")))?;
trace!("responses api chunk ({chunk_str:?})");
let mut data_buffer = String::new();
for line in chunk_str.lines() {
if let Some(tail) = line.strip_prefix("data:") {
data_buffer.push_str(tail.trim_start());
}
if line.is_empty() {
let payload: sse::Payload = serde_json::from_str(&data_buffer)
.map_err(|err| Error::Other(format!("Cannot parse SSE JSON: {err}")))?;
handle_sse_payload(payload, tx_event).await?;
data_buffer.clear();
}
}
Ok(())
}
async fn handle_sse_payload(
payload: sse::Payload,
tx_event: &mpsc::Sender<Result<ResponseEvent>>,
) -> Result<()> {
if let Some(responses) = payload.responses {
for ev in responses {
let event = match ev {
sse::Response::Completed(complete) => ResponseEvent::Completed {
response_id: complete.id,
token_usage: complete.usage,
},
sse::Response::Error(err) => {
let retry_after = err
.retry_after
.map(|secs| Duration::from_secs(if secs < 0 { 0 } else { secs as u64 }));
return Err(Error::Stream(
err.message.unwrap_or_else(|| "fatal error".to_string()),
retry_after,
));
}
};
tx_event.send(Ok(event)).await.ok();
}
}
if let Some(message_delta) = payload.response_message_delta {
let ev = ResponseEvent::OutputTextDelta(message_delta.text.clone());
tx_event.send(Ok(ev)).await.ok();
}
if let Some(_rate_limits) = payload.rate_limits {
// Rate limit snapshots are not emitted for this protocol shape in this build.
}
if let Some(_response_content) = payload.response_content {
// Not used currently
}
if let Some(ev) = payload.response_event {
debug!("Unhandled response_event: {ev:?}");
}
if let Some(item) = payload.response_output_item {
match item.r#type {
sse::OutputItem::Created => {
tx_event.send(Ok(ResponseEvent::Created)).await.ok();
}
}
}
if let Some(done) = payload.response_output_text_delta {
tx_event
.send(Ok(ResponseEvent::OutputTextDelta(done.text)))
.await
.ok();
}
if let Some(completed) = payload.response_output_item_done {
let response_item =
serde_json::from_value::<ResponseItem>(completed.item).map_err(Error::Json)?;
tx_event
.send(Ok(ResponseEvent::OutputItemDone(response_item)))
.await
.ok();
}
if let Some(reasoning_content_delta) = payload.response_output_reasoning_delta {
tx_event
.send(Ok(ResponseEvent::ReasoningContentDelta(
reasoning_content_delta.text,
)))
.await
.ok();
}
if let Some(reasoning_summary_delta) = payload.response_output_reasoning_summary_delta {
tx_event
.send(Ok(ResponseEvent::ReasoningSummaryDelta(
reasoning_summary_delta.text,
)))
.await
.ok();
}
if let Some(ev) = payload.response_error
&& ev.code.as_deref() == Some("max_response_tokens")
{
let _ = tx_event
.send(Err(Error::Stream(
"context window exceeded".to_string(),
None,
)))
.await;
}
Ok(())
}
#[derive(Debug, Deserialize)]
struct ResponseCompleted {
id: String,
usage: Option<TokenUsage>,
}
#[derive(Debug, Deserialize)]
struct ErrorResponse {
error: ErrorBody,
}
#[derive(Debug, Deserialize)]
struct ErrorBody {
r#type: Option<String>,
code: Option<String>,
message: Option<String>,
plan_type: Option<String>,
resets_at: Option<i64>,
}
fn is_quota_exceeded_error(error: &ErrorBody) -> bool {
error.code.as_deref() == Some("quota_exceeded")
}
enum StreamAttemptError {
RetryableHttpError {
status: StatusCode,
retry_after: Option<Duration>,
request_id: Option<String>,
},
RetryableTransportError(Error),
Fatal(Error),
}
impl StreamAttemptError {
fn delay(&self, attempt: i64) -> Duration {
match self {
StreamAttemptError::RetryableHttpError {
retry_after: Some(retry_after),
..
} => *retry_after,
StreamAttemptError::RetryableHttpError {
retry_after: None, ..
}
| StreamAttemptError::RetryableTransportError(..) => backoff(attempt),
StreamAttemptError::Fatal(..) => Duration::from_millis(0),
}
}
fn into_error(self) -> Error {
match self {
StreamAttemptError::RetryableHttpError {
status, request_id, ..
} => Error::RetryLimit {
status: Some(status),
request_id,
},
StreamAttemptError::RetryableTransportError(err) | StreamAttemptError::Fatal(err) => {
err
}
}
}
}
// backoff moved to crate::common
fn rate_limit_regex() -> Option<&'static Regex> {
static RE: OnceLock<Option<Regex>> = OnceLock::new();
RE.get_or_init(|| Regex::new(r"Please try again in (\d+(?:\.\d+)?)(s|ms)").ok())
.as_ref()
}
fn try_parse_retry_after(err: &ErrorResponse) -> Option<Duration> {
if err.error.code.as_deref() != Some("rate_limit_exceeded") {
return None;
}
if let Some(re) = rate_limit_regex()
&& let Some(message) = &err.error.message
&& let Some(captures) = re.captures(message)
{
let seconds = captures.get(1);
let unit = captures.get(2);
if let (Some(value), Some(unit)) = (seconds, unit) {
let value = value.as_str().parse::<f64>().ok()?;
let unit = unit.as_str();
if unit == "s" {
return Some(Duration::from_secs_f64(value));
} else if unit == "ms" {
return Some(Duration::from_millis(value as u64));
}
}
}
None
}
fn is_context_window_error(error: &ErrorResponse) -> bool {
error.error.code.as_deref() == Some("context_length_exceeded")
}
/// used in tests to stream from a text SSE file
pub async fn stream_from_fixture(
path: impl AsRef<Path>,
provider: ModelProviderInfo,
otel_event_manager: OtelEventManager,
) -> Result<ResponseStream> {
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent>>(1600);
let display_path = path.as_ref().display().to_string();
let file = std::fs::File::open(path.as_ref())
.map_err(|err| Error::Other(format!("failed to open fixture {display_path}: {err}")))?;
let lines = std::io::BufReader::new(file).lines();
let mut content = String::new();
for line in lines {
let line = line
.map_err(|err| Error::Other(format!("failed to read fixture {display_path}: {err}")))?;
content.push_str(&line);
content.push('\n');
content.push('\n');
}
let rdr = std::io::Cursor::new(content);
let stream = ReaderStream::new(rdr).map_err(|err| Error::Other(err.to_string()));
tokio::spawn(process_sse(
stream,
tx_event,
provider.stream_idle_timeout(),
otel_event_manager,
));
Ok(ResponseStream { rx_event })
}
fn attach_item_ids_array(_json_array: &mut Vec<Value>, _prompt_input: &[ResponseItem]) {
// no-op for current protocol version
}
#[derive(Debug, Deserialize)]
struct StreamEvent {
r#type: String,
response: Option<Value>,
item: Option<Value>,
error: Option<Value>,
}
#[derive(Debug, Deserialize)]
struct StreamResponsePayload {
event: StreamEvent,
}
async fn handle_stream_event(
event: StreamEvent,
tx_event: mpsc::Sender<Result<ResponseEvent>>,
response_completed: &mut Option<ResponseCompleted>,
response_error: &mut Option<Error>,
) {
trace!("response event: {}", event.r#type);
match event.r#type.as_str() {
"response.output_text.delta" => {
if let Some(item_val) = event.item {
let resp = serde_json::from_value::<TextDelta>(item_val);
if let Ok(delta) = resp {
let event = ResponseEvent::OutputTextDelta(delta.delta);
if tx_event.send(Ok(event)).await.is_err() {}
}
}
}
"response.error" => {
if let Some(err_val) = event.error {
let err_resp = serde_json::from_value::<ErrorResponse>(err_val);
match err_resp {
Ok(err) => {
let retry_after = try_parse_retry_after(&err);
*response_error = Some(Error::Stream(
err.error
.message
.unwrap_or_else(|| "unknown error".to_string()),
retry_after,
));
}
Err(err) => {
let _ = tx_event
.send(Err(Error::Stream(
format!("failed to parse ErrorResponse: {err}"),
None,
)))
.await;
}
}
}
}
"response.completed" => {
if let Some(resp_val) = event.response {
match serde_json::from_value::<ResponseCompleted>(resp_val) {
Ok(resp) => {
*response_completed = Some(resp);
}
Err(err) => {
let _ = tx_event
.send(Err(Error::Stream(
format!("failed to parse ResponseCompleted: {err}"),
None,
)))
.await;
}
};
};
}
"response.output_item.added" => {
if let Some(item_val) = event.item
&& let Ok(item) = serde_json::from_value::<ResponseItem>(item_val)
{
let event = ResponseEvent::OutputItemAdded(item);
if tx_event.send(Ok(event)).await.is_err() {}
}
}
"response.reasoning_summary_part.added" => {
let event = ResponseEvent::ReasoningSummaryPartAdded;
let _ = tx_event.send(Ok(event)).await;
}
_ => {}
}
}
#[derive(Debug, Deserialize)]
struct TextDelta {
role: String,
delta: String,
}
mod sse {
use serde::Deserialize;
use serde_json::Value;
#[derive(Debug, Deserialize)]
pub struct Payload {
pub responses: Option<Vec<Response>>,
pub response_content: Option<Value>,
pub response_error: Option<ResponseError>,
pub response_event: Option<String>,
pub response_message_delta: Option<ResponseMessageDelta>,
pub response_output_item: Option<ResponseOutputItem>,
pub response_output_text_delta: Option<ResponseOutputTextDelta>,
pub response_output_item_done: Option<ResponseOutputItemDone>,
pub response_output_reasoning_delta: Option<ResponseOutputReasoningDelta>,
pub response_output_reasoning_summary_delta: Option<ResponseOutputReasoningSummaryDelta>,
pub rate_limits: Option<Vec<RateLimit>>,
}
#[derive(Debug, Deserialize)]
pub enum Response {
#[serde(rename = "response.completed")]
Completed(ResponseCompleted),
#[serde(rename = "response.error")]
Error(ResponseError),
}
#[derive(Debug, Deserialize)]
pub struct ResponseCompleted {
pub id: String,
pub usage: Option<codex_protocol::protocol::TokenUsage>,
}
#[derive(Debug, Deserialize)]
pub struct ResponseError {
pub code: Option<String>,
pub message: Option<String>,
pub retry_after: Option<i64>,
}
#[derive(Debug, Deserialize)]
pub struct ResponseMessageDelta {
pub text: String,
pub role: String,
pub appended_content: Vec<codex_protocol::models::ContentItem>,
}
#[derive(Debug, Deserialize)]
pub enum OutputItem {
#[serde(rename = "response.output_item.created")]
Created,
}
#[derive(Debug, Deserialize)]
pub struct ResponseOutputItem {
pub r#type: OutputItem,
pub item: Value,
}
#[derive(Debug, Deserialize)]
pub struct ResponseOutputTextDelta {
pub text: String,
}
#[derive(Debug, Deserialize)]
pub struct ResponseOutputItemDone {
pub item: Value,
}
#[derive(Debug, Deserialize)]
pub struct ResponseOutputReasoningDelta {
pub content: Vec<codex_protocol::models::ReasoningItemContent>,
pub text: String,
}
#[derive(Debug, Deserialize)]
pub struct ResponseOutputReasoningSummaryDelta {
pub summary: Vec<codex_protocol::models::ReasoningItemReasoningSummary>,
pub text: String,
}
#[derive(Debug, Deserialize)]
pub struct RateLimit {
pub window: String,
pub remaining_tokens: i64,
pub limit: i64,
pub reset_seconds: i64,
}
}

View File

@@ -0,0 +1,83 @@
use std::pin::Pin;
use std::task::Context;
use std::task::Poll;
use codex_protocol::config_types::ReasoningEffort as ReasoningEffortConfig;
use codex_protocol::config_types::ReasoningSummary as ReasoningSummaryConfig;
use codex_protocol::models::ResponseItem;
use codex_protocol::protocol::RateLimitSnapshot;
use codex_protocol::protocol::TokenUsage;
use futures::Stream;
use serde::Serialize;
use serde_json::Value;
use tokio::sync::mpsc;
use crate::error::Result;
#[derive(Debug, Serialize, Clone)]
pub struct Reasoning {
#[serde(skip_serializing_if = "Option::is_none")]
pub effort: Option<ReasoningEffortConfig>,
#[serde(skip_serializing_if = "Option::is_none")]
pub summary: Option<ReasoningSummaryConfig>,
}
#[derive(Debug, Serialize, Default, Clone)]
#[serde(rename_all = "snake_case")]
pub enum TextFormatType {
#[default]
JsonSchema,
}
#[derive(Debug, Serialize, Default, Clone)]
pub struct TextFormat {
pub r#type: TextFormatType,
pub strict: bool,
pub schema: Value,
pub name: String,
}
#[derive(Debug, Serialize, Default, Clone)]
pub struct TextControls {
#[serde(skip_serializing_if = "Option::is_none")]
pub verbosity: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub format: Option<TextFormat>,
}
#[derive(Debug)]
pub enum ResponseEvent {
Created,
OutputItemDone(ResponseItem),
OutputItemAdded(ResponseItem),
Completed {
response_id: String,
token_usage: Option<TokenUsage>,
},
OutputTextDelta(String),
ReasoningSummaryDelta(String),
ReasoningContentDelta(String),
ReasoningSummaryPartAdded,
RateLimits(RateLimitSnapshot),
}
#[derive(Debug)]
pub struct EventStream<T> {
pub(crate) rx_event: mpsc::Receiver<T>,
}
impl<T> EventStream<T> {
pub fn from_receiver(rx_event: mpsc::Receiver<T>) -> Self {
Self { rx_event }
}
}
impl<T> Stream for EventStream<T> {
type Item = T;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.rx_event.poll_recv(cx)
}
}
pub type ResponseStream = EventStream<Result<ResponseEvent>>;

View File

@@ -22,6 +22,7 @@ chrono = { workspace = true, features = ["serde"] }
codex-app-server-protocol = { workspace = true }
codex-apply-patch = { workspace = true }
codex-async-utils = { workspace = true }
codex-api-client = { workspace = true }
codex-file-search = { workspace = true }
codex-git = { workspace = true }
codex-keyring-store = { workspace = true }

View File

@@ -1,967 +0,0 @@
use std::time::Duration;
use crate::ModelProviderInfo;
use crate::client_common::Prompt;
use crate::client_common::ResponseEvent;
use crate::client_common::ResponseStream;
use crate::default_client::CodexHttpClient;
use crate::error::CodexErr;
use crate::error::ConnectionFailedError;
use crate::error::ResponseStreamFailed;
use crate::error::Result;
use crate::error::RetryLimitReachedError;
use crate::error::UnexpectedResponseError;
use crate::model_family::ModelFamily;
use crate::tools::spec::create_tools_json_for_chat_completions_api;
use crate::util::backoff;
use bytes::Bytes;
use codex_otel::otel_event_manager::OtelEventManager;
use codex_protocol::models::ContentItem;
use codex_protocol::models::FunctionCallOutputContentItem;
use codex_protocol::models::ReasoningItemContent;
use codex_protocol::models::ResponseItem;
use codex_protocol::protocol::SessionSource;
use codex_protocol::protocol::SubAgentSource;
use eventsource_stream::Eventsource;
use futures::Stream;
use futures::StreamExt;
use futures::TryStreamExt;
use reqwest::StatusCode;
use serde_json::json;
use std::pin::Pin;
use std::task::Context;
use std::task::Poll;
use tokio::sync::mpsc;
use tokio::time::timeout;
use tracing::debug;
use tracing::trace;
/// Implementation for the classic Chat Completions API.
pub(crate) async fn stream_chat_completions(
prompt: &Prompt,
model_family: &ModelFamily,
client: &CodexHttpClient,
provider: &ModelProviderInfo,
otel_event_manager: &OtelEventManager,
session_source: &SessionSource,
) -> Result<ResponseStream> {
if prompt.output_schema.is_some() {
return Err(CodexErr::UnsupportedOperation(
"output_schema is not supported for Chat Completions API".to_string(),
));
}
// Build messages array
let mut messages = Vec::<serde_json::Value>::new();
let full_instructions = prompt.get_full_instructions(model_family);
messages.push(json!({"role": "system", "content": full_instructions}));
let input = prompt.get_formatted_input();
// Pre-scan: map Reasoning blocks to the adjacent assistant anchor after the last user.
// - If the last emitted message is a user message, drop all reasoning.
// - Otherwise, for each Reasoning item after the last user message, attach it
// to the immediate previous assistant message (stop turns) or the immediate
// next assistant anchor (tool-call turns: function/local shell call, or assistant message).
let mut reasoning_by_anchor_index: std::collections::HashMap<usize, String> =
std::collections::HashMap::new();
// Determine the last role that would be emitted to Chat Completions.
let mut last_emitted_role: Option<&str> = None;
for item in &input {
match item {
ResponseItem::Message { role, .. } => last_emitted_role = Some(role.as_str()),
ResponseItem::FunctionCall { .. } | ResponseItem::LocalShellCall { .. } => {
last_emitted_role = Some("assistant")
}
ResponseItem::FunctionCallOutput { .. } => last_emitted_role = Some("tool"),
ResponseItem::Reasoning { .. } | ResponseItem::Other => {}
ResponseItem::CustomToolCall { .. } => {}
ResponseItem::CustomToolCallOutput { .. } => {}
ResponseItem::WebSearchCall { .. } => {}
ResponseItem::GhostSnapshot { .. } => {}
}
}
// Find the last user message index in the input.
let mut last_user_index: Option<usize> = None;
for (idx, item) in input.iter().enumerate() {
if let ResponseItem::Message { role, .. } = item
&& role == "user"
{
last_user_index = Some(idx);
}
}
// Attach reasoning only if the conversation does not end with a user message.
if !matches!(last_emitted_role, Some("user")) {
for (idx, item) in input.iter().enumerate() {
// Only consider reasoning that appears after the last user message.
if let Some(u_idx) = last_user_index
&& idx <= u_idx
{
continue;
}
if let ResponseItem::Reasoning {
content: Some(items),
..
} = item
{
let mut text = String::new();
for entry in items {
match entry {
ReasoningItemContent::ReasoningText { text: segment }
| ReasoningItemContent::Text { text: segment } => text.push_str(segment),
}
}
if text.trim().is_empty() {
continue;
}
// Prefer immediate previous assistant message (stop turns)
let mut attached = false;
if idx > 0
&& let ResponseItem::Message { role, .. } = &input[idx - 1]
&& role == "assistant"
{
reasoning_by_anchor_index
.entry(idx - 1)
.and_modify(|v| v.push_str(&text))
.or_insert(text.clone());
attached = true;
}
// Otherwise, attach to immediate next assistant anchor (tool-calls or assistant message)
if !attached && idx + 1 < input.len() {
match &input[idx + 1] {
ResponseItem::FunctionCall { .. } | ResponseItem::LocalShellCall { .. } => {
reasoning_by_anchor_index
.entry(idx + 1)
.and_modify(|v| v.push_str(&text))
.or_insert(text.clone());
}
ResponseItem::Message { role, .. } if role == "assistant" => {
reasoning_by_anchor_index
.entry(idx + 1)
.and_modify(|v| v.push_str(&text))
.or_insert(text.clone());
}
_ => {}
}
}
}
}
}
// Track last assistant text we emitted to avoid duplicate assistant messages
// in the outbound Chat Completions payload (can happen if a final
// aggregated assistant message was recorded alongside an earlier partial).
let mut last_assistant_text: Option<String> = None;
for (idx, item) in input.iter().enumerate() {
match item {
ResponseItem::Message { role, content, .. } => {
// Build content either as a plain string (typical for assistant text)
// or as an array of content items when images are present (user/tool multimodal).
let mut text = String::new();
let mut items: Vec<serde_json::Value> = Vec::new();
let mut saw_image = false;
for c in content {
match c {
ContentItem::InputText { text: t }
| ContentItem::OutputText { text: t } => {
text.push_str(t);
items.push(json!({"type":"text","text": t}));
}
ContentItem::InputImage { image_url } => {
saw_image = true;
items.push(json!({"type":"image_url","image_url": {"url": image_url}}));
}
}
}
// Skip exact-duplicate assistant messages.
if role == "assistant" {
if let Some(prev) = &last_assistant_text
&& prev == &text
{
continue;
}
last_assistant_text = Some(text.clone());
}
// For assistant messages, always send a plain string for compatibility.
// For user messages, if an image is present, send an array of content items.
let content_value = if role == "assistant" {
json!(text)
} else if saw_image {
json!(items)
} else {
json!(text)
};
let mut msg = json!({"role": role, "content": content_value});
if role == "assistant"
&& let Some(reasoning) = reasoning_by_anchor_index.get(&idx)
&& let Some(obj) = msg.as_object_mut()
{
obj.insert("reasoning".to_string(), json!(reasoning));
}
messages.push(msg);
}
ResponseItem::FunctionCall {
name,
arguments,
call_id,
..
} => {
let mut msg = json!({
"role": "assistant",
"content": null,
"tool_calls": [{
"id": call_id,
"type": "function",
"function": {
"name": name,
"arguments": arguments,
}
}]
});
if let Some(reasoning) = reasoning_by_anchor_index.get(&idx)
&& let Some(obj) = msg.as_object_mut()
{
obj.insert("reasoning".to_string(), json!(reasoning));
}
messages.push(msg);
}
ResponseItem::LocalShellCall {
id,
call_id: _,
status,
action,
} => {
// Confirm with API team.
let mut msg = json!({
"role": "assistant",
"content": null,
"tool_calls": [{
"id": id.clone().unwrap_or_else(|| "".to_string()),
"type": "local_shell_call",
"status": status,
"action": action,
}]
});
if let Some(reasoning) = reasoning_by_anchor_index.get(&idx)
&& let Some(obj) = msg.as_object_mut()
{
obj.insert("reasoning".to_string(), json!(reasoning));
}
messages.push(msg);
}
ResponseItem::FunctionCallOutput { call_id, output } => {
// Prefer structured content items when available (e.g., images)
// otherwise fall back to the legacy plain-string content.
let content_value = if let Some(items) = &output.content_items {
let mapped: Vec<serde_json::Value> = items
.iter()
.map(|it| match it {
FunctionCallOutputContentItem::InputText { text } => {
json!({"type":"text","text": text})
}
FunctionCallOutputContentItem::InputImage { image_url } => {
json!({"type":"image_url","image_url": {"url": image_url}})
}
})
.collect();
json!(mapped)
} else {
json!(output.content)
};
messages.push(json!({
"role": "tool",
"tool_call_id": call_id,
"content": content_value,
}));
}
ResponseItem::CustomToolCall {
id,
call_id: _,
name,
input,
status: _,
} => {
messages.push(json!({
"role": "assistant",
"content": null,
"tool_calls": [{
"id": id,
"type": "custom",
"custom": {
"name": name,
"input": input,
}
}]
}));
}
ResponseItem::CustomToolCallOutput { call_id, output } => {
messages.push(json!({
"role": "tool",
"tool_call_id": call_id,
"content": output,
}));
}
ResponseItem::GhostSnapshot { .. } => {
// Ghost snapshots annotate history but are not sent to the model.
continue;
}
ResponseItem::Reasoning { .. }
| ResponseItem::WebSearchCall { .. }
| ResponseItem::Other => {
// Omit these items from the conversation history.
continue;
}
}
}
let tools_json = create_tools_json_for_chat_completions_api(&prompt.tools)?;
let payload = json!({
"model": model_family.slug,
"messages": messages,
"stream": true,
"tools": tools_json,
});
debug!(
"POST to {}: {}",
provider.get_full_url(&None),
serde_json::to_string_pretty(&payload).unwrap_or_default()
);
let mut attempt = 0;
let max_retries = provider.request_max_retries();
loop {
attempt += 1;
let mut req_builder = provider.create_request_builder(client, &None).await?;
// Include subagent header only for subagent sessions.
if let SessionSource::SubAgent(sub) = session_source.clone() {
let subagent = if let SubAgentSource::Other(label) = sub {
label
} else {
serde_json::to_value(&sub)
.ok()
.and_then(|v| v.as_str().map(std::string::ToString::to_string))
.unwrap_or_else(|| "other".to_string())
};
req_builder = req_builder.header("x-openai-subagent", subagent);
}
let res = otel_event_manager
.log_request(attempt, || {
req_builder
.header(reqwest::header::ACCEPT, "text/event-stream")
.json(&payload)
.send()
})
.await;
match res {
Ok(resp) if resp.status().is_success() => {
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent>>(1600);
let stream = resp.bytes_stream().map_err(|e| {
CodexErr::ResponseStreamFailed(ResponseStreamFailed {
source: e,
request_id: None,
})
});
tokio::spawn(process_chat_sse(
stream,
tx_event,
provider.stream_idle_timeout(),
otel_event_manager.clone(),
));
return Ok(ResponseStream { rx_event });
}
Ok(res) => {
let status = res.status();
if !(status == StatusCode::TOO_MANY_REQUESTS || status.is_server_error()) {
let body = (res.text().await).unwrap_or_default();
return Err(CodexErr::UnexpectedStatus(UnexpectedResponseError {
status,
body,
request_id: None,
}));
}
if attempt > max_retries {
return Err(CodexErr::RetryLimit(RetryLimitReachedError {
status,
request_id: None,
}));
}
let retry_after_secs = res
.headers()
.get(reqwest::header::RETRY_AFTER)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok());
let delay = retry_after_secs
.map(|s| Duration::from_millis(s * 1_000))
.unwrap_or_else(|| backoff(attempt));
tokio::time::sleep(delay).await;
}
Err(e) => {
if attempt > max_retries {
return Err(CodexErr::ConnectionFailed(ConnectionFailedError {
source: e,
}));
}
let delay = backoff(attempt);
tokio::time::sleep(delay).await;
}
}
}
}
async fn append_assistant_text(
tx_event: &mpsc::Sender<Result<ResponseEvent>>,
assistant_item: &mut Option<ResponseItem>,
text: String,
) {
if assistant_item.is_none() {
let item = ResponseItem::Message {
id: None,
role: "assistant".to_string(),
content: vec![],
};
*assistant_item = Some(item.clone());
let _ = tx_event
.send(Ok(ResponseEvent::OutputItemAdded(item)))
.await;
}
if let Some(ResponseItem::Message { content, .. }) = assistant_item {
content.push(ContentItem::OutputText { text: text.clone() });
let _ = tx_event
.send(Ok(ResponseEvent::OutputTextDelta(text.clone())))
.await;
}
}
async fn append_reasoning_text(
tx_event: &mpsc::Sender<Result<ResponseEvent>>,
reasoning_item: &mut Option<ResponseItem>,
text: String,
) {
if reasoning_item.is_none() {
let item = ResponseItem::Reasoning {
id: String::new(),
summary: Vec::new(),
content: Some(vec![]),
encrypted_content: None,
};
*reasoning_item = Some(item.clone());
let _ = tx_event
.send(Ok(ResponseEvent::OutputItemAdded(item)))
.await;
}
if let Some(ResponseItem::Reasoning {
content: Some(content),
..
}) = reasoning_item
{
content.push(ReasoningItemContent::ReasoningText { text: text.clone() });
let _ = tx_event
.send(Ok(ResponseEvent::ReasoningContentDelta(text.clone())))
.await;
}
}
/// Lightweight SSE processor for the Chat Completions streaming format. The
/// output is mapped onto Codex's internal [`ResponseEvent`] so that the rest
/// of the pipeline can stay agnostic of the underlying wire format.
async fn process_chat_sse<S>(
stream: S,
tx_event: mpsc::Sender<Result<ResponseEvent>>,
idle_timeout: Duration,
otel_event_manager: OtelEventManager,
) where
S: Stream<Item = Result<Bytes>> + Unpin,
{
let mut stream = stream.eventsource();
// State to accumulate a function call across streaming chunks.
// OpenAI may split the `arguments` string over multiple `delta` events
// until the chunk whose `finish_reason` is `tool_calls` is emitted. We
// keep collecting the pieces here and forward a single
// `ResponseItem::FunctionCall` once the call is complete.
#[derive(Default)]
struct FunctionCallState {
name: Option<String>,
arguments: String,
call_id: Option<String>,
active: bool,
}
let mut fn_call_state = FunctionCallState::default();
let mut assistant_item: Option<ResponseItem> = None;
let mut reasoning_item: Option<ResponseItem> = None;
loop {
let start = std::time::Instant::now();
let response = timeout(idle_timeout, stream.next()).await;
let duration = start.elapsed();
otel_event_manager.log_sse_event(&response, duration);
let sse = match response {
Ok(Some(Ok(ev))) => ev,
Ok(Some(Err(e))) => {
let _ = tx_event
.send(Err(CodexErr::Stream(e.to_string(), None)))
.await;
return;
}
Ok(None) => {
// Stream closed gracefully emit Completed with dummy id.
let _ = tx_event
.send(Ok(ResponseEvent::Completed {
response_id: String::new(),
token_usage: None,
}))
.await;
return;
}
Err(_) => {
let _ = tx_event
.send(Err(CodexErr::Stream(
"idle timeout waiting for SSE".into(),
None,
)))
.await;
return;
}
};
// OpenAI Chat streaming sends a literal string "[DONE]" when finished.
if sse.data.trim() == "[DONE]" {
// Emit any finalized items before closing so downstream consumers receive
// terminal events for both assistant content and raw reasoning.
if let Some(item) = assistant_item {
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
}
if let Some(item) = reasoning_item {
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
}
let _ = tx_event
.send(Ok(ResponseEvent::Completed {
response_id: String::new(),
token_usage: None,
}))
.await;
return;
}
// Parse JSON chunk
let chunk: serde_json::Value = match serde_json::from_str(&sse.data) {
Ok(v) => v,
Err(_) => continue,
};
trace!("chat_completions received SSE chunk: {chunk:?}");
let choice_opt = chunk.get("choices").and_then(|c| c.get(0));
if let Some(choice) = choice_opt {
// Handle assistant content tokens as streaming deltas.
if let Some(content) = choice
.get("delta")
.and_then(|d| d.get("content"))
.and_then(|c| c.as_str())
&& !content.is_empty()
{
append_assistant_text(&tx_event, &mut assistant_item, content.to_string()).await;
}
// Forward any reasoning/thinking deltas if present.
// Some providers stream `reasoning` as a plain string while others
// nest the text under an object (e.g. `{ "reasoning": { "text": "…" } }`).
if let Some(reasoning_val) = choice.get("delta").and_then(|d| d.get("reasoning")) {
let mut maybe_text = reasoning_val
.as_str()
.map(str::to_string)
.filter(|s| !s.is_empty());
if maybe_text.is_none() && reasoning_val.is_object() {
if let Some(s) = reasoning_val
.get("text")
.and_then(|t| t.as_str())
.filter(|s| !s.is_empty())
{
maybe_text = Some(s.to_string());
} else if let Some(s) = reasoning_val
.get("content")
.and_then(|t| t.as_str())
.filter(|s| !s.is_empty())
{
maybe_text = Some(s.to_string());
}
}
if let Some(reasoning) = maybe_text {
// Accumulate so we can emit a terminal Reasoning item at the end.
append_reasoning_text(&tx_event, &mut reasoning_item, reasoning).await;
}
}
// Some providers only include reasoning on the final message object.
if let Some(message_reasoning) = choice.get("message").and_then(|m| m.get("reasoning"))
{
// Accept either a plain string or an object with { text | content }
if let Some(s) = message_reasoning.as_str() {
if !s.is_empty() {
append_reasoning_text(&tx_event, &mut reasoning_item, s.to_string()).await;
}
} else if let Some(obj) = message_reasoning.as_object()
&& let Some(s) = obj
.get("text")
.and_then(|v| v.as_str())
.or_else(|| obj.get("content").and_then(|v| v.as_str()))
&& !s.is_empty()
{
append_reasoning_text(&tx_event, &mut reasoning_item, s.to_string()).await;
}
}
// Handle streaming function / tool calls.
if let Some(tool_calls) = choice
.get("delta")
.and_then(|d| d.get("tool_calls"))
.and_then(|tc| tc.as_array())
&& let Some(tool_call) = tool_calls.first()
{
// Mark that we have an active function call in progress.
fn_call_state.active = true;
// Extract call_id if present.
if let Some(id) = tool_call.get("id").and_then(|v| v.as_str()) {
fn_call_state.call_id.get_or_insert_with(|| id.to_string());
}
// Extract function details if present.
if let Some(function) = tool_call.get("function") {
if let Some(name) = function.get("name").and_then(|n| n.as_str()) {
fn_call_state.name.get_or_insert_with(|| name.to_string());
}
if let Some(args_fragment) = function.get("arguments").and_then(|a| a.as_str())
{
fn_call_state.arguments.push_str(args_fragment);
}
}
}
// Emit end-of-turn when finish_reason signals completion.
if let Some(finish_reason) = choice.get("finish_reason").and_then(|v| v.as_str()) {
match finish_reason {
"tool_calls" if fn_call_state.active => {
// First, flush the terminal raw reasoning so UIs can finalize
// the reasoning stream before any exec/tool events begin.
if let Some(item) = reasoning_item.take() {
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
}
// Then emit the FunctionCall response item.
let item = ResponseItem::FunctionCall {
id: None,
name: fn_call_state.name.clone().unwrap_or_else(|| "".to_string()),
arguments: fn_call_state.arguments.clone(),
call_id: fn_call_state.call_id.clone().unwrap_or_else(String::new),
};
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
}
"stop" => {
// Regular turn without tool-call. Emit the final assistant message
// as a single OutputItemDone so non-delta consumers see the result.
if let Some(item) = assistant_item.take() {
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
}
// Also emit a terminal Reasoning item so UIs can finalize raw reasoning.
if let Some(item) = reasoning_item.take() {
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
}
}
_ => {}
}
// Emit Completed regardless of reason so the agent can advance.
let _ = tx_event
.send(Ok(ResponseEvent::Completed {
response_id: String::new(),
token_usage: None,
}))
.await;
// Prepare for potential next turn (should not happen in same stream).
// fn_call_state = FunctionCallState::default();
return; // End processing for this SSE stream.
}
}
}
}
/// Optional client-side aggregation helper
///
/// Stream adapter that merges the incremental `OutputItemDone` chunks coming from
/// [`process_chat_sse`] into a *running* assistant message, **suppressing the
/// per-token deltas**. The stream stays silent while the model is thinking
/// and only emits two events per turn:
///
/// 1. `ResponseEvent::OutputItemDone` with the *complete* assistant message
/// (fully concatenated).
/// 2. The original `ResponseEvent::Completed` right after it.
///
/// This mirrors the behaviour the TypeScript CLI exposes to its higher layers.
///
/// The adapter is intentionally *lossless*: callers who do **not** opt in via
/// [`AggregateStreamExt::aggregate()`] keep receiving the original unmodified
/// events.
#[derive(Copy, Clone, Eq, PartialEq)]
enum AggregateMode {
AggregatedOnly,
Streaming,
}
pub(crate) struct AggregatedChatStream<S> {
inner: S,
cumulative: String,
cumulative_reasoning: String,
pending: std::collections::VecDeque<ResponseEvent>,
mode: AggregateMode,
}
impl<S> Stream for AggregatedChatStream<S>
where
S: Stream<Item = Result<ResponseEvent>> + Unpin,
{
type Item = Result<ResponseEvent>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
// First, flush any buffered events from the previous call.
if let Some(ev) = this.pending.pop_front() {
return Poll::Ready(Some(Ok(ev)));
}
loop {
match Pin::new(&mut this.inner).poll_next(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(None) => return Poll::Ready(None),
Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item)))) => {
// If this is an incremental assistant message chunk, accumulate but
// do NOT emit yet. Forward any other item (e.g. FunctionCall) right
// away so downstream consumers see it.
let is_assistant_message = matches!(
&item,
codex_protocol::models::ResponseItem::Message { role, .. } if role == "assistant"
);
if is_assistant_message {
match this.mode {
AggregateMode::AggregatedOnly => {
// Only use the final assistant message if we have not
// seen any deltas; otherwise, deltas already built the
// cumulative text and this would duplicate it.
if this.cumulative.is_empty()
&& let codex_protocol::models::ResponseItem::Message {
content,
..
} = &item
&& let Some(text) = content.iter().find_map(|c| match c {
codex_protocol::models::ContentItem::OutputText {
text,
} => Some(text),
_ => None,
})
{
this.cumulative.push_str(text);
}
// Swallow assistant message here; emit on Completed.
continue;
}
AggregateMode::Streaming => {
// In streaming mode, if we have not seen any deltas, forward
// the final assistant message directly. If deltas were seen,
// suppress the final message to avoid duplication.
if this.cumulative.is_empty() {
return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(
item,
))));
} else {
continue;
}
}
}
}
// Not an assistant message forward immediately.
return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item))));
}
Poll::Ready(Some(Ok(ResponseEvent::RateLimits(snapshot)))) => {
return Poll::Ready(Some(Ok(ResponseEvent::RateLimits(snapshot))));
}
Poll::Ready(Some(Ok(ResponseEvent::Completed {
response_id,
token_usage,
}))) => {
// Build any aggregated items in the correct order: Reasoning first, then Message.
let mut emitted_any = false;
if !this.cumulative_reasoning.is_empty()
&& matches!(this.mode, AggregateMode::AggregatedOnly)
{
let aggregated_reasoning =
codex_protocol::models::ResponseItem::Reasoning {
id: String::new(),
summary: Vec::new(),
content: Some(vec![
codex_protocol::models::ReasoningItemContent::ReasoningText {
text: std::mem::take(&mut this.cumulative_reasoning),
},
]),
encrypted_content: None,
};
this.pending
.push_back(ResponseEvent::OutputItemDone(aggregated_reasoning));
emitted_any = true;
}
// Always emit the final aggregated assistant message when any
// content deltas have been observed. In AggregatedOnly mode this
// is the sole assistant output; in Streaming mode this finalizes
// the streamed deltas into a terminal OutputItemDone so callers
// can persist/render the message once per turn.
if !this.cumulative.is_empty() {
let aggregated_message = codex_protocol::models::ResponseItem::Message {
id: None,
role: "assistant".to_string(),
content: vec![codex_protocol::models::ContentItem::OutputText {
text: std::mem::take(&mut this.cumulative),
}],
};
this.pending
.push_back(ResponseEvent::OutputItemDone(aggregated_message));
emitted_any = true;
}
// Always emit Completed last when anything was aggregated.
if emitted_any {
this.pending.push_back(ResponseEvent::Completed {
response_id: response_id.clone(),
token_usage: token_usage.clone(),
});
// Return the first pending event now.
if let Some(ev) = this.pending.pop_front() {
return Poll::Ready(Some(Ok(ev)));
}
}
// Nothing aggregated forward Completed directly.
return Poll::Ready(Some(Ok(ResponseEvent::Completed {
response_id,
token_usage,
})));
}
Poll::Ready(Some(Ok(ResponseEvent::Created))) => {
// These events are exclusive to the Responses API and
// will never appear in a Chat Completions stream.
continue;
}
Poll::Ready(Some(Ok(ResponseEvent::OutputTextDelta(delta)))) => {
// Always accumulate deltas so we can emit a final OutputItemDone at Completed.
this.cumulative.push_str(&delta);
if matches!(this.mode, AggregateMode::Streaming) {
// In streaming mode, also forward the delta immediately.
return Poll::Ready(Some(Ok(ResponseEvent::OutputTextDelta(delta))));
} else {
continue;
}
}
Poll::Ready(Some(Ok(ResponseEvent::ReasoningContentDelta(delta)))) => {
// Always accumulate reasoning deltas so we can emit a final Reasoning item at Completed.
this.cumulative_reasoning.push_str(&delta);
if matches!(this.mode, AggregateMode::Streaming) {
// In streaming mode, also forward the delta immediately.
return Poll::Ready(Some(Ok(ResponseEvent::ReasoningContentDelta(delta))));
} else {
continue;
}
}
Poll::Ready(Some(Ok(ResponseEvent::ReasoningSummaryDelta(_)))) => {
continue;
}
Poll::Ready(Some(Ok(ResponseEvent::ReasoningSummaryPartAdded))) => {
continue;
}
Poll::Ready(Some(Ok(ResponseEvent::OutputItemAdded(item)))) => {
return Poll::Ready(Some(Ok(ResponseEvent::OutputItemAdded(item))));
}
}
}
}
}
/// Extension trait that activates aggregation on any stream of [`ResponseEvent`].
pub(crate) trait AggregateStreamExt: Stream<Item = Result<ResponseEvent>> + Sized {
/// Returns a new stream that emits **only** the final assistant message
/// per turn instead of every incremental delta. The produced
/// `ResponseEvent` sequence for a typical text turn looks like:
///
/// ```ignore
/// OutputItemDone(<full message>)
/// Completed
/// ```
///
/// No other `OutputItemDone` events will be seen by the caller.
///
/// Usage:
///
/// ```ignore
/// let agg_stream = client.stream(&prompt).await?.aggregate();
/// while let Some(event) = agg_stream.next().await {
/// // event now contains cumulative text
/// }
/// ```
fn aggregate(self) -> AggregatedChatStream<Self> {
AggregatedChatStream::new(self, AggregateMode::AggregatedOnly)
}
}
impl<T> AggregateStreamExt for T where T: Stream<Item = Result<ResponseEvent>> + Sized {}
impl<S> AggregatedChatStream<S> {
fn new(inner: S, mode: AggregateMode) -> Self {
AggregatedChatStream {
inner,
cumulative: String::new(),
cumulative_reasoning: String::new(),
pending: std::collections::VecDeque::new(),
mode,
}
}
pub(crate) fn streaming_mode(inner: S) -> Self {
Self::new(inner, AggregateMode::Streaming)
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,24 +1,21 @@
use crate::client_common::tools::ToolSpec;
use crate::error::Result;
use crate::model_family::ModelFamily;
use crate::protocol::RateLimitSnapshot;
use crate::protocol::TokenUsage;
use codex_api_client::Reasoning;
pub use codex_api_client::ResponseEvent;
use codex_api_client::TextControls;
use codex_api_client::TextFormat;
use codex_api_client::TextFormatType;
use codex_apply_patch::APPLY_PATCH_TOOL_INSTRUCTIONS;
use codex_protocol::config_types::ReasoningEffort as ReasoningEffortConfig;
use codex_protocol::config_types::ReasoningSummary as ReasoningSummaryConfig;
use codex_protocol::config_types::Verbosity as VerbosityConfig;
use codex_protocol::models::ResponseItem;
use futures::Stream;
use serde::Deserialize;
use serde::Serialize;
use serde_json::Value;
use std::borrow::Cow;
use std::collections::HashSet;
use std::ops::Deref;
use std::pin::Pin;
use std::task::Context;
use std::task::Poll;
use tokio::sync::mpsc;
/// Review thread system prompt. Edit `core/src/review_prompt.md` to customize.
pub const REVIEW_PROMPT: &str = include_str!("../review_prompt.md");
@@ -193,95 +190,7 @@ fn strip_total_output_header(output: &str) -> Option<&str> {
Some(remainder)
}
#[derive(Debug)]
pub enum ResponseEvent {
Created,
OutputItemDone(ResponseItem),
OutputItemAdded(ResponseItem),
Completed {
response_id: String,
token_usage: Option<TokenUsage>,
},
OutputTextDelta(String),
ReasoningSummaryDelta(String),
ReasoningContentDelta(String),
ReasoningSummaryPartAdded,
RateLimits(RateLimitSnapshot),
}
#[derive(Debug, Serialize)]
pub(crate) struct Reasoning {
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) effort: Option<ReasoningEffortConfig>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) summary: Option<ReasoningSummaryConfig>,
}
#[derive(Debug, Serialize, Default, Clone)]
#[serde(rename_all = "snake_case")]
pub(crate) enum TextFormatType {
#[default]
JsonSchema,
}
#[derive(Debug, Serialize, Default, Clone)]
pub(crate) struct TextFormat {
pub(crate) r#type: TextFormatType,
pub(crate) strict: bool,
pub(crate) schema: Value,
pub(crate) name: String,
}
/// Controls under the `text` field in the Responses API for GPT-5.
#[derive(Debug, Serialize, Default, Clone)]
pub(crate) struct TextControls {
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) verbosity: Option<OpenAiVerbosity>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) format: Option<TextFormat>,
}
#[derive(Debug, Serialize, Default, Clone)]
#[serde(rename_all = "lowercase")]
pub(crate) enum OpenAiVerbosity {
Low,
#[default]
Medium,
High,
}
impl From<VerbosityConfig> for OpenAiVerbosity {
fn from(v: VerbosityConfig) -> Self {
match v {
VerbosityConfig::Low => OpenAiVerbosity::Low,
VerbosityConfig::Medium => OpenAiVerbosity::Medium,
VerbosityConfig::High => OpenAiVerbosity::High,
}
}
}
/// Request object that is serialized as JSON and POST'ed when using the
/// Responses API.
#[derive(Debug, Serialize)]
pub(crate) struct ResponsesApiRequest<'a> {
pub(crate) model: &'a str,
pub(crate) instructions: &'a str,
// TODO(mbolin): ResponseItem::Other should not be serialized. Currently,
// we code defensively to avoid this case, but perhaps we should use a
// separate enum for serialization.
pub(crate) input: &'a Vec<ResponseItem>,
pub(crate) tools: &'a [serde_json::Value],
pub(crate) tool_choice: &'static str,
pub(crate) parallel_tool_calls: bool,
pub(crate) reasoning: Option<Reasoning>,
pub(crate) store: bool,
pub(crate) stream: bool,
pub(crate) include: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) prompt_cache_key: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) text: Option<TextControls>,
}
pub type ResponseStream = codex_api_client::EventStream<Result<ResponseEvent>>;
pub(crate) mod tools {
use crate::tools::spec::JsonSchema;
@@ -366,7 +275,11 @@ pub(crate) fn create_text_param_for_request(
}
Some(TextControls {
verbosity: verbosity.map(std::convert::Into::into),
verbosity: verbosity.map(|v| match v {
VerbosityConfig::Low => "low".to_string(),
VerbosityConfig::Medium => "medium".to_string(),
VerbosityConfig::High => "high".to_string(),
}),
format: output_schema.as_ref().map(|schema| TextFormat {
r#type: TextFormatType::JsonSchema,
strict: true,
@@ -376,18 +289,6 @@ pub(crate) fn create_text_param_for_request(
})
}
pub struct ResponseStream {
pub(crate) rx_event: mpsc::Receiver<Result<ResponseEvent>>,
}
impl Stream for ResponseStream {
type Item = Result<ResponseEvent>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.rx_event.poll_recv(cx)
}
}
#[cfg(test)]
mod tests {
use crate::model_family::find_family_for_model;
@@ -453,39 +354,14 @@ mod tests {
#[test]
fn serializes_text_verbosity_when_set() {
let input: Vec<ResponseItem> = vec![];
let tools: Vec<serde_json::Value> = vec![];
let req = ResponsesApiRequest {
model: "gpt-5",
instructions: "i",
input: &input,
tools: &tools,
tool_choice: "auto",
parallel_tool_calls: true,
reasoning: None,
store: false,
stream: true,
include: vec![],
prompt_cache_key: None,
text: Some(TextControls {
verbosity: Some(OpenAiVerbosity::Low),
format: None,
}),
};
let v = serde_json::to_value(&req).expect("json");
assert_eq!(
v.get("text")
.and_then(|t| t.get("verbosity"))
.and_then(|s| s.as_str()),
Some("low")
);
let controls =
create_text_param_for_request(Some(VerbosityConfig::Low), &None).expect("controls");
assert_eq!(controls.verbosity.as_deref(), Some("low"));
assert!(controls.format.is_none());
}
#[test]
fn serializes_text_schema_with_strict_format() {
let input: Vec<ResponseItem> = vec![];
let tools: Vec<serde_json::Value> = vec![];
let schema = serde_json::json!({
"type": "object",
"properties": {
@@ -493,61 +369,17 @@ mod tests {
},
"required": ["answer"],
});
let text_controls =
let controls =
create_text_param_for_request(None, &Some(schema.clone())).expect("text controls");
let req = ResponsesApiRequest {
model: "gpt-5",
instructions: "i",
input: &input,
tools: &tools,
tool_choice: "auto",
parallel_tool_calls: true,
reasoning: None,
store: false,
stream: true,
include: vec![],
prompt_cache_key: None,
text: Some(text_controls),
};
let v = serde_json::to_value(&req).expect("json");
let text = v.get("text").expect("text field");
assert!(text.get("verbosity").is_none());
let format = text.get("format").expect("format field");
assert_eq!(
format.get("name"),
Some(&serde_json::Value::String("codex_output_schema".into()))
);
assert_eq!(
format.get("type"),
Some(&serde_json::Value::String("json_schema".into()))
);
assert_eq!(format.get("strict"), Some(&serde_json::Value::Bool(true)));
assert_eq!(format.get("schema"), Some(&schema));
assert!(controls.verbosity.is_none());
let format = controls.format.expect("format");
assert_eq!(format.name, "codex_output_schema");
assert!(format.strict);
assert_eq!(format.schema, schema);
}
#[test]
fn omits_text_when_not_set() {
let input: Vec<ResponseItem> = vec![];
let tools: Vec<serde_json::Value> = vec![];
let req = ResponsesApiRequest {
model: "gpt-5",
instructions: "i",
input: &input,
tools: &tools,
tool_choice: "auto",
parallel_tool_calls: true,
reasoning: None,
store: false,
stream: true,
include: vec![],
prompt_cache_key: None,
text: None,
};
let v = serde_json::to_value(&req).expect("json");
assert!(v.get("text").is_none());
assert!(create_text_param_for_request(None, &None).is_none());
}
}

View File

@@ -52,7 +52,6 @@ use tracing::info;
use tracing::warn;
use crate::ModelProviderInfo;
use crate::client::ModelClient;
use crate::client_common::Prompt;
use crate::client_common::ResponseEvent;
use crate::config::Config;
@@ -294,6 +293,8 @@ impl TurnContext {
}
}
// Model-specific helpers live on ModelClient; TurnContext remains lean.
#[allow(dead_code)]
#[derive(Clone)]
pub(crate) struct SessionConfiguration {
@@ -403,6 +404,11 @@ impl Session {
session_configuration.model.as_str(),
);
let tools_config = ToolsConfig::new(&ToolsConfigParams {
model_family: &model_family,
features: &config.features,
});
let client = ModelClient::new(
Arc::new(per_turn_config),
auth_manager,
@@ -414,11 +420,6 @@ impl Session {
session_configuration.session_source.clone(),
);
let tools_config = ToolsConfig::new(&ToolsConfigParams {
model_family: &model_family,
features: &config.features,
});
TurnContext {
sub_id,
client,
@@ -1674,6 +1675,7 @@ async fn spawn_review_thread(
);
let per_turn_config = Arc::new(per_turn_config);
let client = ModelClient::new(
per_turn_config.clone(),
auth_manager,
@@ -1936,7 +1938,7 @@ async fn run_turn(
retries += 1;
let delay = match e {
CodexErr::Stream(_, Some(delay)) => delay,
_ => backoff(retries),
_ => backoff(retries.max(0) as u64),
};
warn!(
"stream disconnected - retrying turn ({retries}/{max_retries} in {delay:?})...",
@@ -1995,10 +1997,7 @@ async fn try_run_turn(
});
sess.persist_rollout_items(&[rollout_item]).await;
let mut stream = turn_context
.client
.clone()
.stream(prompt)
let mut stream = crate::client::stream_for_turn(&turn_context, prompt)
.or_cancel(&cancellation_token)
.await??;
@@ -3144,3 +3143,4 @@ mod tests {
);
}
}
use crate::ModelClient;

View File

@@ -120,7 +120,7 @@ async fn run_compact_task_inner(
Err(e) => {
if retries < max_retries {
retries += 1;
let delay = backoff(retries);
let delay = backoff(retries.max(0) as u64);
sess.notify_stream_error(
turn_context.as_ref(),
format!("Reconnecting... {retries}/{max_retries}"),
@@ -266,7 +266,7 @@ async fn drain_to_completed(
turn_context: &TurnContext,
prompt: &Prompt,
) -> CodexResult<()> {
let mut stream = turn_context.client.clone().stream(prompt).await?;
let mut stream = crate::client::stream_for_turn(turn_context, prompt).await?;
loop {
let maybe_event = stream.next().await;
let Some(event) = maybe_event else {

View File

@@ -1,4 +1,6 @@
use crate::ModelProviderInfo;
use crate::auth::AuthCredentialsStoreMode;
use crate::built_in_model_providers;
use crate::config::types::DEFAULT_OTEL_ENVIRONMENT;
use crate::config::types::History;
use crate::config::types::McpServerConfig;
@@ -25,8 +27,6 @@ use crate::git_info::resolve_root_git_project_for_trust;
use crate::model_family::ModelFamily;
use crate::model_family::derive_default_model_family;
use crate::model_family::find_family_for_model;
use crate::model_provider_info::ModelProviderInfo;
use crate::model_provider_info::built_in_model_providers;
use crate::openai_model_info::get_model_info;
use crate::project_doc::DEFAULT_PROJECT_DOC_FILENAME;
use crate::project_doc::LOCAL_PROJECT_DOC_FILENAME;

View File

@@ -41,6 +41,14 @@ impl CodexHttpClient {
Self { inner }
}
pub fn inner(&self) -> &reqwest::Client {
&self.inner
}
pub fn clone_inner(&self) -> reqwest::Client {
self.inner.clone()
}
pub fn get<U>(&self, url: U) -> CodexRequestBuilder
where
U: IntoUrl,

View File

@@ -8,7 +8,6 @@
mod apply_patch;
pub mod auth;
pub mod bash;
mod chat_completions;
mod client;
mod client_common;
pub mod codex;
@@ -32,7 +31,6 @@ pub mod mcp;
mod mcp_connection_manager;
mod mcp_tool_call;
mod message_history;
mod model_provider_info;
pub mod parse_command;
mod response_processing;
pub mod sandboxing;
@@ -40,11 +38,11 @@ pub mod token_data;
mod truncate;
mod unified_exec;
mod user_instructions;
pub use model_provider_info::BUILT_IN_OSS_MODEL_PROVIDER_ID;
pub use model_provider_info::ModelProviderInfo;
pub use model_provider_info::WireApi;
pub use model_provider_info::built_in_model_providers;
pub use model_provider_info::create_oss_provider_with_base_url;
pub use codex_api_client::BUILT_IN_OSS_MODEL_PROVIDER_ID;
pub use codex_api_client::ModelProviderInfo;
pub use codex_api_client::WireApi;
pub use codex_api_client::built_in_model_providers;
pub use codex_api_client::create_oss_provider_with_base_url;
mod conversation_manager;
mod event_mapping;
pub mod review_format;

View File

@@ -1,532 +0,0 @@
//! Registry of model providers supported by Codex.
//!
//! Providers can be defined in two places:
//! 1. Built-in defaults compiled into the binary so Codex works out-of-the-box.
//! 2. User-defined entries inside `~/.codex/config.toml` under the `model_providers`
//! key. These override or extend the defaults at runtime.
use crate::CodexAuth;
use crate::default_client::CodexHttpClient;
use crate::default_client::CodexRequestBuilder;
use codex_app_server_protocol::AuthMode;
use serde::Deserialize;
use serde::Serialize;
use std::collections::HashMap;
use std::env::VarError;
use std::time::Duration;
use crate::error::EnvVarError;
const DEFAULT_STREAM_IDLE_TIMEOUT_MS: u64 = 300_000;
const DEFAULT_STREAM_MAX_RETRIES: u64 = 5;
const DEFAULT_REQUEST_MAX_RETRIES: u64 = 4;
/// Hard cap for user-configured `stream_max_retries`.
const MAX_STREAM_MAX_RETRIES: u64 = 100;
/// Hard cap for user-configured `request_max_retries`.
const MAX_REQUEST_MAX_RETRIES: u64 = 100;
/// Wire protocol that the provider speaks. Most third-party services only
/// implement the classic OpenAI Chat Completions JSON schema, whereas OpenAI
/// itself (and a handful of others) additionally expose the more modern
/// *Responses* API. The two protocols use different request/response shapes
/// and *cannot* be auto-detected at runtime, therefore each provider entry
/// must declare which one it expects.
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum WireApi {
/// The Responses API exposed by OpenAI at `/v1/responses`.
Responses,
/// Regular Chat Completions compatible with `/v1/chat/completions`.
#[default]
Chat,
}
/// Serializable representation of a provider definition.
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
pub struct ModelProviderInfo {
/// Friendly display name.
pub name: String,
/// Base URL for the provider's OpenAI-compatible API.
pub base_url: Option<String>,
/// Environment variable that stores the user's API key for this provider.
pub env_key: Option<String>,
/// Optional instructions to help the user get a valid value for the
/// variable and set it.
pub env_key_instructions: Option<String>,
/// Value to use with `Authorization: Bearer <token>` header. Use of this
/// config is discouraged in favor of `env_key` for security reasons, but
/// this may be necessary when using this programmatically.
pub experimental_bearer_token: Option<String>,
/// Which wire protocol this provider expects.
#[serde(default)]
pub wire_api: WireApi,
/// Optional query parameters to append to the base URL.
pub query_params: Option<HashMap<String, String>>,
/// Additional HTTP headers to include in requests to this provider where
/// the (key, value) pairs are the header name and value.
pub http_headers: Option<HashMap<String, String>>,
/// Optional HTTP headers to include in requests to this provider where the
/// (key, value) pairs are the header name and _environment variable_ whose
/// value should be used. If the environment variable is not set, or the
/// value is empty, the header will not be included in the request.
pub env_http_headers: Option<HashMap<String, String>>,
/// Maximum number of times to retry a failed HTTP request to this provider.
pub request_max_retries: Option<u64>,
/// Number of times to retry reconnecting a dropped streaming response before failing.
pub stream_max_retries: Option<u64>,
/// Idle timeout (in milliseconds) to wait for activity on a streaming response before treating
/// the connection as lost.
pub stream_idle_timeout_ms: Option<u64>,
/// Does this provider require an OpenAI API Key or ChatGPT login token? If true,
/// user is presented with login screen on first run, and login preference and token/key
/// are stored in auth.json. If false (which is the default), login screen is skipped,
/// and API key (if needed) comes from the "env_key" environment variable.
#[serde(default)]
pub requires_openai_auth: bool,
}
impl ModelProviderInfo {
/// Construct a `POST` RequestBuilder for the given URL using the provided
/// [`CodexHttpClient`] applying:
/// • provider-specific headers (static + env based)
/// • Bearer auth header when an API key is available.
/// • Auth token for OAuth.
///
/// If the provider declares an `env_key` but the variable is missing/empty, returns an [`Err`] identical to the
/// one produced by [`ModelProviderInfo::api_key`].
pub async fn create_request_builder<'a>(
&'a self,
client: &'a CodexHttpClient,
auth: &Option<CodexAuth>,
) -> crate::error::Result<CodexRequestBuilder> {
let effective_auth = if let Some(secret_key) = &self.experimental_bearer_token {
Some(CodexAuth::from_api_key(secret_key))
} else {
match self.api_key() {
Ok(Some(key)) => Some(CodexAuth::from_api_key(&key)),
Ok(None) => auth.clone(),
Err(err) => {
if auth.is_some() {
auth.clone()
} else {
return Err(err);
}
}
}
};
let url = self.get_full_url(&effective_auth);
let mut builder = client.post(url);
if let Some(auth) = effective_auth.as_ref() {
builder = builder.bearer_auth(auth.get_token().await?);
}
Ok(self.apply_http_headers(builder))
}
fn get_query_string(&self) -> String {
self.query_params
.as_ref()
.map_or_else(String::new, |params| {
let full_params = params
.iter()
.map(|(k, v)| format!("{k}={v}"))
.collect::<Vec<_>>()
.join("&");
format!("?{full_params}")
})
}
pub(crate) fn get_full_url(&self, auth: &Option<CodexAuth>) -> String {
let default_base_url = if matches!(
auth,
Some(CodexAuth {
mode: AuthMode::ChatGPT,
..
})
) {
"https://chatgpt.com/backend-api/codex"
} else {
"https://api.openai.com/v1"
};
let query_string = self.get_query_string();
let base_url = self
.base_url
.clone()
.unwrap_or(default_base_url.to_string());
match self.wire_api {
WireApi::Responses => format!("{base_url}/responses{query_string}"),
WireApi::Chat => format!("{base_url}/chat/completions{query_string}"),
}
}
pub(crate) fn is_azure_responses_endpoint(&self) -> bool {
if self.wire_api != WireApi::Responses {
return false;
}
if self.name.eq_ignore_ascii_case("azure") {
return true;
}
self.base_url
.as_ref()
.map(|base| matches_azure_responses_base_url(base))
.unwrap_or(false)
}
/// Apply provider-specific HTTP headers (both static and environment-based)
/// onto an existing [`CodexRequestBuilder`] and return the updated
/// builder.
fn apply_http_headers(&self, mut builder: CodexRequestBuilder) -> CodexRequestBuilder {
if let Some(extra) = &self.http_headers {
for (k, v) in extra {
builder = builder.header(k, v);
}
}
if let Some(env_headers) = &self.env_http_headers {
for (header, env_var) in env_headers {
if let Ok(val) = std::env::var(env_var)
&& !val.trim().is_empty()
{
builder = builder.header(header, val);
}
}
}
builder
}
/// If `env_key` is Some, returns the API key for this provider if present
/// (and non-empty) in the environment. If `env_key` is required but
/// cannot be found, returns an error.
pub fn api_key(&self) -> crate::error::Result<Option<String>> {
match &self.env_key {
Some(env_key) => {
let env_value = std::env::var(env_key);
env_value
.and_then(|v| {
if v.trim().is_empty() {
Err(VarError::NotPresent)
} else {
Ok(Some(v))
}
})
.map_err(|_| {
crate::error::CodexErr::EnvVar(EnvVarError {
var: env_key.clone(),
instructions: self.env_key_instructions.clone(),
})
})
}
None => Ok(None),
}
}
/// Effective maximum number of request retries for this provider.
pub fn request_max_retries(&self) -> u64 {
self.request_max_retries
.unwrap_or(DEFAULT_REQUEST_MAX_RETRIES)
.min(MAX_REQUEST_MAX_RETRIES)
}
/// Effective maximum number of stream reconnection attempts for this provider.
pub fn stream_max_retries(&self) -> u64 {
self.stream_max_retries
.unwrap_or(DEFAULT_STREAM_MAX_RETRIES)
.min(MAX_STREAM_MAX_RETRIES)
}
/// Effective idle timeout for streaming responses.
pub fn stream_idle_timeout(&self) -> Duration {
self.stream_idle_timeout_ms
.map(Duration::from_millis)
.unwrap_or(Duration::from_millis(DEFAULT_STREAM_IDLE_TIMEOUT_MS))
}
}
const DEFAULT_OLLAMA_PORT: u32 = 11434;
pub const BUILT_IN_OSS_MODEL_PROVIDER_ID: &str = "oss";
/// Built-in default provider list.
pub fn built_in_model_providers() -> HashMap<String, ModelProviderInfo> {
use ModelProviderInfo as P;
// We do not want to be in the business of adjucating which third-party
// providers are bundled with Codex CLI, so we only include the OpenAI and
// open source ("oss") providers by default. Users are encouraged to add to
// `model_providers` in config.toml to add their own providers.
[
(
"openai",
P {
name: "OpenAI".into(),
// Allow users to override the default OpenAI endpoint by
// exporting `OPENAI_BASE_URL`. This is useful when pointing
// Codex at a proxy, mock server, or Azure-style deployment
// without requiring a full TOML override for the built-in
// OpenAI provider.
base_url: std::env::var("OPENAI_BASE_URL")
.ok()
.filter(|v| !v.trim().is_empty()),
env_key: None,
env_key_instructions: None,
experimental_bearer_token: None,
wire_api: WireApi::Responses,
query_params: None,
http_headers: Some(
[("version".to_string(), env!("CARGO_PKG_VERSION").to_string())]
.into_iter()
.collect(),
),
env_http_headers: Some(
[
(
"OpenAI-Organization".to_string(),
"OPENAI_ORGANIZATION".to_string(),
),
("OpenAI-Project".to_string(), "OPENAI_PROJECT".to_string()),
]
.into_iter()
.collect(),
),
// Use global defaults for retry/timeout unless overridden in config.toml.
request_max_retries: None,
stream_max_retries: None,
stream_idle_timeout_ms: None,
requires_openai_auth: true,
},
),
(BUILT_IN_OSS_MODEL_PROVIDER_ID, create_oss_provider()),
]
.into_iter()
.map(|(k, v)| (k.to_string(), v))
.collect()
}
pub fn create_oss_provider() -> ModelProviderInfo {
// These CODEX_OSS_ environment variables are experimental: we may
// switch to reading values from config.toml instead.
let codex_oss_base_url = match std::env::var("CODEX_OSS_BASE_URL")
.ok()
.filter(|v| !v.trim().is_empty())
{
Some(url) => url,
None => format!(
"http://localhost:{port}/v1",
port = std::env::var("CODEX_OSS_PORT")
.ok()
.filter(|v| !v.trim().is_empty())
.and_then(|v| v.parse::<u32>().ok())
.unwrap_or(DEFAULT_OLLAMA_PORT)
),
};
create_oss_provider_with_base_url(&codex_oss_base_url)
}
pub fn create_oss_provider_with_base_url(base_url: &str) -> ModelProviderInfo {
ModelProviderInfo {
name: "gpt-oss".into(),
base_url: Some(base_url.into()),
env_key: None,
env_key_instructions: None,
experimental_bearer_token: None,
wire_api: WireApi::Chat,
query_params: None,
http_headers: None,
env_http_headers: None,
request_max_retries: None,
stream_max_retries: None,
stream_idle_timeout_ms: None,
requires_openai_auth: false,
}
}
fn matches_azure_responses_base_url(base_url: &str) -> bool {
let base = base_url.to_ascii_lowercase();
const AZURE_MARKERS: [&str; 5] = [
"openai.azure.",
"cognitiveservices.azure.",
"aoai.azure.",
"azure-api.",
"azurefd.",
];
AZURE_MARKERS.iter().any(|marker| base.contains(marker))
}
#[cfg(test)]
mod tests {
use super::*;
use pretty_assertions::assert_eq;
#[test]
fn test_deserialize_ollama_model_provider_toml() {
let azure_provider_toml = r#"
name = "Ollama"
base_url = "http://localhost:11434/v1"
"#;
let expected_provider = ModelProviderInfo {
name: "Ollama".into(),
base_url: Some("http://localhost:11434/v1".into()),
env_key: None,
env_key_instructions: None,
experimental_bearer_token: None,
wire_api: WireApi::Chat,
query_params: None,
http_headers: None,
env_http_headers: None,
request_max_retries: None,
stream_max_retries: None,
stream_idle_timeout_ms: None,
requires_openai_auth: false,
};
let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap();
assert_eq!(expected_provider, provider);
}
#[test]
fn test_deserialize_azure_model_provider_toml() {
let azure_provider_toml = r#"
name = "Azure"
base_url = "https://xxxxx.openai.azure.com/openai"
env_key = "AZURE_OPENAI_API_KEY"
query_params = { api-version = "2025-04-01-preview" }
"#;
let expected_provider = ModelProviderInfo {
name: "Azure".into(),
base_url: Some("https://xxxxx.openai.azure.com/openai".into()),
env_key: Some("AZURE_OPENAI_API_KEY".into()),
env_key_instructions: None,
experimental_bearer_token: None,
wire_api: WireApi::Chat,
query_params: Some(maplit::hashmap! {
"api-version".to_string() => "2025-04-01-preview".to_string(),
}),
http_headers: None,
env_http_headers: None,
request_max_retries: None,
stream_max_retries: None,
stream_idle_timeout_ms: None,
requires_openai_auth: false,
};
let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap();
assert_eq!(expected_provider, provider);
}
#[test]
fn test_deserialize_example_model_provider_toml() {
let azure_provider_toml = r#"
name = "Example"
base_url = "https://example.com"
env_key = "API_KEY"
http_headers = { "X-Example-Header" = "example-value" }
env_http_headers = { "X-Example-Env-Header" = "EXAMPLE_ENV_VAR" }
"#;
let expected_provider = ModelProviderInfo {
name: "Example".into(),
base_url: Some("https://example.com".into()),
env_key: Some("API_KEY".into()),
env_key_instructions: None,
experimental_bearer_token: None,
wire_api: WireApi::Chat,
query_params: None,
http_headers: Some(maplit::hashmap! {
"X-Example-Header".to_string() => "example-value".to_string(),
}),
env_http_headers: Some(maplit::hashmap! {
"X-Example-Env-Header".to_string() => "EXAMPLE_ENV_VAR".to_string(),
}),
request_max_retries: None,
stream_max_retries: None,
stream_idle_timeout_ms: None,
requires_openai_auth: false,
};
let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap();
assert_eq!(expected_provider, provider);
}
#[test]
fn detects_azure_responses_base_urls() {
fn provider_for(base_url: &str) -> ModelProviderInfo {
ModelProviderInfo {
name: "test".into(),
base_url: Some(base_url.into()),
env_key: None,
env_key_instructions: None,
experimental_bearer_token: None,
wire_api: WireApi::Responses,
query_params: None,
http_headers: None,
env_http_headers: None,
request_max_retries: None,
stream_max_retries: None,
stream_idle_timeout_ms: None,
requires_openai_auth: false,
}
}
let positive_cases = [
"https://foo.openai.azure.com/openai",
"https://foo.openai.azure.us/openai/deployments/bar",
"https://foo.cognitiveservices.azure.cn/openai",
"https://foo.aoai.azure.com/openai",
"https://foo.openai.azure-api.net/openai",
"https://foo.z01.azurefd.net/",
];
for base_url in positive_cases {
let provider = provider_for(base_url);
assert!(
provider.is_azure_responses_endpoint(),
"expected {base_url} to be detected as Azure"
);
}
let named_provider = ModelProviderInfo {
name: "Azure".into(),
base_url: Some("https://example.com".into()),
env_key: None,
env_key_instructions: None,
experimental_bearer_token: None,
wire_api: WireApi::Responses,
query_params: None,
http_headers: None,
env_http_headers: None,
request_max_retries: None,
stream_max_retries: None,
stream_idle_timeout_ms: None,
requires_openai_auth: false,
};
assert!(named_provider.is_azure_responses_endpoint());
let negative_cases = [
"https://api.openai.com/v1",
"https://example.com/openai",
"https://myproxy.azurewebsites.net/openai",
];
for base_url in negative_cases {
let provider = provider_for(base_url);
assert!(
!provider.is_azure_responses_endpoint(),
"expected {base_url} not to be detected as Azure"
);
}
}
}