Clean-up 3

This commit is contained in:
jif-oai
2025-11-13 17:09:04 +01:00
parent 29d93176b6
commit e4e627d1a3
6 changed files with 503 additions and 355 deletions

View File

@@ -0,0 +1,229 @@
use std::pin::Pin;
use std::task::Context;
use std::task::Poll;
use crate::client::ResponseEvent;
use crate::error::Result;
use codex_protocol::models::ContentItem;
use codex_protocol::models::ReasoningItemContent;
use codex_protocol::models::ResponseItem;
use futures::Stream;
/// Optional client-side aggregation helper
///
/// Stream adapter that merges the incremental `OutputItemDone` chunks coming from
/// the chat SSE decoder into a *running* assistant message, **suppressing the
/// per-token deltas**. The stream stays silent while the model is thinking and
/// only emits two events per turn:
///
/// 1. `ResponseEvent::OutputItemDone` with the *complete* assistant message
/// (fully concatenated).
/// 2. The original `ResponseEvent::Completed` right after it.
///
/// The adapter is intentionally *lossless*: callers who do **not** opt in via
/// [`AggregateStreamExt::aggregate()`] keep receiving the original unmodified
/// events.
#[derive(Copy, Clone, Eq, PartialEq)]
enum AggregateMode {
AggregatedOnly,
Streaming,
}
pub(crate) struct AggregatedChatStream<S> {
inner: S,
cumulative: String,
cumulative_reasoning: String,
pending: std::collections::VecDeque<ResponseEvent>,
mode: AggregateMode,
}
impl<S> Stream for AggregatedChatStream<S>
where
S: Stream<Item = Result<ResponseEvent>> + Unpin,
{
type Item = Result<ResponseEvent>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
// First, flush any buffered events from the previous call.
if let Some(ev) = this.pending.pop_front() {
return Poll::Ready(Some(Ok(ev)));
}
loop {
match Pin::new(&mut this.inner).poll_next(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(None) => return Poll::Ready(None),
Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item)))) => {
// If this is an incremental assistant message chunk, accumulate but
// do NOT emit yet. Forward any other item (e.g. FunctionCall) right
// away so downstream consumers see it.
let is_assistant_message = matches!(
&item,
ResponseItem::Message { role, .. } if role == "assistant"
);
if is_assistant_message {
match this.mode {
AggregateMode::AggregatedOnly => {
// Only use the final assistant message if we have not
// seen any deltas; otherwise, deltas already built the
// cumulative text and this would duplicate it.
if this.cumulative.is_empty()
&& let ResponseItem::Message { content, .. } = &item
&& let Some(text) = content.iter().find_map(|c| match c {
ContentItem::OutputText { text } => Some(text),
_ => None,
})
{
this.cumulative.push_str(text);
}
// Swallow assistant message here; emit on Completed.
continue;
}
AggregateMode::Streaming => {
// In streaming mode, if we have not seen any deltas, forward
// the final assistant message directly. If deltas were seen,
// suppress the final message to avoid duplication.
if this.cumulative.is_empty() {
return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(
item,
))));
}
continue;
}
}
}
// Not an assistant message forward immediately.
return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item))));
}
Poll::Ready(Some(Ok(ResponseEvent::RateLimits(snapshot)))) => {
return Poll::Ready(Some(Ok(ResponseEvent::RateLimits(snapshot))));
}
Poll::Ready(Some(Ok(ResponseEvent::Completed {
response_id,
token_usage,
}))) => {
// Build any aggregated items in the correct order: Reasoning first, then Message.
let mut emitted_any = false;
if !this.cumulative_reasoning.is_empty()
&& matches!(this.mode, AggregateMode::AggregatedOnly)
{
let aggregated_reasoning = ResponseItem::Reasoning {
id: String::new(),
summary: Vec::new(),
content: Some(vec![ReasoningItemContent::ReasoningText {
text: std::mem::take(&mut this.cumulative_reasoning),
}]),
encrypted_content: None,
};
this.pending
.push_back(ResponseEvent::OutputItemDone(aggregated_reasoning));
emitted_any = true;
}
// Always emit the final aggregated assistant message when any
// content deltas have been observed. In AggregatedOnly mode this
// is the sole assistant output; in Streaming mode this finalizes
// the streamed deltas into a terminal OutputItemDone so callers
// can persist/render the message once per turn.
if !this.cumulative.is_empty() {
let aggregated_message = ResponseItem::Message {
id: None,
role: "assistant".to_string(),
content: vec![ContentItem::OutputText {
text: std::mem::take(&mut this.cumulative),
}],
};
this.pending
.push_back(ResponseEvent::OutputItemDone(aggregated_message));
emitted_any = true;
}
// Always emit Completed last when anything was aggregated.
if emitted_any {
this.pending.push_back(ResponseEvent::Completed {
response_id: response_id.clone(),
token_usage: token_usage.clone(),
});
if let Some(ev) = this.pending.pop_front() {
return Poll::Ready(Some(Ok(ev)));
}
}
// Nothing aggregated forward Completed directly.
return Poll::Ready(Some(Ok(ResponseEvent::Completed {
response_id,
token_usage,
})));
}
Poll::Ready(Some(Ok(ResponseEvent::Created))) => {
// These events are exclusive to the Responses API and
// will never appear in a Chat Completions stream.
continue;
}
Poll::Ready(Some(Ok(ResponseEvent::OutputTextDelta(delta)))) => {
// Always accumulate deltas so we can emit a final OutputItemDone at Completed.
this.cumulative.push_str(&delta);
if matches!(this.mode, AggregateMode::Streaming) {
// In streaming mode, also forward the delta immediately.
return Poll::Ready(Some(Ok(ResponseEvent::OutputTextDelta(delta))));
}
}
Poll::Ready(Some(Ok(ResponseEvent::ReasoningContentDelta(delta)))) => {
// Always accumulate reasoning deltas so we can emit a final Reasoning item at Completed.
this.cumulative_reasoning.push_str(&delta);
if matches!(this.mode, AggregateMode::Streaming) {
// In streaming mode, also forward the delta immediately.
return Poll::Ready(Some(Ok(ResponseEvent::ReasoningContentDelta(delta))));
}
}
Poll::Ready(Some(Ok(ResponseEvent::ReasoningSummaryDelta(_)))) => {}
Poll::Ready(Some(Ok(ResponseEvent::ReasoningSummaryPartAdded))) => {}
Poll::Ready(Some(Ok(ResponseEvent::OutputItemAdded(item)))) => {
return Poll::Ready(Some(Ok(ResponseEvent::OutputItemAdded(item))));
}
}
}
}
}
/// Extension trait that activates aggregation on any stream of [`ResponseEvent`].
pub(crate) trait AggregateStreamExt: Stream<Item = Result<ResponseEvent>> + Sized {
/// Returns a new stream that emits **only** the final assistant message
/// per turn instead of every incremental delta. The produced
/// `ResponseEvent` sequence for a typical text turn looks like:
///
/// ```ignore
/// OutputItemDone(<full message>)
/// Completed
/// ```
///
/// No other `OutputItemDone` events will be seen by the caller.
fn aggregate(self) -> AggregatedChatStream<Self> {
AggregatedChatStream::new(self, AggregateMode::AggregatedOnly)
}
}
impl<T> AggregateStreamExt for T where T: Stream<Item = Result<ResponseEvent>> + Sized {}
impl<S> AggregatedChatStream<S> {
fn new(inner: S, mode: AggregateMode) -> Self {
AggregatedChatStream {
inner,
cumulative: String::new(),
cumulative_reasoning: String::new(),
pending: std::collections::VecDeque::new(),
mode,
}
}
pub(crate) fn streaming_mode(inner: S) -> Self {
Self::new(inner, AggregateMode::Streaming)
}
}

