chore: proper client extraction (#6996)

This commit is contained in:
jif-oai
2025-11-25 18:06:12 +00:00
committed by GitHub
parent 2845e2c006
commit 4502b1b263
45 changed files with 4893 additions and 2748 deletions

View File

@@ -0,0 +1,27 @@
use codex_client::Request;
/// Provides bearer and account identity information for API requests.
///
/// Implementations should be cheap and non-blocking; any asynchronous
/// refresh or I/O should be handled by higher layers before requests
/// reach this interface.
pub trait AuthProvider: Send + Sync {
fn bearer_token(&self) -> Option<String>;
fn account_id(&self) -> Option<String> {
None
}
}
pub(crate) fn add_auth_headers<A: AuthProvider>(auth: &A, mut req: Request) -> Request {
if let Some(token) = auth.bearer_token()
&& let Ok(header) = format!("Bearer {token}").parse()
{
let _ = req.headers.insert(http::header::AUTHORIZATION, header);
}
if let Some(account_id) = auth.account_id()
&& let Ok(header) = account_id.parse()
{
let _ = req.headers.insert("ChatGPT-Account-ID", header);
}
req
}

View File

@@ -0,0 +1,167 @@
use crate::error::ApiError;
use codex_protocol::config_types::ReasoningEffort as ReasoningEffortConfig;
use codex_protocol::config_types::ReasoningSummary as ReasoningSummaryConfig;
use codex_protocol::config_types::Verbosity as VerbosityConfig;
use codex_protocol::models::ResponseItem;
use codex_protocol::protocol::RateLimitSnapshot;
use codex_protocol::protocol::TokenUsage;
use futures::Stream;
use serde::Serialize;
use serde_json::Value;
use std::pin::Pin;
use std::task::Context;
use std::task::Poll;
use tokio::sync::mpsc;
/// Canonical prompt input for Chat and Responses endpoints.
#[derive(Debug, Clone)]
pub struct Prompt {
/// Fully-resolved system instructions for this turn.
pub instructions: String,
/// Conversation history and user/tool messages.
pub input: Vec<ResponseItem>,
/// JSON-encoded tool definitions compatible with the target API.
// TODO(jif) have a proper type here
pub tools: Vec<Value>,
/// Whether parallel tool calls are permitted.
pub parallel_tool_calls: bool,
/// Optional output schema used to build the `text.format` controls.
pub output_schema: Option<Value>,
}
/// Canonical input payload for the compaction endpoint.
#[derive(Debug, Clone, Serialize)]
pub struct CompactionInput<'a> {
pub model: &'a str,
pub input: &'a [ResponseItem],
pub instructions: &'a str,
}
#[derive(Debug)]
pub enum ResponseEvent {
Created,
OutputItemDone(ResponseItem),
OutputItemAdded(ResponseItem),
Completed {
response_id: String,
token_usage: Option<TokenUsage>,
},
OutputTextDelta(String),
ReasoningSummaryDelta {
delta: String,
summary_index: i64,
},
ReasoningContentDelta {
delta: String,
content_index: i64,
},
ReasoningSummaryPartAdded {
summary_index: i64,
},
RateLimits(RateLimitSnapshot),
}
#[derive(Debug, Serialize, Clone)]
pub struct Reasoning {
#[serde(skip_serializing_if = "Option::is_none")]
pub effort: Option<ReasoningEffortConfig>,
#[serde(skip_serializing_if = "Option::is_none")]
pub summary: Option<ReasoningSummaryConfig>,
}
#[derive(Debug, Serialize, Default, Clone)]
#[serde(rename_all = "snake_case")]
pub enum TextFormatType {
#[default]
JsonSchema,
}
#[derive(Debug, Serialize, Default, Clone)]
pub struct TextFormat {
/// Format type used by the OpenAI text controls.
pub r#type: TextFormatType,
/// When true, the server is expected to strictly validate responses.
pub strict: bool,
/// JSON schema for the desired output.
pub schema: Value,
/// Friendly name for the format, used in telemetry/debugging.
pub name: String,
}
/// Controls the `text` field for the Responses API, combining verbosity and
/// optional JSON schema output formatting.
#[derive(Debug, Serialize, Default, Clone)]
pub struct TextControls {
#[serde(skip_serializing_if = "Option::is_none")]
pub verbosity: Option<OpenAiVerbosity>,
#[serde(skip_serializing_if = "Option::is_none")]
pub format: Option<TextFormat>,
}
#[derive(Debug, Serialize, Default, Clone)]
#[serde(rename_all = "lowercase")]
pub enum OpenAiVerbosity {
Low,
#[default]
Medium,
High,
}
impl From<VerbosityConfig> for OpenAiVerbosity {
fn from(v: VerbosityConfig) -> Self {
match v {
VerbosityConfig::Low => OpenAiVerbosity::Low,
VerbosityConfig::Medium => OpenAiVerbosity::Medium,
VerbosityConfig::High => OpenAiVerbosity::High,
}
}
}
#[derive(Debug, Serialize)]
pub struct ResponsesApiRequest<'a> {
pub model: &'a str,
pub instructions: &'a str,
pub input: &'a [ResponseItem],
pub tools: &'a [serde_json::Value],
pub tool_choice: &'static str,
pub parallel_tool_calls: bool,
pub reasoning: Option<Reasoning>,
pub store: bool,
pub stream: bool,
pub include: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt_cache_key: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub text: Option<TextControls>,
}
pub fn create_text_param_for_request(
verbosity: Option<VerbosityConfig>,
output_schema: &Option<Value>,
) -> Option<TextControls> {
if verbosity.is_none() && output_schema.is_none() {
return None;
}
Some(TextControls {
verbosity: verbosity.map(std::convert::Into::into),
format: output_schema.as_ref().map(|schema| TextFormat {
r#type: TextFormatType::JsonSchema,
strict: true,
schema: schema.clone(),
name: "codex_output_schema".to_string(),
}),
})
}
pub struct ResponseStream {
pub rx_event: mpsc::Receiver<Result<ResponseEvent, ApiError>>,
}
impl Stream for ResponseStream {
type Item = Result<ResponseEvent, ApiError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.rx_event.poll_recv(cx)
}
}

View File

@@ -0,0 +1,266 @@
use crate::ChatRequest;
use crate::auth::AuthProvider;
use crate::common::Prompt as ApiPrompt;
use crate::common::ResponseEvent;
use crate::common::ResponseStream;
use crate::endpoint::streaming::StreamingClient;
use crate::error::ApiError;
use crate::provider::Provider;
use crate::provider::WireApi;
use crate::sse::chat::spawn_chat_stream;
use crate::telemetry::SseTelemetry;
use codex_client::HttpTransport;
use codex_client::RequestTelemetry;
use codex_protocol::models::ContentItem;
use codex_protocol::models::ReasoningItemContent;
use codex_protocol::models::ResponseItem;
use codex_protocol::protocol::SessionSource;
use futures::Stream;
use http::HeaderMap;
use serde_json::Value;
use std::collections::VecDeque;
use std::pin::Pin;
use std::sync::Arc;
use std::task::Context;
use std::task::Poll;
pub struct ChatClient<T: HttpTransport, A: AuthProvider> {
streaming: StreamingClient<T, A>,
}
impl<T: HttpTransport, A: AuthProvider> ChatClient<T, A> {
pub fn new(transport: T, provider: Provider, auth: A) -> Self {
Self {
streaming: StreamingClient::new(transport, provider, auth),
}
}
pub fn with_telemetry(
self,
request: Option<Arc<dyn RequestTelemetry>>,
sse: Option<Arc<dyn SseTelemetry>>,
) -> Self {
Self {
streaming: self.streaming.with_telemetry(request, sse),
}
}
pub async fn stream_request(&self, request: ChatRequest) -> Result<ResponseStream, ApiError> {
self.stream(request.body, request.headers).await
}
pub async fn stream_prompt(
&self,
model: &str,
prompt: &ApiPrompt,
conversation_id: Option<String>,
session_source: Option<SessionSource>,
) -> Result<ResponseStream, ApiError> {
use crate::requests::ChatRequestBuilder;
let request =
ChatRequestBuilder::new(model, &prompt.instructions, &prompt.input, &prompt.tools)
.conversation_id(conversation_id)
.session_source(session_source)
.build(self.streaming.provider())?;
self.stream_request(request).await
}
fn path(&self) -> &'static str {
match self.streaming.provider().wire {
WireApi::Chat => "chat/completions",
_ => "responses",
}
}
pub async fn stream(
&self,
body: Value,
extra_headers: HeaderMap,
) -> Result<ResponseStream, ApiError> {
self.streaming
.stream(self.path(), body, extra_headers, spawn_chat_stream)
.await
}
}
#[derive(Copy, Clone, Eq, PartialEq)]
pub enum AggregateMode {
AggregatedOnly,
Streaming,
}
/// Stream adapter that merges token deltas into a single assistant message per turn.
pub struct AggregatedStream {
inner: ResponseStream,
cumulative: String,
cumulative_reasoning: String,
pending: VecDeque<ResponseEvent>,
mode: AggregateMode,
}
impl Stream for AggregatedStream {
type Item = Result<ResponseEvent, ApiError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
if let Some(ev) = this.pending.pop_front() {
return Poll::Ready(Some(Ok(ev)));
}
loop {
match Pin::new(&mut this.inner).poll_next(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(None) => return Poll::Ready(None),
Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item)))) => {
let is_assistant_message = matches!(
&item,
ResponseItem::Message { role, .. } if role == "assistant"
);
if is_assistant_message {
match this.mode {
AggregateMode::AggregatedOnly => {
if this.cumulative.is_empty()
&& let ResponseItem::Message { content, .. } = &item
&& let Some(text) = content.iter().find_map(|c| match c {
ContentItem::OutputText { text } => Some(text),
_ => None,
})
{
this.cumulative.push_str(text);
}
continue;
}
AggregateMode::Streaming => {
if this.cumulative.is_empty() {
return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(
item,
))));
} else {
continue;
}
}
}
}
return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item))));
}
Poll::Ready(Some(Ok(ResponseEvent::RateLimits(snapshot)))) => {
return Poll::Ready(Some(Ok(ResponseEvent::RateLimits(snapshot))));
}
Poll::Ready(Some(Ok(ResponseEvent::Completed {
response_id,
token_usage,
}))) => {
let mut emitted_any = false;
if !this.cumulative_reasoning.is_empty() {
let aggregated_reasoning = ResponseItem::Reasoning {
id: String::new(),
summary: Vec::new(),
content: Some(vec![ReasoningItemContent::ReasoningText {
text: std::mem::take(&mut this.cumulative_reasoning),
}]),
encrypted_content: None,
};
this.pending
.push_back(ResponseEvent::OutputItemDone(aggregated_reasoning));
emitted_any = true;
}
if !this.cumulative.is_empty() {
let aggregated_message = ResponseItem::Message {
id: None,
role: "assistant".to_string(),
content: vec![ContentItem::OutputText {
text: std::mem::take(&mut this.cumulative),
}],
};
this.pending
.push_back(ResponseEvent::OutputItemDone(aggregated_message));
emitted_any = true;
}
if emitted_any {
this.pending.push_back(ResponseEvent::Completed {
response_id: response_id.clone(),
token_usage: token_usage.clone(),
});
if let Some(ev) = this.pending.pop_front() {
return Poll::Ready(Some(Ok(ev)));
}
}
return Poll::Ready(Some(Ok(ResponseEvent::Completed {
response_id,
token_usage,
})));
}
Poll::Ready(Some(Ok(ResponseEvent::Created))) => {
continue;
}
Poll::Ready(Some(Ok(ResponseEvent::OutputTextDelta(delta)))) => {
this.cumulative.push_str(&delta);
if matches!(this.mode, AggregateMode::Streaming) {
return Poll::Ready(Some(Ok(ResponseEvent::OutputTextDelta(delta))));
} else {
continue;
}
}
Poll::Ready(Some(Ok(ResponseEvent::ReasoningContentDelta {
delta,
content_index,
}))) => {
this.cumulative_reasoning.push_str(&delta);
if matches!(this.mode, AggregateMode::Streaming) {
return Poll::Ready(Some(Ok(ResponseEvent::ReasoningContentDelta {
delta,
content_index,
})));
} else {
continue;
}
}
Poll::Ready(Some(Ok(ResponseEvent::ReasoningSummaryDelta { .. }))) => continue,
Poll::Ready(Some(Ok(ResponseEvent::ReasoningSummaryPartAdded { .. }))) => {
continue;
}
Poll::Ready(Some(Ok(ResponseEvent::OutputItemAdded(item)))) => {
return Poll::Ready(Some(Ok(ResponseEvent::OutputItemAdded(item))));
}
}
}
}
}
pub trait AggregateStreamExt {
fn aggregate(self) -> AggregatedStream;
fn streaming_mode(self) -> ResponseStream;
}
impl AggregateStreamExt for ResponseStream {
fn aggregate(self) -> AggregatedStream {
AggregatedStream::new(self, AggregateMode::AggregatedOnly)
}
fn streaming_mode(self) -> ResponseStream {
self
}
}
impl AggregatedStream {
fn new(inner: ResponseStream, mode: AggregateMode) -> Self {
AggregatedStream {
inner,
cumulative: String::new(),
cumulative_reasoning: String::new(),
pending: VecDeque::new(),
mode,
}
}
}

