diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index 59d23a4a07..1946bea061 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -2385,6 +2385,7 @@ dependencies = [ "codex-terminal-detection", "codex-utils-template", "core_test_support", + "ed25519-dalek", "keyring", "once_cell", "os_info", diff --git a/codex-rs/codex-api/src/files.rs b/codex-rs/codex-api/src/files.rs index d1e2840066..a938242934 100644 --- a/codex-rs/codex-api/src/files.rs +++ b/codex-rs/codex-api/src/files.rs @@ -308,6 +308,18 @@ mod tests { ChatGptTestAuth } + #[derive(Clone, Copy)] + struct AgentAssertionTestAuth; + + impl AuthProvider for AgentAssertionTestAuth { + fn add_auth_headers(&self, headers: &mut reqwest::header::HeaderMap) { + headers.insert( + reqwest::header::AUTHORIZATION, + HeaderValue::from_static("AgentAssertion test-assertion"), + ); + } + } + fn base_url_for(server: &MockServer) -> String { format!("{}/backend-api", server.uri()) } @@ -317,6 +329,7 @@ mod tests { let server = MockServer::start().await; Mock::given(method("POST")) .and(path("/backend-api/files")) + .and(header("authorization", "Bearer token")) .and(header("chatgpt-account-id", "account_id")) .and(body_json(serde_json::json!({ "file_name": "hello.txt", @@ -377,4 +390,52 @@ mod tests { assert_eq!(uploaded.mime_type, Some("text/plain".to_string())); assert_eq!(finalize_attempts.load(Ordering::SeqCst), 2); } + + #[tokio::test] + async fn upload_local_file_uses_authorization_header_value() { + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/backend-api/files")) + .and(header("authorization", "AgentAssertion test-assertion")) + .and(body_json(serde_json::json!({ + "file_name": "hello.txt", + "file_size": 5, + "use_case": "codex", + }))) + .respond_with( + ResponseTemplate::new(200) + .set_body_json(serde_json::json!({"file_id": "file_123", "upload_url": format!("{}/upload/file_123", server.uri())})), + ) + .mount(&server) + .await; + Mock::given(method("PUT")) + .and(path("/upload/file_123")) + .and(header("content-length", "5")) + .respond_with(ResponseTemplate::new(200)) + .mount(&server) + .await; + Mock::given(method("POST")) + .and(path("/backend-api/files/file_123/uploaded")) + .and(header("authorization", "AgentAssertion test-assertion")) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "status": "success", + "download_url": format!("{}/download/file_123", server.uri()), + "file_name": "hello.txt", + "mime_type": "text/plain", + "file_size_bytes": 5 + }))) + .mount(&server) + .await; + + let base_url = base_url_for(&server); + let dir = TempDir::new().expect("temp dir"); + let path = dir.path().join("hello.txt"); + tokio::fs::write(&path, b"hello").await.expect("write file"); + + let uploaded = upload_local_file(&base_url, &AgentAssertionTestAuth, &path) + .await + .expect("upload succeeds"); + + assert_eq!(uploaded.file_id, "file_123"); + } } diff --git a/codex-rs/core/src/agent_identity.rs b/codex-rs/core/src/agent_identity.rs index 8e05dfc6e8..a77fd8fa4a 100644 --- a/codex-rs/core/src/agent_identity.rs +++ b/codex-rs/core/src/agent_identity.rs @@ -27,12 +27,12 @@ use tracing::debug; use tracing::info; use tracing::warn; +use crate::config::Config; + mod task_registration; pub(crate) use task_registration::RegisteredAgentTask; -use crate::config::Config; - const AGENT_REGISTRATION_TIMEOUT: Duration = Duration::from_secs(15); const AGENT_IDENTITY_BISCUIT_TIMEOUT: Duration = Duration::from_secs(15); @@ -335,7 +335,7 @@ impl AgentIdentityManager { } #[cfg(test)] - fn new_for_tests( + pub(crate) fn new_for_tests( auth_manager: Arc, feature_enabled: bool, chatgpt_base_url: String, @@ -349,6 +349,30 @@ impl AgentIdentityManager { ensure_lock: Arc::new(Mutex::new(())), } } + + #[cfg(test)] + pub(crate) async fn seed_generated_identity_for_tests( + &self, + agent_runtime_id: &str, + ) -> Result { + let (auth, binding) = self + .current_auth_binding() + .await + .context("test agent identity requires ChatGPT auth")?; + let key_material = generate_agent_key_material()?; + let stored_identity = StoredAgentIdentity { + binding_id: binding.binding_id.clone(), + chatgpt_account_id: binding.chatgpt_account_id.clone(), + chatgpt_user_id: binding.chatgpt_user_id, + agent_runtime_id: agent_runtime_id.to_string(), + private_key_pkcs8_base64: key_material.private_key_pkcs8_base64, + public_key_ssh: key_material.public_key_ssh, + registered_at: Utc::now().to_rfc3339_opts(SecondsFormat::Secs, true), + abom: self.abom.clone(), + }; + self.store_identity(&auth, &stored_identity)?; + Ok(stored_identity) + } } impl StoredAgentIdentity { @@ -579,7 +603,7 @@ mod tests { .and(path("/v1/agent/register")) .and(header("x-openai-authorization", "human-biscuit")) .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ - "agent_runtime_id": "agent_123", + "agent_runtime_id": "agent-123", }))) .expect(1) .mount(&server) @@ -605,7 +629,7 @@ mod tests { .unwrap() .expect("identity should be reused"); - assert_eq!(first.agent_runtime_id, "agent_123"); + assert_eq!(first.agent_runtime_id, "agent-123"); assert_eq!(first, second); assert_eq!(first.abom.agent_harness_id, "codex-cli"); assert_eq!(first.chatgpt_account_id, "account-123"); @@ -621,7 +645,7 @@ mod tests { .and(path("/v1/agent/register")) .and(header("x-openai-authorization", "human-biscuit")) .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ - "agent_runtime_id": "agent_456", + "agent_runtime_id": "agent-456", }))) .expect(1) .mount(&server) @@ -653,11 +677,11 @@ mod tests { .unwrap() .expect("identity should be registered"); - assert_eq!(stored.agent_runtime_id, "agent_456"); + assert_eq!(stored.agent_runtime_id, "agent-456"); let persisted = auth .get_agent_identity(&binding.chatgpt_account_id) .expect("stored identity"); - assert_eq!(persisted.agent_runtime_id, "agent_456"); + assert_eq!(persisted.agent_runtime_id, "agent-456"); } #[tokio::test] diff --git a/codex-rs/core/src/agent_identity/task_registration.rs b/codex-rs/core/src/agent_identity/task_registration.rs index 53bb272cbc..d00f7d25e4 100644 --- a/codex-rs/core/src/agent_identity/task_registration.rs +++ b/codex-rs/core/src/agent_identity/task_registration.rs @@ -2,6 +2,7 @@ use std::time::Duration; use anyhow::Context; use anyhow::Result; +use codex_login::AgentTaskAuthorizationTarget; use codex_protocol::protocol::SessionAgentTask; use crypto_box::SecretKey as Curve25519SecretKey; use ed25519_dalek::Signer as _; @@ -102,6 +103,13 @@ impl AgentIdentityManager { } impl RegisteredAgentTask { + pub(crate) fn authorization_target(&self) -> AgentTaskAuthorizationTarget<'_> { + AgentTaskAuthorizationTarget { + agent_runtime_id: &self.agent_runtime_id, + task_id: &self.task_id, + } + } + pub(crate) fn to_session_agent_task(&self) -> SessionAgentTask { SessionAgentTask { agent_runtime_id: self.agent_runtime_id.clone(), diff --git a/codex-rs/core/src/arc_monitor.rs b/codex-rs/core/src/arc_monitor.rs index ecd7f39666..99767872d5 100644 --- a/codex-rs/core/src/arc_monitor.rs +++ b/codex-rs/core/src/arc_monitor.rs @@ -13,6 +13,7 @@ use codex_login::CodexAuth; use codex_login::default_client::build_reqwest_client; use codex_protocol::models::MessagePhase; use codex_protocol::models::ResponseItem; +use reqwest::header::AUTHORIZATION; const ARC_MONITOR_TIMEOUT: Duration = Duration::from_secs(30); const CODEX_ARC_MONITOR_ENDPOINT_OVERRIDE: &str = "CODEX_ARC_MONITOR_ENDPOINT_OVERRIDE"; @@ -109,13 +110,31 @@ pub(crate) async fn monitor_action( }, None => None, }; - let token = if let Some(token) = read_non_empty_env_var(CODEX_ARC_MONITOR_TOKEN) { - token + let (authorization_header_value, account_id) = if let Some(token) = + read_non_empty_env_var(CODEX_ARC_MONITOR_TOKEN) + { + ( + format!("Bearer {token}"), + auth.as_ref().and_then(CodexAuth::get_account_id), + ) + } else if let Some(authorization_header_value) = + match sess.authorization_header_for_current_agent_task().await { + Ok(authorization_header_value) => authorization_header_value, + Err(err) => { + warn!( + error = %err, + "skipping safety monitor because agent assertion authorization is unavailable" + ); + return ArcMonitorOutcome::Ok; + } + } + { + (authorization_header_value, None) } else { let Some(auth) = auth.as_ref() else { return ArcMonitorOutcome::Ok; }; - match auth.get_token() { + let token = match auth.get_token() { Ok(token) => token, Err(err) => { warn!( @@ -124,7 +143,8 @@ pub(crate) async fn monitor_action( ); return ArcMonitorOutcome::Ok; } - } + }; + (format!("Bearer {token}"), auth.get_account_id()) }; let url = read_non_empty_env_var(CODEX_ARC_MONITOR_ENDPOINT_OVERRIDE).unwrap_or_else(|| { @@ -147,8 +167,8 @@ pub(crate) async fn monitor_action( .post(&url) .timeout(ARC_MONITOR_TIMEOUT) .json(&body) - .bearer_auth(token); - if let Some(account_id) = auth.as_ref().and_then(CodexAuth::get_account_id) { + .header(AUTHORIZATION, authorization_header_value); + if let Some(account_id) = account_id { request = request.header("chatgpt-account-id", account_id); } diff --git a/codex-rs/core/src/arc_monitor_tests.rs b/codex-rs/core/src/arc_monitor_tests.rs index 6a345e31e3..6bf1aa98f0 100644 --- a/codex-rs/core/src/arc_monitor_tests.rs +++ b/codex-rs/core/src/arc_monitor_tests.rs @@ -9,17 +9,37 @@ use wiremock::MockServer; use wiremock::ResponseTemplate; use wiremock::matchers::body_json; use wiremock::matchers::header; +use wiremock::matchers::header_regex; use wiremock::matchers::method; use wiremock::matchers::path; use super::*; +use crate::agent_identity::AgentIdentityManager; +use crate::agent_identity::RegisteredAgentTask; use crate::session::tests::make_session_and_context; +use chrono::Utc; +use codex_login::AuthCredentialsStoreMode; +use codex_login::AuthDotJson; +use codex_login::AuthManager; +use codex_login::CodexAuth; +use codex_login::save_auth; +use codex_login::token_data::IdTokenInfo; +use codex_login::token_data::TokenData; use codex_protocol::models::ContentItem; use codex_protocol::models::LocalShellAction; use codex_protocol::models::LocalShellExecAction; use codex_protocol::models::LocalShellStatus; use codex_protocol::models::MessagePhase; use codex_protocol::models::ResponseItem; +use codex_protocol::protocol::SessionSource; +use tempfile::tempdir; + +const TEST_ID_TOKEN: &str = concat!( + "eyJhbGciOiJub25lIiwidHlwIjoiSldUIn0.", + "eyJodHRwczovL2FwaS5vcGVuYWkuY29tL2F1dGgiOnsiY2hhdGdwdF91c2VyX2lk", + "IjpudWxsLCJjaGF0Z3B0X2FjY291bnRfaWQiOiJhY2NvdW50X2lkIn19.", + "c2ln", +); struct EnvVarGuard { key: &'static str, @@ -49,6 +69,58 @@ impl Drop for EnvVarGuard { } } +async fn install_cached_agent_task_auth( + session: &mut Session, + turn_context: &mut TurnContext, + chatgpt_base_url: String, +) { + let auth_dir = tempdir().expect("temp auth dir"); + let auth_json = AuthDotJson { + auth_mode: Some(codex_app_server_protocol::AuthMode::Chatgpt), + openai_api_key: None, + tokens: Some(TokenData { + id_token: IdTokenInfo { + email: None, + chatgpt_plan_type: None, + chatgpt_user_id: None, + chatgpt_account_id: Some("account_id".to_string()), + chatgpt_account_is_fedramp: false, + raw_jwt: TEST_ID_TOKEN.to_string(), + }, + access_token: "Access Token".to_string(), + refresh_token: "test".to_string(), + account_id: Some("account_id".to_string()), + }), + last_refresh: Some(Utc::now()), + agent_identity: None, + }; + save_auth(auth_dir.path(), &auth_json, AuthCredentialsStoreMode::File).expect("save test auth"); + let auth = CodexAuth::from_auth_storage(auth_dir.path(), AuthCredentialsStoreMode::File) + .expect("load test auth") + .expect("test auth"); + let auth_manager = AuthManager::from_auth_for_testing(auth); + let agent_identity_manager = Arc::new(AgentIdentityManager::new_for_tests( + Arc::clone(&auth_manager), + /*feature_enabled*/ true, + chatgpt_base_url, + SessionSource::Exec, + )); + let stored_identity = agent_identity_manager + .seed_generated_identity_for_tests("agent-123") + .await + .expect("seed test identity"); + session.services.auth_manager = Arc::clone(&auth_manager); + session.services.agent_identity_manager = agent_identity_manager; + turn_context.auth_manager = Some(auth_manager); + session + .cache_agent_task_for_tests(RegisteredAgentTask { + agent_runtime_id: stored_identity.agent_runtime_id, + task_id: "task-123".to_string(), + registered_at: "2026-04-15T00:00:00Z".to_string(), + }) + .await; +} + #[tokio::test] async fn build_arc_monitor_request_includes_relevant_history_and_null_policies() { let (session, mut turn_context) = make_session_and_context().await; @@ -247,6 +319,80 @@ async fn build_arc_monitor_request_includes_relevant_history_and_null_policies() ); } +#[tokio::test] +#[serial(arc_monitor_env)] +async fn monitor_action_uses_agent_assertion_for_cached_task() { + let server = MockServer::start().await; + let (mut session, mut turn_context) = make_session_and_context().await; + install_cached_agent_task_auth(&mut session, &mut turn_context, server.uri()).await; + + let mut config = (*turn_context.config).clone(); + config.chatgpt_base_url = server.uri(); + turn_context.config = Arc::new(config); + + session + .record_into_history( + &[ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: "please run the tool".to_string(), + }], + end_turn: None, + phase: None, + }], + &turn_context, + ) + .await; + + Mock::given(method("POST")) + .and(path("/codex/safety/arc")) + .and(header_regex("authorization", r"^AgentAssertion .+")) + .and(body_json(serde_json::json!({ + "metadata": { + "codex_thread_id": session.conversation_id.to_string(), + "codex_turn_id": turn_context.sub_id.clone(), + "conversation_id": session.conversation_id.to_string(), + "protection_client_callsite": "normal", + }, + "messages": [{ + "role": "user", + "content": [{ + "type": "input_text", + "text": "please run the tool", + }], + }], + "policies": { + "developer": null, + "user": null, + }, + "action": { + "tool": "mcp_tool_call", + }, + }))) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "outcome": "ok", + "short_reason": "", + "rationale": "", + "risk_score": 1, + "risk_level": "low", + "evidence": [], + }))) + .expect(1) + .mount(&server) + .await; + + let outcome = monitor_action( + &session, + &turn_context, + serde_json::json!({ "tool": "mcp_tool_call" }), + "normal", + ) + .await; + + assert_eq!(outcome, ArcMonitorOutcome::Ok); +} + #[tokio::test] #[serial(arc_monitor_env)] async fn monitor_action_posts_expected_arc_request() { diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs index c5b7c46122..c19b3b2d6e 100644 --- a/codex-rs/core/src/client.rs +++ b/codex-rs/core/src/client.rs @@ -31,6 +31,7 @@ use std::sync::atomic::AtomicBool; use std::sync::atomic::AtomicU64; use std::sync::atomic::Ordering; +use crate::agent_identity::RegisteredAgentTask; use codex_api::ApiError; use codex_api::AuthProvider; use codex_api::CompactClient as ApiCompactClient; @@ -95,6 +96,7 @@ use tokio::sync::oneshot; use tokio::sync::oneshot::error::TryRecvError; use tokio_tungstenite::tungstenite::Error; use tokio_tungstenite::tungstenite::Message; +use tracing::debug; use tracing::instrument; use tracing::trace; use tracing::warn; @@ -109,6 +111,7 @@ use codex_feedback::FeedbackRequestTags; use codex_feedback::emit_feedback_request_tags_with_auth_env; use codex_login::auth_env_telemetry::AuthEnvTelemetry; use codex_login::auth_env_telemetry::collect_auth_env_telemetry; +use codex_model_provider::AuthorizationHeaderAuthProvider; use codex_model_provider::SharedModelProvider; use codex_model_provider::create_model_provider; #[cfg(test)] @@ -212,6 +215,8 @@ pub struct ModelClient { pub struct ModelClientSession { client: ModelClient, websocket_session: WebsocketSession, + agent_task: Option, + cache_websocket_session_on_drop: bool, /// Turn state for sticky routing. /// /// This is an `OnceLock` that stores the turn state value received from the server @@ -329,9 +334,25 @@ impl ModelClient { /// This constructor does not perform network I/O itself; the session opens a websocket lazily /// when the first stream request is issued. pub fn new_session(&self) -> ModelClientSession { + self.new_session_with_agent_task(/*agent_task*/ None) + } + + pub(crate) fn new_session_with_agent_task( + &self, + agent_task: Option, + ) -> ModelClientSession { + let cache_websocket_session_on_drop = agent_task.is_none(); + let websocket_session = if agent_task.is_some() { + drop(self.take_cached_websocket_session()); + WebsocketSession::default() + } else { + self.take_cached_websocket_session() + }; ModelClientSession { client: self.clone(), - websocket_session: self.take_cached_websocket_session(), + websocket_session, + agent_task, + cache_websocket_session_on_drop, turn_state: Arc::new(OnceLock::new()), } } @@ -414,7 +435,7 @@ impl ModelClient { if prompt.input.is_empty() { return Ok(Vec::new()); } - let client_setup = self.current_client_setup().await?; + let client_setup = self.current_client_setup(/*agent_task*/ None).await?; let transport = ReqwestTransport::new(build_reqwest_client()); let request_telemetry = Self::build_request_telemetry( session_telemetry, @@ -478,7 +499,7 @@ impl ModelClient { ) -> Result { // Create the media call over HTTP first, then retain matching auth so realtime can attach // the server-side control WebSocket to the call id from that HTTP response. - let client_setup = self.current_client_setup().await?; + let client_setup = self.current_client_setup(/*agent_task*/ None).await?; let mut sideband_headers = extra_headers.clone(); sideband_headers.extend(sideband_websocket_auth_headers( client_setup.api_auth.as_ref(), @@ -513,7 +534,7 @@ impl ModelClient { return Ok(Vec::new()); } - let client_setup = self.current_client_setup().await?; + let client_setup = self.current_client_setup(/*agent_task*/ None).await?; let transport = ReqwestTransport::new(build_reqwest_client()); let request_telemetry = Self::build_request_telemetry( session_telemetry, @@ -654,10 +675,46 @@ impl ModelClient { /// /// This centralizes setup used by both prewarm and normal request paths so they stay in /// lockstep when auth/provider resolution changes. - async fn current_client_setup(&self) -> Result { + async fn current_client_setup( + &self, + agent_task: Option<&RegisteredAgentTask>, + ) -> Result { let auth = self.state.provider.auth().await; let api_provider = self.state.provider.api_provider().await?; - let api_auth = self.state.provider.api_auth().await?; + let auth_manager = self.state.provider.auth_manager(); + let api_auth = match (agent_task, auth_manager.as_ref(), auth.as_ref()) { + (Some(agent_task), Some(auth_manager), Some(auth)) => { + if let Some(authorization_header_value) = auth_manager + .chatgpt_agent_task_authorization_header_for_auth( + auth, + agent_task.authorization_target(), + ) + .map_err(|err| { + CodexErr::Stream( + format!("failed to build agent assertion authorization: {err}"), + None, + ) + })? + { + debug!( + agent_runtime_id = %agent_task.agent_runtime_id, + task_id = %agent_task.task_id, + "using agent assertion authorization for downstream request" + ); + let mut auth_provider = AuthorizationHeaderAuthProvider::new( + Some(authorization_header_value), + /*account_id*/ None, + ); + if auth.is_fedramp_account() { + auth_provider = auth_provider.with_fedramp_routing_header(); + } + Arc::new(auth_provider) + } else { + self.state.provider.api_auth().await? + } + } + _ => self.state.provider.api_auth().await?, + }; Ok(CurrentClientSetup { auth, api_provider, @@ -791,12 +848,18 @@ impl ModelClient { impl Drop for ModelClientSession { fn drop(&mut self) { let websocket_session = std::mem::take(&mut self.websocket_session); - self.client - .store_cached_websocket_session(websocket_session); + if self.cache_websocket_session_on_drop { + self.client + .store_cached_websocket_session(websocket_session); + } } } impl ModelClientSession { + pub(crate) fn disable_cached_websocket_session_on_drop(&mut self) { + self.cache_websocket_session_on_drop = false; + } + pub(crate) fn reset_websocket_session(&mut self) { self.websocket_session.connection = None; self.websocket_session.last_request = None; @@ -998,11 +1061,15 @@ impl ModelClientSession { return Ok(()); } - let client_setup = self.client.current_client_setup().await.map_err(|err| { - ApiError::Stream(format!( - "failed to build websocket prewarm client setup: {err}" - )) - })?; + let client_setup = self + .client + .current_client_setup(self.agent_task.as_ref()) + .await + .map_err(|err| { + ApiError::Stream(format!( + "failed to build websocket prewarm client setup: {err}" + )) + })?; let auth_context = AuthRequestTelemetryContext::new( client_setup.auth.as_ref().map(CodexAuth::auth_mode), client_setup.api_auth.as_ref(), @@ -1156,7 +1223,10 @@ impl ModelClientSession { .map(AuthManager::unauthorized_recovery); let mut pending_retry = PendingUnauthorizedRetry::default(); loop { - let client_setup = self.client.current_client_setup().await?; + let client_setup = self + .client + .current_client_setup(self.agent_task.as_ref()) + .await?; let transport = ReqwestTransport::new(build_reqwest_client()); let request_auth_context = AuthRequestTelemetryContext::new( client_setup.auth.as_ref().map(CodexAuth::auth_mode), @@ -1245,7 +1315,10 @@ impl ModelClientSession { .map(AuthManager::unauthorized_recovery); let mut pending_retry = PendingUnauthorizedRetry::default(); loop { - let client_setup = self.client.current_client_setup().await?; + let client_setup = self + .client + .current_client_setup(self.agent_task.as_ref()) + .await?; let request_auth_context = AuthRequestTelemetryContext::new( client_setup.auth.as_ref().map(CodexAuth::auth_mode), client_setup.api_auth.as_ref(), diff --git a/codex-rs/core/src/client_tests.rs b/codex-rs/core/src/client_tests.rs index f4575b26a0..357a892264 100644 --- a/codex-rs/core/src/client_tests.rs +++ b/codex-rs/core/src/client_tests.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use super::AuthRequestTelemetryContext; use super::ModelClient; use super::PendingUnauthorizedRetry; @@ -7,17 +9,36 @@ use super::X_CODEX_PARENT_THREAD_ID_HEADER; use super::X_CODEX_TURN_METADATA_HEADER; use super::X_CODEX_WINDOW_ID_HEADER; use super::X_OPENAI_SUBAGENT_HEADER; +use crate::Prompt; +use crate::ResponseEvent; +use crate::agent_identity::AgentIdentityManager; +use crate::agent_identity::RegisteredAgentTask; +use crate::agent_identity::StoredAgentIdentity; +use base64::Engine as _; +use base64::engine::general_purpose::URL_SAFE_NO_PAD; use codex_app_server_protocol::AuthMode; +use codex_login::AuthManager; +use codex_login::CodexAuth; use codex_model_provider::BearerAuthProvider; +use codex_model_provider_info::ModelProviderInfo; use codex_model_provider_info::WireApi; use codex_model_provider_info::create_oss_provider_with_base_url; use codex_otel::SessionTelemetry; use codex_protocol::ThreadId; +use codex_protocol::config_types::ReasoningSummary; +use codex_protocol::models::ContentItem; +use codex_protocol::models::ResponseItem; use codex_protocol::openai_models::ModelInfo; use codex_protocol::protocol::SessionSource; use codex_protocol::protocol::SubAgentSource; +use core_test_support::responses; +use ed25519_dalek::Signature; +use ed25519_dalek::Verifier as _; +use futures::StreamExt; use pretty_assertions::assert_eq; +use serde::Deserialize; use serde_json::json; +use tempfile::TempDir; fn test_model_client(session_source: SessionSource) -> ModelClient { let provider = create_oss_provider_with_base_url("https://example.com/v1", WireApi::Responses); @@ -79,6 +100,118 @@ fn test_session_telemetry() -> SessionTelemetry { ) } +fn test_prompt(text: &str) -> Prompt { + Prompt { + input: vec![ResponseItem::Message { + id: None, + role: "user".into(), + content: vec![ContentItem::InputText { + text: text.to_string(), + }], + end_turn: None, + phase: None, + }], + ..Prompt::default() + } +} + +async fn drain_stream_to_completion(stream: &mut crate::ResponseStream) -> anyhow::Result<()> { + while let Some(event) = stream.next().await { + if matches!(event?, ResponseEvent::Completed { .. }) { + break; + } + } + Ok(()) +} + +async fn model_client_with_agent_task( + provider: ModelProviderInfo, +) -> ( + TempDir, + ModelClient, + RegisteredAgentTask, + StoredAgentIdentity, +) { + let codex_home = tempfile::tempdir().expect("tempdir"); + let auth_manager = + AuthManager::from_auth_for_testing(CodexAuth::create_dummy_chatgpt_auth_for_testing()); + let agent_identity_manager = Arc::new(AgentIdentityManager::new_for_tests( + Arc::clone(&auth_manager), + /*feature_enabled*/ true, + "https://chatgpt.com/backend-api/".to_string(), + SessionSource::Cli, + )); + let stored_identity = agent_identity_manager + .seed_generated_identity_for_tests("agent-123") + .await + .expect("seed test identity"); + let agent_task = RegisteredAgentTask { + agent_runtime_id: stored_identity.agent_runtime_id.clone(), + task_id: "task-123".to_string(), + registered_at: "2026-03-23T12:00:00Z".to_string(), + }; + let client = ModelClient::new( + Some(auth_manager), + ThreadId::new(), + /*installation_id*/ "11111111-1111-4111-8111-111111111111".to_string(), + provider, + SessionSource::Cli, + /*model_verbosity*/ None, + /*enable_request_compression*/ false, + /*include_timing_metrics*/ false, + /*beta_features_header*/ None, + ); + (codex_home, client, agent_task, stored_identity) +} + +#[derive(Debug, Deserialize)] +struct AgentAssertionEnvelope { + agent_runtime_id: String, + task_id: String, + timestamp: String, + signature: String, +} + +fn assert_agent_assertion_header( + authorization_header: &str, + stored_identity: &StoredAgentIdentity, + expected_agent_runtime_id: &str, + expected_task_id: &str, +) { + let token = authorization_header + .strip_prefix("AgentAssertion ") + .expect("agent assertion authorization scheme"); + let envelope: AgentAssertionEnvelope = serde_json::from_slice( + &URL_SAFE_NO_PAD + .decode(token) + .expect("base64url-encoded agent assertion"), + ) + .expect("valid agent assertion envelope"); + + assert_eq!(envelope.agent_runtime_id, expected_agent_runtime_id); + assert_eq!(envelope.task_id, expected_task_id); + + let signature = Signature::from_slice( + &base64::engine::general_purpose::STANDARD + .decode(&envelope.signature) + .expect("base64 signature"), + ) + .expect("signature bytes"); + stored_identity + .signing_key() + .expect("signing key") + .verifying_key() + .verify( + format!( + "{}:{}:{}", + envelope.agent_runtime_id, envelope.task_id, envelope.timestamp + ) + .as_bytes(), + &signature, + ) + .expect("signature should verify"); +} + #[test] fn build_subagent_headers_sets_other_subagent_label() { let client = test_model_client(SessionSource::SubAgent(SubAgentSource::Other( @@ -169,3 +302,130 @@ fn auth_request_telemetry_context_tracks_attached_auth_and_retry_phase() { assert_eq!(auth_context.recovery_mode, Some("managed")); assert_eq!(auth_context.recovery_phase, Some("refresh_token")); } + +#[tokio::test] +async fn responses_http_uses_agent_assertion_when_agent_task_is_present() { + core_test_support::skip_if_no_network!(); + + let server = responses::start_mock_server().await; + let request_recorder = responses::mount_sse_once( + &server, + responses::sse(vec![ + responses::ev_response_created("resp-1"), + responses::ev_completed("resp-1"), + ]), + ) + .await; + let provider = + create_oss_provider_with_base_url(&format!("{}/v1", server.uri()), WireApi::Responses); + let (_codex_home, client, agent_task, stored_identity) = + model_client_with_agent_task(provider).await; + let model_info = test_model_info(); + let session_telemetry = test_session_telemetry(); + let mut client_session = client.new_session_with_agent_task(Some(agent_task.clone())); + + let mut stream = client_session + .stream( + &test_prompt("hello"), + &model_info, + &session_telemetry, + /*effort*/ None, + ReasoningSummary::Auto, + /*service_tier*/ None, + /*turn_metadata_header*/ None, + ) + .await + .expect("stream request should succeed"); + drain_stream_to_completion(&mut stream) + .await + .expect("stream should complete"); + + let request = request_recorder.single_request(); + let authorization = request + .header("authorization") + .expect("authorization header should be present"); + assert_agent_assertion_header( + &authorization, + &stored_identity, + &agent_task.agent_runtime_id, + &agent_task.task_id, + ); + assert_eq!(request.header("chatgpt-account-id"), None); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn websocket_agent_task_bypasses_cached_bearer_prewarm() { + core_test_support::skip_if_no_network!(); + + let server = responses::start_websocket_server(vec![ + vec![vec![ + responses::ev_response_created("resp-prewarm"), + responses::ev_completed("resp-prewarm"), + ]], + vec![vec![ + responses::ev_response_created("resp-1"), + responses::ev_completed("resp-1"), + ]], + ]) + .await; + let mut provider = + create_oss_provider_with_base_url(&format!("{}/v1", server.uri()), WireApi::Responses); + provider.supports_websockets = true; + provider.websocket_connect_timeout_ms = Some(5_000); + let (_codex_home, client, agent_task, stored_identity) = + model_client_with_agent_task(provider).await; + let model_info = test_model_info(); + let session_telemetry = test_session_telemetry(); + let prompt = test_prompt("hello"); + + let mut prewarm_session = client.new_session(); + prewarm_session + .prewarm_websocket( + &prompt, + &model_info, + &session_telemetry, + /*effort*/ None, + ReasoningSummary::Auto, + /*service_tier*/ None, + /*turn_metadata_header*/ None, + ) + .await + .expect("bearer prewarm should succeed"); + drop(prewarm_session); + + let mut agent_task_session = client.new_session_with_agent_task(Some(agent_task.clone())); + let mut stream = agent_task_session + .stream( + &prompt, + &model_info, + &session_telemetry, + /*effort*/ None, + ReasoningSummary::Auto, + /*service_tier*/ None, + /*turn_metadata_header*/ None, + ) + .await + .expect("agent task stream should succeed"); + drain_stream_to_completion(&mut stream) + .await + .expect("agent task websocket stream should complete"); + + let handshakes = server.handshakes(); + assert_eq!(handshakes.len(), 2); + assert_eq!( + handshakes[0].header("authorization"), + Some("Bearer Access Token".to_string()) + ); + let agent_authorization = handshakes[1] + .header("authorization") + .expect("agent handshake should include authorization"); + assert_agent_assertion_header( + &agent_authorization, + &stored_identity, + &agent_task.agent_runtime_id, + &agent_task.task_id, + ); + assert_eq!(handshakes[1].header("chatgpt-account-id"), None); + + server.shutdown().await; +} diff --git a/codex-rs/core/src/mcp_openai_file.rs b/codex-rs/core/src/mcp_openai_file.rs index d6e6d1f9c0..324eb4ad1a 100644 --- a/codex-rs/core/src/mcp_openai_file.rs +++ b/codex-rs/core/src/mcp_openai_file.rs @@ -12,8 +12,10 @@ use crate::session::session::Session; use crate::session::turn_context::TurnContext; +use codex_api::AuthProvider; use codex_api::upload_local_file; use codex_login::CodexAuth; +use codex_model_provider::AuthorizationHeaderAuthProvider; use codex_model_provider::BearerAuthProvider; use serde_json::Value as JsonValue; @@ -40,9 +42,14 @@ pub(crate) async fn rewrite_mcp_tool_arguments_for_openai_files( let Some(value) = arguments.get(field_name) else { continue; }; - let Some(uploaded_value) = - rewrite_argument_value_for_openai_files(turn_context, auth.as_ref(), field_name, value) - .await? + let Some(uploaded_value) = rewrite_argument_value_for_openai_files( + sess, + turn_context, + auth.as_ref(), + field_name, + value, + ) + .await? else { continue; }; @@ -57,6 +64,7 @@ pub(crate) async fn rewrite_mcp_tool_arguments_for_openai_files( } async fn rewrite_argument_value_for_openai_files( + sess: &Session, turn_context: &TurnContext, auth: Option<&CodexAuth>, field_name: &str, @@ -65,6 +73,7 @@ async fn rewrite_argument_value_for_openai_files( match value { JsonValue::String(path_or_file_ref) => { let rewritten = build_uploaded_local_argument_value( + sess, turn_context, auth, field_name, @@ -81,6 +90,7 @@ async fn rewrite_argument_value_for_openai_files( return Ok(None); }; let rewritten = build_uploaded_local_argument_value( + sess, turn_context, auth, field_name, @@ -97,6 +107,7 @@ async fn rewrite_argument_value_for_openai_files( } async fn build_uploaded_local_argument_value( + sess: &Session, turn_context: &TurnContext, auth: Option<&CodexAuth>, field_name: &str, @@ -109,17 +120,32 @@ async fn build_uploaded_local_argument_value( "ChatGPT auth is required to upload local files for Codex Apps tools".to_string(), ); }; - let token_data = auth - .get_token_data() - .map_err(|error| format!("failed to read ChatGPT auth for file upload: {error}"))?; - let upload_auth = BearerAuthProvider { - token: Some(token_data.access_token), - account_id: token_data.account_id, - is_fedramp_account: auth.is_fedramp_account(), + let upload_auth: Box = if let Some(authorization_header_value) = sess + .authorization_header_for_current_agent_task() + .await + .map_err(|error| format!("failed to build agent assertion authorization: {error}"))? + { + let mut auth_provider = AuthorizationHeaderAuthProvider::new( + Some(authorization_header_value), + /*account_id*/ None, + ); + if auth.is_fedramp_account() { + auth_provider = auth_provider.with_fedramp_routing_header(); + } + Box::new(auth_provider) + } else { + let token_data = auth + .get_token_data() + .map_err(|error| format!("failed to read ChatGPT auth for file upload: {error}"))?; + Box::new(BearerAuthProvider { + token: Some(token_data.access_token), + account_id: token_data.account_id, + is_fedramp_account: auth.is_fedramp_account(), + }) }; let uploaded = upload_local_file( turn_context.config.chatgpt_base_url.trim_end_matches('/'), - &upload_auth, + upload_auth.as_ref(), &resolved_path, ) .await @@ -142,12 +168,82 @@ async fn build_uploaded_local_argument_value( #[cfg(test)] mod tests { use super::*; + use crate::agent_identity::AgentIdentityManager; + use crate::agent_identity::RegisteredAgentTask; use crate::session::tests::make_session_and_context; + use chrono::Utc; + use codex_login::AuthCredentialsStoreMode; + use codex_login::AuthDotJson; + use codex_login::AuthManager; + use codex_login::save_auth; + use codex_login::token_data::IdTokenInfo; + use codex_login::token_data::TokenData; + use codex_protocol::protocol::SessionSource; use codex_utils_absolute_path::AbsolutePathBuf; use pretty_assertions::assert_eq; use std::sync::Arc; use tempfile::tempdir; + const TEST_ID_TOKEN: &str = concat!( + "eyJhbGciOiJub25lIiwidHlwIjoiSldUIn0.", + "eyJodHRwczovL2FwaS5vcGVuYWkuY29tL2F1dGgiOnsiY2hhdGdwdF91c2VyX2lk", + "IjpudWxsLCJjaGF0Z3B0X2FjY291bnRfaWQiOiJhY2NvdW50X2lkIn19.", + "c2ln", + ); + + async fn install_cached_agent_task_auth( + session: &mut Session, + turn_context: &mut TurnContext, + chatgpt_base_url: String, + ) { + let auth_dir = tempdir().expect("temp auth dir"); + let auth_json = AuthDotJson { + auth_mode: Some(codex_app_server_protocol::AuthMode::Chatgpt), + openai_api_key: None, + tokens: Some(TokenData { + id_token: IdTokenInfo { + email: None, + chatgpt_plan_type: None, + chatgpt_user_id: None, + chatgpt_account_id: Some("account_id".to_string()), + chatgpt_account_is_fedramp: false, + raw_jwt: TEST_ID_TOKEN.to_string(), + }, + access_token: "Access Token".to_string(), + refresh_token: "test".to_string(), + account_id: Some("account_id".to_string()), + }), + last_refresh: Some(Utc::now()), + agent_identity: None, + }; + save_auth(auth_dir.path(), &auth_json, AuthCredentialsStoreMode::File) + .expect("save test auth"); + let auth = CodexAuth::from_auth_storage(auth_dir.path(), AuthCredentialsStoreMode::File) + .expect("load test auth") + .expect("test auth"); + let auth_manager = AuthManager::from_auth_for_testing(auth); + let agent_identity_manager = Arc::new(AgentIdentityManager::new_for_tests( + Arc::clone(&auth_manager), + /*feature_enabled*/ true, + chatgpt_base_url, + SessionSource::Exec, + )); + let stored_identity = agent_identity_manager + .seed_generated_identity_for_tests("agent-123") + .await + .expect("seed test identity"); + session.services.auth_manager = Arc::clone(&auth_manager); + session.services.agent_identity_manager = agent_identity_manager; + turn_context.auth_manager = Some(auth_manager); + session + .cache_agent_task_for_tests(RegisteredAgentTask { + agent_runtime_id: stored_identity.agent_runtime_id, + task_id: "task-123".to_string(), + registered_at: "2026-04-15T00:00:00Z".to_string(), + }) + .await; + } + #[tokio::test] async fn openai_file_argument_rewrite_requires_declared_file_params() { let (session, turn_context) = make_session_and_context().await; @@ -212,7 +308,7 @@ mod tests { .mount(&server) .await; - let (_, mut turn_context) = make_session_and_context().await; + let (session, mut turn_context) = make_session_and_context().await; let auth = CodexAuth::create_dummy_chatgpt_auth_for_testing(); let dir = tempdir().expect("temp dir"); let local_path = dir.path().join("file_report.csv"); @@ -226,6 +322,7 @@ mod tests { turn_context.config = Arc::new(config); let rewritten = build_uploaded_local_argument_value( + &session, &turn_context, Some(&auth), "file", @@ -293,7 +390,7 @@ mod tests { .mount(&server) .await; - let (_, mut turn_context) = make_session_and_context().await; + let (session, mut turn_context) = make_session_and_context().await; let auth = CodexAuth::create_dummy_chatgpt_auth_for_testing(); let dir = tempdir().expect("temp dir"); let local_path = dir.path().join("file_report.csv"); @@ -306,6 +403,7 @@ mod tests { config.chatgpt_base_url = format!("{}/backend-api", server.uri()); turn_context.config = Arc::new(config); let rewritten = rewrite_argument_value_for_openai_files( + &session, &turn_context, Some(&auth), "file", @@ -405,7 +503,7 @@ mod tests { .mount(&server) .await; - let (_, mut turn_context) = make_session_and_context().await; + let (session, mut turn_context) = make_session_and_context().await; let auth = CodexAuth::create_dummy_chatgpt_auth_for_testing(); let dir = tempdir().expect("temp dir"); tokio::fs::write(dir.path().join("one.csv"), b"one") @@ -420,6 +518,7 @@ mod tests { config.chatgpt_base_url = format!("{}/backend-api", server.uri()); turn_context.config = Arc::new(config); let rewritten = rewrite_argument_value_for_openai_files( + &session, &turn_context, Some(&auth), "files", @@ -471,4 +570,88 @@ mod tests { assert!(error.contains("failed to upload")); assert!(error.contains("file")); } + + #[tokio::test] + async fn build_uploaded_local_argument_value_uses_agent_assertion_for_cached_task() { + use wiremock::Mock; + use wiremock::MockServer; + use wiremock::ResponseTemplate; + use wiremock::matchers::body_json; + use wiremock::matchers::header_regex; + use wiremock::matchers::method; + use wiremock::matchers::path; + + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/backend-api/files")) + .and(header_regex("authorization", r"^AgentAssertion .+")) + .and(body_json(serde_json::json!({ + "file_name": "file_report.csv", + "file_size": 5, + "use_case": "codex", + }))) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "file_id": "file_123", + "upload_url": format!("{}/upload/file_123", server.uri()), + }))) + .expect(1) + .mount(&server) + .await; + Mock::given(method("PUT")) + .and(path("/upload/file_123")) + .respond_with(ResponseTemplate::new(200)) + .expect(1) + .mount(&server) + .await; + Mock::given(method("POST")) + .and(path("/backend-api/files/file_123/uploaded")) + .and(header_regex("authorization", r"^AgentAssertion .+")) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "status": "success", + "download_url": format!("{}/download/file_123", server.uri()), + "file_name": "file_report.csv", + "mime_type": "text/csv", + "file_size_bytes": 5, + }))) + .expect(1) + .mount(&server) + .await; + + let (mut session, mut turn_context) = make_session_and_context().await; + let auth = CodexAuth::create_dummy_chatgpt_auth_for_testing(); + let dir = tempdir().expect("temp dir"); + let local_path = dir.path().join("file_report.csv"); + tokio::fs::write(&local_path, b"hello") + .await + .expect("write local file"); + turn_context.cwd = AbsolutePathBuf::try_from(dir.path()).expect("absolute path"); + + let mut config = (*turn_context.config).clone(); + config.chatgpt_base_url = format!("{}/backend-api", server.uri()); + turn_context.config = Arc::new(config); + install_cached_agent_task_auth(&mut session, &mut turn_context, server.uri()).await; + + let rewritten = build_uploaded_local_argument_value( + &session, + &turn_context, + Some(&auth), + "file", + /*index*/ None, + "file_report.csv", + ) + .await + .expect("rewrite should upload the local file"); + + assert_eq!( + rewritten, + serde_json::json!({ + "download_url": format!("{}/download/file_123", server.uri()), + "file_id": "file_123", + "mime_type": "text/csv", + "file_name": "file_report.csv", + "uri": "sediment://file_123", + "file_size_bytes": 5, + }) + ); + } } diff --git a/codex-rs/core/src/session/agent_task_lifecycle.rs b/codex-rs/core/src/session/agent_task_lifecycle.rs index 789386344a..888848741d 100644 --- a/codex-rs/core/src/session/agent_task_lifecycle.rs +++ b/codex-rs/core/src/session/agent_task_lifecycle.rs @@ -101,6 +101,11 @@ impl Session { agent_task } + #[cfg(test)] + pub(crate) async fn cache_agent_task_for_tests(&self, agent_task: RegisteredAgentTask) { + self.cache_agent_task(agent_task).await; + } + pub(super) async fn cached_agent_task_for_current_identity( &self, ) -> Option { @@ -134,6 +139,33 @@ impl Session { None } + pub(crate) async fn authorization_header_for_current_agent_task( + &self, + ) -> anyhow::Result> { + let Some(agent_task) = self.cached_agent_task_for_current_identity().await else { + return Ok(None); + }; + + let Some(auth) = self.services.auth_manager.auth().await else { + return Ok(None); + }; + let authorization_header_value = self + .services + .auth_manager + .chatgpt_agent_task_authorization_header_for_auth( + &auth, + agent_task.authorization_target(), + )?; + if authorization_header_value.is_some() { + debug!( + agent_runtime_id = %agent_task.agent_runtime_id, + task_id = %agent_task.task_id, + "using agent assertion authorization for current task request" + ); + } + Ok(authorization_header_value) + } + pub(super) async fn ensure_agent_task_registered( &self, ) -> anyhow::Result> { diff --git a/codex-rs/core/src/session/session.rs b/codex-rs/core/src/session/session.rs index 406cb1b270..b3e6802412 100644 --- a/codex-rs/core/src/session/session.rs +++ b/codex-rs/core/src/session/session.rs @@ -619,6 +619,11 @@ impl Session { config.analytics_enabled, ) }); + let agent_identity_manager = Arc::new(AgentIdentityManager::new( + config.as_ref(), + Arc::clone(&auth_manager), + session_configuration.session_source.clone(), + )); let services = SessionServices { // Initialize the MCP connection manager with an uninitialized // instance. It will be replaced with one created via @@ -641,11 +646,7 @@ impl Session { hooks, rollout: Mutex::new(rollout_recorder), user_shell: Arc::new(default_shell), - agent_identity_manager: Arc::new(AgentIdentityManager::new( - config.as_ref(), - Arc::clone(&auth_manager), - session_configuration.session_source.clone(), - )), + agent_identity_manager: Arc::clone(&agent_identity_manager), shell_snapshot_tx, show_raw_agent_reasoning: config.show_raw_agent_reasoning, exec_policy, diff --git a/codex-rs/core/src/session/turn.rs b/codex-rs/core/src/session/turn.rs index 33c8402a05..75d1bdea7f 100644 --- a/codex-rs/core/src/session/turn.rs +++ b/codex-rs/core/src/session/turn.rs @@ -333,20 +333,23 @@ pub(crate) async fn run_turn( })) .await; } - if let Err(error) = sess.ensure_agent_task_registered().await { - warn!(error = %error, "agent task registration failed"); - sess.send_event( - turn_context.as_ref(), - EventMsg::Error(ErrorEvent { - message: format!( - "Agent task registration failed. Please try again; Codex will attempt to register the task again on the next turn: {error}" - ), - codex_error_info: Some(CodexErrorInfo::Other), - }), - ) - .await; - return None; - } + let agent_task = match sess.ensure_agent_task_registered().await { + Ok(agent_task) => agent_task, + Err(error) => { + warn!(error = %error, "agent task registration failed"); + sess.send_event( + turn_context.as_ref(), + EventMsg::Error(ErrorEvent { + message: format!( + "Agent task registration failed. Please try again; Codex will attempt to register the task again on the next turn: {error}" + ), + codex_error_info: Some(CodexErrorInfo::Other), + }), + ) + .await; + return None; + } + }; if !skill_items.is_empty() { sess.record_conversation_items(&turn_context, &skill_items) @@ -371,8 +374,21 @@ pub(crate) async fn run_turn( // `ModelClientSession` is turn-scoped and caches WebSocket + sticky routing state, so we reuse // one instance across retries within this turn. - let mut client_session = - prewarmed_client_session.unwrap_or_else(|| sess.services.model_client.new_session()); + let mut prewarmed_client_session = prewarmed_client_session; + if agent_task.is_some() + && let Some(prewarmed_client_session) = prewarmed_client_session.as_mut() + { + prewarmed_client_session.disable_cached_websocket_session_on_drop(); + } + let mut client_session = if let Some(agent_task) = agent_task { + sess.services + .model_client + .new_session_with_agent_task(Some(agent_task)) + } else if let Some(prewarmed_client_session) = prewarmed_client_session.take() { + prewarmed_client_session + } else { + sess.services.model_client.new_session() + }; // Pending input is drained into history before building the next model request. // However, we defer that drain until after sampling in two cases: // 1. At the start of a turn, so the fresh user prompt in `input` gets sampled first. diff --git a/codex-rs/login/Cargo.toml b/codex-rs/login/Cargo.toml index d5303ea54c..16d1836024 100644 --- a/codex-rs/login/Cargo.toml +++ b/codex-rs/login/Cargo.toml @@ -8,6 +8,7 @@ license.workspace = true workspace = true [dependencies] +anyhow = { workspace = true } async-trait = { workspace = true } base64 = { workspace = true } chrono = { workspace = true, features = ["serde"] } @@ -20,6 +21,7 @@ codex-otel = { workspace = true } codex-protocol = { workspace = true } codex-terminal-detection = { workspace = true } codex-utils-template = { workspace = true } +ed25519-dalek = { workspace = true } once_cell = { workspace = true } os_info = { workspace = true } rand = { workspace = true } @@ -42,7 +44,6 @@ urlencoding = { workspace = true } webbrowser = { workspace = true } [dev-dependencies] -anyhow = { workspace = true } core_test_support = { workspace = true } keyring = { workspace = true } pretty_assertions = { workspace = true } diff --git a/codex-rs/login/src/auth/agent_assertion.rs b/codex-rs/login/src/auth/agent_assertion.rs new file mode 100644 index 0000000000..e4a1731f65 --- /dev/null +++ b/codex-rs/login/src/auth/agent_assertion.rs @@ -0,0 +1,172 @@ +use std::collections::BTreeMap; + +use anyhow::Context; +use anyhow::Result; +use base64::Engine as _; +use base64::engine::general_purpose::STANDARD as BASE64_STANDARD; +use base64::engine::general_purpose::URL_SAFE_NO_PAD; +use chrono::SecondsFormat; +use chrono::Utc; +use ed25519_dalek::Signer as _; +use ed25519_dalek::SigningKey; +use ed25519_dalek::pkcs8::DecodePrivateKey; +use serde::Deserialize; +use serde::Serialize; + +use super::storage::AgentIdentityAuthRecord; + +/// Task binding to use when constructing a task-scoped AgentAssertion. +/// +/// The caller owns the task lifecycle. `AuthManager` only uses this target to +/// sign an authorization header with the stored agent identity key material. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct AgentTaskAuthorizationTarget<'a> { + pub agent_runtime_id: &'a str, + pub task_id: &'a str, +} + +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)] +struct AgentAssertionEnvelope { + agent_runtime_id: String, + task_id: String, + timestamp: String, + signature: String, +} + +pub(super) fn authorization_header_for_agent_task( + record: &AgentIdentityAuthRecord, + target: AgentTaskAuthorizationTarget<'_>, +) -> Result { + anyhow::ensure!( + record.agent_runtime_id == target.agent_runtime_id, + "agent task runtime {} does not match stored agent identity {}", + target.agent_runtime_id, + record.agent_runtime_id + ); + + let timestamp = Utc::now().to_rfc3339_opts(SecondsFormat::Secs, true); + let envelope = AgentAssertionEnvelope { + agent_runtime_id: target.agent_runtime_id.to_string(), + task_id: target.task_id.to_string(), + timestamp: timestamp.clone(), + signature: sign_agent_assertion_payload(record, target, ×tamp)?, + }; + let serialized_assertion = serialize_agent_assertion(&envelope)?; + Ok(format!("AgentAssertion {serialized_assertion}")) +} + +fn sign_agent_assertion_payload( + record: &AgentIdentityAuthRecord, + target: AgentTaskAuthorizationTarget<'_>, + timestamp: &str, +) -> Result { + let signing_key = signing_key_from_agent_private_key(&record.agent_private_key)?; + let payload = format!("{}:{}:{timestamp}", target.agent_runtime_id, target.task_id); + Ok(BASE64_STANDARD.encode(signing_key.sign(payload.as_bytes()).to_bytes())) +} + +fn serialize_agent_assertion(envelope: &AgentAssertionEnvelope) -> Result { + let payload = serde_json::to_vec(&BTreeMap::from([ + ("agent_runtime_id", envelope.agent_runtime_id.as_str()), + ("signature", envelope.signature.as_str()), + ("task_id", envelope.task_id.as_str()), + ("timestamp", envelope.timestamp.as_str()), + ])) + .context("failed to serialize agent assertion envelope")?; + Ok(URL_SAFE_NO_PAD.encode(payload)) +} + +fn signing_key_from_agent_private_key(agent_private_key: &str) -> Result { + let private_key = BASE64_STANDARD + .decode(agent_private_key) + .context("stored agent identity private key is not valid base64")?; + SigningKey::from_pkcs8_der(&private_key) + .context("stored agent identity private key is not valid PKCS#8") +} + +#[cfg(test)] +mod tests { + use ed25519_dalek::Signature; + use ed25519_dalek::Verifier as _; + use ed25519_dalek::pkcs8::EncodePrivateKey; + use pretty_assertions::assert_eq; + + use super::*; + + #[test] + fn authorization_header_for_agent_task_serializes_signed_agent_assertion() { + let record = test_agent_identity_record("agent-123"); + let target = AgentTaskAuthorizationTarget { + agent_runtime_id: "agent-123", + task_id: "task-123", + }; + + let header = authorization_header_for_agent_task(&record, target) + .expect("build agent assertion header"); + let token = header + .strip_prefix("AgentAssertion ") + .expect("agent assertion scheme"); + let payload = URL_SAFE_NO_PAD + .decode(token) + .expect("valid base64url payload"); + let envelope: AgentAssertionEnvelope = + serde_json::from_slice(&payload).expect("valid assertion envelope"); + + assert_eq!( + envelope, + AgentAssertionEnvelope { + agent_runtime_id: "agent-123".to_string(), + task_id: "task-123".to_string(), + timestamp: envelope.timestamp.clone(), + signature: envelope.signature.clone(), + } + ); + let signature_bytes = BASE64_STANDARD + .decode(&envelope.signature) + .expect("valid base64 signature"); + let signature = Signature::from_slice(&signature_bytes).expect("valid signature bytes"); + signing_key_from_agent_private_key(&record.agent_private_key) + .expect("signing key") + .verifying_key() + .verify( + format!( + "{}:{}:{}", + envelope.agent_runtime_id, envelope.task_id, envelope.timestamp + ) + .as_bytes(), + &signature, + ) + .expect("signature should verify"); + } + + #[test] + fn authorization_header_for_agent_task_rejects_mismatched_runtime() { + let record = test_agent_identity_record("agent-123"); + let target = AgentTaskAuthorizationTarget { + agent_runtime_id: "agent-456", + task_id: "task-123", + }; + + let error = authorization_header_for_agent_task(&record, target) + .expect_err("runtime mismatch should fail"); + + assert_eq!( + error.to_string(), + "agent task runtime agent-456 does not match stored agent identity agent-123" + ); + } + + fn test_agent_identity_record(agent_runtime_id: &str) -> AgentIdentityAuthRecord { + let signing_key = SigningKey::from_bytes(&[7u8; 32]); + let private_key = signing_key + .to_pkcs8_der() + .expect("encode test key material"); + AgentIdentityAuthRecord { + workspace_id: "account-123".to_string(), + chatgpt_user_id: Some("user-123".to_string()), + agent_runtime_id: agent_runtime_id.to_string(), + agent_private_key: BASE64_STANDARD.encode(private_key.as_bytes()), + registered_at: "2026-03-23T12:00:00Z".to_string(), + } + } +} diff --git a/codex-rs/login/src/auth/manager.rs b/codex-rs/login/src/auth/manager.rs index c67d73fedf..bcfec41979 100644 --- a/codex-rs/login/src/auth/manager.rs +++ b/codex-rs/login/src/auth/manager.rs @@ -1,3 +1,4 @@ +use anyhow::Context; use async_trait::async_trait; use chrono::Utc; use reqwest::StatusCode; @@ -20,6 +21,8 @@ use codex_app_server_protocol::AuthMode as ApiAuthMode; use codex_protocol::config_types::ForcedLoginMethod; use codex_protocol::config_types::ModelProviderAuthInfo; +use super::agent_assertion; +use super::agent_assertion::AgentTaskAuthorizationTarget; use super::external_bearer::BearerTokenRefresher; use super::revoke::revoke_auth_tokens; pub use crate::auth::storage::AgentIdentityAuthRecord; @@ -1483,6 +1486,43 @@ impl AuthManager { .and_then(|guard| guard.clone()) } + pub fn chatgpt_agent_task_authorization_header_for_auth( + &self, + auth: &CodexAuth, + target: AgentTaskAuthorizationTarget<'_>, + ) -> anyhow::Result> { + let Some(record) = self.agent_identity_for_chatgpt_auth(auth)? else { + return Ok(None); + }; + agent_assertion::authorization_header_for_agent_task(&record, target).map(Some) + } + + fn agent_identity_for_chatgpt_auth( + &self, + auth: &CodexAuth, + ) -> anyhow::Result> { + if !auth.is_chatgpt_auth() { + return Ok(None); + } + + let token_data = auth + .get_token_data() + .context("ChatGPT token data is not available")?; + let workspace_id = self + .forced_chatgpt_workspace_id() + .filter(|value| !value.is_empty()) + .or(token_data.account_id.filter(|value| !value.is_empty())); + + let Some(workspace_id) = workspace_id else { + return Ok(None); + }; + let Some(record) = auth.get_agent_identity(&workspace_id) else { + anyhow::bail!("agent identity is not available for workspace {workspace_id}"); + }; + + Ok(Some(record)) + } + pub fn subscribe_auth_state(&self) -> watch::Receiver<()> { self.auth_state_tx.subscribe() } diff --git a/codex-rs/login/src/auth/mod.rs b/codex-rs/login/src/auth/mod.rs index b927f9a775..62ab467d08 100644 --- a/codex-rs/login/src/auth/mod.rs +++ b/codex-rs/login/src/auth/mod.rs @@ -1,3 +1,4 @@ +mod agent_assertion; pub mod default_client; pub mod error; mod storage; @@ -7,6 +8,7 @@ mod external_bearer; mod manager; mod revoke; +pub use agent_assertion::AgentTaskAuthorizationTarget; pub use error::RefreshTokenFailedError; pub use error::RefreshTokenFailedReason; pub use manager::*; diff --git a/codex-rs/login/src/lib.rs b/codex-rs/login/src/lib.rs index d819b0946d..046d878e87 100644 --- a/codex-rs/login/src/lib.rs +++ b/codex-rs/login/src/lib.rs @@ -18,6 +18,7 @@ pub use server::ShutdownHandle; pub use server::run_login_server; pub use auth::AgentIdentityAuthRecord; +pub use auth::AgentTaskAuthorizationTarget; pub use auth::AuthConfig; pub use auth::AuthDotJson; pub use auth::AuthManager; diff --git a/codex-rs/model-provider/src/bearer_auth_provider.rs b/codex-rs/model-provider/src/bearer_auth_provider.rs index 5a24ca6f78..970574c752 100644 --- a/codex-rs/model-provider/src/bearer_auth_provider.rs +++ b/codex-rs/model-provider/src/bearer_auth_provider.rs @@ -38,6 +38,55 @@ impl AuthProvider for BearerAuthProvider { } } +/// Auth provider for callers that already resolved the complete Authorization header value. +#[derive(Clone, Default)] +pub struct AuthorizationHeaderAuthProvider { + pub authorization_header_value: Option, + pub account_id: Option, + pub is_fedramp_account: bool, +} + +impl AuthorizationHeaderAuthProvider { + pub fn new(authorization_header_value: Option, account_id: Option) -> Self { + Self { + authorization_header_value, + account_id, + is_fedramp_account: false, + } + } + + pub fn for_test(authorization_header_value: Option<&str>, account_id: Option<&str>) -> Self { + Self { + authorization_header_value: authorization_header_value.map(str::to_string), + account_id: account_id.map(str::to_string), + is_fedramp_account: false, + } + } + + pub fn with_fedramp_routing_header(mut self) -> Self { + self.is_fedramp_account = true; + self + } +} + +impl AuthProvider for AuthorizationHeaderAuthProvider { + fn add_auth_headers(&self, headers: &mut HeaderMap) { + if let Some(authorization_header_value) = self.authorization_header_value.as_ref() + && let Ok(header) = HeaderValue::from_str(authorization_header_value) + { + let _ = headers.insert(http::header::AUTHORIZATION, header); + } + if let Some(account_id) = self.account_id.as_ref() + && let Ok(header) = HeaderValue::from_str(account_id) + { + let _ = headers.insert("ChatGPT-Account-ID", header); + } + if self.is_fedramp_account { + let _ = headers.insert("X-OpenAI-Fedramp", HeaderValue::from_static("true")); + } + } +} + #[cfg(test)] mod tests { use super::*; @@ -99,4 +148,54 @@ mod tests { Some("true") ); } + + #[test] + fn authorization_header_auth_provider_supports_non_bearer_authorization_headers() { + let auth = AuthorizationHeaderAuthProvider::for_test( + Some("AgentAssertion opaque-token"), + Some("workspace-123"), + ); + let mut headers = HeaderMap::new(); + + auth.add_auth_headers(&mut headers); + + assert_eq!( + headers + .get(http::header::AUTHORIZATION) + .and_then(|value| value.to_str().ok()), + Some("AgentAssertion opaque-token") + ); + assert_eq!( + headers + .get("ChatGPT-Account-ID") + .and_then(|value| value.to_str().ok()), + Some("workspace-123") + ); + assert_eq!( + codex_api::auth_header_telemetry(&auth), + codex_api::AuthHeaderTelemetry { + attached: true, + name: Some("authorization"), + } + ); + } + + #[test] + fn authorization_header_auth_provider_adds_fedramp_routing_header_when_enabled() { + let auth = AuthorizationHeaderAuthProvider::for_test( + Some("AgentAssertion opaque-token"), + Some("workspace-123"), + ) + .with_fedramp_routing_header(); + let mut headers = HeaderMap::new(); + + auth.add_auth_headers(&mut headers); + + assert_eq!( + headers + .get("X-OpenAI-Fedramp") + .and_then(|value| value.to_str().ok()), + Some("true") + ); + } } diff --git a/codex-rs/model-provider/src/lib.rs b/codex-rs/model-provider/src/lib.rs index f240c47db0..1874f37e31 100644 --- a/codex-rs/model-provider/src/lib.rs +++ b/codex-rs/model-provider/src/lib.rs @@ -2,6 +2,7 @@ mod auth; mod bearer_auth_provider; mod provider; +pub use bearer_auth_provider::AuthorizationHeaderAuthProvider; pub use bearer_auth_provider::BearerAuthProvider; pub use bearer_auth_provider::BearerAuthProvider as CoreAuthProvider; pub use provider::ModelProvider;