mirror of
https://github.com/openai/codex.git
synced 2026-05-18 02:02:30 +00:00
refactor(attestation): move policy behind provider
Co-authored-by: Codex <noreply@openai.com>
This commit is contained in:
@@ -60,7 +60,9 @@ use codex_app_server_protocol::ServerRequestPayload;
|
||||
use codex_app_server_protocol::experimental_required_message;
|
||||
use codex_arg0::Arg0DispatchPaths;
|
||||
use codex_chatgpt::workspace_settings;
|
||||
use codex_core::AttestationContext;
|
||||
use codex_core::AttestationProvider;
|
||||
use codex_core::GenerateAttestationFuture;
|
||||
use codex_core::ThreadManager;
|
||||
use codex_core::config::Config;
|
||||
use codex_core::thread_store_from_config;
|
||||
@@ -158,18 +160,45 @@ impl ExternalAuth for ExternalAuthRefreshBridge {
|
||||
fn app_server_attestation_provider(
|
||||
outgoing: Arc<OutgoingMessageSender>,
|
||||
attestation_connection_ids: Arc<Mutex<HashSet<ConnectionId>>>,
|
||||
) -> AttestationProvider {
|
||||
AttestationProvider::new(move || {
|
||||
let outgoing = outgoing.clone();
|
||||
let attestation_connection_ids = attestation_connection_ids.clone();
|
||||
Box::pin(request_attestation_header_value_with_timeout(
|
||||
outgoing,
|
||||
attestation_connection_ids,
|
||||
ATTESTATION_GENERATE_TIMEOUT,
|
||||
))
|
||||
) -> Arc<dyn AttestationProvider> {
|
||||
Arc::new(AppServerAttestationProvider {
|
||||
outgoing,
|
||||
attestation_connection_ids,
|
||||
})
|
||||
}
|
||||
|
||||
struct AppServerAttestationProvider {
|
||||
outgoing: Arc<OutgoingMessageSender>,
|
||||
attestation_connection_ids: Arc<Mutex<HashSet<ConnectionId>>>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for AppServerAttestationProvider {
|
||||
fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
formatter
|
||||
.debug_struct("AppServerAttestationProvider")
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl AttestationProvider for AppServerAttestationProvider {
|
||||
fn generate_header_value(&self, context: AttestationContext) -> GenerateAttestationFuture<'_> {
|
||||
let outgoing = self.outgoing.clone();
|
||||
let attestation_connection_ids = self.attestation_connection_ids.clone();
|
||||
Box::pin(async move {
|
||||
if !context.uses_chatgpt_auth {
|
||||
return None;
|
||||
}
|
||||
|
||||
request_attestation_header_value_with_timeout(
|
||||
outgoing,
|
||||
attestation_connection_ids,
|
||||
ATTESTATION_GENERATE_TIMEOUT,
|
||||
)
|
||||
.await
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
async fn request_attestation_header_value_with_timeout(
|
||||
outgoing: Arc<OutgoingMessageSender>,
|
||||
attestation_connection_ids: Arc<Mutex<HashSet<ConnectionId>>>,
|
||||
|
||||
@@ -1,39 +1,22 @@
|
||||
use std::fmt;
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
|
||||
use http::HeaderValue;
|
||||
|
||||
pub(crate) const X_OAI_ATTESTATION_HEADER: &str = "x-oai-attestation";
|
||||
|
||||
type GenerateAttestationFuture = Pin<Box<dyn Future<Output = Option<String>> + Send>>;
|
||||
type GenerateAttestationCallback = dyn Fn() -> GenerateAttestationFuture + Send + Sync + 'static;
|
||||
pub type GenerateAttestationFuture<'a> = Pin<Box<dyn Future<Output = Option<String>> + Send + 'a>>;
|
||||
|
||||
/// Session-scoped source for just-in-time attestation header values.
|
||||
/// Request context that host integrations can use when deciding whether to
|
||||
/// generate an attestation header value.
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub struct AttestationContext {
|
||||
pub uses_chatgpt_auth: bool,
|
||||
}
|
||||
|
||||
/// Host integration boundary for just-in-time attestation header values.
|
||||
///
|
||||
/// Host integrations provide the opaque string expected by the upstream
|
||||
/// `x-oai-attestation` header. Core validates only that it is legal as an HTTP
|
||||
/// header value before forwarding it.
|
||||
#[derive(Clone)]
|
||||
pub struct AttestationProvider {
|
||||
generate: Arc<GenerateAttestationCallback>,
|
||||
}
|
||||
|
||||
impl fmt::Debug for AttestationProvider {
|
||||
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
formatter.debug_struct("AttestationProvider").finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl AttestationProvider {
|
||||
pub fn new(generate: impl Fn() -> GenerateAttestationFuture + Send + Sync + 'static) -> Self {
|
||||
Self {
|
||||
generate: Arc::new(generate),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn generate_header(&self) -> Option<HeaderValue> {
|
||||
HeaderValue::from_str(&(self.generate)().await?).ok()
|
||||
}
|
||||
/// Implementations own the policy for when attestation should be attempted and
|
||||
/// return the opaque string expected by the upstream `x-oai-attestation`
|
||||
/// header. Core only forwards valid HTTP header values returned by the host.
|
||||
pub trait AttestationProvider: std::fmt::Debug + Send + Sync {
|
||||
fn generate_header_value(&self, context: AttestationContext) -> GenerateAttestationFuture<'_>;
|
||||
}
|
||||
|
||||
@@ -105,6 +105,7 @@ use tracing::instrument;
|
||||
use tracing::trace;
|
||||
use tracing::warn;
|
||||
|
||||
use crate::attestation::AttestationContext;
|
||||
use crate::attestation::AttestationProvider;
|
||||
use crate::attestation::X_OAI_ATTESTATION_HEADER;
|
||||
use crate::client_common::Prompt;
|
||||
@@ -172,7 +173,7 @@ struct ModelClientState {
|
||||
enable_request_compression: bool,
|
||||
include_timing_metrics: bool,
|
||||
beta_features_header: Option<String>,
|
||||
attestation_provider: Option<AttestationProvider>,
|
||||
attestation_provider: Option<Arc<dyn AttestationProvider>>,
|
||||
disable_websockets: AtomicBool,
|
||||
cached_websocket_session: StdMutex<WebsocketSession>,
|
||||
}
|
||||
@@ -317,7 +318,7 @@ impl ModelClient {
|
||||
enable_request_compression: bool,
|
||||
include_timing_metrics: bool,
|
||||
beta_features_header: Option<String>,
|
||||
attestation_provider: Option<AttestationProvider>,
|
||||
attestation_provider: Option<Arc<dyn AttestationProvider>>,
|
||||
) -> Self {
|
||||
let model_provider = create_model_provider(provider_info, auth_manager);
|
||||
let codex_api_key_env_enabled = model_provider
|
||||
@@ -649,22 +650,18 @@ impl ModelClient {
|
||||
client_metadata
|
||||
}
|
||||
|
||||
async fn generate_attestation_header(&self) -> Option<HeaderValue> {
|
||||
self.state
|
||||
.attestation_provider
|
||||
.as_ref()?
|
||||
.generate_header()
|
||||
.await
|
||||
}
|
||||
|
||||
async fn generate_attestation_header_for(
|
||||
&self,
|
||||
provider: &codex_api::Provider,
|
||||
) -> Option<HeaderValue> {
|
||||
if !provider.uses_chatgpt_auth {
|
||||
return None;
|
||||
}
|
||||
self.generate_attestation_header().await
|
||||
self.state
|
||||
.attestation_provider
|
||||
.as_ref()?
|
||||
.generate_header_value(AttestationContext {
|
||||
uses_chatgpt_auth: provider.uses_chatgpt_auth,
|
||||
})
|
||||
.await
|
||||
.and_then(|value| HeaderValue::from_str(&value).ok())
|
||||
}
|
||||
|
||||
/// Builds request telemetry for unary API calls (e.g., Compact endpoint).
|
||||
|
||||
@@ -7,7 +7,9 @@ use super::X_CODEX_PARENT_THREAD_ID_HEADER;
|
||||
use super::X_CODEX_TURN_METADATA_HEADER;
|
||||
use super::X_CODEX_WINDOW_ID_HEADER;
|
||||
use super::X_OPENAI_SUBAGENT_HEADER;
|
||||
use crate::AttestationContext;
|
||||
use crate::AttestationProvider;
|
||||
use crate::GenerateAttestationFuture;
|
||||
use codex_api::ApiError;
|
||||
use codex_api::ResponseEvent;
|
||||
use codex_app_server_protocol::AuthMode;
|
||||
@@ -494,8 +496,29 @@ fn auth_request_telemetry_context_tracks_attached_auth_and_retry_phase() {
|
||||
}
|
||||
|
||||
fn model_client_with_counting_attestation() -> (ModelClient, Arc<AtomicUsize>) {
|
||||
#[derive(Debug)]
|
||||
struct CountingAttestationProvider {
|
||||
calls: Arc<AtomicUsize>,
|
||||
}
|
||||
|
||||
impl AttestationProvider for CountingAttestationProvider {
|
||||
fn generate_header_value(
|
||||
&self,
|
||||
context: AttestationContext,
|
||||
) -> GenerateAttestationFuture<'_> {
|
||||
let calls = self.calls.clone();
|
||||
Box::pin(async move {
|
||||
if !context.uses_chatgpt_auth {
|
||||
return None;
|
||||
}
|
||||
|
||||
let call = calls.fetch_add(1, Ordering::Relaxed) + 1;
|
||||
Some(format!("v1.header-{call}"))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
let attestation_calls = Arc::new(AtomicUsize::new(0));
|
||||
let calls = attestation_calls.clone();
|
||||
let model_client = ModelClient::new(
|
||||
/*auth_manager*/ None,
|
||||
SessionId::new(),
|
||||
@@ -507,12 +530,8 @@ fn model_client_with_counting_attestation() -> (ModelClient, Arc<AtomicUsize>) {
|
||||
/*enable_request_compression*/ false,
|
||||
/*include_timing_metrics*/ false,
|
||||
/*beta_features_header*/ None,
|
||||
Some(AttestationProvider::new(move || {
|
||||
let calls = calls.clone();
|
||||
Box::pin(async move {
|
||||
let call = calls.fetch_add(1, Ordering::Relaxed) + 1;
|
||||
Some(format!("v1.header-{call}"))
|
||||
})
|
||||
Some(Arc::new(CountingAttestationProvider {
|
||||
calls: attestation_calls.clone(),
|
||||
})),
|
||||
);
|
||||
(model_client, attestation_calls)
|
||||
|
||||
@@ -178,7 +178,9 @@ mod tasks;
|
||||
mod user_shell_command;
|
||||
pub mod util;
|
||||
|
||||
pub use attestation::AttestationContext;
|
||||
pub use attestation::AttestationProvider;
|
||||
pub use attestation::GenerateAttestationFuture;
|
||||
pub use client::ModelClient;
|
||||
pub use client::ModelClientSession;
|
||||
pub use client::X_CODEX_INSTALLATION_ID_HEADER;
|
||||
|
||||
@@ -413,7 +413,7 @@ pub(crate) struct CodexSpawnArgs {
|
||||
pub(crate) environment_selections: ResolvedTurnEnvironments,
|
||||
pub(crate) analytics_events_client: Option<AnalyticsEventsClient>,
|
||||
pub(crate) thread_store: Arc<dyn ThreadStore>,
|
||||
pub(crate) attestation_provider: Option<AttestationProvider>,
|
||||
pub(crate) attestation_provider: Option<Arc<dyn AttestationProvider>>,
|
||||
}
|
||||
|
||||
pub(crate) const INITIAL_SUBMIT_ID: &str = "";
|
||||
|
||||
@@ -370,7 +370,7 @@ impl Session {
|
||||
analytics_events_client: Option<AnalyticsEventsClient>,
|
||||
thread_store: Arc<dyn ThreadStore>,
|
||||
parent_rollout_thread_trace: ThreadTraceContext,
|
||||
attestation_provider: Option<AttestationProvider>,
|
||||
attestation_provider: Option<Arc<dyn AttestationProvider>>,
|
||||
) -> anyhow::Result<Arc<Self>> {
|
||||
debug!(
|
||||
"Configuring session: model={}; provider={:?}",
|
||||
|
||||
@@ -67,7 +67,7 @@ pub(crate) struct SessionServices {
|
||||
pub(crate) state_db: Option<StateDbHandle>,
|
||||
pub(crate) live_thread: Option<LiveThread>,
|
||||
pub(crate) thread_store: Arc<dyn ThreadStore>,
|
||||
pub(crate) attestation_provider: Option<AttestationProvider>,
|
||||
pub(crate) attestation_provider: Option<Arc<dyn AttestationProvider>>,
|
||||
/// Session-scoped model client shared across turns.
|
||||
pub(crate) model_client: ModelClient,
|
||||
pub(crate) code_mode_service: CodeModeService,
|
||||
|
||||
@@ -250,7 +250,7 @@ pub(crate) struct ThreadManagerState {
|
||||
mcp_manager: Arc<McpManager>,
|
||||
skills_watcher: Arc<SkillsWatcher>,
|
||||
thread_store: Arc<dyn ThreadStore>,
|
||||
attestation_provider: Option<AttestationProvider>,
|
||||
attestation_provider: Option<Arc<dyn AttestationProvider>>,
|
||||
session_source: SessionSource,
|
||||
installation_id: String,
|
||||
analytics_events_client: Option<AnalyticsEventsClient>,
|
||||
@@ -295,7 +295,7 @@ impl ThreadManager {
|
||||
thread_store: Arc<dyn ThreadStore>,
|
||||
state_db: Option<StateDbHandle>,
|
||||
installation_id: String,
|
||||
attestation_provider: Option<AttestationProvider>,
|
||||
attestation_provider: Option<Arc<dyn AttestationProvider>>,
|
||||
) -> Self {
|
||||
let codex_home = config.codex_home.clone();
|
||||
let restriction_product = session_source.restriction_product();
|
||||
|
||||
Reference in New Issue
Block a user