diff --git a/codex-rs/core/src/client/aggregation.rs b/codex-rs/core/src/client/aggregation.rs new file mode 100644 index 0000000000..42d6c99270 --- /dev/null +++ b/codex-rs/core/src/client/aggregation.rs @@ -0,0 +1,229 @@ +use std::pin::Pin; +use std::task::Context; +use std::task::Poll; + +use crate::client::ResponseEvent; +use crate::error::Result; +use codex_protocol::models::ContentItem; +use codex_protocol::models::ReasoningItemContent; +use codex_protocol::models::ResponseItem; +use futures::Stream; + +/// Optional client-side aggregation helper +/// +/// Stream adapter that merges the incremental `OutputItemDone` chunks coming from +/// the chat SSE decoder into a *running* assistant message, **suppressing the +/// per-token deltas**. The stream stays silent while the model is thinking and +/// only emits two events per turn: +/// +/// 1. `ResponseEvent::OutputItemDone` with the *complete* assistant message +/// (fully concatenated). +/// 2. The original `ResponseEvent::Completed` right after it. +/// +/// The adapter is intentionally *lossless*: callers who do **not** opt in via +/// [`AggregateStreamExt::aggregate()`] keep receiving the original unmodified +/// events. +#[derive(Copy, Clone, Eq, PartialEq)] +enum AggregateMode { + AggregatedOnly, + Streaming, +} + +pub(crate) struct AggregatedChatStream { + inner: S, + cumulative: String, + cumulative_reasoning: String, + pending: std::collections::VecDeque, + mode: AggregateMode, +} + +impl Stream for AggregatedChatStream +where + S: Stream> + Unpin, +{ + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + + // First, flush any buffered events from the previous call. + if let Some(ev) = this.pending.pop_front() { + return Poll::Ready(Some(Ok(ev))); + } + + loop { + match Pin::new(&mut this.inner).poll_next(cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(None) => return Poll::Ready(None), + Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))), + Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item)))) => { + // If this is an incremental assistant message chunk, accumulate but + // do NOT emit yet. Forward any other item (e.g. FunctionCall) right + // away so downstream consumers see it. + + let is_assistant_message = matches!( + &item, + ResponseItem::Message { role, .. } if role == "assistant" + ); + + if is_assistant_message { + match this.mode { + AggregateMode::AggregatedOnly => { + // Only use the final assistant message if we have not + // seen any deltas; otherwise, deltas already built the + // cumulative text and this would duplicate it. + if this.cumulative.is_empty() + && let ResponseItem::Message { content, .. } = &item + && let Some(text) = content.iter().find_map(|c| match c { + ContentItem::OutputText { text } => Some(text), + _ => None, + }) + { + this.cumulative.push_str(text); + } + // Swallow assistant message here; emit on Completed. + continue; + } + AggregateMode::Streaming => { + // In streaming mode, if we have not seen any deltas, forward + // the final assistant message directly. If deltas were seen, + // suppress the final message to avoid duplication. + if this.cumulative.is_empty() { + return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone( + item, + )))); + } + continue; + } + } + } + + // Not an assistant message – forward immediately. + return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item)))); + } + Poll::Ready(Some(Ok(ResponseEvent::RateLimits(snapshot)))) => { + return Poll::Ready(Some(Ok(ResponseEvent::RateLimits(snapshot)))); + } + Poll::Ready(Some(Ok(ResponseEvent::Completed { + response_id, + token_usage, + }))) => { + // Build any aggregated items in the correct order: Reasoning first, then Message. + let mut emitted_any = false; + + if !this.cumulative_reasoning.is_empty() + && matches!(this.mode, AggregateMode::AggregatedOnly) + { + let aggregated_reasoning = ResponseItem::Reasoning { + id: String::new(), + summary: Vec::new(), + content: Some(vec![ReasoningItemContent::ReasoningText { + text: std::mem::take(&mut this.cumulative_reasoning), + }]), + encrypted_content: None, + }; + this.pending + .push_back(ResponseEvent::OutputItemDone(aggregated_reasoning)); + emitted_any = true; + } + + // Always emit the final aggregated assistant message when any + // content deltas have been observed. In AggregatedOnly mode this + // is the sole assistant output; in Streaming mode this finalizes + // the streamed deltas into a terminal OutputItemDone so callers + // can persist/render the message once per turn. + if !this.cumulative.is_empty() { + let aggregated_message = ResponseItem::Message { + id: None, + role: "assistant".to_string(), + content: vec![ContentItem::OutputText { + text: std::mem::take(&mut this.cumulative), + }], + }; + this.pending + .push_back(ResponseEvent::OutputItemDone(aggregated_message)); + emitted_any = true; + } + + // Always emit Completed last when anything was aggregated. + if emitted_any { + this.pending.push_back(ResponseEvent::Completed { + response_id: response_id.clone(), + token_usage: token_usage.clone(), + }); + if let Some(ev) = this.pending.pop_front() { + return Poll::Ready(Some(Ok(ev))); + } + } + + // Nothing aggregated – forward Completed directly. + return Poll::Ready(Some(Ok(ResponseEvent::Completed { + response_id, + token_usage, + }))); + } + Poll::Ready(Some(Ok(ResponseEvent::Created))) => { + // These events are exclusive to the Responses API and + // will never appear in a Chat Completions stream. + continue; + } + Poll::Ready(Some(Ok(ResponseEvent::OutputTextDelta(delta)))) => { + // Always accumulate deltas so we can emit a final OutputItemDone at Completed. + this.cumulative.push_str(&delta); + if matches!(this.mode, AggregateMode::Streaming) { + // In streaming mode, also forward the delta immediately. + return Poll::Ready(Some(Ok(ResponseEvent::OutputTextDelta(delta)))); + } + } + Poll::Ready(Some(Ok(ResponseEvent::ReasoningContentDelta(delta)))) => { + // Always accumulate reasoning deltas so we can emit a final Reasoning item at Completed. + this.cumulative_reasoning.push_str(&delta); + if matches!(this.mode, AggregateMode::Streaming) { + // In streaming mode, also forward the delta immediately. + return Poll::Ready(Some(Ok(ResponseEvent::ReasoningContentDelta(delta)))); + } + } + Poll::Ready(Some(Ok(ResponseEvent::ReasoningSummaryDelta(_)))) => {} + Poll::Ready(Some(Ok(ResponseEvent::ReasoningSummaryPartAdded))) => {} + Poll::Ready(Some(Ok(ResponseEvent::OutputItemAdded(item)))) => { + return Poll::Ready(Some(Ok(ResponseEvent::OutputItemAdded(item)))); + } + } + } + } +} + +/// Extension trait that activates aggregation on any stream of [`ResponseEvent`]. +pub(crate) trait AggregateStreamExt: Stream> + Sized { + /// Returns a new stream that emits **only** the final assistant message + /// per turn instead of every incremental delta. The produced + /// `ResponseEvent` sequence for a typical text turn looks like: + /// + /// ```ignore + /// OutputItemDone() + /// Completed + /// ``` + /// + /// No other `OutputItemDone` events will be seen by the caller. + fn aggregate(self) -> AggregatedChatStream { + AggregatedChatStream::new(self, AggregateMode::AggregatedOnly) + } +} + +impl AggregateStreamExt for T where T: Stream> + Sized {} + +impl AggregatedChatStream { + fn new(inner: S, mode: AggregateMode) -> Self { + AggregatedChatStream { + inner, + cumulative: String::new(), + cumulative_reasoning: String::new(), + pending: std::collections::VecDeque::new(), + mode, + } + } + + pub(crate) fn streaming_mode(inner: S) -> Self { + Self::new(inner, AggregateMode::Streaming) + } +} diff --git a/codex-rs/core/src/client/chat_completions.rs b/codex-rs/core/src/client/chat_completions.rs index b43498cd21..fc78d77a49 100644 --- a/codex-rs/core/src/client/chat_completions.rs +++ b/codex-rs/core/src/client/chat_completions.rs @@ -4,6 +4,8 @@ use crate::ModelProviderInfo; use crate::client::ResponseEvent; use crate::client::ResponseStream; use crate::client::http::CodexHttpClient; +use crate::client::retry::RetryableStreamError; +use crate::client::retry::retry_stream; use crate::client_common::Prompt; use crate::error::CodexErr; use crate::error::ConnectionFailedError; @@ -25,9 +27,6 @@ use futures::Stream; use futures::TryStreamExt; use reqwest::StatusCode; use serde_json::json; -use std::pin::Pin; -use std::task::Context; -use std::task::Poll; use tokio::sync::mpsc; use tracing::trace; @@ -361,28 +360,15 @@ pub(crate) async fn stream_chat_completions( } let max_attempts = provider.request_max_retries(); - let mut last_error = None; - for attempt in 0..=max_attempts { - match stream_single_chat_completion( - attempt, - client, - provider, - otel_event_manager, - body.clone(), - ) - .await - { - Ok(stream) => return Ok(stream), - Err(e) => { - last_error = Some(e); - if attempt != max_attempts { - tokio::time::sleep(backoff(attempt)).await; - } - } + retry_stream(max_attempts, |attempt| { + let body = body.clone(); + async move { + stream_single_chat_completion(attempt, client, provider, otel_event_manager, body) + .await + .map_err(ChatStreamError::Retryable) } - } - - Err(last_error.unwrap_or(CodexErr::InternalServerError)) + }) + .await } async fn stream_single_chat_completion( @@ -463,6 +449,22 @@ async fn stream_single_chat_completion( } } +enum ChatStreamError { + Retryable(CodexErr), +} + +impl RetryableStreamError for ChatStreamError { + fn delay(&self, attempt: u64) -> Option { + Some(backoff(attempt)) + } + + fn into_error(self) -> CodexErr { + match self { + ChatStreamError::Retryable(e) => e, + } + } +} + async fn append_assistant_text( tx_event: &mpsc::Sender>, assistant_item: &mut Option, @@ -748,224 +750,3 @@ async fn process_chat_sse( } } } - -/// Optional client-side aggregation helper -/// -/// Stream adapter that merges the incremental `OutputItemDone` chunks coming from -/// [`process_chat_sse`] into a *running* assistant message, **suppressing the -/// per-token deltas**. The stream stays silent while the model is thinking -/// and only emits two events per turn: -/// -/// 1. `ResponseEvent::OutputItemDone` with the *complete* assistant message -/// (fully concatenated). -/// 2. The original `ResponseEvent::Completed` right after it. -/// -/// This mirrors the behaviour the TypeScript CLI exposes to its higher layers. -/// -/// The adapter is intentionally *lossless*: callers who do **not** opt in via -/// [`AggregateStreamExt::aggregate()`] keep receiving the original unmodified -/// events. -#[derive(Copy, Clone, Eq, PartialEq)] -enum AggregateMode { - AggregatedOnly, - Streaming, -} - -pub(crate) struct AggregatedChatStream { - inner: S, - cumulative: String, - cumulative_reasoning: String, - pending: std::collections::VecDeque, - mode: AggregateMode, -} - -impl Stream for AggregatedChatStream -where - S: Stream> + Unpin, -{ - type Item = Result; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.get_mut(); - - // First, flush any buffered events from the previous call. - if let Some(ev) = this.pending.pop_front() { - return Poll::Ready(Some(Ok(ev))); - } - - loop { - match Pin::new(&mut this.inner).poll_next(cx) { - Poll::Pending => return Poll::Pending, - Poll::Ready(None) => return Poll::Ready(None), - Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))), - Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item)))) => { - // If this is an incremental assistant message chunk, accumulate but - // do NOT emit yet. Forward any other item (e.g. FunctionCall) right - // away so downstream consumers see it. - - let is_assistant_message = matches!( - &item, - ResponseItem::Message { role, .. } if role == "assistant" - ); - - if is_assistant_message { - match this.mode { - AggregateMode::AggregatedOnly => { - // Only use the final assistant message if we have not - // seen any deltas; otherwise, deltas already built the - // cumulative text and this would duplicate it. - if this.cumulative.is_empty() - && let ResponseItem::Message { content, .. } = &item - && let Some(text) = content.iter().find_map(|c| match c { - ContentItem::OutputText { text } => Some(text), - _ => None, - }) - { - this.cumulative.push_str(text); - } - // Swallow assistant message here; emit on Completed. - continue; - } - AggregateMode::Streaming => { - // In streaming mode, if we have not seen any deltas, forward - // the final assistant message directly. If deltas were seen, - // suppress the final message to avoid duplication. - if this.cumulative.is_empty() { - return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone( - item, - )))); - } - continue; - } - } - } - - // Not an assistant message – forward immediately. - return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item)))); - } - Poll::Ready(Some(Ok(ResponseEvent::RateLimits(snapshot)))) => { - return Poll::Ready(Some(Ok(ResponseEvent::RateLimits(snapshot)))); - } - Poll::Ready(Some(Ok(ResponseEvent::Completed { - response_id, - token_usage, - }))) => { - // Build any aggregated items in the correct order: Reasoning first, then Message. - let mut emitted_any = false; - - if !this.cumulative_reasoning.is_empty() - && matches!(this.mode, AggregateMode::AggregatedOnly) - { - let aggregated_reasoning = ResponseItem::Reasoning { - id: String::new(), - summary: Vec::new(), - content: Some(vec![ReasoningItemContent::ReasoningText { - text: std::mem::take(&mut this.cumulative_reasoning), - }]), - encrypted_content: None, - }; - this.pending - .push_back(ResponseEvent::OutputItemDone(aggregated_reasoning)); - emitted_any = true; - } - - // Always emit the final aggregated assistant message when any - // content deltas have been observed. In AggregatedOnly mode this - // is the sole assistant output; in Streaming mode this finalizes - // the streamed deltas into a terminal OutputItemDone so callers - // can persist/render the message once per turn. - if !this.cumulative.is_empty() { - let aggregated_message = ResponseItem::Message { - id: None, - role: "assistant".to_string(), - content: vec![ContentItem::OutputText { - text: std::mem::take(&mut this.cumulative), - }], - }; - this.pending - .push_back(ResponseEvent::OutputItemDone(aggregated_message)); - emitted_any = true; - } - - // Always emit Completed last when anything was aggregated. - if emitted_any { - this.pending.push_back(ResponseEvent::Completed { - response_id: response_id.clone(), - token_usage: token_usage.clone(), - }); - if let Some(ev) = this.pending.pop_front() { - return Poll::Ready(Some(Ok(ev))); - } - } - - // Nothing aggregated – forward Completed directly. - return Poll::Ready(Some(Ok(ResponseEvent::Completed { - response_id, - token_usage, - }))); - } - Poll::Ready(Some(Ok(ResponseEvent::Created))) => { - // These events are exclusive to the Responses API and - // will never appear in a Chat Completions stream. - continue; - } - Poll::Ready(Some(Ok(ResponseEvent::OutputTextDelta(delta)))) => { - // Always accumulate deltas so we can emit a final OutputItemDone at Completed. - this.cumulative.push_str(&delta); - if matches!(this.mode, AggregateMode::Streaming) { - // In streaming mode, also forward the delta immediately. - return Poll::Ready(Some(Ok(ResponseEvent::OutputTextDelta(delta)))); - } - } - Poll::Ready(Some(Ok(ResponseEvent::ReasoningContentDelta(delta)))) => { - // Always accumulate reasoning deltas so we can emit a final Reasoning item at Completed. - this.cumulative_reasoning.push_str(&delta); - if matches!(this.mode, AggregateMode::Streaming) { - // In streaming mode, also forward the delta immediately. - return Poll::Ready(Some(Ok(ResponseEvent::ReasoningContentDelta(delta)))); - } - } - Poll::Ready(Some(Ok(ResponseEvent::ReasoningSummaryDelta(_)))) => {} - Poll::Ready(Some(Ok(ResponseEvent::ReasoningSummaryPartAdded))) => {} - Poll::Ready(Some(Ok(ResponseEvent::OutputItemAdded(item)))) => { - return Poll::Ready(Some(Ok(ResponseEvent::OutputItemAdded(item)))); - } - } - } - } -} - -/// Extension trait that activates aggregation on any stream of [`ResponseEvent`]. -pub(crate) trait AggregateStreamExt: Stream> + Sized { - /// Returns a new stream that emits **only** the final assistant message - /// per turn instead of every incremental delta. The produced - /// `ResponseEvent` sequence for a typical text turn looks like: - /// - /// ```ignore - /// OutputItemDone() - /// Completed - /// ``` - /// - /// No other `OutputItemDone` events will be seen by the caller. - fn aggregate(self) -> AggregatedChatStream { - AggregatedChatStream::new(self, AggregateMode::AggregatedOnly) - } -} - -impl AggregateStreamExt for T where T: Stream> + Sized {} - -impl AggregatedChatStream { - fn new(inner: S, mode: AggregateMode) -> Self { - AggregatedChatStream { - inner, - cumulative: String::new(), - cumulative_reasoning: String::new(), - pending: std::collections::VecDeque::new(), - mode, - } - } - - pub(crate) fn streaming_mode(inner: S) -> Self { - Self::new(inner, AggregateMode::Streaming) - } -} diff --git a/codex-rs/core/src/client/mod.rs b/codex-rs/core/src/client/mod.rs index 2e901abfbb..2a56246644 100644 --- a/codex-rs/core/src/client/mod.rs +++ b/codex-rs/core/src/client/mod.rs @@ -1,11 +1,14 @@ +mod aggregation; mod chat_completions; pub mod http; +mod rate_limits; mod responses; +mod retry; mod sse; pub mod types; -pub(crate) use chat_completions::AggregateStreamExt; -pub(crate) use chat_completions::AggregatedChatStream; +pub(crate) use aggregation::AggregateStreamExt; +pub(crate) use aggregation::AggregatedChatStream; pub(crate) use chat_completions::stream_chat_completions; pub use responses::ModelClient; pub(crate) use types::FreeformTool; diff --git a/codex-rs/core/src/client/rate_limits.rs b/codex-rs/core/src/client/rate_limits.rs new file mode 100644 index 0000000000..a2b2b0d2fc --- /dev/null +++ b/codex-rs/core/src/client/rate_limits.rs @@ -0,0 +1,86 @@ +use crate::protocol::RateLimitSnapshot; +use crate::protocol::RateLimitWindow; +use chrono::Utc; +use reqwest::header::HeaderMap; + +/// Prefer Codex-specific aggregate rate limit headers if present; fall back +/// to raw OpenAI-style request headers otherwise. +pub(crate) fn parse_rate_limit_snapshot(headers: &HeaderMap) -> Option { + parse_codex_rate_limits(headers).or_else(|| parse_openai_rate_limits(headers)) +} + +fn parse_codex_rate_limits(headers: &HeaderMap) -> Option { + fn parse_f64(headers: &HeaderMap, name: &str) -> Option { + headers + .get(name) + .and_then(|v| v.to_str().ok()) + .and_then(|s| s.parse::().ok()) + } + + fn parse_i64(headers: &HeaderMap, name: &str) -> Option { + headers + .get(name) + .and_then(|v| v.to_str().ok()) + .and_then(|s| s.parse::().ok()) + } + + let primary_used = parse_f64(headers, "x-codex-primary-used-percent"); + let secondary_used = parse_f64(headers, "x-codex-secondary-used-percent"); + + if primary_used.is_none() && secondary_used.is_none() { + return None; + } + + let primary = primary_used.map(|used_percent| RateLimitWindow { + used_percent, + window_minutes: parse_i64(headers, "x-codex-primary-window-minutes"), + resets_at: parse_i64(headers, "x-codex-primary-reset-at"), + }); + + let secondary = secondary_used.map(|used_percent| RateLimitWindow { + used_percent, + window_minutes: parse_i64(headers, "x-codex-secondary-window-minutes"), + resets_at: parse_i64(headers, "x-codex-secondary-reset-at"), + }); + + Some(RateLimitSnapshot { primary, secondary }) +} + +fn parse_openai_rate_limits(headers: &HeaderMap) -> Option { + let limit = headers.get("x-ratelimit-limit-requests")?; + let remaining = headers.get("x-ratelimit-remaining-requests")?; + let reset_ms = headers.get("x-ratelimit-reset-requests")?; + + let limit = limit.to_str().ok()?.parse::().ok()?; + let remaining = remaining.to_str().ok()?.parse::().ok()?; + let reset_ms = reset_ms.to_str().ok()?.parse::().ok()?; + + if limit <= 0.0 { + return None; + } + + let used = (limit - remaining).max(0.0); + let used_percent = (used / limit) * 100.0; + + let window_minutes = if reset_ms <= 0 { + None + } else { + let seconds = reset_ms / 1000; + Some((seconds + 59) / 60) + }; + + let resets_at = if reset_ms > 0 { + Some(Utc::now().timestamp() + reset_ms / 1000) + } else { + None + }; + + Some(RateLimitSnapshot { + primary: Some(RateLimitWindow { + used_percent, + window_minutes, + resets_at, + }), + secondary: None, + }) +} diff --git a/codex-rs/core/src/client/responses.rs b/codex-rs/core/src/client/responses.rs index a9090708f9..1cb65fc332 100644 --- a/codex-rs/core/src/client/responses.rs +++ b/codex-rs/core/src/client/responses.rs @@ -16,7 +16,6 @@ use eventsource_stream::Eventsource; use futures::prelude::*; use regex_lite::Regex; use reqwest::StatusCode; -use reqwest::header::HeaderMap; use serde::Deserialize; use serde_json::Value; use tokio::sync::mpsc; @@ -33,6 +32,9 @@ use crate::client::ResponseEvent; use crate::client::ResponseStream; use crate::client::ResponsesApiRequest; use crate::client::create_text_param_for_request; +use crate::client::rate_limits::parse_rate_limit_snapshot; +use crate::client::retry::RetryableStreamError; +use crate::client::retry::retry_stream; use crate::client_common::Prompt; use crate::config::Config; use crate::default_client::CodexHttpClient; @@ -49,7 +51,6 @@ use crate::model_provider_info::ModelProviderInfo; use crate::model_provider_info::WireApi; use crate::openai_model_info::get_model_info; use crate::protocol::RateLimitSnapshot; -use crate::protocol::RateLimitWindow; use crate::protocol::TokenUsage; use crate::token_data::PlanType; use crate::tools::spec::create_tools_json_for_responses_api; @@ -260,30 +261,10 @@ impl ModelClient { } let max_attempts = self.provider.request_max_retries(); - for attempt in 0..=max_attempts { - match self - .attempt_stream_responses(attempt, &payload_json, &auth_manager) - .await - { - Ok(stream) => { - return Ok(stream); - } - Err(StreamAttemptError::Fatal(e)) => { - return Err(e); - } - Err(retryable_attempt_error) => { - if attempt == max_attempts { - return Err(retryable_attempt_error.into_error()); - } - - if let Some(delay) = retryable_attempt_error.delay(attempt) { - tokio::time::sleep(delay).await; - } - } - } - } - - unreachable!("stream_responses_attempt should always return"); + retry_stream(max_attempts, |attempt| { + self.attempt_stream_responses(attempt, &payload_json, &auth_manager) + }) + .await } /// Single attempt to start a streaming Responses API call. @@ -506,88 +487,6 @@ impl ModelClient { } } -fn parse_rate_limit_snapshot(headers: &HeaderMap) -> Option { - // Prefer codex-specific aggregate rate limit headers if present; fall back - // to raw OpenAI-style request headers otherwise. - parse_codex_rate_limits(headers).or_else(|| parse_openai_rate_limits(headers)) -} - -fn parse_codex_rate_limits(headers: &HeaderMap) -> Option { - fn parse_f64(headers: &HeaderMap, name: &str) -> Option { - headers - .get(name) - .and_then(|v| v.to_str().ok()) - .and_then(|s| s.parse::().ok()) - } - - fn parse_i64(headers: &HeaderMap, name: &str) -> Option { - headers - .get(name) - .and_then(|v| v.to_str().ok()) - .and_then(|s| s.parse::().ok()) - } - - let primary_used = parse_f64(headers, "x-codex-primary-used-percent"); - let secondary_used = parse_f64(headers, "x-codex-secondary-used-percent"); - - if primary_used.is_none() && secondary_used.is_none() { - return None; - } - - let primary = primary_used.map(|used_percent| RateLimitWindow { - used_percent, - window_minutes: parse_i64(headers, "x-codex-primary-window-minutes"), - resets_at: parse_i64(headers, "x-codex-primary-reset-at"), - }); - - let secondary = secondary_used.map(|used_percent| RateLimitWindow { - used_percent, - window_minutes: parse_i64(headers, "x-codex-secondary-window-minutes"), - resets_at: parse_i64(headers, "x-codex-secondary-reset-at"), - }); - - Some(RateLimitSnapshot { primary, secondary }) -} - -fn parse_openai_rate_limits(headers: &HeaderMap) -> Option { - let limit = headers.get("x-ratelimit-limit-requests")?; - let remaining = headers.get("x-ratelimit-remaining-requests")?; - let reset_ms = headers.get("x-ratelimit-reset-requests")?; - - let limit = limit.to_str().ok()?.parse::().ok()?; - let remaining = remaining.to_str().ok()?.parse::().ok()?; - let reset_ms = reset_ms.to_str().ok()?.parse::().ok()?; - - if limit <= 0.0 { - return None; - } - - let used = (limit - remaining).max(0.0); - let used_percent = (used / limit) * 100.0; - - let window_minutes = if reset_ms <= 0 { - None - } else { - let seconds = reset_ms / 1000; - Some((seconds + 59) / 60) - }; - - let resets_at = if reset_ms > 0 { - Some(Utc::now().timestamp() + reset_ms / 1000) - } else { - None - }; - - Some(RateLimitSnapshot { - primary: Some(RateLimitWindow { - used_percent, - window_minutes, - resets_at, - }), - secondary: None, - }) -} - /// For Azure Responses endpoints we must use `store: true` and preserve /// per-item identifiers on the input payload. The `ResponseItem` schema /// deliberately skips serializing these IDs by default, so we patch them @@ -681,6 +580,16 @@ impl StreamAttemptError { } } +impl RetryableStreamError for StreamAttemptError { + fn delay(&self, attempt: u64) -> Option { + self.delay(attempt) + } + + fn into_error(self) -> CodexErr { + self.into_error() + } +} + async fn process_sse( stream: impl Stream> + Unpin + Send + 'static, tx_event: mpsc::Sender>, diff --git a/codex-rs/core/src/client/retry.rs b/codex-rs/core/src/client/retry.rs new file mode 100644 index 0000000000..e57b89e7fb --- /dev/null +++ b/codex-rs/core/src/client/retry.rs @@ -0,0 +1,140 @@ +use std::time::Duration; + +use crate::error::CodexErr; +use crate::error::Result; + +/// Common interface for classifying stream start errors as retryable or fatal. +pub(crate) trait RetryableStreamError { + /// Returns a delay for the next retry attempt, or `None` if the error + /// should be treated as fatal and not retried. + fn delay(&self, attempt: u64) -> Option; + + /// Converts this error into the final `CodexErr` that should be surfaced + /// to callers when retries are exhausted or the error is fatal. + fn into_error(self) -> CodexErr; +} + +/// Helper to retry a streaming operation with provider-configured backoff. +/// +/// The caller supplies an `attempt_fn` that is invoked once per attempt with +/// the current attempt index in `[0, max_attempts]`. On success, the value is +/// returned immediately. On error, the error's [`RetryableStreamError`] +/// implementation decides whether to retry (with an optional delay) or to +/// surface a final error. +pub(crate) async fn retry_stream(max_attempts: u64, mut attempt_fn: F) -> Result +where + F: FnMut(u64) -> Fut, + Fut: std::future::Future>, + E: RetryableStreamError, +{ + for attempt in 0..=max_attempts { + match attempt_fn(attempt).await { + Ok(value) => return Ok(value), + Err(err) => { + let delay = err.delay(attempt); + + // Fatal error or final attempt: surface to caller. + if attempt == max_attempts || delay.is_none() { + return Err(err.into_error()); + } + + if let Some(duration) = delay { + tokio::time::sleep(duration).await; + } + } + } + } + + unreachable!("retry_stream should always return"); +} + +#[cfg(test)] +mod tests { + use super::*; + use pretty_assertions::assert_eq; + + #[derive(Clone)] + struct TestError { + fatal: bool, + } + + impl RetryableStreamError for TestError { + fn delay(&self, attempt: u64) -> Option { + if self.fatal { + None + } else { + Some(Duration::from_millis(attempt * 10)) + } + } + + fn into_error(self) -> CodexErr { + if self.fatal { + CodexErr::InternalServerError + } else { + CodexErr::Io(std::io::Error::new(std::io::ErrorKind::Other, "retryable")) + } + } + } + + #[tokio::test] + async fn retries_until_success_before_max_attempts() { + let max_attempts = 3; + + let result: Result<&str> = retry_stream(max_attempts, |attempt| async move { + if attempt < 2 { + Err(TestError { fatal: false }) + } else { + Ok("ok") + } + }) + .await; + + assert_eq!(result.unwrap(), "ok"); + } + + #[tokio::test] + async fn stops_on_fatal_error_without_retrying() { + use std::sync::Arc; + use std::sync::atomic::AtomicUsize; + use std::sync::atomic::Ordering; + + let calls = Arc::new(AtomicUsize::new(0)); + let calls_ref = calls.clone(); + + let result: Result<()> = retry_stream(5, move |_attempt| { + let calls_ref = calls_ref.clone(); + async move { + calls_ref.fetch_add(1, Ordering::SeqCst); + Err(TestError { fatal: true }) + } + }) + .await; + + assert!(result.is_err()); + assert_eq!(calls.load(Ordering::SeqCst), 1); + } + + #[tokio::test] + async fn stops_after_max_attempts_for_retryable_errors() { + use std::sync::Arc; + use std::sync::atomic::AtomicUsize; + use std::sync::atomic::Ordering; + + let calls = Arc::new(AtomicUsize::new(0)); + let calls_ref = calls.clone(); + + let max_attempts = 2; + + let result: Result<()> = retry_stream(max_attempts, move |_attempt| { + let calls_ref = calls_ref.clone(); + async move { + calls_ref.fetch_add(1, Ordering::SeqCst); + Err(TestError { fatal: false }) + } + }) + .await; + + assert!(result.is_err()); + assert_eq!(calls.load(Ordering::SeqCst), (max_attempts + 1) as usize); + } +}