codex: move app-server attestation into module

This commit is contained in:
Jiaming Zhang
2026-05-07 09:14:43 -07:00
parent 8ee7bf6abc
commit e77b359ade
4 changed files with 206 additions and 197 deletions

View File

@@ -0,0 +1,204 @@
use std::sync::Arc;
use axum::http::HeaderValue;
use codex_app_server_protocol::AttestationGenerateParams;
use codex_app_server_protocol::AttestationGenerateResponse;
use codex_app_server_protocol::ServerRequestPayload;
use codex_core::AttestationContext;
use codex_core::AttestationProvider;
use codex_core::GenerateAttestationFuture;
use serde::Serialize;
use tokio::time::Duration;
use tokio::time::timeout;
use tracing::warn;
use crate::outgoing_message::OutgoingMessageSender;
use crate::thread_state::ThreadStateManager;
const ATTESTATION_GENERATE_TIMEOUT: Duration = Duration::from_millis(100);
pub(crate) fn app_server_attestation_provider(
outgoing: Arc<OutgoingMessageSender>,
thread_state_manager: ThreadStateManager,
) -> Arc<dyn AttestationProvider> {
Arc::new(AppServerAttestationProvider {
outgoing,
thread_state_manager,
})
}
struct AppServerAttestationProvider {
outgoing: Arc<OutgoingMessageSender>,
thread_state_manager: ThreadStateManager,
}
impl std::fmt::Debug for AppServerAttestationProvider {
fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
formatter
.debug_struct("AppServerAttestationProvider")
.finish()
}
}
impl AttestationProvider for AppServerAttestationProvider {
fn header_for_request(&self, context: AttestationContext) -> GenerateAttestationFuture<'_> {
let outgoing = self.outgoing.clone();
let thread_state_manager = self.thread_state_manager.clone();
Box::pin(async move {
if !context.uses_chatgpt_auth {
return None;
}
request_attestation_header_value_with_timeout(
outgoing,
thread_state_manager,
ATTESTATION_GENERATE_TIMEOUT,
)
.await
.and_then(|value| HeaderValue::from_bytes(value.as_bytes()).ok())
})
}
}
async fn request_attestation_header_value_with_timeout(
outgoing: Arc<OutgoingMessageSender>,
thread_state_manager: ThreadStateManager,
timeout_duration: Duration,
) -> Option<String> {
let connection_id = thread_state_manager
.first_attestation_capable_connection()
.await?;
let connection_ids = [connection_id];
let (request_id, rx) = outgoing
.send_request_to_connections(
Some(&connection_ids),
ServerRequestPayload::AttestationGenerate(AttestationGenerateParams {}),
/*thread_id*/ None,
)
.await;
let result = match timeout(timeout_duration, rx).await {
Ok(Ok(Ok(result))) => result,
Ok(Ok(Err(err))) => {
warn!(
code = err.code,
message = %err.message,
"attestation generation request failed"
);
return app_server_attestation_header_value(
AppServerAttestationStatus::RequestFailed,
None,
);
}
Ok(Err(err)) => {
warn!("attestation generation request canceled: {err}");
return app_server_attestation_header_value(
AppServerAttestationStatus::RequestCanceled,
None,
);
}
Err(_) => {
let _canceled = outgoing.cancel_request(&request_id).await;
warn!(
timeout_seconds = timeout_duration.as_secs(),
"attestation generation request timed out"
);
return app_server_attestation_header_value(AppServerAttestationStatus::Timeout, None);
}
};
match serde_json::from_value::<AttestationGenerateResponse>(result) {
Ok(response) => app_server_attestation_header_value(
AppServerAttestationStatus::Ok,
Some(&response.header_value),
),
Err(err) => {
warn!("failed to deserialize attestation generation response: {err}");
app_server_attestation_header_value(AppServerAttestationStatus::MalformedResponse, None)
}
}
}
#[derive(Clone, Copy)]
enum AppServerAttestationStatus {
Ok,
Timeout,
RequestFailed,
RequestCanceled,
MalformedResponse,
}
impl AppServerAttestationStatus {
const fn code(self) -> u8 {
match self {
Self::Ok => 0,
Self::Timeout => 1,
Self::RequestFailed => 2,
Self::RequestCanceled => 3,
Self::MalformedResponse => 4,
}
}
}
#[derive(Serialize)]
struct AppServerAttestationEnvelope<'a> {
v: u8,
s: u8,
#[serde(skip_serializing_if = "Option::is_none")]
t: Option<&'a str>,
}
fn app_server_attestation_header_value(
status: AppServerAttestationStatus,
token: Option<&str>,
) -> Option<String> {
serde_json::to_string(&AppServerAttestationEnvelope {
v: 1,
s: status.code(),
t: token,
})
.map_err(|err| warn!("failed to serialize app-server attestation envelope: {err}"))
.ok()
}
#[cfg(test)]
mod tests {
use super::AppServerAttestationStatus;
use super::app_server_attestation_header_value;
use pretty_assertions::assert_eq;
#[test]
fn app_server_attestation_header_value_wraps_opaque_client_payloads() {
assert_eq!(
app_server_attestation_header_value(
AppServerAttestationStatus::Ok,
Some("v1.opaque-client-payload"),
),
Some(r#"{"v":1,"s":0,"t":"v1.opaque-client-payload"}"#.to_string())
);
}
#[test]
fn app_server_attestation_header_value_reports_app_server_failures() {
assert_eq!(
app_server_attestation_header_value(AppServerAttestationStatus::Timeout, None),
Some(r#"{"v":1,"s":1}"#.to_string())
);
assert_eq!(
app_server_attestation_header_value(AppServerAttestationStatus::RequestFailed, None),
Some(r#"{"v":1,"s":2}"#.to_string())
);
assert_eq!(
app_server_attestation_header_value(AppServerAttestationStatus::RequestCanceled, None),
Some(r#"{"v":1,"s":3}"#.to_string())
);
assert_eq!(
app_server_attestation_header_value(
AppServerAttestationStatus::MalformedResponse,
None
),
Some(r#"{"v":1,"s":4}"#.to_string())
);
}
}

View File

@@ -74,6 +74,7 @@ use tracing_subscriber::util::SubscriberInitExt;
mod analytics_utils;
mod app_server_tracing;
mod attestation;
mod bespoke_event_handling;
mod command_exec;
mod config;

View File

@@ -4,6 +4,7 @@ use std::sync::Arc;
use std::sync::OnceLock;
use std::sync::atomic::AtomicBool;
use crate::attestation::app_server_attestation_provider;
use crate::config_manager::ConfigManager;
use crate::connection_rpc_gate::ConnectionRpcGate;
use crate::error_code::invalid_request;
@@ -39,11 +40,8 @@ use crate::thread_state::ThreadStateManager;
use crate::transport::AppServerTransport;
use crate::transport::RemoteControlHandle;
use async_trait::async_trait;
use axum::http::HeaderValue;
use codex_analytics::AnalyticsEventsClient;
use codex_analytics::AppServerRpcTransport;
use codex_app_server_protocol::AttestationGenerateParams;
use codex_app_server_protocol::AttestationGenerateResponse;
use codex_app_server_protocol::AuthMode as LoginAuthMode;
use codex_app_server_protocol::ChatgptAuthTokensRefreshParams;
use codex_app_server_protocol::ChatgptAuthTokensRefreshReason;
@@ -62,9 +60,6 @@ use codex_app_server_protocol::ServerRequestPayload;
use codex_app_server_protocol::experimental_required_message;
use codex_arg0::Arg0DispatchPaths;
use codex_chatgpt::workspace_settings;
use codex_core::AttestationContext;
use codex_core::AttestationProvider;
use codex_core::GenerateAttestationFuture;
use codex_core::ThreadManager;
use codex_core::config::Config;
use codex_core::thread_store_from_config;
@@ -80,7 +75,6 @@ use codex_protocol::protocol::SessionSource;
use codex_protocol::protocol::W3cTraceContext;
use codex_rollout::StateDbHandle;
use codex_state::log_db::LogDbLayer;
use serde::Serialize;
use tokio::sync::Mutex;
use tokio::sync::Semaphore;
use tokio::sync::broadcast;
@@ -88,52 +82,9 @@ use tokio::sync::watch;
use tokio::time::Duration;
use tokio::time::timeout;
use tracing::Instrument;
use tracing::warn;
const ATTESTATION_GENERATE_TIMEOUT: Duration = Duration::from_millis(100);
const EXTERNAL_AUTH_REFRESH_TIMEOUT: Duration = Duration::from_secs(10);
#[derive(Clone, Copy)]
enum AppServerAttestationStatus {
Ok,
Timeout,
RequestFailed,
RequestCanceled,
MalformedResponse,
}
impl AppServerAttestationStatus {
const fn code(self) -> u8 {
match self {
Self::Ok => 0,
Self::Timeout => 1,
Self::RequestFailed => 2,
Self::RequestCanceled => 3,
Self::MalformedResponse => 4,
}
}
}
#[derive(Serialize)]
struct AppServerAttestationEnvelope<'a> {
v: u8,
s: u8,
#[serde(skip_serializing_if = "Option::is_none")]
t: Option<&'a str>,
}
fn app_server_attestation_header_value(
status: AppServerAttestationStatus,
token: Option<&str>,
) -> String {
serde_json::to_string(&AppServerAttestationEnvelope {
v: 1,
s: status.code(),
t: token,
})
.expect("app-server attestation envelope should serialize")
}
#[derive(Clone)]
struct ExternalAuthRefreshBridge {
outgoing: Arc<OutgoingMessageSender>,
@@ -202,115 +153,6 @@ impl ExternalAuth for ExternalAuthRefreshBridge {
}
}
fn app_server_attestation_provider(
outgoing: Arc<OutgoingMessageSender>,
thread_state_manager: ThreadStateManager,
) -> Arc<dyn AttestationProvider> {
Arc::new(AppServerAttestationProvider {
outgoing,
thread_state_manager,
})
}
struct AppServerAttestationProvider {
outgoing: Arc<OutgoingMessageSender>,
thread_state_manager: ThreadStateManager,
}
impl std::fmt::Debug for AppServerAttestationProvider {
fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
formatter
.debug_struct("AppServerAttestationProvider")
.finish()
}
}
impl AttestationProvider for AppServerAttestationProvider {
fn header_for_request(&self, context: AttestationContext) -> GenerateAttestationFuture<'_> {
let outgoing = self.outgoing.clone();
let thread_state_manager = self.thread_state_manager.clone();
Box::pin(async move {
if !context.uses_chatgpt_auth {
return None;
}
request_attestation_header_value_with_timeout(
outgoing,
thread_state_manager,
ATTESTATION_GENERATE_TIMEOUT,
)
.await
.and_then(|value| HeaderValue::from_bytes(value.as_bytes()).ok())
})
}
}
async fn request_attestation_header_value_with_timeout(
outgoing: Arc<OutgoingMessageSender>,
thread_state_manager: ThreadStateManager,
timeout_duration: Duration,
) -> Option<String> {
let connection_id = thread_state_manager
.first_attestation_capable_connection()
.await?;
let connection_ids = [connection_id];
let (request_id, rx) = outgoing
.send_request_to_connections(
Some(&connection_ids),
ServerRequestPayload::AttestationGenerate(AttestationGenerateParams {}),
/*thread_id*/ None,
)
.await;
let result = match timeout(timeout_duration, rx).await {
Ok(Ok(Ok(result))) => result,
Ok(Ok(Err(err))) => {
warn!(
code = err.code,
message = %err.message,
"attestation generation request failed"
);
return Some(app_server_attestation_header_value(
AppServerAttestationStatus::RequestFailed,
None,
));
}
Ok(Err(err)) => {
warn!("attestation generation request canceled: {err}");
return Some(app_server_attestation_header_value(
AppServerAttestationStatus::RequestCanceled,
None,
));
}
Err(_) => {
let _canceled = outgoing.cancel_request(&request_id).await;
warn!(
timeout_seconds = timeout_duration.as_secs(),
"attestation generation request timed out"
);
return Some(app_server_attestation_header_value(
AppServerAttestationStatus::Timeout,
None,
));
}
};
match serde_json::from_value::<AttestationGenerateResponse>(result) {
Ok(response) => Some(app_server_attestation_header_value(
AppServerAttestationStatus::Ok,
Some(&response.header_value),
)),
Err(err) => {
warn!("failed to deserialize attestation generation response: {err}");
Some(app_server_attestation_header_value(
AppServerAttestationStatus::MalformedResponse,
None,
))
}
}
}
pub(crate) struct MessageProcessor {
outgoing: Arc<OutgoingMessageSender>,
account_processor: AccountRequestProcessor,
@@ -1456,10 +1298,6 @@ impl MessageProcessor {
}
}
#[cfg(test)]
#[path = "message_processor_attestation_tests.rs"]
mod message_processor_attestation_tests;
#[cfg(test)]
#[path = "message_processor_tracing_tests.rs"]
mod message_processor_tracing_tests;

View File

@@ -1,34 +0,0 @@
use super::AppServerAttestationStatus;
use super::app_server_attestation_header_value;
use pretty_assertions::assert_eq;
#[test]
fn app_server_attestation_header_value_wraps_opaque_client_payloads() {
assert_eq!(
app_server_attestation_header_value(
AppServerAttestationStatus::Ok,
Some("v1.opaque-client-payload"),
),
r#"{"v":1,"s":0,"t":"v1.opaque-client-payload"}"#
);
}
#[test]
fn app_server_attestation_header_value_reports_app_server_failures() {
assert_eq!(
app_server_attestation_header_value(AppServerAttestationStatus::Timeout, None),
r#"{"v":1,"s":1}"#
);
assert_eq!(
app_server_attestation_header_value(AppServerAttestationStatus::RequestFailed, None),
r#"{"v":1,"s":2}"#
);
assert_eq!(
app_server_attestation_header_value(AppServerAttestationStatus::RequestCanceled, None),
r#"{"v":1,"s":3}"#
);
assert_eq!(
app_server_attestation_header_value(AppServerAttestationStatus::MalformedResponse, None),
r#"{"v":1,"s":4}"#
);
}