This commit is contained in:
jif-oai
2025-11-10 12:05:51 +00:00
parent 1bac24f827
commit dabf219a45
2 changed files with 208 additions and 0 deletions

View File

@@ -0,0 +1,177 @@
use std::collections::VecDeque;
use std::pin::Pin;
use std::task::Context;
use std::task::Poll;
use codex_protocol::models::ContentItem;
use codex_protocol::models::ResponseItem;
use futures::Stream;
use crate::error::Result;
use crate::stream::ResponseEvent;
#[derive(Clone, Copy, Debug)]
pub enum ChatAggregationMode {
AggregatedOnly,
Streaming,
}
pub trait AggregateStreamExt: Stream<Item = Result<ResponseEvent>> + Sized {
fn aggregate(self) -> AggregatedChatStream<Self>
where
Self: Unpin,
{
AggregatedChatStream::new(self, AggregateMode::AggregatedOnly)
}
fn streaming_mode(self) -> AggregatedChatStream<Self>
where
Self: Unpin,
{
AggregatedChatStream::new(self, AggregateMode::Streaming)
}
}
impl<S> AggregateStreamExt for S where S: Stream<Item = Result<ResponseEvent>> + Sized + Unpin {}
enum AggregateMode {
AggregatedOnly,
Streaming,
}
pub struct AggregatedChatStream<S> {
inner: S,
cumulative: String,
cumulative_reasoning: String,
pending: VecDeque<ResponseEvent>,
mode: AggregateMode,
}
impl<S> AggregatedChatStream<S>
where
S: Stream<Item = Result<ResponseEvent>> + Unpin,
{
fn new(inner: S, mode: AggregateMode) -> Self {
Self {
inner,
cumulative: String::new(),
cumulative_reasoning: String::new(),
pending: VecDeque::new(),
mode,
}
}
}
impl<S> Stream for AggregatedChatStream<S>
where
S: Stream<Item = Result<ResponseEvent>> + Unpin,
{
type Item = Result<ResponseEvent>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
if let Some(ev) = self.pending.pop_front() {
return Poll::Ready(Some(Ok(ev)));
}
loop {
match Pin::new(&mut self.inner).poll_next(cx) {
std::task::Poll::Pending => return Poll::Pending,
std::task::Poll::Ready(None) => return std::task::Poll::Ready(None),
std::task::Poll::Ready(Some(Err(err))) => {
return std::task::Poll::Ready(Some(Err(err)));
}
std::task::Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item)))) => {
let is_assistant_message = matches!(
&item,
ResponseItem::Message { role, .. } if role == "assistant"
);
if is_assistant_message {
if let ResponseItem::Message { role, content, .. } = item {
let mut text = String::new();
for c in content {
match c {
ContentItem::InputText { text: t }
| ContentItem::OutputText { text: t } => text.push_str(&t),
ContentItem::InputImage { image_url } => {
text.push_str(&image_url)
}
}
}
self.cumulative.push_str(&text);
if matches!(self.mode, AggregateMode::Streaming) {
let output_item =
ResponseEvent::OutputItemDone(ResponseItem::Message {
id: None,
role,
content: vec![ContentItem::OutputText {
text: self.cumulative.clone(),
}],
});
self.pending.push_back(output_item);
}
}
} else {
return std::task::Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(
item,
))));
}
}
std::task::Poll::Ready(Some(Ok(ResponseEvent::OutputItemAdded(item)))) => {
if !matches!(
&item,
ResponseItem::Message { role, .. } if role == "assistant"
) {
return std::task::Poll::Ready(Some(Ok(ResponseEvent::OutputItemAdded(
item,
))));
}
}
std::task::Poll::Ready(Some(Ok(ResponseEvent::ReasoningContentDelta(delta)))) => {
self.cumulative_reasoning.push_str(&delta);
if matches!(self.mode, AggregateMode::Streaming) {
let ev =
ResponseEvent::ReasoningContentDelta(self.cumulative_reasoning.clone());
self.pending.push_back(ev);
}
}
std::task::Poll::Ready(Some(Ok(ResponseEvent::ReasoningSummaryDelta(delta)))) => {
if matches!(self.mode, AggregateMode::Streaming) {
let ev = ResponseEvent::ReasoningSummaryDelta(delta);
self.pending.push_back(ev);
}
}
std::task::Poll::Ready(Some(Ok(ResponseEvent::Completed {
response_id,
token_usage,
}))) => {
let assistant_event = ResponseEvent::OutputItemDone(ResponseItem::Message {
id: None,
role: "assistant".to_string(),
content: vec![ContentItem::OutputText {
text: self.cumulative.clone(),
}],
});
let completion_event = ResponseEvent::Completed {
response_id,
token_usage,
};
if matches!(self.mode, AggregateMode::Streaming) {
self.pending.push_back(assistant_event);
self.pending.push_back(completion_event);
} else {
return std::task::Poll::Ready(Some(Ok(assistant_event)));
}
}
std::task::Poll::Ready(Some(Ok(ev))) => {
return std::task::Poll::Ready(Some(Ok(ev)));
}
}
if let Some(ev) = self.pending.pop_front() {
return std::task::Poll::Ready(Some(Ok(ev)));
}
}
}
}

View File

@@ -0,0 +1,31 @@
use std::time::Duration;
use codex_protocol::protocol::SessionSource;
use codex_protocol::protocol::SubAgentSource;
/// Exponential backoff with a 100ms base and a cap on the exponent to avoid
/// unbounded growth. The attempt number is clamped to [0, 6].
pub(crate) fn backoff(attempt: i64) -> Duration {
let capped = attempt.clamp(0, 6) as u32;
Duration::from_millis(100 * 2_i64.pow(capped) as u64)
}
/// Apply the `x-openai-subagent` header when the session source indicates a
/// subagent. Returns the original builder unchanged when not applicable.
pub(crate) fn apply_subagent_header(
mut builder: reqwest::RequestBuilder,
session_source: Option<&SessionSource>,
) -> reqwest::RequestBuilder {
if let Some(SessionSource::SubAgent(sub)) = session_source {
let subagent = if let SubAgentSource::Other(label) = sub {
label.clone()
} else {
serde_json::to_value(sub)
.ok()
.and_then(|v| v.as_str().map(std::string::ToString::to_string))
.unwrap_or_else(|| "other".to_string())
};
builder = builder.header("x-openai-subagent", subagent);
}
builder
}