mirror of
https://github.com/openai/codex.git
synced 2026-04-24 14:45:27 +00:00
Move compression into the request builder
This commit is contained in:
@@ -8,6 +8,8 @@ use crate::error::ApiError;
|
||||
use crate::provider::Provider;
|
||||
use crate::provider::RequestCompression;
|
||||
use crate::provider::WireApi;
|
||||
use crate::requests::body::encode_body;
|
||||
use crate::requests::body::insert_compression_headers;
|
||||
use crate::sse::chat::spawn_chat_stream;
|
||||
use crate::telemetry::SseTelemetry;
|
||||
use codex_client::HttpTransport;
|
||||
@@ -87,14 +89,11 @@ impl<T: HttpTransport, A: AuthProvider> ChatClient<T, A> {
|
||||
extra_headers: HeaderMap,
|
||||
request_compression: RequestCompression,
|
||||
) -> Result<ResponseStream, ApiError> {
|
||||
let mut headers = extra_headers;
|
||||
insert_compression_headers(&mut headers, request_compression);
|
||||
let encoded_body = encode_body(&body, request_compression)?;
|
||||
self.streaming
|
||||
.stream(
|
||||
self.path(),
|
||||
body,
|
||||
extra_headers,
|
||||
request_compression,
|
||||
spawn_chat_stream,
|
||||
)
|
||||
.stream(self.path(), encoded_body, headers, spawn_chat_stream)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,7 +16,6 @@ use codex_client::HttpTransport;
|
||||
use codex_client::RequestTelemetry;
|
||||
use codex_protocol::protocol::SessionSource;
|
||||
use http::HeaderMap;
|
||||
use serde_json::Value;
|
||||
use std::sync::Arc;
|
||||
use tracing::instrument;
|
||||
|
||||
@@ -57,10 +56,8 @@ impl<T: HttpTransport, A: AuthProvider> ResponsesClient<T, A> {
|
||||
pub async fn stream_request(
|
||||
&self,
|
||||
request: ResponsesRequest,
|
||||
request_compression: RequestCompression,
|
||||
) -> Result<ResponseStream, ApiError> {
|
||||
self.stream(request.body, request.headers, request_compression)
|
||||
.await
|
||||
self.stream(request).await
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip_all, err)]
|
||||
@@ -93,9 +90,10 @@ impl<T: HttpTransport, A: AuthProvider> ResponsesClient<T, A> {
|
||||
.session_source(session_source)
|
||||
.store_override(store_override)
|
||||
.extra_headers(extra_headers)
|
||||
.request_compression(request_compression)
|
||||
.build(self.streaming.provider())?;
|
||||
|
||||
self.stream_request(request, request_compression).await
|
||||
self.stream_request(request).await
|
||||
}
|
||||
|
||||
fn path(&self) -> &'static str {
|
||||
@@ -105,18 +103,12 @@ impl<T: HttpTransport, A: AuthProvider> ResponsesClient<T, A> {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn stream(
|
||||
&self,
|
||||
body: Value,
|
||||
extra_headers: HeaderMap,
|
||||
request_compression: RequestCompression,
|
||||
) -> Result<ResponseStream, ApiError> {
|
||||
pub async fn stream(&self, request: ResponsesRequest) -> Result<ResponseStream, ApiError> {
|
||||
self.streaming
|
||||
.stream(
|
||||
self.path(),
|
||||
body,
|
||||
extra_headers,
|
||||
request_compression,
|
||||
request.body,
|
||||
request.headers,
|
||||
spawn_response_stream,
|
||||
)
|
||||
.await
|
||||
|
||||
@@ -3,10 +3,8 @@ use crate::auth::add_auth_headers;
|
||||
use crate::common::ResponseStream;
|
||||
use crate::error::ApiError;
|
||||
use crate::provider::Provider;
|
||||
use crate::provider::RequestCompression;
|
||||
use crate::telemetry::SseTelemetry;
|
||||
use crate::telemetry::run_with_request_telemetry;
|
||||
use bytes::Bytes;
|
||||
use codex_client::Body;
|
||||
use codex_client::HttpTransport;
|
||||
use codex_client::RequestTelemetry;
|
||||
@@ -15,14 +13,9 @@ use http::HeaderMap;
|
||||
use http::HeaderValue;
|
||||
use http::Method;
|
||||
use http::header::ACCEPT;
|
||||
use http::header::CONTENT_ENCODING;
|
||||
use http::header::CONTENT_TYPE;
|
||||
use serde_json::Value;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use std::time::Instant;
|
||||
use tracing::info;
|
||||
use zstd::stream::encode_all;
|
||||
|
||||
pub(crate) struct StreamingClient<T: HttpTransport, A: AuthProvider> {
|
||||
transport: T,
|
||||
@@ -60,14 +53,10 @@ impl<T: HttpTransport, A: AuthProvider> StreamingClient<T, A> {
|
||||
pub(crate) async fn stream(
|
||||
&self,
|
||||
path: &str,
|
||||
body: Value,
|
||||
body: Body,
|
||||
extra_headers: HeaderMap,
|
||||
request_compression: RequestCompression,
|
||||
spawner: fn(StreamResponse, Duration, Option<Arc<dyn SseTelemetry>>) -> ResponseStream,
|
||||
) -> Result<ResponseStream, ApiError> {
|
||||
let content_encoding = matches!(request_compression, RequestCompression::Zstd);
|
||||
let encoded_body = encode_body(&body, request_compression).map_err(ApiError::Stream)?;
|
||||
|
||||
let builder = || {
|
||||
let mut req = self.provider.build_request(Method::POST, path);
|
||||
req.headers.extend(extra_headers.clone());
|
||||
@@ -76,11 +65,7 @@ impl<T: HttpTransport, A: AuthProvider> StreamingClient<T, A> {
|
||||
req.headers
|
||||
.entry(CONTENT_TYPE)
|
||||
.or_insert_with(|| HeaderValue::from_static("application/json"));
|
||||
if content_encoding {
|
||||
req.headers
|
||||
.insert(CONTENT_ENCODING, HeaderValue::from_static("zstd"));
|
||||
}
|
||||
req.body = Some(encoded_body.clone());
|
||||
req.body = Some(body.clone());
|
||||
add_auth_headers(&self.auth, req)
|
||||
};
|
||||
|
||||
@@ -99,24 +84,3 @@ impl<T: HttpTransport, A: AuthProvider> StreamingClient<T, A> {
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
fn encode_body(body: &Value, compression: RequestCompression) -> Result<Body, String> {
|
||||
match compression {
|
||||
RequestCompression::None => Ok(Body::Json(body.clone())),
|
||||
RequestCompression::Zstd => {
|
||||
let json = serde_json::to_vec(body)
|
||||
.map_err(|err| format!("failed to encode request body as json: {err}"))?;
|
||||
let started_at = Instant::now();
|
||||
let compressed = encode_all(json.as_slice(), 0)
|
||||
.map_err(|err| format!("failed to compress request body: {err}"))?;
|
||||
let elapsed = started_at.elapsed();
|
||||
info!(
|
||||
input_bytes = json.len(),
|
||||
output_bytes = compressed.len(),
|
||||
elapsed_ms = elapsed.as_millis(),
|
||||
"compressed request body"
|
||||
);
|
||||
Ok(Body::Bytes(Bytes::from(compressed)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
40
codex-rs/codex-api/src/requests/body.rs
Normal file
40
codex-rs/codex-api/src/requests/body.rs
Normal file
@@ -0,0 +1,40 @@
|
||||
use crate::error::ApiError;
|
||||
use crate::provider::RequestCompression;
|
||||
use bytes::Bytes;
|
||||
use codex_client::Body;
|
||||
use http::HeaderMap;
|
||||
use http::HeaderValue;
|
||||
use http::header::CONTENT_ENCODING;
|
||||
use serde_json::Value;
|
||||
use std::time::Instant;
|
||||
use tracing::info;
|
||||
use zstd::stream::encode_all;
|
||||
|
||||
pub(crate) fn encode_body(body: &Value, compression: RequestCompression) -> Result<Body, ApiError> {
|
||||
match compression {
|
||||
RequestCompression::None => Ok(Body::Json(body.clone())),
|
||||
RequestCompression::Zstd => {
|
||||
let json = serde_json::to_vec(body).map_err(|err| {
|
||||
ApiError::Stream(format!("failed to encode request body as json: {err}"))
|
||||
})?;
|
||||
let started_at = Instant::now();
|
||||
let compressed = encode_all(json.as_slice(), 0).map_err(|err| {
|
||||
ApiError::Stream(format!("failed to compress request body: {err}"))
|
||||
})?;
|
||||
let elapsed = started_at.elapsed();
|
||||
info!(
|
||||
input_bytes = json.len(),
|
||||
output_bytes = compressed.len(),
|
||||
elapsed_ms = elapsed.as_millis(),
|
||||
"compressed request body"
|
||||
);
|
||||
Ok(Body::Bytes(Bytes::from(compressed)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn insert_compression_headers(headers: &mut HeaderMap, compression: RequestCompression) {
|
||||
if matches!(compression, RequestCompression::Zstd) {
|
||||
headers.insert(CONTENT_ENCODING, HeaderValue::from_static("zstd"));
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,4 @@
|
||||
pub(crate) mod body;
|
||||
pub mod chat;
|
||||
pub(crate) mod headers;
|
||||
pub mod responses;
|
||||
|
||||
@@ -3,9 +3,13 @@ use crate::common::ResponsesApiRequest;
|
||||
use crate::common::TextControls;
|
||||
use crate::error::ApiError;
|
||||
use crate::provider::Provider;
|
||||
use crate::provider::RequestCompression;
|
||||
use crate::requests::body::encode_body;
|
||||
use crate::requests::body::insert_compression_headers;
|
||||
use crate::requests::headers::build_conversation_headers;
|
||||
use crate::requests::headers::insert_header;
|
||||
use crate::requests::headers::subagent_header;
|
||||
use codex_client::Body;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::protocol::SessionSource;
|
||||
use http::HeaderMap;
|
||||
@@ -13,7 +17,7 @@ use serde_json::Value;
|
||||
|
||||
/// Assembled request body plus headers for a Responses stream request.
|
||||
pub struct ResponsesRequest {
|
||||
pub body: Value,
|
||||
pub body: Body,
|
||||
pub headers: HeaderMap,
|
||||
}
|
||||
|
||||
@@ -32,6 +36,7 @@ pub struct ResponsesRequestBuilder<'a> {
|
||||
session_source: Option<SessionSource>,
|
||||
store_override: Option<bool>,
|
||||
headers: HeaderMap,
|
||||
request_compression: RequestCompression,
|
||||
}
|
||||
|
||||
impl<'a> ResponsesRequestBuilder<'a> {
|
||||
@@ -94,6 +99,11 @@ impl<'a> ResponsesRequestBuilder<'a> {
|
||||
self
|
||||
}
|
||||
|
||||
pub fn request_compression(mut self, request_compression: RequestCompression) -> Self {
|
||||
self.request_compression = request_compression;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self, provider: &Provider) -> Result<ResponsesRequest, ApiError> {
|
||||
let model = self
|
||||
.model
|
||||
@@ -137,6 +147,8 @@ impl<'a> ResponsesRequestBuilder<'a> {
|
||||
if let Some(subagent) = subagent_header(&self.session_source) {
|
||||
insert_header(&mut headers, "x-openai-subagent", &subagent);
|
||||
}
|
||||
insert_compression_headers(&mut headers, self.request_compression);
|
||||
let body = encode_body(&body, self.request_compression)?;
|
||||
|
||||
Ok(ResponsesRequest { body, headers })
|
||||
}
|
||||
@@ -175,6 +187,7 @@ mod tests {
|
||||
|
||||
use crate::provider::RetryConfig;
|
||||
use crate::provider::WireApi;
|
||||
use codex_client::Body;
|
||||
use codex_protocol::protocol::SubAgentSource;
|
||||
use http::HeaderValue;
|
||||
use pretty_assertions::assert_eq;
|
||||
@@ -220,10 +233,12 @@ mod tests {
|
||||
.build(&provider)
|
||||
.expect("request");
|
||||
|
||||
assert_eq!(request.body.get("store"), Some(&Value::Bool(true)));
|
||||
let Body::Json(body) = &request.body else {
|
||||
panic!("expected json body for responses request");
|
||||
};
|
||||
assert_eq!(body.get("store"), Some(&Value::Bool(true)));
|
||||
|
||||
let ids: Vec<Option<String>> = request
|
||||
.body
|
||||
let ids: Vec<Option<String>> = body
|
||||
.get("input")
|
||||
.and_then(|v| v.as_array())
|
||||
.into_iter()
|
||||
|
||||
@@ -10,7 +10,9 @@ use codex_api::ChatClient;
|
||||
use codex_api::Provider;
|
||||
use codex_api::ResponsesClient;
|
||||
use codex_api::ResponsesOptions;
|
||||
use codex_api::ResponsesRequest;
|
||||
use codex_api::WireApi;
|
||||
use codex_client::Body;
|
||||
use codex_client::HttpTransport;
|
||||
use codex_client::Request;
|
||||
use codex_client::Response;
|
||||
@@ -136,6 +138,13 @@ fn provider(name: &str, wire: WireApi) -> Provider {
|
||||
}
|
||||
}
|
||||
|
||||
fn responses_request(body: Value) -> ResponsesRequest {
|
||||
ResponsesRequest {
|
||||
body: Body::Json(body),
|
||||
headers: HeaderMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct FlakyTransport {
|
||||
state: Arc<Mutex<i64>>,
|
||||
@@ -232,10 +241,8 @@ async fn responses_client_uses_responses_path_for_responses_wire() -> Result<()>
|
||||
let transport = RecordingTransport::new(state.clone());
|
||||
let client = ResponsesClient::new(transport, provider("openai", WireApi::Responses), NoAuth);
|
||||
|
||||
let body = serde_json::json!({ "echo": true });
|
||||
let _stream = client
|
||||
.stream(body, HeaderMap::new(), Default::default())
|
||||
.await?;
|
||||
let request = responses_request(serde_json::json!({ "echo": true }));
|
||||
let _stream = client.stream(request).await?;
|
||||
|
||||
let requests = state.take_stream_requests();
|
||||
assert_path_ends_with(&requests, "/responses");
|
||||
@@ -248,10 +255,8 @@ async fn responses_client_uses_chat_path_for_chat_wire() -> Result<()> {
|
||||
let transport = RecordingTransport::new(state.clone());
|
||||
let client = ResponsesClient::new(transport, provider("openai", WireApi::Chat), NoAuth);
|
||||
|
||||
let body = serde_json::json!({ "echo": true });
|
||||
let _stream = client
|
||||
.stream(body, HeaderMap::new(), Default::default())
|
||||
.await?;
|
||||
let request = responses_request(serde_json::json!({ "echo": true }));
|
||||
let _stream = client.stream(request).await?;
|
||||
|
||||
let requests = state.take_stream_requests();
|
||||
assert_path_ends_with(&requests, "/chat/completions");
|
||||
@@ -265,10 +270,8 @@ async fn streaming_client_adds_auth_headers() -> Result<()> {
|
||||
let auth = StaticAuth::new("secret-token", "acct-1");
|
||||
let client = ResponsesClient::new(transport, provider("openai", WireApi::Responses), auth);
|
||||
|
||||
let body = serde_json::json!({ "model": "gpt-test" });
|
||||
let _stream = client
|
||||
.stream(body, HeaderMap::new(), Default::default())
|
||||
.await?;
|
||||
let request = responses_request(serde_json::json!({ "model": "gpt-test" }));
|
||||
let _stream = client.stream(request).await?;
|
||||
|
||||
let requests = state.take_stream_requests();
|
||||
assert_eq!(requests.len(), 1);
|
||||
|
||||
@@ -8,7 +8,9 @@ use codex_api::AuthProvider;
|
||||
use codex_api::Provider;
|
||||
use codex_api::ResponseEvent;
|
||||
use codex_api::ResponsesClient;
|
||||
use codex_api::ResponsesRequest;
|
||||
use codex_api::WireApi;
|
||||
use codex_client::Body;
|
||||
use codex_client::HttpTransport;
|
||||
use codex_client::Request;
|
||||
use codex_client::Response;
|
||||
@@ -94,6 +96,13 @@ fn build_responses_body(events: Vec<Value>) -> String {
|
||||
body
|
||||
}
|
||||
|
||||
fn responses_request(body: Value) -> ResponsesRequest {
|
||||
ResponsesRequest {
|
||||
body: Body::Json(body),
|
||||
headers: HeaderMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn responses_stream_parses_items_and_completed_end_to_end() -> Result<()> {
|
||||
let item1 = serde_json::json!({
|
||||
@@ -123,13 +132,8 @@ async fn responses_stream_parses_items_and_completed_end_to_end() -> Result<()>
|
||||
let transport = FixtureSseTransport::new(body);
|
||||
let client = ResponsesClient::new(transport, provider("openai", WireApi::Responses), NoAuth);
|
||||
|
||||
let mut stream = client
|
||||
.stream(
|
||||
serde_json::json!({"echo": true}),
|
||||
HeaderMap::new(),
|
||||
Default::default(),
|
||||
)
|
||||
.await?;
|
||||
let request = responses_request(serde_json::json!({"echo": true}));
|
||||
let mut stream = client.stream(request).await?;
|
||||
|
||||
let mut events = Vec::new();
|
||||
while let Some(ev) = stream.next().await {
|
||||
@@ -192,13 +196,8 @@ async fn responses_stream_aggregates_output_text_deltas() -> Result<()> {
|
||||
let transport = FixtureSseTransport::new(body);
|
||||
let client = ResponsesClient::new(transport, provider("openai", WireApi::Responses), NoAuth);
|
||||
|
||||
let stream = client
|
||||
.stream(
|
||||
serde_json::json!({"echo": true}),
|
||||
HeaderMap::new(),
|
||||
Default::default(),
|
||||
)
|
||||
.await?;
|
||||
let request = responses_request(serde_json::json!({"echo": true}));
|
||||
let stream = client.stream(request).await?;
|
||||
|
||||
let mut stream = stream.aggregate();
|
||||
let mut events = Vec::new();
|
||||
|
||||
Reference in New Issue
Block a user