From ede0da1cf24e98437b18ab489f3af925375f0739 Mon Sep 17 00:00:00 2001 From: adrian Date: Wed, 22 Apr 2026 16:18:44 -0700 Subject: [PATCH] feat: use thread agent task auth for inference --- codex-rs/core/src/agent/control.rs | 4 +- codex-rs/core/src/client.rs | 65 +++- codex-rs/core/src/client_tests.rs | 29 +- codex-rs/core/src/session/session.rs | 12 +- codex-rs/login/src/auth/agent_identity.rs | 275 +++++++++++++++-- codex-rs/login/src/auth/auth_tests.rs | 9 +- .../model-provider/src/amazon_bedrock/mod.rs | 5 + codex-rs/model-provider/src/auth.rs | 281 ++++++++++++++++-- codex-rs/model-provider/src/lib.rs | 5 + .../model-provider/src/models_endpoint.rs | 9 +- codex-rs/model-provider/src/provider.rs | 12 +- 11 files changed, 653 insertions(+), 53 deletions(-) diff --git a/codex-rs/core/src/agent/control.rs b/codex-rs/core/src/agent/control.rs index 71bc026a13..3584bf41c3 100644 --- a/codex-rs/core/src/agent/control.rs +++ b/codex-rs/core/src/agent/control.rs @@ -722,7 +722,9 @@ impl AgentControl { } else { state.send_op(agent_id, Op::Shutdown {}).await }; - thread.wait_until_terminated().await; + if result.is_ok() || matches!(result, Err(CodexErr::InternalAgentDied)) { + thread.wait_until_terminated().await; + } result } else { state.send_op(agent_id, Op::Shutdown {}).await diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs index f604a63458..19abe88944 100644 --- a/codex-rs/core/src/client.rs +++ b/codex-rs/core/src/client.rs @@ -116,8 +116,11 @@ use crate::util::emit_feedback_auth_recovery_tags; use codex_api::map_api_error; use codex_feedback::FeedbackRequestTags; use codex_feedback::emit_feedback_request_tags_with_auth_env; +use codex_login::auth::AgentIdentityAuthPolicy; use codex_login::auth_env_telemetry::AuthEnvTelemetry; use codex_login::auth_env_telemetry::collect_auth_env_telemetry; +use codex_model_provider::AgentTaskExternalRef; +use codex_model_provider::ProviderAuthScope; use codex_model_provider::SharedModelProvider; use codex_model_provider::create_model_provider; #[cfg(test)] @@ -170,6 +173,8 @@ struct ModelClientState { provider: SharedModelProvider, auth_env_telemetry: AuthEnvTelemetry, session_source: SessionSource, + agent_identity_policy: AgentIdentityAuthPolicy, + chatgpt_base_url: Option, model_verbosity: Option, enable_request_compression: bool, include_timing_metrics: bool, @@ -321,6 +326,39 @@ impl ModelClient { include_timing_metrics: bool, beta_features_header: Option, attestation_provider: Option>, + ) -> Self { + Self::new_with_agent_identity_policy( + auth_manager, + session_id, + thread_id, + installation_id, + provider_info, + session_source, + AgentIdentityAuthPolicy::JwtOnly, + /*chatgpt_base_url*/ None, + model_verbosity, + enable_request_compression, + include_timing_metrics, + beta_features_header, + attestation_provider, + ) + } + + #[allow(clippy::too_many_arguments)] + pub fn new_with_agent_identity_policy( + auth_manager: Option>, + session_id: SessionId, + thread_id: ThreadId, + installation_id: String, + provider_info: ModelProviderInfo, + session_source: SessionSource, + agent_identity_policy: AgentIdentityAuthPolicy, + chatgpt_base_url: Option, + model_verbosity: Option, + enable_request_compression: bool, + include_timing_metrics: bool, + beta_features_header: Option, + attestation_provider: Option>, ) -> Self { let model_provider = create_model_provider(provider_info, auth_manager); let codex_api_key_env_enabled = model_provider @@ -339,6 +377,8 @@ impl ModelClient { provider: model_provider, auth_env_telemetry, session_source, + agent_identity_policy, + chatgpt_base_url, model_verbosity, enable_request_compression, include_timing_metrics, @@ -785,7 +825,11 @@ impl ModelClient { async fn current_client_setup(&self) -> Result { let auth = self.state.provider.auth().await; let api_provider = self.state.provider.api_provider().await?; - let api_auth = self.state.provider.api_auth().await?; + let api_auth = self + .state + .provider + .api_auth_for_scope(self.provider_auth_scope()) + .await?; Ok(CurrentClientSetup { auth, api_provider, @@ -793,6 +837,15 @@ impl ModelClient { }) } + fn provider_auth_scope(&self) -> ProviderAuthScope { + ProviderAuthScope::Thread { + external_ref: AgentTaskExternalRef::new(self.state.thread_id.to_string()), + agent_identity_policy: self.state.agent_identity_policy, + session_source: self.state.session_source.clone(), + chatgpt_base_url: self.state.chatgpt_base_url.clone(), + } + } + /// Opens a websocket connection using the same header and telemetry wiring as normal turns. /// /// Both startup prewarm and in-turn `needs_new` reconnects call this path so handshake @@ -931,6 +984,10 @@ impl Drop for ModelClientSession { } impl ModelClientSession { + async fn current_client_setup(&self) -> Result { + self.client.current_client_setup().await + } + pub(crate) fn reset_websocket_session(&mut self) { self.websocket_session.connection = None; self.websocket_session.last_request = None; @@ -1077,7 +1134,7 @@ impl ModelClientSession { return Ok(()); } - let client_setup = self.client.current_client_setup().await.map_err(|err| { + let client_setup = self.current_client_setup().await.map_err(|err| { ApiError::Stream(format!( "failed to build websocket prewarm client setup: {err}" )) @@ -1224,7 +1281,7 @@ impl ModelClientSession { .map(AuthManager::unauthorized_recovery); let mut pending_retry = PendingUnauthorizedRetry::default(); loop { - let client_setup = self.client.current_client_setup().await?; + let client_setup = self.current_client_setup().await?; let transport = ReqwestTransport::new(build_reqwest_client()); let request_auth_context = AuthRequestTelemetryContext::new( client_setup.auth.as_ref().map(CodexAuth::auth_mode), @@ -1340,7 +1397,7 @@ impl ModelClientSession { .map(AuthManager::unauthorized_recovery); let mut pending_retry = PendingUnauthorizedRetry::default(); loop { - let client_setup = self.client.current_client_setup().await?; + let client_setup = self.current_client_setup().await?; let request_auth_context = AuthRequestTelemetryContext::new( client_setup.auth.as_ref().map(CodexAuth::auth_mode), client_setup.api_auth.as_ref(), diff --git a/codex-rs/core/src/client_tests.rs b/codex-rs/core/src/client_tests.rs index b9d9172c83..711939f3ec 100644 --- a/codex-rs/core/src/client_tests.rs +++ b/codex-rs/core/src/client_tests.rs @@ -15,7 +15,10 @@ use codex_api::ResponseEvent; use codex_app_server_protocol::AuthMode; use codex_login::AuthManager; use codex_login::CodexAuth; +use codex_login::auth::AgentIdentityAuthPolicy; +use codex_model_provider::AgentTaskExternalRef; use codex_model_provider::BearerAuthProvider; +use codex_model_provider::ProviderAuthScope; use codex_model_provider_info::CHATGPT_CODEX_BASE_URL; use codex_model_provider_info::ModelProviderInfo; use codex_model_provider_info::WireApi; @@ -61,8 +64,15 @@ use tracing_subscriber::registry::LookupSpan; use tracing_subscriber::util::SubscriberInitExt; fn test_model_client(session_source: SessionSource) -> ModelClient { + test_model_client_with_thread_id(ThreadId::new(), session_source) +} + +fn test_model_client_with_thread_id( + conversation_id: ThreadId, + session_source: SessionSource, +) -> ModelClient { let provider = create_oss_provider_with_base_url("https://example.com/v1", WireApi::Responses); - let thread_id = ThreadId::new(); + let thread_id = conversation_id; ModelClient::new( /*auth_manager*/ None, thread_id.into(), @@ -78,6 +88,23 @@ fn test_model_client(session_source: SessionSource) -> ModelClient { ) } +#[test] +fn provider_auth_scope_uses_thread_id_as_session_ref() { + let conversation_id = + ThreadId::from_string("018f4f4c-43f5-7b28-8e24-000000000001").expect("valid thread id"); + let client = test_model_client_with_thread_id(conversation_id, SessionSource::Cli); + + assert_eq!( + client.provider_auth_scope(), + ProviderAuthScope::Thread { + external_ref: AgentTaskExternalRef::new(conversation_id.to_string()), + agent_identity_policy: AgentIdentityAuthPolicy::JwtOnly, + session_source: SessionSource::Cli, + chatgpt_base_url: None, + } + ); +} + fn test_model_info() -> ModelInfo { serde_json::from_value(json!({ "slug": "gpt-test", diff --git a/codex-rs/core/src/session/session.rs b/codex-rs/core/src/session/session.rs index 925ee2df59..41672df411 100644 --- a/codex-rs/core/src/session/session.rs +++ b/codex-rs/core/src/session/session.rs @@ -928,13 +928,19 @@ impl Session { live_thread: live_thread_init.as_ref().cloned(), thread_store: Arc::clone(&thread_store), attestation_provider: attestation_provider.clone(), - model_client: ModelClient::new( + model_client: ModelClient::new_with_agent_identity_policy( Some(Arc::clone(&auth_manager)), session_id, thread_id, installation_id.clone(), session_configuration.provider.clone(), session_configuration.session_source.clone(), + if config.features.enabled(Feature::UseAgentIdentity) { + codex_login::auth::AgentIdentityAuthPolicy::JwtOrChatgpt + } else { + codex_login::auth::AgentIdentityAuthPolicy::JwtOnly + }, + Some(config.chatgpt_base_url.clone()), config.model_verbosity, config.features.enabled(Feature::EnableRequestCompression), config.features.enabled(Feature::RuntimeMetrics), @@ -1119,8 +1125,6 @@ impl Session { anyhow::bail!("required MCP servers failed to initialize: {details}"); } } - sess.schedule_startup_prewarm(session_configuration.base_instructions.clone()) - .await; let session_start_source = match &initial_history { InitialHistory::Resumed(_) => codex_hooks::SessionStartSource::Resume, InitialHistory::New | InitialHistory::Forked(_) => { @@ -1131,6 +1135,8 @@ impl Session { // record_initial_history can emit events. We record only after the SessionConfiguredEvent is emitted. sess.record_initial_history(initial_history).await; + sess.schedule_startup_prewarm(session_configuration.base_instructions.clone()) + .await; { let mut state = sess.state.lock().await; state.set_pending_session_start_source(Some(session_start_source)); diff --git a/codex-rs/login/src/auth/agent_identity.rs b/codex-rs/login/src/auth/agent_identity.rs index 3f55aaf901..1181a926c5 100644 --- a/codex-rs/login/src/auth/agent_identity.rs +++ b/codex-rs/login/src/auth/agent_identity.rs @@ -1,8 +1,15 @@ +use std::collections::HashMap; use std::sync::Arc; +use std::sync::Mutex; use codex_agent_identity::AgentIdentityKey; +use codex_agent_identity::AgentRuntimeId; +use codex_agent_identity::AgentTaskExternalRef; +use codex_agent_identity::AgentTaskId; +use codex_agent_identity::AgentTaskKind; +use codex_agent_identity::RegisteredAgentTask; use codex_agent_identity::normalize_chatgpt_base_url; -use codex_agent_identity::register_agent_task; +use codex_agent_identity::register_agent_task_with_external_ref; use codex_protocol::account::PlanType as AccountPlanType; use tokio::sync::OnceCell; @@ -15,14 +22,20 @@ const DEFAULT_CHATGPT_BACKEND_BASE_URL: &str = "https://chatgpt.com/backend-api" #[derive(Debug)] pub struct AgentIdentityAuth { record: AgentIdentityAuthRecord, - process_task_id: Arc>, + task_ids: Arc>>>>, +} + +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +enum AgentTaskCacheKey { + Process, + Thread(AgentTaskExternalRef), } impl Clone for AgentIdentityAuth { fn clone(&self) -> Self { Self { record: self.record.clone(), - process_task_id: Arc::clone(&self.process_task_id), + task_ids: Arc::clone(&self.task_ids), } } } @@ -31,7 +44,7 @@ impl AgentIdentityAuth { pub fn new(record: AgentIdentityAuthRecord) -> Self { Self { record, - process_task_id: Arc::new(OnceCell::new()), + task_ids: Arc::new(Mutex::new(HashMap::new())), } } @@ -39,24 +52,58 @@ impl AgentIdentityAuth { &self.record } - pub fn process_task_id(&self) -> Option<&str> { - self.process_task_id.get().map(String::as_str) + pub fn process_task_id(&self) -> Option { + self.task_id_for_initialized_key(&AgentTaskCacheKey::Process) } pub async fn ensure_runtime(&self, chatgpt_base_url: Option) -> std::io::Result<()> { - self.process_task_id - .get_or_try_init(|| async { - let base_url = normalize_chatgpt_base_url( - chatgpt_base_url - .as_deref() - .unwrap_or(DEFAULT_CHATGPT_BACKEND_BASE_URL), - ); - register_agent_task(&build_reqwest_client(), &base_url, self.key()) - .await - .map_err(std::io::Error::other) - }) + self.task_id_for_key( + AgentTaskCacheKey::Process, + chatgpt_base_url, + /*external_ref*/ None, + ) + .await + .map(|_| ()) + } + + pub async fn registered_thread_task( + &self, + external_ref: AgentTaskExternalRef, + chatgpt_base_url: Option, + ) -> std::io::Result { + let task_id = self + .task_id_for_key( + AgentTaskCacheKey::Thread(external_ref.clone()), + chatgpt_base_url, + Some(external_ref), + ) + .await?; + Ok(self.registered_task(task_id, AgentTaskKind::Thread)) + } + + pub async fn register_task(&self, chatgpt_base_url: Option) -> std::io::Result { + self.register_task_with_external_ref(chatgpt_base_url, /*external_ref*/ None) .await - .map(|_| ()) + } + + async fn register_task_with_external_ref( + &self, + chatgpt_base_url: Option, + external_ref: Option<&AgentTaskExternalRef>, + ) -> std::io::Result { + let base_url = normalize_chatgpt_base_url( + chatgpt_base_url + .as_deref() + .unwrap_or(DEFAULT_CHATGPT_BACKEND_BASE_URL), + ); + register_agent_task_with_external_ref( + &build_reqwest_client(), + &base_url, + self.key(), + external_ref, + ) + .await + .map_err(std::io::Error::other) } pub fn account_id(&self) -> &str { @@ -84,4 +131,196 @@ impl AgentIdentityAuth { private_key_pkcs8_base64: &self.record.agent_private_key, } } + + async fn task_id_for_key( + &self, + key: AgentTaskCacheKey, + chatgpt_base_url: Option, + external_ref: Option, + ) -> std::io::Result { + let slot = self.task_slot(key)?; + slot.get_or_try_init(|| async { + self.register_task_with_external_ref(chatgpt_base_url, external_ref.as_ref()) + .await + }) + .await + .cloned() + } + + fn task_slot(&self, key: AgentTaskCacheKey) -> std::io::Result>> { + let mut task_ids = self + .task_ids + .lock() + .map_err(|_| std::io::Error::other("failed to lock agent task cache"))?; + Ok(task_ids + .entry(key) + .or_insert_with(|| Arc::new(OnceCell::new())) + .clone()) + } + + fn task_id_for_initialized_key(&self, key: &AgentTaskCacheKey) -> Option { + let task_ids = self.task_ids.lock().ok()?; + task_ids.get(key)?.get().cloned() + } + + fn registered_task(&self, task_id: String, kind: AgentTaskKind) -> RegisteredAgentTask { + RegisteredAgentTask::new( + AgentRuntimeId::new(self.record.agent_runtime_id.clone()), + AgentTaskId::new(task_id), + kind, + ) + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + use std::sync::atomic::AtomicUsize; + use std::sync::atomic::Ordering; + + use codex_agent_identity::generate_agent_key_material; + use pretty_assertions::assert_eq; + use serde_json::json; + use wiremock::Mock; + use wiremock::MockServer; + use wiremock::ResponseTemplate; + use wiremock::matchers::body_partial_json; + use wiremock::matchers::method; + use wiremock::matchers::path; + + use super::*; + + fn agent_identity_record(private_key: String) -> AgentIdentityAuthRecord { + AgentIdentityAuthRecord { + agent_runtime_id: "agent-runtime-1".to_string(), + agent_private_key: private_key, + account_id: "account-1".to_string(), + chatgpt_user_id: "user-1".to_string(), + email: "agent@example.com".to_string(), + plan_type: AccountPlanType::Plus, + chatgpt_account_is_fedramp: false, + registered_at: None, + } + } + + fn agent_identity_auth() -> AgentIdentityAuth { + let key_material = generate_agent_key_material().expect("generate key material"); + AgentIdentityAuth::new(agent_identity_record(key_material.private_key_pkcs8_base64)) + } + + #[tokio::test] + async fn registered_thread_task_registers_once_per_external_ref() -> anyhow::Result<()> { + let auth = agent_identity_auth(); + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/v1/agent/agent-runtime-1/task/register")) + .and(body_partial_json(json!({ + "external_task_ref": "thread-1", + }))) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "task_id": "task-thread-1", + }))) + .expect(1) + .mount(&server) + .await; + + let first = auth + .registered_thread_task(AgentTaskExternalRef::new("thread-1"), Some(server.uri())) + .await?; + let second = auth + .registered_thread_task(AgentTaskExternalRef::new("thread-1"), Some(server.uri())) + .await?; + + assert_eq!(first, second); + assert_eq!( + first, + RegisteredAgentTask::new( + AgentRuntimeId::new("agent-runtime-1"), + AgentTaskId::new("task-thread-1"), + AgentTaskKind::Thread, + ) + ); + Ok(()) + } + + #[tokio::test] + async fn registered_thread_task_uses_distinct_external_refs() -> anyhow::Result<()> { + let auth = agent_identity_auth(); + let server = MockServer::start().await; + for (external_ref, task_id) in + [("thread-1", "task-thread-1"), ("thread-2", "task-thread-2")] + { + Mock::given(method("POST")) + .and(path("/v1/agent/agent-runtime-1/task/register")) + .and(body_partial_json(json!({ + "external_task_ref": external_ref, + }))) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "task_id": task_id, + }))) + .expect(1) + .mount(&server) + .await; + } + + let first = auth + .registered_thread_task(AgentTaskExternalRef::new("thread-1"), Some(server.uri())) + .await?; + let second = auth + .registered_thread_task(AgentTaskExternalRef::new("thread-2"), Some(server.uri())) + .await?; + + assert_eq!(first.task_id.as_str(), "task-thread-1"); + assert_eq!(second.task_id.as_str(), "task-thread-2"); + Ok(()) + } + + #[tokio::test] + async fn failed_thread_task_registration_can_retry() -> anyhow::Result<()> { + let auth = agent_identity_auth(); + let server = MockServer::start().await; + let request_count = Arc::new(AtomicUsize::new(0)); + let response_count = Arc::clone(&request_count); + Mock::given(method("POST")) + .and(path("/v1/agent/agent-runtime-1/task/register")) + .and(body_partial_json(json!({ + "external_task_ref": "thread-1", + }))) + .respond_with(move |_request: &wiremock::Request| { + if response_count.fetch_add(1, Ordering::SeqCst) == 0 { + ResponseTemplate::new(500) + } else { + ResponseTemplate::new(200).set_body_json(json!({ + "task_id": "task-thread-1", + })) + } + }) + .expect(2) + .mount(&server) + .await; + + auth.registered_thread_task(AgentTaskExternalRef::new("thread-1"), Some(server.uri())) + .await + .expect_err("first registration should fail"); + let task = auth + .registered_thread_task(AgentTaskExternalRef::new("thread-1"), Some(server.uri())) + .await?; + + assert_eq!(request_count.load(Ordering::SeqCst), 2); + assert_eq!(task.task_id.as_str(), "task-thread-1"); + Ok(()) + } + + #[test] + fn task_slots_are_shared_across_clones() { + let auth = agent_identity_auth(); + let cloned = auth.clone(); + let slot = auth + .task_slot(AgentTaskCacheKey::Process) + .expect("task slot should be available"); + slot.set("process-task-1".to_string()) + .expect("process task should be unset"); + + assert_eq!(cloned.process_task_id(), Some("process-task-1".to_string())); + } } diff --git a/codex-rs/login/src/auth/auth_tests.rs b/codex-rs/login/src/auth/auth_tests.rs index 9de20b57b4..8ad1a88193 100644 --- a/codex-rs/login/src/auth/auth_tests.rs +++ b/codex-rs/login/src/auth/auth_tests.rs @@ -233,8 +233,8 @@ async fn chatgpt_auth_registers_agent_identity_when_enabled() -> anyhow::Result< agent_auth.record().agent_runtime_id, reused.record().agent_runtime_id ); - assert_eq!(agent_auth.process_task_id(), Some("task-123")); - assert_eq!(reused.process_task_id(), Some("task-123")); + assert_eq!(agent_auth.process_task_id(), Some("task-123".to_string())); + assert_eq!(reused.process_task_id(), Some("task-123".to_string())); assert_eq!(agent_auth.record().agent_runtime_id, "agent-runtime-123"); assert_eq!(agent_auth.record().account_id, "account-123"); assert_eq!(agent_auth.record().chatgpt_user_id, "user-12345"); @@ -853,7 +853,10 @@ async fn load_auth_reads_access_token_from_env() { panic!("env auth should load as agent identity"); }; assert_eq!(agent_identity.record(), &expected_record); - assert_eq!(agent_identity.process_task_id(), Some("task-123")); + assert_eq!( + agent_identity.process_task_id(), + Some("task-123".to_string()) + ); assert!( !get_auth_file(codex_home.path()).exists(), "env auth should not write auth.json" diff --git a/codex-rs/model-provider/src/amazon_bedrock/mod.rs b/codex-rs/model-provider/src/amazon_bedrock/mod.rs index 3940f73fcd..c1c47b4969 100644 --- a/codex-rs/model-provider/src/amazon_bedrock/mod.rs +++ b/codex-rs/model-provider/src/amazon_bedrock/mod.rs @@ -18,6 +18,7 @@ use codex_protocol::account::ProviderAccount; use codex_protocol::error::Result; use codex_protocol::openai_models::ModelsResponse; +use crate::auth::ProviderAuthScope; use crate::provider::ModelProvider; use crate::provider::ProviderAccountResult; use crate::provider::ProviderAccountState; @@ -96,6 +97,10 @@ impl ModelProvider for AmazonBedrockModelProvider { resolve_provider_auth(&self.aws).await } + async fn api_auth_for_scope(&self, _scope: ProviderAuthScope) -> Result { + resolve_provider_auth(&self.aws).await + } + fn models_manager( &self, _codex_home: PathBuf, diff --git a/codex-rs/model-provider/src/auth.rs b/codex-rs/model-provider/src/auth.rs index 2edc052dad..37cea5d878 100644 --- a/codex-rs/model-provider/src/auth.rs +++ b/codex-rs/model-provider/src/auth.rs @@ -2,39 +2,74 @@ use std::sync::Arc; use codex_agent_identity::AgentIdentityKey; use codex_agent_identity::AgentTaskAuthorizationTarget; +use codex_agent_identity::AgentTaskExternalRef; +use codex_agent_identity::RegisteredAgentTask; use codex_agent_identity::authorization_header_for_agent_task; +use codex_agent_identity::authorization_header_for_registered_task; use codex_api::AuthProvider; use codex_api::SharedAuthProvider; use codex_login::AuthManager; use codex_login::CodexAuth; +use codex_login::auth::AgentIdentityAuth; +use codex_login::auth::AgentIdentityAuthPolicy; use codex_model_provider_info::ModelProviderInfo; +use codex_protocol::protocol::SessionSource; use http::HeaderMap; use http::HeaderValue; use crate::bearer_auth_provider::BearerAuthProvider; +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum ProviderAuthScope { + /// Use the provider's default auth. Agent Identity auth uses its process task here. + UnscopedProcess, + /// Use a task-scoped Agent Assertion for work tied to a Codex thread. + Thread { + external_ref: AgentTaskExternalRef, + agent_identity_policy: AgentIdentityAuthPolicy, + session_source: SessionSource, + chatgpt_base_url: Option, + }, +} + #[derive(Clone, Debug)] struct AgentIdentityAuthProvider { - auth: codex_login::auth::AgentIdentityAuth, + auth: AgentIdentityAuth, + task: Option, } impl AuthProvider for AgentIdentityAuthProvider { fn add_auth_headers(&self, headers: &mut HeaderMap) { let record = self.auth.record(); - let Some(task_id) = self.auth.process_task_id() else { - return; + let header_value = match self.task.as_ref() { + Some(task) => authorization_header_for_registered_task( + AgentIdentityKey { + agent_runtime_id: &record.agent_runtime_id, + private_key_pkcs8_base64: &record.agent_private_key, + }, + task, + ) + .map_err(std::io::Error::other), + None => self + .auth + .process_task_id() + .ok_or_else(|| { + std::io::Error::other("agent identity process task is not initialized") + }) + .and_then(|task_id| { + authorization_header_for_agent_task( + AgentIdentityKey { + agent_runtime_id: &record.agent_runtime_id, + private_key_pkcs8_base64: &record.agent_private_key, + }, + AgentTaskAuthorizationTarget { + agent_runtime_id: &record.agent_runtime_id, + task_id: &task_id, + }, + ) + .map_err(std::io::Error::other) + }), }; - let header_value = authorization_header_for_agent_task( - AgentIdentityKey { - agent_runtime_id: &record.agent_runtime_id, - private_key_pkcs8_base64: &record.agent_private_key, - }, - AgentTaskAuthorizationTarget { - agent_runtime_id: &record.agent_runtime_id, - task_id, - }, - ) - .map_err(std::io::Error::other); if let Ok(header_value) = header_value && let Ok(header) = HeaderValue::from_str(&header_value) @@ -78,20 +113,61 @@ pub(crate) fn auth_manager_for_provider( } } -pub(crate) fn resolve_provider_auth( +pub(crate) async fn resolve_provider_auth( + auth_manager: Option>, auth: Option<&CodexAuth>, provider: &ModelProviderInfo, + scope: ProviderAuthScope, ) -> codex_protocol::error::Result { if let Some(auth) = bearer_auth_for_provider(provider)? { return Ok(Arc::new(auth)); } + if provider_uses_first_party_auth_path(provider) + && let ProviderAuthScope::Thread { + external_ref, + agent_identity_policy, + session_source, + chatgpt_base_url, + } = scope + && let Some(agent_identity_auth) = + agent_identity_auth_for_scope(auth_manager, auth, agent_identity_policy, session_source) + .await? + { + let task = agent_identity_auth + .registered_thread_task(external_ref, chatgpt_base_url) + .await?; + return Ok(auth_provider_from_agent_task(agent_identity_auth, task)); + } + Ok(match auth { Some(auth) => auth_provider_from_auth(auth), None => unauthenticated_auth_provider(), }) } +async fn agent_identity_auth_for_scope( + auth_manager: Option>, + auth: Option<&CodexAuth>, + policy: AgentIdentityAuthPolicy, + session_source: SessionSource, +) -> codex_protocol::error::Result> { + if let Some(auth_manager) = auth_manager { + return auth_manager + .agent_identity_auth(policy, session_source) + .await + .map_err(Into::into); + } + + Ok(match auth { + Some(CodexAuth::AgentIdentity(auth)) => Some(auth.clone()), + Some(CodexAuth::ApiKey(_)) + | Some(CodexAuth::Chatgpt(_)) + | Some(CodexAuth::ChatgptAuthTokens(_)) + | None => None, + }) +} + fn bearer_auth_for_provider( provider: &ModelProviderInfo, ) -> codex_protocol::error::Result> { @@ -106,12 +182,21 @@ fn bearer_auth_for_provider( Ok(None) } +pub fn provider_uses_first_party_auth_path(provider: &ModelProviderInfo) -> bool { + provider.requires_openai_auth + && provider.env_key.is_none() + && provider.experimental_bearer_token.is_none() + && provider.auth.is_none() + && provider.aws.is_none() +} + /// Builds request-header auth for a first-party Codex auth snapshot. pub fn auth_provider_from_auth(auth: &CodexAuth) -> SharedAuthProvider { match auth { - CodexAuth::AgentIdentity(auth) => { - Arc::new(AgentIdentityAuthProvider { auth: auth.clone() }) - } + CodexAuth::AgentIdentity(auth) => Arc::new(AgentIdentityAuthProvider { + auth: auth.clone(), + task: None, + }), CodexAuth::ApiKey(_) | CodexAuth::Chatgpt(_) | CodexAuth::ChatgptAuthTokens(_) => { Arc::new(BearerAuthProvider { token: auth.get_token().ok(), @@ -122,19 +207,173 @@ pub fn auth_provider_from_auth(auth: &CodexAuth) -> SharedAuthProvider { } } +pub fn auth_provider_from_agent_task( + auth: AgentIdentityAuth, + task: RegisteredAgentTask, +) -> SharedAuthProvider { + Arc::new(AgentIdentityAuthProvider { + auth, + task: Some(task), + }) +} + #[cfg(test)] mod tests { + use codex_agent_identity::AgentRuntimeId; + use codex_agent_identity::AgentTaskId; + use codex_agent_identity::AgentTaskKind; + use codex_agent_identity::generate_agent_key_material; + use codex_login::auth::AgentIdentityAuthRecord; use codex_model_provider_info::WireApi; use codex_model_provider_info::create_oss_provider_with_base_url; + use codex_protocol::account::PlanType; + use pretty_assertions::assert_eq; + use serde_json::json; + use wiremock::Mock; + use wiremock::MockServer; + use wiremock::ResponseTemplate; + use wiremock::matchers::body_partial_json; + use wiremock::matchers::method; + use wiremock::matchers::path; use super::*; - #[test] - fn unauthenticated_auth_provider_adds_no_headers() { + fn agent_identity_auth(chatgpt_account_is_fedramp: bool) -> AgentIdentityAuth { + let key_material = generate_agent_key_material().expect("generate key material"); + AgentIdentityAuth::new(AgentIdentityAuthRecord { + agent_runtime_id: "agent-runtime-1".to_string(), + agent_private_key: key_material.private_key_pkcs8_base64, + account_id: "account-1".to_string(), + chatgpt_user_id: "user-1".to_string(), + email: "agent@example.com".to_string(), + plan_type: PlanType::Plus, + chatgpt_account_is_fedramp, + registered_at: None, + }) + } + + #[tokio::test] + async fn unauthenticated_auth_provider_adds_no_headers() { let provider = create_oss_provider_with_base_url("http://localhost:11434/v1", WireApi::Responses); - let auth = resolve_provider_auth(/*auth*/ None, &provider).expect("auth should resolve"); + let auth = resolve_provider_auth( + /*auth_manager*/ None, + /*auth*/ None, + &provider, + ProviderAuthScope::UnscopedProcess, + ) + .await + .expect("auth should resolve"); assert!(auth.to_auth_headers().is_empty()); } + + #[tokio::test] + async fn first_party_thread_scope_uses_agent_assertion() { + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/v1/agent/agent-runtime-1/task/register")) + .and(body_partial_json(json!({ + "external_task_ref": "thread-1", + }))) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "task_id": "task-thread-1", + }))) + .expect(1) + .mount(&server) + .await; + let auth = CodexAuth::AgentIdentity(agent_identity_auth( + /*chatgpt_account_is_fedramp*/ false, + )); + let provider = ModelProviderInfo::create_openai_provider(/*base_url*/ None); + + let auth = resolve_provider_auth( + /*auth_manager*/ None, + Some(&auth), + &provider, + ProviderAuthScope::Thread { + external_ref: AgentTaskExternalRef::new("thread-1"), + agent_identity_policy: AgentIdentityAuthPolicy::JwtOnly, + session_source: SessionSource::Cli, + chatgpt_base_url: Some(server.uri()), + }, + ) + .await + .expect("auth should resolve"); + + let headers = auth.to_auth_headers(); + assert!( + headers + .get(http::header::AUTHORIZATION) + .and_then(|value| value.to_str().ok()) + .is_some_and(|value| value.starts_with("AgentAssertion ")) + ); + } + + #[test] + fn agent_task_auth_provider_preserves_account_routing_headers() { + let auth = agent_identity_auth(/*chatgpt_account_is_fedramp*/ true); + let provider = auth_provider_from_agent_task( + auth, + RegisteredAgentTask::new( + AgentRuntimeId::new("agent-runtime-1"), + AgentTaskId::new("background-task-1"), + AgentTaskKind::Background, + ), + ); + + let headers = provider.to_auth_headers(); + + assert!( + headers + .get(http::header::AUTHORIZATION) + .and_then(|value| value.to_str().ok()) + .is_some_and(|value| value.starts_with("AgentAssertion ")) + ); + assert_eq!( + headers + .get("ChatGPT-Account-ID") + .and_then(|value| value.to_str().ok()), + Some("account-1") + ); + assert_eq!( + headers + .get("X-OpenAI-Fedramp") + .and_then(|value| value.to_str().ok()), + Some("true") + ); + } + + #[tokio::test] + async fn provider_auth_ignores_thread_scope_for_non_openai_provider() { + let provider = + create_oss_provider_with_base_url("http://localhost:11434/v1", WireApi::Responses); + + let auth = resolve_provider_auth( + /*auth_manager*/ None, + /*auth*/ None, + &provider, + ProviderAuthScope::Thread { + external_ref: AgentTaskExternalRef::new("thread-1"), + agent_identity_policy: AgentIdentityAuthPolicy::JwtOnly, + session_source: SessionSource::Cli, + chatgpt_base_url: None, + }, + ) + .await + .expect("auth should resolve"); + + assert!(auth.to_auth_headers().is_empty()); + } + + #[test] + fn first_party_auth_path_excludes_provider_specific_auth() { + let mut env_key_provider = + ModelProviderInfo::create_openai_provider(/*base_url*/ None); + env_key_provider.env_key = Some("OPENAI_API_KEY".to_string()); + assert!(!provider_uses_first_party_auth_path(&env_key_provider)); + + let bedrock_provider = ModelProviderInfo::create_amazon_bedrock_provider(/*aws*/ None); + assert!(!provider_uses_first_party_auth_path(&bedrock_provider)); + } } diff --git a/codex-rs/model-provider/src/lib.rs b/codex-rs/model-provider/src/lib.rs index 4e4660812b..c6dffbd2fe 100644 --- a/codex-rs/model-provider/src/lib.rs +++ b/codex-rs/model-provider/src/lib.rs @@ -4,7 +4,12 @@ mod bearer_auth_provider; mod models_endpoint; mod provider; +pub use codex_agent_identity::AgentTaskExternalRef; + +pub use auth::ProviderAuthScope; +pub use auth::auth_provider_from_agent_task; pub use auth::auth_provider_from_auth; +pub use auth::provider_uses_first_party_auth_path; pub use auth::unauthenticated_auth_provider; pub use bearer_auth_provider::BearerAuthProvider; pub use bearer_auth_provider::BearerAuthProvider as CoreAuthProvider; diff --git a/codex-rs/model-provider/src/models_endpoint.rs b/codex-rs/model-provider/src/models_endpoint.rs index 8a72beea70..ddf503834e 100644 --- a/codex-rs/model-provider/src/models_endpoint.rs +++ b/codex-rs/model-provider/src/models_endpoint.rs @@ -26,6 +26,7 @@ use codex_response_debug_context::telemetry_transport_error_message; use http::HeaderMap; use tokio::time::timeout; +use crate::auth::ProviderAuthScope; use crate::auth::resolve_provider_auth; const MODELS_REFRESH_TIMEOUT: Duration = Duration::from_secs(5); @@ -87,7 +88,13 @@ impl ModelsEndpointClient for OpenAiModelsEndpoint { let auth = self.auth().await; let auth_mode = auth.as_ref().map(CodexAuth::auth_mode); let api_provider = self.provider_info.to_api_provider(auth_mode)?; - let api_auth = resolve_provider_auth(auth.as_ref(), &self.provider_info)?; + let api_auth = resolve_provider_auth( + self.auth_manager.clone(), + auth.as_ref(), + &self.provider_info, + ProviderAuthScope::UnscopedProcess, + ) + .await?; let transport = ReqwestTransport::new(build_reqwest_client()); let auth_telemetry = auth_header_telemetry(api_auth.as_ref()); let request_telemetry: Arc = Arc::new(ModelsRequestTelemetry { diff --git a/codex-rs/model-provider/src/provider.rs b/codex-rs/model-provider/src/provider.rs index 1ef7c22962..1e685ef59a 100644 --- a/codex-rs/model-provider/src/provider.rs +++ b/codex-rs/model-provider/src/provider.rs @@ -14,6 +14,7 @@ use codex_protocol::account::ProviderAccount; use codex_protocol::openai_models::ModelsResponse; use crate::amazon_bedrock::AmazonBedrockModelProvider; +use crate::auth::ProviderAuthScope; use crate::auth::auth_manager_for_provider; use crate::auth::resolve_provider_auth; use crate::models_endpoint::OpenAiModelsEndpoint; @@ -129,8 +130,17 @@ pub trait ModelProvider: fmt::Debug + Send + Sync { /// Returns the auth provider used to attach request credentials. async fn api_auth(&self) -> codex_protocol::error::Result { + self.api_auth_for_scope(ProviderAuthScope::UnscopedProcess) + .await + } + + /// Returns request credentials, optionally scoped to a Codex session task. + async fn api_auth_for_scope( + &self, + scope: ProviderAuthScope, + ) -> codex_protocol::error::Result { let auth = self.auth().await; - resolve_provider_auth(auth.as_ref(), self.info()) + resolve_provider_auth(self.auth_manager(), auth.as_ref(), self.info(), scope).await } /// Creates the model manager implementation appropriate for this provider.