mirror of
https://github.com/openai/codex.git
synced 2026-04-24 14:45:27 +00:00
V2
This commit is contained in:
177
codex-rs/api-client/src/aggregate.rs
Normal file
177
codex-rs/api-client/src/aggregate.rs
Normal file
@@ -0,0 +1,177 @@
|
||||
use std::collections::VecDeque;
|
||||
use std::pin::Pin;
|
||||
use std::task::Context;
|
||||
use std::task::Poll;
|
||||
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use futures::Stream;
|
||||
|
||||
use crate::error::Result;
|
||||
use crate::stream::ResponseEvent;
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub enum ChatAggregationMode {
|
||||
AggregatedOnly,
|
||||
Streaming,
|
||||
}
|
||||
|
||||
pub trait AggregateStreamExt: Stream<Item = Result<ResponseEvent>> + Sized {
|
||||
fn aggregate(self) -> AggregatedChatStream<Self>
|
||||
where
|
||||
Self: Unpin,
|
||||
{
|
||||
AggregatedChatStream::new(self, AggregateMode::AggregatedOnly)
|
||||
}
|
||||
|
||||
fn streaming_mode(self) -> AggregatedChatStream<Self>
|
||||
where
|
||||
Self: Unpin,
|
||||
{
|
||||
AggregatedChatStream::new(self, AggregateMode::Streaming)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> AggregateStreamExt for S where S: Stream<Item = Result<ResponseEvent>> + Sized + Unpin {}
|
||||
|
||||
enum AggregateMode {
|
||||
AggregatedOnly,
|
||||
Streaming,
|
||||
}
|
||||
|
||||
pub struct AggregatedChatStream<S> {
|
||||
inner: S,
|
||||
cumulative: String,
|
||||
cumulative_reasoning: String,
|
||||
pending: VecDeque<ResponseEvent>,
|
||||
mode: AggregateMode,
|
||||
}
|
||||
|
||||
impl<S> AggregatedChatStream<S>
|
||||
where
|
||||
S: Stream<Item = Result<ResponseEvent>> + Unpin,
|
||||
{
|
||||
fn new(inner: S, mode: AggregateMode) -> Self {
|
||||
Self {
|
||||
inner,
|
||||
cumulative: String::new(),
|
||||
cumulative_reasoning: String::new(),
|
||||
pending: VecDeque::new(),
|
||||
mode,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> Stream for AggregatedChatStream<S>
|
||||
where
|
||||
S: Stream<Item = Result<ResponseEvent>> + Unpin,
|
||||
{
|
||||
type Item = Result<ResponseEvent>;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
if let Some(ev) = self.pending.pop_front() {
|
||||
return Poll::Ready(Some(Ok(ev)));
|
||||
}
|
||||
|
||||
loop {
|
||||
match Pin::new(&mut self.inner).poll_next(cx) {
|
||||
std::task::Poll::Pending => return Poll::Pending,
|
||||
std::task::Poll::Ready(None) => return std::task::Poll::Ready(None),
|
||||
std::task::Poll::Ready(Some(Err(err))) => {
|
||||
return std::task::Poll::Ready(Some(Err(err)));
|
||||
}
|
||||
std::task::Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item)))) => {
|
||||
let is_assistant_message = matches!(
|
||||
&item,
|
||||
ResponseItem::Message { role, .. } if role == "assistant"
|
||||
);
|
||||
|
||||
if is_assistant_message {
|
||||
if let ResponseItem::Message { role, content, .. } = item {
|
||||
let mut text = String::new();
|
||||
for c in content {
|
||||
match c {
|
||||
ContentItem::InputText { text: t }
|
||||
| ContentItem::OutputText { text: t } => text.push_str(&t),
|
||||
ContentItem::InputImage { image_url } => {
|
||||
text.push_str(&image_url)
|
||||
}
|
||||
}
|
||||
}
|
||||
self.cumulative.push_str(&text);
|
||||
if matches!(self.mode, AggregateMode::Streaming) {
|
||||
let output_item =
|
||||
ResponseEvent::OutputItemDone(ResponseItem::Message {
|
||||
id: None,
|
||||
role,
|
||||
content: vec![ContentItem::OutputText {
|
||||
text: self.cumulative.clone(),
|
||||
}],
|
||||
});
|
||||
self.pending.push_back(output_item);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return std::task::Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(
|
||||
item,
|
||||
))));
|
||||
}
|
||||
}
|
||||
std::task::Poll::Ready(Some(Ok(ResponseEvent::OutputItemAdded(item)))) => {
|
||||
if !matches!(
|
||||
&item,
|
||||
ResponseItem::Message { role, .. } if role == "assistant"
|
||||
) {
|
||||
return std::task::Poll::Ready(Some(Ok(ResponseEvent::OutputItemAdded(
|
||||
item,
|
||||
))));
|
||||
}
|
||||
}
|
||||
std::task::Poll::Ready(Some(Ok(ResponseEvent::ReasoningContentDelta(delta)))) => {
|
||||
self.cumulative_reasoning.push_str(&delta);
|
||||
if matches!(self.mode, AggregateMode::Streaming) {
|
||||
let ev =
|
||||
ResponseEvent::ReasoningContentDelta(self.cumulative_reasoning.clone());
|
||||
self.pending.push_back(ev);
|
||||
}
|
||||
}
|
||||
std::task::Poll::Ready(Some(Ok(ResponseEvent::ReasoningSummaryDelta(delta)))) => {
|
||||
if matches!(self.mode, AggregateMode::Streaming) {
|
||||
let ev = ResponseEvent::ReasoningSummaryDelta(delta);
|
||||
self.pending.push_back(ev);
|
||||
}
|
||||
}
|
||||
std::task::Poll::Ready(Some(Ok(ResponseEvent::Completed {
|
||||
response_id,
|
||||
token_usage,
|
||||
}))) => {
|
||||
let assistant_event = ResponseEvent::OutputItemDone(ResponseItem::Message {
|
||||
id: None,
|
||||
role: "assistant".to_string(),
|
||||
content: vec![ContentItem::OutputText {
|
||||
text: self.cumulative.clone(),
|
||||
}],
|
||||
});
|
||||
let completion_event = ResponseEvent::Completed {
|
||||
response_id,
|
||||
token_usage,
|
||||
};
|
||||
|
||||
if matches!(self.mode, AggregateMode::Streaming) {
|
||||
self.pending.push_back(assistant_event);
|
||||
self.pending.push_back(completion_event);
|
||||
} else {
|
||||
return std::task::Poll::Ready(Some(Ok(assistant_event)));
|
||||
}
|
||||
}
|
||||
std::task::Poll::Ready(Some(Ok(ev))) => {
|
||||
return std::task::Poll::Ready(Some(Ok(ev)));
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(ev) = self.pending.pop_front() {
|
||||
return std::task::Poll::Ready(Some(Ok(ev)));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
31
codex-rs/api-client/src/common.rs
Normal file
31
codex-rs/api-client/src/common.rs
Normal file
@@ -0,0 +1,31 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use codex_protocol::protocol::SessionSource;
|
||||
use codex_protocol::protocol::SubAgentSource;
|
||||
|
||||
/// Exponential backoff with a 100ms base and a cap on the exponent to avoid
|
||||
/// unbounded growth. The attempt number is clamped to [0, 6].
|
||||
pub(crate) fn backoff(attempt: i64) -> Duration {
|
||||
let capped = attempt.clamp(0, 6) as u32;
|
||||
Duration::from_millis(100 * 2_i64.pow(capped) as u64)
|
||||
}
|
||||
|
||||
/// Apply the `x-openai-subagent` header when the session source indicates a
|
||||
/// subagent. Returns the original builder unchanged when not applicable.
|
||||
pub(crate) fn apply_subagent_header(
|
||||
mut builder: reqwest::RequestBuilder,
|
||||
session_source: Option<&SessionSource>,
|
||||
) -> reqwest::RequestBuilder {
|
||||
if let Some(SessionSource::SubAgent(sub)) = session_source {
|
||||
let subagent = if let SubAgentSource::Other(label) = sub {
|
||||
label.clone()
|
||||
} else {
|
||||
serde_json::to_value(sub)
|
||||
.ok()
|
||||
.and_then(|v| v.as_str().map(std::string::ToString::to_string))
|
||||
.unwrap_or_else(|| "other".to_string())
|
||||
};
|
||||
builder = builder.header("x-openai-subagent", subagent);
|
||||
}
|
||||
builder
|
||||
}
|
||||
Reference in New Issue
Block a user