diff --git a/codex-rs/app-server/src/attestation.rs b/codex-rs/app-server/src/attestation.rs new file mode 100644 index 0000000000..adfb415bab --- /dev/null +++ b/codex-rs/app-server/src/attestation.rs @@ -0,0 +1,204 @@ +use std::sync::Arc; + +use axum::http::HeaderValue; +use codex_app_server_protocol::AttestationGenerateParams; +use codex_app_server_protocol::AttestationGenerateResponse; +use codex_app_server_protocol::ServerRequestPayload; +use codex_core::AttestationContext; +use codex_core::AttestationProvider; +use codex_core::GenerateAttestationFuture; +use serde::Serialize; +use tokio::time::Duration; +use tokio::time::timeout; +use tracing::warn; + +use crate::outgoing_message::OutgoingMessageSender; +use crate::thread_state::ThreadStateManager; + +const ATTESTATION_GENERATE_TIMEOUT: Duration = Duration::from_millis(100); + +pub(crate) fn app_server_attestation_provider( + outgoing: Arc, + thread_state_manager: ThreadStateManager, +) -> Arc { + Arc::new(AppServerAttestationProvider { + outgoing, + thread_state_manager, + }) +} + +struct AppServerAttestationProvider { + outgoing: Arc, + thread_state_manager: ThreadStateManager, +} + +impl std::fmt::Debug for AppServerAttestationProvider { + fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter + .debug_struct("AppServerAttestationProvider") + .finish() + } +} + +impl AttestationProvider for AppServerAttestationProvider { + fn header_for_request(&self, context: AttestationContext) -> GenerateAttestationFuture<'_> { + let outgoing = self.outgoing.clone(); + let thread_state_manager = self.thread_state_manager.clone(); + Box::pin(async move { + if !context.uses_chatgpt_auth { + return None; + } + + request_attestation_header_value_with_timeout( + outgoing, + thread_state_manager, + ATTESTATION_GENERATE_TIMEOUT, + ) + .await + .and_then(|value| HeaderValue::from_bytes(value.as_bytes()).ok()) + }) + } +} + +async fn request_attestation_header_value_with_timeout( + outgoing: Arc, + thread_state_manager: ThreadStateManager, + timeout_duration: Duration, +) -> Option { + let connection_id = thread_state_manager + .first_attestation_capable_connection() + .await?; + + let connection_ids = [connection_id]; + let (request_id, rx) = outgoing + .send_request_to_connections( + Some(&connection_ids), + ServerRequestPayload::AttestationGenerate(AttestationGenerateParams {}), + /*thread_id*/ None, + ) + .await; + + let result = match timeout(timeout_duration, rx).await { + Ok(Ok(Ok(result))) => result, + Ok(Ok(Err(err))) => { + warn!( + code = err.code, + message = %err.message, + "attestation generation request failed" + ); + return app_server_attestation_header_value( + AppServerAttestationStatus::RequestFailed, + None, + ); + } + Ok(Err(err)) => { + warn!("attestation generation request canceled: {err}"); + return app_server_attestation_header_value( + AppServerAttestationStatus::RequestCanceled, + None, + ); + } + Err(_) => { + let _canceled = outgoing.cancel_request(&request_id).await; + warn!( + timeout_seconds = timeout_duration.as_secs(), + "attestation generation request timed out" + ); + return app_server_attestation_header_value(AppServerAttestationStatus::Timeout, None); + } + }; + + match serde_json::from_value::(result) { + Ok(response) => app_server_attestation_header_value( + AppServerAttestationStatus::Ok, + Some(&response.header_value), + ), + Err(err) => { + warn!("failed to deserialize attestation generation response: {err}"); + app_server_attestation_header_value(AppServerAttestationStatus::MalformedResponse, None) + } + } +} + +#[derive(Clone, Copy)] +enum AppServerAttestationStatus { + Ok, + Timeout, + RequestFailed, + RequestCanceled, + MalformedResponse, +} + +impl AppServerAttestationStatus { + const fn code(self) -> u8 { + match self { + Self::Ok => 0, + Self::Timeout => 1, + Self::RequestFailed => 2, + Self::RequestCanceled => 3, + Self::MalformedResponse => 4, + } + } +} + +#[derive(Serialize)] +struct AppServerAttestationEnvelope<'a> { + v: u8, + s: u8, + #[serde(skip_serializing_if = "Option::is_none")] + t: Option<&'a str>, +} + +fn app_server_attestation_header_value( + status: AppServerAttestationStatus, + token: Option<&str>, +) -> Option { + serde_json::to_string(&AppServerAttestationEnvelope { + v: 1, + s: status.code(), + t: token, + }) + .map_err(|err| warn!("failed to serialize app-server attestation envelope: {err}")) + .ok() +} + +#[cfg(test)] +mod tests { + use super::AppServerAttestationStatus; + use super::app_server_attestation_header_value; + use pretty_assertions::assert_eq; + + #[test] + fn app_server_attestation_header_value_wraps_opaque_client_payloads() { + assert_eq!( + app_server_attestation_header_value( + AppServerAttestationStatus::Ok, + Some("v1.opaque-client-payload"), + ), + Some(r#"{"v":1,"s":0,"t":"v1.opaque-client-payload"}"#.to_string()) + ); + } + + #[test] + fn app_server_attestation_header_value_reports_app_server_failures() { + assert_eq!( + app_server_attestation_header_value(AppServerAttestationStatus::Timeout, None), + Some(r#"{"v":1,"s":1}"#.to_string()) + ); + assert_eq!( + app_server_attestation_header_value(AppServerAttestationStatus::RequestFailed, None), + Some(r#"{"v":1,"s":2}"#.to_string()) + ); + assert_eq!( + app_server_attestation_header_value(AppServerAttestationStatus::RequestCanceled, None), + Some(r#"{"v":1,"s":3}"#.to_string()) + ); + assert_eq!( + app_server_attestation_header_value( + AppServerAttestationStatus::MalformedResponse, + None + ), + Some(r#"{"v":1,"s":4}"#.to_string()) + ); + } +} diff --git a/codex-rs/app-server/src/lib.rs b/codex-rs/app-server/src/lib.rs index 9bb7937e19..dd63971705 100644 --- a/codex-rs/app-server/src/lib.rs +++ b/codex-rs/app-server/src/lib.rs @@ -74,6 +74,7 @@ use tracing_subscriber::util::SubscriberInitExt; mod analytics_utils; mod app_server_tracing; +mod attestation; mod bespoke_event_handling; mod command_exec; mod config; diff --git a/codex-rs/app-server/src/message_processor.rs b/codex-rs/app-server/src/message_processor.rs index 68aca7ee75..ea89f0b036 100644 --- a/codex-rs/app-server/src/message_processor.rs +++ b/codex-rs/app-server/src/message_processor.rs @@ -4,6 +4,7 @@ use std::sync::Arc; use std::sync::OnceLock; use std::sync::atomic::AtomicBool; +use crate::attestation::app_server_attestation_provider; use crate::config_manager::ConfigManager; use crate::connection_rpc_gate::ConnectionRpcGate; use crate::error_code::invalid_request; @@ -39,11 +40,8 @@ use crate::thread_state::ThreadStateManager; use crate::transport::AppServerTransport; use crate::transport::RemoteControlHandle; use async_trait::async_trait; -use axum::http::HeaderValue; use codex_analytics::AnalyticsEventsClient; use codex_analytics::AppServerRpcTransport; -use codex_app_server_protocol::AttestationGenerateParams; -use codex_app_server_protocol::AttestationGenerateResponse; use codex_app_server_protocol::AuthMode as LoginAuthMode; use codex_app_server_protocol::ChatgptAuthTokensRefreshParams; use codex_app_server_protocol::ChatgptAuthTokensRefreshReason; @@ -62,9 +60,6 @@ use codex_app_server_protocol::ServerRequestPayload; use codex_app_server_protocol::experimental_required_message; use codex_arg0::Arg0DispatchPaths; use codex_chatgpt::workspace_settings; -use codex_core::AttestationContext; -use codex_core::AttestationProvider; -use codex_core::GenerateAttestationFuture; use codex_core::ThreadManager; use codex_core::config::Config; use codex_core::thread_store_from_config; @@ -80,7 +75,6 @@ use codex_protocol::protocol::SessionSource; use codex_protocol::protocol::W3cTraceContext; use codex_rollout::StateDbHandle; use codex_state::log_db::LogDbLayer; -use serde::Serialize; use tokio::sync::Mutex; use tokio::sync::Semaphore; use tokio::sync::broadcast; @@ -88,52 +82,9 @@ use tokio::sync::watch; use tokio::time::Duration; use tokio::time::timeout; use tracing::Instrument; -use tracing::warn; -const ATTESTATION_GENERATE_TIMEOUT: Duration = Duration::from_millis(100); const EXTERNAL_AUTH_REFRESH_TIMEOUT: Duration = Duration::from_secs(10); -#[derive(Clone, Copy)] -enum AppServerAttestationStatus { - Ok, - Timeout, - RequestFailed, - RequestCanceled, - MalformedResponse, -} - -impl AppServerAttestationStatus { - const fn code(self) -> u8 { - match self { - Self::Ok => 0, - Self::Timeout => 1, - Self::RequestFailed => 2, - Self::RequestCanceled => 3, - Self::MalformedResponse => 4, - } - } -} - -#[derive(Serialize)] -struct AppServerAttestationEnvelope<'a> { - v: u8, - s: u8, - #[serde(skip_serializing_if = "Option::is_none")] - t: Option<&'a str>, -} - -fn app_server_attestation_header_value( - status: AppServerAttestationStatus, - token: Option<&str>, -) -> String { - serde_json::to_string(&AppServerAttestationEnvelope { - v: 1, - s: status.code(), - t: token, - }) - .expect("app-server attestation envelope should serialize") -} - #[derive(Clone)] struct ExternalAuthRefreshBridge { outgoing: Arc, @@ -202,115 +153,6 @@ impl ExternalAuth for ExternalAuthRefreshBridge { } } -fn app_server_attestation_provider( - outgoing: Arc, - thread_state_manager: ThreadStateManager, -) -> Arc { - Arc::new(AppServerAttestationProvider { - outgoing, - thread_state_manager, - }) -} - -struct AppServerAttestationProvider { - outgoing: Arc, - thread_state_manager: ThreadStateManager, -} - -impl std::fmt::Debug for AppServerAttestationProvider { - fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter - .debug_struct("AppServerAttestationProvider") - .finish() - } -} - -impl AttestationProvider for AppServerAttestationProvider { - fn header_for_request(&self, context: AttestationContext) -> GenerateAttestationFuture<'_> { - let outgoing = self.outgoing.clone(); - let thread_state_manager = self.thread_state_manager.clone(); - Box::pin(async move { - if !context.uses_chatgpt_auth { - return None; - } - - request_attestation_header_value_with_timeout( - outgoing, - thread_state_manager, - ATTESTATION_GENERATE_TIMEOUT, - ) - .await - .and_then(|value| HeaderValue::from_bytes(value.as_bytes()).ok()) - }) - } -} - -async fn request_attestation_header_value_with_timeout( - outgoing: Arc, - thread_state_manager: ThreadStateManager, - timeout_duration: Duration, -) -> Option { - let connection_id = thread_state_manager - .first_attestation_capable_connection() - .await?; - - let connection_ids = [connection_id]; - let (request_id, rx) = outgoing - .send_request_to_connections( - Some(&connection_ids), - ServerRequestPayload::AttestationGenerate(AttestationGenerateParams {}), - /*thread_id*/ None, - ) - .await; - - let result = match timeout(timeout_duration, rx).await { - Ok(Ok(Ok(result))) => result, - Ok(Ok(Err(err))) => { - warn!( - code = err.code, - message = %err.message, - "attestation generation request failed" - ); - return Some(app_server_attestation_header_value( - AppServerAttestationStatus::RequestFailed, - None, - )); - } - Ok(Err(err)) => { - warn!("attestation generation request canceled: {err}"); - return Some(app_server_attestation_header_value( - AppServerAttestationStatus::RequestCanceled, - None, - )); - } - Err(_) => { - let _canceled = outgoing.cancel_request(&request_id).await; - warn!( - timeout_seconds = timeout_duration.as_secs(), - "attestation generation request timed out" - ); - return Some(app_server_attestation_header_value( - AppServerAttestationStatus::Timeout, - None, - )); - } - }; - - match serde_json::from_value::(result) { - Ok(response) => Some(app_server_attestation_header_value( - AppServerAttestationStatus::Ok, - Some(&response.header_value), - )), - Err(err) => { - warn!("failed to deserialize attestation generation response: {err}"); - Some(app_server_attestation_header_value( - AppServerAttestationStatus::MalformedResponse, - None, - )) - } - } -} - pub(crate) struct MessageProcessor { outgoing: Arc, account_processor: AccountRequestProcessor, @@ -1456,10 +1298,6 @@ impl MessageProcessor { } } -#[cfg(test)] -#[path = "message_processor_attestation_tests.rs"] -mod message_processor_attestation_tests; - #[cfg(test)] #[path = "message_processor_tracing_tests.rs"] mod message_processor_tracing_tests; diff --git a/codex-rs/app-server/src/message_processor_attestation_tests.rs b/codex-rs/app-server/src/message_processor_attestation_tests.rs deleted file mode 100644 index 86b0b706ae..0000000000 --- a/codex-rs/app-server/src/message_processor_attestation_tests.rs +++ /dev/null @@ -1,34 +0,0 @@ -use super::AppServerAttestationStatus; -use super::app_server_attestation_header_value; -use pretty_assertions::assert_eq; - -#[test] -fn app_server_attestation_header_value_wraps_opaque_client_payloads() { - assert_eq!( - app_server_attestation_header_value( - AppServerAttestationStatus::Ok, - Some("v1.opaque-client-payload"), - ), - r#"{"v":1,"s":0,"t":"v1.opaque-client-payload"}"# - ); -} - -#[test] -fn app_server_attestation_header_value_reports_app_server_failures() { - assert_eq!( - app_server_attestation_header_value(AppServerAttestationStatus::Timeout, None), - r#"{"v":1,"s":1}"# - ); - assert_eq!( - app_server_attestation_header_value(AppServerAttestationStatus::RequestFailed, None), - r#"{"v":1,"s":2}"# - ); - assert_eq!( - app_server_attestation_header_value(AppServerAttestationStatus::RequestCanceled, None), - r#"{"v":1,"s":3}"# - ); - assert_eq!( - app_server_attestation_header_value(AppServerAttestationStatus::MalformedResponse, None), - r#"{"v":1,"s":4}"# - ); -}