mirror of
https://github.com/openai/codex.git
synced 2026-04-24 06:35:50 +00:00
Clean-up 3
This commit is contained in:
229
codex-rs/core/src/client/aggregation.rs
Normal file
229
codex-rs/core/src/client/aggregation.rs
Normal file
@@ -0,0 +1,229 @@
|
||||
use std::pin::Pin;
|
||||
use std::task::Context;
|
||||
use std::task::Poll;
|
||||
|
||||
use crate::client::ResponseEvent;
|
||||
use crate::error::Result;
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::ReasoningItemContent;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use futures::Stream;
|
||||
|
||||
/// Optional client-side aggregation helper
|
||||
///
|
||||
/// Stream adapter that merges the incremental `OutputItemDone` chunks coming from
|
||||
/// the chat SSE decoder into a *running* assistant message, **suppressing the
|
||||
/// per-token deltas**. The stream stays silent while the model is thinking and
|
||||
/// only emits two events per turn:
|
||||
///
|
||||
/// 1. `ResponseEvent::OutputItemDone` with the *complete* assistant message
|
||||
/// (fully concatenated).
|
||||
/// 2. The original `ResponseEvent::Completed` right after it.
|
||||
///
|
||||
/// The adapter is intentionally *lossless*: callers who do **not** opt in via
|
||||
/// [`AggregateStreamExt::aggregate()`] keep receiving the original unmodified
|
||||
/// events.
|
||||
#[derive(Copy, Clone, Eq, PartialEq)]
|
||||
enum AggregateMode {
|
||||
AggregatedOnly,
|
||||
Streaming,
|
||||
}
|
||||
|
||||
pub(crate) struct AggregatedChatStream<S> {
|
||||
inner: S,
|
||||
cumulative: String,
|
||||
cumulative_reasoning: String,
|
||||
pending: std::collections::VecDeque<ResponseEvent>,
|
||||
mode: AggregateMode,
|
||||
}
|
||||
|
||||
impl<S> Stream for AggregatedChatStream<S>
|
||||
where
|
||||
S: Stream<Item = Result<ResponseEvent>> + Unpin,
|
||||
{
|
||||
type Item = Result<ResponseEvent>;
|
||||
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
let this = self.get_mut();
|
||||
|
||||
// First, flush any buffered events from the previous call.
|
||||
if let Some(ev) = this.pending.pop_front() {
|
||||
return Poll::Ready(Some(Ok(ev)));
|
||||
}
|
||||
|
||||
loop {
|
||||
match Pin::new(&mut this.inner).poll_next(cx) {
|
||||
Poll::Pending => return Poll::Pending,
|
||||
Poll::Ready(None) => return Poll::Ready(None),
|
||||
Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
|
||||
Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item)))) => {
|
||||
// If this is an incremental assistant message chunk, accumulate but
|
||||
// do NOT emit yet. Forward any other item (e.g. FunctionCall) right
|
||||
// away so downstream consumers see it.
|
||||
|
||||
let is_assistant_message = matches!(
|
||||
&item,
|
||||
ResponseItem::Message { role, .. } if role == "assistant"
|
||||
);
|
||||
|
||||
if is_assistant_message {
|
||||
match this.mode {
|
||||
AggregateMode::AggregatedOnly => {
|
||||
// Only use the final assistant message if we have not
|
||||
// seen any deltas; otherwise, deltas already built the
|
||||
// cumulative text and this would duplicate it.
|
||||
if this.cumulative.is_empty()
|
||||
&& let ResponseItem::Message { content, .. } = &item
|
||||
&& let Some(text) = content.iter().find_map(|c| match c {
|
||||
ContentItem::OutputText { text } => Some(text),
|
||||
_ => None,
|
||||
})
|
||||
{
|
||||
this.cumulative.push_str(text);
|
||||
}
|
||||
// Swallow assistant message here; emit on Completed.
|
||||
continue;
|
||||
}
|
||||
AggregateMode::Streaming => {
|
||||
// In streaming mode, if we have not seen any deltas, forward
|
||||
// the final assistant message directly. If deltas were seen,
|
||||
// suppress the final message to avoid duplication.
|
||||
if this.cumulative.is_empty() {
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(
|
||||
item,
|
||||
))));
|
||||
}
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Not an assistant message – forward immediately.
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item))));
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::RateLimits(snapshot)))) => {
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::RateLimits(snapshot))));
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::Completed {
|
||||
response_id,
|
||||
token_usage,
|
||||
}))) => {
|
||||
// Build any aggregated items in the correct order: Reasoning first, then Message.
|
||||
let mut emitted_any = false;
|
||||
|
||||
if !this.cumulative_reasoning.is_empty()
|
||||
&& matches!(this.mode, AggregateMode::AggregatedOnly)
|
||||
{
|
||||
let aggregated_reasoning = ResponseItem::Reasoning {
|
||||
id: String::new(),
|
||||
summary: Vec::new(),
|
||||
content: Some(vec![ReasoningItemContent::ReasoningText {
|
||||
text: std::mem::take(&mut this.cumulative_reasoning),
|
||||
}]),
|
||||
encrypted_content: None,
|
||||
};
|
||||
this.pending
|
||||
.push_back(ResponseEvent::OutputItemDone(aggregated_reasoning));
|
||||
emitted_any = true;
|
||||
}
|
||||
|
||||
// Always emit the final aggregated assistant message when any
|
||||
// content deltas have been observed. In AggregatedOnly mode this
|
||||
// is the sole assistant output; in Streaming mode this finalizes
|
||||
// the streamed deltas into a terminal OutputItemDone so callers
|
||||
// can persist/render the message once per turn.
|
||||
if !this.cumulative.is_empty() {
|
||||
let aggregated_message = ResponseItem::Message {
|
||||
id: None,
|
||||
role: "assistant".to_string(),
|
||||
content: vec![ContentItem::OutputText {
|
||||
text: std::mem::take(&mut this.cumulative),
|
||||
}],
|
||||
};
|
||||
this.pending
|
||||
.push_back(ResponseEvent::OutputItemDone(aggregated_message));
|
||||
emitted_any = true;
|
||||
}
|
||||
|
||||
// Always emit Completed last when anything was aggregated.
|
||||
if emitted_any {
|
||||
this.pending.push_back(ResponseEvent::Completed {
|
||||
response_id: response_id.clone(),
|
||||
token_usage: token_usage.clone(),
|
||||
});
|
||||
if let Some(ev) = this.pending.pop_front() {
|
||||
return Poll::Ready(Some(Ok(ev)));
|
||||
}
|
||||
}
|
||||
|
||||
// Nothing aggregated – forward Completed directly.
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::Completed {
|
||||
response_id,
|
||||
token_usage,
|
||||
})));
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::Created))) => {
|
||||
// These events are exclusive to the Responses API and
|
||||
// will never appear in a Chat Completions stream.
|
||||
continue;
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::OutputTextDelta(delta)))) => {
|
||||
// Always accumulate deltas so we can emit a final OutputItemDone at Completed.
|
||||
this.cumulative.push_str(&delta);
|
||||
if matches!(this.mode, AggregateMode::Streaming) {
|
||||
// In streaming mode, also forward the delta immediately.
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::OutputTextDelta(delta))));
|
||||
}
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::ReasoningContentDelta(delta)))) => {
|
||||
// Always accumulate reasoning deltas so we can emit a final Reasoning item at Completed.
|
||||
this.cumulative_reasoning.push_str(&delta);
|
||||
if matches!(this.mode, AggregateMode::Streaming) {
|
||||
// In streaming mode, also forward the delta immediately.
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::ReasoningContentDelta(delta))));
|
||||
}
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::ReasoningSummaryDelta(_)))) => {}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::ReasoningSummaryPartAdded))) => {}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::OutputItemAdded(item)))) => {
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::OutputItemAdded(item))));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Extension trait that activates aggregation on any stream of [`ResponseEvent`].
|
||||
pub(crate) trait AggregateStreamExt: Stream<Item = Result<ResponseEvent>> + Sized {
|
||||
/// Returns a new stream that emits **only** the final assistant message
|
||||
/// per turn instead of every incremental delta. The produced
|
||||
/// `ResponseEvent` sequence for a typical text turn looks like:
|
||||
///
|
||||
/// ```ignore
|
||||
/// OutputItemDone(<full message>)
|
||||
/// Completed
|
||||
/// ```
|
||||
///
|
||||
/// No other `OutputItemDone` events will be seen by the caller.
|
||||
fn aggregate(self) -> AggregatedChatStream<Self> {
|
||||
AggregatedChatStream::new(self, AggregateMode::AggregatedOnly)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> AggregateStreamExt for T where T: Stream<Item = Result<ResponseEvent>> + Sized {}
|
||||
|
||||
impl<S> AggregatedChatStream<S> {
|
||||
fn new(inner: S, mode: AggregateMode) -> Self {
|
||||
AggregatedChatStream {
|
||||
inner,
|
||||
cumulative: String::new(),
|
||||
cumulative_reasoning: String::new(),
|
||||
pending: std::collections::VecDeque::new(),
|
||||
mode,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn streaming_mode(inner: S) -> Self {
|
||||
Self::new(inner, AggregateMode::Streaming)
|
||||
}
|
||||
}
|
||||
@@ -4,6 +4,8 @@ use crate::ModelProviderInfo;
|
||||
use crate::client::ResponseEvent;
|
||||
use crate::client::ResponseStream;
|
||||
use crate::client::http::CodexHttpClient;
|
||||
use crate::client::retry::RetryableStreamError;
|
||||
use crate::client::retry::retry_stream;
|
||||
use crate::client_common::Prompt;
|
||||
use crate::error::CodexErr;
|
||||
use crate::error::ConnectionFailedError;
|
||||
@@ -25,9 +27,6 @@ use futures::Stream;
|
||||
use futures::TryStreamExt;
|
||||
use reqwest::StatusCode;
|
||||
use serde_json::json;
|
||||
use std::pin::Pin;
|
||||
use std::task::Context;
|
||||
use std::task::Poll;
|
||||
use tokio::sync::mpsc;
|
||||
use tracing::trace;
|
||||
|
||||
@@ -361,28 +360,15 @@ pub(crate) async fn stream_chat_completions(
|
||||
}
|
||||
|
||||
let max_attempts = provider.request_max_retries();
|
||||
let mut last_error = None;
|
||||
for attempt in 0..=max_attempts {
|
||||
match stream_single_chat_completion(
|
||||
attempt,
|
||||
client,
|
||||
provider,
|
||||
otel_event_manager,
|
||||
body.clone(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(stream) => return Ok(stream),
|
||||
Err(e) => {
|
||||
last_error = Some(e);
|
||||
if attempt != max_attempts {
|
||||
tokio::time::sleep(backoff(attempt)).await;
|
||||
}
|
||||
}
|
||||
retry_stream(max_attempts, |attempt| {
|
||||
let body = body.clone();
|
||||
async move {
|
||||
stream_single_chat_completion(attempt, client, provider, otel_event_manager, body)
|
||||
.await
|
||||
.map_err(ChatStreamError::Retryable)
|
||||
}
|
||||
}
|
||||
|
||||
Err(last_error.unwrap_or(CodexErr::InternalServerError))
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
async fn stream_single_chat_completion(
|
||||
@@ -463,6 +449,22 @@ async fn stream_single_chat_completion(
|
||||
}
|
||||
}
|
||||
|
||||
enum ChatStreamError {
|
||||
Retryable(CodexErr),
|
||||
}
|
||||
|
||||
impl RetryableStreamError for ChatStreamError {
|
||||
fn delay(&self, attempt: u64) -> Option<Duration> {
|
||||
Some(backoff(attempt))
|
||||
}
|
||||
|
||||
fn into_error(self) -> CodexErr {
|
||||
match self {
|
||||
ChatStreamError::Retryable(e) => e,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn append_assistant_text(
|
||||
tx_event: &mpsc::Sender<Result<ResponseEvent>>,
|
||||
assistant_item: &mut Option<ResponseItem>,
|
||||
@@ -748,224 +750,3 @@ async fn process_chat_sse<S>(
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Optional client-side aggregation helper
|
||||
///
|
||||
/// Stream adapter that merges the incremental `OutputItemDone` chunks coming from
|
||||
/// [`process_chat_sse`] into a *running* assistant message, **suppressing the
|
||||
/// per-token deltas**. The stream stays silent while the model is thinking
|
||||
/// and only emits two events per turn:
|
||||
///
|
||||
/// 1. `ResponseEvent::OutputItemDone` with the *complete* assistant message
|
||||
/// (fully concatenated).
|
||||
/// 2. The original `ResponseEvent::Completed` right after it.
|
||||
///
|
||||
/// This mirrors the behaviour the TypeScript CLI exposes to its higher layers.
|
||||
///
|
||||
/// The adapter is intentionally *lossless*: callers who do **not** opt in via
|
||||
/// [`AggregateStreamExt::aggregate()`] keep receiving the original unmodified
|
||||
/// events.
|
||||
#[derive(Copy, Clone, Eq, PartialEq)]
|
||||
enum AggregateMode {
|
||||
AggregatedOnly,
|
||||
Streaming,
|
||||
}
|
||||
|
||||
pub(crate) struct AggregatedChatStream<S> {
|
||||
inner: S,
|
||||
cumulative: String,
|
||||
cumulative_reasoning: String,
|
||||
pending: std::collections::VecDeque<ResponseEvent>,
|
||||
mode: AggregateMode,
|
||||
}
|
||||
|
||||
impl<S> Stream for AggregatedChatStream<S>
|
||||
where
|
||||
S: Stream<Item = Result<ResponseEvent>> + Unpin,
|
||||
{
|
||||
type Item = Result<ResponseEvent>;
|
||||
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
let this = self.get_mut();
|
||||
|
||||
// First, flush any buffered events from the previous call.
|
||||
if let Some(ev) = this.pending.pop_front() {
|
||||
return Poll::Ready(Some(Ok(ev)));
|
||||
}
|
||||
|
||||
loop {
|
||||
match Pin::new(&mut this.inner).poll_next(cx) {
|
||||
Poll::Pending => return Poll::Pending,
|
||||
Poll::Ready(None) => return Poll::Ready(None),
|
||||
Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
|
||||
Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item)))) => {
|
||||
// If this is an incremental assistant message chunk, accumulate but
|
||||
// do NOT emit yet. Forward any other item (e.g. FunctionCall) right
|
||||
// away so downstream consumers see it.
|
||||
|
||||
let is_assistant_message = matches!(
|
||||
&item,
|
||||
ResponseItem::Message { role, .. } if role == "assistant"
|
||||
);
|
||||
|
||||
if is_assistant_message {
|
||||
match this.mode {
|
||||
AggregateMode::AggregatedOnly => {
|
||||
// Only use the final assistant message if we have not
|
||||
// seen any deltas; otherwise, deltas already built the
|
||||
// cumulative text and this would duplicate it.
|
||||
if this.cumulative.is_empty()
|
||||
&& let ResponseItem::Message { content, .. } = &item
|
||||
&& let Some(text) = content.iter().find_map(|c| match c {
|
||||
ContentItem::OutputText { text } => Some(text),
|
||||
_ => None,
|
||||
})
|
||||
{
|
||||
this.cumulative.push_str(text);
|
||||
}
|
||||
// Swallow assistant message here; emit on Completed.
|
||||
continue;
|
||||
}
|
||||
AggregateMode::Streaming => {
|
||||
// In streaming mode, if we have not seen any deltas, forward
|
||||
// the final assistant message directly. If deltas were seen,
|
||||
// suppress the final message to avoid duplication.
|
||||
if this.cumulative.is_empty() {
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(
|
||||
item,
|
||||
))));
|
||||
}
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Not an assistant message – forward immediately.
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item))));
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::RateLimits(snapshot)))) => {
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::RateLimits(snapshot))));
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::Completed {
|
||||
response_id,
|
||||
token_usage,
|
||||
}))) => {
|
||||
// Build any aggregated items in the correct order: Reasoning first, then Message.
|
||||
let mut emitted_any = false;
|
||||
|
||||
if !this.cumulative_reasoning.is_empty()
|
||||
&& matches!(this.mode, AggregateMode::AggregatedOnly)
|
||||
{
|
||||
let aggregated_reasoning = ResponseItem::Reasoning {
|
||||
id: String::new(),
|
||||
summary: Vec::new(),
|
||||
content: Some(vec![ReasoningItemContent::ReasoningText {
|
||||
text: std::mem::take(&mut this.cumulative_reasoning),
|
||||
}]),
|
||||
encrypted_content: None,
|
||||
};
|
||||
this.pending
|
||||
.push_back(ResponseEvent::OutputItemDone(aggregated_reasoning));
|
||||
emitted_any = true;
|
||||
}
|
||||
|
||||
// Always emit the final aggregated assistant message when any
|
||||
// content deltas have been observed. In AggregatedOnly mode this
|
||||
// is the sole assistant output; in Streaming mode this finalizes
|
||||
// the streamed deltas into a terminal OutputItemDone so callers
|
||||
// can persist/render the message once per turn.
|
||||
if !this.cumulative.is_empty() {
|
||||
let aggregated_message = ResponseItem::Message {
|
||||
id: None,
|
||||
role: "assistant".to_string(),
|
||||
content: vec![ContentItem::OutputText {
|
||||
text: std::mem::take(&mut this.cumulative),
|
||||
}],
|
||||
};
|
||||
this.pending
|
||||
.push_back(ResponseEvent::OutputItemDone(aggregated_message));
|
||||
emitted_any = true;
|
||||
}
|
||||
|
||||
// Always emit Completed last when anything was aggregated.
|
||||
if emitted_any {
|
||||
this.pending.push_back(ResponseEvent::Completed {
|
||||
response_id: response_id.clone(),
|
||||
token_usage: token_usage.clone(),
|
||||
});
|
||||
if let Some(ev) = this.pending.pop_front() {
|
||||
return Poll::Ready(Some(Ok(ev)));
|
||||
}
|
||||
}
|
||||
|
||||
// Nothing aggregated – forward Completed directly.
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::Completed {
|
||||
response_id,
|
||||
token_usage,
|
||||
})));
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::Created))) => {
|
||||
// These events are exclusive to the Responses API and
|
||||
// will never appear in a Chat Completions stream.
|
||||
continue;
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::OutputTextDelta(delta)))) => {
|
||||
// Always accumulate deltas so we can emit a final OutputItemDone at Completed.
|
||||
this.cumulative.push_str(&delta);
|
||||
if matches!(this.mode, AggregateMode::Streaming) {
|
||||
// In streaming mode, also forward the delta immediately.
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::OutputTextDelta(delta))));
|
||||
}
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::ReasoningContentDelta(delta)))) => {
|
||||
// Always accumulate reasoning deltas so we can emit a final Reasoning item at Completed.
|
||||
this.cumulative_reasoning.push_str(&delta);
|
||||
if matches!(this.mode, AggregateMode::Streaming) {
|
||||
// In streaming mode, also forward the delta immediately.
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::ReasoningContentDelta(delta))));
|
||||
}
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::ReasoningSummaryDelta(_)))) => {}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::ReasoningSummaryPartAdded))) => {}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::OutputItemAdded(item)))) => {
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::OutputItemAdded(item))));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Extension trait that activates aggregation on any stream of [`ResponseEvent`].
|
||||
pub(crate) trait AggregateStreamExt: Stream<Item = Result<ResponseEvent>> + Sized {
|
||||
/// Returns a new stream that emits **only** the final assistant message
|
||||
/// per turn instead of every incremental delta. The produced
|
||||
/// `ResponseEvent` sequence for a typical text turn looks like:
|
||||
///
|
||||
/// ```ignore
|
||||
/// OutputItemDone(<full message>)
|
||||
/// Completed
|
||||
/// ```
|
||||
///
|
||||
/// No other `OutputItemDone` events will be seen by the caller.
|
||||
fn aggregate(self) -> AggregatedChatStream<Self> {
|
||||
AggregatedChatStream::new(self, AggregateMode::AggregatedOnly)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> AggregateStreamExt for T where T: Stream<Item = Result<ResponseEvent>> + Sized {}
|
||||
|
||||
impl<S> AggregatedChatStream<S> {
|
||||
fn new(inner: S, mode: AggregateMode) -> Self {
|
||||
AggregatedChatStream {
|
||||
inner,
|
||||
cumulative: String::new(),
|
||||
cumulative_reasoning: String::new(),
|
||||
pending: std::collections::VecDeque::new(),
|
||||
mode,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn streaming_mode(inner: S) -> Self {
|
||||
Self::new(inner, AggregateMode::Streaming)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
mod aggregation;
|
||||
mod chat_completions;
|
||||
pub mod http;
|
||||
mod rate_limits;
|
||||
mod responses;
|
||||
mod retry;
|
||||
mod sse;
|
||||
pub mod types;
|
||||
|
||||
pub(crate) use chat_completions::AggregateStreamExt;
|
||||
pub(crate) use chat_completions::AggregatedChatStream;
|
||||
pub(crate) use aggregation::AggregateStreamExt;
|
||||
pub(crate) use aggregation::AggregatedChatStream;
|
||||
pub(crate) use chat_completions::stream_chat_completions;
|
||||
pub use responses::ModelClient;
|
||||
pub(crate) use types::FreeformTool;
|
||||
|
||||
86
codex-rs/core/src/client/rate_limits.rs
Normal file
86
codex-rs/core/src/client/rate_limits.rs
Normal file
@@ -0,0 +1,86 @@
|
||||
use crate::protocol::RateLimitSnapshot;
|
||||
use crate::protocol::RateLimitWindow;
|
||||
use chrono::Utc;
|
||||
use reqwest::header::HeaderMap;
|
||||
|
||||
/// Prefer Codex-specific aggregate rate limit headers if present; fall back
|
||||
/// to raw OpenAI-style request headers otherwise.
|
||||
pub(crate) fn parse_rate_limit_snapshot(headers: &HeaderMap) -> Option<RateLimitSnapshot> {
|
||||
parse_codex_rate_limits(headers).or_else(|| parse_openai_rate_limits(headers))
|
||||
}
|
||||
|
||||
fn parse_codex_rate_limits(headers: &HeaderMap) -> Option<RateLimitSnapshot> {
|
||||
fn parse_f64(headers: &HeaderMap, name: &str) -> Option<f64> {
|
||||
headers
|
||||
.get(name)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|s| s.parse::<f64>().ok())
|
||||
}
|
||||
|
||||
fn parse_i64(headers: &HeaderMap, name: &str) -> Option<i64> {
|
||||
headers
|
||||
.get(name)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|s| s.parse::<i64>().ok())
|
||||
}
|
||||
|
||||
let primary_used = parse_f64(headers, "x-codex-primary-used-percent");
|
||||
let secondary_used = parse_f64(headers, "x-codex-secondary-used-percent");
|
||||
|
||||
if primary_used.is_none() && secondary_used.is_none() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let primary = primary_used.map(|used_percent| RateLimitWindow {
|
||||
used_percent,
|
||||
window_minutes: parse_i64(headers, "x-codex-primary-window-minutes"),
|
||||
resets_at: parse_i64(headers, "x-codex-primary-reset-at"),
|
||||
});
|
||||
|
||||
let secondary = secondary_used.map(|used_percent| RateLimitWindow {
|
||||
used_percent,
|
||||
window_minutes: parse_i64(headers, "x-codex-secondary-window-minutes"),
|
||||
resets_at: parse_i64(headers, "x-codex-secondary-reset-at"),
|
||||
});
|
||||
|
||||
Some(RateLimitSnapshot { primary, secondary })
|
||||
}
|
||||
|
||||
fn parse_openai_rate_limits(headers: &HeaderMap) -> Option<RateLimitSnapshot> {
|
||||
let limit = headers.get("x-ratelimit-limit-requests")?;
|
||||
let remaining = headers.get("x-ratelimit-remaining-requests")?;
|
||||
let reset_ms = headers.get("x-ratelimit-reset-requests")?;
|
||||
|
||||
let limit = limit.to_str().ok()?.parse::<f64>().ok()?;
|
||||
let remaining = remaining.to_str().ok()?.parse::<f64>().ok()?;
|
||||
let reset_ms = reset_ms.to_str().ok()?.parse::<i64>().ok()?;
|
||||
|
||||
if limit <= 0.0 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let used = (limit - remaining).max(0.0);
|
||||
let used_percent = (used / limit) * 100.0;
|
||||
|
||||
let window_minutes = if reset_ms <= 0 {
|
||||
None
|
||||
} else {
|
||||
let seconds = reset_ms / 1000;
|
||||
Some((seconds + 59) / 60)
|
||||
};
|
||||
|
||||
let resets_at = if reset_ms > 0 {
|
||||
Some(Utc::now().timestamp() + reset_ms / 1000)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Some(RateLimitSnapshot {
|
||||
primary: Some(RateLimitWindow {
|
||||
used_percent,
|
||||
window_minutes,
|
||||
resets_at,
|
||||
}),
|
||||
secondary: None,
|
||||
})
|
||||
}
|
||||
@@ -16,7 +16,6 @@ use eventsource_stream::Eventsource;
|
||||
use futures::prelude::*;
|
||||
use regex_lite::Regex;
|
||||
use reqwest::StatusCode;
|
||||
use reqwest::header::HeaderMap;
|
||||
use serde::Deserialize;
|
||||
use serde_json::Value;
|
||||
use tokio::sync::mpsc;
|
||||
@@ -33,6 +32,9 @@ use crate::client::ResponseEvent;
|
||||
use crate::client::ResponseStream;
|
||||
use crate::client::ResponsesApiRequest;
|
||||
use crate::client::create_text_param_for_request;
|
||||
use crate::client::rate_limits::parse_rate_limit_snapshot;
|
||||
use crate::client::retry::RetryableStreamError;
|
||||
use crate::client::retry::retry_stream;
|
||||
use crate::client_common::Prompt;
|
||||
use crate::config::Config;
|
||||
use crate::default_client::CodexHttpClient;
|
||||
@@ -49,7 +51,6 @@ use crate::model_provider_info::ModelProviderInfo;
|
||||
use crate::model_provider_info::WireApi;
|
||||
use crate::openai_model_info::get_model_info;
|
||||
use crate::protocol::RateLimitSnapshot;
|
||||
use crate::protocol::RateLimitWindow;
|
||||
use crate::protocol::TokenUsage;
|
||||
use crate::token_data::PlanType;
|
||||
use crate::tools::spec::create_tools_json_for_responses_api;
|
||||
@@ -260,30 +261,10 @@ impl ModelClient {
|
||||
}
|
||||
|
||||
let max_attempts = self.provider.request_max_retries();
|
||||
for attempt in 0..=max_attempts {
|
||||
match self
|
||||
.attempt_stream_responses(attempt, &payload_json, &auth_manager)
|
||||
.await
|
||||
{
|
||||
Ok(stream) => {
|
||||
return Ok(stream);
|
||||
}
|
||||
Err(StreamAttemptError::Fatal(e)) => {
|
||||
return Err(e);
|
||||
}
|
||||
Err(retryable_attempt_error) => {
|
||||
if attempt == max_attempts {
|
||||
return Err(retryable_attempt_error.into_error());
|
||||
}
|
||||
|
||||
if let Some(delay) = retryable_attempt_error.delay(attempt) {
|
||||
tokio::time::sleep(delay).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unreachable!("stream_responses_attempt should always return");
|
||||
retry_stream(max_attempts, |attempt| {
|
||||
self.attempt_stream_responses(attempt, &payload_json, &auth_manager)
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
/// Single attempt to start a streaming Responses API call.
|
||||
@@ -506,88 +487,6 @@ impl ModelClient {
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_rate_limit_snapshot(headers: &HeaderMap) -> Option<RateLimitSnapshot> {
|
||||
// Prefer codex-specific aggregate rate limit headers if present; fall back
|
||||
// to raw OpenAI-style request headers otherwise.
|
||||
parse_codex_rate_limits(headers).or_else(|| parse_openai_rate_limits(headers))
|
||||
}
|
||||
|
||||
fn parse_codex_rate_limits(headers: &HeaderMap) -> Option<RateLimitSnapshot> {
|
||||
fn parse_f64(headers: &HeaderMap, name: &str) -> Option<f64> {
|
||||
headers
|
||||
.get(name)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|s| s.parse::<f64>().ok())
|
||||
}
|
||||
|
||||
fn parse_i64(headers: &HeaderMap, name: &str) -> Option<i64> {
|
||||
headers
|
||||
.get(name)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|s| s.parse::<i64>().ok())
|
||||
}
|
||||
|
||||
let primary_used = parse_f64(headers, "x-codex-primary-used-percent");
|
||||
let secondary_used = parse_f64(headers, "x-codex-secondary-used-percent");
|
||||
|
||||
if primary_used.is_none() && secondary_used.is_none() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let primary = primary_used.map(|used_percent| RateLimitWindow {
|
||||
used_percent,
|
||||
window_minutes: parse_i64(headers, "x-codex-primary-window-minutes"),
|
||||
resets_at: parse_i64(headers, "x-codex-primary-reset-at"),
|
||||
});
|
||||
|
||||
let secondary = secondary_used.map(|used_percent| RateLimitWindow {
|
||||
used_percent,
|
||||
window_minutes: parse_i64(headers, "x-codex-secondary-window-minutes"),
|
||||
resets_at: parse_i64(headers, "x-codex-secondary-reset-at"),
|
||||
});
|
||||
|
||||
Some(RateLimitSnapshot { primary, secondary })
|
||||
}
|
||||
|
||||
fn parse_openai_rate_limits(headers: &HeaderMap) -> Option<RateLimitSnapshot> {
|
||||
let limit = headers.get("x-ratelimit-limit-requests")?;
|
||||
let remaining = headers.get("x-ratelimit-remaining-requests")?;
|
||||
let reset_ms = headers.get("x-ratelimit-reset-requests")?;
|
||||
|
||||
let limit = limit.to_str().ok()?.parse::<f64>().ok()?;
|
||||
let remaining = remaining.to_str().ok()?.parse::<f64>().ok()?;
|
||||
let reset_ms = reset_ms.to_str().ok()?.parse::<i64>().ok()?;
|
||||
|
||||
if limit <= 0.0 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let used = (limit - remaining).max(0.0);
|
||||
let used_percent = (used / limit) * 100.0;
|
||||
|
||||
let window_minutes = if reset_ms <= 0 {
|
||||
None
|
||||
} else {
|
||||
let seconds = reset_ms / 1000;
|
||||
Some((seconds + 59) / 60)
|
||||
};
|
||||
|
||||
let resets_at = if reset_ms > 0 {
|
||||
Some(Utc::now().timestamp() + reset_ms / 1000)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Some(RateLimitSnapshot {
|
||||
primary: Some(RateLimitWindow {
|
||||
used_percent,
|
||||
window_minutes,
|
||||
resets_at,
|
||||
}),
|
||||
secondary: None,
|
||||
})
|
||||
}
|
||||
|
||||
/// For Azure Responses endpoints we must use `store: true` and preserve
|
||||
/// per-item identifiers on the input payload. The `ResponseItem` schema
|
||||
/// deliberately skips serializing these IDs by default, so we patch them
|
||||
@@ -681,6 +580,16 @@ impl StreamAttemptError {
|
||||
}
|
||||
}
|
||||
|
||||
impl RetryableStreamError for StreamAttemptError {
|
||||
fn delay(&self, attempt: u64) -> Option<Duration> {
|
||||
self.delay(attempt)
|
||||
}
|
||||
|
||||
fn into_error(self) -> CodexErr {
|
||||
self.into_error()
|
||||
}
|
||||
}
|
||||
|
||||
async fn process_sse(
|
||||
stream: impl Stream<Item = std::result::Result<Bytes, CodexErr>> + Unpin + Send + 'static,
|
||||
tx_event: mpsc::Sender<Result<ResponseEvent>>,
|
||||
|
||||
140
codex-rs/core/src/client/retry.rs
Normal file
140
codex-rs/core/src/client/retry.rs
Normal file
@@ -0,0 +1,140 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::error::CodexErr;
|
||||
use crate::error::Result;
|
||||
|
||||
/// Common interface for classifying stream start errors as retryable or fatal.
|
||||
pub(crate) trait RetryableStreamError {
|
||||
/// Returns a delay for the next retry attempt, or `None` if the error
|
||||
/// should be treated as fatal and not retried.
|
||||
fn delay(&self, attempt: u64) -> Option<Duration>;
|
||||
|
||||
/// Converts this error into the final `CodexErr` that should be surfaced
|
||||
/// to callers when retries are exhausted or the error is fatal.
|
||||
fn into_error(self) -> CodexErr;
|
||||
}
|
||||
|
||||
/// Helper to retry a streaming operation with provider-configured backoff.
|
||||
///
|
||||
/// The caller supplies an `attempt_fn` that is invoked once per attempt with
|
||||
/// the current attempt index in `[0, max_attempts]`. On success, the value is
|
||||
/// returned immediately. On error, the error's [`RetryableStreamError`]
|
||||
/// implementation decides whether to retry (with an optional delay) or to
|
||||
/// surface a final error.
|
||||
pub(crate) async fn retry_stream<F, Fut, T, E>(max_attempts: u64, mut attempt_fn: F) -> Result<T>
|
||||
where
|
||||
F: FnMut(u64) -> Fut,
|
||||
Fut: std::future::Future<Output = std::result::Result<T, E>>,
|
||||
E: RetryableStreamError,
|
||||
{
|
||||
for attempt in 0..=max_attempts {
|
||||
match attempt_fn(attempt).await {
|
||||
Ok(value) => return Ok(value),
|
||||
Err(err) => {
|
||||
let delay = err.delay(attempt);
|
||||
|
||||
// Fatal error or final attempt: surface to caller.
|
||||
if attempt == max_attempts || delay.is_none() {
|
||||
return Err(err.into_error());
|
||||
}
|
||||
|
||||
if let Some(duration) = delay {
|
||||
tokio::time::sleep(duration).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unreachable!("retry_stream should always return");
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct TestError {
|
||||
fatal: bool,
|
||||
}
|
||||
|
||||
impl RetryableStreamError for TestError {
|
||||
fn delay(&self, attempt: u64) -> Option<Duration> {
|
||||
if self.fatal {
|
||||
None
|
||||
} else {
|
||||
Some(Duration::from_millis(attempt * 10))
|
||||
}
|
||||
}
|
||||
|
||||
fn into_error(self) -> CodexErr {
|
||||
if self.fatal {
|
||||
CodexErr::InternalServerError
|
||||
} else {
|
||||
CodexErr::Io(std::io::Error::new(std::io::ErrorKind::Other, "retryable"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn retries_until_success_before_max_attempts() {
|
||||
let max_attempts = 3;
|
||||
|
||||
let result: Result<&str> = retry_stream(max_attempts, |attempt| async move {
|
||||
if attempt < 2 {
|
||||
Err(TestError { fatal: false })
|
||||
} else {
|
||||
Ok("ok")
|
||||
}
|
||||
})
|
||||
.await;
|
||||
|
||||
assert_eq!(result.unwrap(), "ok");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn stops_on_fatal_error_without_retrying() {
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicUsize;
|
||||
use std::sync::atomic::Ordering;
|
||||
|
||||
let calls = Arc::new(AtomicUsize::new(0));
|
||||
let calls_ref = calls.clone();
|
||||
|
||||
let result: Result<()> = retry_stream(5, move |_attempt| {
|
||||
let calls_ref = calls_ref.clone();
|
||||
async move {
|
||||
calls_ref.fetch_add(1, Ordering::SeqCst);
|
||||
Err(TestError { fatal: true })
|
||||
}
|
||||
})
|
||||
.await;
|
||||
|
||||
assert!(result.is_err());
|
||||
assert_eq!(calls.load(Ordering::SeqCst), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn stops_after_max_attempts_for_retryable_errors() {
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicUsize;
|
||||
use std::sync::atomic::Ordering;
|
||||
|
||||
let calls = Arc::new(AtomicUsize::new(0));
|
||||
let calls_ref = calls.clone();
|
||||
|
||||
let max_attempts = 2;
|
||||
|
||||
let result: Result<()> = retry_stream(max_attempts, move |_attempt| {
|
||||
let calls_ref = calls_ref.clone();
|
||||
async move {
|
||||
calls_ref.fetch_add(1, Ordering::SeqCst);
|
||||
Err(TestError { fatal: false })
|
||||
}
|
||||
})
|
||||
.await;
|
||||
|
||||
assert!(result.is_err());
|
||||
assert_eq!(calls.load(Ordering::SeqCst), (max_attempts + 1) as usize);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user