View File

@@ -0,0 +1,162 @@
use crate::auth::AuthProvider;
use crate::auth::add_auth_headers;
use crate::common::CompactionInput;
use crate::error::ApiError;
use crate::provider::Provider;
use crate::provider::WireApi;
use crate::telemetry::run_with_request_telemetry;
use codex_client::HttpTransport;
use codex_client::RequestTelemetry;
use codex_protocol::models::ResponseItem;
use http::HeaderMap;
use http::Method;
use serde::Deserialize;
use serde_json::to_value;
use std::sync::Arc;
pub struct CompactClient<T: HttpTransport, A: AuthProvider> {
transport: T,
provider: Provider,
auth: A,
request_telemetry: Option<Arc<dyn RequestTelemetry>>,
}
impl<T: HttpTransport, A: AuthProvider> CompactClient<T, A> {
pub fn new(transport: T, provider: Provider, auth: A) -> Self {
Self {
transport,
provider,
auth,
request_telemetry: None,
}
}
pub fn with_telemetry(mut self, request: Option<Arc<dyn RequestTelemetry>>) -> Self {
self.request_telemetry = request;
self
}
fn path(&self) -> Result<&'static str, ApiError> {
match self.provider.wire {
WireApi::Compact | WireApi::Responses => Ok("responses/compact"),
WireApi::Chat => Err(ApiError::Stream(
"compact endpoint requires responses wire api".to_string(),
)),
}
}
pub async fn compact(
&self,
body: serde_json::Value,
extra_headers: HeaderMap,
) -> Result<Vec<ResponseItem>, ApiError> {
let path = self.path()?;
let builder = || {
let mut req = self.provider.build_request(Method::POST, path);
req.headers.extend(extra_headers.clone());
req.body = Some(body.clone());
add_auth_headers(&self.auth, req)
};
let resp = run_with_request_telemetry(
self.provider.retry.to_policy(),
self.request_telemetry.clone(),
builder,
|req| self.transport.execute(req),
)
.await?;
let parsed: CompactHistoryResponse =
serde_json::from_slice(&resp.body).map_err(|e| ApiError::Stream(e.to_string()))?;
Ok(parsed.output)
}
pub async fn compact_input(
&self,
input: &CompactionInput<'_>,
extra_headers: HeaderMap,
) -> Result<Vec<ResponseItem>, ApiError> {
let body = to_value(input)
.map_err(|e| ApiError::Stream(format!("failed to encode compaction input: {e}")))?;
self.compact(body, extra_headers).await
}
}
#[derive(Debug, Deserialize)]
struct CompactHistoryResponse {
output: Vec<ResponseItem>,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::provider::RetryConfig;
use async_trait::async_trait;
use codex_client::Request;
use codex_client::Response;
use codex_client::StreamResponse;
use codex_client::TransportError;
use http::HeaderMap;
use std::time::Duration;
#[derive(Clone, Default)]
struct DummyTransport;
#[async_trait]
impl HttpTransport for DummyTransport {
async fn execute(&self, _req: Request) -> Result<Response, TransportError> {
Err(TransportError::Build("execute should not run".to_string()))
}
async fn stream(&self, _req: Request) -> Result<StreamResponse, TransportError> {
Err(TransportError::Build("stream should not run".to_string()))
}
}
#[derive(Clone, Default)]
struct DummyAuth;
impl AuthProvider for DummyAuth {
fn bearer_token(&self) -> Option<String> {
None
}
}
fn provider(wire: WireApi) -> Provider {
Provider {
name: "test".to_string(),
base_url: "https://example.com/v1".to_string(),
query_params: None,
wire,
headers: HeaderMap::new(),
retry: RetryConfig {
max_attempts: 1,
base_delay: Duration::from_millis(1),
retry_429: false,
retry_5xx: true,
retry_transport: true,
},
stream_idle_timeout: Duration::from_secs(1),
}
}
#[tokio::test]
async fn errors_when_wire_is_chat() {
let client = CompactClient::new(DummyTransport, provider(WireApi::Chat), DummyAuth);
let input = CompactionInput {
model: "gpt-test",
input: &[],
instructions: "inst",
};
let err = client
.compact_input(&input, HeaderMap::new())
.await
.expect_err("expected wire mismatch to fail");
match err {
ApiError::Stream(msg) => {
assert_eq!(msg, "compact endpoint requires responses wire api");
}
other => panic!("unexpected error: {other:?}"),
}
}
}

View File

@@ -0,0 +1,4 @@
pub mod chat;
pub mod compact;
pub mod responses;
mod streaming;

View File

@@ -0,0 +1,107 @@
use crate::auth::AuthProvider;
use crate::common::Prompt as ApiPrompt;
use crate::common::Reasoning;
use crate::common::ResponseStream;
use crate::common::TextControls;
use crate::endpoint::streaming::StreamingClient;
use crate::error::ApiError;
use crate::provider::Provider;
use crate::provider::WireApi;
use crate::requests::ResponsesRequest;
use crate::requests::ResponsesRequestBuilder;
use crate::sse::spawn_response_stream;
use crate::telemetry::SseTelemetry;
use codex_client::HttpTransport;
use codex_client::RequestTelemetry;
use codex_protocol::protocol::SessionSource;
use http::HeaderMap;
use serde_json::Value;
use std::sync::Arc;
pub struct ResponsesClient<T: HttpTransport, A: AuthProvider> {
streaming: StreamingClient<T, A>,
}
#[derive(Default)]
pub struct ResponsesOptions {
pub reasoning: Option<Reasoning>,
pub include: Vec<String>,
pub prompt_cache_key: Option<String>,
pub text: Option<TextControls>,
pub store_override: Option<bool>,
pub conversation_id: Option<String>,
pub session_source: Option<SessionSource>,
}
impl<T: HttpTransport, A: AuthProvider> ResponsesClient<T, A> {
pub fn new(transport: T, provider: Provider, auth: A) -> Self {
Self {
streaming: StreamingClient::new(transport, provider, auth),
}
}
pub fn with_telemetry(
self,
request: Option<Arc<dyn RequestTelemetry>>,
sse: Option<Arc<dyn SseTelemetry>>,
) -> Self {
Self {
streaming: self.streaming.with_telemetry(request, sse),
}
}
pub async fn stream_request(
&self,
request: ResponsesRequest,
) -> Result<ResponseStream, ApiError> {
self.stream(request.body, request.headers).await
}
pub async fn stream_prompt(
&self,
model: &str,
prompt: &ApiPrompt,
options: ResponsesOptions,
) -> Result<ResponseStream, ApiError> {
let ResponsesOptions {
reasoning,
include,
prompt_cache_key,
text,
store_override,
conversation_id,
session_source,
} = options;
let request = ResponsesRequestBuilder::new(model, &prompt.instructions, &prompt.input)
.tools(&prompt.tools)
.parallel_tool_calls(prompt.parallel_tool_calls)
.reasoning(reasoning)
.include(include)
.prompt_cache_key(prompt_cache_key)
.text(text)
.conversation(conversation_id)
.session_source(session_source)
.store_override(store_override)
.build(self.streaming.provider())?;
self.stream_request(request).await
}
fn path(&self) -> &'static str {
match self.streaming.provider().wire {
WireApi::Responses | WireApi::Compact => "responses",
WireApi::Chat => "chat/completions",
}
}
pub async fn stream(
&self,
body: Value,
extra_headers: HeaderMap,
) -> Result<ResponseStream, ApiError> {
self.streaming
.stream(self.path(), body, extra_headers, spawn_response_stream)
.await
}
}

View File

