Files
codex/codex-rs/core/src/client_tests.rs
2026-05-06 17:32:36 -07:00

711 lines
23 KiB
Rust

use super::AttestationPurpose;
use super::AuthRequestTelemetryContext;
use super::ModelClient;
use super::PendingUnauthorizedRetry;
use super::UnauthorizedRecoveryExecution;
use super::X_CODEX_INSTALLATION_ID_HEADER;
use super::X_CODEX_PARENT_THREAD_ID_HEADER;
use super::X_CODEX_TURN_METADATA_HEADER;
use super::X_CODEX_WINDOW_ID_HEADER;
use super::X_OPENAI_SUBAGENT_HEADER;
use crate::AttestationProvider;
use codex_api::ApiError;
use codex_api::ResponseEvent;
use codex_app_server_protocol::AuthMode;
use codex_model_provider::BearerAuthProvider;
use codex_model_provider_info::CHATGPT_CODEX_BASE_URL;
use codex_model_provider_info::WireApi;
use codex_model_provider_info::create_oss_provider_with_base_url;
use codex_otel::SessionTelemetry;
use codex_protocol::SessionId;
use codex_protocol::ThreadId;
use codex_protocol::models::ContentItem;
use codex_protocol::models::ResponseItem;
use codex_protocol::openai_models::ModelInfo;
use codex_protocol::protocol::InternalSessionSource;
use codex_protocol::protocol::SessionSource;
use codex_protocol::protocol::SubAgentSource;
use codex_rollout_trace::ExecutionStatus;
use codex_rollout_trace::InferenceTraceAttempt;
use codex_rollout_trace::InferenceTraceContext;
use codex_rollout_trace::RawTraceEventPayload;
use codex_rollout_trace::RolloutTrace;
use codex_rollout_trace::TraceWriter;
use codex_rollout_trace::replay_bundle;
use futures::StreamExt;
use pretty_assertions::assert_eq;
use serde_json::json;
use std::collections::BTreeMap;
use std::collections::VecDeque;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::Mutex;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;
use std::task::Context;
use std::task::Poll;
use std::time::Duration;
use tempfile::TempDir;
use tokio::sync::Notify;
use tracing::Event;
use tracing::Subscriber;
use tracing::field::Visit;
use tracing_subscriber::Layer;
use tracing_subscriber::layer::Context as LayerContext;
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::registry::LookupSpan;
use tracing_subscriber::util::SubscriberInitExt;
fn test_model_client(session_source: SessionSource) -> ModelClient {
let provider = create_oss_provider_with_base_url("https://example.com/v1", WireApi::Responses);
let thread_id = ThreadId::new();
ModelClient::new(
/*auth_manager*/ None,
thread_id.into(),
thread_id,
/*installation_id*/ "11111111-1111-4111-8111-111111111111".to_string(),
provider,
session_source,
/*model_verbosity*/ None,
/*enable_request_compression*/ false,
/*include_timing_metrics*/ false,
/*beta_features_header*/ None,
)
}
fn api_provider(base_url: &str) -> codex_api::Provider {
codex_api::Provider {
name: "test".to_string(),
base_url: base_url.to_string(),
uses_chatgpt_auth: base_url
.trim_end_matches('/')
.eq_ignore_ascii_case(CHATGPT_CODEX_BASE_URL),
query_params: None,
headers: http::HeaderMap::new(),
retry: codex_api::RetryConfig {
max_attempts: 1,
base_delay: Duration::from_millis(1),
retry_429: false,
retry_5xx: true,
retry_transport: true,
},
stream_idle_timeout: Duration::from_secs(1),
}
}
fn test_model_info() -> ModelInfo {
serde_json::from_value(json!({
"slug": "gpt-test",
"display_name": "gpt-test",
"description": "desc",
"default_reasoning_level": "medium",
"supported_reasoning_levels": [
{"effort": "medium", "description": "medium"}
],
"shell_type": "shell_command",
"visibility": "list",
"supported_in_api": true,
"priority": 1,
"upgrade": null,
"base_instructions": "base instructions",
"model_messages": null,
"supports_reasoning_summaries": false,
"support_verbosity": false,
"default_verbosity": null,
"apply_patch_tool_type": null,
"truncation_policy": {"mode": "bytes", "limit": 10000},
"supports_parallel_tool_calls": false,
"supports_image_detail_original": false,
"context_window": 272000,
"auto_compact_token_limit": null,
"experimental_supported_tools": []
}))
.expect("deserialize test model info")
}
fn test_session_telemetry() -> SessionTelemetry {
SessionTelemetry::new(
ThreadId::new(),
"gpt-test",
"gpt-test",
/*account_id*/ None,
/*account_email*/ None,
/*auth_mode*/ None,
"test-originator".to_string(),
/*log_user_prompts*/ false,
"test-terminal".to_string(),
SessionSource::Cli,
)
}
#[derive(Default)]
struct TagCollectorVisitor {
tags: BTreeMap<String, String>,
}
impl Visit for TagCollectorVisitor {
fn record_str(&mut self, field: &tracing::field::Field, value: &str) {
self.tags
.insert(field.name().to_string(), value.to_string());
}
fn record_debug(&mut self, field: &tracing::field::Field, value: &dyn std::fmt::Debug) {
self.tags
.insert(field.name().to_string(), format!("{value:?}"));
}
}
#[derive(Clone)]
struct TagCollectorLayer {
tags: Arc<Mutex<BTreeMap<String, String>>>,
}
impl<S> Layer<S> for TagCollectorLayer
where
S: Subscriber + for<'a> LookupSpan<'a>,
{
fn on_event(&self, event: &Event<'_>, _ctx: LayerContext<'_, S>) {
if event.metadata().target() != "feedback_tags" {
return;
}
let mut visitor = TagCollectorVisitor::default();
event.record(&mut visitor);
self.tags.lock().unwrap().extend(visitor.tags);
}
}
fn started_inference_attempt(temp: &TempDir) -> anyhow::Result<InferenceTraceAttempt> {
let writer = Arc::new(TraceWriter::create(
temp.path(),
"trace-1".to_string(),
"rollout-1".to_string(),
"thread-root".to_string(),
)?);
writer.append(RawTraceEventPayload::ThreadStarted {
thread_id: "thread-root".to_string(),
agent_path: "/root".to_string(),
metadata_payload: None,
})?;
writer.append(RawTraceEventPayload::CodexTurnStarted {
codex_turn_id: "turn-1".to_string(),
thread_id: "thread-root".to_string(),
})?;
let inference_trace = InferenceTraceContext::enabled(
writer,
"thread-root".to_string(),
"turn-1".to_string(),
"gpt-test".to_string(),
"test-provider".to_string(),
);
let attempt = inference_trace.start_attempt();
attempt.record_started(&json!({
"model": "gpt-test",
"input": [{
"type": "message",
"role": "user",
"content": [{"type": "input_text", "text": "hello"}]
}],
}));
Ok(attempt)
}
fn output_message(id: &str, text: &str) -> ResponseItem {
ResponseItem::Message {
id: Some(id.to_string()),
role: "assistant".to_string(),
content: vec![ContentItem::OutputText {
text: text.to_string(),
}],
phase: None,
}
}
async fn replay_until_cancelled(temp: &TempDir) -> anyhow::Result<RolloutTrace> {
let mut rollout = replay_bundle(temp.path())?;
for _ in 0..50 {
let inference = rollout
.inference_calls
.values()
.next()
.expect("inference should be reduced");
if inference.execution.status == ExecutionStatus::Cancelled {
return Ok(rollout);
}
tokio::time::sleep(Duration::from_millis(10)).await;
rollout = replay_bundle(temp.path())?;
}
Ok(rollout)
}
struct NotifyAfterEventStream {
events: VecDeque<ResponseEvent>,
yielded: usize,
notify_after: usize,
notify: Arc<Notify>,
}
impl futures::Stream for NotifyAfterEventStream {
type Item = std::result::Result<ResponseEvent, ApiError>;
fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let Some(event) = self.events.pop_front() else {
return Poll::Pending;
};
self.yielded += 1;
if self.yielded == self.notify_after {
self.notify.notify_one();
}
Poll::Ready(Some(Ok(event)))
}
}
#[test]
fn build_subagent_headers_sets_other_subagent_label() {
let client = test_model_client(SessionSource::SubAgent(SubAgentSource::Other(
"memory_consolidation".to_string(),
)));
let headers = client.build_subagent_headers();
let value = headers
.get(X_OPENAI_SUBAGENT_HEADER)
.and_then(|value| value.to_str().ok());
assert_eq!(value, Some("memory_consolidation"));
}
#[test]
fn build_subagent_headers_sets_internal_memory_consolidation_label() {
let client = test_model_client(SessionSource::Internal(
InternalSessionSource::MemoryConsolidation,
));
let headers = client.build_subagent_headers();
let value = headers
.get(X_OPENAI_SUBAGENT_HEADER)
.and_then(|value| value.to_str().ok());
assert_eq!(value, Some("memory_consolidation"));
}
#[test]
fn build_ws_client_metadata_includes_window_lineage_and_turn_metadata() {
let parent_thread_id = ThreadId::new();
let client = test_model_client(SessionSource::SubAgent(SubAgentSource::ThreadSpawn {
parent_thread_id,
depth: 2,
agent_path: None,
agent_nickname: None,
agent_role: None,
}));
client.advance_window_generation();
let client_metadata = client.build_ws_client_metadata(Some(r#"{"turn_id":"turn-123"}"#));
let thread_id = client.state.thread_id;
assert_eq!(
client_metadata,
std::collections::HashMap::from([
(
X_CODEX_INSTALLATION_ID_HEADER.to_string(),
"11111111-1111-4111-8111-111111111111".to_string(),
),
(
X_CODEX_WINDOW_ID_HEADER.to_string(),
format!("{thread_id}:1"),
),
(
X_OPENAI_SUBAGENT_HEADER.to_string(),
"collab_spawn".to_string(),
),
(
X_CODEX_PARENT_THREAD_ID_HEADER.to_string(),
parent_thread_id.to_string(),
),
(
X_CODEX_TURN_METADATA_HEADER.to_string(),
r#"{"turn_id":"turn-123"}"#.to_string(),
),
])
);
}
#[tokio::test]
async fn summarize_memories_returns_empty_for_empty_input() {
let client = test_model_client(SessionSource::Cli);
let model_info = test_model_info();
let session_telemetry = test_session_telemetry();
let output = client
.summarize_memories(
Vec::new(),
&model_info,
/*effort*/ None,
&session_telemetry,
)
.await
.expect("empty summarize request should succeed");
assert_eq!(output.len(), 0);
}
#[tokio::test]
async fn dropped_response_stream_traces_cancelled_partial_output() -> anyhow::Result<()> {
let temp = TempDir::new()?;
let attempt = started_inference_attempt(&temp)?;
// The provider has produced one complete output item, but no terminal
// response.completed event. The harness has enough information to keep this
// item in history, so the trace should preserve it when the stream is
// abandoned.
let item = output_message("msg-1", "partial answer");
let api_stream = futures::stream::iter([Ok(ResponseEvent::OutputItemDone(item))])
.chain(futures::stream::pending());
let (mut stream, _) = super::map_response_events(
/*upstream_request_id*/ None,
api_stream,
test_session_telemetry(),
attempt,
);
let observed = stream
.next()
.await
.expect("mapped stream should yield output item")?;
assert!(matches!(observed, ResponseEvent::OutputItemDone(_)));
// Dropping the consumer is how turn interruption/preemption stops polling
// the provider stream. The mapper task observes that drop asynchronously
// and records cancellation using the output items it has already seen.
drop(stream);
// Cancellation is recorded by the mapper task after Drop wakes it, so the
// replay may need a short wait before the terminal event appears on disk.
let rollout = replay_until_cancelled(&temp).await?;
let inference = rollout
.inference_calls
.values()
.next()
.expect("inference should be reduced");
assert_eq!(inference.execution.status, ExecutionStatus::Cancelled);
assert_eq!(inference.response_item_ids.len(), 1);
assert_eq!(rollout.raw_payloads.len(), 2);
Ok(())
}
#[tokio::test]
async fn response_stream_records_last_model_feedback_ids() {
let tags = Arc::new(Mutex::new(BTreeMap::new()));
let _guard = tracing_subscriber::registry()
.with(TagCollectorLayer { tags: tags.clone() })
.set_default();
let api_stream = futures::stream::iter([
Ok(ResponseEvent::Created),
Ok(ResponseEvent::Completed {
response_id: "resp-123".to_string(),
token_usage: None,
end_turn: Some(true),
}),
]);
let (mut stream, _) = super::map_response_events(
Some("req-123".to_string()),
api_stream,
test_session_telemetry(),
InferenceTraceAttempt::disabled(),
);
while stream.next().await.is_some() {}
let tags = tags.lock().unwrap().clone();
assert_eq!(
tags.get("last_model_request_id").map(String::as_str),
Some("\"req-123\"")
);
assert_eq!(
tags.get("last_model_response_id").map(String::as_str),
Some("\"resp-123\"")
);
}
#[tokio::test]
async fn dropped_backpressured_response_stream_traces_cancelled_partial_output()
-> anyhow::Result<()> {
let temp = TempDir::new()?;
let attempt = started_inference_attempt(&temp)?;
let backpressured_item_yielded = Arc::new(Notify::new());
let mut events = VecDeque::new();
for _ in 0..super::RESPONSE_STREAM_CHANNEL_CAPACITY {
events.push_back(ResponseEvent::Created);
}
events.push_back(ResponseEvent::OutputItemDone(output_message(
"msg-1",
"partial answer",
)));
let api_stream = NotifyAfterEventStream {
events,
yielded: 0,
notify_after: super::RESPONSE_STREAM_CHANNEL_CAPACITY + 1,
notify: Arc::clone(&backpressured_item_yielded),
};
let (stream, _) = super::map_response_events(
/*upstream_request_id*/ None,
api_stream,
test_session_telemetry(),
attempt,
);
// Fill the mapper channel with non-terminal events, then yield one output
// item. The mapper has observed that item and is blocked trying to send it
// downstream, so dropping the consumer covers the send-failure path rather
// than the `consumer_dropped` select branch.
backpressured_item_yielded.notified().await;
drop(stream);
let rollout = replay_until_cancelled(&temp).await?;
let inference = rollout
.inference_calls
.values()
.next()
.expect("inference should be reduced");
assert_eq!(inference.execution.status, ExecutionStatus::Cancelled);
assert_eq!(inference.response_item_ids.len(), 1);
assert_eq!(rollout.raw_payloads.len(), 2);
Ok(())
}
#[test]
fn auth_request_telemetry_context_tracks_attached_auth_and_retry_phase() {
let auth_context = AuthRequestTelemetryContext::new(
Some(AuthMode::Chatgpt),
&BearerAuthProvider::for_test(Some("access-token"), Some("workspace-123")),
PendingUnauthorizedRetry::from_recovery(UnauthorizedRecoveryExecution {
mode: "managed",
phase: "refresh_token",
}),
);
assert_eq!(auth_context.auth_mode, Some("Chatgpt"));
assert!(auth_context.auth_header_attached);
assert_eq!(auth_context.auth_header_name, Some("authorization"));
assert!(auth_context.retry_after_unauthorized);
assert_eq!(auth_context.recovery_mode, Some("managed"));
assert_eq!(auth_context.recovery_phase, Some("refresh_token"));
}
fn model_client_with_counting_attestation() -> (ModelClient, Arc<AtomicUsize>) {
let attestation_calls = Arc::new(AtomicUsize::new(0));
let calls = attestation_calls.clone();
let model_client = ModelClient::new_with_attestation_provider(
/*auth_manager*/ None,
SessionId::new(),
ThreadId::new(),
/*installation_id*/ "11111111-1111-4111-8111-111111111111".to_string(),
create_oss_provider_with_base_url("https://example.com/v1", WireApi::Responses),
SessionSource::Exec,
/*model_verbosity*/ None,
/*enable_request_compression*/ false,
/*include_timing_metrics*/ false,
/*beta_features_header*/ None,
Some(AttestationProvider::new(move || {
let calls = calls.clone();
Box::pin(async move {
let call = calls.fetch_add(1, Ordering::Relaxed) + 1;
Some(format!("v1.header-{call}"))
})
})),
);
(model_client, attestation_calls)
}
#[test]
fn should_send_attestation_for_allowed_chatgpt_codex_purposes() {
let provider = api_provider("https://chatgpt.com/backend-api/codex/");
for purpose in [
AttestationPurpose::Response,
AttestationPurpose::Compaction,
AttestationPurpose::RealtimeWebrtcCallSetup,
] {
assert!(super::should_send_attestation(&provider, purpose));
}
}
#[test]
fn should_not_send_attestation_for_non_chatgpt_codex_provider() {
let provider = api_provider("https://api.openai.com/v1");
assert!(!super::should_send_attestation(
&provider,
AttestationPurpose::Response,
));
}
#[tokio::test]
async fn responses_generate_fresh_attestation_headers_for_chatgpt_codex() {
let provider = api_provider("https://chatgpt.com/backend-api/codex/");
let (model_client, attestation_calls) = model_client_with_counting_attestation();
let mut first_headers = http::HeaderMap::new();
let mut second_headers = http::HeaderMap::new();
model_client
.extend_attestation_header_for(&mut first_headers, &provider, AttestationPurpose::Response)
.await;
model_client
.extend_attestation_header_for(&mut second_headers, &provider, AttestationPurpose::Response)
.await;
assert_eq!(
first_headers
.get(crate::attestation::X_OAI_ATTESTATION_HEADER)
.and_then(|value| value.to_str().ok()),
Some("v1.header-1"),
);
assert_eq!(
second_headers
.get(crate::attestation::X_OAI_ATTESTATION_HEADER)
.and_then(|value| value.to_str().ok()),
Some("v1.header-2"),
);
assert_eq!(attestation_calls.load(Ordering::Relaxed), 2);
}
#[tokio::test]
async fn websocket_handshake_includes_attestation_for_chatgpt_codex_responses() {
let provider = api_provider("https://chatgpt.com/backend-api/codex/");
let (model_client, attestation_calls) = model_client_with_counting_attestation();
let headers = model_client
.build_websocket_headers(
&provider, /*turn_state*/ None, /*turn_metadata_header*/ None,
)
.await;
assert_eq!(
headers
.get(crate::attestation::X_OAI_ATTESTATION_HEADER)
.and_then(|value| value.to_str().ok()),
Some("v1.header-1"),
);
assert_eq!(attestation_calls.load(Ordering::Relaxed), 1);
}
#[tokio::test]
async fn compact_generate_fresh_attestation_headers_for_chatgpt_codex() {
let provider = api_provider("https://chatgpt.com/backend-api/codex/");
let (model_client, attestation_calls) = model_client_with_counting_attestation();
let mut first_headers = http::HeaderMap::new();
let mut second_headers = http::HeaderMap::new();
model_client
.extend_attestation_header_for(
&mut first_headers,
&provider,
AttestationPurpose::Compaction,
)
.await;
model_client
.extend_attestation_header_for(
&mut second_headers,
&provider,
AttestationPurpose::Compaction,
)
.await;
assert_eq!(
first_headers
.get(crate::attestation::X_OAI_ATTESTATION_HEADER)
.and_then(|value| value.to_str().ok()),
Some("v1.header-1"),
);
assert_eq!(
second_headers
.get(crate::attestation::X_OAI_ATTESTATION_HEADER)
.and_then(|value| value.to_str().ok()),
Some("v1.header-2"),
);
assert_eq!(attestation_calls.load(Ordering::Relaxed), 2);
}
#[tokio::test]
async fn realtime_setup_generate_fresh_attestation_headers_for_chatgpt_codex() {
let provider = api_provider("https://chatgpt.com/backend-api/codex/");
let (model_client, attestation_calls) = model_client_with_counting_attestation();
let mut first_headers = http::HeaderMap::new();
let mut second_headers = http::HeaderMap::new();
model_client
.extend_attestation_header_for(
&mut first_headers,
&provider,
AttestationPurpose::RealtimeWebrtcCallSetup,
)
.await;
model_client
.extend_attestation_header_for(
&mut second_headers,
&provider,
AttestationPurpose::RealtimeWebrtcCallSetup,
)
.await;
assert_eq!(
first_headers
.get(crate::attestation::X_OAI_ATTESTATION_HEADER)
.and_then(|value| value.to_str().ok()),
Some("v1.header-1"),
);
assert_eq!(
second_headers
.get(crate::attestation::X_OAI_ATTESTATION_HEADER)
.and_then(|value| value.to_str().ok()),
Some("v1.header-2"),
);
assert_eq!(attestation_calls.load(Ordering::Relaxed), 2);
}
#[tokio::test]
async fn non_chatgpt_codex_endpoints_omit_attestation_generation() {
let provider = api_provider("https://api.openai.com/v1");
let (model_client, attestation_calls) = model_client_with_counting_attestation();
let mut response_headers = http::HeaderMap::new();
model_client
.extend_attestation_header_for(
&mut response_headers,
&provider,
AttestationPurpose::Response,
)
.await;
let mut compaction_headers = http::HeaderMap::new();
model_client
.extend_attestation_header_for(
&mut compaction_headers,
&provider,
AttestationPurpose::Compaction,
)
.await;
let mut realtime_headers = http::HeaderMap::new();
model_client
.extend_attestation_header_for(
&mut realtime_headers,
&provider,
AttestationPurpose::RealtimeWebrtcCallSetup,
)
.await;
assert_eq!(
response_headers.get(crate::attestation::X_OAI_ATTESTATION_HEADER),
None,
);
assert_eq!(
compaction_headers.get(crate::attestation::X_OAI_ATTESTATION_HEADER),
None,
);
assert_eq!(
realtime_headers.get(crate::attestation::X_OAI_ATTESTATION_HEADER),
None,
);
assert_eq!(attestation_calls.load(Ordering::Relaxed), 0);
}