mirror of
https://github.com/openai/codex.git
synced 2026-04-24 14:45:27 +00:00
282 lines
10 KiB
Rust
282 lines
10 KiB
Rust
use crate::ChatRequest;
|
|
use crate::auth::AuthProvider;
|
|
use crate::common::Prompt as ApiPrompt;
|
|
use crate::common::ResponseEvent;
|
|
use crate::common::ResponseStream;
|
|
use crate::endpoint::streaming::StreamingClient;
|
|
use crate::error::ApiError;
|
|
use crate::provider::Provider;
|
|
use crate::provider::WireApi;
|
|
use crate::sse::chat::spawn_chat_stream;
|
|
use crate::telemetry::SseTelemetry;
|
|
use codex_client::HttpTransport;
|
|
use codex_client::RequestCompression;
|
|
use codex_client::RequestTelemetry;
|
|
use codex_protocol::models::ContentItem;
|
|
use codex_protocol::models::ReasoningItemContent;
|
|
use codex_protocol::models::ResponseItem;
|
|
use codex_protocol::protocol::SessionSource;
|
|
use futures::Stream;
|
|
use http::HeaderMap;
|
|
use serde_json::Value;
|
|
use std::collections::VecDeque;
|
|
use std::pin::Pin;
|
|
use std::sync::Arc;
|
|
use std::task::Context;
|
|
use std::task::Poll;
|
|
|
|
pub struct ChatClient<T: HttpTransport, A: AuthProvider> {
|
|
streaming: StreamingClient<T, A>,
|
|
}
|
|
|
|
impl<T: HttpTransport, A: AuthProvider> ChatClient<T, A> {
|
|
pub fn new(transport: T, provider: Provider, auth: A) -> Self {
|
|
Self {
|
|
streaming: StreamingClient::new(transport, provider, auth),
|
|
}
|
|
}
|
|
|
|
pub fn with_telemetry(
|
|
self,
|
|
request: Option<Arc<dyn RequestTelemetry>>,
|
|
sse: Option<Arc<dyn SseTelemetry>>,
|
|
) -> Self {
|
|
Self {
|
|
streaming: self.streaming.with_telemetry(request, sse),
|
|
}
|
|
}
|
|
|
|
pub async fn stream_request(&self, request: ChatRequest) -> Result<ResponseStream, ApiError> {
|
|
self.stream(request.body, request.headers).await
|
|
}
|
|
|
|
pub async fn stream_prompt(
|
|
&self,
|
|
model: &str,
|
|
prompt: &ApiPrompt,
|
|
conversation_id: Option<String>,
|
|
session_source: Option<SessionSource>,
|
|
) -> Result<ResponseStream, ApiError> {
|
|
use crate::requests::ChatRequestBuilder;
|
|
|
|
let request =
|
|
ChatRequestBuilder::new(model, &prompt.instructions, &prompt.input, &prompt.tools)
|
|
.conversation_id(conversation_id)
|
|
.session_source(session_source)
|
|
.build(self.streaming.provider())?;
|
|
|
|
self.stream_request(request).await
|
|
}
|
|
|
|
fn path(&self) -> &'static str {
|
|
match self.streaming.provider().wire {
|
|
WireApi::Chat => "chat/completions",
|
|
_ => "responses",
|
|
}
|
|
}
|
|
|
|
pub async fn stream(
|
|
&self,
|
|
body: Value,
|
|
extra_headers: HeaderMap,
|
|
) -> Result<ResponseStream, ApiError> {
|
|
self.streaming
|
|
.stream(
|
|
self.path(),
|
|
body,
|
|
extra_headers,
|
|
RequestCompression::None,
|
|
spawn_chat_stream,
|
|
None,
|
|
)
|
|
.await
|
|
}
|
|
}
|
|
|
|
#[derive(Copy, Clone, Eq, PartialEq)]
|
|
pub enum AggregateMode {
|
|
AggregatedOnly,
|
|
Streaming,
|
|
}
|
|
|
|
/// Stream adapter that merges token deltas into a single assistant message per turn.
|
|
pub struct AggregatedStream {
|
|
inner: ResponseStream,
|
|
cumulative: String,
|
|
cumulative_reasoning: String,
|
|
pending: VecDeque<ResponseEvent>,
|
|
mode: AggregateMode,
|
|
}
|
|
|
|
impl Stream for AggregatedStream {
|
|
type Item = Result<ResponseEvent, ApiError>;
|
|
|
|
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
|
let this = self.get_mut();
|
|
|
|
if let Some(ev) = this.pending.pop_front() {
|
|
return Poll::Ready(Some(Ok(ev)));
|
|
}
|
|
|
|
loop {
|
|
match Pin::new(&mut this.inner).poll_next(cx) {
|
|
Poll::Pending => return Poll::Pending,
|
|
Poll::Ready(None) => return Poll::Ready(None),
|
|
Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
|
|
Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item)))) => {
|
|
let is_assistant_message = matches!(
|
|
&item,
|
|
ResponseItem::Message { role, .. } if role == "assistant"
|
|
);
|
|
|
|
if is_assistant_message {
|
|
match this.mode {
|
|
AggregateMode::AggregatedOnly => {
|
|
if this.cumulative.is_empty()
|
|
&& let ResponseItem::Message { content, .. } = &item
|
|
&& let Some(text) = content.iter().find_map(|c| match c {
|
|
ContentItem::OutputText { text } => Some(text),
|
|
_ => None,
|
|
})
|
|
{
|
|
this.cumulative.push_str(text);
|
|
}
|
|
continue;
|
|
}
|
|
AggregateMode::Streaming => {
|
|
if this.cumulative.is_empty() {
|
|
return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(
|
|
item,
|
|
))));
|
|
} else {
|
|
continue;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item))));
|
|
}
|
|
Poll::Ready(Some(Ok(ResponseEvent::ServerReasoningIncluded(included)))) => {
|
|
return Poll::Ready(Some(Ok(ResponseEvent::ServerReasoningIncluded(included))));
|
|
}
|
|
Poll::Ready(Some(Ok(ResponseEvent::RateLimits(snapshot)))) => {
|
|
return Poll::Ready(Some(Ok(ResponseEvent::RateLimits(snapshot))));
|
|
}
|
|
Poll::Ready(Some(Ok(ResponseEvent::ModelsEtag(etag)))) => {
|
|
return Poll::Ready(Some(Ok(ResponseEvent::ModelsEtag(etag))));
|
|
}
|
|
Poll::Ready(Some(Ok(ResponseEvent::Completed {
|
|
response_id,
|
|
token_usage,
|
|
}))) => {
|
|
let mut emitted_any = false;
|
|
|
|
if !this.cumulative_reasoning.is_empty() {
|
|
let aggregated_reasoning = ResponseItem::Reasoning {
|
|
id: String::new(),
|
|
summary: Vec::new(),
|
|
content: Some(vec![ReasoningItemContent::ReasoningText {
|
|
text: std::mem::take(&mut this.cumulative_reasoning),
|
|
}]),
|
|
encrypted_content: None,
|
|
};
|
|
this.pending
|
|
.push_back(ResponseEvent::OutputItemDone(aggregated_reasoning));
|
|
emitted_any = true;
|
|
}
|
|
|
|
if !this.cumulative.is_empty() {
|
|
let aggregated_message = ResponseItem::Message {
|
|
id: None,
|
|
role: "assistant".to_string(),
|
|
content: vec![ContentItem::OutputText {
|
|
text: std::mem::take(&mut this.cumulative),
|
|
}],
|
|
end_turn: None,
|
|
};
|
|
this.pending
|
|
.push_back(ResponseEvent::OutputItemDone(aggregated_message));
|
|
emitted_any = true;
|
|
}
|
|
|
|
if emitted_any {
|
|
this.pending.push_back(ResponseEvent::Completed {
|
|
response_id: response_id.clone(),
|
|
token_usage: token_usage.clone(),
|
|
});
|
|
if let Some(ev) = this.pending.pop_front() {
|
|
return Poll::Ready(Some(Ok(ev)));
|
|
}
|
|
}
|
|
|
|
return Poll::Ready(Some(Ok(ResponseEvent::Completed {
|
|
response_id,
|
|
token_usage,
|
|
})));
|
|
}
|
|
Poll::Ready(Some(Ok(ResponseEvent::Created))) => {
|
|
continue;
|
|
}
|
|
Poll::Ready(Some(Ok(ResponseEvent::OutputTextDelta(delta)))) => {
|
|
this.cumulative.push_str(&delta);
|
|
if matches!(this.mode, AggregateMode::Streaming) {
|
|
return Poll::Ready(Some(Ok(ResponseEvent::OutputTextDelta(delta))));
|
|
} else {
|
|
continue;
|
|
}
|
|
}
|
|
Poll::Ready(Some(Ok(ResponseEvent::ReasoningContentDelta {
|
|
delta,
|
|
content_index,
|
|
}))) => {
|
|
this.cumulative_reasoning.push_str(&delta);
|
|
if matches!(this.mode, AggregateMode::Streaming) {
|
|
return Poll::Ready(Some(Ok(ResponseEvent::ReasoningContentDelta {
|
|
delta,
|
|
content_index,
|
|
})));
|
|
} else {
|
|
continue;
|
|
}
|
|
}
|
|
Poll::Ready(Some(Ok(ResponseEvent::ReasoningSummaryDelta { .. }))) => continue,
|
|
Poll::Ready(Some(Ok(ResponseEvent::ReasoningSummaryPartAdded { .. }))) => {
|
|
continue;
|
|
}
|
|
Poll::Ready(Some(Ok(ResponseEvent::OutputItemAdded(item)))) => {
|
|
return Poll::Ready(Some(Ok(ResponseEvent::OutputItemAdded(item))));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
pub trait AggregateStreamExt {
|
|
fn aggregate(self) -> AggregatedStream;
|
|
|
|
fn streaming_mode(self) -> ResponseStream;
|
|
}
|
|
|
|
impl AggregateStreamExt for ResponseStream {
|
|
fn aggregate(self) -> AggregatedStream {
|
|
AggregatedStream::new(self, AggregateMode::AggregatedOnly)
|
|
}
|
|
|
|
fn streaming_mode(self) -> ResponseStream {
|
|
self
|
|
}
|
|
}
|
|
|
|
impl AggregatedStream {
|
|
fn new(inner: ResponseStream, mode: AggregateMode) -> Self {
|
|
AggregatedStream {
|
|
inner,
|
|
cumulative: String::new(),
|
|
cumulative_reasoning: String::new(),
|
|
pending: VecDeque::new(),
|
|
mode,
|
|
}
|
|
}
|
|
}
|