View File

@@ -4,6 +4,8 @@ use crate::ModelProviderInfo;
use crate::client::ResponseEvent;
use crate::client::ResponseStream;
use crate::client::http::CodexHttpClient;
use crate::client::retry::RetryableStreamError;
use crate::client::retry::retry_stream;
use crate::client_common::Prompt;
use crate::error::CodexErr;
use crate::error::ConnectionFailedError;
@@ -25,9 +27,6 @@ use futures::Stream;
use futures::TryStreamExt;
use reqwest::StatusCode;
use serde_json::json;
use std::pin::Pin;
use std::task::Context;
use std::task::Poll;
use tokio::sync::mpsc;
use tracing::trace;
@@ -361,28 +360,15 @@ pub(crate) async fn stream_chat_completions(
}
let max_attempts = provider.request_max_retries();
let mut last_error = None;
for attempt in 0..=max_attempts {
match stream_single_chat_completion(
attempt,
client,
provider,
otel_event_manager,
body.clone(),
)
.await
{
Ok(stream) => return Ok(stream),
Err(e) => {
last_error = Some(e);
if attempt != max_attempts {
tokio::time::sleep(backoff(attempt)).await;
}
}
retry_stream(max_attempts, |attempt| {
let body = body.clone();
async move {
stream_single_chat_completion(attempt, client, provider, otel_event_manager, body)
.await
.map_err(ChatStreamError::Retryable)
}
}
Err(last_error.unwrap_or(CodexErr::InternalServerError))
})
.await
}
async fn stream_single_chat_completion(
@@ -463,6 +449,22 @@ async fn stream_single_chat_completion(
}
}
enum ChatStreamError {
Retryable(CodexErr),
}
impl RetryableStreamError for ChatStreamError {
fn delay(&self, attempt: u64) -> Option<Duration> {
Some(backoff(attempt))
}
fn into_error(self) -> CodexErr {
match self {
ChatStreamError::Retryable(e) => e,
}
}
}
async fn append_assistant_text(
tx_event: &mpsc::Sender<Result<ResponseEvent>>,
assistant_item: &mut Option<ResponseItem>,
@@ -748,224 +750,3 @@ async fn process_chat_sse<S>(
}
}
}
/// Optional client-side aggregation helper
///
/// Stream adapter that merges the incremental `OutputItemDone` chunks coming from
/// [`process_chat_sse`] into a *running* assistant message, **suppressing the
/// per-token deltas**. The stream stays silent while the model is thinking
/// and only emits two events per turn:
///
/// 1. `ResponseEvent::OutputItemDone` with the *complete* assistant message
/// (fully concatenated).
/// 2. The original `ResponseEvent::Completed` right after it.
///
/// This mirrors the behaviour the TypeScript CLI exposes to its higher layers.
///
/// The adapter is intentionally *lossless*: callers who do **not** opt in via
/// [`AggregateStreamExt::aggregate()`] keep receiving the original unmodified
/// events.
#[derive(Copy, Clone, Eq, PartialEq)]
enum AggregateMode {
AggregatedOnly,
Streaming,
}
pub(crate) struct AggregatedChatStream<S> {
inner: S,
cumulative: String,
cumulative_reasoning: String,
pending: std::collections::VecDeque<ResponseEvent>,
mode: AggregateMode,
}
impl<S> Stream for AggregatedChatStream<S>
where
S: Stream<Item = Result<ResponseEvent>> + Unpin,
{
type Item = Result<ResponseEvent>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
// First, flush any buffered events from the previous call.
if let Some(ev) = this.pending.pop_front() {
return Poll::Ready(Some(Ok(ev)));
}
loop {
match Pin::new(&mut this.inner).poll_next(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(None) => return Poll::Ready(None),
Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item)))) => {
// If this is an incremental assistant message chunk, accumulate but
// do NOT emit yet. Forward any other item (e.g. FunctionCall) right
// away so downstream consumers see it.
let is_assistant_message = matches!(
&item,
ResponseItem::Message { role, .. } if role == "assistant"
);
if is_assistant_message {
match this.mode {
AggregateMode::AggregatedOnly => {
// Only use the final assistant message if we have not
// seen any deltas; otherwise, deltas already built the
// cumulative text and this would duplicate it.
if this.cumulative.is_empty()
&& let ResponseItem::Message { content, .. } = &item
&& let Some(text) = content.iter().find_map(|c| match c {
ContentItem::OutputText { text } => Some(text),
_ => None,
})
{
this.cumulative.push_str(text);
}
// Swallow assistant message here; emit on Completed.
continue;
}
AggregateMode::Streaming => {
// In streaming mode, if we have not seen any deltas, forward
// the final assistant message directly. If deltas were seen,
// suppress the final message to avoid duplication.
if this.cumulative.is_empty() {
return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(
item,
))));
}
continue;
}
}
}
// Not an assistant message forward immediately.
return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item))));
}
Poll::Ready(Some(Ok(ResponseEvent::RateLimits(snapshot)))) => {
return Poll::Ready(Some(Ok(ResponseEvent::RateLimits(snapshot))));
}
Poll::Ready(Some(Ok(ResponseEvent::Completed {
response_id,
token_usage,
}))) => {
// Build any aggregated items in the correct order: Reasoning first, then Message.
let mut emitted_any = false;
if !this.cumulative_reasoning.is_empty()
&& matches!(this.mode, AggregateMode::AggregatedOnly)
{
let aggregated_reasoning = ResponseItem::Reasoning {
id: String::new(),
summary: Vec::new(),
content: Some(vec![ReasoningItemContent::ReasoningText {
text: std::mem::take(&mut this.cumulative_reasoning),
}]),
encrypted_content: None,
};
this.pending
.push_back(ResponseEvent::OutputItemDone(aggregated_reasoning));
emitted_any = true;
}
// Always emit the final aggregated assistant message when any
// content deltas have been observed. In AggregatedOnly mode this
// is the sole assistant output; in Streaming mode this finalizes
// the streamed deltas into a terminal OutputItemDone so callers
// can persist/render the message once per turn.
if !this.cumulative.is_empty() {
let aggregated_message = ResponseItem::Message {
id: None,
role: "assistant".to_string(),
content: vec![ContentItem::OutputText {
text: std::mem::take(&mut this.cumulative),
}],
};
this.pending
.push_back(ResponseEvent::OutputItemDone(aggregated_message));
emitted_any = true;
}
// Always emit Completed last when anything was aggregated.
if emitted_any {
this.pending.push_back(ResponseEvent::Completed {
response_id: response_id.clone(),
token_usage: token_usage.clone(),
});
if let Some(ev) = this.pending.pop_front() {
return Poll::Ready(Some(Ok(ev)));
}
}
// Nothing aggregated forward Completed directly.
return Poll::Ready(Some(Ok(ResponseEvent::Completed {
response_id,
token_usage,
})));
}
Poll::Ready(Some(Ok(ResponseEvent::Created))) => {
// These events are exclusive to the Responses API and
// will never appear in a Chat Completions stream.
continue;
}
Poll::Ready(Some(Ok(ResponseEvent::OutputTextDelta(delta)))) => {
// Always accumulate deltas so we can emit a final OutputItemDone at Completed.
this.cumulative.push_str(&delta);
if matches!(this.mode, AggregateMode::Streaming) {
// In streaming mode, also forward the delta immediately.
return Poll::Ready(Some(Ok(ResponseEvent::OutputTextDelta(delta))));
}
}
Poll::Ready(Some(Ok(ResponseEvent::ReasoningContentDelta(delta)))) => {
// Always accumulate reasoning deltas so we can emit a final Reasoning item at Completed.
this.cumulative_reasoning.push_str(&delta);
if matches!(this.mode, AggregateMode::Streaming) {
// In streaming mode, also forward the delta immediately.
return Poll::Ready(Some(Ok(ResponseEvent::ReasoningContentDelta(delta))));
}
}
Poll::Ready(Some(Ok(ResponseEvent::ReasoningSummaryDelta(_)))) => {}
Poll::Ready(Some(Ok(ResponseEvent::ReasoningSummaryPartAdded))) => {}
Poll::Ready(Some(Ok(ResponseEvent::OutputItemAdded(item)))) => {
return Poll::Ready(Some(Ok(ResponseEvent::OutputItemAdded(item))));
}
}
}
}
}
/// Extension trait that activates aggregation on any stream of [`ResponseEvent`].
pub(crate) trait AggregateStreamExt: Stream<Item = Result<ResponseEvent>> + Sized {
/// Returns a new stream that emits **only** the final assistant message
/// per turn instead of every incremental delta. The produced
/// `ResponseEvent` sequence for a typical text turn looks like:
///
/// ```ignore
/// OutputItemDone(<full message>)
/// Completed
/// ```
///
/// No other `OutputItemDone` events will be seen by the caller.
fn aggregate(self) -> AggregatedChatStream<Self> {
AggregatedChatStream::new(self, AggregateMode::AggregatedOnly)
}
}
impl<T> AggregateStreamExt for T where T: Stream<Item = Result<ResponseEvent>> + Sized {}
impl<S> AggregatedChatStream<S> {
fn new(inner: S, mode: AggregateMode) -> Self {
AggregatedChatStream {
inner,
cumulative: String::new(),
cumulative_reasoning: String::new(),
pending: std::collections::VecDeque::new(),
mode,
}
}
pub(crate) fn streaming_mode(inner: S) -> Self {
Self::new(inner, AggregateMode::Streaming)
}
}

