mirror of
https://github.com/openai/codex.git
synced 2026-05-04 19:36:45 +00:00
chore: proper client extraction (#6996)
This commit is contained in:
27
codex-rs/codex-api/src/auth.rs
Normal file
27
codex-rs/codex-api/src/auth.rs
Normal file
@@ -0,0 +1,27 @@
|
||||
use codex_client::Request;
|
||||
|
||||
/// Provides bearer and account identity information for API requests.
|
||||
///
|
||||
/// Implementations should be cheap and non-blocking; any asynchronous
|
||||
/// refresh or I/O should be handled by higher layers before requests
|
||||
/// reach this interface.
|
||||
pub trait AuthProvider: Send + Sync {
|
||||
fn bearer_token(&self) -> Option<String>;
|
||||
fn account_id(&self) -> Option<String> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn add_auth_headers<A: AuthProvider>(auth: &A, mut req: Request) -> Request {
|
||||
if let Some(token) = auth.bearer_token()
|
||||
&& let Ok(header) = format!("Bearer {token}").parse()
|
||||
{
|
||||
let _ = req.headers.insert(http::header::AUTHORIZATION, header);
|
||||
}
|
||||
if let Some(account_id) = auth.account_id()
|
||||
&& let Ok(header) = account_id.parse()
|
||||
{
|
||||
let _ = req.headers.insert("ChatGPT-Account-ID", header);
|
||||
}
|
||||
req
|
||||
}
|
||||
167
codex-rs/codex-api/src/common.rs
Normal file
167
codex-rs/codex-api/src/common.rs
Normal file
@@ -0,0 +1,167 @@
|
||||
use crate::error::ApiError;
|
||||
use codex_protocol::config_types::ReasoningEffort as ReasoningEffortConfig;
|
||||
use codex_protocol::config_types::ReasoningSummary as ReasoningSummaryConfig;
|
||||
use codex_protocol::config_types::Verbosity as VerbosityConfig;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::protocol::RateLimitSnapshot;
|
||||
use codex_protocol::protocol::TokenUsage;
|
||||
use futures::Stream;
|
||||
use serde::Serialize;
|
||||
use serde_json::Value;
|
||||
use std::pin::Pin;
|
||||
use std::task::Context;
|
||||
use std::task::Poll;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
/// Canonical prompt input for Chat and Responses endpoints.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Prompt {
|
||||
/// Fully-resolved system instructions for this turn.
|
||||
pub instructions: String,
|
||||
/// Conversation history and user/tool messages.
|
||||
pub input: Vec<ResponseItem>,
|
||||
/// JSON-encoded tool definitions compatible with the target API.
|
||||
// TODO(jif) have a proper type here
|
||||
pub tools: Vec<Value>,
|
||||
/// Whether parallel tool calls are permitted.
|
||||
pub parallel_tool_calls: bool,
|
||||
/// Optional output schema used to build the `text.format` controls.
|
||||
pub output_schema: Option<Value>,
|
||||
}
|
||||
|
||||
/// Canonical input payload for the compaction endpoint.
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct CompactionInput<'a> {
|
||||
pub model: &'a str,
|
||||
pub input: &'a [ResponseItem],
|
||||
pub instructions: &'a str,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum ResponseEvent {
|
||||
Created,
|
||||
OutputItemDone(ResponseItem),
|
||||
OutputItemAdded(ResponseItem),
|
||||
Completed {
|
||||
response_id: String,
|
||||
token_usage: Option<TokenUsage>,
|
||||
},
|
||||
OutputTextDelta(String),
|
||||
ReasoningSummaryDelta {
|
||||
delta: String,
|
||||
summary_index: i64,
|
||||
},
|
||||
ReasoningContentDelta {
|
||||
delta: String,
|
||||
content_index: i64,
|
||||
},
|
||||
ReasoningSummaryPartAdded {
|
||||
summary_index: i64,
|
||||
},
|
||||
RateLimits(RateLimitSnapshot),
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Clone)]
|
||||
pub struct Reasoning {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub effort: Option<ReasoningEffortConfig>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub summary: Option<ReasoningSummaryConfig>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Default, Clone)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum TextFormatType {
|
||||
#[default]
|
||||
JsonSchema,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Default, Clone)]
|
||||
pub struct TextFormat {
|
||||
/// Format type used by the OpenAI text controls.
|
||||
pub r#type: TextFormatType,
|
||||
/// When true, the server is expected to strictly validate responses.
|
||||
pub strict: bool,
|
||||
/// JSON schema for the desired output.
|
||||
pub schema: Value,
|
||||
/// Friendly name for the format, used in telemetry/debugging.
|
||||
pub name: String,
|
||||
}
|
||||
|
||||
/// Controls the `text` field for the Responses API, combining verbosity and
|
||||
/// optional JSON schema output formatting.
|
||||
#[derive(Debug, Serialize, Default, Clone)]
|
||||
pub struct TextControls {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub verbosity: Option<OpenAiVerbosity>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub format: Option<TextFormat>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Default, Clone)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum OpenAiVerbosity {
|
||||
Low,
|
||||
#[default]
|
||||
Medium,
|
||||
High,
|
||||
}
|
||||
|
||||
impl From<VerbosityConfig> for OpenAiVerbosity {
|
||||
fn from(v: VerbosityConfig) -> Self {
|
||||
match v {
|
||||
VerbosityConfig::Low => OpenAiVerbosity::Low,
|
||||
VerbosityConfig::Medium => OpenAiVerbosity::Medium,
|
||||
VerbosityConfig::High => OpenAiVerbosity::High,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct ResponsesApiRequest<'a> {
|
||||
pub model: &'a str,
|
||||
pub instructions: &'a str,
|
||||
pub input: &'a [ResponseItem],
|
||||
pub tools: &'a [serde_json::Value],
|
||||
pub tool_choice: &'static str,
|
||||
pub parallel_tool_calls: bool,
|
||||
pub reasoning: Option<Reasoning>,
|
||||
pub store: bool,
|
||||
pub stream: bool,
|
||||
pub include: Vec<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub prompt_cache_key: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub text: Option<TextControls>,
|
||||
}
|
||||
|
||||
pub fn create_text_param_for_request(
|
||||
verbosity: Option<VerbosityConfig>,
|
||||
output_schema: &Option<Value>,
|
||||
) -> Option<TextControls> {
|
||||
if verbosity.is_none() && output_schema.is_none() {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(TextControls {
|
||||
verbosity: verbosity.map(std::convert::Into::into),
|
||||
format: output_schema.as_ref().map(|schema| TextFormat {
|
||||
r#type: TextFormatType::JsonSchema,
|
||||
strict: true,
|
||||
schema: schema.clone(),
|
||||
name: "codex_output_schema".to_string(),
|
||||
}),
|
||||
})
|
||||
}
|
||||
|
||||
pub struct ResponseStream {
|
||||
pub rx_event: mpsc::Receiver<Result<ResponseEvent, ApiError>>,
|
||||
}
|
||||
|
||||
impl Stream for ResponseStream {
|
||||
type Item = Result<ResponseEvent, ApiError>;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
self.rx_event.poll_recv(cx)
|
||||
}
|
||||
}
|
||||
266
codex-rs/codex-api/src/endpoint/chat.rs
Normal file
266
codex-rs/codex-api/src/endpoint/chat.rs
Normal file
@@ -0,0 +1,266 @@
|
||||
use crate::ChatRequest;
|
||||
use crate::auth::AuthProvider;
|
||||
use crate::common::Prompt as ApiPrompt;
|
||||
use crate::common::ResponseEvent;
|
||||
use crate::common::ResponseStream;
|
||||
use crate::endpoint::streaming::StreamingClient;
|
||||
use crate::error::ApiError;
|
||||
use crate::provider::Provider;
|
||||
use crate::provider::WireApi;
|
||||
use crate::sse::chat::spawn_chat_stream;
|
||||
use crate::telemetry::SseTelemetry;
|
||||
use codex_client::HttpTransport;
|
||||
use codex_client::RequestTelemetry;
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::ReasoningItemContent;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::protocol::SessionSource;
|
||||
use futures::Stream;
|
||||
use http::HeaderMap;
|
||||
use serde_json::Value;
|
||||
use std::collections::VecDeque;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::task::Context;
|
||||
use std::task::Poll;
|
||||
|
||||
pub struct ChatClient<T: HttpTransport, A: AuthProvider> {
|
||||
streaming: StreamingClient<T, A>,
|
||||
}
|
||||
|
||||
impl<T: HttpTransport, A: AuthProvider> ChatClient<T, A> {
|
||||
pub fn new(transport: T, provider: Provider, auth: A) -> Self {
|
||||
Self {
|
||||
streaming: StreamingClient::new(transport, provider, auth),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_telemetry(
|
||||
self,
|
||||
request: Option<Arc<dyn RequestTelemetry>>,
|
||||
sse: Option<Arc<dyn SseTelemetry>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
streaming: self.streaming.with_telemetry(request, sse),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn stream_request(&self, request: ChatRequest) -> Result<ResponseStream, ApiError> {
|
||||
self.stream(request.body, request.headers).await
|
||||
}
|
||||
|
||||
pub async fn stream_prompt(
|
||||
&self,
|
||||
model: &str,
|
||||
prompt: &ApiPrompt,
|
||||
conversation_id: Option<String>,
|
||||
session_source: Option<SessionSource>,
|
||||
) -> Result<ResponseStream, ApiError> {
|
||||
use crate::requests::ChatRequestBuilder;
|
||||
|
||||
let request =
|
||||
ChatRequestBuilder::new(model, &prompt.instructions, &prompt.input, &prompt.tools)
|
||||
.conversation_id(conversation_id)
|
||||
.session_source(session_source)
|
||||
.build(self.streaming.provider())?;
|
||||
|
||||
self.stream_request(request).await
|
||||
}
|
||||
|
||||
fn path(&self) -> &'static str {
|
||||
match self.streaming.provider().wire {
|
||||
WireApi::Chat => "chat/completions",
|
||||
_ => "responses",
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn stream(
|
||||
&self,
|
||||
body: Value,
|
||||
extra_headers: HeaderMap,
|
||||
) -> Result<ResponseStream, ApiError> {
|
||||
self.streaming
|
||||
.stream(self.path(), body, extra_headers, spawn_chat_stream)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Eq, PartialEq)]
|
||||
pub enum AggregateMode {
|
||||
AggregatedOnly,
|
||||
Streaming,
|
||||
}
|
||||
|
||||
/// Stream adapter that merges token deltas into a single assistant message per turn.
|
||||
pub struct AggregatedStream {
|
||||
inner: ResponseStream,
|
||||
cumulative: String,
|
||||
cumulative_reasoning: String,
|
||||
pending: VecDeque<ResponseEvent>,
|
||||
mode: AggregateMode,
|
||||
}
|
||||
|
||||
impl Stream for AggregatedStream {
|
||||
type Item = Result<ResponseEvent, ApiError>;
|
||||
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
let this = self.get_mut();
|
||||
|
||||
if let Some(ev) = this.pending.pop_front() {
|
||||
return Poll::Ready(Some(Ok(ev)));
|
||||
}
|
||||
|
||||
loop {
|
||||
match Pin::new(&mut this.inner).poll_next(cx) {
|
||||
Poll::Pending => return Poll::Pending,
|
||||
Poll::Ready(None) => return Poll::Ready(None),
|
||||
Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
|
||||
Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item)))) => {
|
||||
let is_assistant_message = matches!(
|
||||
&item,
|
||||
ResponseItem::Message { role, .. } if role == "assistant"
|
||||
);
|
||||
|
||||
if is_assistant_message {
|
||||
match this.mode {
|
||||
AggregateMode::AggregatedOnly => {
|
||||
if this.cumulative.is_empty()
|
||||
&& let ResponseItem::Message { content, .. } = &item
|
||||
&& let Some(text) = content.iter().find_map(|c| match c {
|
||||
ContentItem::OutputText { text } => Some(text),
|
||||
_ => None,
|
||||
})
|
||||
{
|
||||
this.cumulative.push_str(text);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
AggregateMode::Streaming => {
|
||||
if this.cumulative.is_empty() {
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(
|
||||
item,
|
||||
))));
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item))));
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::RateLimits(snapshot)))) => {
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::RateLimits(snapshot))));
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::Completed {
|
||||
response_id,
|
||||
token_usage,
|
||||
}))) => {
|
||||
let mut emitted_any = false;
|
||||
|
||||
if !this.cumulative_reasoning.is_empty() {
|
||||
let aggregated_reasoning = ResponseItem::Reasoning {
|
||||
id: String::new(),
|
||||
summary: Vec::new(),
|
||||
content: Some(vec![ReasoningItemContent::ReasoningText {
|
||||
text: std::mem::take(&mut this.cumulative_reasoning),
|
||||
}]),
|
||||
encrypted_content: None,
|
||||
};
|
||||
this.pending
|
||||
.push_back(ResponseEvent::OutputItemDone(aggregated_reasoning));
|
||||
emitted_any = true;
|
||||
}
|
||||
|
||||
if !this.cumulative.is_empty() {
|
||||
let aggregated_message = ResponseItem::Message {
|
||||
id: None,
|
||||
role: "assistant".to_string(),
|
||||
content: vec![ContentItem::OutputText {
|
||||
text: std::mem::take(&mut this.cumulative),
|
||||
}],
|
||||
};
|
||||
this.pending
|
||||
.push_back(ResponseEvent::OutputItemDone(aggregated_message));
|
||||
emitted_any = true;
|
||||
}
|
||||
|
||||
if emitted_any {
|
||||
this.pending.push_back(ResponseEvent::Completed {
|
||||
response_id: response_id.clone(),
|
||||
token_usage: token_usage.clone(),
|
||||
});
|
||||
if let Some(ev) = this.pending.pop_front() {
|
||||
return Poll::Ready(Some(Ok(ev)));
|
||||
}
|
||||
}
|
||||
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::Completed {
|
||||
response_id,
|
||||
token_usage,
|
||||
})));
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::Created))) => {
|
||||
continue;
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::OutputTextDelta(delta)))) => {
|
||||
this.cumulative.push_str(&delta);
|
||||
if matches!(this.mode, AggregateMode::Streaming) {
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::OutputTextDelta(delta))));
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::ReasoningContentDelta {
|
||||
delta,
|
||||
content_index,
|
||||
}))) => {
|
||||
this.cumulative_reasoning.push_str(&delta);
|
||||
if matches!(this.mode, AggregateMode::Streaming) {
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::ReasoningContentDelta {
|
||||
delta,
|
||||
content_index,
|
||||
})));
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::ReasoningSummaryDelta { .. }))) => continue,
|
||||
Poll::Ready(Some(Ok(ResponseEvent::ReasoningSummaryPartAdded { .. }))) => {
|
||||
continue;
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::OutputItemAdded(item)))) => {
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::OutputItemAdded(item))));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait AggregateStreamExt {
|
||||
fn aggregate(self) -> AggregatedStream;
|
||||
|
||||
fn streaming_mode(self) -> ResponseStream;
|
||||
}
|
||||
|
||||
impl AggregateStreamExt for ResponseStream {
|
||||
fn aggregate(self) -> AggregatedStream {
|
||||
AggregatedStream::new(self, AggregateMode::AggregatedOnly)
|
||||
}
|
||||
|
||||
fn streaming_mode(self) -> ResponseStream {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl AggregatedStream {
|
||||
fn new(inner: ResponseStream, mode: AggregateMode) -> Self {
|
||||
AggregatedStream {
|
||||
inner,
|
||||
cumulative: String::new(),
|
||||
cumulative_reasoning: String::new(),
|
||||
pending: VecDeque::new(),
|
||||
mode,
|
||||
}
|
||||
}
|
||||
}
|
||||
162
codex-rs/codex-api/src/endpoint/compact.rs
Normal file
162
codex-rs/codex-api/src/endpoint/compact.rs
Normal file
@@ -0,0 +1,162 @@
|
||||
use crate::auth::AuthProvider;
|
||||
use crate::auth::add_auth_headers;
|
||||
use crate::common::CompactionInput;
|
||||
use crate::error::ApiError;
|
||||
use crate::provider::Provider;
|
||||
use crate::provider::WireApi;
|
||||
use crate::telemetry::run_with_request_telemetry;
|
||||
use codex_client::HttpTransport;
|
||||
use codex_client::RequestTelemetry;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use http::HeaderMap;
|
||||
use http::Method;
|
||||
use serde::Deserialize;
|
||||
use serde_json::to_value;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub struct CompactClient<T: HttpTransport, A: AuthProvider> {
|
||||
transport: T,
|
||||
provider: Provider,
|
||||
auth: A,
|
||||
request_telemetry: Option<Arc<dyn RequestTelemetry>>,
|
||||
}
|
||||
|
||||
impl<T: HttpTransport, A: AuthProvider> CompactClient<T, A> {
|
||||
pub fn new(transport: T, provider: Provider, auth: A) -> Self {
|
||||
Self {
|
||||
transport,
|
||||
provider,
|
||||
auth,
|
||||
request_telemetry: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_telemetry(mut self, request: Option<Arc<dyn RequestTelemetry>>) -> Self {
|
||||
self.request_telemetry = request;
|
||||
self
|
||||
}
|
||||
|
||||
fn path(&self) -> Result<&'static str, ApiError> {
|
||||
match self.provider.wire {
|
||||
WireApi::Compact | WireApi::Responses => Ok("responses/compact"),
|
||||
WireApi::Chat => Err(ApiError::Stream(
|
||||
"compact endpoint requires responses wire api".to_string(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn compact(
|
||||
&self,
|
||||
body: serde_json::Value,
|
||||
extra_headers: HeaderMap,
|
||||
) -> Result<Vec<ResponseItem>, ApiError> {
|
||||
let path = self.path()?;
|
||||
let builder = || {
|
||||
let mut req = self.provider.build_request(Method::POST, path);
|
||||
req.headers.extend(extra_headers.clone());
|
||||
req.body = Some(body.clone());
|
||||
add_auth_headers(&self.auth, req)
|
||||
};
|
||||
|
||||
let resp = run_with_request_telemetry(
|
||||
self.provider.retry.to_policy(),
|
||||
self.request_telemetry.clone(),
|
||||
builder,
|
||||
|req| self.transport.execute(req),
|
||||
)
|
||||
.await?;
|
||||
let parsed: CompactHistoryResponse =
|
||||
serde_json::from_slice(&resp.body).map_err(|e| ApiError::Stream(e.to_string()))?;
|
||||
Ok(parsed.output)
|
||||
}
|
||||
|
||||
pub async fn compact_input(
|
||||
&self,
|
||||
input: &CompactionInput<'_>,
|
||||
extra_headers: HeaderMap,
|
||||
) -> Result<Vec<ResponseItem>, ApiError> {
|
||||
let body = to_value(input)
|
||||
.map_err(|e| ApiError::Stream(format!("failed to encode compaction input: {e}")))?;
|
||||
self.compact(body, extra_headers).await
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct CompactHistoryResponse {
|
||||
output: Vec<ResponseItem>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::provider::RetryConfig;
|
||||
use async_trait::async_trait;
|
||||
use codex_client::Request;
|
||||
use codex_client::Response;
|
||||
use codex_client::StreamResponse;
|
||||
use codex_client::TransportError;
|
||||
use http::HeaderMap;
|
||||
use std::time::Duration;
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
struct DummyTransport;
|
||||
|
||||
#[async_trait]
|
||||
impl HttpTransport for DummyTransport {
|
||||
async fn execute(&self, _req: Request) -> Result<Response, TransportError> {
|
||||
Err(TransportError::Build("execute should not run".to_string()))
|
||||
}
|
||||
|
||||
async fn stream(&self, _req: Request) -> Result<StreamResponse, TransportError> {
|
||||
Err(TransportError::Build("stream should not run".to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
struct DummyAuth;
|
||||
|
||||
impl AuthProvider for DummyAuth {
|
||||
fn bearer_token(&self) -> Option<String> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn provider(wire: WireApi) -> Provider {
|
||||
Provider {
|
||||
name: "test".to_string(),
|
||||
base_url: "https://example.com/v1".to_string(),
|
||||
query_params: None,
|
||||
wire,
|
||||
headers: HeaderMap::new(),
|
||||
retry: RetryConfig {
|
||||
max_attempts: 1,
|
||||
base_delay: Duration::from_millis(1),
|
||||
retry_429: false,
|
||||
retry_5xx: true,
|
||||
retry_transport: true,
|
||||
},
|
||||
stream_idle_timeout: Duration::from_secs(1),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn errors_when_wire_is_chat() {
|
||||
let client = CompactClient::new(DummyTransport, provider(WireApi::Chat), DummyAuth);
|
||||
let input = CompactionInput {
|
||||
model: "gpt-test",
|
||||
input: &[],
|
||||
instructions: "inst",
|
||||
};
|
||||
let err = client
|
||||
.compact_input(&input, HeaderMap::new())
|
||||
.await
|
||||
.expect_err("expected wire mismatch to fail");
|
||||
|
||||
match err {
|
||||
ApiError::Stream(msg) => {
|
||||
assert_eq!(msg, "compact endpoint requires responses wire api");
|
||||
}
|
||||
other => panic!("unexpected error: {other:?}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
4
codex-rs/codex-api/src/endpoint/mod.rs
Normal file
4
codex-rs/codex-api/src/endpoint/mod.rs
Normal file
@@ -0,0 +1,4 @@
|
||||
pub mod chat;
|
||||
pub mod compact;
|
||||
pub mod responses;
|
||||
mod streaming;
|
||||
107
codex-rs/codex-api/src/endpoint/responses.rs
Normal file
107
codex-rs/codex-api/src/endpoint/responses.rs
Normal file
@@ -0,0 +1,107 @@
|
||||
use crate::auth::AuthProvider;
|
||||
use crate::common::Prompt as ApiPrompt;
|
||||
use crate::common::Reasoning;
|
||||
use crate::common::ResponseStream;
|
||||
use crate::common::TextControls;
|
||||
use crate::endpoint::streaming::StreamingClient;
|
||||
use crate::error::ApiError;
|
||||
use crate::provider::Provider;
|
||||
use crate::provider::WireApi;
|
||||
use crate::requests::ResponsesRequest;
|
||||
use crate::requests::ResponsesRequestBuilder;
|
||||
use crate::sse::spawn_response_stream;
|
||||
use crate::telemetry::SseTelemetry;
|
||||
use codex_client::HttpTransport;
|
||||
use codex_client::RequestTelemetry;
|
||||
use codex_protocol::protocol::SessionSource;
|
||||
use http::HeaderMap;
|
||||
use serde_json::Value;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub struct ResponsesClient<T: HttpTransport, A: AuthProvider> {
|
||||
streaming: StreamingClient<T, A>,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct ResponsesOptions {
|
||||
pub reasoning: Option<Reasoning>,
|
||||
pub include: Vec<String>,
|
||||
pub prompt_cache_key: Option<String>,
|
||||
pub text: Option<TextControls>,
|
||||
pub store_override: Option<bool>,
|
||||
pub conversation_id: Option<String>,
|
||||
pub session_source: Option<SessionSource>,
|
||||
}
|
||||
|
||||
impl<T: HttpTransport, A: AuthProvider> ResponsesClient<T, A> {
|
||||
pub fn new(transport: T, provider: Provider, auth: A) -> Self {
|
||||
Self {
|
||||
streaming: StreamingClient::new(transport, provider, auth),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_telemetry(
|
||||
self,
|
||||
request: Option<Arc<dyn RequestTelemetry>>,
|
||||
sse: Option<Arc<dyn SseTelemetry>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
streaming: self.streaming.with_telemetry(request, sse),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn stream_request(
|
||||
&self,
|
||||
request: ResponsesRequest,
|
||||
) -> Result<ResponseStream, ApiError> {
|
||||
self.stream(request.body, request.headers).await
|
||||
}
|
||||
|
||||
pub async fn stream_prompt(
|
||||
&self,
|
||||
model: &str,
|
||||
prompt: &ApiPrompt,
|
||||
options: ResponsesOptions,
|
||||
) -> Result<ResponseStream, ApiError> {
|
||||
let ResponsesOptions {
|
||||
reasoning,
|
||||
include,
|
||||
prompt_cache_key,
|
||||
text,
|
||||
store_override,
|
||||
conversation_id,
|
||||
session_source,
|
||||
} = options;
|
||||
|
||||
let request = ResponsesRequestBuilder::new(model, &prompt.instructions, &prompt.input)
|
||||
.tools(&prompt.tools)
|
||||
.parallel_tool_calls(prompt.parallel_tool_calls)
|
||||
.reasoning(reasoning)
|
||||
.include(include)
|
||||
.prompt_cache_key(prompt_cache_key)
|
||||
.text(text)
|
||||
.conversation(conversation_id)
|
||||
.session_source(session_source)
|
||||
.store_override(store_override)
|
||||
.build(self.streaming.provider())?;
|
||||
|
||||
self.stream_request(request).await
|
||||
}
|
||||
|
||||
fn path(&self) -> &'static str {
|
||||
match self.streaming.provider().wire {
|
||||
WireApi::Responses | WireApi::Compact => "responses",
|
||||
WireApi::Chat => "chat/completions",
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn stream(
|
||||
&self,
|
||||
body: Value,
|
||||
extra_headers: HeaderMap,
|
||||
) -> Result<ResponseStream, ApiError> {
|
||||
self.streaming
|
||||
.stream(self.path(), body, extra_headers, spawn_response_stream)
|
||||
.await
|
||||
}
|
||||
}
|
||||
82
codex-rs/codex-api/src/endpoint/streaming.rs
Normal file
82
codex-rs/codex-api/src/endpoint/streaming.rs
Normal file
@@ -0,0 +1,82 @@
|
||||
use crate::auth::AuthProvider;
|
||||
use crate::auth::add_auth_headers;
|
||||
use crate::common::ResponseStream;
|
||||
use crate::error::ApiError;
|
||||
use crate::provider::Provider;
|
||||
use crate::telemetry::SseTelemetry;
|
||||
use crate::telemetry::run_with_request_telemetry;
|
||||
use codex_client::HttpTransport;
|
||||
use codex_client::RequestTelemetry;
|
||||
use codex_client::StreamResponse;
|
||||
use http::HeaderMap;
|
||||
use http::Method;
|
||||
use serde_json::Value;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
pub(crate) struct StreamingClient<T: HttpTransport, A: AuthProvider> {
|
||||
transport: T,
|
||||
provider: Provider,
|
||||
auth: A,
|
||||
request_telemetry: Option<Arc<dyn RequestTelemetry>>,
|
||||
sse_telemetry: Option<Arc<dyn SseTelemetry>>,
|
||||
}
|
||||
|
||||
impl<T: HttpTransport, A: AuthProvider> StreamingClient<T, A> {
|
||||
pub(crate) fn new(transport: T, provider: Provider, auth: A) -> Self {
|
||||
Self {
|
||||
transport,
|
||||
provider,
|
||||
auth,
|
||||
request_telemetry: None,
|
||||
sse_telemetry: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn with_telemetry(
|
||||
mut self,
|
||||
request: Option<Arc<dyn RequestTelemetry>>,
|
||||
sse: Option<Arc<dyn SseTelemetry>>,
|
||||
) -> Self {
|
||||
self.request_telemetry = request;
|
||||
self.sse_telemetry = sse;
|
||||
self
|
||||
}
|
||||
|
||||
pub(crate) fn provider(&self) -> &Provider {
|
||||
&self.provider
|
||||
}
|
||||
|
||||
pub(crate) async fn stream(
|
||||
&self,
|
||||
path: &str,
|
||||
body: Value,
|
||||
extra_headers: HeaderMap,
|
||||
spawner: fn(StreamResponse, Duration, Option<Arc<dyn SseTelemetry>>) -> ResponseStream,
|
||||
) -> Result<ResponseStream, ApiError> {
|
||||
let builder = || {
|
||||
let mut req = self.provider.build_request(Method::POST, path);
|
||||
req.headers.extend(extra_headers.clone());
|
||||
req.headers.insert(
|
||||
http::header::ACCEPT,
|
||||
http::HeaderValue::from_static("text/event-stream"),
|
||||
);
|
||||
req.body = Some(body.clone());
|
||||
add_auth_headers(&self.auth, req)
|
||||
};
|
||||
|
||||
let stream_response = run_with_request_telemetry(
|
||||
self.provider.retry.to_policy(),
|
||||
self.request_telemetry.clone(),
|
||||
builder,
|
||||
|req| self.transport.stream(req),
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(spawner(
|
||||
stream_response,
|
||||
self.provider.stream_idle_timeout,
|
||||
self.sse_telemetry.clone(),
|
||||
))
|
||||
}
|
||||
}
|
||||
34
codex-rs/codex-api/src/error.rs
Normal file
34
codex-rs/codex-api/src/error.rs
Normal file
@@ -0,0 +1,34 @@
|
||||
use crate::rate_limits::RateLimitError;
|
||||
use codex_client::TransportError;
|
||||
use http::StatusCode;
|
||||
use std::time::Duration;
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ApiError {
|
||||
#[error(transparent)]
|
||||
Transport(#[from] TransportError),
|
||||
#[error("api error {status}: {message}")]
|
||||
Api { status: StatusCode, message: String },
|
||||
#[error("stream error: {0}")]
|
||||
Stream(String),
|
||||
#[error("context window exceeded")]
|
||||
ContextWindowExceeded,
|
||||
#[error("quota exceeded")]
|
||||
QuotaExceeded,
|
||||
#[error("usage not included")]
|
||||
UsageNotIncluded,
|
||||
#[error("retryable error: {message}")]
|
||||
Retryable {
|
||||
message: String,
|
||||
delay: Option<Duration>,
|
||||
},
|
||||
#[error("rate limit: {0}")]
|
||||
RateLimit(String),
|
||||
}
|
||||
|
||||
impl From<RateLimitError> for ApiError {
|
||||
fn from(err: RateLimitError) -> Self {
|
||||
Self::RateLimit(err.to_string())
|
||||
}
|
||||
}
|
||||
35
codex-rs/codex-api/src/lib.rs
Normal file
35
codex-rs/codex-api/src/lib.rs
Normal file
@@ -0,0 +1,35 @@
|
||||
pub mod auth;
|
||||
pub mod common;
|
||||
pub mod endpoint;
|
||||
pub mod error;
|
||||
pub mod provider;
|
||||
pub mod rate_limits;
|
||||
pub mod requests;
|
||||
pub mod sse;
|
||||
pub mod telemetry;
|
||||
|
||||
pub use codex_client::RequestTelemetry;
|
||||
pub use codex_client::ReqwestTransport;
|
||||
pub use codex_client::TransportError;
|
||||
|
||||
pub use crate::auth::AuthProvider;
|
||||
pub use crate::common::CompactionInput;
|
||||
pub use crate::common::Prompt;
|
||||
pub use crate::common::ResponseEvent;
|
||||
pub use crate::common::ResponseStream;
|
||||
pub use crate::common::ResponsesApiRequest;
|
||||
pub use crate::common::create_text_param_for_request;
|
||||
pub use crate::endpoint::chat::AggregateStreamExt;
|
||||
pub use crate::endpoint::chat::ChatClient;
|
||||
pub use crate::endpoint::compact::CompactClient;
|
||||
pub use crate::endpoint::responses::ResponsesClient;
|
||||
pub use crate::endpoint::responses::ResponsesOptions;
|
||||
pub use crate::error::ApiError;
|
||||
pub use crate::provider::Provider;
|
||||
pub use crate::provider::WireApi;
|
||||
pub use crate::requests::ChatRequest;
|
||||
pub use crate::requests::ChatRequestBuilder;
|
||||
pub use crate::requests::ResponsesRequest;
|
||||
pub use crate::requests::ResponsesRequestBuilder;
|
||||
pub use crate::sse::stream_from_fixture;
|
||||
pub use crate::telemetry::SseTelemetry;
|
||||
118
codex-rs/codex-api/src/provider.rs
Normal file
118
codex-rs/codex-api/src/provider.rs
Normal file
@@ -0,0 +1,118 @@
|
||||
use codex_client::Request;
|
||||
use codex_client::RetryOn;
|
||||
use codex_client::RetryPolicy;
|
||||
use http::Method;
|
||||
use http::header::HeaderMap;
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
|
||||
/// Wire-level APIs supported by a `Provider`.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum WireApi {
|
||||
Responses,
|
||||
Chat,
|
||||
Compact,
|
||||
}
|
||||
|
||||
/// High-level retry configuration for a provider.
|
||||
///
|
||||
/// This is converted into a `RetryPolicy` used by `codex-client` to drive
|
||||
/// transport-level retries for both unary and streaming calls.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RetryConfig {
|
||||
pub max_attempts: u64,
|
||||
pub base_delay: Duration,
|
||||
pub retry_429: bool,
|
||||
pub retry_5xx: bool,
|
||||
pub retry_transport: bool,
|
||||
}
|
||||
|
||||
impl RetryConfig {
|
||||
pub fn to_policy(&self) -> RetryPolicy {
|
||||
RetryPolicy {
|
||||
max_attempts: self.max_attempts,
|
||||
base_delay: self.base_delay,
|
||||
retry_on: RetryOn {
|
||||
retry_429: self.retry_429,
|
||||
retry_5xx: self.retry_5xx,
|
||||
retry_transport: self.retry_transport,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// HTTP endpoint configuration used to talk to a concrete API deployment.
|
||||
///
|
||||
/// Encapsulates base URL, default headers, query params, retry policy, and
|
||||
/// stream idle timeout, plus helper methods for building requests.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Provider {
|
||||
pub name: String,
|
||||
pub base_url: String,
|
||||
pub query_params: Option<HashMap<String, String>>,
|
||||
pub wire: WireApi,
|
||||
pub headers: HeaderMap,
|
||||
pub retry: RetryConfig,
|
||||
pub stream_idle_timeout: Duration,
|
||||
}
|
||||
|
||||
impl Provider {
|
||||
pub fn url_for_path(&self, path: &str) -> String {
|
||||
let base = self.base_url.trim_end_matches('/');
|
||||
let path = path.trim_start_matches('/');
|
||||
let mut url = if path.is_empty() {
|
||||
base.to_string()
|
||||
} else {
|
||||
format!("{base}/{path}")
|
||||
};
|
||||
|
||||
if let Some(params) = &self.query_params
|
||||
&& !params.is_empty()
|
||||
{
|
||||
let qs = params
|
||||
.iter()
|
||||
.map(|(k, v)| format!("{k}={v}"))
|
||||
.collect::<Vec<_>>()
|
||||
.join("&");
|
||||
url.push('?');
|
||||
url.push_str(&qs);
|
||||
}
|
||||
|
||||
url
|
||||
}
|
||||
|
||||
pub fn build_request(&self, method: Method, path: &str) -> Request {
|
||||
Request {
|
||||
method,
|
||||
url: self.url_for_path(path),
|
||||
headers: self.headers.clone(),
|
||||
body: None,
|
||||
timeout: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_azure_responses_endpoint(&self) -> bool {
|
||||
if self.wire != WireApi::Responses {
|
||||
return false;
|
||||
}
|
||||
|
||||
if self.name.eq_ignore_ascii_case("azure") {
|
||||
return true;
|
||||
}
|
||||
|
||||
self.base_url.to_ascii_lowercase().contains("openai.azure.")
|
||||
|| matches_azure_responses_base_url(&self.base_url)
|
||||
}
|
||||
}
|
||||
|
||||
fn matches_azure_responses_base_url(base_url: &str) -> bool {
|
||||
const AZURE_MARKERS: [&str; 5] = [
|
||||
"cognitiveservices.azure.",
|
||||
"aoai.azure.",
|
||||
"azure-api.",
|
||||
"azurefd.",
|
||||
"windows.net/openai",
|
||||
];
|
||||
let base = base_url.to_ascii_lowercase();
|
||||
AZURE_MARKERS.iter().any(|marker| base.contains(marker))
|
||||
}
|
||||
105
codex-rs/codex-api/src/rate_limits.rs
Normal file
105
codex-rs/codex-api/src/rate_limits.rs
Normal file
@@ -0,0 +1,105 @@
|
||||
use codex_protocol::protocol::CreditsSnapshot;
|
||||
use codex_protocol::protocol::RateLimitSnapshot;
|
||||
use codex_protocol::protocol::RateLimitWindow;
|
||||
use http::HeaderMap;
|
||||
use std::fmt::Display;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct RateLimitError {
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
impl Display for RateLimitError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.message)
|
||||
}
|
||||
}
|
||||
|
||||
/// Parses the bespoke Codex rate-limit headers into a `RateLimitSnapshot`.
|
||||
pub fn parse_rate_limit(headers: &HeaderMap) -> Option<RateLimitSnapshot> {
|
||||
let primary = parse_rate_limit_window(
|
||||
headers,
|
||||
"x-codex-primary-used-percent",
|
||||
"x-codex-primary-window-minutes",
|
||||
"x-codex-primary-reset-at",
|
||||
);
|
||||
|
||||
let secondary = parse_rate_limit_window(
|
||||
headers,
|
||||
"x-codex-secondary-used-percent",
|
||||
"x-codex-secondary-window-minutes",
|
||||
"x-codex-secondary-reset-at",
|
||||
);
|
||||
|
||||
let credits = parse_credits_snapshot(headers);
|
||||
|
||||
Some(RateLimitSnapshot {
|
||||
primary,
|
||||
secondary,
|
||||
credits,
|
||||
})
|
||||
}
|
||||
|
||||
fn parse_rate_limit_window(
|
||||
headers: &HeaderMap,
|
||||
used_percent_header: &str,
|
||||
window_minutes_header: &str,
|
||||
resets_at_header: &str,
|
||||
) -> Option<RateLimitWindow> {
|
||||
let used_percent: Option<f64> = parse_header_f64(headers, used_percent_header);
|
||||
|
||||
used_percent.and_then(|used_percent| {
|
||||
let window_minutes = parse_header_i64(headers, window_minutes_header);
|
||||
let resets_at = parse_header_i64(headers, resets_at_header);
|
||||
|
||||
let has_data = used_percent != 0.0
|
||||
|| window_minutes.is_some_and(|minutes| minutes != 0)
|
||||
|| resets_at.is_some();
|
||||
|
||||
has_data.then_some(RateLimitWindow {
|
||||
used_percent,
|
||||
window_minutes,
|
||||
resets_at,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn parse_credits_snapshot(headers: &HeaderMap) -> Option<CreditsSnapshot> {
|
||||
let has_credits = parse_header_bool(headers, "x-codex-credits-has-credits")?;
|
||||
let unlimited = parse_header_bool(headers, "x-codex-credits-unlimited")?;
|
||||
let balance = parse_header_str(headers, "x-codex-credits-balance")
|
||||
.map(str::trim)
|
||||
.filter(|value| !value.is_empty())
|
||||
.map(std::string::ToString::to_string);
|
||||
Some(CreditsSnapshot {
|
||||
has_credits,
|
||||
unlimited,
|
||||
balance,
|
||||
})
|
||||
}
|
||||
|
||||
fn parse_header_f64(headers: &HeaderMap, name: &str) -> Option<f64> {
|
||||
parse_header_str(headers, name)?
|
||||
.parse::<f64>()
|
||||
.ok()
|
||||
.filter(|v| v.is_finite())
|
||||
}
|
||||
|
||||
fn parse_header_i64(headers: &HeaderMap, name: &str) -> Option<i64> {
|
||||
parse_header_str(headers, name)?.parse::<i64>().ok()
|
||||
}
|
||||
|
||||
fn parse_header_bool(headers: &HeaderMap, name: &str) -> Option<bool> {
|
||||
let raw = parse_header_str(headers, name)?;
|
||||
if raw.eq_ignore_ascii_case("true") || raw == "1" {
|
||||
Some(true)
|
||||
} else if raw.eq_ignore_ascii_case("false") || raw == "0" {
|
||||
Some(false)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_header_str<'a>(headers: &'a HeaderMap, name: &str) -> Option<&'a str> {
|
||||
headers.get(name)?.to_str().ok()
|
||||
}
|
||||
388
codex-rs/codex-api/src/requests/chat.rs
Normal file
388
codex-rs/codex-api/src/requests/chat.rs
Normal file
@@ -0,0 +1,388 @@
|
||||
use crate::error::ApiError;
|
||||
use crate::provider::Provider;
|
||||
use crate::requests::headers::build_conversation_headers;
|
||||
use crate::requests::headers::insert_header;
|
||||
use crate::requests::headers::subagent_header;
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::FunctionCallOutputContentItem;
|
||||
use codex_protocol::models::ReasoningItemContent;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::protocol::SessionSource;
|
||||
use http::HeaderMap;
|
||||
use serde_json::Value;
|
||||
use serde_json::json;
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Assembled request body plus headers for Chat Completions streaming calls.
|
||||
pub struct ChatRequest {
|
||||
pub body: Value,
|
||||
pub headers: HeaderMap,
|
||||
}
|
||||
|
||||
pub struct ChatRequestBuilder<'a> {
|
||||
model: &'a str,
|
||||
instructions: &'a str,
|
||||
input: &'a [ResponseItem],
|
||||
tools: &'a [Value],
|
||||
conversation_id: Option<String>,
|
||||
session_source: Option<SessionSource>,
|
||||
}
|
||||
|
||||
impl<'a> ChatRequestBuilder<'a> {
|
||||
pub fn new(
|
||||
model: &'a str,
|
||||
instructions: &'a str,
|
||||
input: &'a [ResponseItem],
|
||||
tools: &'a [Value],
|
||||
) -> Self {
|
||||
Self {
|
||||
model,
|
||||
instructions,
|
||||
input,
|
||||
tools,
|
||||
conversation_id: None,
|
||||
session_source: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn conversation_id(mut self, id: Option<String>) -> Self {
|
||||
self.conversation_id = id;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn session_source(mut self, source: Option<SessionSource>) -> Self {
|
||||
self.session_source = source;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self, _provider: &Provider) -> Result<ChatRequest, ApiError> {
|
||||
let mut messages = Vec::<Value>::new();
|
||||
messages.push(json!({"role": "system", "content": self.instructions}));
|
||||
|
||||
let input = self.input;
|
||||
let mut reasoning_by_anchor_index: HashMap<usize, String> = HashMap::new();
|
||||
let mut last_emitted_role: Option<&str> = None;
|
||||
for item in input {
|
||||
match item {
|
||||
ResponseItem::Message { role, .. } => last_emitted_role = Some(role.as_str()),
|
||||
ResponseItem::FunctionCall { .. } | ResponseItem::LocalShellCall { .. } => {
|
||||
last_emitted_role = Some("assistant")
|
||||
}
|
||||
ResponseItem::FunctionCallOutput { .. } => last_emitted_role = Some("tool"),
|
||||
ResponseItem::Reasoning { .. } | ResponseItem::Other => {}
|
||||
ResponseItem::CustomToolCall { .. } => {}
|
||||
ResponseItem::CustomToolCallOutput { .. } => {}
|
||||
ResponseItem::WebSearchCall { .. } => {}
|
||||
ResponseItem::GhostSnapshot { .. } => {}
|
||||
ResponseItem::CompactionSummary { .. } => {}
|
||||
}
|
||||
}
|
||||
|
||||
let mut last_user_index: Option<usize> = None;
|
||||
for (idx, item) in input.iter().enumerate() {
|
||||
if let ResponseItem::Message { role, .. } = item
|
||||
&& role == "user"
|
||||
{
|
||||
last_user_index = Some(idx);
|
||||
}
|
||||
}
|
||||
|
||||
if !matches!(last_emitted_role, Some("user")) {
|
||||
for (idx, item) in input.iter().enumerate() {
|
||||
if let Some(u_idx) = last_user_index
|
||||
&& idx <= u_idx
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
if let ResponseItem::Reasoning {
|
||||
content: Some(items),
|
||||
..
|
||||
} = item
|
||||
{
|
||||
let mut text = String::new();
|
||||
for entry in items {
|
||||
match entry {
|
||||
ReasoningItemContent::ReasoningText { text: segment }
|
||||
| ReasoningItemContent::Text { text: segment } => {
|
||||
text.push_str(segment)
|
||||
}
|
||||
}
|
||||
}
|
||||
if text.trim().is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut attached = false;
|
||||
if idx > 0
|
||||
&& let ResponseItem::Message { role, .. } = &input[idx - 1]
|
||||
&& role == "assistant"
|
||||
{
|
||||
reasoning_by_anchor_index
|
||||
.entry(idx - 1)
|
||||
.and_modify(|v| v.push_str(&text))
|
||||
.or_insert(text.clone());
|
||||
attached = true;
|
||||
}
|
||||
|
||||
if !attached && idx + 1 < input.len() {
|
||||
match &input[idx + 1] {
|
||||
ResponseItem::FunctionCall { .. }
|
||||
| ResponseItem::LocalShellCall { .. } => {
|
||||
reasoning_by_anchor_index
|
||||
.entry(idx + 1)
|
||||
.and_modify(|v| v.push_str(&text))
|
||||
.or_insert(text.clone());
|
||||
}
|
||||
ResponseItem::Message { role, .. } if role == "assistant" => {
|
||||
reasoning_by_anchor_index
|
||||
.entry(idx + 1)
|
||||
.and_modify(|v| v.push_str(&text))
|
||||
.or_insert(text.clone());
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut last_assistant_text: Option<String> = None;
|
||||
|
||||
for (idx, item) in input.iter().enumerate() {
|
||||
match item {
|
||||
ResponseItem::Message { role, content, .. } => {
|
||||
let mut text = String::new();
|
||||
let mut items: Vec<Value> = Vec::new();
|
||||
let mut saw_image = false;
|
||||
|
||||
for c in content {
|
||||
match c {
|
||||
ContentItem::InputText { text: t }
|
||||
| ContentItem::OutputText { text: t } => {
|
||||
text.push_str(t);
|
||||
items.push(json!({"type":"text","text": t}));
|
||||
}
|
||||
ContentItem::InputImage { image_url } => {
|
||||
saw_image = true;
|
||||
items.push(
|
||||
json!({"type":"image_url","image_url": {"url": image_url}}),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if role == "assistant" {
|
||||
if let Some(prev) = &last_assistant_text
|
||||
&& prev == &text
|
||||
{
|
||||
continue;
|
||||
}
|
||||
last_assistant_text = Some(text.clone());
|
||||
}
|
||||
|
||||
let content_value = if role == "assistant" {
|
||||
json!(text)
|
||||
} else if saw_image {
|
||||
json!(items)
|
||||
} else {
|
||||
json!(text)
|
||||
};
|
||||
|
||||
let mut msg = json!({"role": role, "content": content_value});
|
||||
if role == "assistant"
|
||||
&& let Some(reasoning) = reasoning_by_anchor_index.get(&idx)
|
||||
&& let Some(obj) = msg.as_object_mut()
|
||||
{
|
||||
obj.insert("reasoning".to_string(), json!(reasoning));
|
||||
}
|
||||
messages.push(msg);
|
||||
}
|
||||
ResponseItem::FunctionCall {
|
||||
name,
|
||||
arguments,
|
||||
call_id,
|
||||
..
|
||||
} => {
|
||||
let mut msg = json!({
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"tool_calls": [{
|
||||
"id": call_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": name,
|
||||
"arguments": arguments,
|
||||
}
|
||||
}]
|
||||
});
|
||||
if let Some(reasoning) = reasoning_by_anchor_index.get(&idx)
|
||||
&& let Some(obj) = msg.as_object_mut()
|
||||
{
|
||||
obj.insert("reasoning".to_string(), json!(reasoning));
|
||||
}
|
||||
messages.push(msg);
|
||||
}
|
||||
ResponseItem::LocalShellCall {
|
||||
id,
|
||||
call_id: _,
|
||||
status,
|
||||
action,
|
||||
} => {
|
||||
let mut msg = json!({
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"tool_calls": [{
|
||||
"id": id.clone().unwrap_or_default(),
|
||||
"type": "local_shell_call",
|
||||
"status": status,
|
||||
"action": action,
|
||||
}]
|
||||
});
|
||||
if let Some(reasoning) = reasoning_by_anchor_index.get(&idx)
|
||||
&& let Some(obj) = msg.as_object_mut()
|
||||
{
|
||||
obj.insert("reasoning".to_string(), json!(reasoning));
|
||||
}
|
||||
messages.push(msg);
|
||||
}
|
||||
ResponseItem::FunctionCallOutput { call_id, output } => {
|
||||
let content_value = if let Some(items) = &output.content_items {
|
||||
let mapped: Vec<Value> = items
|
||||
.iter()
|
||||
.map(|it| match it {
|
||||
FunctionCallOutputContentItem::InputText { text } => {
|
||||
json!({"type":"text","text": text})
|
||||
}
|
||||
FunctionCallOutputContentItem::InputImage { image_url } => {
|
||||
json!({"type":"image_url","image_url": {"url": image_url}})
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
json!(mapped)
|
||||
} else {
|
||||
json!(output.content)
|
||||
};
|
||||
|
||||
messages.push(json!({
|
||||
"role": "tool",
|
||||
"tool_call_id": call_id,
|
||||
"content": content_value,
|
||||
}));
|
||||
}
|
||||
ResponseItem::CustomToolCall {
|
||||
id,
|
||||
call_id: _,
|
||||
name,
|
||||
input,
|
||||
status: _,
|
||||
} => {
|
||||
messages.push(json!({
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"tool_calls": [{
|
||||
"id": id,
|
||||
"type": "custom",
|
||||
"custom": {
|
||||
"name": name,
|
||||
"input": input,
|
||||
}
|
||||
}]
|
||||
}));
|
||||
}
|
||||
ResponseItem::CustomToolCallOutput { call_id, output } => {
|
||||
messages.push(json!({
|
||||
"role": "tool",
|
||||
"tool_call_id": call_id,
|
||||
"content": output,
|
||||
}));
|
||||
}
|
||||
ResponseItem::GhostSnapshot { .. } => {
|
||||
continue;
|
||||
}
|
||||
ResponseItem::Reasoning { .. }
|
||||
| ResponseItem::WebSearchCall { .. }
|
||||
| ResponseItem::Other
|
||||
| ResponseItem::CompactionSummary { .. } => {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let payload = json!({
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"stream": true,
|
||||
"tools": self.tools,
|
||||
});
|
||||
|
||||
let mut headers = build_conversation_headers(self.conversation_id);
|
||||
if let Some(subagent) = subagent_header(&self.session_source) {
|
||||
insert_header(&mut headers, "x-openai-subagent", &subagent);
|
||||
}
|
||||
|
||||
Ok(ChatRequest {
|
||||
body: payload,
|
||||
headers,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::provider::RetryConfig;
|
||||
use crate::provider::WireApi;
|
||||
use codex_protocol::protocol::SessionSource;
|
||||
use codex_protocol::protocol::SubAgentSource;
|
||||
use http::HeaderValue;
|
||||
use pretty_assertions::assert_eq;
|
||||
use std::time::Duration;
|
||||
|
||||
fn provider() -> Provider {
|
||||
Provider {
|
||||
name: "openai".to_string(),
|
||||
base_url: "https://api.openai.com/v1".to_string(),
|
||||
query_params: None,
|
||||
wire: WireApi::Chat,
|
||||
headers: HeaderMap::new(),
|
||||
retry: RetryConfig {
|
||||
max_attempts: 1,
|
||||
base_delay: Duration::from_millis(10),
|
||||
retry_429: false,
|
||||
retry_5xx: true,
|
||||
retry_transport: true,
|
||||
},
|
||||
stream_idle_timeout: Duration::from_secs(1),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn attaches_conversation_and_subagent_headers() {
|
||||
let prompt_input = vec![ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentItem::InputText {
|
||||
text: "hi".to_string(),
|
||||
}],
|
||||
}];
|
||||
let req = ChatRequestBuilder::new("gpt-test", "inst", &prompt_input, &[])
|
||||
.conversation_id(Some("conv-1".into()))
|
||||
.session_source(Some(SessionSource::SubAgent(SubAgentSource::Review)))
|
||||
.build(&provider())
|
||||
.expect("request");
|
||||
|
||||
assert_eq!(
|
||||
req.headers.get("conversation_id"),
|
||||
Some(&HeaderValue::from_static("conv-1"))
|
||||
);
|
||||
assert_eq!(
|
||||
req.headers.get("session_id"),
|
||||
Some(&HeaderValue::from_static("conv-1"))
|
||||
);
|
||||
assert_eq!(
|
||||
req.headers.get("x-openai-subagent"),
|
||||
Some(&HeaderValue::from_static("review"))
|
||||
);
|
||||
}
|
||||
}
|
||||
36
codex-rs/codex-api/src/requests/headers.rs
Normal file
36
codex-rs/codex-api/src/requests/headers.rs
Normal file
@@ -0,0 +1,36 @@
|
||||
use codex_protocol::protocol::SessionSource;
|
||||
use http::HeaderMap;
|
||||
use http::HeaderValue;
|
||||
|
||||
pub(crate) fn build_conversation_headers(conversation_id: Option<String>) -> HeaderMap {
|
||||
let mut headers = HeaderMap::new();
|
||||
if let Some(id) = conversation_id {
|
||||
insert_header(&mut headers, "conversation_id", &id);
|
||||
insert_header(&mut headers, "session_id", &id);
|
||||
}
|
||||
headers
|
||||
}
|
||||
|
||||
pub(crate) fn subagent_header(source: &Option<SessionSource>) -> Option<String> {
|
||||
let SessionSource::SubAgent(sub) = source.as_ref()? else {
|
||||
return None;
|
||||
};
|
||||
match sub {
|
||||
codex_protocol::protocol::SubAgentSource::Other(label) => Some(label.clone()),
|
||||
other => Some(
|
||||
serde_json::to_value(other)
|
||||
.ok()
|
||||
.and_then(|v| v.as_str().map(std::string::ToString::to_string))
|
||||
.unwrap_or_else(|| "other".to_string()),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn insert_header(headers: &mut HeaderMap, name: &str, value: &str) {
|
||||
if let (Ok(header_name), Ok(header_value)) = (
|
||||
name.parse::<http::HeaderName>(),
|
||||
HeaderValue::from_str(value),
|
||||
) {
|
||||
headers.insert(header_name, header_value);
|
||||
}
|
||||
}
|
||||
8
codex-rs/codex-api/src/requests/mod.rs
Normal file
8
codex-rs/codex-api/src/requests/mod.rs
Normal file
@@ -0,0 +1,8 @@
|
||||
pub mod chat;
|
||||
pub(crate) mod headers;
|
||||
pub mod responses;
|
||||
|
||||
pub use chat::ChatRequest;
|
||||
pub use chat::ChatRequestBuilder;
|
||||
pub use responses::ResponsesRequest;
|
||||
pub use responses::ResponsesRequestBuilder;
|
||||
247
codex-rs/codex-api/src/requests/responses.rs
Normal file
247
codex-rs/codex-api/src/requests/responses.rs
Normal file
@@ -0,0 +1,247 @@
|
||||
use crate::common::Reasoning;
|
||||
use crate::common::ResponsesApiRequest;
|
||||
use crate::common::TextControls;
|
||||
use crate::error::ApiError;
|
||||
use crate::provider::Provider;
|
||||
use crate::requests::headers::build_conversation_headers;
|
||||
use crate::requests::headers::insert_header;
|
||||
use crate::requests::headers::subagent_header;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::protocol::SessionSource;
|
||||
use http::HeaderMap;
|
||||
use serde_json::Value;
|
||||
|
||||
/// Assembled request body plus headers for a Responses stream request.
|
||||
pub struct ResponsesRequest {
|
||||
pub body: Value,
|
||||
pub headers: HeaderMap,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct ResponsesRequestBuilder<'a> {
|
||||
model: Option<&'a str>,
|
||||
instructions: Option<&'a str>,
|
||||
input: Option<&'a [ResponseItem]>,
|
||||
tools: Option<&'a [Value]>,
|
||||
parallel_tool_calls: bool,
|
||||
reasoning: Option<Reasoning>,
|
||||
include: Vec<String>,
|
||||
prompt_cache_key: Option<String>,
|
||||
text: Option<TextControls>,
|
||||
conversation_id: Option<String>,
|
||||
session_source: Option<SessionSource>,
|
||||
store_override: Option<bool>,
|
||||
headers: HeaderMap,
|
||||
}
|
||||
|
||||
impl<'a> ResponsesRequestBuilder<'a> {
|
||||
pub fn new(model: &'a str, instructions: &'a str, input: &'a [ResponseItem]) -> Self {
|
||||
Self {
|
||||
model: Some(model),
|
||||
instructions: Some(instructions),
|
||||
input: Some(input),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tools(mut self, tools: &'a [Value]) -> Self {
|
||||
self.tools = Some(tools);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn parallel_tool_calls(mut self, enabled: bool) -> Self {
|
||||
self.parallel_tool_calls = enabled;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn reasoning(mut self, reasoning: Option<Reasoning>) -> Self {
|
||||
self.reasoning = reasoning;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn include(mut self, include: Vec<String>) -> Self {
|
||||
self.include = include;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn prompt_cache_key(mut self, key: Option<String>) -> Self {
|
||||
self.prompt_cache_key = key;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn text(mut self, text: Option<TextControls>) -> Self {
|
||||
self.text = text;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn conversation(mut self, conversation_id: Option<String>) -> Self {
|
||||
self.conversation_id = conversation_id;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn session_source(mut self, source: Option<SessionSource>) -> Self {
|
||||
self.session_source = source;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn store_override(mut self, store: Option<bool>) -> Self {
|
||||
self.store_override = store;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn extra_headers(mut self, headers: HeaderMap) -> Self {
|
||||
self.headers = headers;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self, provider: &Provider) -> Result<ResponsesRequest, ApiError> {
|
||||
let model = self
|
||||
.model
|
||||
.ok_or_else(|| ApiError::Stream("missing model for responses request".into()))?;
|
||||
let instructions = self
|
||||
.instructions
|
||||
.ok_or_else(|| ApiError::Stream("missing instructions for responses request".into()))?;
|
||||
let input = self
|
||||
.input
|
||||
.ok_or_else(|| ApiError::Stream("missing input for responses request".into()))?;
|
||||
let tools = self.tools.unwrap_or_default();
|
||||
|
||||
let store = self
|
||||
.store_override
|
||||
.unwrap_or_else(|| provider.is_azure_responses_endpoint());
|
||||
|
||||
let req = ResponsesApiRequest {
|
||||
model,
|
||||
instructions,
|
||||
input,
|
||||
tools,
|
||||
tool_choice: "auto",
|
||||
parallel_tool_calls: self.parallel_tool_calls,
|
||||
reasoning: self.reasoning,
|
||||
store,
|
||||
stream: true,
|
||||
include: self.include,
|
||||
prompt_cache_key: self.prompt_cache_key,
|
||||
text: self.text,
|
||||
};
|
||||
|
||||
let mut body = serde_json::to_value(&req)
|
||||
.map_err(|e| ApiError::Stream(format!("failed to encode responses request: {e}")))?;
|
||||
|
||||
if store && provider.is_azure_responses_endpoint() {
|
||||
attach_item_ids(&mut body, input);
|
||||
}
|
||||
|
||||
let mut headers = self.headers;
|
||||
headers.extend(build_conversation_headers(self.conversation_id));
|
||||
if let Some(subagent) = subagent_header(&self.session_source) {
|
||||
insert_header(&mut headers, "x-openai-subagent", &subagent);
|
||||
}
|
||||
|
||||
Ok(ResponsesRequest { body, headers })
|
||||
}
|
||||
}
|
||||
|
||||
fn attach_item_ids(payload_json: &mut Value, original_items: &[ResponseItem]) {
|
||||
let Some(input_value) = payload_json.get_mut("input") else {
|
||||
return;
|
||||
};
|
||||
let Value::Array(items) = input_value else {
|
||||
return;
|
||||
};
|
||||
|
||||
for (value, item) in items.iter_mut().zip(original_items.iter()) {
|
||||
if let ResponseItem::Reasoning { id, .. }
|
||||
| ResponseItem::Message { id: Some(id), .. }
|
||||
| ResponseItem::WebSearchCall { id: Some(id), .. }
|
||||
| ResponseItem::FunctionCall { id: Some(id), .. }
|
||||
| ResponseItem::LocalShellCall { id: Some(id), .. }
|
||||
| ResponseItem::CustomToolCall { id: Some(id), .. } = item
|
||||
{
|
||||
if id.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(obj) = value.as_object_mut() {
|
||||
obj.insert("id".to_string(), Value::String(id.clone()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::provider::RetryConfig;
|
||||
use crate::provider::WireApi;
|
||||
use codex_protocol::protocol::SubAgentSource;
|
||||
use http::HeaderValue;
|
||||
use pretty_assertions::assert_eq;
|
||||
use std::time::Duration;
|
||||
|
||||
fn provider(name: &str, base_url: &str) -> Provider {
|
||||
Provider {
|
||||
name: name.to_string(),
|
||||
base_url: base_url.to_string(),
|
||||
query_params: None,
|
||||
wire: WireApi::Responses,
|
||||
headers: HeaderMap::new(),
|
||||
retry: RetryConfig {
|
||||
max_attempts: 1,
|
||||
base_delay: Duration::from_millis(50),
|
||||
retry_429: false,
|
||||
retry_5xx: true,
|
||||
retry_transport: true,
|
||||
},
|
||||
stream_idle_timeout: Duration::from_secs(5),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn azure_default_store_attaches_ids_and_headers() {
|
||||
let provider = provider("azure", "https://example.openai.azure.com/v1");
|
||||
let input = vec![
|
||||
ResponseItem::Message {
|
||||
id: Some("m1".into()),
|
||||
role: "assistant".into(),
|
||||
content: Vec::new(),
|
||||
},
|
||||
ResponseItem::Message {
|
||||
id: None,
|
||||
role: "assistant".into(),
|
||||
content: Vec::new(),
|
||||
},
|
||||
];
|
||||
|
||||
let request = ResponsesRequestBuilder::new("gpt-test", "inst", &input)
|
||||
.conversation(Some("conv-1".into()))
|
||||
.session_source(Some(SessionSource::SubAgent(SubAgentSource::Review)))
|
||||
.build(&provider)
|
||||
.expect("request");
|
||||
|
||||
assert_eq!(request.body.get("store"), Some(&Value::Bool(true)));
|
||||
|
||||
let ids: Vec<Option<String>> = request
|
||||
.body
|
||||
.get("input")
|
||||
.and_then(|v| v.as_array())
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.map(|item| item.get("id").and_then(|v| v.as_str().map(str::to_string)))
|
||||
.collect();
|
||||
assert_eq!(ids, vec![Some("m1".to_string()), None]);
|
||||
|
||||
assert_eq!(
|
||||
request.headers.get("conversation_id"),
|
||||
Some(&HeaderValue::from_static("conv-1"))
|
||||
);
|
||||
assert_eq!(
|
||||
request.headers.get("session_id"),
|
||||
Some(&HeaderValue::from_static("conv-1"))
|
||||
);
|
||||
assert_eq!(
|
||||
request.headers.get("x-openai-subagent"),
|
||||
Some(&HeaderValue::from_static("review"))
|
||||
);
|
||||
}
|
||||
}
|
||||
504
codex-rs/codex-api/src/sse/chat.rs
Normal file
504
codex-rs/codex-api/src/sse/chat.rs
Normal file
@@ -0,0 +1,504 @@
|
||||
use crate::common::ResponseEvent;
|
||||
use crate::common::ResponseStream;
|
||||
use crate::error::ApiError;
|
||||
use crate::telemetry::SseTelemetry;
|
||||
use codex_client::StreamResponse;
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::ReasoningItemContent;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use eventsource_stream::Eventsource;
|
||||
use futures::Stream;
|
||||
use futures::StreamExt;
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::time::Instant;
|
||||
use tokio::time::timeout;
|
||||
use tracing::debug;
|
||||
use tracing::trace;
|
||||
|
||||
pub(crate) fn spawn_chat_stream(
|
||||
stream_response: StreamResponse,
|
||||
idle_timeout: Duration,
|
||||
telemetry: Option<std::sync::Arc<dyn SseTelemetry>>,
|
||||
) -> ResponseStream {
|
||||
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent, ApiError>>(1600);
|
||||
tokio::spawn(async move {
|
||||
process_chat_sse(stream_response.bytes, tx_event, idle_timeout, telemetry).await;
|
||||
});
|
||||
ResponseStream { rx_event }
|
||||
}
|
||||
|
||||
pub async fn process_chat_sse<S>(
|
||||
stream: S,
|
||||
tx_event: mpsc::Sender<Result<ResponseEvent, ApiError>>,
|
||||
idle_timeout: Duration,
|
||||
telemetry: Option<std::sync::Arc<dyn SseTelemetry>>,
|
||||
) where
|
||||
S: Stream<Item = Result<bytes::Bytes, codex_client::TransportError>> + Unpin,
|
||||
{
|
||||
let mut stream = stream.eventsource();
|
||||
|
||||
#[derive(Default, Debug)]
|
||||
struct ToolCallState {
|
||||
name: Option<String>,
|
||||
arguments: String,
|
||||
}
|
||||
|
||||
let mut tool_calls: HashMap<String, ToolCallState> = HashMap::new();
|
||||
let mut tool_call_order: Vec<String> = Vec::new();
|
||||
let mut assistant_item: Option<ResponseItem> = None;
|
||||
let mut reasoning_item: Option<ResponseItem> = None;
|
||||
let mut completed_sent = false;
|
||||
|
||||
loop {
|
||||
let start = Instant::now();
|
||||
let response = timeout(idle_timeout, stream.next()).await;
|
||||
if let Some(t) = telemetry.as_ref() {
|
||||
t.on_sse_poll(&response, start.elapsed());
|
||||
}
|
||||
let sse = match response {
|
||||
Ok(Some(Ok(sse))) => sse,
|
||||
Ok(Some(Err(e))) => {
|
||||
let _ = tx_event.send(Err(ApiError::Stream(e.to_string()))).await;
|
||||
return;
|
||||
}
|
||||
Ok(None) => {
|
||||
if let Some(reasoning) = reasoning_item {
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::OutputItemDone(reasoning)))
|
||||
.await;
|
||||
}
|
||||
|
||||
if let Some(assistant) = assistant_item {
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::OutputItemDone(assistant)))
|
||||
.await;
|
||||
}
|
||||
if !completed_sent {
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::Completed {
|
||||
response_id: String::new(),
|
||||
token_usage: None,
|
||||
}))
|
||||
.await;
|
||||
}
|
||||
return;
|
||||
}
|
||||
Err(_) => {
|
||||
let _ = tx_event
|
||||
.send(Err(ApiError::Stream("idle timeout waiting for SSE".into())))
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
trace!("SSE event: {}", sse.data);
|
||||
|
||||
if sse.data.trim().is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let value: serde_json::Value = match serde_json::from_str(&sse.data) {
|
||||
Ok(val) => val,
|
||||
Err(err) => {
|
||||
debug!(
|
||||
"Failed to parse ChatCompletions SSE event: {err}, data: {}",
|
||||
&sse.data
|
||||
);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let Some(choices) = value.get("choices").and_then(|c| c.as_array()) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
for choice in choices {
|
||||
if let Some(delta) = choice.get("delta") {
|
||||
if let Some(reasoning) = delta.get("reasoning") {
|
||||
if let Some(text) = reasoning.as_str() {
|
||||
append_reasoning_text(&tx_event, &mut reasoning_item, text.to_string())
|
||||
.await;
|
||||
} else if let Some(text) = reasoning.get("text").and_then(|v| v.as_str()) {
|
||||
append_reasoning_text(&tx_event, &mut reasoning_item, text.to_string())
|
||||
.await;
|
||||
} else if let Some(text) = reasoning.get("content").and_then(|v| v.as_str()) {
|
||||
append_reasoning_text(&tx_event, &mut reasoning_item, text.to_string())
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(content) = delta.get("content") {
|
||||
if content.is_array() {
|
||||
for item in content.as_array().unwrap_or(&vec![]) {
|
||||
if let Some(text) = item.get("text").and_then(|t| t.as_str()) {
|
||||
append_assistant_text(
|
||||
&tx_event,
|
||||
&mut assistant_item,
|
||||
text.to_string(),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
} else if let Some(text) = content.as_str() {
|
||||
append_assistant_text(&tx_event, &mut assistant_item, text.to_string())
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(tool_call_values) = delta.get("tool_calls").and_then(|c| c.as_array()) {
|
||||
for tool_call in tool_call_values {
|
||||
let id = tool_call
|
||||
.get("id")
|
||||
.and_then(|i| i.as_str())
|
||||
.map(str::to_string)
|
||||
.unwrap_or_else(|| format!("tool-call-{}", tool_call_order.len()));
|
||||
|
||||
let call_state = tool_calls.entry(id.clone()).or_default();
|
||||
if !tool_call_order.contains(&id) {
|
||||
tool_call_order.push(id.clone());
|
||||
}
|
||||
|
||||
if let Some(func) = tool_call.get("function") {
|
||||
if let Some(fname) = func.get("name").and_then(|n| n.as_str()) {
|
||||
call_state.name = Some(fname.to_string());
|
||||
}
|
||||
if let Some(arguments) = func.get("arguments").and_then(|a| a.as_str())
|
||||
{
|
||||
call_state.arguments.push_str(arguments);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(message) = choice.get("message")
|
||||
&& let Some(reasoning) = message.get("reasoning")
|
||||
{
|
||||
if let Some(text) = reasoning.as_str() {
|
||||
append_reasoning_text(&tx_event, &mut reasoning_item, text.to_string()).await;
|
||||
} else if let Some(text) = reasoning.get("text").and_then(|v| v.as_str()) {
|
||||
append_reasoning_text(&tx_event, &mut reasoning_item, text.to_string()).await;
|
||||
} else if let Some(text) = reasoning.get("content").and_then(|v| v.as_str()) {
|
||||
append_reasoning_text(&tx_event, &mut reasoning_item, text.to_string()).await;
|
||||
}
|
||||
}
|
||||
|
||||
let finish_reason = choice.get("finish_reason").and_then(|r| r.as_str());
|
||||
if finish_reason == Some("stop") {
|
||||
if let Some(reasoning) = reasoning_item.take() {
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::OutputItemDone(reasoning)))
|
||||
.await;
|
||||
}
|
||||
|
||||
if let Some(assistant) = assistant_item.take() {
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::OutputItemDone(assistant)))
|
||||
.await;
|
||||
}
|
||||
if !completed_sent {
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::Completed {
|
||||
response_id: String::new(),
|
||||
token_usage: None,
|
||||
}))
|
||||
.await;
|
||||
completed_sent = true;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if finish_reason == Some("length") {
|
||||
let _ = tx_event.send(Err(ApiError::ContextWindowExceeded)).await;
|
||||
return;
|
||||
}
|
||||
|
||||
if finish_reason == Some("tool_calls") {
|
||||
if let Some(reasoning) = reasoning_item.take() {
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::OutputItemDone(reasoning)))
|
||||
.await;
|
||||
}
|
||||
|
||||
for call_id in tool_call_order.drain(..) {
|
||||
let state = tool_calls.remove(&call_id).unwrap_or_default();
|
||||
let item = ResponseItem::FunctionCall {
|
||||
id: None,
|
||||
name: state.name.unwrap_or_default(),
|
||||
arguments: state.arguments,
|
||||
call_id: call_id.clone(),
|
||||
};
|
||||
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn append_assistant_text(
|
||||
tx_event: &mpsc::Sender<Result<ResponseEvent, ApiError>>,
|
||||
assistant_item: &mut Option<ResponseItem>,
|
||||
text: String,
|
||||
) {
|
||||
if assistant_item.is_none() {
|
||||
let item = ResponseItem::Message {
|
||||
id: None,
|
||||
role: "assistant".to_string(),
|
||||
content: vec![],
|
||||
};
|
||||
*assistant_item = Some(item.clone());
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::OutputItemAdded(item)))
|
||||
.await;
|
||||
}
|
||||
|
||||
if let Some(ResponseItem::Message { content, .. }) = assistant_item {
|
||||
content.push(ContentItem::OutputText { text: text.clone() });
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::OutputTextDelta(text.clone())))
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn append_reasoning_text(
|
||||
tx_event: &mpsc::Sender<Result<ResponseEvent, ApiError>>,
|
||||
reasoning_item: &mut Option<ResponseItem>,
|
||||
text: String,
|
||||
) {
|
||||
if reasoning_item.is_none() {
|
||||
let item = ResponseItem::Reasoning {
|
||||
id: String::new(),
|
||||
summary: Vec::new(),
|
||||
content: Some(vec![]),
|
||||
encrypted_content: None,
|
||||
};
|
||||
*reasoning_item = Some(item.clone());
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::OutputItemAdded(item)))
|
||||
.await;
|
||||
}
|
||||
|
||||
if let Some(ResponseItem::Reasoning {
|
||||
content: Some(content),
|
||||
..
|
||||
}) = reasoning_item
|
||||
{
|
||||
let content_index = content.len() as i64;
|
||||
content.push(ReasoningItemContent::ReasoningText { text: text.clone() });
|
||||
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::ReasoningContentDelta {
|
||||
delta: text.clone(),
|
||||
content_index,
|
||||
}))
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use assert_matches::assert_matches;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use futures::TryStreamExt;
|
||||
use serde_json::json;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_util::io::ReaderStream;
|
||||
|
||||
fn build_body(events: &[serde_json::Value]) -> String {
|
||||
let mut body = String::new();
|
||||
for e in events {
|
||||
body.push_str(&format!("event: message\ndata: {e}\n\n"));
|
||||
}
|
||||
body
|
||||
}
|
||||
|
||||
async fn collect_events(body: &str) -> Vec<ResponseEvent> {
|
||||
let reader = ReaderStream::new(std::io::Cursor::new(body.to_string()))
|
||||
.map_err(|err| codex_client::TransportError::Network(err.to_string()));
|
||||
let (tx, mut rx) = mpsc::channel::<Result<ResponseEvent, ApiError>>(16);
|
||||
tokio::spawn(process_chat_sse(
|
||||
reader,
|
||||
tx,
|
||||
Duration::from_millis(1000),
|
||||
None,
|
||||
));
|
||||
|
||||
let mut out = Vec::new();
|
||||
while let Some(ev) = rx.recv().await {
|
||||
out.push(ev.expect("stream error"));
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn emits_multiple_tool_calls() {
|
||||
let delta_a = json!({
|
||||
"choices": [{
|
||||
"delta": {
|
||||
"tool_calls": [{
|
||||
"id": "call_a",
|
||||
"function": { "name": "do_a", "arguments": "{\"foo\":1}" }
|
||||
}]
|
||||
}
|
||||
}]
|
||||
});
|
||||
|
||||
let delta_b = json!({
|
||||
"choices": [{
|
||||
"delta": {
|
||||
"tool_calls": [{
|
||||
"id": "call_b",
|
||||
"function": { "name": "do_b", "arguments": "{\"bar\":2}" }
|
||||
}]
|
||||
}
|
||||
}]
|
||||
});
|
||||
|
||||
let finish = json!({
|
||||
"choices": [{
|
||||
"finish_reason": "tool_calls"
|
||||
}]
|
||||
});
|
||||
|
||||
let body = build_body(&[delta_a, delta_b, finish]);
|
||||
let events = collect_events(&body).await;
|
||||
assert_eq!(events.len(), 3);
|
||||
|
||||
assert_matches!(
|
||||
&events[0],
|
||||
ResponseEvent::OutputItemDone(ResponseItem::FunctionCall { call_id, name, arguments, .. })
|
||||
if call_id == "call_a" && name == "do_a" && arguments == "{\"foo\":1}"
|
||||
);
|
||||
assert_matches!(
|
||||
&events[1],
|
||||
ResponseEvent::OutputItemDone(ResponseItem::FunctionCall { call_id, name, arguments, .. })
|
||||
if call_id == "call_b" && name == "do_b" && arguments == "{\"bar\":2}"
|
||||
);
|
||||
assert_matches!(events[2], ResponseEvent::Completed { .. });
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn concatenates_tool_call_arguments_across_deltas() {
|
||||
let delta_name = json!({
|
||||
"choices": [{
|
||||
"delta": {
|
||||
"tool_calls": [{
|
||||
"id": "call_a",
|
||||
"function": { "name": "do_a" }
|
||||
}]
|
||||
}
|
||||
}]
|
||||
});
|
||||
|
||||
let delta_args_1 = json!({
|
||||
"choices": [{
|
||||
"delta": {
|
||||
"tool_calls": [{
|
||||
"id": "call_a",
|
||||
"function": { "arguments": "{ \"foo\":" }
|
||||
}]
|
||||
}
|
||||
}]
|
||||
});
|
||||
|
||||
let delta_args_2 = json!({
|
||||
"choices": [{
|
||||
"delta": {
|
||||
"tool_calls": [{
|
||||
"id": "call_a",
|
||||
"function": { "arguments": "1}" }
|
||||
}]
|
||||
}
|
||||
}]
|
||||
});
|
||||
|
||||
let finish = json!({
|
||||
"choices": [{
|
||||
"finish_reason": "tool_calls"
|
||||
}]
|
||||
});
|
||||
|
||||
let body = build_body(&[delta_name, delta_args_1, delta_args_2, finish]);
|
||||
let events = collect_events(&body).await;
|
||||
assert_matches!(
|
||||
&events[..],
|
||||
[
|
||||
ResponseEvent::OutputItemDone(ResponseItem::FunctionCall { call_id, name, arguments, .. }),
|
||||
ResponseEvent::Completed { .. }
|
||||
] if call_id == "call_a" && name == "do_a" && arguments == "{ \"foo\":1}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn emits_tool_calls_even_when_content_and_reasoning_present() {
|
||||
let delta_content_and_tools = json!({
|
||||
"choices": [{
|
||||
"delta": {
|
||||
"content": [{"text": "hi"}],
|
||||
"reasoning": "because",
|
||||
"tool_calls": [{
|
||||
"id": "call_a",
|
||||
"function": { "name": "do_a", "arguments": "{}" }
|
||||
}]
|
||||
}
|
||||
}]
|
||||
});
|
||||
|
||||
let finish = json!({
|
||||
"choices": [{
|
||||
"finish_reason": "tool_calls"
|
||||
}]
|
||||
});
|
||||
|
||||
let body = build_body(&[delta_content_and_tools, finish]);
|
||||
let events = collect_events(&body).await;
|
||||
|
||||
assert_matches!(
|
||||
&events[..],
|
||||
[
|
||||
ResponseEvent::OutputItemAdded(ResponseItem::Reasoning { .. }),
|
||||
ResponseEvent::ReasoningContentDelta { .. },
|
||||
ResponseEvent::OutputItemAdded(ResponseItem::Message { .. }),
|
||||
ResponseEvent::OutputTextDelta(delta),
|
||||
ResponseEvent::OutputItemDone(ResponseItem::Reasoning { .. }),
|
||||
ResponseEvent::OutputItemDone(ResponseItem::FunctionCall { call_id, name, .. }),
|
||||
ResponseEvent::OutputItemDone(ResponseItem::Message { .. }),
|
||||
ResponseEvent::Completed { .. }
|
||||
] if delta == "hi" && call_id == "call_a" && name == "do_a"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn drops_partial_tool_calls_on_stop_finish_reason() {
|
||||
let delta_tool = json!({
|
||||
"choices": [{
|
||||
"delta": {
|
||||
"tool_calls": [{
|
||||
"id": "call_a",
|
||||
"function": { "name": "do_a", "arguments": "{}" }
|
||||
}]
|
||||
}
|
||||
}]
|
||||
});
|
||||
|
||||
let finish_stop = json!({
|
||||
"choices": [{
|
||||
"finish_reason": "stop"
|
||||
}]
|
||||
});
|
||||
|
||||
let body = build_body(&[delta_tool, finish_stop]);
|
||||
let events = collect_events(&body).await;
|
||||
|
||||
assert!(!events.iter().any(|ev| {
|
||||
matches!(
|
||||
ev,
|
||||
ResponseEvent::OutputItemDone(ResponseItem::FunctionCall { .. })
|
||||
)
|
||||
}));
|
||||
assert_matches!(events.last(), Some(ResponseEvent::Completed { .. }));
|
||||
}
|
||||
}
|
||||
6
codex-rs/codex-api/src/sse/mod.rs
Normal file
6
codex-rs/codex-api/src/sse/mod.rs
Normal file
@@ -0,0 +1,6 @@
|
||||
pub mod chat;
|
||||
pub mod responses;
|
||||
|
||||
pub use responses::process_sse;
|
||||
pub use responses::spawn_response_stream;
|
||||
pub use responses::stream_from_fixture;
|
||||
672
codex-rs/codex-api/src/sse/responses.rs
Normal file
672
codex-rs/codex-api/src/sse/responses.rs
Normal file
@@ -0,0 +1,672 @@
|
||||
use crate::common::ResponseEvent;
|
||||
use crate::common::ResponseStream;
|
||||
use crate::error::ApiError;
|
||||
use crate::rate_limits::parse_rate_limit;
|
||||
use crate::telemetry::SseTelemetry;
|
||||
use codex_client::ByteStream;
|
||||
use codex_client::StreamResponse;
|
||||
use codex_client::TransportError;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::protocol::TokenUsage;
|
||||
use eventsource_stream::Eventsource;
|
||||
use futures::StreamExt;
|
||||
use futures::TryStreamExt;
|
||||
use serde::Deserialize;
|
||||
use serde_json::Value;
|
||||
use std::io::BufRead;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::time::Instant;
|
||||
use tokio::time::timeout;
|
||||
use tokio_util::io::ReaderStream;
|
||||
use tracing::debug;
|
||||
use tracing::trace;
|
||||
|
||||
/// Streams SSE events from an on-disk fixture for tests.
|
||||
pub fn stream_from_fixture(
|
||||
path: impl AsRef<Path>,
|
||||
idle_timeout: Duration,
|
||||
) -> Result<ResponseStream, ApiError> {
|
||||
let file =
|
||||
std::fs::File::open(path.as_ref()).map_err(|err| ApiError::Stream(err.to_string()))?;
|
||||
let mut content = String::new();
|
||||
for line in std::io::BufReader::new(file).lines() {
|
||||
let line = line.map_err(|err| ApiError::Stream(err.to_string()))?;
|
||||
content.push_str(&line);
|
||||
content.push_str("\n\n");
|
||||
}
|
||||
|
||||
let reader = std::io::Cursor::new(content);
|
||||
let stream = ReaderStream::new(reader).map_err(|err| TransportError::Network(err.to_string()));
|
||||
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent, ApiError>>(1600);
|
||||
tokio::spawn(process_sse(Box::pin(stream), tx_event, idle_timeout, None));
|
||||
Ok(ResponseStream { rx_event })
|
||||
}
|
||||
|
||||
pub fn spawn_response_stream(
|
||||
stream_response: StreamResponse,
|
||||
idle_timeout: Duration,
|
||||
telemetry: Option<Arc<dyn SseTelemetry>>,
|
||||
) -> ResponseStream {
|
||||
let rate_limits = parse_rate_limit(&stream_response.headers);
|
||||
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent, ApiError>>(1600);
|
||||
tokio::spawn(async move {
|
||||
if let Some(snapshot) = rate_limits {
|
||||
let _ = tx_event.send(Ok(ResponseEvent::RateLimits(snapshot))).await;
|
||||
}
|
||||
process_sse(stream_response.bytes, tx_event, idle_timeout, telemetry).await;
|
||||
});
|
||||
|
||||
ResponseStream { rx_event }
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[allow(dead_code)]
|
||||
struct Error {
|
||||
r#type: Option<String>,
|
||||
code: Option<String>,
|
||||
message: Option<String>,
|
||||
plan_type: Option<String>,
|
||||
resets_at: Option<i64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[allow(dead_code)]
|
||||
struct ResponseCompleted {
|
||||
id: String,
|
||||
#[serde(default)]
|
||||
usage: Option<ResponseCompletedUsage>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ResponseCompletedUsage {
|
||||
input_tokens: i64,
|
||||
input_tokens_details: Option<ResponseCompletedInputTokensDetails>,
|
||||
output_tokens: i64,
|
||||
output_tokens_details: Option<ResponseCompletedOutputTokensDetails>,
|
||||
total_tokens: i64,
|
||||
}
|
||||
|
||||
impl From<ResponseCompletedUsage> for TokenUsage {
|
||||
fn from(val: ResponseCompletedUsage) -> Self {
|
||||
TokenUsage {
|
||||
input_tokens: val.input_tokens,
|
||||
cached_input_tokens: val
|
||||
.input_tokens_details
|
||||
.map(|d| d.cached_tokens)
|
||||
.unwrap_or(0),
|
||||
output_tokens: val.output_tokens,
|
||||
reasoning_output_tokens: val
|
||||
.output_tokens_details
|
||||
.map(|d| d.reasoning_tokens)
|
||||
.unwrap_or(0),
|
||||
total_tokens: val.total_tokens,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ResponseCompletedInputTokensDetails {
|
||||
cached_tokens: i64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ResponseCompletedOutputTokensDetails {
|
||||
reasoning_tokens: i64,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
struct SseEvent {
|
||||
#[serde(rename = "type")]
|
||||
kind: String,
|
||||
response: Option<Value>,
|
||||
item: Option<Value>,
|
||||
delta: Option<String>,
|
||||
summary_index: Option<i64>,
|
||||
content_index: Option<i64>,
|
||||
}
|
||||
|
||||
pub async fn process_sse(
|
||||
stream: ByteStream,
|
||||
tx_event: mpsc::Sender<Result<ResponseEvent, ApiError>>,
|
||||
idle_timeout: Duration,
|
||||
telemetry: Option<Arc<dyn SseTelemetry>>,
|
||||
) {
|
||||
let mut stream = stream.eventsource();
|
||||
let mut response_completed: Option<ResponseCompleted> = None;
|
||||
let mut response_error: Option<ApiError> = None;
|
||||
|
||||
loop {
|
||||
let start = Instant::now();
|
||||
let response = timeout(idle_timeout, stream.next()).await;
|
||||
if let Some(t) = telemetry.as_ref() {
|
||||
t.on_sse_poll(&response, start.elapsed());
|
||||
}
|
||||
let sse = match response {
|
||||
Ok(Some(Ok(sse))) => sse,
|
||||
Ok(Some(Err(e))) => {
|
||||
debug!("SSE Error: {e:#}");
|
||||
let _ = tx_event.send(Err(ApiError::Stream(e.to_string()))).await;
|
||||
return;
|
||||
}
|
||||
Ok(None) => {
|
||||
match response_completed.take() {
|
||||
Some(ResponseCompleted { id, usage }) => {
|
||||
let event = ResponseEvent::Completed {
|
||||
response_id: id,
|
||||
token_usage: usage.map(Into::into),
|
||||
};
|
||||
let _ = tx_event.send(Ok(event)).await;
|
||||
}
|
||||
None => {
|
||||
let error = response_error.unwrap_or(ApiError::Stream(
|
||||
"stream closed before response.completed".into(),
|
||||
));
|
||||
let _ = tx_event.send(Err(error)).await;
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
Err(_) => {
|
||||
let _ = tx_event
|
||||
.send(Err(ApiError::Stream("idle timeout waiting for SSE".into())))
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let raw = sse.data.clone();
|
||||
trace!("SSE event: {raw}");
|
||||
|
||||
let event: SseEvent = match serde_json::from_str(&sse.data) {
|
||||
Ok(event) => event,
|
||||
Err(e) => {
|
||||
debug!("Failed to parse SSE event: {e}, data: {}", &sse.data);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
match event.kind.as_str() {
|
||||
"response.output_item.done" => {
|
||||
let Some(item_val) = event.item else { continue };
|
||||
let Ok(item) = serde_json::from_value::<ResponseItem>(item_val) else {
|
||||
debug!("failed to parse ResponseItem from output_item.done");
|
||||
continue;
|
||||
};
|
||||
|
||||
let event = ResponseEvent::OutputItemDone(item);
|
||||
if tx_event.send(Ok(event)).await.is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
"response.output_text.delta" => {
|
||||
if let Some(delta) = event.delta {
|
||||
let event = ResponseEvent::OutputTextDelta(delta);
|
||||
if tx_event.send(Ok(event)).await.is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
"response.reasoning_summary_text.delta" => {
|
||||
if let (Some(delta), Some(summary_index)) = (event.delta, event.summary_index) {
|
||||
let event = ResponseEvent::ReasoningSummaryDelta {
|
||||
delta,
|
||||
summary_index,
|
||||
};
|
||||
if tx_event.send(Ok(event)).await.is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
"response.reasoning_text.delta" => {
|
||||
if let (Some(delta), Some(content_index)) = (event.delta, event.content_index) {
|
||||
let event = ResponseEvent::ReasoningContentDelta {
|
||||
delta,
|
||||
content_index,
|
||||
};
|
||||
if tx_event.send(Ok(event)).await.is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
"response.created" => {
|
||||
if event.response.is_some() {
|
||||
let _ = tx_event.send(Ok(ResponseEvent::Created {})).await;
|
||||
}
|
||||
}
|
||||
"response.failed" => {
|
||||
if let Some(resp_val) = event.response {
|
||||
response_error =
|
||||
Some(ApiError::Stream("response.failed event received".into()));
|
||||
|
||||
if let Some(error) = resp_val.get("error")
|
||||
&& let Ok(error) = serde_json::from_value::<Error>(error.clone())
|
||||
{
|
||||
if is_context_window_error(&error) {
|
||||
response_error = Some(ApiError::ContextWindowExceeded);
|
||||
} else if is_quota_exceeded_error(&error) {
|
||||
response_error = Some(ApiError::QuotaExceeded);
|
||||
} else if is_usage_not_included(&error) {
|
||||
response_error = Some(ApiError::UsageNotIncluded);
|
||||
} else {
|
||||
let delay = try_parse_retry_after(&error);
|
||||
let message = error.message.clone().unwrap_or_default();
|
||||
response_error = Some(ApiError::Retryable { message, delay });
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"response.completed" => {
|
||||
if let Some(resp_val) = event.response {
|
||||
match serde_json::from_value::<ResponseCompleted>(resp_val) {
|
||||
Ok(r) => {
|
||||
response_completed = Some(r);
|
||||
}
|
||||
Err(e) => {
|
||||
let error = format!("failed to parse ResponseCompleted: {e}");
|
||||
debug!(error);
|
||||
response_error = Some(ApiError::Stream(error));
|
||||
continue;
|
||||
}
|
||||
};
|
||||
};
|
||||
}
|
||||
"response.output_item.added" => {
|
||||
let Some(item_val) = event.item else { continue };
|
||||
let Ok(item) = serde_json::from_value::<ResponseItem>(item_val) else {
|
||||
debug!("failed to parse ResponseItem from output_item.done");
|
||||
continue;
|
||||
};
|
||||
|
||||
let event = ResponseEvent::OutputItemAdded(item);
|
||||
if tx_event.send(Ok(event)).await.is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
"response.reasoning_summary_part.added" => {
|
||||
if let Some(summary_index) = event.summary_index {
|
||||
let event = ResponseEvent::ReasoningSummaryPartAdded { summary_index };
|
||||
if tx_event.send(Ok(event)).await.is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn try_parse_retry_after(err: &Error) -> Option<Duration> {
|
||||
if err.code.as_deref() != Some("rate_limit_exceeded") {
|
||||
return None;
|
||||
}
|
||||
|
||||
let re = rate_limit_regex();
|
||||
if let Some(message) = &err.message
|
||||
&& let Some(captures) = re.captures(message)
|
||||
{
|
||||
let seconds = captures.get(1);
|
||||
let unit = captures.get(2);
|
||||
|
||||
if let (Some(value), Some(unit)) = (seconds, unit) {
|
||||
let value = value.as_str().parse::<f64>().ok()?;
|
||||
let unit = unit.as_str().to_ascii_lowercase();
|
||||
|
||||
if unit == "s" || unit.starts_with("second") {
|
||||
return Some(Duration::from_secs_f64(value));
|
||||
} else if unit == "ms" {
|
||||
return Some(Duration::from_millis(value as u64));
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn is_context_window_error(error: &Error) -> bool {
|
||||
error.code.as_deref() == Some("context_length_exceeded")
|
||||
}
|
||||
|
||||
fn is_quota_exceeded_error(error: &Error) -> bool {
|
||||
error.code.as_deref() == Some("insufficient_quota")
|
||||
}
|
||||
|
||||
fn is_usage_not_included(error: &Error) -> bool {
|
||||
error.code.as_deref() == Some("usage_not_included")
|
||||
}
|
||||
|
||||
fn rate_limit_regex() -> &'static regex_lite::Regex {
|
||||
static RE: std::sync::OnceLock<regex_lite::Regex> = std::sync::OnceLock::new();
|
||||
#[expect(clippy::unwrap_used)]
|
||||
RE.get_or_init(|| {
|
||||
regex_lite::Regex::new(r"(?i)try again in\s*(\d+(?:\.\d+)?)\s*(s|ms|seconds?)").unwrap()
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use assert_matches::assert_matches;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_test::io::Builder as IoBuilder;
|
||||
|
||||
async fn collect_events(chunks: &[&[u8]]) -> Vec<Result<ResponseEvent, ApiError>> {
|
||||
let mut builder = IoBuilder::new();
|
||||
for chunk in chunks {
|
||||
builder.read(chunk);
|
||||
}
|
||||
|
||||
let reader = builder.build();
|
||||
let stream =
|
||||
ReaderStream::new(reader).map_err(|err| TransportError::Network(err.to_string()));
|
||||
let (tx, mut rx) = mpsc::channel::<Result<ResponseEvent, ApiError>>(16);
|
||||
tokio::spawn(process_sse(Box::pin(stream), tx, idle_timeout(), None));
|
||||
|
||||
let mut events = Vec::new();
|
||||
while let Some(ev) = rx.recv().await {
|
||||
events.push(ev);
|
||||
}
|
||||
events
|
||||
}
|
||||
|
||||
async fn run_sse(events: Vec<serde_json::Value>) -> Vec<ResponseEvent> {
|
||||
let mut body = String::new();
|
||||
for e in events {
|
||||
let kind = e
|
||||
.get("type")
|
||||
.and_then(|v| v.as_str())
|
||||
.expect("fixture event missing type");
|
||||
if e.as_object().map(|o| o.len() == 1).unwrap_or(false) {
|
||||
body.push_str(&format!("event: {kind}\n\n"));
|
||||
} else {
|
||||
body.push_str(&format!("event: {kind}\ndata: {e}\n\n"));
|
||||
}
|
||||
}
|
||||
|
||||
let (tx, mut rx) = mpsc::channel::<Result<ResponseEvent, ApiError>>(8);
|
||||
let stream = ReaderStream::new(std::io::Cursor::new(body))
|
||||
.map_err(|err| TransportError::Network(err.to_string()));
|
||||
tokio::spawn(process_sse(Box::pin(stream), tx, idle_timeout(), None));
|
||||
|
||||
let mut out = Vec::new();
|
||||
while let Some(ev) = rx.recv().await {
|
||||
out.push(ev.expect("channel closed"));
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
fn idle_timeout() -> Duration {
|
||||
Duration::from_millis(1000)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn parses_items_and_completed() {
|
||||
let item1 = json!({
|
||||
"type": "response.output_item.done",
|
||||
"item": {
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [{"type": "output_text", "text": "Hello"}]
|
||||
}
|
||||
})
|
||||
.to_string();
|
||||
|
||||
let item2 = json!({
|
||||
"type": "response.output_item.done",
|
||||
"item": {
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [{"type": "output_text", "text": "World"}]
|
||||
}
|
||||
})
|
||||
.to_string();
|
||||
|
||||
let completed = json!({
|
||||
"type": "response.completed",
|
||||
"response": { "id": "resp1" }
|
||||
})
|
||||
.to_string();
|
||||
|
||||
let sse1 = format!("event: response.output_item.done\ndata: {item1}\n\n");
|
||||
let sse2 = format!("event: response.output_item.done\ndata: {item2}\n\n");
|
||||
let sse3 = format!("event: response.completed\ndata: {completed}\n\n");
|
||||
|
||||
let events = collect_events(&[sse1.as_bytes(), sse2.as_bytes(), sse3.as_bytes()]).await;
|
||||
|
||||
assert_eq!(events.len(), 3);
|
||||
|
||||
assert_matches!(
|
||||
&events[0],
|
||||
Ok(ResponseEvent::OutputItemDone(ResponseItem::Message { role, .. }))
|
||||
if role == "assistant"
|
||||
);
|
||||
|
||||
assert_matches!(
|
||||
&events[1],
|
||||
Ok(ResponseEvent::OutputItemDone(ResponseItem::Message { role, .. }))
|
||||
if role == "assistant"
|
||||
);
|
||||
|
||||
match &events[2] {
|
||||
Ok(ResponseEvent::Completed {
|
||||
response_id,
|
||||
token_usage,
|
||||
}) => {
|
||||
assert_eq!(response_id, "resp1");
|
||||
assert!(token_usage.is_none());
|
||||
}
|
||||
other => panic!("unexpected third event: {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn error_when_missing_completed() {
|
||||
let item1 = json!({
|
||||
"type": "response.output_item.done",
|
||||
"item": {
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [{"type": "output_text", "text": "Hello"}]
|
||||
}
|
||||
})
|
||||
.to_string();
|
||||
|
||||
let sse1 = format!("event: response.output_item.done\ndata: {item1}\n\n");
|
||||
|
||||
let events = collect_events(&[sse1.as_bytes()]).await;
|
||||
|
||||
assert_eq!(events.len(), 2);
|
||||
|
||||
assert_matches!(events[0], Ok(ResponseEvent::OutputItemDone(_)));
|
||||
|
||||
match &events[1] {
|
||||
Err(ApiError::Stream(msg)) => {
|
||||
assert_eq!(msg, "stream closed before response.completed")
|
||||
}
|
||||
other => panic!("unexpected second event: {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn error_when_error_event() {
|
||||
let raw_error = r#"{"type":"response.failed","sequence_number":3,"response":{"id":"resp_689bcf18d7f08194bf3440ba62fe05d803fee0cdac429894","object":"response","created_at":1755041560,"status":"failed","background":false,"error":{"code":"rate_limit_exceeded","message":"Rate limit reached for gpt-5.1 in organization org-AAA on tokens per min (TPM): Limit 30000, Used 22999, Requested 12528. Please try again in 11.054s. Visit https://platform.openai.com/account/rate-limits to learn more."}, "usage":null,"user":null,"metadata":{}}}"#;
|
||||
|
||||
let sse1 = format!("event: response.failed\ndata: {raw_error}\n\n");
|
||||
|
||||
let events = collect_events(&[sse1.as_bytes()]).await;
|
||||
|
||||
assert_eq!(events.len(), 1);
|
||||
|
||||
match &events[0] {
|
||||
Err(ApiError::Retryable { message, delay }) => {
|
||||
assert_eq!(
|
||||
message,
|
||||
"Rate limit reached for gpt-5.1 in organization org-AAA on tokens per min (TPM): Limit 30000, Used 22999, Requested 12528. Please try again in 11.054s. Visit https://platform.openai.com/account/rate-limits to learn more."
|
||||
);
|
||||
assert_eq!(*delay, Some(Duration::from_secs_f64(11.054)));
|
||||
}
|
||||
other => panic!("unexpected second event: {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn context_window_error_is_fatal() {
|
||||
let raw_error = r#"{"type":"response.failed","sequence_number":3,"response":{"id":"resp_5c66275b97b9baef1ed95550adb3b7ec13b17aafd1d2f11b","object":"response","created_at":1759510079,"status":"failed","background":false,"error":{"code":"context_length_exceeded","message":"Your input exceeds the context window of this model. Please adjust your input and try again."},"usage":null,"user":null,"metadata":{}}}"#;
|
||||
|
||||
let sse1 = format!("event: response.failed\ndata: {raw_error}\n\n");
|
||||
|
||||
let events = collect_events(&[sse1.as_bytes()]).await;
|
||||
|
||||
assert_eq!(events.len(), 1);
|
||||
|
||||
assert_matches!(events[0], Err(ApiError::ContextWindowExceeded));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn context_window_error_with_newline_is_fatal() {
|
||||
let raw_error = r#"{"type":"response.failed","sequence_number":4,"response":{"id":"resp_fatal_newline","object":"response","created_at":1759510080,"status":"failed","background":false,"error":{"code":"context_length_exceeded","message":"Your input exceeds the context window of this model. Please adjust your input and try\nagain."},"usage":null,"user":null,"metadata":{}}}"#;
|
||||
|
||||
let sse1 = format!("event: response.failed\ndata: {raw_error}\n\n");
|
||||
|
||||
let events = collect_events(&[sse1.as_bytes()]).await;
|
||||
|
||||
assert_eq!(events.len(), 1);
|
||||
|
||||
assert_matches!(events[0], Err(ApiError::ContextWindowExceeded));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn quota_exceeded_error_is_fatal() {
|
||||
let raw_error = r#"{"type":"response.failed","sequence_number":3,"response":{"id":"resp_fatal_quota","object":"response","created_at":1759771626,"status":"failed","background":false,"error":{"code":"insufficient_quota","message":"You exceeded your current quota, please check your plan and billing details. For more information on this error, read the docs: https://platform.openai.com/docs/guides/error-codes/api-errors."},"incomplete_details":null}}"#;
|
||||
|
||||
let sse1 = format!("event: response.failed\ndata: {raw_error}\n\n");
|
||||
|
||||
let events = collect_events(&[sse1.as_bytes()]).await;
|
||||
|
||||
assert_eq!(events.len(), 1);
|
||||
|
||||
assert_matches!(events[0], Err(ApiError::QuotaExceeded));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn table_driven_event_kinds() {
|
||||
struct TestCase {
|
||||
name: &'static str,
|
||||
event: serde_json::Value,
|
||||
expect_first: fn(&ResponseEvent) -> bool,
|
||||
expected_len: usize,
|
||||
}
|
||||
|
||||
fn is_created(ev: &ResponseEvent) -> bool {
|
||||
matches!(ev, ResponseEvent::Created)
|
||||
}
|
||||
fn is_output(ev: &ResponseEvent) -> bool {
|
||||
matches!(ev, ResponseEvent::OutputItemDone(_))
|
||||
}
|
||||
fn is_completed(ev: &ResponseEvent) -> bool {
|
||||
matches!(ev, ResponseEvent::Completed { .. })
|
||||
}
|
||||
|
||||
let completed = json!({
|
||||
"type": "response.completed",
|
||||
"response": {
|
||||
"id": "c",
|
||||
"usage": {
|
||||
"input_tokens": 0,
|
||||
"input_tokens_details": null,
|
||||
"output_tokens": 0,
|
||||
"output_tokens_details": null,
|
||||
"total_tokens": 0
|
||||
},
|
||||
"output": []
|
||||
}
|
||||
});
|
||||
|
||||
let cases = vec![
|
||||
TestCase {
|
||||
name: "created",
|
||||
event: json!({"type": "response.created", "response": {}}),
|
||||
expect_first: is_created,
|
||||
expected_len: 2,
|
||||
},
|
||||
TestCase {
|
||||
name: "output_item.done",
|
||||
event: json!({
|
||||
"type": "response.output_item.done",
|
||||
"item": {
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "output_text", "text": "hi"}
|
||||
]
|
||||
}
|
||||
}),
|
||||
expect_first: is_output,
|
||||
expected_len: 2,
|
||||
},
|
||||
TestCase {
|
||||
name: "unknown",
|
||||
event: json!({"type": "response.new_tool_event"}),
|
||||
expect_first: is_completed,
|
||||
expected_len: 1,
|
||||
},
|
||||
];
|
||||
|
||||
for case in cases {
|
||||
let mut evs = vec![case.event];
|
||||
evs.push(completed.clone());
|
||||
|
||||
let out = run_sse(evs).await;
|
||||
assert_eq!(out.len(), case.expected_len, "case {}", case.name);
|
||||
assert!(
|
||||
(case.expect_first)(&out[0]),
|
||||
"first event mismatch in case {}",
|
||||
case.name
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_try_parse_retry_after() {
|
||||
let err = Error {
|
||||
r#type: None,
|
||||
message: Some("Rate limit reached for gpt-5.1 in organization org- on tokens per min (TPM): Limit 1, Used 1, Requested 19304. Please try again in 28ms. Visit https://platform.openai.com/account/rate-limits to learn more.".to_string()),
|
||||
code: Some("rate_limit_exceeded".to_string()),
|
||||
plan_type: None,
|
||||
resets_at: None,
|
||||
};
|
||||
|
||||
let delay = try_parse_retry_after(&err);
|
||||
assert_eq!(delay, Some(Duration::from_millis(28)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_try_parse_retry_after_no_delay() {
|
||||
let err = Error {
|
||||
r#type: None,
|
||||
message: Some("Rate limit reached for gpt-5.1 in organization <ORG> on tokens per min (TPM): Limit 30000, Used 6899, Requested 24050. Please try again in 1.898s. Visit https://platform.openai.com/account/rate-limits to learn more.".to_string()),
|
||||
code: Some("rate_limit_exceeded".to_string()),
|
||||
plan_type: None,
|
||||
resets_at: None,
|
||||
};
|
||||
let delay = try_parse_retry_after(&err);
|
||||
assert_eq!(delay, Some(Duration::from_secs_f64(1.898)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_try_parse_retry_after_azure() {
|
||||
let err = Error {
|
||||
r#type: None,
|
||||
message: Some("Rate limit exceeded. Try again in 35 seconds.".to_string()),
|
||||
code: Some("rate_limit_exceeded".to_string()),
|
||||
plan_type: None,
|
||||
resets_at: None,
|
||||
};
|
||||
let delay = try_parse_retry_after(&err);
|
||||
assert_eq!(delay, Some(Duration::from_secs(35)));
|
||||
}
|
||||
}
|
||||
84
codex-rs/codex-api/src/telemetry.rs
Normal file
84
codex-rs/codex-api/src/telemetry.rs
Normal file
@@ -0,0 +1,84 @@
|
||||
use codex_client::Request;
|
||||
use codex_client::RequestTelemetry;
|
||||
use codex_client::Response;
|
||||
use codex_client::RetryPolicy;
|
||||
use codex_client::StreamResponse;
|
||||
use codex_client::TransportError;
|
||||
use codex_client::run_with_retry;
|
||||
use http::StatusCode;
|
||||
use std::future::Future;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::time::Instant;
|
||||
|
||||
/// Generic telemetry.
|
||||
pub trait SseTelemetry: Send + Sync {
|
||||
fn on_sse_poll(
|
||||
&self,
|
||||
result: &Result<
|
||||
Option<
|
||||
Result<
|
||||
eventsource_stream::Event,
|
||||
eventsource_stream::EventStreamError<TransportError>,
|
||||
>,
|
||||
>,
|
||||
tokio::time::error::Elapsed,
|
||||
>,
|
||||
duration: Duration,
|
||||
);
|
||||
}
|
||||
|
||||
pub(crate) trait WithStatus {
|
||||
fn status(&self) -> StatusCode;
|
||||
}
|
||||
|
||||
fn http_status(err: &TransportError) -> Option<StatusCode> {
|
||||
match err {
|
||||
TransportError::Http { status, .. } => Some(*status),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
impl WithStatus for Response {
|
||||
fn status(&self) -> StatusCode {
|
||||
self.status
|
||||
}
|
||||
}
|
||||
|
||||
impl WithStatus for StreamResponse {
|
||||
fn status(&self) -> StatusCode {
|
||||
self.status
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn run_with_request_telemetry<T, F, Fut>(
|
||||
policy: RetryPolicy,
|
||||
telemetry: Option<Arc<dyn RequestTelemetry>>,
|
||||
make_request: impl FnMut() -> Request,
|
||||
send: F,
|
||||
) -> Result<T, TransportError>
|
||||
where
|
||||
T: WithStatus,
|
||||
F: Clone + Fn(Request) -> Fut,
|
||||
Fut: Future<Output = Result<T, TransportError>>,
|
||||
{
|
||||
// Wraps `run_with_retry` to attach per-attempt request telemetry for both
|
||||
// unary and streaming HTTP calls.
|
||||
run_with_retry(policy, make_request, move |req, attempt| {
|
||||
let telemetry = telemetry.clone();
|
||||
let send = send.clone();
|
||||
async move {
|
||||
let start = Instant::now();
|
||||
let result = send(req).await;
|
||||
if let Some(t) = telemetry.as_ref() {
|
||||
let (status, err) = match &result {
|
||||
Ok(resp) => (Some(resp.status()), None),
|
||||
Err(err) => (http_status(err), Some(err)),
|
||||
};
|
||||
t.on_request(attempt, status, err, start.elapsed());
|
||||
}
|
||||
result
|
||||
}
|
||||
})
|
||||
.await
|
||||
}
|
||||
Reference in New Issue
Block a user