mirror of
https://github.com/openai/codex.git
synced 2026-05-18 10:12:59 +00:00
codex: move app-server attestation into module
This commit is contained in:
204
codex-rs/app-server/src/attestation.rs
Normal file
204
codex-rs/app-server/src/attestation.rs
Normal file
@@ -0,0 +1,204 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::http::HeaderValue;
|
||||
use codex_app_server_protocol::AttestationGenerateParams;
|
||||
use codex_app_server_protocol::AttestationGenerateResponse;
|
||||
use codex_app_server_protocol::ServerRequestPayload;
|
||||
use codex_core::AttestationContext;
|
||||
use codex_core::AttestationProvider;
|
||||
use codex_core::GenerateAttestationFuture;
|
||||
use serde::Serialize;
|
||||
use tokio::time::Duration;
|
||||
use tokio::time::timeout;
|
||||
use tracing::warn;
|
||||
|
||||
use crate::outgoing_message::OutgoingMessageSender;
|
||||
use crate::thread_state::ThreadStateManager;
|
||||
|
||||
const ATTESTATION_GENERATE_TIMEOUT: Duration = Duration::from_millis(100);
|
||||
|
||||
pub(crate) fn app_server_attestation_provider(
|
||||
outgoing: Arc<OutgoingMessageSender>,
|
||||
thread_state_manager: ThreadStateManager,
|
||||
) -> Arc<dyn AttestationProvider> {
|
||||
Arc::new(AppServerAttestationProvider {
|
||||
outgoing,
|
||||
thread_state_manager,
|
||||
})
|
||||
}
|
||||
|
||||
struct AppServerAttestationProvider {
|
||||
outgoing: Arc<OutgoingMessageSender>,
|
||||
thread_state_manager: ThreadStateManager,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for AppServerAttestationProvider {
|
||||
fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
formatter
|
||||
.debug_struct("AppServerAttestationProvider")
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl AttestationProvider for AppServerAttestationProvider {
|
||||
fn header_for_request(&self, context: AttestationContext) -> GenerateAttestationFuture<'_> {
|
||||
let outgoing = self.outgoing.clone();
|
||||
let thread_state_manager = self.thread_state_manager.clone();
|
||||
Box::pin(async move {
|
||||
if !context.uses_chatgpt_auth {
|
||||
return None;
|
||||
}
|
||||
|
||||
request_attestation_header_value_with_timeout(
|
||||
outgoing,
|
||||
thread_state_manager,
|
||||
ATTESTATION_GENERATE_TIMEOUT,
|
||||
)
|
||||
.await
|
||||
.and_then(|value| HeaderValue::from_bytes(value.as_bytes()).ok())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
async fn request_attestation_header_value_with_timeout(
|
||||
outgoing: Arc<OutgoingMessageSender>,
|
||||
thread_state_manager: ThreadStateManager,
|
||||
timeout_duration: Duration,
|
||||
) -> Option<String> {
|
||||
let connection_id = thread_state_manager
|
||||
.first_attestation_capable_connection()
|
||||
.await?;
|
||||
|
||||
let connection_ids = [connection_id];
|
||||
let (request_id, rx) = outgoing
|
||||
.send_request_to_connections(
|
||||
Some(&connection_ids),
|
||||
ServerRequestPayload::AttestationGenerate(AttestationGenerateParams {}),
|
||||
/*thread_id*/ None,
|
||||
)
|
||||
.await;
|
||||
|
||||
let result = match timeout(timeout_duration, rx).await {
|
||||
Ok(Ok(Ok(result))) => result,
|
||||
Ok(Ok(Err(err))) => {
|
||||
warn!(
|
||||
code = err.code,
|
||||
message = %err.message,
|
||||
"attestation generation request failed"
|
||||
);
|
||||
return app_server_attestation_header_value(
|
||||
AppServerAttestationStatus::RequestFailed,
|
||||
None,
|
||||
);
|
||||
}
|
||||
Ok(Err(err)) => {
|
||||
warn!("attestation generation request canceled: {err}");
|
||||
return app_server_attestation_header_value(
|
||||
AppServerAttestationStatus::RequestCanceled,
|
||||
None,
|
||||
);
|
||||
}
|
||||
Err(_) => {
|
||||
let _canceled = outgoing.cancel_request(&request_id).await;
|
||||
warn!(
|
||||
timeout_seconds = timeout_duration.as_secs(),
|
||||
"attestation generation request timed out"
|
||||
);
|
||||
return app_server_attestation_header_value(AppServerAttestationStatus::Timeout, None);
|
||||
}
|
||||
};
|
||||
|
||||
match serde_json::from_value::<AttestationGenerateResponse>(result) {
|
||||
Ok(response) => app_server_attestation_header_value(
|
||||
AppServerAttestationStatus::Ok,
|
||||
Some(&response.header_value),
|
||||
),
|
||||
Err(err) => {
|
||||
warn!("failed to deserialize attestation generation response: {err}");
|
||||
app_server_attestation_header_value(AppServerAttestationStatus::MalformedResponse, None)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
enum AppServerAttestationStatus {
|
||||
Ok,
|
||||
Timeout,
|
||||
RequestFailed,
|
||||
RequestCanceled,
|
||||
MalformedResponse,
|
||||
}
|
||||
|
||||
impl AppServerAttestationStatus {
|
||||
const fn code(self) -> u8 {
|
||||
match self {
|
||||
Self::Ok => 0,
|
||||
Self::Timeout => 1,
|
||||
Self::RequestFailed => 2,
|
||||
Self::RequestCanceled => 3,
|
||||
Self::MalformedResponse => 4,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct AppServerAttestationEnvelope<'a> {
|
||||
v: u8,
|
||||
s: u8,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
t: Option<&'a str>,
|
||||
}
|
||||
|
||||
fn app_server_attestation_header_value(
|
||||
status: AppServerAttestationStatus,
|
||||
token: Option<&str>,
|
||||
) -> Option<String> {
|
||||
serde_json::to_string(&AppServerAttestationEnvelope {
|
||||
v: 1,
|
||||
s: status.code(),
|
||||
t: token,
|
||||
})
|
||||
.map_err(|err| warn!("failed to serialize app-server attestation envelope: {err}"))
|
||||
.ok()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::AppServerAttestationStatus;
|
||||
use super::app_server_attestation_header_value;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
#[test]
|
||||
fn app_server_attestation_header_value_wraps_opaque_client_payloads() {
|
||||
assert_eq!(
|
||||
app_server_attestation_header_value(
|
||||
AppServerAttestationStatus::Ok,
|
||||
Some("v1.opaque-client-payload"),
|
||||
),
|
||||
Some(r#"{"v":1,"s":0,"t":"v1.opaque-client-payload"}"#.to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn app_server_attestation_header_value_reports_app_server_failures() {
|
||||
assert_eq!(
|
||||
app_server_attestation_header_value(AppServerAttestationStatus::Timeout, None),
|
||||
Some(r#"{"v":1,"s":1}"#.to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
app_server_attestation_header_value(AppServerAttestationStatus::RequestFailed, None),
|
||||
Some(r#"{"v":1,"s":2}"#.to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
app_server_attestation_header_value(AppServerAttestationStatus::RequestCanceled, None),
|
||||
Some(r#"{"v":1,"s":3}"#.to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
app_server_attestation_header_value(
|
||||
AppServerAttestationStatus::MalformedResponse,
|
||||
None
|
||||
),
|
||||
Some(r#"{"v":1,"s":4}"#.to_string())
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -74,6 +74,7 @@ use tracing_subscriber::util::SubscriberInitExt;
|
||||
|
||||
mod analytics_utils;
|
||||
mod app_server_tracing;
|
||||
mod attestation;
|
||||
mod bespoke_event_handling;
|
||||
mod command_exec;
|
||||
mod config;
|
||||
|
||||
@@ -4,6 +4,7 @@ use std::sync::Arc;
|
||||
use std::sync::OnceLock;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
|
||||
use crate::attestation::app_server_attestation_provider;
|
||||
use crate::config_manager::ConfigManager;
|
||||
use crate::connection_rpc_gate::ConnectionRpcGate;
|
||||
use crate::error_code::invalid_request;
|
||||
@@ -39,11 +40,8 @@ use crate::thread_state::ThreadStateManager;
|
||||
use crate::transport::AppServerTransport;
|
||||
use crate::transport::RemoteControlHandle;
|
||||
use async_trait::async_trait;
|
||||
use axum::http::HeaderValue;
|
||||
use codex_analytics::AnalyticsEventsClient;
|
||||
use codex_analytics::AppServerRpcTransport;
|
||||
use codex_app_server_protocol::AttestationGenerateParams;
|
||||
use codex_app_server_protocol::AttestationGenerateResponse;
|
||||
use codex_app_server_protocol::AuthMode as LoginAuthMode;
|
||||
use codex_app_server_protocol::ChatgptAuthTokensRefreshParams;
|
||||
use codex_app_server_protocol::ChatgptAuthTokensRefreshReason;
|
||||
@@ -62,9 +60,6 @@ use codex_app_server_protocol::ServerRequestPayload;
|
||||
use codex_app_server_protocol::experimental_required_message;
|
||||
use codex_arg0::Arg0DispatchPaths;
|
||||
use codex_chatgpt::workspace_settings;
|
||||
use codex_core::AttestationContext;
|
||||
use codex_core::AttestationProvider;
|
||||
use codex_core::GenerateAttestationFuture;
|
||||
use codex_core::ThreadManager;
|
||||
use codex_core::config::Config;
|
||||
use codex_core::thread_store_from_config;
|
||||
@@ -80,7 +75,6 @@ use codex_protocol::protocol::SessionSource;
|
||||
use codex_protocol::protocol::W3cTraceContext;
|
||||
use codex_rollout::StateDbHandle;
|
||||
use codex_state::log_db::LogDbLayer;
|
||||
use serde::Serialize;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::Semaphore;
|
||||
use tokio::sync::broadcast;
|
||||
@@ -88,52 +82,9 @@ use tokio::sync::watch;
|
||||
use tokio::time::Duration;
|
||||
use tokio::time::timeout;
|
||||
use tracing::Instrument;
|
||||
use tracing::warn;
|
||||
|
||||
const ATTESTATION_GENERATE_TIMEOUT: Duration = Duration::from_millis(100);
|
||||
const EXTERNAL_AUTH_REFRESH_TIMEOUT: Duration = Duration::from_secs(10);
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
enum AppServerAttestationStatus {
|
||||
Ok,
|
||||
Timeout,
|
||||
RequestFailed,
|
||||
RequestCanceled,
|
||||
MalformedResponse,
|
||||
}
|
||||
|
||||
impl AppServerAttestationStatus {
|
||||
const fn code(self) -> u8 {
|
||||
match self {
|
||||
Self::Ok => 0,
|
||||
Self::Timeout => 1,
|
||||
Self::RequestFailed => 2,
|
||||
Self::RequestCanceled => 3,
|
||||
Self::MalformedResponse => 4,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct AppServerAttestationEnvelope<'a> {
|
||||
v: u8,
|
||||
s: u8,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
t: Option<&'a str>,
|
||||
}
|
||||
|
||||
fn app_server_attestation_header_value(
|
||||
status: AppServerAttestationStatus,
|
||||
token: Option<&str>,
|
||||
) -> String {
|
||||
serde_json::to_string(&AppServerAttestationEnvelope {
|
||||
v: 1,
|
||||
s: status.code(),
|
||||
t: token,
|
||||
})
|
||||
.expect("app-server attestation envelope should serialize")
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct ExternalAuthRefreshBridge {
|
||||
outgoing: Arc<OutgoingMessageSender>,
|
||||
@@ -202,115 +153,6 @@ impl ExternalAuth for ExternalAuthRefreshBridge {
|
||||
}
|
||||
}
|
||||
|
||||
fn app_server_attestation_provider(
|
||||
outgoing: Arc<OutgoingMessageSender>,
|
||||
thread_state_manager: ThreadStateManager,
|
||||
) -> Arc<dyn AttestationProvider> {
|
||||
Arc::new(AppServerAttestationProvider {
|
||||
outgoing,
|
||||
thread_state_manager,
|
||||
})
|
||||
}
|
||||
|
||||
struct AppServerAttestationProvider {
|
||||
outgoing: Arc<OutgoingMessageSender>,
|
||||
thread_state_manager: ThreadStateManager,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for AppServerAttestationProvider {
|
||||
fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
formatter
|
||||
.debug_struct("AppServerAttestationProvider")
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl AttestationProvider for AppServerAttestationProvider {
|
||||
fn header_for_request(&self, context: AttestationContext) -> GenerateAttestationFuture<'_> {
|
||||
let outgoing = self.outgoing.clone();
|
||||
let thread_state_manager = self.thread_state_manager.clone();
|
||||
Box::pin(async move {
|
||||
if !context.uses_chatgpt_auth {
|
||||
return None;
|
||||
}
|
||||
|
||||
request_attestation_header_value_with_timeout(
|
||||
outgoing,
|
||||
thread_state_manager,
|
||||
ATTESTATION_GENERATE_TIMEOUT,
|
||||
)
|
||||
.await
|
||||
.and_then(|value| HeaderValue::from_bytes(value.as_bytes()).ok())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
async fn request_attestation_header_value_with_timeout(
|
||||
outgoing: Arc<OutgoingMessageSender>,
|
||||
thread_state_manager: ThreadStateManager,
|
||||
timeout_duration: Duration,
|
||||
) -> Option<String> {
|
||||
let connection_id = thread_state_manager
|
||||
.first_attestation_capable_connection()
|
||||
.await?;
|
||||
|
||||
let connection_ids = [connection_id];
|
||||
let (request_id, rx) = outgoing
|
||||
.send_request_to_connections(
|
||||
Some(&connection_ids),
|
||||
ServerRequestPayload::AttestationGenerate(AttestationGenerateParams {}),
|
||||
/*thread_id*/ None,
|
||||
)
|
||||
.await;
|
||||
|
||||
let result = match timeout(timeout_duration, rx).await {
|
||||
Ok(Ok(Ok(result))) => result,
|
||||
Ok(Ok(Err(err))) => {
|
||||
warn!(
|
||||
code = err.code,
|
||||
message = %err.message,
|
||||
"attestation generation request failed"
|
||||
);
|
||||
return Some(app_server_attestation_header_value(
|
||||
AppServerAttestationStatus::RequestFailed,
|
||||
None,
|
||||
));
|
||||
}
|
||||
Ok(Err(err)) => {
|
||||
warn!("attestation generation request canceled: {err}");
|
||||
return Some(app_server_attestation_header_value(
|
||||
AppServerAttestationStatus::RequestCanceled,
|
||||
None,
|
||||
));
|
||||
}
|
||||
Err(_) => {
|
||||
let _canceled = outgoing.cancel_request(&request_id).await;
|
||||
warn!(
|
||||
timeout_seconds = timeout_duration.as_secs(),
|
||||
"attestation generation request timed out"
|
||||
);
|
||||
return Some(app_server_attestation_header_value(
|
||||
AppServerAttestationStatus::Timeout,
|
||||
None,
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
match serde_json::from_value::<AttestationGenerateResponse>(result) {
|
||||
Ok(response) => Some(app_server_attestation_header_value(
|
||||
AppServerAttestationStatus::Ok,
|
||||
Some(&response.header_value),
|
||||
)),
|
||||
Err(err) => {
|
||||
warn!("failed to deserialize attestation generation response: {err}");
|
||||
Some(app_server_attestation_header_value(
|
||||
AppServerAttestationStatus::MalformedResponse,
|
||||
None,
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct MessageProcessor {
|
||||
outgoing: Arc<OutgoingMessageSender>,
|
||||
account_processor: AccountRequestProcessor,
|
||||
@@ -1456,10 +1298,6 @@ impl MessageProcessor {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "message_processor_attestation_tests.rs"]
|
||||
mod message_processor_attestation_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "message_processor_tracing_tests.rs"]
|
||||
mod message_processor_tracing_tests;
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
use super::AppServerAttestationStatus;
|
||||
use super::app_server_attestation_header_value;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
#[test]
|
||||
fn app_server_attestation_header_value_wraps_opaque_client_payloads() {
|
||||
assert_eq!(
|
||||
app_server_attestation_header_value(
|
||||
AppServerAttestationStatus::Ok,
|
||||
Some("v1.opaque-client-payload"),
|
||||
),
|
||||
r#"{"v":1,"s":0,"t":"v1.opaque-client-payload"}"#
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn app_server_attestation_header_value_reports_app_server_failures() {
|
||||
assert_eq!(
|
||||
app_server_attestation_header_value(AppServerAttestationStatus::Timeout, None),
|
||||
r#"{"v":1,"s":1}"#
|
||||
);
|
||||
assert_eq!(
|
||||
app_server_attestation_header_value(AppServerAttestationStatus::RequestFailed, None),
|
||||
r#"{"v":1,"s":2}"#
|
||||
);
|
||||
assert_eq!(
|
||||
app_server_attestation_header_value(AppServerAttestationStatus::RequestCanceled, None),
|
||||
r#"{"v":1,"s":3}"#
|
||||
);
|
||||
assert_eq!(
|
||||
app_server_attestation_header_value(AppServerAttestationStatus::MalformedResponse, None),
|
||||
r#"{"v":1,"s":4}"#
|
||||
);
|
||||
}
|
||||
Reference in New Issue
Block a user