View File

@@ -1,11 +1,14 @@
mod aggregation;
mod chat_completions;
pub mod http;
mod rate_limits;
mod responses;
mod retry;
mod sse;
pub mod types;
pub(crate) use chat_completions::AggregateStreamExt;
pub(crate) use chat_completions::AggregatedChatStream;
pub(crate) use aggregation::AggregateStreamExt;
pub(crate) use aggregation::AggregatedChatStream;
pub(crate) use chat_completions::stream_chat_completions;
pub use responses::ModelClient;
pub(crate) use types::FreeformTool;

View File

@@ -0,0 +1,86 @@
use crate::protocol::RateLimitSnapshot;
use crate::protocol::RateLimitWindow;
use chrono::Utc;
use reqwest::header::HeaderMap;
/// Prefer Codex-specific aggregate rate limit headers if present; fall back
/// to raw OpenAI-style request headers otherwise.
pub(crate) fn parse_rate_limit_snapshot(headers: &HeaderMap) -> Option<RateLimitSnapshot> {
parse_codex_rate_limits(headers).or_else(|| parse_openai_rate_limits(headers))
}
fn parse_codex_rate_limits(headers: &HeaderMap) -> Option<RateLimitSnapshot> {
fn parse_f64(headers: &HeaderMap, name: &str) -> Option<f64> {
headers
.get(name)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<f64>().ok())
}
fn parse_i64(headers: &HeaderMap, name: &str) -> Option<i64> {
headers
.get(name)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<i64>().ok())
}
let primary_used = parse_f64(headers, "x-codex-primary-used-percent");
let secondary_used = parse_f64(headers, "x-codex-secondary-used-percent");
if primary_used.is_none() && secondary_used.is_none() {
return None;
}
let primary = primary_used.map(|used_percent| RateLimitWindow {
used_percent,
window_minutes: parse_i64(headers, "x-codex-primary-window-minutes"),
resets_at: parse_i64(headers, "x-codex-primary-reset-at"),
});
let secondary = secondary_used.map(|used_percent| RateLimitWindow {
used_percent,
window_minutes: parse_i64(headers, "x-codex-secondary-window-minutes"),
resets_at: parse_i64(headers, "x-codex-secondary-reset-at"),
});
Some(RateLimitSnapshot { primary, secondary })
}
fn parse_openai_rate_limits(headers: &HeaderMap) -> Option<RateLimitSnapshot> {
let limit = headers.get("x-ratelimit-limit-requests")?;
let remaining = headers.get("x-ratelimit-remaining-requests")?;
let reset_ms = headers.get("x-ratelimit-reset-requests")?;
let limit = limit.to_str().ok()?.parse::<f64>().ok()?;
let remaining = remaining.to_str().ok()?.parse::<f64>().ok()?;
let reset_ms = reset_ms.to_str().ok()?.parse::<i64>().ok()?;
if limit <= 0.0 {
return None;
}
let used = (limit - remaining).max(0.0);
let used_percent = (used / limit) * 100.0;
let window_minutes = if reset_ms <= 0 {
None
} else {
let seconds = reset_ms / 1000;
Some((seconds + 59) / 60)
};
let resets_at = if reset_ms > 0 {
Some(Utc::now().timestamp() + reset_ms / 1000)
} else {
None
};
Some(RateLimitSnapshot {
primary: Some(RateLimitWindow {
used_percent,
window_minutes,
resets_at,
}),
secondary: None,
})
}

