mirror of
https://github.com/openai/codex.git
synced 2026-05-16 09:12:54 +00:00
711 lines
23 KiB
Rust
711 lines
23 KiB
Rust
use super::AttestationPurpose;
|
|
use super::AuthRequestTelemetryContext;
|
|
use super::ModelClient;
|
|
use super::PendingUnauthorizedRetry;
|
|
use super::UnauthorizedRecoveryExecution;
|
|
use super::X_CODEX_INSTALLATION_ID_HEADER;
|
|
use super::X_CODEX_PARENT_THREAD_ID_HEADER;
|
|
use super::X_CODEX_TURN_METADATA_HEADER;
|
|
use super::X_CODEX_WINDOW_ID_HEADER;
|
|
use super::X_OPENAI_SUBAGENT_HEADER;
|
|
use crate::AttestationProvider;
|
|
use codex_api::ApiError;
|
|
use codex_api::ResponseEvent;
|
|
use codex_app_server_protocol::AuthMode;
|
|
use codex_model_provider::BearerAuthProvider;
|
|
use codex_model_provider_info::CHATGPT_CODEX_BASE_URL;
|
|
use codex_model_provider_info::WireApi;
|
|
use codex_model_provider_info::create_oss_provider_with_base_url;
|
|
use codex_otel::SessionTelemetry;
|
|
use codex_protocol::SessionId;
|
|
use codex_protocol::ThreadId;
|
|
use codex_protocol::models::ContentItem;
|
|
use codex_protocol::models::ResponseItem;
|
|
use codex_protocol::openai_models::ModelInfo;
|
|
use codex_protocol::protocol::InternalSessionSource;
|
|
use codex_protocol::protocol::SessionSource;
|
|
use codex_protocol::protocol::SubAgentSource;
|
|
use codex_rollout_trace::ExecutionStatus;
|
|
use codex_rollout_trace::InferenceTraceAttempt;
|
|
use codex_rollout_trace::InferenceTraceContext;
|
|
use codex_rollout_trace::RawTraceEventPayload;
|
|
use codex_rollout_trace::RolloutTrace;
|
|
use codex_rollout_trace::TraceWriter;
|
|
use codex_rollout_trace::replay_bundle;
|
|
use futures::StreamExt;
|
|
use pretty_assertions::assert_eq;
|
|
use serde_json::json;
|
|
use std::collections::BTreeMap;
|
|
use std::collections::VecDeque;
|
|
use std::pin::Pin;
|
|
use std::sync::Arc;
|
|
use std::sync::Mutex;
|
|
use std::sync::atomic::AtomicUsize;
|
|
use std::sync::atomic::Ordering;
|
|
use std::task::Context;
|
|
use std::task::Poll;
|
|
use std::time::Duration;
|
|
use tempfile::TempDir;
|
|
use tokio::sync::Notify;
|
|
use tracing::Event;
|
|
use tracing::Subscriber;
|
|
use tracing::field::Visit;
|
|
use tracing_subscriber::Layer;
|
|
use tracing_subscriber::layer::Context as LayerContext;
|
|
use tracing_subscriber::layer::SubscriberExt;
|
|
use tracing_subscriber::registry::LookupSpan;
|
|
use tracing_subscriber::util::SubscriberInitExt;
|
|
|
|
fn test_model_client(session_source: SessionSource) -> ModelClient {
|
|
let provider = create_oss_provider_with_base_url("https://example.com/v1", WireApi::Responses);
|
|
let thread_id = ThreadId::new();
|
|
ModelClient::new(
|
|
/*auth_manager*/ None,
|
|
thread_id.into(),
|
|
thread_id,
|
|
/*installation_id*/ "11111111-1111-4111-8111-111111111111".to_string(),
|
|
provider,
|
|
session_source,
|
|
/*model_verbosity*/ None,
|
|
/*enable_request_compression*/ false,
|
|
/*include_timing_metrics*/ false,
|
|
/*beta_features_header*/ None,
|
|
)
|
|
}
|
|
|
|
fn api_provider(base_url: &str) -> codex_api::Provider {
|
|
codex_api::Provider {
|
|
name: "test".to_string(),
|
|
base_url: base_url.to_string(),
|
|
uses_chatgpt_auth: base_url
|
|
.trim_end_matches('/')
|
|
.eq_ignore_ascii_case(CHATGPT_CODEX_BASE_URL),
|
|
query_params: None,
|
|
headers: http::HeaderMap::new(),
|
|
retry: codex_api::RetryConfig {
|
|
max_attempts: 1,
|
|
base_delay: Duration::from_millis(1),
|
|
retry_429: false,
|
|
retry_5xx: true,
|
|
retry_transport: true,
|
|
},
|
|
stream_idle_timeout: Duration::from_secs(1),
|
|
}
|
|
}
|
|
|
|
fn test_model_info() -> ModelInfo {
|
|
serde_json::from_value(json!({
|
|
"slug": "gpt-test",
|
|
"display_name": "gpt-test",
|
|
"description": "desc",
|
|
"default_reasoning_level": "medium",
|
|
"supported_reasoning_levels": [
|
|
{"effort": "medium", "description": "medium"}
|
|
],
|
|
"shell_type": "shell_command",
|
|
"visibility": "list",
|
|
"supported_in_api": true,
|
|
"priority": 1,
|
|
"upgrade": null,
|
|
"base_instructions": "base instructions",
|
|
"model_messages": null,
|
|
"supports_reasoning_summaries": false,
|
|
"support_verbosity": false,
|
|
"default_verbosity": null,
|
|
"apply_patch_tool_type": null,
|
|
"truncation_policy": {"mode": "bytes", "limit": 10000},
|
|
"supports_parallel_tool_calls": false,
|
|
"supports_image_detail_original": false,
|
|
"context_window": 272000,
|
|
"auto_compact_token_limit": null,
|
|
"experimental_supported_tools": []
|
|
}))
|
|
.expect("deserialize test model info")
|
|
}
|
|
|
|
fn test_session_telemetry() -> SessionTelemetry {
|
|
SessionTelemetry::new(
|
|
ThreadId::new(),
|
|
"gpt-test",
|
|
"gpt-test",
|
|
/*account_id*/ None,
|
|
/*account_email*/ None,
|
|
/*auth_mode*/ None,
|
|
"test-originator".to_string(),
|
|
/*log_user_prompts*/ false,
|
|
"test-terminal".to_string(),
|
|
SessionSource::Cli,
|
|
)
|
|
}
|
|
|
|
#[derive(Default)]
|
|
struct TagCollectorVisitor {
|
|
tags: BTreeMap<String, String>,
|
|
}
|
|
|
|
impl Visit for TagCollectorVisitor {
|
|
fn record_str(&mut self, field: &tracing::field::Field, value: &str) {
|
|
self.tags
|
|
.insert(field.name().to_string(), value.to_string());
|
|
}
|
|
|
|
fn record_debug(&mut self, field: &tracing::field::Field, value: &dyn std::fmt::Debug) {
|
|
self.tags
|
|
.insert(field.name().to_string(), format!("{value:?}"));
|
|
}
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
struct TagCollectorLayer {
|
|
tags: Arc<Mutex<BTreeMap<String, String>>>,
|
|
}
|
|
|
|
impl<S> Layer<S> for TagCollectorLayer
|
|
where
|
|
S: Subscriber + for<'a> LookupSpan<'a>,
|
|
{
|
|
fn on_event(&self, event: &Event<'_>, _ctx: LayerContext<'_, S>) {
|
|
if event.metadata().target() != "feedback_tags" {
|
|
return;
|
|
}
|
|
let mut visitor = TagCollectorVisitor::default();
|
|
event.record(&mut visitor);
|
|
self.tags.lock().unwrap().extend(visitor.tags);
|
|
}
|
|
}
|
|
|
|
fn started_inference_attempt(temp: &TempDir) -> anyhow::Result<InferenceTraceAttempt> {
|
|
let writer = Arc::new(TraceWriter::create(
|
|
temp.path(),
|
|
"trace-1".to_string(),
|
|
"rollout-1".to_string(),
|
|
"thread-root".to_string(),
|
|
)?);
|
|
writer.append(RawTraceEventPayload::ThreadStarted {
|
|
thread_id: "thread-root".to_string(),
|
|
agent_path: "/root".to_string(),
|
|
metadata_payload: None,
|
|
})?;
|
|
writer.append(RawTraceEventPayload::CodexTurnStarted {
|
|
codex_turn_id: "turn-1".to_string(),
|
|
thread_id: "thread-root".to_string(),
|
|
})?;
|
|
|
|
let inference_trace = InferenceTraceContext::enabled(
|
|
writer,
|
|
"thread-root".to_string(),
|
|
"turn-1".to_string(),
|
|
"gpt-test".to_string(),
|
|
"test-provider".to_string(),
|
|
);
|
|
let attempt = inference_trace.start_attempt();
|
|
attempt.record_started(&json!({
|
|
"model": "gpt-test",
|
|
"input": [{
|
|
"type": "message",
|
|
"role": "user",
|
|
"content": [{"type": "input_text", "text": "hello"}]
|
|
}],
|
|
}));
|
|
Ok(attempt)
|
|
}
|
|
|
|
fn output_message(id: &str, text: &str) -> ResponseItem {
|
|
ResponseItem::Message {
|
|
id: Some(id.to_string()),
|
|
role: "assistant".to_string(),
|
|
content: vec![ContentItem::OutputText {
|
|
text: text.to_string(),
|
|
}],
|
|
phase: None,
|
|
}
|
|
}
|
|
|
|
async fn replay_until_cancelled(temp: &TempDir) -> anyhow::Result<RolloutTrace> {
|
|
let mut rollout = replay_bundle(temp.path())?;
|
|
for _ in 0..50 {
|
|
let inference = rollout
|
|
.inference_calls
|
|
.values()
|
|
.next()
|
|
.expect("inference should be reduced");
|
|
if inference.execution.status == ExecutionStatus::Cancelled {
|
|
return Ok(rollout);
|
|
}
|
|
tokio::time::sleep(Duration::from_millis(10)).await;
|
|
rollout = replay_bundle(temp.path())?;
|
|
}
|
|
Ok(rollout)
|
|
}
|
|
|
|
struct NotifyAfterEventStream {
|
|
events: VecDeque<ResponseEvent>,
|
|
yielded: usize,
|
|
notify_after: usize,
|
|
notify: Arc<Notify>,
|
|
}
|
|
|
|
impl futures::Stream for NotifyAfterEventStream {
|
|
type Item = std::result::Result<ResponseEvent, ApiError>;
|
|
|
|
fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
|
let Some(event) = self.events.pop_front() else {
|
|
return Poll::Pending;
|
|
};
|
|
self.yielded += 1;
|
|
if self.yielded == self.notify_after {
|
|
self.notify.notify_one();
|
|
}
|
|
Poll::Ready(Some(Ok(event)))
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn build_subagent_headers_sets_other_subagent_label() {
|
|
let client = test_model_client(SessionSource::SubAgent(SubAgentSource::Other(
|
|
"memory_consolidation".to_string(),
|
|
)));
|
|
let headers = client.build_subagent_headers();
|
|
let value = headers
|
|
.get(X_OPENAI_SUBAGENT_HEADER)
|
|
.and_then(|value| value.to_str().ok());
|
|
assert_eq!(value, Some("memory_consolidation"));
|
|
}
|
|
|
|
#[test]
|
|
fn build_subagent_headers_sets_internal_memory_consolidation_label() {
|
|
let client = test_model_client(SessionSource::Internal(
|
|
InternalSessionSource::MemoryConsolidation,
|
|
));
|
|
let headers = client.build_subagent_headers();
|
|
let value = headers
|
|
.get(X_OPENAI_SUBAGENT_HEADER)
|
|
.and_then(|value| value.to_str().ok());
|
|
assert_eq!(value, Some("memory_consolidation"));
|
|
}
|
|
|
|
#[test]
|
|
fn build_ws_client_metadata_includes_window_lineage_and_turn_metadata() {
|
|
let parent_thread_id = ThreadId::new();
|
|
let client = test_model_client(SessionSource::SubAgent(SubAgentSource::ThreadSpawn {
|
|
parent_thread_id,
|
|
depth: 2,
|
|
agent_path: None,
|
|
agent_nickname: None,
|
|
agent_role: None,
|
|
}));
|
|
|
|
client.advance_window_generation();
|
|
|
|
let client_metadata = client.build_ws_client_metadata(Some(r#"{"turn_id":"turn-123"}"#));
|
|
let thread_id = client.state.thread_id;
|
|
assert_eq!(
|
|
client_metadata,
|
|
std::collections::HashMap::from([
|
|
(
|
|
X_CODEX_INSTALLATION_ID_HEADER.to_string(),
|
|
"11111111-1111-4111-8111-111111111111".to_string(),
|
|
),
|
|
(
|
|
X_CODEX_WINDOW_ID_HEADER.to_string(),
|
|
format!("{thread_id}:1"),
|
|
),
|
|
(
|
|
X_OPENAI_SUBAGENT_HEADER.to_string(),
|
|
"collab_spawn".to_string(),
|
|
),
|
|
(
|
|
X_CODEX_PARENT_THREAD_ID_HEADER.to_string(),
|
|
parent_thread_id.to_string(),
|
|
),
|
|
(
|
|
X_CODEX_TURN_METADATA_HEADER.to_string(),
|
|
r#"{"turn_id":"turn-123"}"#.to_string(),
|
|
),
|
|
])
|
|
);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn summarize_memories_returns_empty_for_empty_input() {
|
|
let client = test_model_client(SessionSource::Cli);
|
|
let model_info = test_model_info();
|
|
let session_telemetry = test_session_telemetry();
|
|
|
|
let output = client
|
|
.summarize_memories(
|
|
Vec::new(),
|
|
&model_info,
|
|
/*effort*/ None,
|
|
&session_telemetry,
|
|
)
|
|
.await
|
|
.expect("empty summarize request should succeed");
|
|
assert_eq!(output.len(), 0);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn dropped_response_stream_traces_cancelled_partial_output() -> anyhow::Result<()> {
|
|
let temp = TempDir::new()?;
|
|
let attempt = started_inference_attempt(&temp)?;
|
|
|
|
// The provider has produced one complete output item, but no terminal
|
|
// response.completed event. The harness has enough information to keep this
|
|
// item in history, so the trace should preserve it when the stream is
|
|
// abandoned.
|
|
let item = output_message("msg-1", "partial answer");
|
|
let api_stream = futures::stream::iter([Ok(ResponseEvent::OutputItemDone(item))])
|
|
.chain(futures::stream::pending());
|
|
let (mut stream, _) = super::map_response_events(
|
|
/*upstream_request_id*/ None,
|
|
api_stream,
|
|
test_session_telemetry(),
|
|
attempt,
|
|
);
|
|
|
|
let observed = stream
|
|
.next()
|
|
.await
|
|
.expect("mapped stream should yield output item")?;
|
|
assert!(matches!(observed, ResponseEvent::OutputItemDone(_)));
|
|
|
|
// Dropping the consumer is how turn interruption/preemption stops polling
|
|
// the provider stream. The mapper task observes that drop asynchronously
|
|
// and records cancellation using the output items it has already seen.
|
|
drop(stream);
|
|
|
|
// Cancellation is recorded by the mapper task after Drop wakes it, so the
|
|
// replay may need a short wait before the terminal event appears on disk.
|
|
let rollout = replay_until_cancelled(&temp).await?;
|
|
let inference = rollout
|
|
.inference_calls
|
|
.values()
|
|
.next()
|
|
.expect("inference should be reduced");
|
|
|
|
assert_eq!(inference.execution.status, ExecutionStatus::Cancelled);
|
|
assert_eq!(inference.response_item_ids.len(), 1);
|
|
assert_eq!(rollout.raw_payloads.len(), 2);
|
|
|
|
Ok(())
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn response_stream_records_last_model_feedback_ids() {
|
|
let tags = Arc::new(Mutex::new(BTreeMap::new()));
|
|
let _guard = tracing_subscriber::registry()
|
|
.with(TagCollectorLayer { tags: tags.clone() })
|
|
.set_default();
|
|
|
|
let api_stream = futures::stream::iter([
|
|
Ok(ResponseEvent::Created),
|
|
Ok(ResponseEvent::Completed {
|
|
response_id: "resp-123".to_string(),
|
|
token_usage: None,
|
|
end_turn: Some(true),
|
|
}),
|
|
]);
|
|
let (mut stream, _) = super::map_response_events(
|
|
Some("req-123".to_string()),
|
|
api_stream,
|
|
test_session_telemetry(),
|
|
InferenceTraceAttempt::disabled(),
|
|
);
|
|
|
|
while stream.next().await.is_some() {}
|
|
|
|
let tags = tags.lock().unwrap().clone();
|
|
assert_eq!(
|
|
tags.get("last_model_request_id").map(String::as_str),
|
|
Some("\"req-123\"")
|
|
);
|
|
assert_eq!(
|
|
tags.get("last_model_response_id").map(String::as_str),
|
|
Some("\"resp-123\"")
|
|
);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn dropped_backpressured_response_stream_traces_cancelled_partial_output()
|
|
-> anyhow::Result<()> {
|
|
let temp = TempDir::new()?;
|
|
let attempt = started_inference_attempt(&temp)?;
|
|
let backpressured_item_yielded = Arc::new(Notify::new());
|
|
let mut events = VecDeque::new();
|
|
for _ in 0..super::RESPONSE_STREAM_CHANNEL_CAPACITY {
|
|
events.push_back(ResponseEvent::Created);
|
|
}
|
|
events.push_back(ResponseEvent::OutputItemDone(output_message(
|
|
"msg-1",
|
|
"partial answer",
|
|
)));
|
|
let api_stream = NotifyAfterEventStream {
|
|
events,
|
|
yielded: 0,
|
|
notify_after: super::RESPONSE_STREAM_CHANNEL_CAPACITY + 1,
|
|
notify: Arc::clone(&backpressured_item_yielded),
|
|
};
|
|
|
|
let (stream, _) = super::map_response_events(
|
|
/*upstream_request_id*/ None,
|
|
api_stream,
|
|
test_session_telemetry(),
|
|
attempt,
|
|
);
|
|
|
|
// Fill the mapper channel with non-terminal events, then yield one output
|
|
// item. The mapper has observed that item and is blocked trying to send it
|
|
// downstream, so dropping the consumer covers the send-failure path rather
|
|
// than the `consumer_dropped` select branch.
|
|
backpressured_item_yielded.notified().await;
|
|
drop(stream);
|
|
|
|
let rollout = replay_until_cancelled(&temp).await?;
|
|
let inference = rollout
|
|
.inference_calls
|
|
.values()
|
|
.next()
|
|
.expect("inference should be reduced");
|
|
|
|
assert_eq!(inference.execution.status, ExecutionStatus::Cancelled);
|
|
assert_eq!(inference.response_item_ids.len(), 1);
|
|
assert_eq!(rollout.raw_payloads.len(), 2);
|
|
|
|
Ok(())
|
|
}
|
|
|
|
#[test]
|
|
fn auth_request_telemetry_context_tracks_attached_auth_and_retry_phase() {
|
|
let auth_context = AuthRequestTelemetryContext::new(
|
|
Some(AuthMode::Chatgpt),
|
|
&BearerAuthProvider::for_test(Some("access-token"), Some("workspace-123")),
|
|
PendingUnauthorizedRetry::from_recovery(UnauthorizedRecoveryExecution {
|
|
mode: "managed",
|
|
phase: "refresh_token",
|
|
}),
|
|
);
|
|
|
|
assert_eq!(auth_context.auth_mode, Some("Chatgpt"));
|
|
assert!(auth_context.auth_header_attached);
|
|
assert_eq!(auth_context.auth_header_name, Some("authorization"));
|
|
assert!(auth_context.retry_after_unauthorized);
|
|
assert_eq!(auth_context.recovery_mode, Some("managed"));
|
|
assert_eq!(auth_context.recovery_phase, Some("refresh_token"));
|
|
}
|
|
|
|
fn model_client_with_counting_attestation() -> (ModelClient, Arc<AtomicUsize>) {
|
|
let attestation_calls = Arc::new(AtomicUsize::new(0));
|
|
let calls = attestation_calls.clone();
|
|
let model_client = ModelClient::new_with_attestation_provider(
|
|
/*auth_manager*/ None,
|
|
SessionId::new(),
|
|
ThreadId::new(),
|
|
/*installation_id*/ "11111111-1111-4111-8111-111111111111".to_string(),
|
|
create_oss_provider_with_base_url("https://example.com/v1", WireApi::Responses),
|
|
SessionSource::Exec,
|
|
/*model_verbosity*/ None,
|
|
/*enable_request_compression*/ false,
|
|
/*include_timing_metrics*/ false,
|
|
/*beta_features_header*/ None,
|
|
Some(AttestationProvider::new(move || {
|
|
let calls = calls.clone();
|
|
Box::pin(async move {
|
|
let call = calls.fetch_add(1, Ordering::Relaxed) + 1;
|
|
Some(format!("v1.header-{call}"))
|
|
})
|
|
})),
|
|
);
|
|
(model_client, attestation_calls)
|
|
}
|
|
|
|
#[test]
|
|
fn should_send_attestation_for_allowed_chatgpt_codex_purposes() {
|
|
let provider = api_provider("https://chatgpt.com/backend-api/codex/");
|
|
|
|
for purpose in [
|
|
AttestationPurpose::Response,
|
|
AttestationPurpose::Compaction,
|
|
AttestationPurpose::RealtimeWebrtcCallSetup,
|
|
] {
|
|
assert!(super::should_send_attestation(&provider, purpose));
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn should_not_send_attestation_for_non_chatgpt_codex_provider() {
|
|
let provider = api_provider("https://api.openai.com/v1");
|
|
|
|
assert!(!super::should_send_attestation(
|
|
&provider,
|
|
AttestationPurpose::Response,
|
|
));
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn responses_generate_fresh_attestation_headers_for_chatgpt_codex() {
|
|
let provider = api_provider("https://chatgpt.com/backend-api/codex/");
|
|
let (model_client, attestation_calls) = model_client_with_counting_attestation();
|
|
let mut first_headers = http::HeaderMap::new();
|
|
let mut second_headers = http::HeaderMap::new();
|
|
|
|
model_client
|
|
.extend_attestation_header_for(&mut first_headers, &provider, AttestationPurpose::Response)
|
|
.await;
|
|
model_client
|
|
.extend_attestation_header_for(&mut second_headers, &provider, AttestationPurpose::Response)
|
|
.await;
|
|
|
|
assert_eq!(
|
|
first_headers
|
|
.get(crate::attestation::X_OAI_ATTESTATION_HEADER)
|
|
.and_then(|value| value.to_str().ok()),
|
|
Some("v1.header-1"),
|
|
);
|
|
assert_eq!(
|
|
second_headers
|
|
.get(crate::attestation::X_OAI_ATTESTATION_HEADER)
|
|
.and_then(|value| value.to_str().ok()),
|
|
Some("v1.header-2"),
|
|
);
|
|
assert_eq!(attestation_calls.load(Ordering::Relaxed), 2);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn websocket_handshake_includes_attestation_for_chatgpt_codex_responses() {
|
|
let provider = api_provider("https://chatgpt.com/backend-api/codex/");
|
|
let (model_client, attestation_calls) = model_client_with_counting_attestation();
|
|
|
|
let headers = model_client
|
|
.build_websocket_headers(
|
|
&provider, /*turn_state*/ None, /*turn_metadata_header*/ None,
|
|
)
|
|
.await;
|
|
|
|
assert_eq!(
|
|
headers
|
|
.get(crate::attestation::X_OAI_ATTESTATION_HEADER)
|
|
.and_then(|value| value.to_str().ok()),
|
|
Some("v1.header-1"),
|
|
);
|
|
assert_eq!(attestation_calls.load(Ordering::Relaxed), 1);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn compact_generate_fresh_attestation_headers_for_chatgpt_codex() {
|
|
let provider = api_provider("https://chatgpt.com/backend-api/codex/");
|
|
let (model_client, attestation_calls) = model_client_with_counting_attestation();
|
|
let mut first_headers = http::HeaderMap::new();
|
|
let mut second_headers = http::HeaderMap::new();
|
|
|
|
model_client
|
|
.extend_attestation_header_for(
|
|
&mut first_headers,
|
|
&provider,
|
|
AttestationPurpose::Compaction,
|
|
)
|
|
.await;
|
|
model_client
|
|
.extend_attestation_header_for(
|
|
&mut second_headers,
|
|
&provider,
|
|
AttestationPurpose::Compaction,
|
|
)
|
|
.await;
|
|
|
|
assert_eq!(
|
|
first_headers
|
|
.get(crate::attestation::X_OAI_ATTESTATION_HEADER)
|
|
.and_then(|value| value.to_str().ok()),
|
|
Some("v1.header-1"),
|
|
);
|
|
assert_eq!(
|
|
second_headers
|
|
.get(crate::attestation::X_OAI_ATTESTATION_HEADER)
|
|
.and_then(|value| value.to_str().ok()),
|
|
Some("v1.header-2"),
|
|
);
|
|
assert_eq!(attestation_calls.load(Ordering::Relaxed), 2);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn realtime_setup_generate_fresh_attestation_headers_for_chatgpt_codex() {
|
|
let provider = api_provider("https://chatgpt.com/backend-api/codex/");
|
|
let (model_client, attestation_calls) = model_client_with_counting_attestation();
|
|
let mut first_headers = http::HeaderMap::new();
|
|
let mut second_headers = http::HeaderMap::new();
|
|
|
|
model_client
|
|
.extend_attestation_header_for(
|
|
&mut first_headers,
|
|
&provider,
|
|
AttestationPurpose::RealtimeWebrtcCallSetup,
|
|
)
|
|
.await;
|
|
model_client
|
|
.extend_attestation_header_for(
|
|
&mut second_headers,
|
|
&provider,
|
|
AttestationPurpose::RealtimeWebrtcCallSetup,
|
|
)
|
|
.await;
|
|
|
|
assert_eq!(
|
|
first_headers
|
|
.get(crate::attestation::X_OAI_ATTESTATION_HEADER)
|
|
.and_then(|value| value.to_str().ok()),
|
|
Some("v1.header-1"),
|
|
);
|
|
assert_eq!(
|
|
second_headers
|
|
.get(crate::attestation::X_OAI_ATTESTATION_HEADER)
|
|
.and_then(|value| value.to_str().ok()),
|
|
Some("v1.header-2"),
|
|
);
|
|
assert_eq!(attestation_calls.load(Ordering::Relaxed), 2);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn non_chatgpt_codex_endpoints_omit_attestation_generation() {
|
|
let provider = api_provider("https://api.openai.com/v1");
|
|
let (model_client, attestation_calls) = model_client_with_counting_attestation();
|
|
let mut response_headers = http::HeaderMap::new();
|
|
|
|
model_client
|
|
.extend_attestation_header_for(
|
|
&mut response_headers,
|
|
&provider,
|
|
AttestationPurpose::Response,
|
|
)
|
|
.await;
|
|
let mut compaction_headers = http::HeaderMap::new();
|
|
model_client
|
|
.extend_attestation_header_for(
|
|
&mut compaction_headers,
|
|
&provider,
|
|
AttestationPurpose::Compaction,
|
|
)
|
|
.await;
|
|
let mut realtime_headers = http::HeaderMap::new();
|
|
model_client
|
|
.extend_attestation_header_for(
|
|
&mut realtime_headers,
|
|
&provider,
|
|
AttestationPurpose::RealtimeWebrtcCallSetup,
|
|
)
|
|
.await;
|
|
|
|
assert_eq!(
|
|
response_headers.get(crate::attestation::X_OAI_ATTESTATION_HEADER),
|
|
None,
|
|
);
|
|
assert_eq!(
|
|
compaction_headers.get(crate::attestation::X_OAI_ATTESTATION_HEADER),
|
|
None,
|
|
);
|
|
assert_eq!(
|
|
realtime_headers.get(crate::attestation::X_OAI_ATTESTATION_HEADER),
|
|
None,
|
|
);
|
|
assert_eq!(attestation_calls.load(Ordering::Relaxed), 0);
|
|
}
|