Files
codex/codex-rs/app-server/src/thread_state.rs
Jiaming Zhang 5f4d0ec343 [codex] request desktop attestation from app (#20619)
## Summary

TL;DR: teaches `codex-rs` / app-server to request a desktop-provided
attestation token and attach it as `x-oai-attestation` on the scoped
ChatGPT Codex request paths.

![DeviceCheck attestation
interface](https://raw.githubusercontent.com/openai/codex/dev/jm/devicecheck-diagram-assets/pr-assets/devicecheck-attestation-interface.png)

## Details

This PR teaches the Codex app-server runtime how to request and attach
an attestation token. It does not generate DeviceCheck tokens directly;
instead, it relies on the connected desktop app to advertise that it can
generate attestation and then asks that app for a fresh header value
when needed.

The flow is:

1. The Codex desktop app connects to app-server.
2. During `initialize`, the app can advertise that it supports
`requestAttestation`.
3. Before app-server calls selected ChatGPT Codex endpoints, it sends
the internal server request `attestation/generate` to the app.
4. app-server receives a pre-encoded header value back.
5. app-server forwards that value as `x-oai-attestation` on the scoped
outbound requests.

The code in this repo is mostly protocol and runtime plumbing: it adds
the app-server request/response shape, introduces an attestation
provider in core, wires that provider into Responses / compaction /
realtime setup paths, and covers the intended scoping with tests. The
signed macOS DeviceCheck generation remains owned by the desktop app PR.

## Related PR

- Codex desktop app implementation:
https://github.com/openai/openai/pull/878649

## Validation

<details>
<summary>Tests run</summary>

```sh
cargo test -p codex-app-server-protocol
cargo test -p codex-core attestation --lib
cargo test -p codex-app-server --lib attestation
```

Also ran:

```sh
just fix -p codex-core
just fix -p codex-app-server
just fix -p codex-app-server-protocol
just fmt
just write-app-server-schema
```

</details>

<details>
<summary>E2E DeviceCheck validation</summary>

First validated the signed desktop app boundary directly: launched a
packaged signed `Codex.app`, sent `attestation/generate`, decoded the
returned `v1.` attestation header, and validated the extracted
DeviceCheck token with `personal/jm/verify_devicecheck_token.py` using
bundle ID `com.openai.codex`. Apple returned `status_code: 200` and
`is_ok: true`.

Then ran the fuller app + app-server flow. The packaged `Codex.app`
launched a current-branch app-server via `CODEX_CLI_PATH`, and a local
MITM proxy intercepted outbound `chatgpt.com` traffic. The app-server
requested `attestation/generate` from the real Electron app process, and
the intercepted `/backend-api/codex/responses` traffic included
`x-oai-attestation` on both routes:

```text
GET  /backend-api/codex/responses  Upgrade: websocket  x-oai-attestation: present
POST /backend-api/codex/responses  Upgrade: none       x-oai-attestation: present
```

The captured header decoded to a DeviceCheck token that also validated
with Apple for `com.openai.codex` (`status_code: 200`, `is_ok: true`,
team `2DC432GLL2`).

</details>

---------

Co-authored-by: Codex <noreply@openai.com>
2026-05-08 12:36:02 -07:00

449 lines
15 KiB
Rust

use crate::outgoing_message::ConnectionId;
use crate::outgoing_message::ConnectionRequestId;
use codex_app_server_protocol::RequestId;
use codex_app_server_protocol::ThreadGoal;
use codex_app_server_protocol::ThreadHistoryBuilder;
use codex_app_server_protocol::Turn;
use codex_app_server_protocol::TurnError;
use codex_core::CodexThread;
use codex_core::ThreadConfigSnapshot;
use codex_protocol::ThreadId;
use codex_protocol::protocol::EventMsg;
use codex_protocol::protocol::RolloutItem;
use codex_rollout::state_db::StateDbHandle;
use codex_utils_absolute_path::AbsolutePathBuf;
use std::collections::HashMap;
use std::collections::HashSet;
use std::sync::Arc;
use std::sync::Weak;
use tokio::sync::Mutex;
use tokio::sync::mpsc;
use tokio::sync::oneshot;
use tokio::sync::watch;
use tracing::error;
type PendingInterruptQueue = Vec<ConnectionRequestId>;
pub(crate) struct PendingThreadResumeRequest {
pub(crate) request_id: ConnectionRequestId,
pub(crate) history_items: Vec<RolloutItem>,
pub(crate) config_snapshot: ThreadConfigSnapshot,
pub(crate) instruction_sources: Vec<AbsolutePathBuf>,
pub(crate) thread_summary: codex_app_server_protocol::Thread,
pub(crate) emit_thread_goal_update: bool,
pub(crate) thread_goal_state_db: Option<StateDbHandle>,
pub(crate) include_turns: bool,
}
// ThreadListenerCommand is used to perform operations in the context of the thread listener, for serialization purposes.
pub(crate) enum ThreadListenerCommand {
// SendThreadResumeResponse is used to resume an already running thread by sending the thread's history to the client and atomically subscribing for new updates.
SendThreadResumeResponse(Box<PendingThreadResumeRequest>),
// EmitThreadGoalUpdated is used to order app-server goal updates with running-thread resume responses.
EmitThreadGoalUpdated {
goal: ThreadGoal,
},
// EmitThreadGoalCleared is used to order app-server goal clears with running-thread resume responses.
EmitThreadGoalCleared,
// EmitThreadGoalSnapshot is used to read and emit the latest goal state in the listener order.
EmitThreadGoalSnapshot {
state_db: StateDbHandle,
},
// ResolveServerRequest is used to notify the client that the request has been resolved.
// It is executed in the thread listener's context to ensure that the resolved notification is ordered with regard to the request itself.
ResolveServerRequest {
request_id: RequestId,
completion_tx: oneshot::Sender<()>,
},
}
/// Per-conversation accumulation of the latest states e.g. error message while a turn runs.
#[derive(Default, Clone)]
pub(crate) struct TurnSummary {
pub(crate) started_at: Option<i64>,
pub(crate) command_execution_started: HashSet<String>,
pub(crate) last_error: Option<TurnError>,
}
#[derive(Default)]
pub(crate) struct ThreadState {
pub(crate) pending_interrupts: PendingInterruptQueue,
pub(crate) pending_rollbacks: Option<ConnectionRequestId>,
pub(crate) turn_summary: TurnSummary,
pub(crate) last_terminal_turn_id: Option<String>,
pub(crate) cancel_tx: Option<oneshot::Sender<()>>,
pub(crate) experimental_raw_events: bool,
pub(crate) listener_generation: u64,
listener_command_tx: Option<mpsc::UnboundedSender<ThreadListenerCommand>>,
current_turn_history: ThreadHistoryBuilder,
listener_thread: Option<Weak<CodexThread>>,
}
impl ThreadState {
pub(crate) fn listener_matches(&self, conversation: &Arc<CodexThread>) -> bool {
self.listener_thread
.as_ref()
.and_then(Weak::upgrade)
.is_some_and(|existing| Arc::ptr_eq(&existing, conversation))
}
pub(crate) fn set_listener(
&mut self,
cancel_tx: oneshot::Sender<()>,
conversation: &Arc<CodexThread>,
) -> (mpsc::UnboundedReceiver<ThreadListenerCommand>, u64) {
if let Some(previous) = self.cancel_tx.replace(cancel_tx) {
let _ = previous.send(());
}
self.listener_generation = self.listener_generation.wrapping_add(1);
let (listener_command_tx, listener_command_rx) = mpsc::unbounded_channel();
self.listener_command_tx = Some(listener_command_tx);
self.listener_thread = Some(Arc::downgrade(conversation));
(listener_command_rx, self.listener_generation)
}
pub(crate) fn clear_listener(&mut self) {
if let Some(cancel_tx) = self.cancel_tx.take() {
let _ = cancel_tx.send(());
}
self.listener_command_tx = None;
self.current_turn_history.reset();
self.listener_thread = None;
}
pub(crate) fn set_experimental_raw_events(&mut self, enabled: bool) {
self.experimental_raw_events = enabled;
}
pub(crate) fn listener_command_tx(
&self,
) -> Option<mpsc::UnboundedSender<ThreadListenerCommand>> {
self.listener_command_tx.clone()
}
pub(crate) fn active_turn_snapshot(&self) -> Option<Turn> {
self.current_turn_history.active_turn_snapshot()
}
pub(crate) fn track_current_turn_event(&mut self, event_turn_id: &str, event: &EventMsg) {
if let EventMsg::TurnStarted(payload) = event {
self.turn_summary.started_at = payload.started_at;
}
self.current_turn_history.handle_event(event);
if matches!(event, EventMsg::TurnAborted(_) | EventMsg::TurnComplete(_))
&& !self.current_turn_history.has_active_turn()
{
self.last_terminal_turn_id = Some(event_turn_id.to_string());
self.current_turn_history.reset();
}
}
}
pub(crate) async fn resolve_server_request_on_thread_listener(
thread_state: &Arc<Mutex<ThreadState>>,
request_id: RequestId,
) {
let (completion_tx, completion_rx) = oneshot::channel();
let listener_command_tx = {
let state = thread_state.lock().await;
state.listener_command_tx()
};
let Some(listener_command_tx) = listener_command_tx else {
error!("failed to remove pending client request: thread listener is not running");
return;
};
if listener_command_tx
.send(ThreadListenerCommand::ResolveServerRequest {
request_id,
completion_tx,
})
.is_err()
{
error!(
"failed to remove pending client request: thread listener command channel is closed"
);
return;
}
if let Err(err) = completion_rx.await {
error!("failed to remove pending client request: {err}");
}
}
struct ThreadEntry {
state: Arc<Mutex<ThreadState>>,
connection_ids: HashSet<ConnectionId>,
has_connections_watcher: watch::Sender<bool>,
}
impl Default for ThreadEntry {
fn default() -> Self {
Self {
state: Arc::new(Mutex::new(ThreadState::default())),
connection_ids: HashSet::new(),
has_connections_watcher: watch::channel(false).0,
}
}
}
impl ThreadEntry {
fn update_has_connections(&self) {
let _ = self.has_connections_watcher.send_if_modified(|current| {
let prev = *current;
*current = !self.connection_ids.is_empty();
prev != *current
});
}
}
#[derive(Default)]
struct ThreadStateManagerInner {
live_connections: HashMap<ConnectionId, ConnectionCapabilities>,
threads: HashMap<ThreadId, ThreadEntry>,
thread_ids_by_connection: HashMap<ConnectionId, HashSet<ThreadId>>,
}
#[derive(Clone, Copy, Default)]
pub(crate) struct ConnectionCapabilities {
pub(crate) request_attestation: bool,
}
#[derive(Clone, Default)]
pub(crate) struct ThreadStateManager {
state: Arc<Mutex<ThreadStateManagerInner>>,
}
impl ThreadStateManager {
pub(crate) fn new() -> Self {
Self::default()
}
pub(crate) async fn connection_initialized(
&self,
connection_id: ConnectionId,
capabilities: ConnectionCapabilities,
) {
self.state
.lock()
.await
.live_connections
.insert(connection_id, capabilities);
}
pub(crate) async fn first_attestation_capable_connection_for_thread(
&self,
thread_id: ThreadId,
) -> Option<ConnectionId> {
let state = self.state.lock().await;
state
.threads
.get(&thread_id)?
.connection_ids
.iter()
.filter_map(|connection_id| {
state
.live_connections
.get(connection_id)?
.request_attestation
.then_some(*connection_id)
})
.min_by_key(|connection_id| connection_id.0)
}
pub(crate) async fn subscribed_connection_ids(&self, thread_id: ThreadId) -> Vec<ConnectionId> {
let state = self.state.lock().await;
state
.threads
.get(&thread_id)
.map(|thread_entry| thread_entry.connection_ids.iter().copied().collect())
.unwrap_or_default()
}
pub(crate) async fn thread_state(&self, thread_id: ThreadId) -> Arc<Mutex<ThreadState>> {
let mut state = self.state.lock().await;
state.threads.entry(thread_id).or_default().state.clone()
}
pub(crate) async fn remove_thread_state(&self, thread_id: ThreadId) {
let thread_state = {
let mut state = self.state.lock().await;
let thread_state = state
.threads
.remove(&thread_id)
.map(|thread_entry| thread_entry.state);
state.thread_ids_by_connection.retain(|_, thread_ids| {
thread_ids.remove(&thread_id);
!thread_ids.is_empty()
});
thread_state
};
if let Some(thread_state) = thread_state {
let mut thread_state = thread_state.lock().await;
tracing::debug!(
thread_id = %thread_id,
listener_generation = thread_state.listener_generation,
had_listener = thread_state.cancel_tx.is_some(),
had_active_turn = thread_state.active_turn_snapshot().is_some(),
"clearing thread listener during thread-state teardown"
);
thread_state.clear_listener();
}
}
pub(crate) async fn clear_all_listeners(&self) {
let thread_states = {
let state = self.state.lock().await;
state
.threads
.iter()
.map(|(thread_id, thread_entry)| (*thread_id, thread_entry.state.clone()))
.collect::<Vec<_>>()
};
for (thread_id, thread_state) in thread_states {
let mut thread_state = thread_state.lock().await;
tracing::debug!(
thread_id = %thread_id,
listener_generation = thread_state.listener_generation,
had_listener = thread_state.cancel_tx.is_some(),
had_active_turn = thread_state.active_turn_snapshot().is_some(),
"clearing thread listener during app-server shutdown"
);
thread_state.clear_listener();
}
}
pub(crate) async fn unsubscribe_connection_from_thread(
&self,
thread_id: ThreadId,
connection_id: ConnectionId,
) -> bool {
{
let mut state = self.state.lock().await;
if !state.threads.contains_key(&thread_id) {
return false;
}
if !state
.thread_ids_by_connection
.get(&connection_id)
.is_some_and(|thread_ids| thread_ids.contains(&thread_id))
{
return false;
}
if let Some(thread_ids) = state.thread_ids_by_connection.get_mut(&connection_id) {
thread_ids.remove(&thread_id);
if thread_ids.is_empty() {
state.thread_ids_by_connection.remove(&connection_id);
}
}
if let Some(thread_entry) = state.threads.get_mut(&thread_id) {
thread_entry.connection_ids.remove(&connection_id);
thread_entry.update_has_connections();
}
};
true
}
#[cfg(test)]
pub(crate) async fn has_subscribers(&self, thread_id: ThreadId) -> bool {
self.state
.lock()
.await
.threads
.get(&thread_id)
.is_some_and(|thread_entry| !thread_entry.connection_ids.is_empty())
}
pub(crate) async fn try_ensure_connection_subscribed(
&self,
thread_id: ThreadId,
connection_id: ConnectionId,
experimental_raw_events: bool,
) -> Option<Arc<Mutex<ThreadState>>> {
let thread_state = {
let mut state = self.state.lock().await;
if !state.live_connections.contains_key(&connection_id) {
return None;
}
state
.thread_ids_by_connection
.entry(connection_id)
.or_default()
.insert(thread_id);
let thread_entry = state.threads.entry(thread_id).or_default();
thread_entry.connection_ids.insert(connection_id);
thread_entry.update_has_connections();
thread_entry.state.clone()
};
{
let mut thread_state_guard = thread_state.lock().await;
if experimental_raw_events {
thread_state_guard.set_experimental_raw_events(/*enabled*/ true);
}
}
Some(thread_state)
}
pub(crate) async fn try_add_connection_to_thread(
&self,
thread_id: ThreadId,
connection_id: ConnectionId,
) -> bool {
let mut state = self.state.lock().await;
if !state.live_connections.contains_key(&connection_id) {
return false;
}
state
.thread_ids_by_connection
.entry(connection_id)
.or_default()
.insert(thread_id);
let thread_entry = state.threads.entry(thread_id).or_default();
thread_entry.connection_ids.insert(connection_id);
thread_entry.update_has_connections();
true
}
pub(crate) async fn remove_connection(&self, connection_id: ConnectionId) -> Vec<ThreadId> {
{
let mut state = self.state.lock().await;
state.live_connections.remove(&connection_id);
let thread_ids = state
.thread_ids_by_connection
.remove(&connection_id)
.unwrap_or_default();
for thread_id in &thread_ids {
if let Some(thread_entry) = state.threads.get_mut(thread_id) {
thread_entry.connection_ids.remove(&connection_id);
thread_entry.update_has_connections();
}
}
thread_ids
.into_iter()
.filter(|thread_id| {
state
.threads
.get(thread_id)
.is_some_and(|thread_entry| thread_entry.connection_ids.is_empty())
})
.collect::<Vec<_>>()
}
}
pub(crate) async fn subscribe_to_has_connections(
&self,
thread_id: ThreadId,
) -> Option<watch::Receiver<bool>> {
let state = self.state.lock().await;
state
.threads
.get(&thread_id)
.map(|thread_entry| thread_entry.has_connections_watcher.subscribe())
}
}