mirror of
https://github.com/openai/codex.git
synced 2026-02-01 22:47:52 +00:00
Compare commits
8 Commits
dh--git-in
...
cconger/co
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e213999759 | ||
|
|
0124a95c83 | ||
|
|
7361c1fea2 | ||
|
|
f5adf8291a | ||
|
|
47366e8417 | ||
|
|
eef24681c8 | ||
|
|
c069660811 | ||
|
|
55acaaffac |
45
codex-rs/Cargo.lock
generated
45
codex-rs/Cargo.lock
generated
@@ -819,6 +819,8 @@ version = "1.2.30"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "deec109607ca693028562ed836a5f1c4b8bd77755c4e132fc5ce11b0b6211ae7"
|
||||
dependencies = [
|
||||
"jobserver",
|
||||
"libc",
|
||||
"shlex",
|
||||
]
|
||||
|
||||
@@ -985,6 +987,7 @@ dependencies = [
|
||||
"tokio-util",
|
||||
"tracing",
|
||||
"wiremock",
|
||||
"zstd",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1348,6 +1351,7 @@ dependencies = [
|
||||
"which",
|
||||
"wildmatch",
|
||||
"wiremock",
|
||||
"zstd",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2109,16 +2113,19 @@ dependencies = [
|
||||
"codex-protocol",
|
||||
"codex-utils-absolute-path",
|
||||
"codex-utils-cargo-bin",
|
||||
"http 1.3.1",
|
||||
"notify",
|
||||
"pretty_assertions",
|
||||
"regex-lite",
|
||||
"reqwest",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"shlex",
|
||||
"tempfile",
|
||||
"tokio",
|
||||
"walkdir",
|
||||
"wiremock",
|
||||
"zstd",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3924,6 +3931,16 @@ version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130"
|
||||
|
||||
[[package]]
|
||||
name = "jobserver"
|
||||
version = "0.1.34"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9afb3de4395d6b3e67a780b6de64b51c978ecf11cb9a462c66be7d4ca9039d33"
|
||||
dependencies = [
|
||||
"getrandom 0.3.3",
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "js-sys"
|
||||
version = "0.3.77"
|
||||
@@ -8809,6 +8826,34 @@ dependencies = [
|
||||
"syn 2.0.104",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zstd"
|
||||
version = "0.13.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e91ee311a569c327171651566e07972200e76fcfe2242a4fa446149a3881c08a"
|
||||
dependencies = [
|
||||
"zstd-safe",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zstd-safe"
|
||||
version = "7.2.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8f49c4d5f0abb602a93fb8736af2a4f4dd9512e36f7f570d66e65ff867ed3b9d"
|
||||
dependencies = [
|
||||
"zstd-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zstd-sys"
|
||||
version = "2.0.16+zstd.1.5.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "91e19ebc2adc8f83e43039e79776e3fda8ca919132d68a1fed6a5faca2683748"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"pkg-config",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zune-core"
|
||||
version = "0.4.12"
|
||||
|
||||
@@ -235,6 +235,7 @@ wildmatch = "2.6.1"
|
||||
|
||||
wiremock = "0.6"
|
||||
zeroize = "1.8.2"
|
||||
zstd = "0.13"
|
||||
|
||||
[workspace.lints]
|
||||
rust = {}
|
||||
|
||||
@@ -19,6 +19,7 @@ tracing = { workspace = true }
|
||||
eventsource-stream = { workspace = true }
|
||||
regex-lite = { workspace = true }
|
||||
tokio-util = { workspace = true, features = ["codec"] }
|
||||
zstd = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
|
||||
@@ -6,7 +6,10 @@ use crate::common::ResponseStream;
|
||||
use crate::endpoint::streaming::StreamingClient;
|
||||
use crate::error::ApiError;
|
||||
use crate::provider::Provider;
|
||||
use crate::provider::RequestCompression;
|
||||
use crate::provider::WireApi;
|
||||
use crate::requests::body::encode_body;
|
||||
use crate::requests::body::insert_compression_headers;
|
||||
use crate::sse::chat::spawn_chat_stream;
|
||||
use crate::telemetry::SseTelemetry;
|
||||
use codex_client::HttpTransport;
|
||||
@@ -45,8 +48,13 @@ impl<T: HttpTransport, A: AuthProvider> ChatClient<T, A> {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn stream_request(&self, request: ChatRequest) -> Result<ResponseStream, ApiError> {
|
||||
self.stream(request.body, request.headers).await
|
||||
pub async fn stream_request(
|
||||
&self,
|
||||
request: ChatRequest,
|
||||
request_compression: RequestCompression,
|
||||
) -> Result<ResponseStream, ApiError> {
|
||||
self.stream(request.body, request.headers, request_compression)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn stream_prompt(
|
||||
@@ -55,6 +63,7 @@ impl<T: HttpTransport, A: AuthProvider> ChatClient<T, A> {
|
||||
prompt: &ApiPrompt,
|
||||
conversation_id: Option<String>,
|
||||
session_source: Option<SessionSource>,
|
||||
request_compression: RequestCompression,
|
||||
) -> Result<ResponseStream, ApiError> {
|
||||
use crate::requests::ChatRequestBuilder;
|
||||
|
||||
@@ -64,7 +73,7 @@ impl<T: HttpTransport, A: AuthProvider> ChatClient<T, A> {
|
||||
.session_source(session_source)
|
||||
.build(self.streaming.provider())?;
|
||||
|
||||
self.stream_request(request).await
|
||||
self.stream_request(request, request_compression).await
|
||||
}
|
||||
|
||||
fn path(&self) -> &'static str {
|
||||
@@ -78,9 +87,13 @@ impl<T: HttpTransport, A: AuthProvider> ChatClient<T, A> {
|
||||
&self,
|
||||
body: Value,
|
||||
extra_headers: HeaderMap,
|
||||
request_compression: RequestCompression,
|
||||
) -> Result<ResponseStream, ApiError> {
|
||||
let mut headers = extra_headers;
|
||||
insert_compression_headers(&mut headers, request_compression);
|
||||
let encoded_body = encode_body(&body, request_compression)?;
|
||||
self.streaming
|
||||
.stream(self.path(), body, extra_headers, spawn_chat_stream)
|
||||
.stream(self.path(), encoded_body, headers, spawn_chat_stream)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ use crate::error::ApiError;
|
||||
use crate::provider::Provider;
|
||||
use crate::provider::WireApi;
|
||||
use crate::telemetry::run_with_request_telemetry;
|
||||
use codex_client::Body;
|
||||
use codex_client::HttpTransport;
|
||||
use codex_client::RequestTelemetry;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
@@ -54,7 +55,7 @@ impl<T: HttpTransport, A: AuthProvider> CompactClient<T, A> {
|
||||
let builder = || {
|
||||
let mut req = self.provider.build_request(Method::POST, path);
|
||||
req.headers.extend(extra_headers.clone());
|
||||
req.body = Some(body.clone());
|
||||
req.body = Some(Body::Json(body.clone()));
|
||||
add_auth_headers(&self.auth, req)
|
||||
};
|
||||
|
||||
@@ -89,6 +90,7 @@ struct CompactHistoryResponse {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
use crate::provider::RetryConfig;
|
||||
use async_trait::async_trait;
|
||||
use codex_client::Request;
|
||||
|
||||
@@ -6,6 +6,7 @@ use crate::common::TextControls;
|
||||
use crate::endpoint::streaming::StreamingClient;
|
||||
use crate::error::ApiError;
|
||||
use crate::provider::Provider;
|
||||
use crate::provider::RequestCompression;
|
||||
use crate::provider::WireApi;
|
||||
use crate::requests::ResponsesRequest;
|
||||
use crate::requests::ResponsesRequestBuilder;
|
||||
@@ -15,7 +16,6 @@ use codex_client::HttpTransport;
|
||||
use codex_client::RequestTelemetry;
|
||||
use codex_protocol::protocol::SessionSource;
|
||||
use http::HeaderMap;
|
||||
use serde_json::Value;
|
||||
use std::sync::Arc;
|
||||
use tracing::instrument;
|
||||
|
||||
@@ -33,6 +33,7 @@ pub struct ResponsesOptions {
|
||||
pub conversation_id: Option<String>,
|
||||
pub session_source: Option<SessionSource>,
|
||||
pub extra_headers: HeaderMap,
|
||||
pub request_compression: RequestCompression,
|
||||
}
|
||||
|
||||
impl<T: HttpTransport, A: AuthProvider> ResponsesClient<T, A> {
|
||||
@@ -56,7 +57,7 @@ impl<T: HttpTransport, A: AuthProvider> ResponsesClient<T, A> {
|
||||
&self,
|
||||
request: ResponsesRequest,
|
||||
) -> Result<ResponseStream, ApiError> {
|
||||
self.stream(request.body, request.headers).await
|
||||
self.stream(request).await
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip_all, err)]
|
||||
@@ -75,6 +76,7 @@ impl<T: HttpTransport, A: AuthProvider> ResponsesClient<T, A> {
|
||||
conversation_id,
|
||||
session_source,
|
||||
extra_headers,
|
||||
request_compression,
|
||||
} = options;
|
||||
|
||||
let request = ResponsesRequestBuilder::new(model, &prompt.instructions, &prompt.input)
|
||||
@@ -88,6 +90,7 @@ impl<T: HttpTransport, A: AuthProvider> ResponsesClient<T, A> {
|
||||
.session_source(session_source)
|
||||
.store_override(store_override)
|
||||
.extra_headers(extra_headers)
|
||||
.request_compression(request_compression)
|
||||
.build(self.streaming.provider())?;
|
||||
|
||||
self.stream_request(request).await
|
||||
@@ -100,13 +103,14 @@ impl<T: HttpTransport, A: AuthProvider> ResponsesClient<T, A> {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn stream(
|
||||
&self,
|
||||
body: Value,
|
||||
extra_headers: HeaderMap,
|
||||
) -> Result<ResponseStream, ApiError> {
|
||||
pub async fn stream(&self, request: ResponsesRequest) -> Result<ResponseStream, ApiError> {
|
||||
self.streaming
|
||||
.stream(self.path(), body, extra_headers, spawn_response_stream)
|
||||
.stream(
|
||||
self.path(),
|
||||
request.body,
|
||||
request.headers,
|
||||
spawn_response_stream,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,12 +5,15 @@ use crate::error::ApiError;
|
||||
use crate::provider::Provider;
|
||||
use crate::telemetry::SseTelemetry;
|
||||
use crate::telemetry::run_with_request_telemetry;
|
||||
use codex_client::Body;
|
||||
use codex_client::HttpTransport;
|
||||
use codex_client::RequestTelemetry;
|
||||
use codex_client::StreamResponse;
|
||||
use http::HeaderMap;
|
||||
use http::HeaderValue;
|
||||
use http::Method;
|
||||
use serde_json::Value;
|
||||
use http::header::ACCEPT;
|
||||
use http::header::CONTENT_TYPE;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
@@ -50,17 +53,18 @@ impl<T: HttpTransport, A: AuthProvider> StreamingClient<T, A> {
|
||||
pub(crate) async fn stream(
|
||||
&self,
|
||||
path: &str,
|
||||
body: Value,
|
||||
body: Body,
|
||||
extra_headers: HeaderMap,
|
||||
spawner: fn(StreamResponse, Duration, Option<Arc<dyn SseTelemetry>>) -> ResponseStream,
|
||||
) -> Result<ResponseStream, ApiError> {
|
||||
let builder = || {
|
||||
let mut req = self.provider.build_request(Method::POST, path);
|
||||
req.headers.extend(extra_headers.clone());
|
||||
req.headers.insert(
|
||||
http::header::ACCEPT,
|
||||
http::HeaderValue::from_static("text/event-stream"),
|
||||
);
|
||||
req.headers
|
||||
.insert(ACCEPT, HeaderValue::from_static("text/event-stream"));
|
||||
req.headers
|
||||
.entry(CONTENT_TYPE)
|
||||
.or_insert_with(|| HeaderValue::from_static("application/json"));
|
||||
req.body = Some(body.clone());
|
||||
add_auth_headers(&self.auth, req)
|
||||
};
|
||||
|
||||
@@ -41,6 +41,13 @@ impl RetryConfig {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
|
||||
pub enum RequestCompression {
|
||||
#[default]
|
||||
None,
|
||||
Zstd,
|
||||
}
|
||||
|
||||
/// HTTP endpoint configuration used to talk to a concrete API deployment.
|
||||
///
|
||||
/// Encapsulates base URL, default headers, query params, retry policy, and
|
||||
|
||||
40
codex-rs/codex-api/src/requests/body.rs
Normal file
40
codex-rs/codex-api/src/requests/body.rs
Normal file
@@ -0,0 +1,40 @@
|
||||
use crate::error::ApiError;
|
||||
use crate::provider::RequestCompression;
|
||||
use bytes::Bytes;
|
||||
use codex_client::Body;
|
||||
use http::HeaderMap;
|
||||
use http::HeaderValue;
|
||||
use http::header::CONTENT_ENCODING;
|
||||
use serde_json::Value;
|
||||
use std::time::Instant;
|
||||
use tracing::info;
|
||||
use zstd::stream::encode_all;
|
||||
|
||||
pub(crate) fn encode_body(body: &Value, compression: RequestCompression) -> Result<Body, ApiError> {
|
||||
match compression {
|
||||
RequestCompression::None => Ok(Body::Json(body.clone())),
|
||||
RequestCompression::Zstd => {
|
||||
let json = serde_json::to_vec(body).map_err(|err| {
|
||||
ApiError::Stream(format!("failed to encode request body as json: {err}"))
|
||||
})?;
|
||||
let started_at = Instant::now();
|
||||
let compressed = encode_all(json.as_slice(), 0).map_err(|err| {
|
||||
ApiError::Stream(format!("failed to compress request body: {err}"))
|
||||
})?;
|
||||
let elapsed = started_at.elapsed();
|
||||
info!(
|
||||
input_bytes = json.len(),
|
||||
output_bytes = compressed.len(),
|
||||
elapsed_ms = elapsed.as_millis(),
|
||||
"compressed request body"
|
||||
);
|
||||
Ok(Body::Bytes(Bytes::from(compressed)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn insert_compression_headers(headers: &mut HeaderMap, compression: RequestCompression) {
|
||||
if matches!(compression, RequestCompression::Zstd) {
|
||||
headers.insert(CONTENT_ENCODING, HeaderValue::from_static("zstd"));
|
||||
}
|
||||
}
|
||||
@@ -351,6 +351,7 @@ fn push_tool_call_message(messages: &mut Vec<Value>, tool_call: Value, reasoning
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
use crate::provider::RetryConfig;
|
||||
use crate::provider::WireApi;
|
||||
use codex_protocol::models::FunctionCallOutputPayload;
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
pub(crate) mod body;
|
||||
pub mod chat;
|
||||
pub(crate) mod headers;
|
||||
pub mod responses;
|
||||
|
||||
@@ -3,9 +3,13 @@ use crate::common::ResponsesApiRequest;
|
||||
use crate::common::TextControls;
|
||||
use crate::error::ApiError;
|
||||
use crate::provider::Provider;
|
||||
use crate::provider::RequestCompression;
|
||||
use crate::requests::body::encode_body;
|
||||
use crate::requests::body::insert_compression_headers;
|
||||
use crate::requests::headers::build_conversation_headers;
|
||||
use crate::requests::headers::insert_header;
|
||||
use crate::requests::headers::subagent_header;
|
||||
use codex_client::Body;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::protocol::SessionSource;
|
||||
use http::HeaderMap;
|
||||
@@ -13,7 +17,7 @@ use serde_json::Value;
|
||||
|
||||
/// Assembled request body plus headers for a Responses stream request.
|
||||
pub struct ResponsesRequest {
|
||||
pub body: Value,
|
||||
pub body: Body,
|
||||
pub headers: HeaderMap,
|
||||
}
|
||||
|
||||
@@ -32,6 +36,7 @@ pub struct ResponsesRequestBuilder<'a> {
|
||||
session_source: Option<SessionSource>,
|
||||
store_override: Option<bool>,
|
||||
headers: HeaderMap,
|
||||
request_compression: RequestCompression,
|
||||
}
|
||||
|
||||
impl<'a> ResponsesRequestBuilder<'a> {
|
||||
@@ -94,6 +99,11 @@ impl<'a> ResponsesRequestBuilder<'a> {
|
||||
self
|
||||
}
|
||||
|
||||
pub fn request_compression(mut self, request_compression: RequestCompression) -> Self {
|
||||
self.request_compression = request_compression;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self, provider: &Provider) -> Result<ResponsesRequest, ApiError> {
|
||||
let model = self
|
||||
.model
|
||||
@@ -137,6 +147,8 @@ impl<'a> ResponsesRequestBuilder<'a> {
|
||||
if let Some(subagent) = subagent_header(&self.session_source) {
|
||||
insert_header(&mut headers, "x-openai-subagent", &subagent);
|
||||
}
|
||||
insert_compression_headers(&mut headers, self.request_compression);
|
||||
let body = encode_body(&body, self.request_compression)?;
|
||||
|
||||
Ok(ResponsesRequest { body, headers })
|
||||
}
|
||||
@@ -172,8 +184,10 @@ fn attach_item_ids(payload_json: &mut Value, original_items: &[ResponseItem]) {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
use crate::provider::RetryConfig;
|
||||
use crate::provider::WireApi;
|
||||
use codex_client::Body;
|
||||
use codex_protocol::protocol::SubAgentSource;
|
||||
use http::HeaderValue;
|
||||
use pretty_assertions::assert_eq;
|
||||
@@ -219,10 +233,12 @@ mod tests {
|
||||
.build(&provider)
|
||||
.expect("request");
|
||||
|
||||
assert_eq!(request.body.get("store"), Some(&Value::Bool(true)));
|
||||
let Body::Json(body) = &request.body else {
|
||||
panic!("expected json body for responses request");
|
||||
};
|
||||
assert_eq!(body.get("store"), Some(&Value::Bool(true)));
|
||||
|
||||
let ids: Vec<Option<String>> = request
|
||||
.body
|
||||
let ids: Vec<Option<String>> = body
|
||||
.get("input")
|
||||
.and_then(|v| v.as_array())
|
||||
.into_iter()
|
||||
|
||||
@@ -10,7 +10,9 @@ use codex_api::ChatClient;
|
||||
use codex_api::Provider;
|
||||
use codex_api::ResponsesClient;
|
||||
use codex_api::ResponsesOptions;
|
||||
use codex_api::ResponsesRequest;
|
||||
use codex_api::WireApi;
|
||||
use codex_client::Body;
|
||||
use codex_client::HttpTransport;
|
||||
use codex_client::Request;
|
||||
use codex_client::Response;
|
||||
@@ -136,6 +138,13 @@ fn provider(name: &str, wire: WireApi) -> Provider {
|
||||
}
|
||||
}
|
||||
|
||||
fn responses_request(body: Value) -> ResponsesRequest {
|
||||
ResponsesRequest {
|
||||
body: Body::Json(body),
|
||||
headers: HeaderMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct FlakyTransport {
|
||||
state: Arc<Mutex<i64>>,
|
||||
@@ -201,7 +210,9 @@ async fn chat_client_uses_chat_completions_path_for_chat_wire() -> Result<()> {
|
||||
let client = ChatClient::new(transport, provider("openai", WireApi::Chat), NoAuth);
|
||||
|
||||
let body = serde_json::json!({ "echo": true });
|
||||
let _stream = client.stream(body, HeaderMap::new()).await?;
|
||||
let _stream = client
|
||||
.stream(body, HeaderMap::new(), Default::default())
|
||||
.await?;
|
||||
|
||||
let requests = state.take_stream_requests();
|
||||
assert_path_ends_with(&requests, "/chat/completions");
|
||||
@@ -215,7 +226,9 @@ async fn chat_client_uses_responses_path_for_responses_wire() -> Result<()> {
|
||||
let client = ChatClient::new(transport, provider("openai", WireApi::Responses), NoAuth);
|
||||
|
||||
let body = serde_json::json!({ "echo": true });
|
||||
let _stream = client.stream(body, HeaderMap::new()).await?;
|
||||
let _stream = client
|
||||
.stream(body, HeaderMap::new(), Default::default())
|
||||
.await?;
|
||||
|
||||
let requests = state.take_stream_requests();
|
||||
assert_path_ends_with(&requests, "/responses");
|
||||
@@ -228,8 +241,8 @@ async fn responses_client_uses_responses_path_for_responses_wire() -> Result<()>
|
||||
let transport = RecordingTransport::new(state.clone());
|
||||
let client = ResponsesClient::new(transport, provider("openai", WireApi::Responses), NoAuth);
|
||||
|
||||
let body = serde_json::json!({ "echo": true });
|
||||
let _stream = client.stream(body, HeaderMap::new()).await?;
|
||||
let request = responses_request(serde_json::json!({ "echo": true }));
|
||||
let _stream = client.stream(request).await?;
|
||||
|
||||
let requests = state.take_stream_requests();
|
||||
assert_path_ends_with(&requests, "/responses");
|
||||
@@ -242,8 +255,8 @@ async fn responses_client_uses_chat_path_for_chat_wire() -> Result<()> {
|
||||
let transport = RecordingTransport::new(state.clone());
|
||||
let client = ResponsesClient::new(transport, provider("openai", WireApi::Chat), NoAuth);
|
||||
|
||||
let body = serde_json::json!({ "echo": true });
|
||||
let _stream = client.stream(body, HeaderMap::new()).await?;
|
||||
let request = responses_request(serde_json::json!({ "echo": true }));
|
||||
let _stream = client.stream(request).await?;
|
||||
|
||||
let requests = state.take_stream_requests();
|
||||
assert_path_ends_with(&requests, "/chat/completions");
|
||||
@@ -257,8 +270,8 @@ async fn streaming_client_adds_auth_headers() -> Result<()> {
|
||||
let auth = StaticAuth::new("secret-token", "acct-1");
|
||||
let client = ResponsesClient::new(transport, provider("openai", WireApi::Responses), auth);
|
||||
|
||||
let body = serde_json::json!({ "model": "gpt-test" });
|
||||
let _stream = client.stream(body, HeaderMap::new()).await?;
|
||||
let request = responses_request(serde_json::json!({ "model": "gpt-test" }));
|
||||
let _stream = client.stream(request).await?;
|
||||
|
||||
let requests = state.take_stream_requests();
|
||||
assert_eq!(requests.len(), 1);
|
||||
|
||||
@@ -8,7 +8,9 @@ use codex_api::AuthProvider;
|
||||
use codex_api::Provider;
|
||||
use codex_api::ResponseEvent;
|
||||
use codex_api::ResponsesClient;
|
||||
use codex_api::ResponsesRequest;
|
||||
use codex_api::WireApi;
|
||||
use codex_client::Body;
|
||||
use codex_client::HttpTransport;
|
||||
use codex_client::Request;
|
||||
use codex_client::Response;
|
||||
@@ -94,6 +96,13 @@ fn build_responses_body(events: Vec<Value>) -> String {
|
||||
body
|
||||
}
|
||||
|
||||
fn responses_request(body: Value) -> ResponsesRequest {
|
||||
ResponsesRequest {
|
||||
body: Body::Json(body),
|
||||
headers: HeaderMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn responses_stream_parses_items_and_completed_end_to_end() -> Result<()> {
|
||||
let item1 = serde_json::json!({
|
||||
@@ -123,9 +132,8 @@ async fn responses_stream_parses_items_and_completed_end_to_end() -> Result<()>
|
||||
let transport = FixtureSseTransport::new(body);
|
||||
let client = ResponsesClient::new(transport, provider("openai", WireApi::Responses), NoAuth);
|
||||
|
||||
let mut stream = client
|
||||
.stream(serde_json::json!({"echo": true}), HeaderMap::new())
|
||||
.await?;
|
||||
let request = responses_request(serde_json::json!({"echo": true}));
|
||||
let mut stream = client.stream(request).await?;
|
||||
|
||||
let mut events = Vec::new();
|
||||
while let Some(ev) = stream.next().await {
|
||||
@@ -188,9 +196,8 @@ async fn responses_stream_aggregates_output_text_deltas() -> Result<()> {
|
||||
let transport = FixtureSseTransport::new(body);
|
||||
let client = ResponsesClient::new(transport, provider("openai", WireApi::Responses), NoAuth);
|
||||
|
||||
let stream = client
|
||||
.stream(serde_json::json!({"echo": true}), HeaderMap::new())
|
||||
.await?;
|
||||
let request = responses_request(serde_json::json!({"echo": true}));
|
||||
let stream = client.stream(request).await?;
|
||||
|
||||
let mut stream = stream.aggregate();
|
||||
let mut events = Vec::new();
|
||||
|
||||
@@ -104,6 +104,13 @@ impl CodexRequestBuilder {
|
||||
self.map(|builder| builder.json(value))
|
||||
}
|
||||
|
||||
pub fn body<T>(self, body: T) -> Self
|
||||
where
|
||||
T: Into<reqwest::Body>,
|
||||
{
|
||||
self.map(|builder| builder.body(body))
|
||||
}
|
||||
|
||||
pub async fn send(self) -> Result<Response, reqwest::Error> {
|
||||
let headers = trace_headers();
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ pub use crate::default_client::CodexHttpClient;
|
||||
pub use crate::default_client::CodexRequestBuilder;
|
||||
pub use crate::error::StreamError;
|
||||
pub use crate::error::TransportError;
|
||||
pub use crate::request::Body;
|
||||
pub use crate::request::Request;
|
||||
pub use crate::request::Response;
|
||||
pub use crate::retry::RetryOn;
|
||||
|
||||
@@ -5,12 +5,18 @@ use serde::Serialize;
|
||||
use serde_json::Value;
|
||||
use std::time::Duration;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Body {
|
||||
Json(Value),
|
||||
Bytes(Bytes),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Request {
|
||||
pub method: Method,
|
||||
pub url: String,
|
||||
pub headers: HeaderMap,
|
||||
pub body: Option<Value>,
|
||||
pub body: Option<Body>,
|
||||
pub timeout: Option<Duration>,
|
||||
}
|
||||
|
||||
@@ -26,7 +32,7 @@ impl Request {
|
||||
}
|
||||
|
||||
pub fn with_json<T: Serialize>(mut self, body: &T) -> Self {
|
||||
self.body = serde_json::to_value(body).ok();
|
||||
self.body = serde_json::to_value(body).ok().map(Body::Json);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use crate::default_client::CodexHttpClient;
|
||||
use crate::default_client::CodexRequestBuilder;
|
||||
use crate::error::TransportError;
|
||||
use crate::request::Body;
|
||||
use crate::request::Request;
|
||||
use crate::request::Response;
|
||||
use async_trait::async_trait;
|
||||
@@ -52,7 +53,10 @@ impl ReqwestTransport {
|
||||
builder = builder.timeout(timeout);
|
||||
}
|
||||
if let Some(body) = req.body {
|
||||
builder = builder.json(&body);
|
||||
builder = match body {
|
||||
Body::Json(value) => builder.json(&value),
|
||||
Body::Bytes(bytes) => builder.body(bytes),
|
||||
};
|
||||
}
|
||||
Ok(builder)
|
||||
}
|
||||
@@ -101,10 +105,10 @@ impl HttpTransport for ReqwestTransport {
|
||||
async fn stream(&self, req: Request) -> Result<StreamResponse, TransportError> {
|
||||
if enabled!(Level::TRACE) {
|
||||
trace!(
|
||||
"{} to {}: {}",
|
||||
req.method,
|
||||
req.url,
|
||||
req.body.as_ref().unwrap_or_default()
|
||||
method = %req.method,
|
||||
url = %req.url,
|
||||
body = ?req.body,
|
||||
"Sending streaming request"
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -89,6 +89,7 @@ url = { workspace = true }
|
||||
uuid = { workspace = true, features = ["serde", "v4", "v5"] }
|
||||
which = { workspace = true }
|
||||
wildmatch = { workspace = true }
|
||||
zstd = { workspace = true }
|
||||
|
||||
[features]
|
||||
deterministic_process_ids = []
|
||||
|
||||
@@ -156,6 +156,9 @@ impl ModelClient {
|
||||
let mut refreshed = false;
|
||||
loop {
|
||||
let auth = auth_manager.as_ref().and_then(|m| m.auth());
|
||||
let request_compression = self
|
||||
.provider
|
||||
.request_compression_for(auth.as_ref().map(|a| a.mode), &self.config.features);
|
||||
let api_provider = self
|
||||
.provider
|
||||
.to_api_provider(auth.as_ref().map(|a| a.mode))?;
|
||||
@@ -171,6 +174,7 @@ impl ModelClient {
|
||||
&api_prompt,
|
||||
Some(conversation_id.clone()),
|
||||
Some(session_source.clone()),
|
||||
request_compression,
|
||||
)
|
||||
.await;
|
||||
|
||||
@@ -245,6 +249,9 @@ impl ModelClient {
|
||||
let mut refreshed = false;
|
||||
loop {
|
||||
let auth = auth_manager.as_ref().and_then(|m| m.auth());
|
||||
let request_compression = self
|
||||
.provider
|
||||
.request_compression_for(auth.as_ref().map(|a| a.mode), &self.config.features);
|
||||
let api_provider = self
|
||||
.provider
|
||||
.to_api_provider(auth.as_ref().map(|a| a.mode))?;
|
||||
@@ -263,6 +270,7 @@ impl ModelClient {
|
||||
conversation_id: Some(conversation_id.clone()),
|
||||
session_source: Some(session_source.clone()),
|
||||
extra_headers: beta_feature_headers(&self.config),
|
||||
request_compression,
|
||||
};
|
||||
|
||||
let stream_result = client
|
||||
|
||||
@@ -24,7 +24,6 @@ use std::sync::OnceLock;
|
||||
pub static USER_AGENT_SUFFIX: LazyLock<Mutex<Option<String>>> = LazyLock::new(|| Mutex::new(None));
|
||||
pub const DEFAULT_ORIGINATOR: &str = "codex_cli_rs";
|
||||
pub const CODEX_INTERNAL_ORIGINATOR_OVERRIDE_ENV_VAR: &str = "CODEX_INTERNAL_ORIGINATOR_OVERRIDE";
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Originator {
|
||||
pub value: String,
|
||||
|
||||
@@ -8,7 +8,6 @@
|
||||
use crate::config::ConfigToml;
|
||||
use crate::config::profile::ConfigProfile;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use std::collections::BTreeMap;
|
||||
use std::collections::BTreeSet;
|
||||
|
||||
@@ -74,6 +73,8 @@ pub enum Feature {
|
||||
ApplyPatchFreeform,
|
||||
/// Allow the model to request web searches.
|
||||
WebSearchRequest,
|
||||
/// Allow request body compression when using ChatGPT auth.
|
||||
RequestCompression,
|
||||
/// Gate the execpolicy enforcement for shell/unified exec.
|
||||
ExecPolicy,
|
||||
/// Enable Windows sandbox (restricted token) on Windows.
|
||||
@@ -150,16 +151,16 @@ impl FeatureOverrides {
|
||||
impl Features {
|
||||
/// Starts with built-in defaults.
|
||||
pub fn with_defaults() -> Self {
|
||||
let mut set = BTreeSet::new();
|
||||
let mut features = Self {
|
||||
enabled: BTreeSet::new(),
|
||||
legacy_usages: BTreeSet::new(),
|
||||
};
|
||||
for spec in FEATURES {
|
||||
if spec.default_enabled {
|
||||
set.insert(spec.id);
|
||||
features.enable(spec.id);
|
||||
}
|
||||
}
|
||||
Self {
|
||||
enabled: set,
|
||||
legacy_usages: BTreeSet::new(),
|
||||
}
|
||||
features
|
||||
}
|
||||
|
||||
pub fn enabled(&self, f: Feature) -> bool {
|
||||
@@ -196,7 +197,7 @@ impl Features {
|
||||
.map(|usage| (usage.alias.as_str(), usage.feature))
|
||||
}
|
||||
|
||||
/// Apply a table of key -> bool toggles (e.g. from TOML).
|
||||
/// Apply a table of key -> value toggles (e.g. from TOML).
|
||||
pub fn apply_map(&mut self, m: &BTreeMap<String, bool>) {
|
||||
for (k, v) in m {
|
||||
match feature_for_key(k) {
|
||||
@@ -330,6 +331,12 @@ pub const FEATURES: &[FeatureSpec] = &[
|
||||
stage: Stage::Stable,
|
||||
default_enabled: false,
|
||||
},
|
||||
FeatureSpec {
|
||||
id: Feature::RequestCompression,
|
||||
key: "request_compression",
|
||||
stage: Stage::Experimental,
|
||||
default_enabled: false,
|
||||
},
|
||||
// Beta program. Rendered in the `/experimental` menu for users.
|
||||
FeatureSpec {
|
||||
id: Feature::UnifiedExec,
|
||||
|
||||
@@ -19,6 +19,8 @@ use std::env::VarError;
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::error::EnvVarError;
|
||||
use crate::features::Feature;
|
||||
use crate::features::Features;
|
||||
const DEFAULT_STREAM_IDLE_TIMEOUT_MS: u64 = 300_000;
|
||||
const DEFAULT_STREAM_MAX_RETRIES: u64 = 5;
|
||||
const DEFAULT_REQUEST_MAX_RETRIES: u64 = 4;
|
||||
@@ -253,6 +255,21 @@ impl ModelProviderInfo {
|
||||
pub fn is_openai(&self) -> bool {
|
||||
self.name == OPENAI_PROVIDER_NAME
|
||||
}
|
||||
|
||||
pub fn request_compression_for(
|
||||
&self,
|
||||
auth_mode: Option<AuthMode>,
|
||||
features: &Features,
|
||||
) -> codex_api::provider::RequestCompression {
|
||||
if self.is_openai()
|
||||
&& matches!(auth_mode, Some(AuthMode::ChatGPT))
|
||||
&& features.enabled(Feature::RequestCompression)
|
||||
{
|
||||
codex_api::provider::RequestCompression::Zstd
|
||||
} else {
|
||||
codex_api::provider::RequestCompression::None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub const DEFAULT_LMSTUDIO_PORT: u16 = 1234;
|
||||
|
||||
@@ -15,14 +15,17 @@ codex-core = { workspace = true, features = ["test-support"] }
|
||||
codex-protocol = { workspace = true }
|
||||
codex-utils-absolute-path = { workspace = true }
|
||||
codex-utils-cargo-bin = { workspace = true }
|
||||
http = { workspace = true }
|
||||
notify = { workspace = true }
|
||||
regex-lite = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
tempfile = { workspace = true }
|
||||
tokio = { workspace = true, features = ["time"] }
|
||||
walkdir = { workspace = true }
|
||||
wiremock = { workspace = true }
|
||||
shlex = { workspace = true }
|
||||
zstd = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
pretty_assertions = { workspace = true }
|
||||
|
||||
@@ -11,11 +11,15 @@ use regex_lite::Regex;
|
||||
use std::path::PathBuf;
|
||||
|
||||
pub mod process;
|
||||
pub mod request;
|
||||
pub mod responses;
|
||||
pub mod streaming_sse;
|
||||
pub mod test_codex;
|
||||
pub mod test_codex_exec;
|
||||
|
||||
pub use request::RequestBodyExt;
|
||||
pub use request::body_contains;
|
||||
|
||||
#[track_caller]
|
||||
pub fn assert_regex_match<'s>(pattern: &str, actual: &'s str) -> regex_lite::Captures<'s> {
|
||||
let regex = Regex::new(pattern).unwrap_or_else(|err| {
|
||||
@@ -178,7 +182,7 @@ where
|
||||
F: FnMut(&codex_core::protocol::EventMsg) -> bool,
|
||||
{
|
||||
use tokio::time::Duration;
|
||||
wait_for_event_with_timeout(codex, predicate, Duration::from_secs(1)).await
|
||||
wait_for_event_with_timeout(codex, predicate, Duration::from_secs(10)).await
|
||||
}
|
||||
|
||||
pub async fn wait_for_event_match<T, F>(codex: &CodexConversation, matcher: F) -> T
|
||||
|
||||
59
codex-rs/core/tests/common/request.rs
Normal file
59
codex-rs/core/tests/common/request.rs
Normal file
@@ -0,0 +1,59 @@
|
||||
use http::header::CONTENT_ENCODING;
|
||||
use serde::de::DeserializeOwned;
|
||||
use wiremock::Match;
|
||||
|
||||
pub fn decoded_body_bytes(request: &wiremock::Request) -> Vec<u8> {
|
||||
if is_zstd_encoded(request) {
|
||||
zstd::decode_all(request.body.as_slice()).unwrap_or_else(|err| {
|
||||
panic!("failed to decode zstd-encoded request body: {err}");
|
||||
})
|
||||
} else {
|
||||
request.body.clone()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn decoded_body_string(request: &wiremock::Request) -> String {
|
||||
String::from_utf8_lossy(&decoded_body_bytes(request)).into_owned()
|
||||
}
|
||||
|
||||
pub trait RequestBodyExt {
|
||||
fn json_body<T: DeserializeOwned>(&self) -> T;
|
||||
fn text_body(&self) -> String;
|
||||
}
|
||||
|
||||
impl RequestBodyExt for wiremock::Request {
|
||||
fn json_body<T: DeserializeOwned>(&self) -> T {
|
||||
serde_json::from_slice(&decoded_body_bytes(self)).unwrap_or_else(|err| {
|
||||
panic!("failed to decode request body as JSON: {err}");
|
||||
})
|
||||
}
|
||||
|
||||
fn text_body(&self) -> String {
|
||||
decoded_body_string(self)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn body_contains(needle: impl Into<String>) -> impl Match {
|
||||
BodyContains {
|
||||
needle: needle.into(),
|
||||
}
|
||||
}
|
||||
|
||||
struct BodyContains {
|
||||
needle: String,
|
||||
}
|
||||
|
||||
impl Match for BodyContains {
|
||||
fn matches(&self, request: &wiremock::Request) -> bool {
|
||||
decoded_body_string(request).contains(self.needle.as_str())
|
||||
}
|
||||
}
|
||||
|
||||
fn is_zstd_encoded(request: &wiremock::Request) -> bool {
|
||||
request
|
||||
.headers
|
||||
.get(CONTENT_ENCODING)
|
||||
.and_then(|value| value.to_str().ok())
|
||||
.map(|value| value.eq_ignore_ascii_case("zstd"))
|
||||
.unwrap_or(false)
|
||||
}
|
||||
@@ -15,6 +15,7 @@ use wiremock::ResponseTemplate;
|
||||
use wiremock::matchers::method;
|
||||
use wiremock::matchers::path_regex;
|
||||
|
||||
use crate::RequestBodyExt;
|
||||
use crate::test_codex::ApplyPatchModelOutput;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@@ -67,7 +68,7 @@ pub struct ResponsesRequest(wiremock::Request);
|
||||
|
||||
impl ResponsesRequest {
|
||||
pub fn body_json(&self) -> Value {
|
||||
self.0.body_json().unwrap()
|
||||
self.0.json_body()
|
||||
}
|
||||
|
||||
/// Returns all `input_text` spans from `message` inputs for the provided role.
|
||||
@@ -83,7 +84,7 @@ impl ResponsesRequest {
|
||||
}
|
||||
|
||||
pub fn input(&self) -> Vec<Value> {
|
||||
self.0.body_json::<Value>().unwrap()["input"]
|
||||
self.body_json()["input"]
|
||||
.as_array()
|
||||
.expect("input array not found in request")
|
||||
.clone()
|
||||
@@ -721,10 +722,7 @@ pub async fn get_responses_request_bodies(server: &MockServer) -> Vec<Value> {
|
||||
get_responses_requests(server)
|
||||
.await
|
||||
.into_iter()
|
||||
.map(|req| {
|
||||
req.body_json::<Value>()
|
||||
.expect("request body to be valid JSON")
|
||||
})
|
||||
.map(|req| req.json_body::<Value>())
|
||||
.collect()
|
||||
}
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ use codex_core::Prompt;
|
||||
use codex_core::ResponseEvent;
|
||||
use codex_core::ResponseItem;
|
||||
use codex_core::WireApi;
|
||||
use codex_core::features::Feature;
|
||||
use codex_core::models_manager::manager::ModelsManager;
|
||||
use codex_otel::otel_manager::OtelManager;
|
||||
use codex_protocol::ConversationId;
|
||||
@@ -317,3 +318,191 @@ async fn responses_respects_model_family_overrides_from_config() {
|
||||
Some("detailed")
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn responses_request_body_is_zstd_encoded() {
|
||||
core_test_support::skip_if_no_network!();
|
||||
|
||||
let server = responses::start_mock_server().await;
|
||||
let response_body = responses::sse(vec![
|
||||
responses::ev_response_created("resp-1"),
|
||||
responses::ev_completed("resp-1"),
|
||||
]);
|
||||
|
||||
let request_recorder = responses::mount_sse_once(&server, response_body).await;
|
||||
|
||||
let provider = ModelProviderInfo {
|
||||
name: "OpenAI".into(),
|
||||
base_url: Some(format!("{}/v1", server.uri())),
|
||||
env_key: None,
|
||||
env_key_instructions: None,
|
||||
experimental_bearer_token: None,
|
||||
wire_api: WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
request_max_retries: Some(0),
|
||||
stream_max_retries: Some(0),
|
||||
stream_idle_timeout_ms: Some(5_000),
|
||||
requires_openai_auth: false,
|
||||
};
|
||||
|
||||
let codex_home = TempDir::new().expect("failed to create TempDir");
|
||||
let mut config = load_default_config_for_test(&codex_home).await;
|
||||
config.model_provider_id = provider.name.clone();
|
||||
config.model_provider = provider.clone();
|
||||
config.features.enable(Feature::RequestCompression);
|
||||
let effort = config.model_reasoning_effort;
|
||||
let summary = config.model_reasoning_summary;
|
||||
let model = ModelsManager::get_model_offline(config.model.as_deref());
|
||||
config.model = Some(model.clone());
|
||||
let config = Arc::new(config);
|
||||
|
||||
let conversation_id = ConversationId::new();
|
||||
let session_source = SessionSource::Exec;
|
||||
let model_family = ModelsManager::construct_model_family_offline(model.as_str(), &config);
|
||||
let otel_manager = OtelManager::new(
|
||||
conversation_id,
|
||||
model.as_str(),
|
||||
model_family.slug.as_str(),
|
||||
None,
|
||||
Some("test@test.com".to_string()),
|
||||
Some(AuthMode::ChatGPT),
|
||||
false,
|
||||
"test".to_string(),
|
||||
session_source.clone(),
|
||||
);
|
||||
|
||||
let auth_manager =
|
||||
AuthManager::from_auth_for_testing(CodexAuth::create_dummy_chatgpt_auth_for_testing());
|
||||
let client = ModelClient::new(
|
||||
Arc::clone(&config),
|
||||
Some(auth_manager),
|
||||
model_family,
|
||||
otel_manager,
|
||||
provider,
|
||||
effort,
|
||||
summary,
|
||||
conversation_id,
|
||||
session_source,
|
||||
);
|
||||
let mut prompt = Prompt::default();
|
||||
prompt.input = vec![ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".into(),
|
||||
content: vec![ContentItem::InputText {
|
||||
text: "hello".into(),
|
||||
}],
|
||||
}];
|
||||
|
||||
let mut stream = client.stream(&prompt).await.expect("stream failed");
|
||||
while let Some(event) = stream.next().await {
|
||||
if matches!(event, Ok(ResponseEvent::Completed { .. })) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let request = request_recorder.single_request();
|
||||
assert_eq!(request.header("content-encoding").as_deref(), Some("zstd"));
|
||||
assert_eq!(
|
||||
request.header("content-type").as_deref(),
|
||||
Some("application/json")
|
||||
);
|
||||
let request_body = request.body_json();
|
||||
assert_eq!(request_body["stream"].as_bool(), Some(true));
|
||||
assert_eq!(
|
||||
request_body["input"][0]["content"][0]["text"].as_str(),
|
||||
Some("hello")
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn responses_request_body_is_uncompressed_when_disabled() {
|
||||
core_test_support::skip_if_no_network!();
|
||||
|
||||
let server = responses::start_mock_server().await;
|
||||
let response_body = responses::sse(vec![
|
||||
responses::ev_response_created("resp-1"),
|
||||
responses::ev_completed("resp-1"),
|
||||
]);
|
||||
|
||||
let request_recorder = responses::mount_sse_once(&server, response_body).await;
|
||||
|
||||
let provider = ModelProviderInfo {
|
||||
name: "OpenAI".into(),
|
||||
base_url: Some(format!("{}/v1", server.uri())),
|
||||
env_key: None,
|
||||
env_key_instructions: None,
|
||||
experimental_bearer_token: None,
|
||||
wire_api: WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
request_max_retries: Some(0),
|
||||
stream_max_retries: Some(0),
|
||||
stream_idle_timeout_ms: Some(5_000),
|
||||
requires_openai_auth: false,
|
||||
};
|
||||
|
||||
let codex_home = TempDir::new().expect("failed to create TempDir");
|
||||
let mut config = load_default_config_for_test(&codex_home).await;
|
||||
config.model_provider_id = provider.name.clone();
|
||||
config.model_provider = provider.clone();
|
||||
let effort = config.model_reasoning_effort;
|
||||
let summary = config.model_reasoning_summary;
|
||||
let model = ModelsManager::get_model_offline(config.model.as_deref());
|
||||
config.model = Some(model.clone());
|
||||
let config = Arc::new(config);
|
||||
|
||||
let conversation_id = ConversationId::new();
|
||||
let session_source = SessionSource::Exec;
|
||||
let model_family = ModelsManager::construct_model_family_offline(model.as_str(), &config);
|
||||
let otel_manager = OtelManager::new(
|
||||
conversation_id,
|
||||
model.as_str(),
|
||||
model_family.slug.as_str(),
|
||||
None,
|
||||
Some("test@test.com".to_string()),
|
||||
Some(AuthMode::ChatGPT),
|
||||
false,
|
||||
"test".to_string(),
|
||||
session_source.clone(),
|
||||
);
|
||||
|
||||
let auth_manager =
|
||||
AuthManager::from_auth_for_testing(CodexAuth::create_dummy_chatgpt_auth_for_testing());
|
||||
let client = ModelClient::new(
|
||||
Arc::clone(&config),
|
||||
Some(auth_manager),
|
||||
model_family,
|
||||
otel_manager,
|
||||
provider,
|
||||
effort,
|
||||
summary,
|
||||
conversation_id,
|
||||
session_source,
|
||||
);
|
||||
|
||||
let mut prompt = Prompt::default();
|
||||
prompt.input = vec![ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".into(),
|
||||
content: vec![ContentItem::InputText {
|
||||
text: "hello".into(),
|
||||
}],
|
||||
}];
|
||||
|
||||
let mut stream = client.stream(&prompt).await.expect("stream failed");
|
||||
while let Some(event) = stream.next().await {
|
||||
if matches!(event, Ok(ResponseEvent::Completed { .. })) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let request = request_recorder.single_request();
|
||||
assert_eq!(request.header("content-encoding"), None);
|
||||
assert_eq!(
|
||||
request.header("content-type").as_deref(),
|
||||
Some("application/json")
|
||||
);
|
||||
}
|
||||
|
||||
@@ -29,6 +29,8 @@ use codex_protocol::models::ReasoningItemReasoningSummary;
|
||||
use codex_protocol::models::WebSearchAction;
|
||||
use codex_protocol::openai_models::ReasoningEffort;
|
||||
use codex_protocol::user_input::UserInput;
|
||||
use core_test_support::RequestBodyExt;
|
||||
use core_test_support::body_contains;
|
||||
use core_test_support::load_default_config_for_test;
|
||||
use core_test_support::load_sse_fixture_with_id;
|
||||
use core_test_support::responses::ev_completed_with_tokens;
|
||||
@@ -51,7 +53,6 @@ use uuid::Uuid;
|
||||
use wiremock::Mock;
|
||||
use wiremock::MockServer;
|
||||
use wiremock::ResponseTemplate;
|
||||
use wiremock::matchers::body_string_contains;
|
||||
use wiremock::matchers::header_regex;
|
||||
use wiremock::matchers::method;
|
||||
use wiremock::matchers::path;
|
||||
@@ -507,7 +508,7 @@ async fn chatgpt_auth_sends_correct_request() {
|
||||
let request_authorization = request.headers.get("authorization").unwrap();
|
||||
let request_originator = request.headers.get("originator").unwrap();
|
||||
let request_chatgpt_account_id = request.headers.get("chatgpt-account-id").unwrap();
|
||||
let request_body = request.body_json::<serde_json::Value>().unwrap();
|
||||
let request_body = request.json_body::<serde_json::Value>();
|
||||
|
||||
assert_eq!(
|
||||
request_conversation_id.to_str().unwrap(),
|
||||
@@ -1495,7 +1496,7 @@ async fn context_window_error_sets_total_tokens_to_model_window() -> anyhow::Res
|
||||
|
||||
mount_sse_once_match(
|
||||
&server,
|
||||
body_string_contains("trigger context window"),
|
||||
body_contains("trigger context window"),
|
||||
sse_failed(
|
||||
"resp_context_window",
|
||||
"context_length_exceeded",
|
||||
@@ -1506,7 +1507,7 @@ async fn context_window_error_sets_total_tokens_to_model_window() -> anyhow::Res
|
||||
|
||||
mount_sse_once_match(
|
||||
&server,
|
||||
body_string_contains("seed turn"),
|
||||
body_contains("seed turn"),
|
||||
sse_completed("resp_seed"),
|
||||
)
|
||||
.await;
|
||||
@@ -1882,8 +1883,7 @@ async fn history_dedupes_streamed_and_final_messages_across_turns() {
|
||||
]);
|
||||
|
||||
let r3_input_array = requests[2]
|
||||
.body_json::<serde_json::Value>()
|
||||
.unwrap()
|
||||
.json_body::<serde_json::Value>()
|
||||
.get("input")
|
||||
.and_then(|v| v.as_array())
|
||||
.cloned()
|
||||
|
||||
@@ -17,6 +17,7 @@ use codex_core::protocol::SandboxPolicy;
|
||||
use codex_core::protocol::WarningEvent;
|
||||
use codex_protocol::config_types::ReasoningSummary;
|
||||
use codex_protocol::user_input::UserInput;
|
||||
use core_test_support::RequestBodyExt;
|
||||
use core_test_support::load_default_config_for_test;
|
||||
use core_test_support::responses::ev_local_shell_call;
|
||||
use core_test_support::responses::ev_reasoning_item;
|
||||
@@ -132,7 +133,6 @@ async fn summarize_context_three_requests_and_instructions() {
|
||||
|
||||
// SSE 3: minimal completed; we only need to capture the request body.
|
||||
let sse3 = sse(vec![ev_completed("r3")]);
|
||||
|
||||
// Mount the three expected requests in sequence so the assertions below can
|
||||
// inspect them without relying on specific prompt markers.
|
||||
let request_log = mount_sse_sequence(&server, vec![sse1, sse2, sse3]).await;
|
||||
@@ -361,7 +361,8 @@ async fn manual_compact_uses_custom_prompt() {
|
||||
let requests = get_responses_requests(&server).await;
|
||||
let body = requests
|
||||
.iter()
|
||||
.find_map(|req| req.body_json::<serde_json::Value>().ok())
|
||||
.map(core_test_support::RequestBodyExt::json_body::<serde_json::Value>)
|
||||
.next()
|
||||
.expect("summary request body");
|
||||
|
||||
let input = body
|
||||
@@ -591,9 +592,7 @@ async fn multiple_auto_compact_per_task_runs_after_token_limit_hit() {
|
||||
// collect the requests payloads from the model
|
||||
let requests_payloads = get_responses_requests(&server).await;
|
||||
|
||||
let body = requests_payloads[0]
|
||||
.body_json::<serde_json::Value>()
|
||||
.unwrap();
|
||||
let body = requests_payloads[0].json_body::<serde_json::Value>();
|
||||
let input = body.get("input").and_then(|v| v.as_array()).unwrap();
|
||||
|
||||
fn normalize_inputs(values: &[serde_json::Value]) -> Vec<serde_json::Value> {
|
||||
@@ -634,9 +633,7 @@ async fn multiple_auto_compact_per_task_runs_after_token_limit_hit() {
|
||||
prefixed_third_summary.as_str(),
|
||||
];
|
||||
for (i, expected_summary) in compaction_indices.into_iter().zip(expected_summaries) {
|
||||
let body = requests_payloads.clone()[i]
|
||||
.body_json::<serde_json::Value>()
|
||||
.unwrap();
|
||||
let body = requests_payloads.clone()[i].json_body::<serde_json::Value>();
|
||||
let input = body.get("input").and_then(|v| v.as_array()).unwrap();
|
||||
let input = normalize_inputs(input);
|
||||
assert_eq!(input.len(), 3);
|
||||
@@ -999,7 +996,7 @@ async fn multiple_auto_compact_per_task_runs_after_token_limit_hit() {
|
||||
]);
|
||||
|
||||
for (i, request) in requests_payloads.iter().enumerate() {
|
||||
let body = request.body_json::<serde_json::Value>().unwrap();
|
||||
let body = request.json_body::<serde_json::Value>();
|
||||
let input = body.get("input").and_then(|v| v.as_array()).unwrap();
|
||||
let expected_input = expected_requests_inputs[i].as_array().unwrap();
|
||||
assert_eq!(normalize_inputs(input), normalize_inputs(expected_input));
|
||||
@@ -1038,30 +1035,30 @@ async fn auto_compact_runs_after_token_limit_hit() {
|
||||
let prefixed_auto_summary = AUTO_SUMMARY_TEXT;
|
||||
|
||||
let first_matcher = |req: &wiremock::Request| {
|
||||
let body = std::str::from_utf8(&req.body).unwrap_or("");
|
||||
let body = req.text_body();
|
||||
body.contains(FIRST_AUTO_MSG)
|
||||
&& !body.contains(SECOND_AUTO_MSG)
|
||||
&& !body_contains_text(body, SUMMARIZATION_PROMPT)
|
||||
&& !body_contains_text(body.as_str(), SUMMARIZATION_PROMPT)
|
||||
};
|
||||
mount_sse_once_match(&server, first_matcher, sse1).await;
|
||||
|
||||
let second_matcher = |req: &wiremock::Request| {
|
||||
let body = std::str::from_utf8(&req.body).unwrap_or("");
|
||||
let body = req.text_body();
|
||||
body.contains(SECOND_AUTO_MSG)
|
||||
&& body.contains(FIRST_AUTO_MSG)
|
||||
&& !body_contains_text(body, SUMMARIZATION_PROMPT)
|
||||
&& !body_contains_text(body.as_str(), SUMMARIZATION_PROMPT)
|
||||
};
|
||||
mount_sse_once_match(&server, second_matcher, sse2).await;
|
||||
|
||||
let third_matcher = |req: &wiremock::Request| {
|
||||
let body = std::str::from_utf8(&req.body).unwrap_or("");
|
||||
body_contains_text(body, SUMMARIZATION_PROMPT)
|
||||
let body = req.text_body();
|
||||
body_contains_text(body.as_str(), SUMMARIZATION_PROMPT)
|
||||
};
|
||||
mount_sse_once_match(&server, third_matcher, sse3).await;
|
||||
|
||||
let fourth_matcher = |req: &wiremock::Request| {
|
||||
let body = std::str::from_utf8(&req.body).unwrap_or("");
|
||||
body.contains(POST_AUTO_USER_MSG) && !body_contains_text(body, SUMMARIZATION_PROMPT)
|
||||
let body = req.text_body();
|
||||
body.contains(POST_AUTO_USER_MSG)
|
||||
&& !body_contains_text(body.as_str(), SUMMARIZATION_PROMPT)
|
||||
};
|
||||
mount_sse_once_match(&server, fourth_matcher, sse4).await;
|
||||
|
||||
@@ -1126,10 +1123,7 @@ async fn auto_compact_runs_after_token_limit_hit() {
|
||||
requests.len()
|
||||
);
|
||||
let is_auto_compact = |req: &wiremock::Request| {
|
||||
body_contains_text(
|
||||
std::str::from_utf8(&req.body).unwrap_or(""),
|
||||
SUMMARIZATION_PROMPT,
|
||||
)
|
||||
body_contains_text(req.text_body().as_str(), SUMMARIZATION_PROMPT)
|
||||
};
|
||||
let auto_compact_count = requests.iter().filter(|req| is_auto_compact(req)).count();
|
||||
assert_eq!(
|
||||
@@ -1151,20 +1145,16 @@ async fn auto_compact_runs_after_token_limit_hit() {
|
||||
.enumerate()
|
||||
.rev()
|
||||
.find_map(|(idx, req)| {
|
||||
let body = std::str::from_utf8(&req.body).unwrap_or("");
|
||||
(body.contains(POST_AUTO_USER_MSG) && !body_contains_text(body, SUMMARIZATION_PROMPT))
|
||||
let body = req.text_body();
|
||||
(body.contains(POST_AUTO_USER_MSG) && !body_contains_text(&body, SUMMARIZATION_PROMPT))
|
||||
.then_some(idx)
|
||||
})
|
||||
.expect("follow-up request missing");
|
||||
assert_eq!(follow_up_index, 3, "follow-up request should be last");
|
||||
|
||||
let body_first = requests[0].body_json::<serde_json::Value>().unwrap();
|
||||
let body_auto = requests[auto_compact_index]
|
||||
.body_json::<serde_json::Value>()
|
||||
.unwrap();
|
||||
let body_follow_up = requests[follow_up_index]
|
||||
.body_json::<serde_json::Value>()
|
||||
.unwrap();
|
||||
let body_first = requests[0].json_body::<serde_json::Value>();
|
||||
let body_auto = requests[auto_compact_index].json_body::<serde_json::Value>();
|
||||
let body_follow_up = requests[follow_up_index].json_body::<serde_json::Value>();
|
||||
let instructions = body_auto
|
||||
.get("instructions")
|
||||
.and_then(|v| v.as_str())
|
||||
@@ -1375,24 +1365,24 @@ async fn auto_compact_persists_rollout_entries() {
|
||||
]);
|
||||
|
||||
let first_matcher = |req: &wiremock::Request| {
|
||||
let body = std::str::from_utf8(&req.body).unwrap_or("");
|
||||
let body = req.text_body();
|
||||
body.contains(FIRST_AUTO_MSG)
|
||||
&& !body.contains(SECOND_AUTO_MSG)
|
||||
&& !body_contains_text(body, SUMMARIZATION_PROMPT)
|
||||
&& !body_contains_text(body.as_str(), SUMMARIZATION_PROMPT)
|
||||
};
|
||||
mount_sse_once_match(&server, first_matcher, sse1).await;
|
||||
|
||||
let second_matcher = |req: &wiremock::Request| {
|
||||
let body = std::str::from_utf8(&req.body).unwrap_or("");
|
||||
let body = req.text_body();
|
||||
body.contains(SECOND_AUTO_MSG)
|
||||
&& body.contains(FIRST_AUTO_MSG)
|
||||
&& !body_contains_text(body, SUMMARIZATION_PROMPT)
|
||||
&& !body_contains_text(body.as_str(), SUMMARIZATION_PROMPT)
|
||||
};
|
||||
mount_sse_once_match(&server, second_matcher, sse2).await;
|
||||
|
||||
let third_matcher = |req: &wiremock::Request| {
|
||||
let body = std::str::from_utf8(&req.body).unwrap_or("");
|
||||
body_contains_text(body, SUMMARIZATION_PROMPT)
|
||||
let body = req.text_body();
|
||||
body_contains_text(body.as_str(), SUMMARIZATION_PROMPT)
|
||||
};
|
||||
mount_sse_once_match(&server, third_matcher, sse3).await;
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@ use codex_core::protocol::Op;
|
||||
use codex_core::protocol::WarningEvent;
|
||||
use codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR;
|
||||
use codex_protocol::user_input::UserInput;
|
||||
use core_test_support::RequestBodyExt;
|
||||
use core_test_support::load_default_config_for_test;
|
||||
use core_test_support::responses::ev_assistant_message;
|
||||
use core_test_support::responses::ev_completed;
|
||||
@@ -796,8 +797,9 @@ async fn mount_initial_flow(server: &MockServer) {
|
||||
let sse5 = sse(vec![ev_completed("r5")]);
|
||||
|
||||
let match_first = |req: &wiremock::Request| {
|
||||
let body = std::str::from_utf8(&req.body).unwrap_or("");
|
||||
let body = req.text_body();
|
||||
body.contains("\"text\":\"hello world\"")
|
||||
&& !body_contains_text(body.as_str(), SUMMARIZATION_PROMPT)
|
||||
&& !body.contains(&format!("\"text\":\"{SUMMARY_TEXT}\""))
|
||||
&& !body.contains("\"text\":\"AFTER_COMPACT\"")
|
||||
&& !body.contains("\"text\":\"AFTER_RESUME\"")
|
||||
@@ -806,13 +808,13 @@ async fn mount_initial_flow(server: &MockServer) {
|
||||
mount_sse_once_match(server, match_first, sse1).await;
|
||||
|
||||
let match_compact = |req: &wiremock::Request| {
|
||||
let body = std::str::from_utf8(&req.body).unwrap_or("");
|
||||
body_contains_text(body, SUMMARIZATION_PROMPT) || body.contains(&json_fragment(FIRST_REPLY))
|
||||
let body = req.text_body();
|
||||
body_contains_text(body.as_str(), SUMMARIZATION_PROMPT)
|
||||
};
|
||||
mount_sse_once_match(server, match_compact, sse2).await;
|
||||
|
||||
let match_after_compact = |req: &wiremock::Request| {
|
||||
let body = std::str::from_utf8(&req.body).unwrap_or("");
|
||||
let body = req.text_body();
|
||||
body.contains("\"text\":\"AFTER_COMPACT\"")
|
||||
&& !body.contains("\"text\":\"AFTER_RESUME\"")
|
||||
&& !body.contains("\"text\":\"AFTER_FORK\"")
|
||||
@@ -820,13 +822,13 @@ async fn mount_initial_flow(server: &MockServer) {
|
||||
mount_sse_once_match(server, match_after_compact, sse3).await;
|
||||
|
||||
let match_after_resume = |req: &wiremock::Request| {
|
||||
let body = std::str::from_utf8(&req.body).unwrap_or("");
|
||||
let body = req.text_body();
|
||||
body.contains("\"text\":\"AFTER_RESUME\"")
|
||||
};
|
||||
mount_sse_once_match(server, match_after_resume, sse4).await;
|
||||
|
||||
let match_after_fork = |req: &wiremock::Request| {
|
||||
let body = std::str::from_utf8(&req.body).unwrap_or("");
|
||||
let body = req.text_body();
|
||||
body.contains("\"text\":\"AFTER_FORK\"")
|
||||
};
|
||||
mount_sse_once_match(server, match_after_fork, sse5).await;
|
||||
@@ -840,13 +842,13 @@ async fn mount_second_compact_flow(server: &MockServer) {
|
||||
let sse7 = sse(vec![ev_completed("r7")]);
|
||||
|
||||
let match_second_compact = |req: &wiremock::Request| {
|
||||
let body = std::str::from_utf8(&req.body).unwrap_or("");
|
||||
body.contains("AFTER_FORK")
|
||||
let body = req.text_body();
|
||||
body_contains_text(body.as_str(), SUMMARIZATION_PROMPT) && body.contains("AFTER_FORK")
|
||||
};
|
||||
mount_sse_once_match(server, match_second_compact, sse6).await;
|
||||
|
||||
let match_after_second_resume = |req: &wiremock::Request| {
|
||||
let body = std::str::from_utf8(&req.body).unwrap_or("");
|
||||
let body = req.text_body();
|
||||
body.contains(&format!("\"text\":\"{AFTER_SECOND_RESUME}\""))
|
||||
};
|
||||
mount_sse_once_match(server, match_after_second_resume, sse7).await;
|
||||
|
||||
@@ -6,6 +6,7 @@ use codex_core::protocol::Op;
|
||||
use codex_core::protocol::SandboxPolicy;
|
||||
use codex_protocol::config_types::ReasoningSummary;
|
||||
use codex_protocol::user_input::UserInput;
|
||||
use core_test_support::RequestBodyExt;
|
||||
use core_test_support::responses;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::test_codex::TestCodex;
|
||||
@@ -54,7 +55,7 @@ async fn codex_returns_json_result(model: String) -> anyhow::Result<()> {
|
||||
|
||||
let expected_schema: serde_json::Value = serde_json::from_str(SCHEMA)?;
|
||||
let match_json_text_param = move |req: &wiremock::Request| {
|
||||
let body: serde_json::Value = serde_json::from_slice(&req.body).unwrap_or_default();
|
||||
let body: serde_json::Value = req.json_body();
|
||||
let Some(text) = body.get("text") else {
|
||||
return false;
|
||||
};
|
||||
|
||||
@@ -21,6 +21,7 @@ use codex_core::protocol::RolloutItem;
|
||||
use codex_core::protocol::RolloutLine;
|
||||
use codex_core::review_format::render_review_output_text;
|
||||
use codex_protocol::user_input::UserInput;
|
||||
use core_test_support::RequestBodyExt;
|
||||
use core_test_support::load_default_config_for_test;
|
||||
use core_test_support::load_sse_fixture_with_id_from_str;
|
||||
use core_test_support::responses::get_responses_requests;
|
||||
@@ -430,7 +431,7 @@ async fn review_uses_custom_review_model_from_config() {
|
||||
let request = requests
|
||||
.first()
|
||||
.expect("expected POST request to /responses");
|
||||
let body = request.body_json::<serde_json::Value>().unwrap();
|
||||
let body = request.json_body::<serde_json::Value>();
|
||||
assert_eq!(body["model"].as_str().unwrap(), "gpt-5.1");
|
||||
|
||||
server.verify().await;
|
||||
@@ -551,7 +552,7 @@ async fn review_input_isolated_from_parent_history() {
|
||||
let request = requests
|
||||
.first()
|
||||
.expect("expected POST request to /responses");
|
||||
let body = request.body_json::<serde_json::Value>().unwrap();
|
||||
let body = request.json_body::<serde_json::Value>();
|
||||
let input = body["input"].as_array().expect("input array");
|
||||
assert!(
|
||||
input.len() >= 2,
|
||||
@@ -676,7 +677,7 @@ async fn review_history_surfaces_in_parent_session() {
|
||||
// Critically, no messages from the review thread should appear.
|
||||
let requests = get_responses_requests(&server).await;
|
||||
assert_eq!(requests.len(), 2);
|
||||
let body = requests[1].body_json::<serde_json::Value>().unwrap();
|
||||
let body = requests[1].json_body::<serde_json::Value>();
|
||||
let input = body["input"].as_array().expect("input array");
|
||||
|
||||
// Must include the followup as the last item for this turn
|
||||
|
||||
@@ -3,6 +3,7 @@ use codex_core::WireApi;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_protocol::user_input::UserInput;
|
||||
use core_test_support::body_contains;
|
||||
use core_test_support::load_sse_fixture_with_id;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::test_codex::TestCodex;
|
||||
@@ -11,7 +12,6 @@ use core_test_support::wait_for_event;
|
||||
use wiremock::Mock;
|
||||
use wiremock::MockServer;
|
||||
use wiremock::ResponseTemplate;
|
||||
use wiremock::matchers::body_string_contains;
|
||||
use wiremock::matchers::method;
|
||||
use wiremock::matchers::path;
|
||||
|
||||
@@ -38,7 +38,7 @@ async fn continue_after_stream_error() {
|
||||
// so the failing request should only occur once.
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.and(body_string_contains("first message"))
|
||||
.and(body_contains("first message"))
|
||||
.respond_with(fail)
|
||||
.up_to_n_times(2)
|
||||
.mount(&server)
|
||||
@@ -50,7 +50,7 @@ async fn continue_after_stream_error() {
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.and(body_string_contains("follow up"))
|
||||
.and(body_contains("follow up"))
|
||||
.respond_with(ok)
|
||||
.expect(1)
|
||||
.mount(&server)
|
||||
|
||||
Reference in New Issue
Block a user