refactor(attestation): move policy behind provider

Co-authored-by: Codex <noreply@openai.com>
This commit is contained in:
Jiaming Zhang
2026-05-06 21:43:17 -07:00
parent 92a0b57d36
commit bb55e050d7
9 changed files with 96 additions and 66 deletions

View File

@@ -60,7 +60,9 @@ use codex_app_server_protocol::ServerRequestPayload;
use codex_app_server_protocol::experimental_required_message;
use codex_arg0::Arg0DispatchPaths;
use codex_chatgpt::workspace_settings;
use codex_core::AttestationContext;
use codex_core::AttestationProvider;
use codex_core::GenerateAttestationFuture;
use codex_core::ThreadManager;
use codex_core::config::Config;
use codex_core::thread_store_from_config;
@@ -158,18 +160,45 @@ impl ExternalAuth for ExternalAuthRefreshBridge {
fn app_server_attestation_provider(
outgoing: Arc<OutgoingMessageSender>,
attestation_connection_ids: Arc<Mutex<HashSet<ConnectionId>>>,
) -> AttestationProvider {
AttestationProvider::new(move || {
let outgoing = outgoing.clone();
let attestation_connection_ids = attestation_connection_ids.clone();
Box::pin(request_attestation_header_value_with_timeout(
outgoing,
attestation_connection_ids,
ATTESTATION_GENERATE_TIMEOUT,
))
) -> Arc<dyn AttestationProvider> {
Arc::new(AppServerAttestationProvider {
outgoing,
attestation_connection_ids,
})
}
struct AppServerAttestationProvider {
outgoing: Arc<OutgoingMessageSender>,
attestation_connection_ids: Arc<Mutex<HashSet<ConnectionId>>>,
}
impl std::fmt::Debug for AppServerAttestationProvider {
fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
formatter
.debug_struct("AppServerAttestationProvider")
.finish()
}
}
impl AttestationProvider for AppServerAttestationProvider {
fn generate_header_value(&self, context: AttestationContext) -> GenerateAttestationFuture<'_> {
let outgoing = self.outgoing.clone();
let attestation_connection_ids = self.attestation_connection_ids.clone();
Box::pin(async move {
if !context.uses_chatgpt_auth {
return None;
}
request_attestation_header_value_with_timeout(
outgoing,
attestation_connection_ids,
ATTESTATION_GENERATE_TIMEOUT,
)
.await
})
}
}
async fn request_attestation_header_value_with_timeout(
outgoing: Arc<OutgoingMessageSender>,
attestation_connection_ids: Arc<Mutex<HashSet<ConnectionId>>>,

View File

@@ -1,39 +1,22 @@
use std::fmt;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use http::HeaderValue;
pub(crate) const X_OAI_ATTESTATION_HEADER: &str = "x-oai-attestation";
type GenerateAttestationFuture = Pin<Box<dyn Future<Output = Option<String>> + Send>>;
type GenerateAttestationCallback = dyn Fn() -> GenerateAttestationFuture + Send + Sync + 'static;
pub type GenerateAttestationFuture<'a> = Pin<Box<dyn Future<Output = Option<String>> + Send + 'a>>;
/// Session-scoped source for just-in-time attestation header values.
/// Request context that host integrations can use when deciding whether to
/// generate an attestation header value.
#[derive(Clone, Copy, Debug)]
pub struct AttestationContext {
pub uses_chatgpt_auth: bool,
}
/// Host integration boundary for just-in-time attestation header values.
///
/// Host integrations provide the opaque string expected by the upstream
/// `x-oai-attestation` header. Core validates only that it is legal as an HTTP
/// header value before forwarding it.
#[derive(Clone)]
pub struct AttestationProvider {
generate: Arc<GenerateAttestationCallback>,
}
impl fmt::Debug for AttestationProvider {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter.debug_struct("AttestationProvider").finish()
}
}
impl AttestationProvider {
pub fn new(generate: impl Fn() -> GenerateAttestationFuture + Send + Sync + 'static) -> Self {
Self {
generate: Arc::new(generate),
}
}
pub(crate) async fn generate_header(&self) -> Option<HeaderValue> {
HeaderValue::from_str(&(self.generate)().await?).ok()
}
/// Implementations own the policy for when attestation should be attempted and
/// return the opaque string expected by the upstream `x-oai-attestation`
/// header. Core only forwards valid HTTP header values returned by the host.
pub trait AttestationProvider: std::fmt::Debug + Send + Sync {
fn generate_header_value(&self, context: AttestationContext) -> GenerateAttestationFuture<'_>;
}

View File

@@ -105,6 +105,7 @@ use tracing::instrument;
use tracing::trace;
use tracing::warn;
use crate::attestation::AttestationContext;
use crate::attestation::AttestationProvider;
use crate::attestation::X_OAI_ATTESTATION_HEADER;
use crate::client_common::Prompt;
@@ -172,7 +173,7 @@ struct ModelClientState {
enable_request_compression: bool,
include_timing_metrics: bool,
beta_features_header: Option<String>,
attestation_provider: Option<AttestationProvider>,
attestation_provider: Option<Arc<dyn AttestationProvider>>,
disable_websockets: AtomicBool,
cached_websocket_session: StdMutex<WebsocketSession>,
}
@@ -317,7 +318,7 @@ impl ModelClient {
enable_request_compression: bool,
include_timing_metrics: bool,
beta_features_header: Option<String>,
attestation_provider: Option<AttestationProvider>,
attestation_provider: Option<Arc<dyn AttestationProvider>>,
) -> Self {
let model_provider = create_model_provider(provider_info, auth_manager);
let codex_api_key_env_enabled = model_provider
@@ -649,22 +650,18 @@ impl ModelClient {
client_metadata
}
async fn generate_attestation_header(&self) -> Option<HeaderValue> {
self.state
.attestation_provider
.as_ref()?
.generate_header()
.await
}
async fn generate_attestation_header_for(
&self,
provider: &codex_api::Provider,
) -> Option<HeaderValue> {
if !provider.uses_chatgpt_auth {
return None;
}
self.generate_attestation_header().await
self.state
.attestation_provider
.as_ref()?
.generate_header_value(AttestationContext {
uses_chatgpt_auth: provider.uses_chatgpt_auth,
})
.await
.and_then(|value| HeaderValue::from_str(&value).ok())
}
/// Builds request telemetry for unary API calls (e.g., Compact endpoint).

View File

@@ -7,7 +7,9 @@ use super::X_CODEX_PARENT_THREAD_ID_HEADER;
use super::X_CODEX_TURN_METADATA_HEADER;
use super::X_CODEX_WINDOW_ID_HEADER;
use super::X_OPENAI_SUBAGENT_HEADER;
use crate::AttestationContext;
use crate::AttestationProvider;
use crate::GenerateAttestationFuture;
use codex_api::ApiError;
use codex_api::ResponseEvent;
use codex_app_server_protocol::AuthMode;
@@ -494,8 +496,29 @@ fn auth_request_telemetry_context_tracks_attached_auth_and_retry_phase() {
}
fn model_client_with_counting_attestation() -> (ModelClient, Arc<AtomicUsize>) {
#[derive(Debug)]
struct CountingAttestationProvider {
calls: Arc<AtomicUsize>,
}
impl AttestationProvider for CountingAttestationProvider {
fn generate_header_value(
&self,
context: AttestationContext,
) -> GenerateAttestationFuture<'_> {
let calls = self.calls.clone();
Box::pin(async move {
if !context.uses_chatgpt_auth {
return None;
}
let call = calls.fetch_add(1, Ordering::Relaxed) + 1;
Some(format!("v1.header-{call}"))
})
}
}
let attestation_calls = Arc::new(AtomicUsize::new(0));
let calls = attestation_calls.clone();
let model_client = ModelClient::new(
/*auth_manager*/ None,
SessionId::new(),
@@ -507,12 +530,8 @@ fn model_client_with_counting_attestation() -> (ModelClient, Arc<AtomicUsize>) {
/*enable_request_compression*/ false,
/*include_timing_metrics*/ false,
/*beta_features_header*/ None,
Some(AttestationProvider::new(move || {
let calls = calls.clone();
Box::pin(async move {
let call = calls.fetch_add(1, Ordering::Relaxed) + 1;
Some(format!("v1.header-{call}"))
})
Some(Arc::new(CountingAttestationProvider {
calls: attestation_calls.clone(),
})),
);
(model_client, attestation_calls)

View File

@@ -178,7 +178,9 @@ mod tasks;
mod user_shell_command;
pub mod util;
pub use attestation::AttestationContext;
pub use attestation::AttestationProvider;
pub use attestation::GenerateAttestationFuture;
pub use client::ModelClient;
pub use client::ModelClientSession;
pub use client::X_CODEX_INSTALLATION_ID_HEADER;

View File

@@ -413,7 +413,7 @@ pub(crate) struct CodexSpawnArgs {
pub(crate) environment_selections: ResolvedTurnEnvironments,
pub(crate) analytics_events_client: Option<AnalyticsEventsClient>,
pub(crate) thread_store: Arc<dyn ThreadStore>,
pub(crate) attestation_provider: Option<AttestationProvider>,
pub(crate) attestation_provider: Option<Arc<dyn AttestationProvider>>,
}
pub(crate) const INITIAL_SUBMIT_ID: &str = "";

View File

@@ -370,7 +370,7 @@ impl Session {
analytics_events_client: Option<AnalyticsEventsClient>,
thread_store: Arc<dyn ThreadStore>,
parent_rollout_thread_trace: ThreadTraceContext,
attestation_provider: Option<AttestationProvider>,
attestation_provider: Option<Arc<dyn AttestationProvider>>,
) -> anyhow::Result<Arc<Self>> {
debug!(
"Configuring session: model={}; provider={:?}",

View File

@@ -67,7 +67,7 @@ pub(crate) struct SessionServices {
pub(crate) state_db: Option<StateDbHandle>,
pub(crate) live_thread: Option<LiveThread>,
pub(crate) thread_store: Arc<dyn ThreadStore>,
pub(crate) attestation_provider: Option<AttestationProvider>,
pub(crate) attestation_provider: Option<Arc<dyn AttestationProvider>>,
/// Session-scoped model client shared across turns.
pub(crate) model_client: ModelClient,
pub(crate) code_mode_service: CodeModeService,

View File

@@ -250,7 +250,7 @@ pub(crate) struct ThreadManagerState {
mcp_manager: Arc<McpManager>,
skills_watcher: Arc<SkillsWatcher>,
thread_store: Arc<dyn ThreadStore>,
attestation_provider: Option<AttestationProvider>,
attestation_provider: Option<Arc<dyn AttestationProvider>>,
session_source: SessionSource,
installation_id: String,
analytics_events_client: Option<AnalyticsEventsClient>,
@@ -295,7 +295,7 @@ impl ThreadManager {
thread_store: Arc<dyn ThreadStore>,
state_db: Option<StateDbHandle>,
installation_id: String,
attestation_provider: Option<AttestationProvider>,
attestation_provider: Option<Arc<dyn AttestationProvider>>,
) -> Self {
let codex_home = config.codex_home.clone();
let restriction_product = session_source.restriction_product();