use crate::ChatRequest; use crate::auth::AuthProvider; use crate::common::Prompt as ApiPrompt; use crate::common::ResponseEvent; use crate::common::ResponseStream; use crate::endpoint::streaming::StreamingClient; use crate::error::ApiError; use crate::provider::Provider; use crate::provider::WireApi; use crate::sse::chat::spawn_chat_stream; use crate::telemetry::SseTelemetry; use codex_client::HttpTransport; use codex_client::RequestCompression; use codex_client::RequestTelemetry; use codex_protocol::models::ContentItem; use codex_protocol::models::ReasoningItemContent; use codex_protocol::models::ResponseItem; use codex_protocol::protocol::SessionSource; use futures::Stream; use http::HeaderMap; use serde_json::Value; use std::collections::VecDeque; use std::pin::Pin; use std::sync::Arc; use std::task::Context; use std::task::Poll; pub struct ChatClient { streaming: StreamingClient, } impl ChatClient { pub fn new(transport: T, provider: Provider, auth: A) -> Self { Self { streaming: StreamingClient::new(transport, provider, auth), } } pub fn with_telemetry( self, request: Option>, sse: Option>, ) -> Self { Self { streaming: self.streaming.with_telemetry(request, sse), } } pub async fn stream_request(&self, request: ChatRequest) -> Result { self.stream(request.body, request.headers).await } pub async fn stream_prompt( &self, model: &str, prompt: &ApiPrompt, conversation_id: Option, session_source: Option, ) -> Result { use crate::requests::ChatRequestBuilder; let request = ChatRequestBuilder::new(model, &prompt.instructions, &prompt.input, &prompt.tools) .conversation_id(conversation_id) .session_source(session_source) .build(self.streaming.provider())?; self.stream_request(request).await } fn path(&self) -> &'static str { match self.streaming.provider().wire { WireApi::Chat => "chat/completions", _ => "responses", } } pub async fn stream( &self, body: Value, extra_headers: HeaderMap, ) -> Result { self.streaming .stream( self.path(), body, extra_headers, RequestCompression::None, spawn_chat_stream, ) .await } } #[derive(Copy, Clone, Eq, PartialEq)] pub enum AggregateMode { AggregatedOnly, Streaming, } /// Stream adapter that merges token deltas into a single assistant message per turn. pub struct AggregatedStream { inner: ResponseStream, cumulative: String, cumulative_reasoning: String, pending: VecDeque, mode: AggregateMode, } impl Stream for AggregatedStream { type Item = Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.get_mut(); if let Some(ev) = this.pending.pop_front() { return Poll::Ready(Some(Ok(ev))); } loop { match Pin::new(&mut this.inner).poll_next(cx) { Poll::Pending => return Poll::Pending, Poll::Ready(None) => return Poll::Ready(None), Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))), Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item)))) => { let is_assistant_message = matches!( &item, ResponseItem::Message { role, .. } if role == "assistant" ); if is_assistant_message { match this.mode { AggregateMode::AggregatedOnly => { if this.cumulative.is_empty() && let ResponseItem::Message { content, .. } = &item && let Some(text) = content.iter().find_map(|c| match c { ContentItem::OutputText { text } => Some(text), _ => None, }) { this.cumulative.push_str(text); } continue; } AggregateMode::Streaming => { if this.cumulative.is_empty() { return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone( item, )))); } else { continue; } } } } return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item)))); } Poll::Ready(Some(Ok(ResponseEvent::RateLimits(snapshot)))) => { return Poll::Ready(Some(Ok(ResponseEvent::RateLimits(snapshot)))); } Poll::Ready(Some(Ok(ResponseEvent::ModelsEtag(etag)))) => { return Poll::Ready(Some(Ok(ResponseEvent::ModelsEtag(etag)))); } Poll::Ready(Some(Ok(ResponseEvent::Completed { response_id, token_usage, }))) => { let mut emitted_any = false; if !this.cumulative_reasoning.is_empty() { let aggregated_reasoning = ResponseItem::Reasoning { id: String::new(), summary: Vec::new(), content: Some(vec![ReasoningItemContent::ReasoningText { text: std::mem::take(&mut this.cumulative_reasoning), }]), encrypted_content: None, }; this.pending .push_back(ResponseEvent::OutputItemDone(aggregated_reasoning)); emitted_any = true; } if !this.cumulative.is_empty() { let aggregated_message = ResponseItem::Message { id: None, role: "assistant".to_string(), content: vec![ContentItem::OutputText { text: std::mem::take(&mut this.cumulative), }], }; this.pending .push_back(ResponseEvent::OutputItemDone(aggregated_message)); emitted_any = true; } if emitted_any { this.pending.push_back(ResponseEvent::Completed { response_id: response_id.clone(), token_usage: token_usage.clone(), }); if let Some(ev) = this.pending.pop_front() { return Poll::Ready(Some(Ok(ev))); } } return Poll::Ready(Some(Ok(ResponseEvent::Completed { response_id, token_usage, }))); } Poll::Ready(Some(Ok(ResponseEvent::Created))) => { continue; } Poll::Ready(Some(Ok(ResponseEvent::OutputTextDelta(delta)))) => { this.cumulative.push_str(&delta); if matches!(this.mode, AggregateMode::Streaming) { return Poll::Ready(Some(Ok(ResponseEvent::OutputTextDelta(delta)))); } else { continue; } } Poll::Ready(Some(Ok(ResponseEvent::ReasoningContentDelta { delta, content_index, }))) => { this.cumulative_reasoning.push_str(&delta); if matches!(this.mode, AggregateMode::Streaming) { return Poll::Ready(Some(Ok(ResponseEvent::ReasoningContentDelta { delta, content_index, }))); } else { continue; } } Poll::Ready(Some(Ok(ResponseEvent::ReasoningSummaryDelta { .. }))) => continue, Poll::Ready(Some(Ok(ResponseEvent::ReasoningSummaryPartAdded { .. }))) => { continue; } Poll::Ready(Some(Ok(ResponseEvent::OutputItemAdded(item)))) => { return Poll::Ready(Some(Ok(ResponseEvent::OutputItemAdded(item)))); } } } } } pub trait AggregateStreamExt { fn aggregate(self) -> AggregatedStream; fn streaming_mode(self) -> ResponseStream; } impl AggregateStreamExt for ResponseStream { fn aggregate(self) -> AggregatedStream { AggregatedStream::new(self, AggregateMode::AggregatedOnly) } fn streaming_mode(self) -> ResponseStream { self } } impl AggregatedStream { fn new(inner: ResponseStream, mode: AggregateMode) -> Self { AggregatedStream { inner, cumulative: String::new(), cumulative_reasoning: String::new(), pending: VecDeque::new(), mode, } } }