@@ -0,0 +1,82 @@
use crate::auth::AuthProvider;
use crate::auth::add_auth_headers;
use crate::common::ResponseStream;
use crate::error::ApiError;
use crate::provider::Provider;
use crate::telemetry::SseTelemetry;
use crate::telemetry::run_with_request_telemetry;
use codex_client::HttpTransport;
use codex_client::RequestTelemetry;
use codex_client::StreamResponse;
use http::HeaderMap;
use http::Method;
use serde_json::Value;
use std::sync::Arc;
use std::time::Duration;
pub(crate) struct StreamingClient<T: HttpTransport, A: AuthProvider> {
transport: T,
provider: Provider,
auth: A,
request_telemetry: Option<Arc<dyn RequestTelemetry>>,
sse_telemetry: Option<Arc<dyn SseTelemetry>>,
}
impl<T: HttpTransport, A: AuthProvider> StreamingClient<T, A> {
pub(crate) fn new(transport: T, provider: Provider, auth: A) -> Self {
Self {
transport,
provider,
auth,
request_telemetry: None,
sse_telemetry: None,
}
}
pub(crate) fn with_telemetry(
mut self,
request: Option<Arc<dyn RequestTelemetry>>,
sse: Option<Arc<dyn SseTelemetry>>,
) -> Self {
self.request_telemetry = request;
self.sse_telemetry = sse;
self
}
pub(crate) fn provider(&self) -> &Provider {
&self.provider
}
pub(crate) async fn stream(
&self,
path: &str,
body: Value,
extra_headers: HeaderMap,
spawner: fn(StreamResponse, Duration, Option<Arc<dyn SseTelemetry>>) -> ResponseStream,
) -> Result<ResponseStream, ApiError> {
let builder = || {
let mut req = self.provider.build_request(Method::POST, path);
req.headers.extend(extra_headers.clone());
req.headers.insert(
http::header::ACCEPT,
http::HeaderValue::from_static("text/event-stream"),
);
req.body = Some(body.clone());
add_auth_headers(&self.auth, req)
};
let stream_response = run_with_request_telemetry(
self.provider.retry.to_policy(),
self.request_telemetry.clone(),
builder,
|req| self.transport.stream(req),
)
.await?;
Ok(spawner(
stream_response,
self.provider.stream_idle_timeout,
self.sse_telemetry.clone(),
))
}
}

View File

@@ -0,0 +1,34 @@
use crate::rate_limits::RateLimitError;
use codex_client::TransportError;
use http::StatusCode;
use std::time::Duration;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum ApiError {
#[error(transparent)]
Transport(#[from] TransportError),
#[error("api error {status}: {message}")]
Api { status: StatusCode, message: String },
#[error("stream error: {0}")]
Stream(String),
#[error("context window exceeded")]
ContextWindowExceeded,
#[error("quota exceeded")]
QuotaExceeded,
#[error("usage not included")]
UsageNotIncluded,
#[error("retryable error: {message}")]
Retryable {
message: String,
delay: Option<Duration>,
},
#[error("rate limit: {0}")]
RateLimit(String),
}
impl From<RateLimitError> for ApiError {
fn from(err: RateLimitError) -> Self {
Self::RateLimit(err.to_string())
}
}

View File

@@ -0,0 +1,35 @@
pub mod auth;
pub mod common;
pub mod endpoint;
pub mod error;
pub mod provider;
pub mod rate_limits;
pub mod requests;
pub mod sse;
pub mod telemetry;
pub use codex_client::RequestTelemetry;
pub use codex_client::ReqwestTransport;
pub use codex_client::TransportError;
pub use crate::auth::AuthProvider;
pub use crate::common::CompactionInput;
pub use crate::common::Prompt;
pub use crate::common::ResponseEvent;
pub use crate::common::ResponseStream;
pub use crate::common::ResponsesApiRequest;
pub use crate::common::create_text_param_for_request;
pub use crate::endpoint::chat::AggregateStreamExt;
pub use crate::endpoint::chat::ChatClient;
pub use crate::endpoint::compact::CompactClient;
pub use crate::endpoint::responses::ResponsesClient;
pub use crate::endpoint::responses::ResponsesOptions;
pub use crate::error::ApiError;
pub use crate::provider::Provider;
pub use crate::provider::WireApi;
pub use crate::requests::ChatRequest;
pub use crate::requests::ChatRequestBuilder;
pub use crate::requests::ResponsesRequest;
pub use crate::requests::ResponsesRequestBuilder;
pub use crate::sse::stream_from_fixture;
pub use crate::telemetry::SseTelemetry;

View File

@@ -0,0 +1,118 @@
use codex_client::Request;
use codex_client::RetryOn;
use codex_client::RetryPolicy;
use http::Method;
use http::header::HeaderMap;
use std::collections::HashMap;
use std::time::Duration;
/// Wire-level APIs supported by a `Provider`.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum WireApi {
Responses,
Chat,
Compact,
}
/// High-level retry configuration for a provider.
///
/// This is converted into a `RetryPolicy` used by `codex-client` to drive
/// transport-level retries for both unary and streaming calls.
#[derive(Debug, Clone)]
pub struct RetryConfig {
pub max_attempts: u64,
pub base_delay: Duration,
pub retry_429: bool,
pub retry_5xx: bool,
pub retry_transport: bool,
}
impl RetryConfig {
pub fn to_policy(&self) -> RetryPolicy {
RetryPolicy {
max_attempts: self.max_attempts,
base_delay: self.base_delay,
retry_on: RetryOn {
retry_429: self.retry_429,
retry_5xx: self.retry_5xx,
retry_transport: self.retry_transport,
},
}
}
}
/// HTTP endpoint configuration used to talk to a concrete API deployment.
///
/// Encapsulates base URL, default headers, query params, retry policy, and
/// stream idle timeout, plus helper methods for building requests.
#[derive(Debug, Clone)]
pub struct Provider {
pub name: String,
pub base_url: String,
pub query_params: Option<HashMap<String, String>>,
pub wire: WireApi,
pub headers: HeaderMap,
pub retry: RetryConfig,
pub stream_idle_timeout: Duration,
}
impl Provider {
pub fn url_for_path(&self, path: &str) -> String {
let base = self.base_url.trim_end_matches('/');
let path = path.trim_start_matches('/');
let mut url = if path.is_empty() {
base.to_string()
} else {
format!("{base}/{path}")
};
if let Some(params) = &self.query_params
&& !params.is_empty()
{
let qs = params
.iter()
.map(|(k, v)| format!("{k}={v}"))
.collect::<Vec<_>>()
.join("&");
url.push('?');
url.push_str(&qs);
}
url
}
pub fn build_request(&self, method: Method, path: &str) -> Request {
Request {
method,
url: self.url_for_path(path),
headers: self.headers.clone(),
body: None,
timeout: None,
}
}
pub fn is_azure_responses_endpoint(&self) -> bool {
if self.wire != WireApi::Responses {
return false;
}
if self.name.eq_ignore_ascii_case("azure") {
return true;
}
self.base_url.to_ascii_lowercase().contains("openai.azure.")
|| matches_azure_responses_base_url(&self.base_url)
}
}
fn matches_azure_responses_base_url(base_url: &str) -> bool {
const AZURE_MARKERS: [&str; 5] = [
"cognitiveservices.azure.",
"aoai.azure.",
"azure-api.",
"azurefd.",
"windows.net/openai",
];
let base = base_url.to_ascii_lowercase();
AZURE_MARKERS.iter().any(|marker| base.contains(marker))
}

View File

@@ -0,0 +1,105 @@
use codex_protocol::protocol::CreditsSnapshot;
use codex_protocol::protocol::RateLimitSnapshot;
use codex_protocol::protocol::RateLimitWindow;
use http::HeaderMap;
use std::fmt::Display;
#[derive(Debug)]
pub struct RateLimitError {
pub message: String,
}
impl Display for RateLimitError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.message)
}
}
/// Parses the bespoke Codex rate-limit headers into a `RateLimitSnapshot`.
pub fn parse_rate_limit(headers: &HeaderMap) -> Option<RateLimitSnapshot> {
let primary = parse_rate_limit_window(
headers,
"x-codex-primary-used-percent",
"x-codex-primary-window-minutes",
"x-codex-primary-reset-at",
);
let secondary = parse_rate_limit_window(
headers,
"x-codex-secondary-used-percent",
"x-codex-secondary-window-minutes",
"x-codex-secondary-reset-at",
);
let credits = parse_credits_snapshot(headers);
Some(RateLimitSnapshot {
primary,
secondary,
credits,
})
}
fn parse_rate_limit_window(
headers: &HeaderMap,
used_percent_header: &str,
window_minutes_header: &str,
resets_at_header: &str,
) -> Option<RateLimitWindow> {
let used_percent: Option<f64> = parse_header_f64(headers, used_percent_header);
used_percent.and_then(|used_percent| {
let window_minutes = parse_header_i64(headers, window_minutes_header);
let resets_at = parse_header_i64(headers, resets_at_header);
let has_data = used_percent != 0.0
|| window_minutes.is_some_and(|minutes| minutes != 0)
|| resets_at.is_some();
has_data.then_some(RateLimitWindow {
used_percent,
window_minutes,
resets_at,
})
})
}
fn parse_credits_snapshot(headers: &HeaderMap) -> Option<CreditsSnapshot> {
let has_credits = parse_header_bool(headers, "x-codex-credits-has-credits")?;
let unlimited = parse_header_bool(headers, "x-codex-credits-unlimited")?;
let balance = parse_header_str(headers, "x-codex-credits-balance")
.map(str::trim)
.filter(|value| !value.is_empty())
.map(std::string::ToString::to_string);
Some(CreditsSnapshot {
has_credits,
unlimited,
balance,
})
}
fn parse_header_f64(headers: &HeaderMap, name: &str) -> Option<f64> {
parse_header_str(headers, name)?
.parse::<f64>()
.ok()
.filter(|v| v.is_finite())
}
fn parse_header_i64(headers: &HeaderMap, name: &str) -> Option<i64> {
parse_header_str(headers, name)?.parse::<i64>().ok()
}
fn parse_header_bool(headers: &HeaderMap, name: &str) -> Option<bool> {
let raw = parse_header_str(headers, name)?;
if raw.eq_ignore_ascii_case("true") || raw == "1" {
Some(true)
} else if raw.eq_ignore_ascii_case("false") || raw == "0" {
Some(false)
} else {
None
}
}
fn parse_header_str<'a>(headers: &'a HeaderMap, name: &str) -> Option<&'a str> {
headers.get(name)?.to_str().ok()
}

View File

