Compare commits

...

8 Commits

Author SHA1 Message Date
Channing Conger
e213999759 merge fix 2026-01-05 12:41:13 -08:00
Channing Conger
0124a95c83 Move compression into the request builder 2026-01-05 12:41:13 -08:00
Channing Conger
7361c1fea2 await on default_config_for_test 2026-01-05 12:41:13 -08:00
Channing Conger
f5adf8291a Request compression is no longer tied to the provider 2026-01-05 12:41:12 -08:00
Channing Conger
47366e8417 Revert to boolean for the feature since they only operate on booleans 2026-01-05 12:41:12 -08:00
Channing Conger
eef24681c8 Update to use Feature, with log line about time and compression ratio 2026-01-05 12:41:12 -08:00
Channing Conger
c069660811 Fix test 2026-01-05 12:41:12 -08:00
Channing Conger
55acaaffac Request compression.
Add new model_provider flag for compression to enable request
compression.  We support zstd and gzip, server also supports brotli

You can test this against the sign in with chatgpt flow by adding the
following profile:

```
[profiles.compressed]
name = "compressed"
model_provider = "openai-zstd"

[model_providers.openai-zstd]
name = "OpenAI (ChatGPT, zstd)"
wire_api = "responses"
request_compression = "zstd"
requires_openai_auth = true
```

This will zstd compress your request before sending it to the server.
2026-01-05 12:41:08 -08:00
34 changed files with 571 additions and 119 deletions

45
codex-rs/Cargo.lock generated
View File

@@ -819,6 +819,8 @@ version = "1.2.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "deec109607ca693028562ed836a5f1c4b8bd77755c4e132fc5ce11b0b6211ae7"
dependencies = [
"jobserver",
"libc",
"shlex",
]
@@ -985,6 +987,7 @@ dependencies = [
"tokio-util",
"tracing",
"wiremock",
"zstd",
]
[[package]]
@@ -1348,6 +1351,7 @@ dependencies = [
"which",
"wildmatch",
"wiremock",
"zstd",
]
[[package]]
@@ -2109,16 +2113,19 @@ dependencies = [
"codex-protocol",
"codex-utils-absolute-path",
"codex-utils-cargo-bin",
"http 1.3.1",
"notify",
"pretty_assertions",
"regex-lite",
"reqwest",
"serde",
"serde_json",
"shlex",
"tempfile",
"tokio",
"walkdir",
"wiremock",
"zstd",
]
[[package]]
@@ -3924,6 +3931,16 @@ version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130"
[[package]]
name = "jobserver"
version = "0.1.34"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9afb3de4395d6b3e67a780b6de64b51c978ecf11cb9a462c66be7d4ca9039d33"
dependencies = [
"getrandom 0.3.3",
"libc",
]
[[package]]
name = "js-sys"
version = "0.3.77"
@@ -8809,6 +8826,34 @@ dependencies = [
"syn 2.0.104",
]
[[package]]
name = "zstd"
version = "0.13.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e91ee311a569c327171651566e07972200e76fcfe2242a4fa446149a3881c08a"
dependencies = [
"zstd-safe",
]
[[package]]
name = "zstd-safe"
version = "7.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8f49c4d5f0abb602a93fb8736af2a4f4dd9512e36f7f570d66e65ff867ed3b9d"
dependencies = [
"zstd-sys",
]
[[package]]
name = "zstd-sys"
version = "2.0.16+zstd.1.5.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "91e19ebc2adc8f83e43039e79776e3fda8ca919132d68a1fed6a5faca2683748"
dependencies = [
"cc",
"pkg-config",
]
[[package]]
name = "zune-core"
version = "0.4.12"

View File

@@ -235,6 +235,7 @@ wildmatch = "2.6.1"
wiremock = "0.6"
zeroize = "1.8.2"
zstd = "0.13"
[workspace.lints]
rust = {}

View File

@@ -19,6 +19,7 @@ tracing = { workspace = true }
eventsource-stream = { workspace = true }
regex-lite = { workspace = true }
tokio-util = { workspace = true, features = ["codec"] }
zstd = { workspace = true }
[dev-dependencies]
anyhow = { workspace = true }

View File

