use super::AttestationPurpose; use super::AuthRequestTelemetryContext; use super::ModelClient; use super::PendingUnauthorizedRetry; use super::UnauthorizedRecoveryExecution; use super::X_CODEX_INSTALLATION_ID_HEADER; use super::X_CODEX_PARENT_THREAD_ID_HEADER; use super::X_CODEX_TURN_METADATA_HEADER; use super::X_CODEX_WINDOW_ID_HEADER; use super::X_OPENAI_SUBAGENT_HEADER; use crate::AttestationProvider; use codex_api::ApiError; use codex_api::ResponseEvent; use codex_app_server_protocol::AuthMode; use codex_model_provider::BearerAuthProvider; use codex_model_provider_info::CHATGPT_CODEX_BASE_URL; use codex_model_provider_info::WireApi; use codex_model_provider_info::create_oss_provider_with_base_url; use codex_otel::SessionTelemetry; use codex_protocol::SessionId; use codex_protocol::ThreadId; use codex_protocol::models::ContentItem; use codex_protocol::models::ResponseItem; use codex_protocol::openai_models::ModelInfo; use codex_protocol::protocol::InternalSessionSource; use codex_protocol::protocol::SessionSource; use codex_protocol::protocol::SubAgentSource; use codex_rollout_trace::ExecutionStatus; use codex_rollout_trace::InferenceTraceAttempt; use codex_rollout_trace::InferenceTraceContext; use codex_rollout_trace::RawTraceEventPayload; use codex_rollout_trace::RolloutTrace; use codex_rollout_trace::TraceWriter; use codex_rollout_trace::replay_bundle; use futures::StreamExt; use pretty_assertions::assert_eq; use serde_json::json; use std::collections::BTreeMap; use std::collections::VecDeque; use std::pin::Pin; use std::sync::Arc; use std::sync::Mutex; use std::sync::atomic::AtomicUsize; use std::sync::atomic::Ordering; use std::task::Context; use std::task::Poll; use std::time::Duration; use tempfile::TempDir; use tokio::sync::Notify; use tracing::Event; use tracing::Subscriber; use tracing::field::Visit; use tracing_subscriber::Layer; use tracing_subscriber::layer::Context as LayerContext; use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::registry::LookupSpan; use tracing_subscriber::util::SubscriberInitExt; fn test_model_client(session_source: SessionSource) -> ModelClient { let provider = create_oss_provider_with_base_url("https://example.com/v1", WireApi::Responses); let thread_id = ThreadId::new(); ModelClient::new( /*auth_manager*/ None, thread_id.into(), thread_id, /*installation_id*/ "11111111-1111-4111-8111-111111111111".to_string(), provider, session_source, /*model_verbosity*/ None, /*enable_request_compression*/ false, /*include_timing_metrics*/ false, /*beta_features_header*/ None, ) } fn api_provider(base_url: &str) -> codex_api::Provider { codex_api::Provider { name: "test".to_string(), base_url: base_url.to_string(), uses_chatgpt_auth: base_url .trim_end_matches('/') .eq_ignore_ascii_case(CHATGPT_CODEX_BASE_URL), query_params: None, headers: http::HeaderMap::new(), retry: codex_api::RetryConfig { max_attempts: 1, base_delay: Duration::from_millis(1), retry_429: false, retry_5xx: true, retry_transport: true, }, stream_idle_timeout: Duration::from_secs(1), } } fn test_model_info() -> ModelInfo { serde_json::from_value(json!({ "slug": "gpt-test", "display_name": "gpt-test", "description": "desc", "default_reasoning_level": "medium", "supported_reasoning_levels": [ {"effort": "medium", "description": "medium"} ], "shell_type": "shell_command", "visibility": "list", "supported_in_api": true, "priority": 1, "upgrade": null, "base_instructions": "base instructions", "model_messages": null, "supports_reasoning_summaries": false, "support_verbosity": false, "default_verbosity": null, "apply_patch_tool_type": null, "truncation_policy": {"mode": "bytes", "limit": 10000}, "supports_parallel_tool_calls": false, "supports_image_detail_original": false, "context_window": 272000, "auto_compact_token_limit": null, "experimental_supported_tools": [] })) .expect("deserialize test model info") } fn test_session_telemetry() -> SessionTelemetry { SessionTelemetry::new( ThreadId::new(), "gpt-test", "gpt-test", /*account_id*/ None, /*account_email*/ None, /*auth_mode*/ None, "test-originator".to_string(), /*log_user_prompts*/ false, "test-terminal".to_string(), SessionSource::Cli, ) } #[derive(Default)] struct TagCollectorVisitor { tags: BTreeMap, } impl Visit for TagCollectorVisitor { fn record_str(&mut self, field: &tracing::field::Field, value: &str) { self.tags .insert(field.name().to_string(), value.to_string()); } fn record_debug(&mut self, field: &tracing::field::Field, value: &dyn std::fmt::Debug) { self.tags .insert(field.name().to_string(), format!("{value:?}")); } } #[derive(Clone)] struct TagCollectorLayer { tags: Arc>>, } impl Layer for TagCollectorLayer where S: Subscriber + for<'a> LookupSpan<'a>, { fn on_event(&self, event: &Event<'_>, _ctx: LayerContext<'_, S>) { if event.metadata().target() != "feedback_tags" { return; } let mut visitor = TagCollectorVisitor::default(); event.record(&mut visitor); self.tags.lock().unwrap().extend(visitor.tags); } } fn started_inference_attempt(temp: &TempDir) -> anyhow::Result { let writer = Arc::new(TraceWriter::create( temp.path(), "trace-1".to_string(), "rollout-1".to_string(), "thread-root".to_string(), )?); writer.append(RawTraceEventPayload::ThreadStarted { thread_id: "thread-root".to_string(), agent_path: "/root".to_string(), metadata_payload: None, })?; writer.append(RawTraceEventPayload::CodexTurnStarted { codex_turn_id: "turn-1".to_string(), thread_id: "thread-root".to_string(), })?; let inference_trace = InferenceTraceContext::enabled( writer, "thread-root".to_string(), "turn-1".to_string(), "gpt-test".to_string(), "test-provider".to_string(), ); let attempt = inference_trace.start_attempt(); attempt.record_started(&json!({ "model": "gpt-test", "input": [{ "type": "message", "role": "user", "content": [{"type": "input_text", "text": "hello"}] }], })); Ok(attempt) } fn output_message(id: &str, text: &str) -> ResponseItem { ResponseItem::Message { id: Some(id.to_string()), role: "assistant".to_string(), content: vec![ContentItem::OutputText { text: text.to_string(), }], phase: None, } } async fn replay_until_cancelled(temp: &TempDir) -> anyhow::Result { let mut rollout = replay_bundle(temp.path())?; for _ in 0..50 { let inference = rollout .inference_calls .values() .next() .expect("inference should be reduced"); if inference.execution.status == ExecutionStatus::Cancelled { return Ok(rollout); } tokio::time::sleep(Duration::from_millis(10)).await; rollout = replay_bundle(temp.path())?; } Ok(rollout) } struct NotifyAfterEventStream { events: VecDeque, yielded: usize, notify_after: usize, notify: Arc, } impl futures::Stream for NotifyAfterEventStream { type Item = std::result::Result; fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { let Some(event) = self.events.pop_front() else { return Poll::Pending; }; self.yielded += 1; if self.yielded == self.notify_after { self.notify.notify_one(); } Poll::Ready(Some(Ok(event))) } } #[test] fn build_subagent_headers_sets_other_subagent_label() { let client = test_model_client(SessionSource::SubAgent(SubAgentSource::Other( "memory_consolidation".to_string(), ))); let headers = client.build_subagent_headers(); let value = headers .get(X_OPENAI_SUBAGENT_HEADER) .and_then(|value| value.to_str().ok()); assert_eq!(value, Some("memory_consolidation")); } #[test] fn build_subagent_headers_sets_internal_memory_consolidation_label() { let client = test_model_client(SessionSource::Internal( InternalSessionSource::MemoryConsolidation, )); let headers = client.build_subagent_headers(); let value = headers .get(X_OPENAI_SUBAGENT_HEADER) .and_then(|value| value.to_str().ok()); assert_eq!(value, Some("memory_consolidation")); } #[test] fn build_ws_client_metadata_includes_window_lineage_and_turn_metadata() { let parent_thread_id = ThreadId::new(); let client = test_model_client(SessionSource::SubAgent(SubAgentSource::ThreadSpawn { parent_thread_id, depth: 2, agent_path: None, agent_nickname: None, agent_role: None, })); client.advance_window_generation(); let client_metadata = client.build_ws_client_metadata(Some(r#"{"turn_id":"turn-123"}"#)); let thread_id = client.state.thread_id; assert_eq!( client_metadata, std::collections::HashMap::from([ ( X_CODEX_INSTALLATION_ID_HEADER.to_string(), "11111111-1111-4111-8111-111111111111".to_string(), ), ( X_CODEX_WINDOW_ID_HEADER.to_string(), format!("{thread_id}:1"), ), ( X_OPENAI_SUBAGENT_HEADER.to_string(), "collab_spawn".to_string(), ), ( X_CODEX_PARENT_THREAD_ID_HEADER.to_string(), parent_thread_id.to_string(), ), ( X_CODEX_TURN_METADATA_HEADER.to_string(), r#"{"turn_id":"turn-123"}"#.to_string(), ), ]) ); } #[tokio::test] async fn summarize_memories_returns_empty_for_empty_input() { let client = test_model_client(SessionSource::Cli); let model_info = test_model_info(); let session_telemetry = test_session_telemetry(); let output = client .summarize_memories( Vec::new(), &model_info, /*effort*/ None, &session_telemetry, ) .await .expect("empty summarize request should succeed"); assert_eq!(output.len(), 0); } #[tokio::test] async fn dropped_response_stream_traces_cancelled_partial_output() -> anyhow::Result<()> { let temp = TempDir::new()?; let attempt = started_inference_attempt(&temp)?; // The provider has produced one complete output item, but no terminal // response.completed event. The harness has enough information to keep this // item in history, so the trace should preserve it when the stream is // abandoned. let item = output_message("msg-1", "partial answer"); let api_stream = futures::stream::iter([Ok(ResponseEvent::OutputItemDone(item))]) .chain(futures::stream::pending()); let (mut stream, _) = super::map_response_events( /*upstream_request_id*/ None, api_stream, test_session_telemetry(), attempt, ); let observed = stream .next() .await .expect("mapped stream should yield output item")?; assert!(matches!(observed, ResponseEvent::OutputItemDone(_))); // Dropping the consumer is how turn interruption/preemption stops polling // the provider stream. The mapper task observes that drop asynchronously // and records cancellation using the output items it has already seen. drop(stream); // Cancellation is recorded by the mapper task after Drop wakes it, so the // replay may need a short wait before the terminal event appears on disk. let rollout = replay_until_cancelled(&temp).await?; let inference = rollout .inference_calls .values() .next() .expect("inference should be reduced"); assert_eq!(inference.execution.status, ExecutionStatus::Cancelled); assert_eq!(inference.response_item_ids.len(), 1); assert_eq!(rollout.raw_payloads.len(), 2); Ok(()) } #[tokio::test] async fn response_stream_records_last_model_feedback_ids() { let tags = Arc::new(Mutex::new(BTreeMap::new())); let _guard = tracing_subscriber::registry() .with(TagCollectorLayer { tags: tags.clone() }) .set_default(); let api_stream = futures::stream::iter([ Ok(ResponseEvent::Created), Ok(ResponseEvent::Completed { response_id: "resp-123".to_string(), token_usage: None, end_turn: Some(true), }), ]); let (mut stream, _) = super::map_response_events( Some("req-123".to_string()), api_stream, test_session_telemetry(), InferenceTraceAttempt::disabled(), ); while stream.next().await.is_some() {} let tags = tags.lock().unwrap().clone(); assert_eq!( tags.get("last_model_request_id").map(String::as_str), Some("\"req-123\"") ); assert_eq!( tags.get("last_model_response_id").map(String::as_str), Some("\"resp-123\"") ); } #[tokio::test] async fn dropped_backpressured_response_stream_traces_cancelled_partial_output() -> anyhow::Result<()> { let temp = TempDir::new()?; let attempt = started_inference_attempt(&temp)?; let backpressured_item_yielded = Arc::new(Notify::new()); let mut events = VecDeque::new(); for _ in 0..super::RESPONSE_STREAM_CHANNEL_CAPACITY { events.push_back(ResponseEvent::Created); } events.push_back(ResponseEvent::OutputItemDone(output_message( "msg-1", "partial answer", ))); let api_stream = NotifyAfterEventStream { events, yielded: 0, notify_after: super::RESPONSE_STREAM_CHANNEL_CAPACITY + 1, notify: Arc::clone(&backpressured_item_yielded), }; let (stream, _) = super::map_response_events( /*upstream_request_id*/ None, api_stream, test_session_telemetry(), attempt, ); // Fill the mapper channel with non-terminal events, then yield one output // item. The mapper has observed that item and is blocked trying to send it // downstream, so dropping the consumer covers the send-failure path rather // than the `consumer_dropped` select branch. backpressured_item_yielded.notified().await; drop(stream); let rollout = replay_until_cancelled(&temp).await?; let inference = rollout .inference_calls .values() .next() .expect("inference should be reduced"); assert_eq!(inference.execution.status, ExecutionStatus::Cancelled); assert_eq!(inference.response_item_ids.len(), 1); assert_eq!(rollout.raw_payloads.len(), 2); Ok(()) } #[test] fn auth_request_telemetry_context_tracks_attached_auth_and_retry_phase() { let auth_context = AuthRequestTelemetryContext::new( Some(AuthMode::Chatgpt), &BearerAuthProvider::for_test(Some("access-token"), Some("workspace-123")), PendingUnauthorizedRetry::from_recovery(UnauthorizedRecoveryExecution { mode: "managed", phase: "refresh_token", }), ); assert_eq!(auth_context.auth_mode, Some("Chatgpt")); assert!(auth_context.auth_header_attached); assert_eq!(auth_context.auth_header_name, Some("authorization")); assert!(auth_context.retry_after_unauthorized); assert_eq!(auth_context.recovery_mode, Some("managed")); assert_eq!(auth_context.recovery_phase, Some("refresh_token")); } fn model_client_with_counting_attestation() -> (ModelClient, Arc) { let attestation_calls = Arc::new(AtomicUsize::new(0)); let calls = attestation_calls.clone(); let model_client = ModelClient::new_with_attestation_provider( /*auth_manager*/ None, SessionId::new(), ThreadId::new(), /*installation_id*/ "11111111-1111-4111-8111-111111111111".to_string(), create_oss_provider_with_base_url("https://example.com/v1", WireApi::Responses), SessionSource::Exec, /*model_verbosity*/ None, /*enable_request_compression*/ false, /*include_timing_metrics*/ false, /*beta_features_header*/ None, Some(AttestationProvider::new(move || { let calls = calls.clone(); Box::pin(async move { let call = calls.fetch_add(1, Ordering::Relaxed) + 1; Some(format!("v1.header-{call}")) }) })), ); (model_client, attestation_calls) } #[test] fn should_send_attestation_for_allowed_chatgpt_codex_purposes() { let provider = api_provider("https://chatgpt.com/backend-api/codex/"); for purpose in [ AttestationPurpose::Response, AttestationPurpose::Compaction, AttestationPurpose::RealtimeWebrtcCallSetup, ] { assert!(super::should_send_attestation(&provider, purpose)); } } #[test] fn should_not_send_attestation_for_non_chatgpt_codex_provider() { let provider = api_provider("https://api.openai.com/v1"); assert!(!super::should_send_attestation( &provider, AttestationPurpose::Response, )); } #[tokio::test] async fn responses_generate_fresh_attestation_headers_for_chatgpt_codex() { let provider = api_provider("https://chatgpt.com/backend-api/codex/"); let (model_client, attestation_calls) = model_client_with_counting_attestation(); let mut first_headers = http::HeaderMap::new(); let mut second_headers = http::HeaderMap::new(); model_client .extend_attestation_header_for(&mut first_headers, &provider, AttestationPurpose::Response) .await; model_client .extend_attestation_header_for(&mut second_headers, &provider, AttestationPurpose::Response) .await; assert_eq!( first_headers .get(crate::attestation::X_OAI_ATTESTATION_HEADER) .and_then(|value| value.to_str().ok()), Some("v1.header-1"), ); assert_eq!( second_headers .get(crate::attestation::X_OAI_ATTESTATION_HEADER) .and_then(|value| value.to_str().ok()), Some("v1.header-2"), ); assert_eq!(attestation_calls.load(Ordering::Relaxed), 2); } #[tokio::test] async fn websocket_handshake_includes_attestation_for_chatgpt_codex_responses() { let provider = api_provider("https://chatgpt.com/backend-api/codex/"); let (model_client, attestation_calls) = model_client_with_counting_attestation(); let headers = model_client .build_websocket_headers( &provider, /*turn_state*/ None, /*turn_metadata_header*/ None, ) .await; assert_eq!( headers .get(crate::attestation::X_OAI_ATTESTATION_HEADER) .and_then(|value| value.to_str().ok()), Some("v1.header-1"), ); assert_eq!(attestation_calls.load(Ordering::Relaxed), 1); } #[tokio::test] async fn compact_generate_fresh_attestation_headers_for_chatgpt_codex() { let provider = api_provider("https://chatgpt.com/backend-api/codex/"); let (model_client, attestation_calls) = model_client_with_counting_attestation(); let mut first_headers = http::HeaderMap::new(); let mut second_headers = http::HeaderMap::new(); model_client .extend_attestation_header_for( &mut first_headers, &provider, AttestationPurpose::Compaction, ) .await; model_client .extend_attestation_header_for( &mut second_headers, &provider, AttestationPurpose::Compaction, ) .await; assert_eq!( first_headers .get(crate::attestation::X_OAI_ATTESTATION_HEADER) .and_then(|value| value.to_str().ok()), Some("v1.header-1"), ); assert_eq!( second_headers .get(crate::attestation::X_OAI_ATTESTATION_HEADER) .and_then(|value| value.to_str().ok()), Some("v1.header-2"), ); assert_eq!(attestation_calls.load(Ordering::Relaxed), 2); } #[tokio::test] async fn realtime_setup_generate_fresh_attestation_headers_for_chatgpt_codex() { let provider = api_provider("https://chatgpt.com/backend-api/codex/"); let (model_client, attestation_calls) = model_client_with_counting_attestation(); let mut first_headers = http::HeaderMap::new(); let mut second_headers = http::HeaderMap::new(); model_client .extend_attestation_header_for( &mut first_headers, &provider, AttestationPurpose::RealtimeWebrtcCallSetup, ) .await; model_client .extend_attestation_header_for( &mut second_headers, &provider, AttestationPurpose::RealtimeWebrtcCallSetup, ) .await; assert_eq!( first_headers .get(crate::attestation::X_OAI_ATTESTATION_HEADER) .and_then(|value| value.to_str().ok()), Some("v1.header-1"), ); assert_eq!( second_headers .get(crate::attestation::X_OAI_ATTESTATION_HEADER) .and_then(|value| value.to_str().ok()), Some("v1.header-2"), ); assert_eq!(attestation_calls.load(Ordering::Relaxed), 2); } #[tokio::test] async fn non_chatgpt_codex_endpoints_omit_attestation_generation() { let provider = api_provider("https://api.openai.com/v1"); let (model_client, attestation_calls) = model_client_with_counting_attestation(); let mut response_headers = http::HeaderMap::new(); model_client .extend_attestation_header_for( &mut response_headers, &provider, AttestationPurpose::Response, ) .await; let mut compaction_headers = http::HeaderMap::new(); model_client .extend_attestation_header_for( &mut compaction_headers, &provider, AttestationPurpose::Compaction, ) .await; let mut realtime_headers = http::HeaderMap::new(); model_client .extend_attestation_header_for( &mut realtime_headers, &provider, AttestationPurpose::RealtimeWebrtcCallSetup, ) .await; assert_eq!( response_headers.get(crate::attestation::X_OAI_ATTESTATION_HEADER), None, ); assert_eq!( compaction_headers.get(crate::attestation::X_OAI_ATTESTATION_HEADER), None, ); assert_eq!( realtime_headers.get(crate::attestation::X_OAI_ATTESTATION_HEADER), None, ); assert_eq!(attestation_calls.load(Ordering::Relaxed), 0); }