@@ -0,0 +1,388 @@
use crate::error::ApiError;
use crate::provider::Provider;
use crate::requests::headers::build_conversation_headers;
use crate::requests::headers::insert_header;
use crate::requests::headers::subagent_header;
use codex_protocol::models::ContentItem;
use codex_protocol::models::FunctionCallOutputContentItem;
use codex_protocol::models::ReasoningItemContent;
use codex_protocol::models::ResponseItem;
use codex_protocol::protocol::SessionSource;
use http::HeaderMap;
use serde_json::Value;
use serde_json::json;
use std::collections::HashMap;
/// Assembled request body plus headers for Chat Completions streaming calls.
pub struct ChatRequest {
pub body: Value,
pub headers: HeaderMap,
}
pub struct ChatRequestBuilder<'a> {
model: &'a str,
instructions: &'a str,
input: &'a [ResponseItem],
tools: &'a [Value],
conversation_id: Option<String>,
session_source: Option<SessionSource>,
}
impl<'a> ChatRequestBuilder<'a> {
pub fn new(
model: &'a str,
instructions: &'a str,
input: &'a [ResponseItem],
tools: &'a [Value],
) -> Self {
Self {
model,
instructions,
input,
tools,
conversation_id: None,
session_source: None,
}
}
pub fn conversation_id(mut self, id: Option<String>) -> Self {
self.conversation_id = id;
self
}
pub fn session_source(mut self, source: Option<SessionSource>) -> Self {
self.session_source = source;
self
}
pub fn build(self, _provider: &Provider) -> Result<ChatRequest, ApiError> {
let mut messages = Vec::<Value>::new();
messages.push(json!({"role": "system", "content": self.instructions}));
let input = self.input;
let mut reasoning_by_anchor_index: HashMap<usize, String> = HashMap::new();
let mut last_emitted_role: Option<&str> = None;
for item in input {
match item {
ResponseItem::Message { role, .. } => last_emitted_role = Some(role.as_str()),
ResponseItem::FunctionCall { .. } | ResponseItem::LocalShellCall { .. } => {
last_emitted_role = Some("assistant")
}
ResponseItem::FunctionCallOutput { .. } => last_emitted_role = Some("tool"),
ResponseItem::Reasoning { .. } | ResponseItem::Other => {}
ResponseItem::CustomToolCall { .. } => {}
ResponseItem::CustomToolCallOutput { .. } => {}
ResponseItem::WebSearchCall { .. } => {}
ResponseItem::GhostSnapshot { .. } => {}
ResponseItem::CompactionSummary { .. } => {}
}
}
let mut last_user_index: Option<usize> = None;
for (idx, item) in input.iter().enumerate() {
if let ResponseItem::Message { role, .. } = item
&& role == "user"
{
last_user_index = Some(idx);
}
}
if !matches!(last_emitted_role, Some("user")) {
for (idx, item) in input.iter().enumerate() {
if let Some(u_idx) = last_user_index
&& idx <= u_idx
{
continue;
}
if let ResponseItem::Reasoning {
content: Some(items),
..
} = item
{
let mut text = String::new();
for entry in items {
match entry {
ReasoningItemContent::ReasoningText { text: segment }
| ReasoningItemContent::Text { text: segment } => {
text.push_str(segment)
}
}
}
if text.trim().is_empty() {
continue;
}
let mut attached = false;
if idx > 0
&& let ResponseItem::Message { role, .. } = &input[idx - 1]
&& role == "assistant"
{
reasoning_by_anchor_index
.entry(idx - 1)
.and_modify(|v| v.push_str(&text))
.or_insert(text.clone());
attached = true;
}
if !attached && idx + 1 < input.len() {
match &input[idx + 1] {
ResponseItem::FunctionCall { .. }
| ResponseItem::LocalShellCall { .. } => {
reasoning_by_anchor_index
.entry(idx + 1)
.and_modify(|v| v.push_str(&text))
.or_insert(text.clone());
}
ResponseItem::Message { role, .. } if role == "assistant" => {
reasoning_by_anchor_index
.entry(idx + 1)
.and_modify(|v| v.push_str(&text))
.or_insert(text.clone());
}
_ => {}
}
}
}
}
}
let mut last_assistant_text: Option<String> = None;
for (idx, item) in input.iter().enumerate() {
match item {
ResponseItem::Message { role, content, .. } => {
let mut text = String::new();
let mut items: Vec<Value> = Vec::new();
let mut saw_image = false;
for c in content {
match c {
ContentItem::InputText { text: t }
| ContentItem::OutputText { text: t } => {
text.push_str(t);
items.push(json!({"type":"text","text": t}));
}
ContentItem::InputImage { image_url } => {
saw_image = true;
items.push(
json!({"type":"image_url","image_url": {"url": image_url}}),
);
}
}
}
if role == "assistant" {
if let Some(prev) = &last_assistant_text
&& prev == &text
{
continue;
}
last_assistant_text = Some(text.clone());
}
let content_value = if role == "assistant" {
json!(text)
} else if saw_image {
json!(items)
} else {
json!(text)
};
let mut msg = json!({"role": role, "content": content_value});
if role == "assistant"
&& let Some(reasoning) = reasoning_by_anchor_index.get(&idx)
&& let Some(obj) = msg.as_object_mut()
{
obj.insert("reasoning".to_string(), json!(reasoning));
}
messages.push(msg);
}
ResponseItem::FunctionCall {
name,
arguments,
call_id,
..
} => {
let mut msg = json!({
"role": "assistant",
"content": null,
"tool_calls": [{
"id": call_id,
"type": "function",
"function": {
"name": name,
"arguments": arguments,
}
}]
});
if let Some(reasoning) = reasoning_by_anchor_index.get(&idx)
&& let Some(obj) = msg.as_object_mut()
{
obj.insert("reasoning".to_string(), json!(reasoning));
}
messages.push(msg);
}
ResponseItem::LocalShellCall {
id,
call_id: _,
status,
action,
} => {
let mut msg = json!({
"role": "assistant",
"content": null,
"tool_calls": [{
"id": id.clone().unwrap_or_default(),
"type": "local_shell_call",
"status": status,
"action": action,
}]
});
if let Some(reasoning) = reasoning_by_anchor_index.get(&idx)
&& let Some(obj) = msg.as_object_mut()
{
obj.insert("reasoning".to_string(), json!(reasoning));
}
messages.push(msg);
}
ResponseItem::FunctionCallOutput { call_id, output } => {
let content_value = if let Some(items) = &output.content_items {
let mapped: Vec<Value> = items
.iter()
.map(|it| match it {
FunctionCallOutputContentItem::InputText { text } => {
json!({"type":"text","text": text})
}
FunctionCallOutputContentItem::InputImage { image_url } => {
json!({"type":"image_url","image_url": {"url": image_url}})
}
})
.collect();
json!(mapped)
} else {
json!(output.content)
};
messages.push(json!({
"role": "tool",
"tool_call_id": call_id,
"content": content_value,
}));
}
ResponseItem::CustomToolCall {
id,
call_id: _,
name,
input,
status: _,
} => {
messages.push(json!({
"role": "assistant",
"content": null,
"tool_calls": [{
"id": id,
"type": "custom",
"custom": {
"name": name,
"input": input,
}
}]
}));
}
ResponseItem::CustomToolCallOutput { call_id, output } => {
messages.push(json!({
"role": "tool",
"tool_call_id": call_id,
"content": output,
}));
}
ResponseItem::GhostSnapshot { .. } => {
continue;
}
ResponseItem::Reasoning { .. }
| ResponseItem::WebSearchCall { .. }
| ResponseItem::Other
| ResponseItem::CompactionSummary { .. } => {
continue;
}
}
}
let payload = json!({
"model": self.model,
"messages": messages,
"stream": true,
"tools": self.tools,
});
let mut headers = build_conversation_headers(self.conversation_id);
if let Some(subagent) = subagent_header(&self.session_source) {
insert_header(&mut headers, "x-openai-subagent", &subagent);
}
Ok(ChatRequest {
body: payload,
headers,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::provider::RetryConfig;
use crate::provider::WireApi;
use codex_protocol::protocol::SessionSource;
use codex_protocol::protocol::SubAgentSource;
use http::HeaderValue;
use pretty_assertions::assert_eq;
use std::time::Duration;
fn provider() -> Provider {
Provider {
name: "openai".to_string(),
base_url: "https://api.openai.com/v1".to_string(),
query_params: None,
wire: WireApi::Chat,
headers: HeaderMap::new(),
retry: RetryConfig {
max_attempts: 1,
base_delay: Duration::from_millis(10),
retry_429: false,
retry_5xx: true,
retry_transport: true,
},
stream_idle_timeout: Duration::from_secs(1),
}
}
#[test]
fn attaches_conversation_and_subagent_headers() {
let prompt_input = vec![ResponseItem::Message {
id: None,
role: "user".to_string(),
content: vec![ContentItem::InputText {
text: "hi".to_string(),
}],
}];
let req = ChatRequestBuilder::new("gpt-test", "inst", &prompt_input, &[])
.conversation_id(Some("conv-1".into()))
.session_source(Some(SessionSource::SubAgent(SubAgentSource::Review)))
.build(&provider())
.expect("request");
assert_eq!(
req.headers.get("conversation_id"),
Some(&HeaderValue::from_static("conv-1"))
);
assert_eq!(
req.headers.get("session_id"),
Some(&HeaderValue::from_static("conv-1"))
);
assert_eq!(
req.headers.get("x-openai-subagent"),
Some(&HeaderValue::from_static("review"))
);
}
}

View File

@@ -0,0 +1,36 @@
use codex_protocol::protocol::SessionSource;
use http::HeaderMap;
use http::HeaderValue;
pub(crate) fn build_conversation_headers(conversation_id: Option<String>) -> HeaderMap {
let mut headers = HeaderMap::new();
if let Some(id) = conversation_id {
insert_header(&mut headers, "conversation_id", &id);
insert_header(&mut headers, "session_id", &id);
}
headers
}
pub(crate) fn subagent_header(source: &Option<SessionSource>) -> Option<String> {
let SessionSource::SubAgent(sub) = source.as_ref()? else {
return None;
};
match sub {
codex_protocol::protocol::SubAgentSource::Other(label) => Some(label.clone()),
other => Some(
serde_json::to_value(other)
.ok()
.and_then(|v| v.as_str().map(std::string::ToString::to_string))
.unwrap_or_else(|| "other".to_string()),
),
}
}
pub(crate) fn insert_header(headers: &mut HeaderMap, name: &str, value: &str) {
if let (Ok(header_name), Ok(header_value)) = (
name.parse::<http::HeaderName>(),
HeaderValue::from_str(value),
) {
headers.insert(header_name, header_value);
}
}

View File

@@ -0,0 +1,8 @@
pub mod chat;
pub(crate) mod headers;
pub mod responses;
pub use chat::ChatRequest;
pub use chat::ChatRequestBuilder;
pub use responses::ResponsesRequest;
pub use responses::ResponsesRequestBuilder;

View File

@@ -0,0 +1,247 @@
use crate::common::Reasoning;
use crate::common::ResponsesApiRequest;
use crate::common::TextControls;
use crate::error::ApiError;
use crate::provider::Provider;
use crate::requests::headers::build_conversation_headers;
use crate::requests::headers::insert_header;
use crate::requests::headers::subagent_header;
use codex_protocol::models::ResponseItem;
use codex_protocol::protocol::SessionSource;
use http::HeaderMap;
use serde_json::Value;
/// Assembled request body plus headers for a Responses stream request.
pub struct ResponsesRequest {
pub body: Value,
pub headers: HeaderMap,
}
#[derive(Default)]
pub struct ResponsesRequestBuilder<'a> {
model: Option<&'a str>,
instructions: Option<&'a str>,
input: Option<&'a [ResponseItem]>,
tools: Option<&'a [Value]>,
parallel_tool_calls: bool,
reasoning: Option<Reasoning>,
include: Vec<String>,
prompt_cache_key: Option<String>,
text: Option<TextControls>,
conversation_id: Option<String>,
session_source: Option<SessionSource>,
store_override: Option<bool>,
headers: HeaderMap,
}
impl<'a> ResponsesRequestBuilder<'a> {
pub fn new(model: &'a str, instructions: &'a str, input: &'a [ResponseItem]) -> Self {
Self {
model: Some(model),
instructions: Some(instructions),
input: Some(input),
..Default::default()
}
}
pub fn tools(mut self, tools: &'a [Value]) -> Self {
self.tools = Some(tools);
self
}
pub fn parallel_tool_calls(mut self, enabled: bool) -> Self {
self.parallel_tool_calls = enabled;
self
}
pub fn reasoning(mut self, reasoning: Option<Reasoning>) -> Self {
self.reasoning = reasoning;
self
}
pub fn include(mut self, include: Vec<String>) -> Self {
self.include = include;
self
}
pub fn prompt_cache_key(mut self, key: Option<String>) -> Self {
self.prompt_cache_key = key;
self
}
pub fn text(mut self, text: Option<TextControls>) -> Self {
self.text = text;
self
}
pub fn conversation(mut self, conversation_id: Option<String>) -> Self {
self.conversation_id = conversation_id;
self
}
pub fn session_source(mut self, source: Option<SessionSource>) -> Self {
self.session_source = source;
self
}
pub fn store_override(mut self, store: Option<bool>) -> Self {
self.store_override = store;
self
}
pub fn extra_headers(mut self, headers: HeaderMap) -> Self {
self.headers = headers;
self
}
pub fn build(self, provider: &Provider) -> Result<ResponsesRequest, ApiError> {
let model = self
.model
.ok_or_else(|| ApiError::Stream("missing model for responses request".into()))?;
let instructions = self
.instructions
.ok_or_else(|| ApiError::Stream("missing instructions for responses request".into()))?;
let input = self
.input
.ok_or_else(|| ApiError::Stream("missing input for responses request".into()))?;
let tools = self.tools.unwrap_or_default();
let store = self
.store_override
.unwrap_or_else(|| provider.is_azure_responses_endpoint());
let req = ResponsesApiRequest {
model,
instructions,
input,
tools,
tool_choice: "auto",
parallel_tool_calls: self.parallel_tool_calls,
reasoning: self.reasoning,
store,
stream: true,
include: self.include,
prompt_cache_key: self.prompt_cache_key,
text: self.text,
};
let mut body = serde_json::to_value(&req)
.map_err(|e| ApiError::Stream(format!("failed to encode responses request: {e}")))?;
if store && provider.is_azure_responses_endpoint() {
attach_item_ids(&mut body, input);
}
let mut headers = self.headers;
headers.extend(build_conversation_headers(self.conversation_id));
if let Some(subagent) = subagent_header(&self.session_source) {
insert_header(&mut headers, "x-openai-subagent", &subagent);
}
Ok(ResponsesRequest { body, headers })
}
}
fn attach_item_ids(payload_json: &mut Value, original_items: &[ResponseItem]) {
let Some(input_value) = payload_json.get_mut("input") else {
return;
};
let Value::Array(items) = input_value else {
return;
};
for (value, item) in items.iter_mut().zip(original_items.iter()) {
if let ResponseItem::Reasoning { id, .. }
| ResponseItem::Message { id: Some(id), .. }
| ResponseItem::WebSearchCall { id: Some(id), .. }
| ResponseItem::FunctionCall { id: Some(id), .. }
| ResponseItem::LocalShellCall { id: Some(id), .. }
| ResponseItem::CustomToolCall { id: Some(id), .. } = item
{
if id.is_empty() {
continue;
}
if let Some(obj) = value.as_object_mut() {
obj.insert("id".to_string(), Value::String(id.clone()));
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::provider::RetryConfig;
use crate::provider::WireApi;
use codex_protocol::protocol::SubAgentSource;
use http::HeaderValue;
use pretty_assertions::assert_eq;
use std::time::Duration;
fn provider(name: &str, base_url: &str) -> Provider {
Provider {
name: name.to_string(),
base_url: base_url.to_string(),
query_params: None,
wire: WireApi::Responses,
headers: HeaderMap::new(),
retry: RetryConfig {
max_attempts: 1,
base_delay: Duration::from_millis(50),
retry_429: false,
retry_5xx: true,
retry_transport: true,
},
stream_idle_timeout: Duration::from_secs(5),
}
}
#[test]
fn azure_default_store_attaches_ids_and_headers() {
let provider = provider("azure", "https://example.openai.azure.com/v1");
let input = vec![
ResponseItem::Message {
id: Some("m1".into()),
role: "assistant".into(),
content: Vec::new(),
},
ResponseItem::Message {
id: None,
role: "assistant".into(),
content: Vec::new(),
},
];
let request = ResponsesRequestBuilder::new("gpt-test", "inst", &input)
.conversation(Some("conv-1".into()))
.session_source(Some(SessionSource::SubAgent(SubAgentSource::Review)))
.build(&provider)
.expect("request");
assert_eq!(request.body.get("store"), Some(&Value::Bool(true)));
let ids: Vec<Option<String>> = request
.body
.get("input")
.and_then(|v| v.as_array())
.into_iter()
.flatten()
.map(|item| item.get("id").and_then(|v| v.as_str().map(str::to_string)))
.collect();
assert_eq!(ids, vec![Some("m1".to_string()), None]);
assert_eq!(
request.headers.get("conversation_id"),
Some(&HeaderValue::from_static("conv-1"))
);
assert_eq!(
request.headers.get("session_id"),
Some(&HeaderValue::from_static("conv-1"))
);
assert_eq!(
request.headers.get("x-openai-subagent"),
Some(&HeaderValue::from_static("review"))
);
}
}

View File

@@ -0,0 +1,504 @@
use crate::common::ResponseEvent;
use crate::common::ResponseStream;
use crate::error::ApiError;
use crate::telemetry::SseTelemetry;
use codex_client::StreamResponse;
use codex_protocol::models::ContentItem;
use codex_protocol::models::ReasoningItemContent;
use codex_protocol::models::ResponseItem;
use eventsource_stream::Eventsource;
use futures::Stream;
use futures::StreamExt;
use std::collections::HashMap;
use std::time::Duration;
use tokio::sync::mpsc;
use tokio::time::Instant;
use tokio::time::timeout;
use tracing::debug;
use tracing::trace;
pub(crate) fn spawn_chat_stream(
stream_response: StreamResponse,
idle_timeout: Duration,
telemetry: Option<std::sync::Arc<dyn SseTelemetry>>,
) -> ResponseStream {
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent, ApiError>>(1600);
tokio::spawn(async move {
process_chat_sse(stream_response.bytes, tx_event, idle_timeout, telemetry).await;
});
ResponseStream { rx_event }
}
pub async fn process_chat_sse<S>(
stream: S,
tx_event: mpsc::Sender<Result<ResponseEvent, ApiError>>,
idle_timeout: Duration,
telemetry: Option<std::sync::Arc<dyn SseTelemetry>>,
) where
S: Stream<Item = Result<bytes::Bytes, codex_client::TransportError>> + Unpin,
{
let mut stream = stream.eventsource();
#[derive(Default, Debug)]
struct ToolCallState {
name: Option<String>,
arguments: String,
}
let mut tool_calls: HashMap<String, ToolCallState> = HashMap::new();
let mut tool_call_order: Vec<String> = Vec::new();
let mut assistant_item: Option<ResponseItem> = None;
let mut reasoning_item: Option<ResponseItem> = None;
let mut completed_sent = false;
loop {
let start = Instant::now();
let response = timeout(idle_timeout, stream.next()).await;
if let Some(t) = telemetry.as_ref() {
t.on_sse_poll(&response, start.elapsed());
}
let sse = match response {
Ok(Some(Ok(sse))) => sse,
Ok(Some(Err(e))) => {
let _ = tx_event.send(Err(ApiError::Stream(e.to_string()))).await;
return;
}
Ok(None) => {
if let Some(reasoning) = reasoning_item {
let _ = tx_event
.send(Ok(ResponseEvent::OutputItemDone(reasoning)))
.await;
}
if let Some(assistant) = assistant_item {
let _ = tx_event
.send(Ok(ResponseEvent::OutputItemDone(assistant)))
.await;
}
if !completed_sent {
let _ = tx_event
.send(Ok(ResponseEvent::Completed {
response_id: String::new(),
token_usage: None,
}))
.await;
}
return;
}
Err(_) => {
let _ = tx_event
.send(Err(ApiError::Stream("idle timeout waiting for SSE".into())))
.await;
return;
}
};
trace!("SSE event: {}", sse.data);
if sse.data.trim().is_empty() {
continue;
}
let value: serde_json::Value = match serde_json::from_str(&sse.data) {
Ok(val) => val,
Err(err) => {
debug!(
"Failed to parse ChatCompletions SSE event: {err}, data: {}",
&sse.data
);
continue;
}
};
let Some(choices) = value.get("choices").and_then(|c| c.as_array()) else {
continue;
};
for choice in choices {
if let Some(delta) = choice.get("delta") {
if let Some(reasoning) = delta.get("reasoning") {
if let Some(text) = reasoning.as_str() {
append_reasoning_text(&tx_event, &mut reasoning_item, text.to_string())
.await;
} else if let Some(text) = reasoning.get("text").and_then(|v| v.as_str()) {
append_reasoning_text(&tx_event, &mut reasoning_item, text.to_string())
.await;
} else if let Some(text) = reasoning.get("content").and_then(|v| v.as_str()) {
append_reasoning_text(&tx_event, &mut reasoning_item, text.to_string())
.await;
}
}
if let Some(content) = delta.get("content") {
if content.is_array() {
for item in content.as_array().unwrap_or(&vec![]) {
if let Some(text) = item.get("text").and_then(|t| t.as_str()) {
append_assistant_text(
&tx_event,
&mut assistant_item,
text.to_string(),
)
.await;
}
}
} else if let Some(text) = content.as_str() {
append_assistant_text(&tx_event, &mut assistant_item, text.to_string())
.await;
}
}
if let Some(tool_call_values) = delta.get("tool_calls").and_then(|c| c.as_array()) {
for tool_call in tool_call_values {
let id = tool_call
.get("id")
.and_then(|i| i.as_str())
.map(str::to_string)
.unwrap_or_else(|| format!("tool-call-{}", tool_call_order.len()));
let call_state = tool_calls.entry(id.clone()).or_default();
if !tool_call_order.contains(&id) {
tool_call_order.push(id.clone());
}
if let Some(func) = tool_call.get("function") {
if let Some(fname) = func.get("name").and_then(|n| n.as_str()) {
call_state.name = Some(fname.to_string());
}
if let Some(arguments) = func.get("arguments").and_then(|a| a.as_str())
{
call_state.arguments.push_str(arguments);
}
}
}
}
}
if let Some(message) = choice.get("message")
&& let Some(reasoning) = message.get("reasoning")
{
if let Some(text) = reasoning.as_str() {
append_reasoning_text(&tx_event, &mut reasoning_item, text.to_string()).await;
} else if let Some(text) = reasoning.get("text").and_then(|v| v.as_str()) {
append_reasoning_text(&tx_event, &mut reasoning_item, text.to_string()).await;
} else if let Some(text) = reasoning.get("content").and_then(|v| v.as_str()) {
append_reasoning_text(&tx_event, &mut reasoning_item, text.to_string()).await;
}
}
let finish_reason = choice.get("finish_reason").and_then(|r| r.as_str());
if finish_reason == Some("stop") {
if let Some(reasoning) = reasoning_item.take() {
let _ = tx_event
.send(Ok(ResponseEvent::OutputItemDone(reasoning)))
.await;
}
if let Some(assistant) = assistant_item.take() {
let _ = tx_event
.send(Ok(ResponseEvent::OutputItemDone(assistant)))
.await;
}
if !completed_sent {
let _ = tx_event
.send(Ok(ResponseEvent::Completed {
response_id: String::new(),
token_usage: None,
}))
.await;
completed_sent = true;
}
continue;
}
if finish_reason == Some("length") {
let _ = tx_event.send(Err(ApiError::ContextWindowExceeded)).await;
return;
}
if finish_reason == Some("tool_calls") {
if let Some(reasoning) = reasoning_item.take() {
let _ = tx_event
.send(Ok(ResponseEvent::OutputItemDone(reasoning)))
.await;
}
for call_id in tool_call_order.drain(..) {
let state = tool_calls.remove(&call_id).unwrap_or_default();
let item = ResponseItem::FunctionCall {
id: None,
name: state.name.unwrap_or_default(),
arguments: state.arguments,
call_id: call_id.clone(),
};
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
}
}
}
}
}
async fn append_assistant_text(
tx_event: &mpsc::Sender<Result<ResponseEvent, ApiError>>,
assistant_item: &mut Option<ResponseItem>,
text: String,
) {
if assistant_item.is_none() {
let item = ResponseItem::Message {
id: None,
role: "assistant".to_string(),
content: vec![],
};
*assistant_item = Some(item.clone());
let _ = tx_event
.send(Ok(ResponseEvent::OutputItemAdded(item)))
.await;
}
if let Some(ResponseItem::Message { content, .. }) = assistant_item {
content.push(ContentItem::OutputText { text: text.clone() });
let _ = tx_event
.send(Ok(ResponseEvent::OutputTextDelta(text.clone())))
.await;
}
}
async fn append_reasoning_text(
tx_event: &mpsc::Sender<Result<ResponseEvent, ApiError>>,
reasoning_item: &mut Option<ResponseItem>,
text: String,
) {
if reasoning_item.is_none() {
let item = ResponseItem::Reasoning {
id: String::new(),
summary: Vec::new(),
content: Some(vec![]),
encrypted_content: None,
};
*reasoning_item = Some(item.clone());
let _ = tx_event
.send(Ok(ResponseEvent::OutputItemAdded(item)))
.await;
}
if let Some(ResponseItem::Reasoning {
content: Some(content),
..
}) = reasoning_item
{
let content_index = content.len() as i64;
content.push(ReasoningItemContent::ReasoningText { text: text.clone() });
let _ = tx_event
.send(Ok(ResponseEvent::ReasoningContentDelta {
delta: text.clone(),
content_index,
}))
.await;
}
}
#[cfg(test)]
mod tests {
use super::*;
use assert_matches::assert_matches;
use codex_protocol::models::ResponseItem;
use futures::TryStreamExt;
use serde_json::json;
use tokio::sync::mpsc;
use tokio_util::io::ReaderStream;
fn build_body(events: &[serde_json::Value]) -> String {
let mut body = String::new();
for e in events {
body.push_str(&format!("event: message\ndata: {e}\n\n"));
}
body
}
async fn collect_events(body: &str) -> Vec<ResponseEvent> {
let reader = ReaderStream::new(std::io::Cursor::new(body.to_string()))
.map_err(|err| codex_client::TransportError::Network(err.to_string()));
let (tx, mut rx) = mpsc::channel::<Result<ResponseEvent, ApiError>>(16);
tokio::spawn(process_chat_sse(
reader,
tx,
Duration::from_millis(1000),
None,
));
let mut out = Vec::new();
while let Some(ev) = rx.recv().await {
out.push(ev.expect("stream error"));
}
out
}
#[tokio::test]
async fn emits_multiple_tool_calls() {
let delta_a = json!({
"choices": [{
"delta": {
"tool_calls": [{
"id": "call_a",
"function": { "name": "do_a", "arguments": "{\"foo\":1}" }
}]
}
}]
});
let delta_b = json!({
"choices": [{
"delta": {
"tool_calls": [{
"id": "call_b",
"function": { "name": "do_b", "arguments": "{\"bar\":2}" }
}]
}
}]
});
let finish = json!({
"choices": [{
"finish_reason": "tool_calls"
}]
});
let body = build_body(&[delta_a, delta_b, finish]);
let events = collect_events(&body).await;
assert_eq!(events.len(), 3);
assert_matches!(
&events[0],
ResponseEvent::OutputItemDone(ResponseItem::FunctionCall { call_id, name, arguments, .. })
if call_id == "call_a" && name == "do_a" && arguments == "{\"foo\":1}"
);
assert_matches!(
&events[1],
ResponseEvent::OutputItemDone(ResponseItem::FunctionCall { call_id, name, arguments, .. })
if call_id == "call_b" && name == "do_b" && arguments == "{\"bar\":2}"
);
assert_matches!(events[2], ResponseEvent::Completed { .. });
}
#[tokio::test]
async fn concatenates_tool_call_arguments_across_deltas() {
let delta_name = json!({
"choices": [{
"delta": {
"tool_calls": [{
"id": "call_a",
"function": { "name": "do_a" }
}]
}
}]
});
let delta_args_1 = json!({
"choices": [{
"delta": {
"tool_calls": [{
"id": "call_a",
"function": { "arguments": "{ \"foo\":" }
}]
}
}]
});
let delta_args_2 = json!({
"choices": [{
"delta": {
"tool_calls": [{
"id": "call_a",
"function": { "arguments": "1}" }
}]
}
}]
});
let finish = json!({
"choices": [{
"finish_reason": "tool_calls"
}]
});
let body = build_body(&[delta_name, delta_args_1, delta_args_2, finish]);
let events = collect_events(&body).await;
assert_matches!(
&events[..],
[
ResponseEvent::OutputItemDone(ResponseItem::FunctionCall { call_id, name, arguments, .. }),
ResponseEvent::Completed { .. }
] if call_id == "call_a" && name == "do_a" && arguments == "{ \"foo\":1}"
);
}
#[tokio::test]
async fn emits_tool_calls_even_when_content_and_reasoning_present() {
let delta_content_and_tools = json!({
"choices": [{
"delta": {
"content": [{"text": "hi"}],
"reasoning": "because",
"tool_calls": [{
"id": "call_a",
"function": { "name": "do_a", "arguments": "{}" }
}]
}
}]
});
let finish = json!({
"choices": [{
"finish_reason": "tool_calls"
}]
});
let body = build_body(&[delta_content_and_tools, finish]);
let events = collect_events(&body).await;
assert_matches!(
&events[..],
[
ResponseEvent::OutputItemAdded(ResponseItem::Reasoning { .. }),
ResponseEvent::ReasoningContentDelta { .. },
ResponseEvent::OutputItemAdded(ResponseItem::Message { .. }),
ResponseEvent::OutputTextDelta(delta),
ResponseEvent::OutputItemDone(ResponseItem::Reasoning { .. }),
ResponseEvent::OutputItemDone(ResponseItem::FunctionCall { call_id, name, .. }),
ResponseEvent::OutputItemDone(ResponseItem::Message { .. }),
ResponseEvent::Completed { .. }
] if delta == "hi" && call_id == "call_a" && name == "do_a"
);
}
#[tokio::test]
async fn drops_partial_tool_calls_on_stop_finish_reason() {
let delta_tool = json!({
"choices": [{
"delta": {
"tool_calls": [{
"id": "call_a",
"function": { "name": "do_a", "arguments": "{}" }
}]
}
}]
});
let finish_stop = json!({
"choices": [{
"finish_reason": "stop"
}]
});
let body = build_body(&[delta_tool, finish_stop]);
let events = collect_events(&body).await;
assert!(!events.iter().any(|ev| {
matches!(
ev,
ResponseEvent::OutputItemDone(ResponseItem::FunctionCall { .. })
)
}));
assert_matches!(events.last(), Some(ResponseEvent::Completed { .. }));
}
}

View File

@@ -0,0 +1,6 @@
pub mod chat;
pub mod responses;
pub use responses::process_sse;
pub use responses::spawn_response_stream;
pub use responses::stream_from_fixture;

View File

@@ -0,0 +1,672 @@
use crate::common::ResponseEvent;
use crate::common::ResponseStream;
use crate::error::ApiError;
use crate::rate_limits::parse_rate_limit;
use crate::telemetry::SseTelemetry;
use codex_client::ByteStream;
use codex_client::StreamResponse;
use codex_client::TransportError;
use codex_protocol::models::ResponseItem;
use codex_protocol::protocol::TokenUsage;
use eventsource_stream::Eventsource;
use futures::StreamExt;
use futures::TryStreamExt;
use serde::Deserialize;
use serde_json::Value;
use std::io::BufRead;
use std::path::Path;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::mpsc;
use tokio::time::Instant;
use tokio::time::timeout;
use tokio_util::io::ReaderStream;
use tracing::debug;
use tracing::trace;
/// Streams SSE events from an on-disk fixture for tests.
pub fn stream_from_fixture(
path: impl AsRef<Path>,
idle_timeout: Duration,
) -> Result<ResponseStream, ApiError> {
let file =
std::fs::File::open(path.as_ref()).map_err(|err| ApiError::Stream(err.to_string()))?;
let mut content = String::new();
for line in std::io::BufReader::new(file).lines() {
let line = line.map_err(|err| ApiError::Stream(err.to_string()))?;
content.push_str(&line);
content.push_str("\n\n");
}
let reader = std::io::Cursor::new(content);
let stream = ReaderStream::new(reader).map_err(|err| TransportError::Network(err.to_string()));
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent, ApiError>>(1600);
tokio::spawn(process_sse(Box::pin(stream), tx_event, idle_timeout, None));
Ok(ResponseStream { rx_event })
}
pub fn spawn_response_stream(
stream_response: StreamResponse,
idle_timeout: Duration,
telemetry: Option<Arc<dyn SseTelemetry>>,
) -> ResponseStream {
let rate_limits = parse_rate_limit(&stream_response.headers);
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent, ApiError>>(1600);
tokio::spawn(async move {
if let Some(snapshot) = rate_limits {
let _ = tx_event.send(Ok(ResponseEvent::RateLimits(snapshot))).await;
}
process_sse(stream_response.bytes, tx_event, idle_timeout, telemetry).await;
});
ResponseStream { rx_event }
}
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
struct Error {
r#type: Option<String>,
code: Option<String>,
message: Option<String>,
plan_type: Option<String>,
resets_at: Option<i64>,
}
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
struct ResponseCompleted {
id: String,
#[serde(default)]
usage: Option<ResponseCompletedUsage>,
}
#[derive(Debug, Deserialize)]
struct ResponseCompletedUsage {
input_tokens: i64,
input_tokens_details: Option<ResponseCompletedInputTokensDetails>,
output_tokens: i64,
output_tokens_details: Option<ResponseCompletedOutputTokensDetails>,
total_tokens: i64,
}
impl From<ResponseCompletedUsage> for TokenUsage {
fn from(val: ResponseCompletedUsage) -> Self {
TokenUsage {
input_tokens: val.input_tokens,
cached_input_tokens: val
.input_tokens_details
.map(|d| d.cached_tokens)
.unwrap_or(0),
output_tokens: val.output_tokens,
reasoning_output_tokens: val
.output_tokens_details
.map(|d| d.reasoning_tokens)
.unwrap_or(0),
total_tokens: val.total_tokens,
}
}
}
#[derive(Debug, Deserialize)]
struct ResponseCompletedInputTokensDetails {
cached_tokens: i64,
}
#[derive(Debug, Deserialize)]
struct ResponseCompletedOutputTokensDetails {
reasoning_tokens: i64,
}
#[derive(Deserialize, Debug)]
struct SseEvent {
#[serde(rename = "type")]
kind: String,
response: Option<Value>,
item: Option<Value>,
delta: Option<String>,
summary_index: Option<i64>,
content_index: Option<i64>,
}
pub async fn process_sse(
stream: ByteStream,
tx_event: mpsc::Sender<Result<ResponseEvent, ApiError>>,
idle_timeout: Duration,
telemetry: Option<Arc<dyn SseTelemetry>>,
) {
let mut stream = stream.eventsource();
let mut response_completed: Option<ResponseCompleted> = None;
let mut response_error: Option<ApiError> = None;
loop {
let start = Instant::now();
let response = timeout(idle_timeout, stream.next()).await;
if let Some(t) = telemetry.as_ref() {
t.on_sse_poll(&response, start.elapsed());
}
let sse = match response {
Ok(Some(Ok(sse))) => sse,
Ok(Some(Err(e))) => {
debug!("SSE Error: {e:#}");
let _ = tx_event.send(Err(ApiError::Stream(e.to_string()))).await;
return;
}
Ok(None) => {
match response_completed.take() {
Some(ResponseCompleted { id, usage }) => {
let event = ResponseEvent::Completed {
response_id: id,
token_usage: usage.map(Into::into),
};
let _ = tx_event.send(Ok(event)).await;
}
None => {
let error = response_error.unwrap_or(ApiError::Stream(
"stream closed before response.completed".into(),
));
let _ = tx_event.send(Err(error)).await;
}
}
return;
}
Err(_) => {
let _ = tx_event
.send(Err(ApiError::Stream("idle timeout waiting for SSE".into())))
.await;
return;
}
};
let raw = sse.data.clone();
trace!("SSE event: {raw}");
let event: SseEvent = match serde_json::from_str(&sse.data) {
Ok(event) => event,
Err(e) => {
debug!("Failed to parse SSE event: {e}, data: {}", &sse.data);
continue;
}
};
match event.kind.as_str() {
"response.output_item.done" => {
let Some(item_val) = event.item else { continue };
let Ok(item) = serde_json::from_value::<ResponseItem>(item_val) else {
debug!("failed to parse ResponseItem from output_item.done");
continue;
};
let event = ResponseEvent::OutputItemDone(item);
if tx_event.send(Ok(event)).await.is_err() {
return;
}
}
"response.output_text.delta" => {
if let Some(delta) = event.delta {
let event = ResponseEvent::OutputTextDelta(delta);
if tx_event.send(Ok(event)).await.is_err() {
return;
}
}
}
"response.reasoning_summary_text.delta" => {
if let (Some(delta), Some(summary_index)) = (event.delta, event.summary_index) {
let event = ResponseEvent::ReasoningSummaryDelta {
delta,
summary_index,
};
if tx_event.send(Ok(event)).await.is_err() {
return;
}
}
}
"response.reasoning_text.delta" => {
if let (Some(delta), Some(content_index)) = (event.delta, event.content_index) {
let event = ResponseEvent::ReasoningContentDelta {
delta,
content_index,
};
if tx_event.send(Ok(event)).await.is_err() {
return;
}
}
}
"response.created" => {
if event.response.is_some() {
let _ = tx_event.send(Ok(ResponseEvent::Created {})).await;
}
}
"response.failed" => {
if let Some(resp_val) = event.response {
response_error =
Some(ApiError::Stream("response.failed event received".into()));
if let Some(error) = resp_val.get("error")
&& let Ok(error) = serde_json::from_value::<Error>(error.clone())
{
if is_context_window_error(&error) {
response_error = Some(ApiError::ContextWindowExceeded);
} else if is_quota_exceeded_error(&error) {
response_error = Some(ApiError::QuotaExceeded);
} else if is_usage_not_included(&error) {
response_error = Some(ApiError::UsageNotIncluded);
} else {
let delay = try_parse_retry_after(&error);
let message = error.message.clone().unwrap_or_default();
response_error = Some(ApiError::Retryable { message, delay });
}
}
}
}
"response.completed" => {
if let Some(resp_val) = event.response {
match serde_json::from_value::<ResponseCompleted>(resp_val) {
Ok(r) => {
response_completed = Some(r);
}
Err(e) => {
let error = format!("failed to parse ResponseCompleted: {e}");
debug!(error);
response_error = Some(ApiError::Stream(error));
continue;
}
};
};
}
"response.output_item.added" => {
let Some(item_val) = event.item else { continue };
let Ok(item) = serde_json::from_value::<ResponseItem>(item_val) else {
debug!("failed to parse ResponseItem from output_item.done");
continue;
};
let event = ResponseEvent::OutputItemAdded(item);
if tx_event.send(Ok(event)).await.is_err() {
return;
}
}
"response.reasoning_summary_part.added" => {
if let Some(summary_index) = event.summary_index {
let event = ResponseEvent::ReasoningSummaryPartAdded { summary_index };
if tx_event.send(Ok(event)).await.is_err() {
return;
}
}
}
_ => {}
}
}
}
fn try_parse_retry_after(err: &Error) -> Option<Duration> {
if err.code.as_deref() != Some("rate_limit_exceeded") {
return None;
}
let re = rate_limit_regex();
if let Some(message) = &err.message
&& let Some(captures) = re.captures(message)
{
let seconds = captures.get(1);
let unit = captures.get(2);
if let (Some(value), Some(unit)) = (seconds, unit) {
let value = value.as_str().parse::<f64>().ok()?;
let unit = unit.as_str().to_ascii_lowercase();
if unit == "s" || unit.starts_with("second") {
return Some(Duration::from_secs_f64(value));
} else if unit == "ms" {
return Some(Duration::from_millis(value as u64));
}
}
}
None
}
fn is_context_window_error(error: &Error) -> bool {
error.code.as_deref() == Some("context_length_exceeded")
}
fn is_quota_exceeded_error(error: &Error) -> bool {
error.code.as_deref() == Some("insufficient_quota")
}
fn is_usage_not_included(error: &Error) -> bool {
error.code.as_deref() == Some("usage_not_included")
}
fn rate_limit_regex() -> &'static regex_lite::Regex {
static RE: std::sync::OnceLock<regex_lite::Regex> = std::sync::OnceLock::new();
#[expect(clippy::unwrap_used)]
RE.get_or_init(|| {
regex_lite::Regex::new(r"(?i)try again in\s*(\d+(?:\.\d+)?)\s*(s|ms|seconds?)").unwrap()
})
}
#[cfg(test)]
mod tests {
use super::*;
use assert_matches::assert_matches;
use codex_protocol::models::ResponseItem;
use pretty_assertions::assert_eq;
use serde_json::json;
use tokio::sync::mpsc;
use tokio_test::io::Builder as IoBuilder;
async fn collect_events(chunks: &[&[u8]]) -> Vec<Result<ResponseEvent, ApiError>> {
let mut builder = IoBuilder::new();
for chunk in chunks {
builder.read(chunk);
}
let reader = builder.build();
let stream =
ReaderStream::new(reader).map_err(|err| TransportError::Network(err.to_string()));
let (tx, mut rx) = mpsc::channel::<Result<ResponseEvent, ApiError>>(16);
tokio::spawn(process_sse(Box::pin(stream), tx, idle_timeout(), None));
let mut events = Vec::new();
while let Some(ev) = rx.recv().await {
events.push(ev);
}
events
}
async fn run_sse(events: Vec<serde_json::Value>) -> Vec<ResponseEvent> {
let mut body = String::new();
for e in events {
let kind = e
.get("type")
.and_then(|v| v.as_str())
.expect("fixture event missing type");
if e.as_object().map(|o| o.len() == 1).unwrap_or(false) {
body.push_str(&format!("event: {kind}\n\n"));
} else {
body.push_str(&format!("event: {kind}\ndata: {e}\n\n"));
}
}
let (tx, mut rx) = mpsc::channel::<Result<ResponseEvent, ApiError>>(8);
let stream = ReaderStream::new(std::io::Cursor::new(body))
.map_err(|err| TransportError::Network(err.to_string()));
tokio::spawn(process_sse(Box::pin(stream), tx, idle_timeout(), None));
let mut out = Vec::new();
while let Some(ev) = rx.recv().await {
out.push(ev.expect("channel closed"));
}
out
}
fn idle_timeout() -> Duration {
Duration::from_millis(1000)
}
#[tokio::test]
async fn parses_items_and_completed() {
let item1 = json!({
"type": "response.output_item.done",
"item": {
"type": "message",
"role": "assistant",
"content": [{"type": "output_text", "text": "Hello"}]
}
})
.to_string();
let item2 = json!({
"type": "response.output_item.done",
"item": {
"type": "message",
"role": "assistant",
"content": [{"type": "output_text", "text": "World"}]
}
})
.to_string();
let completed = json!({
"type": "response.completed",
"response": { "id": "resp1" }
})
.to_string();
let sse1 = format!("event: response.output_item.done\ndata: {item1}\n\n");
let sse2 = format!("event: response.output_item.done\ndata: {item2}\n\n");
let sse3 = format!("event: response.completed\ndata: {completed}\n\n");
let events = collect_events(&[sse1.as_bytes(), sse2.as_bytes(), sse3.as_bytes()]).await;
assert_eq!(events.len(), 3);
assert_matches!(
&events[0],
Ok(ResponseEvent::OutputItemDone(ResponseItem::Message { role, .. }))
if role == "assistant"
);
assert_matches!(
&events[1],
Ok(ResponseEvent::OutputItemDone(ResponseItem::Message { role, .. }))
if role == "assistant"
);
match &events[2] {
Ok(ResponseEvent::Completed {
response_id,
token_usage,
}) => {
assert_eq!(response_id, "resp1");
assert!(token_usage.is_none());
}
other => panic!("unexpected third event: {other:?}"),
}
}
#[tokio::test]
async fn error_when_missing_completed() {
let item1 = json!({
"type": "response.output_item.done",
"item": {
"type": "message",
"role": "assistant",
"content": [{"type": "output_text", "text": "Hello"}]
}
})
.to_string();
let sse1 = format!("event: response.output_item.done\ndata: {item1}\n\n");
let events = collect_events(&[sse1.as_bytes()]).await;
assert_eq!(events.len(), 2);
assert_matches!(events[0], Ok(ResponseEvent::OutputItemDone(_)));
match &events[1] {
Err(ApiError::Stream(msg)) => {
assert_eq!(msg, "stream closed before response.completed")
}
other => panic!("unexpected second event: {other:?}"),
}
}
#[tokio::test]
async fn error_when_error_event() {
let raw_error = r#"{"type":"response.failed","sequence_number":3,"response":{"id":"resp_689bcf18d7f08194bf3440ba62fe05d803fee0cdac429894","object":"response","created_at":1755041560,"status":"failed","background":false,"error":{"code":"rate_limit_exceeded","message":"Rate limit reached for gpt-5.1 in organization org-AAA on tokens per min (TPM): Limit 30000, Used 22999, Requested 12528. Please try again in 11.054s. Visit https://platform.openai.com/account/rate-limits to learn more."}, "usage":null,"user":null,"metadata":{}}}"#;
let sse1 = format!("event: response.failed\ndata: {raw_error}\n\n");
let events = collect_events(&[sse1.as_bytes()]).await;
assert_eq!(events.len(), 1);
match &events[0] {
Err(ApiError::Retryable { message, delay }) => {
assert_eq!(
message,
"Rate limit reached for gpt-5.1 in organization org-AAA on tokens per min (TPM): Limit 30000, Used 22999, Requested 12528. Please try again in 11.054s. Visit https://platform.openai.com/account/rate-limits to learn more."
);
assert_eq!(*delay, Some(Duration::from_secs_f64(11.054)));
}
other => panic!("unexpected second event: {other:?}"),
}
}
#[tokio::test]
async fn context_window_error_is_fatal() {
let raw_error = r#"{"type":"response.failed","sequence_number":3,"response":{"id":"resp_5c66275b97b9baef1ed95550adb3b7ec13b17aafd1d2f11b","object":"response","created_at":1759510079,"status":"failed","background":false,"error":{"code":"context_length_exceeded","message":"Your input exceeds the context window of this model. Please adjust your input and try again."},"usage":null,"user":null,"metadata":{}}}"#;
let sse1 = format!("event: response.failed\ndata: {raw_error}\n\n");
let events = collect_events(&[sse1.as_bytes()]).await;
assert_eq!(events.len(), 1);
assert_matches!(events[0], Err(ApiError::ContextWindowExceeded));
}
#[tokio::test]
async fn context_window_error_with_newline_is_fatal() {
let raw_error = r#"{"type":"response.failed","sequence_number":4,"response":{"id":"resp_fatal_newline","object":"response","created_at":1759510080,"status":"failed","background":false,"error":{"code":"context_length_exceeded","message":"Your input exceeds the context window of this model. Please adjust your input and try\nagain."},"usage":null,"user":null,"metadata":{}}}"#;
let sse1 = format!("event: response.failed\ndata: {raw_error}\n\n");
let events = collect_events(&[sse1.as_bytes()]).await;
assert_eq!(events.len(), 1);
assert_matches!(events[0], Err(ApiError::ContextWindowExceeded));
}
#[tokio::test]
async fn quota_exceeded_error_is_fatal() {
let raw_error = r#"{"type":"response.failed","sequence_number":3,"response":{"id":"resp_fatal_quota","object":"response","created_at":1759771626,"status":"failed","background":false,"error":{"code":"insufficient_quota","message":"You exceeded your current quota, please check your plan and billing details. For more information on this error, read the docs: https://platform.openai.com/docs/guides/error-codes/api-errors."},"incomplete_details":null}}"#;
let sse1 = format!("event: response.failed\ndata: {raw_error}\n\n");
let events = collect_events(&[sse1.as_bytes()]).await;
assert_eq!(events.len(), 1);
assert_matches!(events[0], Err(ApiError::QuotaExceeded));
}
#[tokio::test]
async fn table_driven_event_kinds() {
struct TestCase {
name: &'static str,
event: serde_json::Value,
expect_first: fn(&ResponseEvent) -> bool,
expected_len: usize,
}
fn is_created(ev: &ResponseEvent) -> bool {
matches!(ev, ResponseEvent::Created)
}
fn is_output(ev: &ResponseEvent) -> bool {
matches!(ev, ResponseEvent::OutputItemDone(_))
}
fn is_completed(ev: &ResponseEvent) -> bool {
matches!(ev, ResponseEvent::Completed { .. })
}
let completed = json!({
"type": "response.completed",
"response": {
"id": "c",
"usage": {
"input_tokens": 0,
"input_tokens_details": null,
"output_tokens": 0,
"output_tokens_details": null,
"total_tokens": 0
},
"output": []
}
});
let cases = vec![
TestCase {
name: "created",
event: json!({"type": "response.created", "response": {}}),
expect_first: is_created,
expected_len: 2,
},
TestCase {
name: "output_item.done",
event: json!({
"type": "response.output_item.done",
"item": {
"type": "message",
"role": "assistant",
"content": [
{"type": "output_text", "text": "hi"}
]
}
}),
expect_first: is_output,
expected_len: 2,
},
TestCase {
name: "unknown",
event: json!({"type": "response.new_tool_event"}),
expect_first: is_completed,
expected_len: 1,
},
];
for case in cases {
let mut evs = vec![case.event];
evs.push(completed.clone());
let out = run_sse(evs).await;
assert_eq!(out.len(), case.expected_len, "case {}", case.name);
assert!(
(case.expect_first)(&out[0]),
"first event mismatch in case {}",
case.name
);
}
}
#[test]
fn test_try_parse_retry_after() {
let err = Error {
r#type: None,
message: Some("Rate limit reached for gpt-5.1 in organization org- on tokens per min (TPM): Limit 1, Used 1, Requested 19304. Please try again in 28ms. Visit https://platform.openai.com/account/rate-limits to learn more.".to_string()),
code: Some("rate_limit_exceeded".to_string()),
plan_type: None,
resets_at: None,
};
let delay = try_parse_retry_after(&err);
assert_eq!(delay, Some(Duration::from_millis(28)));
}
#[test]
fn test_try_parse_retry_after_no_delay() {
let err = Error {
r#type: None,
message: Some("Rate limit reached for gpt-5.1 in organization <ORG> on tokens per min (TPM): Limit 30000, Used 6899, Requested 24050. Please try again in 1.898s. Visit https://platform.openai.com/account/rate-limits to learn more.".to_string()),
code: Some("rate_limit_exceeded".to_string()),
plan_type: None,
resets_at: None,
};
let delay = try_parse_retry_after(&err);
assert_eq!(delay, Some(Duration::from_secs_f64(1.898)));
}
#[test]
fn test_try_parse_retry_after_azure() {
let err = Error {
r#type: None,
message: Some("Rate limit exceeded. Try again in 35 seconds.".to_string()),
code: Some("rate_limit_exceeded".to_string()),
plan_type: None,
resets_at: None,
};
let delay = try_parse_retry_after(&err);
assert_eq!(delay, Some(Duration::from_secs(35)));
}
}

View File

@@ -0,0 +1,84 @@
use codex_client::Request;
use codex_client::RequestTelemetry;
use codex_client::Response;
use codex_client::RetryPolicy;
use codex_client::StreamResponse;
use codex_client::TransportError;
use codex_client::run_with_retry;
use http::StatusCode;
use std::future::Future;
use std::sync::Arc;
use std::time::Duration;
use tokio::time::Instant;
/// Generic telemetry.
pub trait SseTelemetry: Send + Sync {
fn on_sse_poll(
&self,
result: &Result<
Option<
Result<
eventsource_stream::Event,
eventsource_stream::EventStreamError<TransportError>,
>,
>,
tokio::time::error::Elapsed,
>,
duration: Duration,
);
}
pub(crate) trait WithStatus {
fn status(&self) -> StatusCode;
}
fn http_status(err: &TransportError) -> Option<StatusCode> {
match err {
TransportError::Http { status, .. } => Some(*status),
_ => None,
}
}
impl WithStatus for Response {
fn status(&self) -> StatusCode {
self.status
}
}
impl WithStatus for StreamResponse {
fn status(&self) -> StatusCode {
self.status
}
}
pub(crate) async fn run_with_request_telemetry<T, F, Fut>(
policy: RetryPolicy,
telemetry: Option<Arc<dyn RequestTelemetry>>,
make_request: impl FnMut() -> Request,
send: F,
) -> Result<T, TransportError>
where
T: WithStatus,
F: Clone + Fn(Request) -> Fut,
Fut: Future<Output = Result<T, TransportError>>,
{
// Wraps `run_with_retry` to attach per-attempt request telemetry for both
// unary and streaming HTTP calls.
run_with_retry(policy, make_request, move |req, attempt| {
let telemetry = telemetry.clone();
let send = send.clone();
async move {
let start = Instant::now();
let result = send(req).await;
if let Some(t) = telemetry.as_ref() {
let (status, err) = match &result {
Ok(resp) => (Some(resp.status()), None),
Err(err) => (http_status(err), Some(err)),
};
t.on_request(attempt, status, err, start.elapsed());
}
result
}
})
.await
}