Files
codex/codex-rs/core/src/client_tests.rs
Jiaming Zhang 5f4d0ec343 [codex] request desktop attestation from app (#20619)
## Summary

TL;DR: teaches `codex-rs` / app-server to request a desktop-provided
attestation token and attach it as `x-oai-attestation` on the scoped
ChatGPT Codex request paths.

![DeviceCheck attestation
interface](https://raw.githubusercontent.com/openai/codex/dev/jm/devicecheck-diagram-assets/pr-assets/devicecheck-attestation-interface.png)

## Details

This PR teaches the Codex app-server runtime how to request and attach
an attestation token. It does not generate DeviceCheck tokens directly;
instead, it relies on the connected desktop app to advertise that it can
generate attestation and then asks that app for a fresh header value
when needed.

The flow is:

1. The Codex desktop app connects to app-server.
2. During `initialize`, the app can advertise that it supports
`requestAttestation`.
3. Before app-server calls selected ChatGPT Codex endpoints, it sends
the internal server request `attestation/generate` to the app.
4. app-server receives a pre-encoded header value back.
5. app-server forwards that value as `x-oai-attestation` on the scoped
outbound requests.

The code in this repo is mostly protocol and runtime plumbing: it adds
the app-server request/response shape, introduces an attestation
provider in core, wires that provider into Responses / compaction /
realtime setup paths, and covers the intended scoping with tests. The
signed macOS DeviceCheck generation remains owned by the desktop app PR.

## Related PR

- Codex desktop app implementation:
https://github.com/openai/openai/pull/878649

## Validation

<details>
<summary>Tests run</summary>

```sh
cargo test -p codex-app-server-protocol
cargo test -p codex-core attestation --lib
cargo test -p codex-app-server --lib attestation
```

Also ran:

```sh
just fix -p codex-core
just fix -p codex-app-server
just fix -p codex-app-server-protocol
just fmt
just write-app-server-schema
```

</details>

<details>
<summary>E2E DeviceCheck validation</summary>

First validated the signed desktop app boundary directly: launched a
packaged signed `Codex.app`, sent `attestation/generate`, decoded the
returned `v1.` attestation header, and validated the extracted
DeviceCheck token with `personal/jm/verify_devicecheck_token.py` using
bundle ID `com.openai.codex`. Apple returned `status_code: 200` and
`is_ok: true`.

Then ran the fuller app + app-server flow. The packaged `Codex.app`
launched a current-branch app-server via `CODEX_CLI_PATH`, and a local
MITM proxy intercepted outbound `chatgpt.com` traffic. The app-server
requested `attestation/generate` from the real Electron app process, and
the intercepted `/backend-api/codex/responses` traffic included
`x-oai-attestation` on both routes:

```text
GET  /backend-api/codex/responses  Upgrade: websocket  x-oai-attestation: present
POST /backend-api/codex/responses  Upgrade: none       x-oai-attestation: present
```

The captured header decoded to a DeviceCheck token that also validated
with Apple for `com.openai.codex` (`status_code: 200`, `is_ok: true`,
team `2DC432GLL2`).

</details>

---------

Co-authored-by: Codex <noreply@openai.com>
2026-05-08 12:36:02 -07:00

584 lines
19 KiB
Rust

use super::AuthRequestTelemetryContext;
use super::ModelClient;
use super::PendingUnauthorizedRetry;
use super::UnauthorizedRecoveryExecution;
use super::X_CODEX_INSTALLATION_ID_HEADER;
use super::X_CODEX_PARENT_THREAD_ID_HEADER;
use super::X_CODEX_TURN_METADATA_HEADER;
use super::X_CODEX_WINDOW_ID_HEADER;
use super::X_OPENAI_SUBAGENT_HEADER;
use crate::AttestationContext;
use crate::AttestationProvider;
use crate::GenerateAttestationFuture;
use codex_api::ApiError;
use codex_api::ResponseEvent;
use codex_app_server_protocol::AuthMode;
use codex_login::AuthManager;
use codex_login::CodexAuth;
use codex_model_provider::BearerAuthProvider;
use codex_model_provider_info::CHATGPT_CODEX_BASE_URL;
use codex_model_provider_info::ModelProviderInfo;
use codex_model_provider_info::WireApi;
use codex_model_provider_info::create_oss_provider_with_base_url;
use codex_otel::SessionTelemetry;
use codex_protocol::SessionId;
use codex_protocol::ThreadId;
use codex_protocol::models::ContentItem;
use codex_protocol::models::ResponseItem;
use codex_protocol::openai_models::ModelInfo;
use codex_protocol::protocol::InternalSessionSource;
use codex_protocol::protocol::SessionSource;
use codex_protocol::protocol::SubAgentSource;
use codex_rollout_trace::ExecutionStatus;
use codex_rollout_trace::InferenceTraceAttempt;
use codex_rollout_trace::InferenceTraceContext;
use codex_rollout_trace::RawTraceEventPayload;
use codex_rollout_trace::RolloutTrace;
use codex_rollout_trace::TraceWriter;
use codex_rollout_trace::replay_bundle;
use futures::StreamExt;
use pretty_assertions::assert_eq;
use serde_json::json;
use std::collections::BTreeMap;
use std::collections::VecDeque;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::Mutex;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;
use std::task::Context;
use std::task::Poll;
use std::time::Duration;
use tempfile::TempDir;
use tokio::sync::Notify;
use tracing::Event;
use tracing::Subscriber;
use tracing::field::Visit;
use tracing_subscriber::Layer;
use tracing_subscriber::layer::Context as LayerContext;
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::registry::LookupSpan;
use tracing_subscriber::util::SubscriberInitExt;
fn test_model_client(session_source: SessionSource) -> ModelClient {
let provider = create_oss_provider_with_base_url("https://example.com/v1", WireApi::Responses);
let thread_id = ThreadId::new();
ModelClient::new(
/*auth_manager*/ None,
thread_id.into(),
thread_id,
/*installation_id*/ "11111111-1111-4111-8111-111111111111".to_string(),
provider,
session_source,
/*model_verbosity*/ None,
/*enable_request_compression*/ false,
/*include_timing_metrics*/ false,
/*beta_features_header*/ None,
/*attestation_provider*/ None,
)
}
fn test_model_info() -> ModelInfo {
serde_json::from_value(json!({
"slug": "gpt-test",
"display_name": "gpt-test",
"description": "desc",
"default_reasoning_level": "medium",
"supported_reasoning_levels": [
{"effort": "medium", "description": "medium"}
],
"shell_type": "shell_command",
"visibility": "list",
"supported_in_api": true,
"priority": 1,
"upgrade": null,
"base_instructions": "base instructions",
"model_messages": null,
"supports_reasoning_summaries": false,
"support_verbosity": false,
"default_verbosity": null,
"apply_patch_tool_type": null,
"truncation_policy": {"mode": "bytes", "limit": 10000},
"supports_parallel_tool_calls": false,
"supports_image_detail_original": false,
"context_window": 272000,
"auto_compact_token_limit": null,
"experimental_supported_tools": []
}))
.expect("deserialize test model info")
}
fn test_session_telemetry() -> SessionTelemetry {
SessionTelemetry::new(
ThreadId::new(),
"gpt-test",
"gpt-test",
/*account_id*/ None,
/*account_email*/ None,
/*auth_mode*/ None,
"test-originator".to_string(),
/*log_user_prompts*/ false,
"test-terminal".to_string(),
SessionSource::Cli,
)
}
#[derive(Default)]
struct TagCollectorVisitor {
tags: BTreeMap<String, String>,
}
impl Visit for TagCollectorVisitor {
fn record_str(&mut self, field: &tracing::field::Field, value: &str) {
self.tags
.insert(field.name().to_string(), value.to_string());
}
fn record_debug(&mut self, field: &tracing::field::Field, value: &dyn std::fmt::Debug) {
self.tags
.insert(field.name().to_string(), format!("{value:?}"));
}
}
#[derive(Clone)]
struct TagCollectorLayer {
tags: Arc<Mutex<BTreeMap<String, String>>>,
}
impl<S> Layer<S> for TagCollectorLayer
where
S: Subscriber + for<'a> LookupSpan<'a>,
{
fn on_event(&self, event: &Event<'_>, _ctx: LayerContext<'_, S>) {
if event.metadata().target() != "feedback_tags" {
return;
}
let mut visitor = TagCollectorVisitor::default();
event.record(&mut visitor);
self.tags.lock().unwrap().extend(visitor.tags);
}
}
fn started_inference_attempt(temp: &TempDir) -> anyhow::Result<InferenceTraceAttempt> {
let writer = Arc::new(TraceWriter::create(
temp.path(),
"trace-1".to_string(),
"rollout-1".to_string(),
"thread-root".to_string(),
)?);
writer.append(RawTraceEventPayload::ThreadStarted {
thread_id: "thread-root".to_string(),
agent_path: "/root".to_string(),
metadata_payload: None,
})?;
writer.append(RawTraceEventPayload::CodexTurnStarted {
codex_turn_id: "turn-1".to_string(),
thread_id: "thread-root".to_string(),
})?;
let inference_trace = InferenceTraceContext::enabled(
writer,
"thread-root".to_string(),
"turn-1".to_string(),
"gpt-test".to_string(),
"test-provider".to_string(),
);
let attempt = inference_trace.start_attempt();
attempt.record_started(&json!({
"model": "gpt-test",
"input": [{
"type": "message",
"role": "user",
"content": [{"type": "input_text", "text": "hello"}]
}],
}));
Ok(attempt)
}
fn output_message(id: &str, text: &str) -> ResponseItem {
ResponseItem::Message {
id: Some(id.to_string()),
role: "assistant".to_string(),
content: vec![ContentItem::OutputText {
text: text.to_string(),
}],
phase: None,
}
}
async fn replay_until_cancelled(temp: &TempDir) -> anyhow::Result<RolloutTrace> {
let mut rollout = replay_bundle(temp.path())?;
for _ in 0..50 {
let inference = rollout
.inference_calls
.values()
.next()
.expect("inference should be reduced");
if inference.execution.status == ExecutionStatus::Cancelled {
return Ok(rollout);
}
tokio::time::sleep(Duration::from_millis(10)).await;
rollout = replay_bundle(temp.path())?;
}
Ok(rollout)
}
struct NotifyAfterEventStream {
events: VecDeque<ResponseEvent>,
yielded: usize,
notify_after: usize,
notify: Arc<Notify>,
}
impl futures::Stream for NotifyAfterEventStream {
type Item = std::result::Result<ResponseEvent, ApiError>;
fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let Some(event) = self.events.pop_front() else {
return Poll::Pending;
};
self.yielded += 1;
if self.yielded == self.notify_after {
self.notify.notify_one();
}
Poll::Ready(Some(Ok(event)))
}
}
#[test]
fn build_subagent_headers_sets_other_subagent_label() {
let client = test_model_client(SessionSource::SubAgent(SubAgentSource::Other(
"memory_consolidation".to_string(),
)));
let headers = client.build_subagent_headers();
let value = headers
.get(X_OPENAI_SUBAGENT_HEADER)
.and_then(|value| value.to_str().ok());
assert_eq!(value, Some("memory_consolidation"));
}
#[test]
fn build_subagent_headers_sets_internal_memory_consolidation_label() {
let client = test_model_client(SessionSource::Internal(
InternalSessionSource::MemoryConsolidation,
));
let headers = client.build_subagent_headers();
let value = headers
.get(X_OPENAI_SUBAGENT_HEADER)
.and_then(|value| value.to_str().ok());
assert_eq!(value, Some("memory_consolidation"));
}
#[test]
fn build_ws_client_metadata_includes_window_lineage_and_turn_metadata() {
let parent_thread_id = ThreadId::new();
let client = test_model_client(SessionSource::SubAgent(SubAgentSource::ThreadSpawn {
parent_thread_id,
depth: 2,
agent_path: None,
agent_nickname: None,
agent_role: None,
}));
client.advance_window_generation();
let client_metadata = client.build_ws_client_metadata(Some(r#"{"turn_id":"turn-123"}"#));
let thread_id = client.state.thread_id;
assert_eq!(
client_metadata,
std::collections::HashMap::from([
(
X_CODEX_INSTALLATION_ID_HEADER.to_string(),
"11111111-1111-4111-8111-111111111111".to_string(),
),
(
X_CODEX_WINDOW_ID_HEADER.to_string(),
format!("{thread_id}:1"),
),
(
X_OPENAI_SUBAGENT_HEADER.to_string(),
"collab_spawn".to_string(),
),
(
X_CODEX_PARENT_THREAD_ID_HEADER.to_string(),
parent_thread_id.to_string(),
),
(
X_CODEX_TURN_METADATA_HEADER.to_string(),
r#"{"turn_id":"turn-123"}"#.to_string(),
),
])
);
}
#[tokio::test]
async fn summarize_memories_returns_empty_for_empty_input() {
let client = test_model_client(SessionSource::Cli);
let model_info = test_model_info();
let session_telemetry = test_session_telemetry();
let output = client
.summarize_memories(
Vec::new(),
&model_info,
/*effort*/ None,
&session_telemetry,
)
.await
.expect("empty summarize request should succeed");
assert_eq!(output.len(), 0);
}
#[tokio::test]
async fn dropped_response_stream_traces_cancelled_partial_output() -> anyhow::Result<()> {
let temp = TempDir::new()?;
let attempt = started_inference_attempt(&temp)?;
// The provider has produced one complete output item, but no terminal
// response.completed event. The harness has enough information to keep this
// item in history, so the trace should preserve it when the stream is
// abandoned.
let item = output_message("msg-1", "partial answer");
let api_stream = futures::stream::iter([Ok(ResponseEvent::OutputItemDone(item))])
.chain(futures::stream::pending());
let (mut stream, _) = super::map_response_events(
/*upstream_request_id*/ None,
api_stream,
test_session_telemetry(),
attempt,
);
let observed = stream
.next()
.await
.expect("mapped stream should yield output item")?;
assert!(matches!(observed, ResponseEvent::OutputItemDone(_)));
// Dropping the consumer is how turn interruption/preemption stops polling
// the provider stream. The mapper task observes that drop asynchronously
// and records cancellation using the output items it has already seen.
drop(stream);
// Cancellation is recorded by the mapper task after Drop wakes it, so the
// replay may need a short wait before the terminal event appears on disk.
let rollout = replay_until_cancelled(&temp).await?;
let inference = rollout
.inference_calls
.values()
.next()
.expect("inference should be reduced");
assert_eq!(inference.execution.status, ExecutionStatus::Cancelled);
assert_eq!(inference.response_item_ids.len(), 1);
assert_eq!(rollout.raw_payloads.len(), 2);
Ok(())
}
#[tokio::test]
async fn response_stream_records_last_model_feedback_ids() {
let tags = Arc::new(Mutex::new(BTreeMap::new()));
let _guard = tracing_subscriber::registry()
.with(TagCollectorLayer { tags: tags.clone() })
.set_default();
let api_stream = futures::stream::iter([
Ok(ResponseEvent::Created),
Ok(ResponseEvent::Completed {
response_id: "resp-123".to_string(),
token_usage: None,
end_turn: Some(true),
}),
]);
let (mut stream, _) = super::map_response_events(
Some("req-123".to_string()),
api_stream,
test_session_telemetry(),
InferenceTraceAttempt::disabled(),
);
while stream.next().await.is_some() {}
let tags = tags.lock().unwrap().clone();
assert_eq!(
tags.get("last_model_request_id").map(String::as_str),
Some("\"req-123\"")
);
assert_eq!(
tags.get("last_model_response_id").map(String::as_str),
Some("\"resp-123\"")
);
}
#[tokio::test]
async fn dropped_backpressured_response_stream_traces_cancelled_partial_output()
-> anyhow::Result<()> {
let temp = TempDir::new()?;
let attempt = started_inference_attempt(&temp)?;
let backpressured_item_yielded = Arc::new(Notify::new());
let mut events = VecDeque::new();
for _ in 0..super::RESPONSE_STREAM_CHANNEL_CAPACITY {
events.push_back(ResponseEvent::Created);
}
events.push_back(ResponseEvent::OutputItemDone(output_message(
"msg-1",
"partial answer",
)));
let api_stream = NotifyAfterEventStream {
events,
yielded: 0,
notify_after: super::RESPONSE_STREAM_CHANNEL_CAPACITY + 1,
notify: Arc::clone(&backpressured_item_yielded),
};
let (stream, _) = super::map_response_events(
/*upstream_request_id*/ None,
api_stream,
test_session_telemetry(),
attempt,
);
// Fill the mapper channel with non-terminal events, then yield one output
// item. The mapper has observed that item and is blocked trying to send it
// downstream, so dropping the consumer covers the send-failure path rather
// than the `consumer_dropped` select branch.
backpressured_item_yielded.notified().await;
drop(stream);
let rollout = replay_until_cancelled(&temp).await?;
let inference = rollout
.inference_calls
.values()
.next()
.expect("inference should be reduced");
assert_eq!(inference.execution.status, ExecutionStatus::Cancelled);
assert_eq!(inference.response_item_ids.len(), 1);
assert_eq!(rollout.raw_payloads.len(), 2);
Ok(())
}
#[test]
fn auth_request_telemetry_context_tracks_attached_auth_and_retry_phase() {
let auth_context = AuthRequestTelemetryContext::new(
Some(AuthMode::Chatgpt),
&BearerAuthProvider::for_test(Some("access-token"), Some("workspace-123")),
PendingUnauthorizedRetry::from_recovery(UnauthorizedRecoveryExecution {
mode: "managed",
phase: "refresh_token",
}),
);
assert_eq!(auth_context.auth_mode, Some("Chatgpt"));
assert!(auth_context.auth_header_attached);
assert_eq!(auth_context.auth_header_name, Some("authorization"));
assert!(auth_context.retry_after_unauthorized);
assert_eq!(auth_context.recovery_mode, Some("managed"));
assert_eq!(auth_context.recovery_phase, Some("refresh_token"));
}
fn model_client_with_counting_attestation(
include_attestation: bool,
) -> (ModelClient, Arc<AtomicUsize>) {
#[derive(Debug)]
struct CountingAttestationProvider {
calls: Arc<AtomicUsize>,
}
impl AttestationProvider for CountingAttestationProvider {
fn header_for_request(
&self,
_context: AttestationContext,
) -> GenerateAttestationFuture<'_> {
let calls = self.calls.clone();
Box::pin(async move {
let call = calls.fetch_add(1, Ordering::Relaxed) + 1;
Some(http::HeaderValue::from_bytes(format!("v1.header-{call}").as_bytes()).unwrap())
})
}
}
let attestation_calls = Arc::new(AtomicUsize::new(0));
let (auth_manager, provider) = if include_attestation {
(
Some(AuthManager::from_auth_for_testing(
CodexAuth::create_dummy_chatgpt_auth_for_testing(),
)),
ModelProviderInfo::create_openai_provider(Some(CHATGPT_CODEX_BASE_URL.to_string())),
)
} else {
(
None,
create_oss_provider_with_base_url("https://example.com/v1", WireApi::Responses),
)
};
let model_client = ModelClient::new(
auth_manager,
SessionId::new(),
ThreadId::new(),
/*installation_id*/ "11111111-1111-4111-8111-111111111111".to_string(),
provider,
SessionSource::Exec,
/*model_verbosity*/ None,
/*enable_request_compression*/ false,
/*include_timing_metrics*/ false,
/*beta_features_header*/ None,
Some(Arc::new(CountingAttestationProvider {
calls: attestation_calls.clone(),
})),
);
(model_client, attestation_calls)
}
#[tokio::test]
async fn websocket_handshake_includes_attestation_for_chatgpt_codex_responses() {
let (model_client, attestation_calls) =
model_client_with_counting_attestation(/*include_attestation*/ true);
let headers = model_client
.build_websocket_headers(/*turn_state*/ None, /*turn_metadata_header*/ None)
.await;
assert_eq!(
headers
.get(crate::attestation::X_OAI_ATTESTATION_HEADER)
.and_then(|value| value.to_str().ok()),
Some("v1.header-1"),
);
assert_eq!(attestation_calls.load(Ordering::Relaxed), 1);
}
#[tokio::test]
async fn non_chatgpt_codex_endpoints_omit_attestation_generation() {
let (model_client, attestation_calls) =
model_client_with_counting_attestation(/*include_attestation*/ false);
let mut response_headers = http::HeaderMap::new();
if let Some(header_value) = model_client.generate_attestation_header_for().await {
response_headers.insert(crate::attestation::X_OAI_ATTESTATION_HEADER, header_value);
}
let mut compaction_headers = http::HeaderMap::new();
if let Some(header_value) = model_client.generate_attestation_header_for().await {
compaction_headers.insert(crate::attestation::X_OAI_ATTESTATION_HEADER, header_value);
}
let mut realtime_headers = http::HeaderMap::new();
if let Some(header_value) = model_client.generate_attestation_header_for().await {
realtime_headers.insert(crate::attestation::X_OAI_ATTESTATION_HEADER, header_value);
}
assert_eq!(
response_headers.get(crate::attestation::X_OAI_ATTESTATION_HEADER),
None,
);
assert_eq!(
compaction_headers.get(crate::attestation::X_OAI_ATTESTATION_HEADER),
None,
);
assert_eq!(
realtime_headers.get(crate::attestation::X_OAI_ATTESTATION_HEADER),
None,
);
assert_eq!(attestation_calls.load(Ordering::Relaxed), 0);
}