View File

@@ -16,7 +16,6 @@ use eventsource_stream::Eventsource;
use futures::prelude::*;
use regex_lite::Regex;
use reqwest::StatusCode;
use reqwest::header::HeaderMap;
use serde::Deserialize;
use serde_json::Value;
use tokio::sync::mpsc;
@@ -33,6 +32,9 @@ use crate::client::ResponseEvent;
use crate::client::ResponseStream;
use crate::client::ResponsesApiRequest;
use crate::client::create_text_param_for_request;
use crate::client::rate_limits::parse_rate_limit_snapshot;
use crate::client::retry::RetryableStreamError;
use crate::client::retry::retry_stream;
use crate::client_common::Prompt;
use crate::config::Config;
use crate::default_client::CodexHttpClient;
@@ -49,7 +51,6 @@ use crate::model_provider_info::ModelProviderInfo;
use crate::model_provider_info::WireApi;
use crate::openai_model_info::get_model_info;
use crate::protocol::RateLimitSnapshot;
use crate::protocol::RateLimitWindow;
use crate::protocol::TokenUsage;
use crate::token_data::PlanType;
use crate::tools::spec::create_tools_json_for_responses_api;
@@ -260,30 +261,10 @@ impl ModelClient {
}
let max_attempts = self.provider.request_max_retries();
for attempt in 0..=max_attempts {
match self
.attempt_stream_responses(attempt, &payload_json, &auth_manager)
.await
{
Ok(stream) => {
return Ok(stream);
}
Err(StreamAttemptError::Fatal(e)) => {
return Err(e);
}
Err(retryable_attempt_error) => {
if attempt == max_attempts {
return Err(retryable_attempt_error.into_error());
}
if let Some(delay) = retryable_attempt_error.delay(attempt) {
tokio::time::sleep(delay).await;
}
}
}
}
unreachable!("stream_responses_attempt should always return");
retry_stream(max_attempts, |attempt| {
self.attempt_stream_responses(attempt, &payload_json, &auth_manager)
})
.await
}
/// Single attempt to start a streaming Responses API call.
@@ -506,88 +487,6 @@ impl ModelClient {
}
}
fn parse_rate_limit_snapshot(headers: &HeaderMap) -> Option<RateLimitSnapshot> {
// Prefer codex-specific aggregate rate limit headers if present; fall back
// to raw OpenAI-style request headers otherwise.
parse_codex_rate_limits(headers).or_else(|| parse_openai_rate_limits(headers))
}
fn parse_codex_rate_limits(headers: &HeaderMap) -> Option<RateLimitSnapshot> {
fn parse_f64(headers: &HeaderMap, name: &str) -> Option<f64> {
headers
.get(name)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<f64>().ok())
}
fn parse_i64(headers: &HeaderMap, name: &str) -> Option<i64> {
headers
.get(name)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<i64>().ok())
}
let primary_used = parse_f64(headers, "x-codex-primary-used-percent");
let secondary_used = parse_f64(headers, "x-codex-secondary-used-percent");
if primary_used.is_none() && secondary_used.is_none() {
return None;
}
let primary = primary_used.map(|used_percent| RateLimitWindow {
used_percent,
window_minutes: parse_i64(headers, "x-codex-primary-window-minutes"),
resets_at: parse_i64(headers, "x-codex-primary-reset-at"),
});
let secondary = secondary_used.map(|used_percent| RateLimitWindow {
used_percent,
window_minutes: parse_i64(headers, "x-codex-secondary-window-minutes"),
resets_at: parse_i64(headers, "x-codex-secondary-reset-at"),
});
Some(RateLimitSnapshot { primary, secondary })
}
fn parse_openai_rate_limits(headers: &HeaderMap) -> Option<RateLimitSnapshot> {
let limit = headers.get("x-ratelimit-limit-requests")?;
let remaining = headers.get("x-ratelimit-remaining-requests")?;
let reset_ms = headers.get("x-ratelimit-reset-requests")?;
let limit = limit.to_str().ok()?.parse::<f64>().ok()?;
let remaining = remaining.to_str().ok()?.parse::<f64>().ok()?;
let reset_ms = reset_ms.to_str().ok()?.parse::<i64>().ok()?;
if limit <= 0.0 {
return None;
}
let used = (limit - remaining).max(0.0);
let used_percent = (used / limit) * 100.0;
let window_minutes = if reset_ms <= 0 {
None
} else {
let seconds = reset_ms / 1000;
Some((seconds + 59) / 60)
};
let resets_at = if reset_ms > 0 {
Some(Utc::now().timestamp() + reset_ms / 1000)
} else {
None
};
Some(RateLimitSnapshot {
primary: Some(RateLimitWindow {
used_percent,
window_minutes,
resets_at,
}),
secondary: None,
})
}
/// For Azure Responses endpoints we must use `store: true` and preserve
/// per-item identifiers on the input payload. The `ResponseItem` schema
/// deliberately skips serializing these IDs by default, so we patch them
@@ -681,6 +580,16 @@ impl StreamAttemptError {
}
}
impl RetryableStreamError for StreamAttemptError {
fn delay(&self, attempt: u64) -> Option<Duration> {
self.delay(attempt)
}
fn into_error(self) -> CodexErr {
self.into_error()
}
}
async fn process_sse(
stream: impl Stream<Item = std::result::Result<Bytes, CodexErr>> + Unpin + Send + 'static,
tx_event: mpsc::Sender<Result<ResponseEvent>>,

View File

@@ -0,0 +1,140 @@
use std::time::Duration;
use crate::error::CodexErr;
use crate::error::Result;
/// Common interface for classifying stream start errors as retryable or fatal.
pub(crate) trait RetryableStreamError {
/// Returns a delay for the next retry attempt, or `None` if the error
/// should be treated as fatal and not retried.
fn delay(&self, attempt: u64) -> Option<Duration>;
/// Converts this error into the final `CodexErr` that should be surfaced
/// to callers when retries are exhausted or the error is fatal.
fn into_error(self) -> CodexErr;
}
/// Helper to retry a streaming operation with provider-configured backoff.
///
/// The caller supplies an `attempt_fn` that is invoked once per attempt with
/// the current attempt index in `[0, max_attempts]`. On success, the value is
/// returned immediately. On error, the error's [`RetryableStreamError`]
/// implementation decides whether to retry (with an optional delay) or to
/// surface a final error.
pub(crate) async fn retry_stream<F, Fut, T, E>(max_attempts: u64, mut attempt_fn: F) -> Result<T>
where
F: FnMut(u64) -> Fut,
Fut: std::future::Future<Output = std::result::Result<T, E>>,
E: RetryableStreamError,
{
for attempt in 0..=max_attempts {
match attempt_fn(attempt).await {
Ok(value) => return Ok(value),
Err(err) => {
let delay = err.delay(attempt);
// Fatal error or final attempt: surface to caller.
if attempt == max_attempts || delay.is_none() {
return Err(err.into_error());
}
if let Some(duration) = delay {
tokio::time::sleep(duration).await;
}
}
}
}
unreachable!("retry_stream should always return");
}
#[cfg(test)]
mod tests {
use super::*;
use pretty_assertions::assert_eq;
#[derive(Clone)]
struct TestError {
fatal: bool,
}
impl RetryableStreamError for TestError {
fn delay(&self, attempt: u64) -> Option<Duration> {
if self.fatal {
None
} else {
Some(Duration::from_millis(attempt * 10))
}
}
fn into_error(self) -> CodexErr {
if self.fatal {
CodexErr::InternalServerError
} else {
CodexErr::Io(std::io::Error::new(std::io::ErrorKind::Other, "retryable"))
}
}
}
#[tokio::test]
async fn retries_until_success_before_max_attempts() {
let max_attempts = 3;
let result: Result<&str> = retry_stream(max_attempts, |attempt| async move {
if attempt < 2 {
Err(TestError { fatal: false })
} else {
Ok("ok")
}
})
.await;
assert_eq!(result.unwrap(), "ok");
}
#[tokio::test]
async fn stops_on_fatal_error_without_retrying() {
use std::sync::Arc;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;
let calls = Arc::new(AtomicUsize::new(0));
let calls_ref = calls.clone();
let result: Result<()> = retry_stream(5, move |_attempt| {
let calls_ref = calls_ref.clone();
async move {
calls_ref.fetch_add(1, Ordering::SeqCst);
Err(TestError { fatal: true })
}
})
.await;
assert!(result.is_err());
assert_eq!(calls.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn stops_after_max_attempts_for_retryable_errors() {
use std::sync::Arc;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;
let calls = Arc::new(AtomicUsize::new(0));
let calls_ref = calls.clone();
let max_attempts = 2;
let result: Result<()> = retry_stream(max_attempts, move |_attempt| {
let calls_ref = calls_ref.clone();
async move {
calls_ref.fetch_add(1, Ordering::SeqCst);
Err(TestError { fatal: false })
}
})
.await;
assert!(result.is_err());
assert_eq!(calls.load(Ordering::SeqCst), (max_attempts + 1) as usize);
}
}