@@ -6,7 +6,10 @@ use crate::common::ResponseStream;
use crate::endpoint::streaming::StreamingClient;
use crate::error::ApiError;
use crate::provider::Provider;
use crate::provider::RequestCompression;
use crate::provider::WireApi;
use crate::requests::body::encode_body;
use crate::requests::body::insert_compression_headers;
use crate::sse::chat::spawn_chat_stream;
use crate::telemetry::SseTelemetry;
use codex_client::HttpTransport;
@@ -45,8 +48,13 @@ impl<T: HttpTransport, A: AuthProvider> ChatClient<T, A> {
}
}
pub async fn stream_request(&self, request: ChatRequest) -> Result<ResponseStream, ApiError> {
self.stream(request.body, request.headers).await
pub async fn stream_request(
&self,
request: ChatRequest,
request_compression: RequestCompression,
) -> Result<ResponseStream, ApiError> {
self.stream(request.body, request.headers, request_compression)
.await
}
pub async fn stream_prompt(
@@ -55,6 +63,7 @@ impl<T: HttpTransport, A: AuthProvider> ChatClient<T, A> {
prompt: &ApiPrompt,
conversation_id: Option<String>,
session_source: Option<SessionSource>,
request_compression: RequestCompression,
) -> Result<ResponseStream, ApiError> {
use crate::requests::ChatRequestBuilder;
@@ -64,7 +73,7 @@ impl<T: HttpTransport, A: AuthProvider> ChatClient<T, A> {
.session_source(session_source)
.build(self.streaming.provider())?;
self.stream_request(request).await
self.stream_request(request, request_compression).await
}
fn path(&self) -> &'static str {
@@ -78,9 +87,13 @@ impl<T: HttpTransport, A: AuthProvider> ChatClient<T, A> {
&self,
body: Value,
extra_headers: HeaderMap,
request_compression: RequestCompression,
) -> Result<ResponseStream, ApiError> {
let mut headers = extra_headers;
insert_compression_headers(&mut headers, request_compression);
let encoded_body = encode_body(&body, request_compression)?;
self.streaming
.stream(self.path(), body, extra_headers, spawn_chat_stream)
.stream(self.path(), encoded_body, headers, spawn_chat_stream)
.await
}
}

View File

@@ -5,6 +5,7 @@ use crate::error::ApiError;
use crate::provider::Provider;
use crate::provider::WireApi;
use crate::telemetry::run_with_request_telemetry;
use codex_client::Body;
use codex_client::HttpTransport;
use codex_client::RequestTelemetry;
use codex_protocol::models::ResponseItem;
@@ -54,7 +55,7 @@ impl<T: HttpTransport, A: AuthProvider> CompactClient<T, A> {
let builder = || {
let mut req = self.provider.build_request(Method::POST, path);
req.headers.extend(extra_headers.clone());
req.body = Some(body.clone());
req.body = Some(Body::Json(body.clone()));
add_auth_headers(&self.auth, req)
};
@@ -89,6 +90,7 @@ struct CompactHistoryResponse {
#[cfg(test)]
mod tests {
use super::*;
use crate::provider::RetryConfig;
use async_trait::async_trait;
use codex_client::Request;

View File

@@ -6,6 +6,7 @@ use crate::common::TextControls;
use crate::endpoint::streaming::StreamingClient;
use crate::error::ApiError;
use crate::provider::Provider;
use crate::provider::RequestCompression;
use crate::provider::WireApi;
use crate::requests::ResponsesRequest;
use crate::requests::ResponsesRequestBuilder;
@@ -15,7 +16,6 @@ use codex_client::HttpTransport;
use codex_client::RequestTelemetry;
use codex_protocol::protocol::SessionSource;
use http::HeaderMap;
use serde_json::Value;
use std::sync::Arc;
use tracing::instrument;
@@ -33,6 +33,7 @@ pub struct ResponsesOptions {
pub conversation_id: Option<String>,
pub session_source: Option<SessionSource>,
pub extra_headers: HeaderMap,
pub request_compression: RequestCompression,
}
impl<T: HttpTransport, A: AuthProvider> ResponsesClient<T, A> {
@@ -56,7 +57,7 @@ impl<T: HttpTransport, A: AuthProvider> ResponsesClient<T, A> {
&self,
request: ResponsesRequest,
) -> Result<ResponseStream, ApiError> {
self.stream(request.body, request.headers).await
self.stream(request).await
}
#[instrument(level = "trace", skip_all, err)]
@@ -75,6 +76,7 @@ impl<T: HttpTransport, A: AuthProvider> ResponsesClient<T, A> {
conversation_id,
session_source,
extra_headers,
request_compression,
} = options;
let request = ResponsesRequestBuilder::new(model, &prompt.instructions, &prompt.input)
@@ -88,6 +90,7 @@ impl<T: HttpTransport, A: AuthProvider> ResponsesClient<T, A> {
.session_source(session_source)
.store_override(store_override)
.extra_headers(extra_headers)
.request_compression(request_compression)
.build(self.streaming.provider())?;
self.stream_request(request).await
@@ -100,13 +103,14 @@ impl<T: HttpTransport, A: AuthProvider> ResponsesClient<T, A> {
}
}
pub async fn stream(
&self,
body: Value,
extra_headers: HeaderMap,
) -> Result<ResponseStream, ApiError> {
pub async fn stream(&self, request: ResponsesRequest) -> Result<ResponseStream, ApiError> {
self.streaming
.stream(self.path(), body, extra_headers, spawn_response_stream)
.stream(
self.path(),
request.body,
request.headers,
spawn_response_stream,
)
.await
}
}

View File

@@ -5,12 +5,15 @@ use crate::error::ApiError;
use crate::provider::Provider;
use crate::telemetry::SseTelemetry;
use crate::telemetry::run_with_request_telemetry;
use codex_client::Body;
use codex_client::HttpTransport;
use codex_client::RequestTelemetry;
use codex_client::StreamResponse;
use http::HeaderMap;
use http::HeaderValue;
use http::Method;
use serde_json::Value;
use http::header::ACCEPT;
use http::header::CONTENT_TYPE;
use std::sync::Arc;
use std::time::Duration;
@@ -50,17 +53,18 @@ impl<T: HttpTransport, A: AuthProvider> StreamingClient<T, A> {
pub(crate) async fn stream(
&self,
path: &str,
body: Value,
body: Body,
extra_headers: HeaderMap,
spawner: fn(StreamResponse, Duration, Option<Arc<dyn SseTelemetry>>) -> ResponseStream,
) -> Result<ResponseStream, ApiError> {
let builder = || {
let mut req = self.provider.build_request(Method::POST, path);
req.headers.extend(extra_headers.clone());
req.headers.insert(
http::header::ACCEPT,
http::HeaderValue::from_static("text/event-stream"),
);
req.headers
.insert(ACCEPT, HeaderValue::from_static("text/event-stream"));
req.headers
.entry(CONTENT_TYPE)
.or_insert_with(|| HeaderValue::from_static("application/json"));
req.body = Some(body.clone());
add_auth_headers(&self.auth, req)
};

View File

@@ -41,6 +41,13 @@ impl RetryConfig {
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub enum RequestCompression {
#[default]
None,
Zstd,
}
/// HTTP endpoint configuration used to talk to a concrete API deployment.
///
/// Encapsulates base URL, default headers, query params, retry policy, and

View File

@@ -0,0 +1,40 @@
use crate::error::ApiError;
use crate::provider::RequestCompression;
use bytes::Bytes;
use codex_client::Body;
use http::HeaderMap;
use http::HeaderValue;
use http::header::CONTENT_ENCODING;
use serde_json::Value;
use std::time::Instant;
use tracing::info;
use zstd::stream::encode_all;
pub(crate) fn encode_body(body: &Value, compression: RequestCompression) -> Result<Body, ApiError> {
match compression {
RequestCompression::None => Ok(Body::Json(body.clone())),
RequestCompression::Zstd => {
let json = serde_json::to_vec(body).map_err(|err| {
ApiError::Stream(format!("failed to encode request body as json: {err}"))
})?;
let started_at = Instant::now();
let compressed = encode_all(json.as_slice(), 0).map_err(|err| {
ApiError::Stream(format!("failed to compress request body: {err}"))
})?;
let elapsed = started_at.elapsed();
info!(
input_bytes = json.len(),
output_bytes = compressed.len(),
elapsed_ms = elapsed.as_millis(),
"compressed request body"
);
Ok(Body::Bytes(Bytes::from(compressed)))
}
}
}
pub(crate) fn insert_compression_headers(headers: &mut HeaderMap, compression: RequestCompression) {
if matches!(compression, RequestCompression::Zstd) {
headers.insert(CONTENT_ENCODING, HeaderValue::from_static("zstd"));
}
}

View File

@@ -351,6 +351,7 @@ fn push_tool_call_message(messages: &mut Vec<Value>, tool_call: Value, reasoning
#[cfg(test)]
mod tests {
use super::*;
use crate::provider::RetryConfig;
use crate::provider::WireApi;
use codex_protocol::models::FunctionCallOutputPayload;

View File

@@ -1,3 +1,4 @@
pub(crate) mod body;
pub mod chat;
pub(crate) mod headers;
pub mod responses;

View File

@@ -3,9 +3,13 @@ use crate::common::ResponsesApiRequest;
use crate::common::TextControls;
use crate::error::ApiError;
use crate::provider::Provider;
use crate::provider::RequestCompression;
use crate::requests::body::encode_body;
use crate::requests::body::insert_compression_headers;
use crate::requests::headers::build_conversation_headers;
use crate::requests::headers::insert_header;
use crate::requests::headers::subagent_header;
use codex_client::Body;
use codex_protocol::models::ResponseItem;
use codex_protocol::protocol::SessionSource;
use http::HeaderMap;
@@ -13,7 +17,7 @@ use serde_json::Value;
/// Assembled request body plus headers for a Responses stream request.
pub struct ResponsesRequest {
pub body: Value,
pub body: Body,
pub headers: HeaderMap,
}
@@ -32,6 +36,7 @@ pub struct ResponsesRequestBuilder<'a> {
session_source: Option<SessionSource>,
store_override: Option<bool>,
headers: HeaderMap,
request_compression: RequestCompression,
}
impl<'a> ResponsesRequestBuilder<'a> {
@@ -94,6 +99,11 @@ impl<'a> ResponsesRequestBuilder<'a> {
self
}
pub fn request_compression(mut self, request_compression: RequestCompression) -> Self {
self.request_compression = request_compression;
self
}
pub fn build(self, provider: &Provider) -> Result<ResponsesRequest, ApiError> {
let model = self
.model
@@ -137,6 +147,8 @@ impl<'a> ResponsesRequestBuilder<'a> {
if let Some(subagent) = subagent_header(&self.session_source) {
insert_header(&mut headers, "x-openai-subagent", &subagent);
}
insert_compression_headers(&mut headers, self.request_compression);
let body = encode_body(&body, self.request_compression)?;
Ok(ResponsesRequest { body, headers })
}
@@ -172,8 +184,10 @@ fn attach_item_ids(payload_json: &mut Value, original_items: &[ResponseItem]) {
#[cfg(test)]
mod tests {
use super::*;
use crate::provider::RetryConfig;
use crate::provider::WireApi;
use codex_client::Body;
use codex_protocol::protocol::SubAgentSource;
use http::HeaderValue;
use pretty_assertions::assert_eq;
@@ -219,10 +233,12 @@ mod tests {
.build(&provider)
.expect("request");
assert_eq!(request.body.get("store"), Some(&Value::Bool(true)));
let Body::Json(body) = &request.body else {
panic!("expected json body for responses request");
};
assert_eq!(body.get("store"), Some(&Value::Bool(true)));
let ids: Vec<Option<String>> = request
.body
let ids: Vec<Option<String>> = body
.get("input")
.and_then(|v| v.as_array())
.into_iter()

View File

@@ -10,7 +10,9 @@ use codex_api::ChatClient;
use codex_api::Provider;
use codex_api::ResponsesClient;
use codex_api::ResponsesOptions;
use codex_api::ResponsesRequest;
use codex_api::WireApi;
use codex_client::Body;
use codex_client::HttpTransport;
use codex_client::Request;
use codex_client::Response;
@@ -136,6 +138,13 @@ fn provider(name: &str, wire: WireApi) -> Provider {
}
}
fn responses_request(body: Value) -> ResponsesRequest {
ResponsesRequest {
body: Body::Json(body),
headers: HeaderMap::new(),
}
}
#[derive(Clone)]
struct FlakyTransport {
state: Arc<Mutex<i64>>,
@@ -201,7 +210,9 @@ async fn chat_client_uses_chat_completions_path_for_chat_wire() -> Result<()> {
let client = ChatClient::new(transport, provider("openai", WireApi::Chat), NoAuth);
let body = serde_json::json!({ "echo": true });
let _stream = client.stream(body, HeaderMap::new()).await?;
let _stream = client
.stream(body, HeaderMap::new(), Default::default())
.await?;
let requests = state.take_stream_requests();
assert_path_ends_with(&requests, "/chat/completions");
@@ -215,7 +226,9 @@ async fn chat_client_uses_responses_path_for_responses_wire() -> Result<()> {
let client = ChatClient::new(transport, provider("openai", WireApi::Responses), NoAuth);
let body = serde_json::json!({ "echo": true });
let _stream = client.stream(body, HeaderMap::new()).await?;
let _stream = client
.stream(body, HeaderMap::new(), Default::default())
.await?;
let requests = state.take_stream_requests();
assert_path_ends_with(&requests, "/responses");
@@ -228,8 +241,8 @@ async fn responses_client_uses_responses_path_for_responses_wire() -> Result<()>
let transport = RecordingTransport::new(state.clone());
let client = ResponsesClient::new(transport, provider("openai", WireApi::Responses), NoAuth);
let body = serde_json::json!({ "echo": true });
let _stream = client.stream(body, HeaderMap::new()).await?;
let request = responses_request(serde_json::json!({ "echo": true }));
let _stream = client.stream(request).await?;
let requests = state.take_stream_requests();
assert_path_ends_with(&requests, "/responses");
@@ -242,8 +255,8 @@ async fn responses_client_uses_chat_path_for_chat_wire() -> Result<()> {
let transport = RecordingTransport::new(state.clone());
let client = ResponsesClient::new(transport, provider("openai", WireApi::Chat), NoAuth);
let body = serde_json::json!({ "echo": true });
let _stream = client.stream(body, HeaderMap::new()).await?;
let request = responses_request(serde_json::json!({ "echo": true }));
let _stream = client.stream(request).await?;
let requests = state.take_stream_requests();
assert_path_ends_with(&requests, "/chat/completions");
@@ -257,8 +270,8 @@ async fn streaming_client_adds_auth_headers() -> Result<()> {
let auth = StaticAuth::new("secret-token", "acct-1");
let client = ResponsesClient::new(transport, provider("openai", WireApi::Responses), auth);
let body = serde_json::json!({ "model": "gpt-test" });
let _stream = client.stream(body, HeaderMap::new()).await?;
let request = responses_request(serde_json::json!({ "model": "gpt-test" }));
let _stream = client.stream(request).await?;
let requests = state.take_stream_requests();
assert_eq!(requests.len(), 1);

View File

@@ -8,7 +8,9 @@ use codex_api::AuthProvider;
use codex_api::Provider;
use codex_api::ResponseEvent;
use codex_api::ResponsesClient;
use codex_api::ResponsesRequest;
use codex_api::WireApi;
use codex_client::Body;
use codex_client::HttpTransport;
use codex_client::Request;
use codex_client::Response;
@@ -94,6 +96,13 @@ fn build_responses_body(events: Vec<Value>) -> String {
body
}
fn responses_request(body: Value) -> ResponsesRequest {
ResponsesRequest {
body: Body::Json(body),
headers: HeaderMap::new(),
}
}
#[tokio::test]
async fn responses_stream_parses_items_and_completed_end_to_end() -> Result<()> {
let item1 = serde_json::json!({
@@ -123,9 +132,8 @@ async fn responses_stream_parses_items_and_completed_end_to_end() -> Result<()>
let transport = FixtureSseTransport::new(body);
let client = ResponsesClient::new(transport, provider("openai", WireApi::Responses), NoAuth);
let mut stream = client
.stream(serde_json::json!({"echo": true}), HeaderMap::new())
.await?;
let request = responses_request(serde_json::json!({"echo": true}));
let mut stream = client.stream(request).await?;
let mut events = Vec::new();
while let Some(ev) = stream.next().await {
@@ -188,9 +196,8 @@ async fn responses_stream_aggregates_output_text_deltas() -> Result<()> {
let transport = FixtureSseTransport::new(body);
let client = ResponsesClient::new(transport, provider("openai", WireApi::Responses), NoAuth);
let stream = client
.stream(serde_json::json!({"echo": true}), HeaderMap::new())
.await?;
let request = responses_request(serde_json::json!({"echo": true}));
let stream = client.stream(request).await?;
let mut stream = stream.aggregate();
let mut events = Vec::new();

View File

@@ -104,6 +104,13 @@ impl CodexRequestBuilder {
self.map(|builder| builder.json(value))
}
pub fn body<T>(self, body: T) -> Self
where
T: Into<reqwest::Body>,
{
self.map(|builder| builder.body(body))
}
pub async fn send(self) -> Result<Response, reqwest::Error> {
let headers = trace_headers();

View File

@@ -10,6 +10,7 @@ pub use crate::default_client::CodexHttpClient;
pub use crate::default_client::CodexRequestBuilder;
pub use crate::error::StreamError;
pub use crate::error::TransportError;
pub use crate::request::Body;
pub use crate::request::Request;
pub use crate::request::Response;
pub use crate::retry::RetryOn;

View File

@@ -5,12 +5,18 @@ use serde::Serialize;
use serde_json::Value;
use std::time::Duration;
#[derive(Debug, Clone)]
pub enum Body {
Json(Value),
Bytes(Bytes),
}
#[derive(Debug, Clone)]
pub struct Request {
pub method: Method,
pub url: String,
pub headers: HeaderMap,
pub body: Option<Value>,
pub body: Option<Body>,
pub timeout: Option<Duration>,
}
@@ -26,7 +32,7 @@ impl Request {
}
pub fn with_json<T: Serialize>(mut self, body: &T) -> Self {
self.body = serde_json::to_value(body).ok();
self.body = serde_json::to_value(body).ok().map(Body::Json);
self
}
}

View File

@@ -1,6 +1,7 @@
use crate::default_client::CodexHttpClient;
use crate::default_client::CodexRequestBuilder;
use crate::error::TransportError;
use crate::request::Body;
use crate::request::Request;
use crate::request::Response;
use async_trait::async_trait;
@@ -52,7 +53,10 @@ impl ReqwestTransport {
builder = builder.timeout(timeout);
}
if let Some(body) = req.body {
builder = builder.json(&body);
builder = match body {
Body::Json(value) => builder.json(&value),
Body::Bytes(bytes) => builder.body(bytes),
};
}
Ok(builder)
}
@@ -101,10 +105,10 @@ impl HttpTransport for ReqwestTransport {
async fn stream(&self, req: Request) -> Result<StreamResponse, TransportError> {
if enabled!(Level::TRACE) {
trace!(
"{} to {}: {}",
req.method,
req.url,
req.body.as_ref().unwrap_or_default()
method = %req.method,
url = %req.url,
body = ?req.body,
"Sending streaming request"
);
}

View File

@@ -89,6 +89,7 @@ url = { workspace = true }
uuid = { workspace = true, features = ["serde", "v4", "v5"] }
which = { workspace = true }
wildmatch = { workspace = true }
zstd = { workspace = true }
[features]
deterministic_process_ids = []

View File

@@ -156,6 +156,9 @@ impl ModelClient {
let mut refreshed = false;
loop {
let auth = auth_manager.as_ref().and_then(|m| m.auth());
let request_compression = self
.provider
.request_compression_for(auth.as_ref().map(|a| a.mode), &self.config.features);
let api_provider = self
.provider
.to_api_provider(auth.as_ref().map(|a| a.mode))?;
@@ -171,6 +174,7 @@ impl ModelClient {
&api_prompt,
Some(conversation_id.clone()),
Some(session_source.clone()),
request_compression,
)
.await;
@@ -245,6 +249,9 @@ impl ModelClient {
let mut refreshed = false;
loop {
let auth = auth_manager.as_ref().and_then(|m| m.auth());
let request_compression = self
.provider
.request_compression_for(auth.as_ref().map(|a| a.mode), &self.config.features);
let api_provider = self
.provider
.to_api_provider(auth.as_ref().map(|a| a.mode))?;
@@ -263,6 +270,7 @@ impl ModelClient {
conversation_id: Some(conversation_id.clone()),
session_source: Some(session_source.clone()),
extra_headers: beta_feature_headers(&self.config),
request_compression,
};
let stream_result = client

View File

@@ -24,7 +24,6 @@ use std::sync::OnceLock;
pub static USER_AGENT_SUFFIX: LazyLock<Mutex<Option<String>>> = LazyLock::new(|| Mutex::new(None));
pub const DEFAULT_ORIGINATOR: &str = "codex_cli_rs";
pub const CODEX_INTERNAL_ORIGINATOR_OVERRIDE_ENV_VAR: &str = "CODEX_INTERNAL_ORIGINATOR_OVERRIDE";
#[derive(Debug, Clone)]
pub struct Originator {
pub value: String,

View File

@@ -8,7 +8,6 @@
use crate::config::ConfigToml;
use crate::config::profile::ConfigProfile;
use serde::Deserialize;
use serde::Serialize;
use std::collections::BTreeMap;
use std::collections::BTreeSet;
@@ -74,6 +73,8 @@ pub enum Feature {
ApplyPatchFreeform,
/// Allow the model to request web searches.
WebSearchRequest,
/// Allow request body compression when using ChatGPT auth.
RequestCompression,
/// Gate the execpolicy enforcement for shell/unified exec.
ExecPolicy,
/// Enable Windows sandbox (restricted token) on Windows.
@@ -150,16 +151,16 @@ impl FeatureOverrides {
impl Features {
/// Starts with built-in defaults.
pub fn with_defaults() -> Self {
let mut set = BTreeSet::new();
let mut features = Self {
enabled: BTreeSet::new(),
legacy_usages: BTreeSet::new(),
};
for spec in FEATURES {
if spec.default_enabled {
set.insert(spec.id);
features.enable(spec.id);
}
}
Self {
enabled: set,
legacy_usages: BTreeSet::new(),
}
features
}
pub fn enabled(&self, f: Feature) -> bool {
@@ -196,7 +197,7 @@ impl Features {
.map(|usage| (usage.alias.as_str(), usage.feature))
}
/// Apply a table of key -> bool toggles (e.g. from TOML).
/// Apply a table of key -> value toggles (e.g. from TOML).
pub fn apply_map(&mut self, m: &BTreeMap<String, bool>) {
for (k, v) in m {
match feature_for_key(k) {
@@ -330,6 +331,12 @@ pub const FEATURES: &[FeatureSpec] = &[
stage: Stage::Stable,
default_enabled: false,
},
FeatureSpec {
id: Feature::RequestCompression,
key: "request_compression",
stage: Stage::Experimental,
default_enabled: false,
},
// Beta program. Rendered in the `/experimental` menu for users.
FeatureSpec {
id: Feature::UnifiedExec,

View File

@@ -19,6 +19,8 @@ use std::env::VarError;
use std::time::Duration;
use crate::error::EnvVarError;
use crate::features::Feature;
use crate::features::Features;
const DEFAULT_STREAM_IDLE_TIMEOUT_MS: u64 = 300_000;
const DEFAULT_STREAM_MAX_RETRIES: u64 = 5;
const DEFAULT_REQUEST_MAX_RETRIES: u64 = 4;
@@ -253,6 +255,21 @@ impl ModelProviderInfo {
pub fn is_openai(&self) -> bool {
self.name == OPENAI_PROVIDER_NAME
}
pub fn request_compression_for(
&self,
auth_mode: Option<AuthMode>,
features: &Features,
) -> codex_api::provider::RequestCompression {
if self.is_openai()
&& matches!(auth_mode, Some(AuthMode::ChatGPT))
&& features.enabled(Feature::RequestCompression)
{
codex_api::provider::RequestCompression::Zstd
} else {
codex_api::provider::RequestCompression::None
}
}
}
pub const DEFAULT_LMSTUDIO_PORT: u16 = 1234;

View File

@@ -15,14 +15,17 @@ codex-core = { workspace = true, features = ["test-support"] }
codex-protocol = { workspace = true }
codex-utils-absolute-path = { workspace = true }
codex-utils-cargo-bin = { workspace = true }
http = { workspace = true }
notify = { workspace = true }
regex-lite = { workspace = true }
serde_json = { workspace = true }
serde = { workspace = true }
tempfile = { workspace = true }
tokio = { workspace = true, features = ["time"] }
walkdir = { workspace = true }
wiremock = { workspace = true }
shlex = { workspace = true }
zstd = { workspace = true }
[dev-dependencies]
pretty_assertions = { workspace = true }

View File

@@ -11,11 +11,15 @@ use regex_lite::Regex;
use std::path::PathBuf;
pub mod process;
pub mod request;
pub mod responses;
pub mod streaming_sse;
pub mod test_codex;
pub mod test_codex_exec;
pub use request::RequestBodyExt;
pub use request::body_contains;
#[track_caller]
pub fn assert_regex_match<'s>(pattern: &str, actual: &'s str) -> regex_lite::Captures<'s> {
let regex = Regex::new(pattern).unwrap_or_else(|err| {
@@ -178,7 +182,7 @@ where
F: FnMut(&codex_core::protocol::EventMsg) -> bool,
{
use tokio::time::Duration;
wait_for_event_with_timeout(codex, predicate, Duration::from_secs(1)).await
wait_for_event_with_timeout(codex, predicate, Duration::from_secs(10)).await
}
pub async fn wait_for_event_match<T, F>(codex: &CodexConversation, matcher: F) -> T

View File

@@ -0,0 +1,59 @@
use http::header::CONTENT_ENCODING;
use serde::de::DeserializeOwned;
use wiremock::Match;
pub fn decoded_body_bytes(request: &wiremock::Request) -> Vec<u8> {
if is_zstd_encoded(request) {
zstd::decode_all(request.body.as_slice()).unwrap_or_else(|err| {
panic!("failed to decode zstd-encoded request body: {err}");
})
} else {
request.body.clone()
}
}
pub fn decoded_body_string(request: &wiremock::Request) -> String {
String::from_utf8_lossy(&decoded_body_bytes(request)).into_owned()
}
pub trait RequestBodyExt {
fn json_body<T: DeserializeOwned>(&self) -> T;
fn text_body(&self) -> String;
}
impl RequestBodyExt for wiremock::Request {
fn json_body<T: DeserializeOwned>(&self) -> T {
serde_json::from_slice(&decoded_body_bytes(self)).unwrap_or_else(|err| {
panic!("failed to decode request body as JSON: {err}");
})
}
fn text_body(&self) -> String {
decoded_body_string(self)
}
}
pub fn body_contains(needle: impl Into<String>) -> impl Match {
BodyContains {
needle: needle.into(),
}
}
struct BodyContains {
needle: String,
}
impl Match for BodyContains {
fn matches(&self, request: &wiremock::Request) -> bool {
decoded_body_string(request).contains(self.needle.as_str())
}
}
fn is_zstd_encoded(request: &wiremock::Request) -> bool {
request
.headers
.get(CONTENT_ENCODING)
.and_then(|value| value.to_str().ok())
.map(|value| value.eq_ignore_ascii_case("zstd"))
.unwrap_or(false)
}

View File

@@ -15,6 +15,7 @@ use wiremock::ResponseTemplate;
use wiremock::matchers::method;
use wiremock::matchers::path_regex;
use crate::RequestBodyExt;
use crate::test_codex::ApplyPatchModelOutput;
#[derive(Debug, Clone)]
@@ -67,7 +68,7 @@ pub struct ResponsesRequest(wiremock::Request);
impl ResponsesRequest {
pub fn body_json(&self) -> Value {
self.0.body_json().unwrap()
self.0.json_body()
}
/// Returns all `input_text` spans from `message` inputs for the provided role.
@@ -83,7 +84,7 @@ impl ResponsesRequest {
}
pub fn input(&self) -> Vec<Value> {
self.0.body_json::<Value>().unwrap()["input"]
self.body_json()["input"]
.as_array()
.expect("input array not found in request")
.clone()
@@ -721,10 +722,7 @@ pub async fn get_responses_request_bodies(server: &MockServer) -> Vec<Value> {
get_responses_requests(server)
.await
.into_iter()
.map(|req| {
req.body_json::<Value>()
.expect("request body to be valid JSON")
})
.map(|req| req.json_body::<Value>())
.collect()
}

View File

@@ -10,6 +10,7 @@ use codex_core::Prompt;
use codex_core::ResponseEvent;
use codex_core::ResponseItem;
use codex_core::WireApi;
use codex_core::features::Feature;
use codex_core::models_manager::manager::ModelsManager;
use codex_otel::otel_manager::OtelManager;
use codex_protocol::ConversationId;
@@ -317,3 +318,191 @@ async fn responses_respects_model_family_overrides_from_config() {
Some("detailed")
);
}
#[tokio::test]
async fn responses_request_body_is_zstd_encoded() {
core_test_support::skip_if_no_network!();
let server = responses::start_mock_server().await;
let response_body = responses::sse(vec![
responses::ev_response_created("resp-1"),
responses::ev_completed("resp-1"),
]);
let request_recorder = responses::mount_sse_once(&server, response_body).await;
let provider = ModelProviderInfo {
name: "OpenAI".into(),
base_url: Some(format!("{}/v1", server.uri())),
env_key: None,
env_key_instructions: None,
experimental_bearer_token: None,
wire_api: WireApi::Responses,
query_params: None,
http_headers: None,
env_http_headers: None,
request_max_retries: Some(0),
stream_max_retries: Some(0),
stream_idle_timeout_ms: Some(5_000),
requires_openai_auth: false,
};
let codex_home = TempDir::new().expect("failed to create TempDir");
let mut config = load_default_config_for_test(&codex_home).await;
config.model_provider_id = provider.name.clone();
config.model_provider = provider.clone();
config.features.enable(Feature::RequestCompression);
let effort = config.model_reasoning_effort;
let summary = config.model_reasoning_summary;
let model = ModelsManager::get_model_offline(config.model.as_deref());
config.model = Some(model.clone());
let config = Arc::new(config);
let conversation_id = ConversationId::new();
let session_source = SessionSource::Exec;
let model_family = ModelsManager::construct_model_family_offline(model.as_str(), &config);
let otel_manager = OtelManager::new(
conversation_id,
model.as_str(),
model_family.slug.as_str(),
None,
Some("test@test.com".to_string()),
Some(AuthMode::ChatGPT),
false,
"test".to_string(),
session_source.clone(),
);
let auth_manager =
AuthManager::from_auth_for_testing(CodexAuth::create_dummy_chatgpt_auth_for_testing());
let client = ModelClient::new(
Arc::clone(&config),
Some(auth_manager),
model_family,
otel_manager,
provider,
effort,
summary,
conversation_id,
session_source,
);
let mut prompt = Prompt::default();
prompt.input = vec![ResponseItem::Message {
id: None,
role: "user".into(),
content: vec![ContentItem::InputText {
text: "hello".into(),
}],
}];
let mut stream = client.stream(&prompt).await.expect("stream failed");
while let Some(event) = stream.next().await {
if matches!(event, Ok(ResponseEvent::Completed { .. })) {
break;
}
}
let request = request_recorder.single_request();
assert_eq!(request.header("content-encoding").as_deref(), Some("zstd"));
assert_eq!(
request.header("content-type").as_deref(),
Some("application/json")
);
let request_body = request.body_json();
assert_eq!(request_body["stream"].as_bool(), Some(true));
assert_eq!(
request_body["input"][0]["content"][0]["text"].as_str(),
Some("hello")
);
}
#[tokio::test]
async fn responses_request_body_is_uncompressed_when_disabled() {
core_test_support::skip_if_no_network!();
let server = responses::start_mock_server().await;
let response_body = responses::sse(vec![
responses::ev_response_created("resp-1"),
responses::ev_completed("resp-1"),
]);
let request_recorder = responses::mount_sse_once(&server, response_body).await;
let provider = ModelProviderInfo {
name: "OpenAI".into(),
base_url: Some(format!("{}/v1", server.uri())),
env_key: None,
env_key_instructions: None,
experimental_bearer_token: None,
wire_api: WireApi::Responses,
query_params: None,
http_headers: None,
env_http_headers: None,
request_max_retries: Some(0),
stream_max_retries: Some(0),
stream_idle_timeout_ms: Some(5_000),
requires_openai_auth: false,
};
let codex_home = TempDir::new().expect("failed to create TempDir");
let mut config = load_default_config_for_test(&codex_home).await;
config.model_provider_id = provider.name.clone();
config.model_provider = provider.clone();
let effort = config.model_reasoning_effort;
let summary = config.model_reasoning_summary;
let model = ModelsManager::get_model_offline(config.model.as_deref());
config.model = Some(model.clone());
let config = Arc::new(config);
let conversation_id = ConversationId::new();
let session_source = SessionSource::Exec;
let model_family = ModelsManager::construct_model_family_offline(model.as_str(), &config);
let otel_manager = OtelManager::new(
conversation_id,
model.as_str(),
model_family.slug.as_str(),
None,
Some("test@test.com".to_string()),
Some(AuthMode::ChatGPT),
false,
"test".to_string(),
session_source.clone(),
);
let auth_manager =
AuthManager::from_auth_for_testing(CodexAuth::create_dummy_chatgpt_auth_for_testing());
let client = ModelClient::new(
Arc::clone(&config),
Some(auth_manager),
model_family,
otel_manager,
provider,
effort,
summary,
conversation_id,
session_source,
);
let mut prompt = Prompt::default();
prompt.input = vec![ResponseItem::Message {
id: None,
role: "user".into(),
content: vec![ContentItem::InputText {
text: "hello".into(),
}],
}];
let mut stream = client.stream(&prompt).await.expect("stream failed");
while let Some(event) = stream.next().await {
if matches!(event, Ok(ResponseEvent::Completed { .. })) {
break;
}
}
let request = request_recorder.single_request();
assert_eq!(request.header("content-encoding"), None);
assert_eq!(
request.header("content-type").as_deref(),
Some("application/json")
);
}

View File

@@ -29,6 +29,8 @@ use codex_protocol::models::ReasoningItemReasoningSummary;
use codex_protocol::models::WebSearchAction;
use codex_protocol::openai_models::ReasoningEffort;
use codex_protocol::user_input::UserInput;
use core_test_support::RequestBodyExt;
use core_test_support::body_contains;
use core_test_support::load_default_config_for_test;
use core_test_support::load_sse_fixture_with_id;
use core_test_support::responses::ev_completed_with_tokens;
@@ -51,7 +53,6 @@ use uuid::Uuid;
use wiremock::Mock;
use wiremock::MockServer;
use wiremock::ResponseTemplate;
use wiremock::matchers::body_string_contains;
use wiremock::matchers::header_regex;
use wiremock::matchers::method;
use wiremock::matchers::path;
@@ -507,7 +508,7 @@ async fn chatgpt_auth_sends_correct_request() {
let request_authorization = request.headers.get("authorization").unwrap();
let request_originator = request.headers.get("originator").unwrap();
let request_chatgpt_account_id = request.headers.get("chatgpt-account-id").unwrap();
let request_body = request.body_json::<serde_json::Value>().unwrap();
let request_body = request.json_body::<serde_json::Value>();
assert_eq!(
request_conversation_id.to_str().unwrap(),
@@ -1495,7 +1496,7 @@ async fn context_window_error_sets_total_tokens_to_model_window() -> anyhow::Res
mount_sse_once_match(
&server,
body_string_contains("trigger context window"),
body_contains("trigger context window"),
sse_failed(
"resp_context_window",
"context_length_exceeded",
@@ -1506,7 +1507,7 @@ async fn context_window_error_sets_total_tokens_to_model_window() -> anyhow::Res
mount_sse_once_match(
&server,
body_string_contains("seed turn"),
body_contains("seed turn"),
sse_completed("resp_seed"),
)
.await;
@@ -1882,8 +1883,7 @@ async fn history_dedupes_streamed_and_final_messages_across_turns() {
]);
let r3_input_array = requests[2]
.body_json::<serde_json::Value>()
.unwrap()
.json_body::<serde_json::Value>()
.get("input")
.and_then(|v| v.as_array())
.cloned()

View File

@@ -17,6 +17,7 @@ use codex_core::protocol::SandboxPolicy;
use codex_core::protocol::WarningEvent;
use codex_protocol::config_types::ReasoningSummary;
use codex_protocol::user_input::UserInput;
use core_test_support::RequestBodyExt;
use core_test_support::load_default_config_for_test;
use core_test_support::responses::ev_local_shell_call;
use core_test_support::responses::ev_reasoning_item;
@@ -132,7 +133,6 @@ async fn summarize_context_three_requests_and_instructions() {
// SSE 3: minimal completed; we only need to capture the request body.
let sse3 = sse(vec![ev_completed("r3")]);
// Mount the three expected requests in sequence so the assertions below can
// inspect them without relying on specific prompt markers.
let request_log = mount_sse_sequence(&server, vec![sse1, sse2, sse3]).await;
@@ -361,7 +361,8 @@ async fn manual_compact_uses_custom_prompt() {
let requests = get_responses_requests(&server).await;
let body = requests
.iter()
.find_map(|req| req.body_json::<serde_json::Value>().ok())
.map(core_test_support::RequestBodyExt::json_body::<serde_json::Value>)
.next()
.expect("summary request body");
let input = body
@@ -591,9 +592,7 @@ async fn multiple_auto_compact_per_task_runs_after_token_limit_hit() {
// collect the requests payloads from the model
let requests_payloads = get_responses_requests(&server).await;
let body = requests_payloads[0]
.body_json::<serde_json::Value>()
.unwrap();
let body = requests_payloads[0].json_body::<serde_json::Value>();
let input = body.get("input").and_then(|v| v.as_array()).unwrap();
fn normalize_inputs(values: &[serde_json::Value]) -> Vec<serde_json::Value> {
@@ -634,9 +633,7 @@ async fn multiple_auto_compact_per_task_runs_after_token_limit_hit() {
prefixed_third_summary.as_str(),
];
for (i, expected_summary) in compaction_indices.into_iter().zip(expected_summaries) {
let body = requests_payloads.clone()[i]
.body_json::<serde_json::Value>()
.unwrap();
let body = requests_payloads.clone()[i].json_body::<serde_json::Value>();
let input = body.get("input").and_then(|v| v.as_array()).unwrap();
let input = normalize_inputs(input);
assert_eq!(input.len(), 3);
@@ -999,7 +996,7 @@ async fn multiple_auto_compact_per_task_runs_after_token_limit_hit() {
]);
for (i, request) in requests_payloads.iter().enumerate() {
let body = request.body_json::<serde_json::Value>().unwrap();
let body = request.json_body::<serde_json::Value>();
let input = body.get("input").and_then(|v| v.as_array()).unwrap();
let expected_input = expected_requests_inputs[i].as_array().unwrap();
assert_eq!(normalize_inputs(input), normalize_inputs(expected_input));
@@ -1038,30 +1035,30 @@ async fn auto_compact_runs_after_token_limit_hit() {
let prefixed_auto_summary = AUTO_SUMMARY_TEXT;
let first_matcher = |req: &wiremock::Request| {
let body = std::str::from_utf8(&req.body).unwrap_or("");
let body = req.text_body();
body.contains(FIRST_AUTO_MSG)
&& !body.contains(SECOND_AUTO_MSG)
&& !body_contains_text(body, SUMMARIZATION_PROMPT)
&& !body_contains_text(body.as_str(), SUMMARIZATION_PROMPT)
};
mount_sse_once_match(&server, first_matcher, sse1).await;
let second_matcher = |req: &wiremock::Request| {
let body = std::str::from_utf8(&req.body).unwrap_or("");
let body = req.text_body();
body.contains(SECOND_AUTO_MSG)
&& body.contains(FIRST_AUTO_MSG)
&& !body_contains_text(body, SUMMARIZATION_PROMPT)
&& !body_contains_text(body.as_str(), SUMMARIZATION_PROMPT)
};
mount_sse_once_match(&server, second_matcher, sse2).await;
let third_matcher = |req: &wiremock::Request| {
let body = std::str::from_utf8(&req.body).unwrap_or("");
body_contains_text(body, SUMMARIZATION_PROMPT)
let body = req.text_body();
body_contains_text(body.as_str(), SUMMARIZATION_PROMPT)
};
mount_sse_once_match(&server, third_matcher, sse3).await;
let fourth_matcher = |req: &wiremock::Request| {
let body = std::str::from_utf8(&req.body).unwrap_or("");
body.contains(POST_AUTO_USER_MSG) && !body_contains_text(body, SUMMARIZATION_PROMPT)
let body = req.text_body();
body.contains(POST_AUTO_USER_MSG)
&& !body_contains_text(body.as_str(), SUMMARIZATION_PROMPT)
};
mount_sse_once_match(&server, fourth_matcher, sse4).await;
@@ -1126,10 +1123,7 @@ async fn auto_compact_runs_after_token_limit_hit() {
requests.len()
);
let is_auto_compact = |req: &wiremock::Request| {
body_contains_text(
std::str::from_utf8(&req.body).unwrap_or(""),
SUMMARIZATION_PROMPT,
)
body_contains_text(req.text_body().as_str(), SUMMARIZATION_PROMPT)
};
let auto_compact_count = requests.iter().filter(|req| is_auto_compact(req)).count();
assert_eq!(
@@ -1151,20 +1145,16 @@ async fn auto_compact_runs_after_token_limit_hit() {
.enumerate()
.rev()
.find_map(|(idx, req)| {
let body = std::str::from_utf8(&req.body).unwrap_or("");
(body.contains(POST_AUTO_USER_MSG) && !body_contains_text(body, SUMMARIZATION_PROMPT))
let body = req.text_body();
(body.contains(POST_AUTO_USER_MSG) && !body_contains_text(&body, SUMMARIZATION_PROMPT))
.then_some(idx)
})
.expect("follow-up request missing");
assert_eq!(follow_up_index, 3, "follow-up request should be last");
let body_first = requests[0].body_json::<serde_json::Value>().unwrap();
let body_auto = requests[auto_compact_index]
.body_json::<serde_json::Value>()
.unwrap();
let body_follow_up = requests[follow_up_index]
.body_json::<serde_json::Value>()
.unwrap();
let body_first = requests[0].json_body::<serde_json::Value>();
let body_auto = requests[auto_compact_index].json_body::<serde_json::Value>();
let body_follow_up = requests[follow_up_index].json_body::<serde_json::Value>();
let instructions = body_auto
.get("instructions")
.and_then(|v| v.as_str())
@@ -1375,24 +1365,24 @@ async fn auto_compact_persists_rollout_entries() {
]);
let first_matcher = |req: &wiremock::Request| {
let body = std::str::from_utf8(&req.body).unwrap_or("");
let body = req.text_body();
body.contains(FIRST_AUTO_MSG)
&& !body.contains(SECOND_AUTO_MSG)
&& !body_contains_text(body, SUMMARIZATION_PROMPT)
&& !body_contains_text(body.as_str(), SUMMARIZATION_PROMPT)
};
mount_sse_once_match(&server, first_matcher, sse1).await;
let second_matcher = |req: &wiremock::Request| {
let body = std::str::from_utf8(&req.body).unwrap_or("");
let body = req.text_body();
body.contains(SECOND_AUTO_MSG)
&& body.contains(FIRST_AUTO_MSG)
&& !body_contains_text(body, SUMMARIZATION_PROMPT)
&& !body_contains_text(body.as_str(), SUMMARIZATION_PROMPT)
};
mount_sse_once_match(&server, second_matcher, sse2).await;
let third_matcher = |req: &wiremock::Request| {
let body = std::str::from_utf8(&req.body).unwrap_or("");
body_contains_text(body, SUMMARIZATION_PROMPT)
let body = req.text_body();
body_contains_text(body.as_str(), SUMMARIZATION_PROMPT)
};
mount_sse_once_match(&server, third_matcher, sse3).await;

View File

@@ -23,6 +23,7 @@ use codex_core::protocol::Op;
use codex_core::protocol::WarningEvent;
use codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR;
use codex_protocol::user_input::UserInput;
use core_test_support::RequestBodyExt;
use core_test_support::load_default_config_for_test;
use core_test_support::responses::ev_assistant_message;
use core_test_support::responses::ev_completed;
@@ -796,8 +797,9 @@ async fn mount_initial_flow(server: &MockServer) {
let sse5 = sse(vec![ev_completed("r5")]);
let match_first = |req: &wiremock::Request| {
let body = std::str::from_utf8(&req.body).unwrap_or("");
let body = req.text_body();
body.contains("\"text\":\"hello world\"")
&& !body_contains_text(body.as_str(), SUMMARIZATION_PROMPT)
&& !body.contains(&format!("\"text\":\"{SUMMARY_TEXT}\""))
&& !body.contains("\"text\":\"AFTER_COMPACT\"")
&& !body.contains("\"text\":\"AFTER_RESUME\"")
@@ -806,13 +808,13 @@ async fn mount_initial_flow(server: &MockServer) {
mount_sse_once_match(server, match_first, sse1).await;
let match_compact = |req: &wiremock::Request| {
let body = std::str::from_utf8(&req.body).unwrap_or("");
body_contains_text(body, SUMMARIZATION_PROMPT) || body.contains(&json_fragment(FIRST_REPLY))
let body = req.text_body();
body_contains_text(body.as_str(), SUMMARIZATION_PROMPT)
};
mount_sse_once_match(server, match_compact, sse2).await;
let match_after_compact = |req: &wiremock::Request| {
let body = std::str::from_utf8(&req.body).unwrap_or("");
let body = req.text_body();
body.contains("\"text\":\"AFTER_COMPACT\"")
&& !body.contains("\"text\":\"AFTER_RESUME\"")
&& !body.contains("\"text\":\"AFTER_FORK\"")
@@ -820,13 +822,13 @@ async fn mount_initial_flow(server: &MockServer) {
mount_sse_once_match(server, match_after_compact, sse3).await;
let match_after_resume = |req: &wiremock::Request| {
let body = std::str::from_utf8(&req.body).unwrap_or("");
let body = req.text_body();
body.contains("\"text\":\"AFTER_RESUME\"")
};
mount_sse_once_match(server, match_after_resume, sse4).await;
let match_after_fork = |req: &wiremock::Request| {
let body = std::str::from_utf8(&req.body).unwrap_or("");
let body = req.text_body();
body.contains("\"text\":\"AFTER_FORK\"")
};
mount_sse_once_match(server, match_after_fork, sse5).await;
@@ -840,13 +842,13 @@ async fn mount_second_compact_flow(server: &MockServer) {
let sse7 = sse(vec![ev_completed("r7")]);
let match_second_compact = |req: &wiremock::Request| {
let body = std::str::from_utf8(&req.body).unwrap_or("");
body.contains("AFTER_FORK")
let body = req.text_body();
body_contains_text(body.as_str(), SUMMARIZATION_PROMPT) && body.contains("AFTER_FORK")
};
mount_sse_once_match(server, match_second_compact, sse6).await;
let match_after_second_resume = |req: &wiremock::Request| {
let body = std::str::from_utf8(&req.body).unwrap_or("");
let body = req.text_body();
body.contains(&format!("\"text\":\"{AFTER_SECOND_RESUME}\""))
};
mount_sse_once_match(server, match_after_second_resume, sse7).await;

View File

@@ -6,6 +6,7 @@ use codex_core::protocol::Op;
use codex_core::protocol::SandboxPolicy;
use codex_protocol::config_types::ReasoningSummary;
use codex_protocol::user_input::UserInput;
use core_test_support::RequestBodyExt;
use core_test_support::responses;
use core_test_support::skip_if_no_network;
use core_test_support::test_codex::TestCodex;
@@ -54,7 +55,7 @@ async fn codex_returns_json_result(model: String) -> anyhow::Result<()> {
let expected_schema: serde_json::Value = serde_json::from_str(SCHEMA)?;
let match_json_text_param = move |req: &wiremock::Request| {
let body: serde_json::Value = serde_json::from_slice(&req.body).unwrap_or_default();
let body: serde_json::Value = req.json_body();
let Some(text) = body.get("text") else {
return false;
};

View File

@@ -21,6 +21,7 @@ use codex_core::protocol::RolloutItem;
use codex_core::protocol::RolloutLine;
use codex_core::review_format::render_review_output_text;
use codex_protocol::user_input::UserInput;
use core_test_support::RequestBodyExt;
use core_test_support::load_default_config_for_test;
use core_test_support::load_sse_fixture_with_id_from_str;
use core_test_support::responses::get_responses_requests;
@@ -430,7 +431,7 @@ async fn review_uses_custom_review_model_from_config() {
let request = requests
.first()
.expect("expected POST request to /responses");
let body = request.body_json::<serde_json::Value>().unwrap();
let body = request.json_body::<serde_json::Value>();
assert_eq!(body["model"].as_str().unwrap(), "gpt-5.1");
server.verify().await;
@@ -551,7 +552,7 @@ async fn review_input_isolated_from_parent_history() {
let request = requests
.first()
.expect("expected POST request to /responses");
let body = request.body_json::<serde_json::Value>().unwrap();
let body = request.json_body::<serde_json::Value>();
let input = body["input"].as_array().expect("input array");
assert!(
input.len() >= 2,
@@ -676,7 +677,7 @@ async fn review_history_surfaces_in_parent_session() {
// Critically, no messages from the review thread should appear.
let requests = get_responses_requests(&server).await;
assert_eq!(requests.len(), 2);
let body = requests[1].body_json::<serde_json::Value>().unwrap();
let body = requests[1].json_body::<serde_json::Value>();
let input = body["input"].as_array().expect("input array");
// Must include the followup as the last item for this turn

View File

@@ -3,6 +3,7 @@ use codex_core::WireApi;
use codex_core::protocol::EventMsg;
use codex_core::protocol::Op;
use codex_protocol::user_input::UserInput;
use core_test_support::body_contains;
use core_test_support::load_sse_fixture_with_id;
use core_test_support::skip_if_no_network;
use core_test_support::test_codex::TestCodex;
@@ -11,7 +12,6 @@ use core_test_support::wait_for_event;
use wiremock::Mock;
use wiremock::MockServer;
use wiremock::ResponseTemplate;
use wiremock::matchers::body_string_contains;
use wiremock::matchers::method;
use wiremock::matchers::path;
@@ -38,7 +38,7 @@ async fn continue_after_stream_error() {
// so the failing request should only occur once.
Mock::given(method("POST"))
.and(path("/v1/responses"))
.and(body_string_contains("first message"))
.and(body_contains("first message"))
.respond_with(fail)
.up_to_n_times(2)
.mount(&server)
@@ -50,7 +50,7 @@ async fn continue_after_stream_error() {
Mock::given(method("POST"))
.and(path("/v1/responses"))
.and(body_string_contains("follow up"))
.and(body_contains("follow up"))
.respond_with(ok)
.expect(1)
.mount(&server)