From bb55e050d75c7b52af5b0977bb660cfa0ee14f70 Mon Sep 17 00:00:00 2001 From: Jiaming Zhang Date: Wed, 6 May 2026 21:43:17 -0700 Subject: [PATCH] refactor(attestation): move policy behind provider Co-authored-by: Codex --- codex-rs/app-server/src/message_processor.rs | 47 ++++++++++++++++---- codex-rs/core/src/attestation.rs | 45 ++++++------------- codex-rs/core/src/client.rs | 25 +++++------ codex-rs/core/src/client_tests.rs | 33 +++++++++++--- codex-rs/core/src/lib.rs | 2 + codex-rs/core/src/session/mod.rs | 2 +- codex-rs/core/src/session/session.rs | 2 +- codex-rs/core/src/state/service.rs | 2 +- codex-rs/core/src/thread_manager.rs | 4 +- 9 files changed, 96 insertions(+), 66 deletions(-) diff --git a/codex-rs/app-server/src/message_processor.rs b/codex-rs/app-server/src/message_processor.rs index e486023806..f950661180 100644 --- a/codex-rs/app-server/src/message_processor.rs +++ b/codex-rs/app-server/src/message_processor.rs @@ -60,7 +60,9 @@ use codex_app_server_protocol::ServerRequestPayload; use codex_app_server_protocol::experimental_required_message; use codex_arg0::Arg0DispatchPaths; use codex_chatgpt::workspace_settings; +use codex_core::AttestationContext; use codex_core::AttestationProvider; +use codex_core::GenerateAttestationFuture; use codex_core::ThreadManager; use codex_core::config::Config; use codex_core::thread_store_from_config; @@ -158,18 +160,45 @@ impl ExternalAuth for ExternalAuthRefreshBridge { fn app_server_attestation_provider( outgoing: Arc, attestation_connection_ids: Arc>>, -) -> AttestationProvider { - AttestationProvider::new(move || { - let outgoing = outgoing.clone(); - let attestation_connection_ids = attestation_connection_ids.clone(); - Box::pin(request_attestation_header_value_with_timeout( - outgoing, - attestation_connection_ids, - ATTESTATION_GENERATE_TIMEOUT, - )) +) -> Arc { + Arc::new(AppServerAttestationProvider { + outgoing, + attestation_connection_ids, }) } +struct AppServerAttestationProvider { + outgoing: Arc, + attestation_connection_ids: Arc>>, +} + +impl std::fmt::Debug for AppServerAttestationProvider { + fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter + .debug_struct("AppServerAttestationProvider") + .finish() + } +} + +impl AttestationProvider for AppServerAttestationProvider { + fn generate_header_value(&self, context: AttestationContext) -> GenerateAttestationFuture<'_> { + let outgoing = self.outgoing.clone(); + let attestation_connection_ids = self.attestation_connection_ids.clone(); + Box::pin(async move { + if !context.uses_chatgpt_auth { + return None; + } + + request_attestation_header_value_with_timeout( + outgoing, + attestation_connection_ids, + ATTESTATION_GENERATE_TIMEOUT, + ) + .await + }) + } +} + async fn request_attestation_header_value_with_timeout( outgoing: Arc, attestation_connection_ids: Arc>>, diff --git a/codex-rs/core/src/attestation.rs b/codex-rs/core/src/attestation.rs index aa637304be..8e5bdcf538 100644 --- a/codex-rs/core/src/attestation.rs +++ b/codex-rs/core/src/attestation.rs @@ -1,39 +1,22 @@ -use std::fmt; use std::future::Future; use std::pin::Pin; -use std::sync::Arc; - -use http::HeaderValue; pub(crate) const X_OAI_ATTESTATION_HEADER: &str = "x-oai-attestation"; -type GenerateAttestationFuture = Pin> + Send>>; -type GenerateAttestationCallback = dyn Fn() -> GenerateAttestationFuture + Send + Sync + 'static; +pub type GenerateAttestationFuture<'a> = Pin> + Send + 'a>>; -/// Session-scoped source for just-in-time attestation header values. +/// Request context that host integrations can use when deciding whether to +/// generate an attestation header value. +#[derive(Clone, Copy, Debug)] +pub struct AttestationContext { + pub uses_chatgpt_auth: bool, +} + +/// Host integration boundary for just-in-time attestation header values. /// -/// Host integrations provide the opaque string expected by the upstream -/// `x-oai-attestation` header. Core validates only that it is legal as an HTTP -/// header value before forwarding it. -#[derive(Clone)] -pub struct AttestationProvider { - generate: Arc, -} - -impl fmt::Debug for AttestationProvider { - fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { - formatter.debug_struct("AttestationProvider").finish() - } -} - -impl AttestationProvider { - pub fn new(generate: impl Fn() -> GenerateAttestationFuture + Send + Sync + 'static) -> Self { - Self { - generate: Arc::new(generate), - } - } - - pub(crate) async fn generate_header(&self) -> Option { - HeaderValue::from_str(&(self.generate)().await?).ok() - } +/// Implementations own the policy for when attestation should be attempted and +/// return the opaque string expected by the upstream `x-oai-attestation` +/// header. Core only forwards valid HTTP header values returned by the host. +pub trait AttestationProvider: std::fmt::Debug + Send + Sync { + fn generate_header_value(&self, context: AttestationContext) -> GenerateAttestationFuture<'_>; } diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs index 9a0ec4e2ca..21b69064a7 100644 --- a/codex-rs/core/src/client.rs +++ b/codex-rs/core/src/client.rs @@ -105,6 +105,7 @@ use tracing::instrument; use tracing::trace; use tracing::warn; +use crate::attestation::AttestationContext; use crate::attestation::AttestationProvider; use crate::attestation::X_OAI_ATTESTATION_HEADER; use crate::client_common::Prompt; @@ -172,7 +173,7 @@ struct ModelClientState { enable_request_compression: bool, include_timing_metrics: bool, beta_features_header: Option, - attestation_provider: Option, + attestation_provider: Option>, disable_websockets: AtomicBool, cached_websocket_session: StdMutex, } @@ -317,7 +318,7 @@ impl ModelClient { enable_request_compression: bool, include_timing_metrics: bool, beta_features_header: Option, - attestation_provider: Option, + attestation_provider: Option>, ) -> Self { let model_provider = create_model_provider(provider_info, auth_manager); let codex_api_key_env_enabled = model_provider @@ -649,22 +650,18 @@ impl ModelClient { client_metadata } - async fn generate_attestation_header(&self) -> Option { - self.state - .attestation_provider - .as_ref()? - .generate_header() - .await - } - async fn generate_attestation_header_for( &self, provider: &codex_api::Provider, ) -> Option { - if !provider.uses_chatgpt_auth { - return None; - } - self.generate_attestation_header().await + self.state + .attestation_provider + .as_ref()? + .generate_header_value(AttestationContext { + uses_chatgpt_auth: provider.uses_chatgpt_auth, + }) + .await + .and_then(|value| HeaderValue::from_str(&value).ok()) } /// Builds request telemetry for unary API calls (e.g., Compact endpoint). diff --git a/codex-rs/core/src/client_tests.rs b/codex-rs/core/src/client_tests.rs index 4bfa5b8a5a..38058e2baa 100644 --- a/codex-rs/core/src/client_tests.rs +++ b/codex-rs/core/src/client_tests.rs @@ -7,7 +7,9 @@ use super::X_CODEX_PARENT_THREAD_ID_HEADER; use super::X_CODEX_TURN_METADATA_HEADER; use super::X_CODEX_WINDOW_ID_HEADER; use super::X_OPENAI_SUBAGENT_HEADER; +use crate::AttestationContext; use crate::AttestationProvider; +use crate::GenerateAttestationFuture; use codex_api::ApiError; use codex_api::ResponseEvent; use codex_app_server_protocol::AuthMode; @@ -494,8 +496,29 @@ fn auth_request_telemetry_context_tracks_attached_auth_and_retry_phase() { } fn model_client_with_counting_attestation() -> (ModelClient, Arc) { + #[derive(Debug)] + struct CountingAttestationProvider { + calls: Arc, + } + + impl AttestationProvider for CountingAttestationProvider { + fn generate_header_value( + &self, + context: AttestationContext, + ) -> GenerateAttestationFuture<'_> { + let calls = self.calls.clone(); + Box::pin(async move { + if !context.uses_chatgpt_auth { + return None; + } + + let call = calls.fetch_add(1, Ordering::Relaxed) + 1; + Some(format!("v1.header-{call}")) + }) + } + } + let attestation_calls = Arc::new(AtomicUsize::new(0)); - let calls = attestation_calls.clone(); let model_client = ModelClient::new( /*auth_manager*/ None, SessionId::new(), @@ -507,12 +530,8 @@ fn model_client_with_counting_attestation() -> (ModelClient, Arc) { /*enable_request_compression*/ false, /*include_timing_metrics*/ false, /*beta_features_header*/ None, - Some(AttestationProvider::new(move || { - let calls = calls.clone(); - Box::pin(async move { - let call = calls.fetch_add(1, Ordering::Relaxed) + 1; - Some(format!("v1.header-{call}")) - }) + Some(Arc::new(CountingAttestationProvider { + calls: attestation_calls.clone(), })), ); (model_client, attestation_calls) diff --git a/codex-rs/core/src/lib.rs b/codex-rs/core/src/lib.rs index c178b5d4e8..57ac9b4a59 100644 --- a/codex-rs/core/src/lib.rs +++ b/codex-rs/core/src/lib.rs @@ -178,7 +178,9 @@ mod tasks; mod user_shell_command; pub mod util; +pub use attestation::AttestationContext; pub use attestation::AttestationProvider; +pub use attestation::GenerateAttestationFuture; pub use client::ModelClient; pub use client::ModelClientSession; pub use client::X_CODEX_INSTALLATION_ID_HEADER; diff --git a/codex-rs/core/src/session/mod.rs b/codex-rs/core/src/session/mod.rs index 07ef550f51..d3a00b3108 100644 --- a/codex-rs/core/src/session/mod.rs +++ b/codex-rs/core/src/session/mod.rs @@ -413,7 +413,7 @@ pub(crate) struct CodexSpawnArgs { pub(crate) environment_selections: ResolvedTurnEnvironments, pub(crate) analytics_events_client: Option, pub(crate) thread_store: Arc, - pub(crate) attestation_provider: Option, + pub(crate) attestation_provider: Option>, } pub(crate) const INITIAL_SUBMIT_ID: &str = ""; diff --git a/codex-rs/core/src/session/session.rs b/codex-rs/core/src/session/session.rs index f5faccf0be..1a790314d5 100644 --- a/codex-rs/core/src/session/session.rs +++ b/codex-rs/core/src/session/session.rs @@ -370,7 +370,7 @@ impl Session { analytics_events_client: Option, thread_store: Arc, parent_rollout_thread_trace: ThreadTraceContext, - attestation_provider: Option, + attestation_provider: Option>, ) -> anyhow::Result> { debug!( "Configuring session: model={}; provider={:?}", diff --git a/codex-rs/core/src/state/service.rs b/codex-rs/core/src/state/service.rs index 4506c0054c..0dba931296 100644 --- a/codex-rs/core/src/state/service.rs +++ b/codex-rs/core/src/state/service.rs @@ -67,7 +67,7 @@ pub(crate) struct SessionServices { pub(crate) state_db: Option, pub(crate) live_thread: Option, pub(crate) thread_store: Arc, - pub(crate) attestation_provider: Option, + pub(crate) attestation_provider: Option>, /// Session-scoped model client shared across turns. pub(crate) model_client: ModelClient, pub(crate) code_mode_service: CodeModeService, diff --git a/codex-rs/core/src/thread_manager.rs b/codex-rs/core/src/thread_manager.rs index 2b8cfa8b39..d2c433d44e 100644 --- a/codex-rs/core/src/thread_manager.rs +++ b/codex-rs/core/src/thread_manager.rs @@ -250,7 +250,7 @@ pub(crate) struct ThreadManagerState { mcp_manager: Arc, skills_watcher: Arc, thread_store: Arc, - attestation_provider: Option, + attestation_provider: Option>, session_source: SessionSource, installation_id: String, analytics_events_client: Option, @@ -295,7 +295,7 @@ impl ThreadManager { thread_store: Arc, state_db: Option, installation_id: String, - attestation_provider: Option, + attestation_provider: Option>, ) -> Self { let codex_home = config.codex_home.clone(); let restriction_product = session_source.restriction_product();