feat: use thread agent task auth for inference

This commit is contained in:
adrian
2026-04-22 16:18:44 -07:00
parent 6ffcd6eb15
commit ede0da1cf2
11 changed files with 653 additions and 53 deletions

View File

@@ -722,7 +722,9 @@ impl AgentControl {
} else {
state.send_op(agent_id, Op::Shutdown {}).await
};
thread.wait_until_terminated().await;
if result.is_ok() || matches!(result, Err(CodexErr::InternalAgentDied)) {
thread.wait_until_terminated().await;
}
result
} else {
state.send_op(agent_id, Op::Shutdown {}).await

View File

@@ -116,8 +116,11 @@ use crate::util::emit_feedback_auth_recovery_tags;
use codex_api::map_api_error;
use codex_feedback::FeedbackRequestTags;
use codex_feedback::emit_feedback_request_tags_with_auth_env;
use codex_login::auth::AgentIdentityAuthPolicy;
use codex_login::auth_env_telemetry::AuthEnvTelemetry;
use codex_login::auth_env_telemetry::collect_auth_env_telemetry;
use codex_model_provider::AgentTaskExternalRef;
use codex_model_provider::ProviderAuthScope;
use codex_model_provider::SharedModelProvider;
use codex_model_provider::create_model_provider;
#[cfg(test)]
@@ -170,6 +173,8 @@ struct ModelClientState {
provider: SharedModelProvider,
auth_env_telemetry: AuthEnvTelemetry,
session_source: SessionSource,
agent_identity_policy: AgentIdentityAuthPolicy,
chatgpt_base_url: Option<String>,
model_verbosity: Option<VerbosityConfig>,
enable_request_compression: bool,
include_timing_metrics: bool,
@@ -321,6 +326,39 @@ impl ModelClient {
include_timing_metrics: bool,
beta_features_header: Option<String>,
attestation_provider: Option<Arc<dyn AttestationProvider>>,
) -> Self {
Self::new_with_agent_identity_policy(
auth_manager,
session_id,
thread_id,
installation_id,
provider_info,
session_source,
AgentIdentityAuthPolicy::JwtOnly,
/*chatgpt_base_url*/ None,
model_verbosity,
enable_request_compression,
include_timing_metrics,
beta_features_header,
attestation_provider,
)
}
#[allow(clippy::too_many_arguments)]
pub fn new_with_agent_identity_policy(
auth_manager: Option<Arc<AuthManager>>,
session_id: SessionId,
thread_id: ThreadId,
installation_id: String,
provider_info: ModelProviderInfo,
session_source: SessionSource,
agent_identity_policy: AgentIdentityAuthPolicy,
chatgpt_base_url: Option<String>,
model_verbosity: Option<VerbosityConfig>,
enable_request_compression: bool,
include_timing_metrics: bool,
beta_features_header: Option<String>,
attestation_provider: Option<Arc<dyn AttestationProvider>>,
) -> Self {
let model_provider = create_model_provider(provider_info, auth_manager);
let codex_api_key_env_enabled = model_provider
@@ -339,6 +377,8 @@ impl ModelClient {
provider: model_provider,
auth_env_telemetry,
session_source,
agent_identity_policy,
chatgpt_base_url,
model_verbosity,
enable_request_compression,
include_timing_metrics,
@@ -785,7 +825,11 @@ impl ModelClient {
async fn current_client_setup(&self) -> Result<CurrentClientSetup> {
let auth = self.state.provider.auth().await;
let api_provider = self.state.provider.api_provider().await?;
let api_auth = self.state.provider.api_auth().await?;
let api_auth = self
.state
.provider
.api_auth_for_scope(self.provider_auth_scope())
.await?;
Ok(CurrentClientSetup {
auth,
api_provider,
@@ -793,6 +837,15 @@ impl ModelClient {
})
}
fn provider_auth_scope(&self) -> ProviderAuthScope {
ProviderAuthScope::Thread {
external_ref: AgentTaskExternalRef::new(self.state.thread_id.to_string()),
agent_identity_policy: self.state.agent_identity_policy,
session_source: self.state.session_source.clone(),
chatgpt_base_url: self.state.chatgpt_base_url.clone(),
}
}
/// Opens a websocket connection using the same header and telemetry wiring as normal turns.
///
/// Both startup prewarm and in-turn `needs_new` reconnects call this path so handshake
@@ -931,6 +984,10 @@ impl Drop for ModelClientSession {
}
impl ModelClientSession {
async fn current_client_setup(&self) -> Result<CurrentClientSetup> {
self.client.current_client_setup().await
}
pub(crate) fn reset_websocket_session(&mut self) {
self.websocket_session.connection = None;
self.websocket_session.last_request = None;
@@ -1077,7 +1134,7 @@ impl ModelClientSession {
return Ok(());
}
let client_setup = self.client.current_client_setup().await.map_err(|err| {
let client_setup = self.current_client_setup().await.map_err(|err| {
ApiError::Stream(format!(
"failed to build websocket prewarm client setup: {err}"
))
@@ -1224,7 +1281,7 @@ impl ModelClientSession {
.map(AuthManager::unauthorized_recovery);
let mut pending_retry = PendingUnauthorizedRetry::default();
loop {
let client_setup = self.client.current_client_setup().await?;
let client_setup = self.current_client_setup().await?;
let transport = ReqwestTransport::new(build_reqwest_client());
let request_auth_context = AuthRequestTelemetryContext::new(
client_setup.auth.as_ref().map(CodexAuth::auth_mode),
@@ -1340,7 +1397,7 @@ impl ModelClientSession {
.map(AuthManager::unauthorized_recovery);
let mut pending_retry = PendingUnauthorizedRetry::default();
loop {
let client_setup = self.client.current_client_setup().await?;
let client_setup = self.current_client_setup().await?;
let request_auth_context = AuthRequestTelemetryContext::new(
client_setup.auth.as_ref().map(CodexAuth::auth_mode),
client_setup.api_auth.as_ref(),

View File

@@ -15,7 +15,10 @@ use codex_api::ResponseEvent;
use codex_app_server_protocol::AuthMode;
use codex_login::AuthManager;
use codex_login::CodexAuth;
use codex_login::auth::AgentIdentityAuthPolicy;
use codex_model_provider::AgentTaskExternalRef;
use codex_model_provider::BearerAuthProvider;
use codex_model_provider::ProviderAuthScope;
use codex_model_provider_info::CHATGPT_CODEX_BASE_URL;
use codex_model_provider_info::ModelProviderInfo;
use codex_model_provider_info::WireApi;
@@ -61,8 +64,15 @@ use tracing_subscriber::registry::LookupSpan;
use tracing_subscriber::util::SubscriberInitExt;
fn test_model_client(session_source: SessionSource) -> ModelClient {
test_model_client_with_thread_id(ThreadId::new(), session_source)
}
fn test_model_client_with_thread_id(
conversation_id: ThreadId,
session_source: SessionSource,
) -> ModelClient {
let provider = create_oss_provider_with_base_url("https://example.com/v1", WireApi::Responses);
let thread_id = ThreadId::new();
let thread_id = conversation_id;
ModelClient::new(
/*auth_manager*/ None,
thread_id.into(),
@@ -78,6 +88,23 @@ fn test_model_client(session_source: SessionSource) -> ModelClient {
)
}
#[test]
fn provider_auth_scope_uses_thread_id_as_session_ref() {
let conversation_id =
ThreadId::from_string("018f4f4c-43f5-7b28-8e24-000000000001").expect("valid thread id");
let client = test_model_client_with_thread_id(conversation_id, SessionSource::Cli);
assert_eq!(
client.provider_auth_scope(),
ProviderAuthScope::Thread {
external_ref: AgentTaskExternalRef::new(conversation_id.to_string()),
agent_identity_policy: AgentIdentityAuthPolicy::JwtOnly,
session_source: SessionSource::Cli,
chatgpt_base_url: None,
}
);
}
fn test_model_info() -> ModelInfo {
serde_json::from_value(json!({
"slug": "gpt-test",

View File

@@ -928,13 +928,19 @@ impl Session {
live_thread: live_thread_init.as_ref().cloned(),
thread_store: Arc::clone(&thread_store),
attestation_provider: attestation_provider.clone(),
model_client: ModelClient::new(
model_client: ModelClient::new_with_agent_identity_policy(
Some(Arc::clone(&auth_manager)),
session_id,
thread_id,
installation_id.clone(),
session_configuration.provider.clone(),
session_configuration.session_source.clone(),
if config.features.enabled(Feature::UseAgentIdentity) {
codex_login::auth::AgentIdentityAuthPolicy::JwtOrChatgpt
} else {
codex_login::auth::AgentIdentityAuthPolicy::JwtOnly
},
Some(config.chatgpt_base_url.clone()),
config.model_verbosity,
config.features.enabled(Feature::EnableRequestCompression),
config.features.enabled(Feature::RuntimeMetrics),
@@ -1119,8 +1125,6 @@ impl Session {
anyhow::bail!("required MCP servers failed to initialize: {details}");
}
}
sess.schedule_startup_prewarm(session_configuration.base_instructions.clone())
.await;
let session_start_source = match &initial_history {
InitialHistory::Resumed(_) => codex_hooks::SessionStartSource::Resume,
InitialHistory::New | InitialHistory::Forked(_) => {
@@ -1131,6 +1135,8 @@ impl Session {
// record_initial_history can emit events. We record only after the SessionConfiguredEvent is emitted.
sess.record_initial_history(initial_history).await;
sess.schedule_startup_prewarm(session_configuration.base_instructions.clone())
.await;
{
let mut state = sess.state.lock().await;
state.set_pending_session_start_source(Some(session_start_source));

View File

@@ -1,8 +1,15 @@
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::Mutex;
use codex_agent_identity::AgentIdentityKey;
use codex_agent_identity::AgentRuntimeId;
use codex_agent_identity::AgentTaskExternalRef;
use codex_agent_identity::AgentTaskId;
use codex_agent_identity::AgentTaskKind;
use codex_agent_identity::RegisteredAgentTask;
use codex_agent_identity::normalize_chatgpt_base_url;
use codex_agent_identity::register_agent_task;
use codex_agent_identity::register_agent_task_with_external_ref;
use codex_protocol::account::PlanType as AccountPlanType;
use tokio::sync::OnceCell;
@@ -15,14 +22,20 @@ const DEFAULT_CHATGPT_BACKEND_BASE_URL: &str = "https://chatgpt.com/backend-api"
#[derive(Debug)]
pub struct AgentIdentityAuth {
record: AgentIdentityAuthRecord,
process_task_id: Arc<OnceCell<String>>,
task_ids: Arc<Mutex<HashMap<AgentTaskCacheKey, Arc<OnceCell<String>>>>>,
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
enum AgentTaskCacheKey {
Process,
Thread(AgentTaskExternalRef),
}
impl Clone for AgentIdentityAuth {
fn clone(&self) -> Self {
Self {
record: self.record.clone(),
process_task_id: Arc::clone(&self.process_task_id),
task_ids: Arc::clone(&self.task_ids),
}
}
}
@@ -31,7 +44,7 @@ impl AgentIdentityAuth {
pub fn new(record: AgentIdentityAuthRecord) -> Self {
Self {
record,
process_task_id: Arc::new(OnceCell::new()),
task_ids: Arc::new(Mutex::new(HashMap::new())),
}
}
@@ -39,24 +52,58 @@ impl AgentIdentityAuth {
&self.record
}
pub fn process_task_id(&self) -> Option<&str> {
self.process_task_id.get().map(String::as_str)
pub fn process_task_id(&self) -> Option<String> {
self.task_id_for_initialized_key(&AgentTaskCacheKey::Process)
}
pub async fn ensure_runtime(&self, chatgpt_base_url: Option<String>) -> std::io::Result<()> {
self.process_task_id
.get_or_try_init(|| async {
let base_url = normalize_chatgpt_base_url(
chatgpt_base_url
.as_deref()
.unwrap_or(DEFAULT_CHATGPT_BACKEND_BASE_URL),
);
register_agent_task(&build_reqwest_client(), &base_url, self.key())
.await
.map_err(std::io::Error::other)
})
self.task_id_for_key(
AgentTaskCacheKey::Process,
chatgpt_base_url,
/*external_ref*/ None,
)
.await
.map(|_| ())
}
pub async fn registered_thread_task(
&self,
external_ref: AgentTaskExternalRef,
chatgpt_base_url: Option<String>,
) -> std::io::Result<RegisteredAgentTask> {
let task_id = self
.task_id_for_key(
AgentTaskCacheKey::Thread(external_ref.clone()),
chatgpt_base_url,
Some(external_ref),
)
.await?;
Ok(self.registered_task(task_id, AgentTaskKind::Thread))
}
pub async fn register_task(&self, chatgpt_base_url: Option<String>) -> std::io::Result<String> {
self.register_task_with_external_ref(chatgpt_base_url, /*external_ref*/ None)
.await
.map(|_| ())
}
async fn register_task_with_external_ref(
&self,
chatgpt_base_url: Option<String>,
external_ref: Option<&AgentTaskExternalRef>,
) -> std::io::Result<String> {
let base_url = normalize_chatgpt_base_url(
chatgpt_base_url
.as_deref()
.unwrap_or(DEFAULT_CHATGPT_BACKEND_BASE_URL),
);
register_agent_task_with_external_ref(
&build_reqwest_client(),
&base_url,
self.key(),
external_ref,
)
.await
.map_err(std::io::Error::other)
}
pub fn account_id(&self) -> &str {
@@ -84,4 +131,196 @@ impl AgentIdentityAuth {
private_key_pkcs8_base64: &self.record.agent_private_key,
}
}
async fn task_id_for_key(
&self,
key: AgentTaskCacheKey,
chatgpt_base_url: Option<String>,
external_ref: Option<AgentTaskExternalRef>,
) -> std::io::Result<String> {
let slot = self.task_slot(key)?;
slot.get_or_try_init(|| async {
self.register_task_with_external_ref(chatgpt_base_url, external_ref.as_ref())
.await
})
.await
.cloned()
}
fn task_slot(&self, key: AgentTaskCacheKey) -> std::io::Result<Arc<OnceCell<String>>> {
let mut task_ids = self
.task_ids
.lock()
.map_err(|_| std::io::Error::other("failed to lock agent task cache"))?;
Ok(task_ids
.entry(key)
.or_insert_with(|| Arc::new(OnceCell::new()))
.clone())
}
fn task_id_for_initialized_key(&self, key: &AgentTaskCacheKey) -> Option<String> {
let task_ids = self.task_ids.lock().ok()?;
task_ids.get(key)?.get().cloned()
}
fn registered_task(&self, task_id: String, kind: AgentTaskKind) -> RegisteredAgentTask {
RegisteredAgentTask::new(
AgentRuntimeId::new(self.record.agent_runtime_id.clone()),
AgentTaskId::new(task_id),
kind,
)
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;
use codex_agent_identity::generate_agent_key_material;
use pretty_assertions::assert_eq;
use serde_json::json;
use wiremock::Mock;
use wiremock::MockServer;
use wiremock::ResponseTemplate;
use wiremock::matchers::body_partial_json;
use wiremock::matchers::method;
use wiremock::matchers::path;
use super::*;
fn agent_identity_record(private_key: String) -> AgentIdentityAuthRecord {
AgentIdentityAuthRecord {
agent_runtime_id: "agent-runtime-1".to_string(),
agent_private_key: private_key,
account_id: "account-1".to_string(),
chatgpt_user_id: "user-1".to_string(),
email: "agent@example.com".to_string(),
plan_type: AccountPlanType::Plus,
chatgpt_account_is_fedramp: false,
registered_at: None,
}
}
fn agent_identity_auth() -> AgentIdentityAuth {
let key_material = generate_agent_key_material().expect("generate key material");
AgentIdentityAuth::new(agent_identity_record(key_material.private_key_pkcs8_base64))
}
#[tokio::test]
async fn registered_thread_task_registers_once_per_external_ref() -> anyhow::Result<()> {
let auth = agent_identity_auth();
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/agent/agent-runtime-1/task/register"))
.and(body_partial_json(json!({
"external_task_ref": "thread-1",
})))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"task_id": "task-thread-1",
})))
.expect(1)
.mount(&server)
.await;
let first = auth
.registered_thread_task(AgentTaskExternalRef::new("thread-1"), Some(server.uri()))
.await?;
let second = auth
.registered_thread_task(AgentTaskExternalRef::new("thread-1"), Some(server.uri()))
.await?;
assert_eq!(first, second);
assert_eq!(
first,
RegisteredAgentTask::new(
AgentRuntimeId::new("agent-runtime-1"),
AgentTaskId::new("task-thread-1"),
AgentTaskKind::Thread,
)
);
Ok(())
}
#[tokio::test]
async fn registered_thread_task_uses_distinct_external_refs() -> anyhow::Result<()> {
let auth = agent_identity_auth();
let server = MockServer::start().await;
for (external_ref, task_id) in
[("thread-1", "task-thread-1"), ("thread-2", "task-thread-2")]
{
Mock::given(method("POST"))
.and(path("/v1/agent/agent-runtime-1/task/register"))
.and(body_partial_json(json!({
"external_task_ref": external_ref,
})))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"task_id": task_id,
})))
.expect(1)
.mount(&server)
.await;
}
let first = auth
.registered_thread_task(AgentTaskExternalRef::new("thread-1"), Some(server.uri()))
.await?;
let second = auth
.registered_thread_task(AgentTaskExternalRef::new("thread-2"), Some(server.uri()))
.await?;
assert_eq!(first.task_id.as_str(), "task-thread-1");
assert_eq!(second.task_id.as_str(), "task-thread-2");
Ok(())
}
#[tokio::test]
async fn failed_thread_task_registration_can_retry() -> anyhow::Result<()> {
let auth = agent_identity_auth();
let server = MockServer::start().await;
let request_count = Arc::new(AtomicUsize::new(0));
let response_count = Arc::clone(&request_count);
Mock::given(method("POST"))
.and(path("/v1/agent/agent-runtime-1/task/register"))
.and(body_partial_json(json!({
"external_task_ref": "thread-1",
})))
.respond_with(move |_request: &wiremock::Request| {
if response_count.fetch_add(1, Ordering::SeqCst) == 0 {
ResponseTemplate::new(500)
} else {
ResponseTemplate::new(200).set_body_json(json!({
"task_id": "task-thread-1",
}))
}
})
.expect(2)
.mount(&server)
.await;
auth.registered_thread_task(AgentTaskExternalRef::new("thread-1"), Some(server.uri()))
.await
.expect_err("first registration should fail");
let task = auth
.registered_thread_task(AgentTaskExternalRef::new("thread-1"), Some(server.uri()))
.await?;
assert_eq!(request_count.load(Ordering::SeqCst), 2);
assert_eq!(task.task_id.as_str(), "task-thread-1");
Ok(())
}
#[test]
fn task_slots_are_shared_across_clones() {
let auth = agent_identity_auth();
let cloned = auth.clone();
let slot = auth
.task_slot(AgentTaskCacheKey::Process)
.expect("task slot should be available");
slot.set("process-task-1".to_string())
.expect("process task should be unset");
assert_eq!(cloned.process_task_id(), Some("process-task-1".to_string()));
}
}

View File

@@ -233,8 +233,8 @@ async fn chatgpt_auth_registers_agent_identity_when_enabled() -> anyhow::Result<
agent_auth.record().agent_runtime_id,
reused.record().agent_runtime_id
);
assert_eq!(agent_auth.process_task_id(), Some("task-123"));
assert_eq!(reused.process_task_id(), Some("task-123"));
assert_eq!(agent_auth.process_task_id(), Some("task-123".to_string()));
assert_eq!(reused.process_task_id(), Some("task-123".to_string()));
assert_eq!(agent_auth.record().agent_runtime_id, "agent-runtime-123");
assert_eq!(agent_auth.record().account_id, "account-123");
assert_eq!(agent_auth.record().chatgpt_user_id, "user-12345");
@@ -853,7 +853,10 @@ async fn load_auth_reads_access_token_from_env() {
panic!("env auth should load as agent identity");
};
assert_eq!(agent_identity.record(), &expected_record);
assert_eq!(agent_identity.process_task_id(), Some("task-123"));
assert_eq!(
agent_identity.process_task_id(),
Some("task-123".to_string())
);
assert!(
!get_auth_file(codex_home.path()).exists(),
"env auth should not write auth.json"

View File

@@ -18,6 +18,7 @@ use codex_protocol::account::ProviderAccount;
use codex_protocol::error::Result;
use codex_protocol::openai_models::ModelsResponse;
use crate::auth::ProviderAuthScope;
use crate::provider::ModelProvider;
use crate::provider::ProviderAccountResult;
use crate::provider::ProviderAccountState;
@@ -96,6 +97,10 @@ impl ModelProvider for AmazonBedrockModelProvider {
resolve_provider_auth(&self.aws).await
}
async fn api_auth_for_scope(&self, _scope: ProviderAuthScope) -> Result<SharedAuthProvider> {
resolve_provider_auth(&self.aws).await
}
fn models_manager(
&self,
_codex_home: PathBuf,

View File

@@ -2,39 +2,74 @@ use std::sync::Arc;
use codex_agent_identity::AgentIdentityKey;
use codex_agent_identity::AgentTaskAuthorizationTarget;
use codex_agent_identity::AgentTaskExternalRef;
use codex_agent_identity::RegisteredAgentTask;
use codex_agent_identity::authorization_header_for_agent_task;
use codex_agent_identity::authorization_header_for_registered_task;
use codex_api::AuthProvider;
use codex_api::SharedAuthProvider;
use codex_login::AuthManager;
use codex_login::CodexAuth;
use codex_login::auth::AgentIdentityAuth;
use codex_login::auth::AgentIdentityAuthPolicy;
use codex_model_provider_info::ModelProviderInfo;
use codex_protocol::protocol::SessionSource;
use http::HeaderMap;
use http::HeaderValue;
use crate::bearer_auth_provider::BearerAuthProvider;
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum ProviderAuthScope {
/// Use the provider's default auth. Agent Identity auth uses its process task here.
UnscopedProcess,
/// Use a task-scoped Agent Assertion for work tied to a Codex thread.
Thread {
external_ref: AgentTaskExternalRef,
agent_identity_policy: AgentIdentityAuthPolicy,
session_source: SessionSource,
chatgpt_base_url: Option<String>,
},
}
#[derive(Clone, Debug)]
struct AgentIdentityAuthProvider {
auth: codex_login::auth::AgentIdentityAuth,
auth: AgentIdentityAuth,
task: Option<RegisteredAgentTask>,
}
impl AuthProvider for AgentIdentityAuthProvider {
fn add_auth_headers(&self, headers: &mut HeaderMap) {
let record = self.auth.record();
let Some(task_id) = self.auth.process_task_id() else {
return;
let header_value = match self.task.as_ref() {
Some(task) => authorization_header_for_registered_task(
AgentIdentityKey {
agent_runtime_id: &record.agent_runtime_id,
private_key_pkcs8_base64: &record.agent_private_key,
},
task,
)
.map_err(std::io::Error::other),
None => self
.auth
.process_task_id()
.ok_or_else(|| {
std::io::Error::other("agent identity process task is not initialized")
})
.and_then(|task_id| {
authorization_header_for_agent_task(
AgentIdentityKey {
agent_runtime_id: &record.agent_runtime_id,
private_key_pkcs8_base64: &record.agent_private_key,
},
AgentTaskAuthorizationTarget {
agent_runtime_id: &record.agent_runtime_id,
task_id: &task_id,
},
)
.map_err(std::io::Error::other)
}),
};
let header_value = authorization_header_for_agent_task(
AgentIdentityKey {
agent_runtime_id: &record.agent_runtime_id,
private_key_pkcs8_base64: &record.agent_private_key,
},
AgentTaskAuthorizationTarget {
agent_runtime_id: &record.agent_runtime_id,
task_id,
},
)
.map_err(std::io::Error::other);
if let Ok(header_value) = header_value
&& let Ok(header) = HeaderValue::from_str(&header_value)
@@ -78,20 +113,61 @@ pub(crate) fn auth_manager_for_provider(
}
}
pub(crate) fn resolve_provider_auth(
pub(crate) async fn resolve_provider_auth(
auth_manager: Option<Arc<AuthManager>>,
auth: Option<&CodexAuth>,
provider: &ModelProviderInfo,
scope: ProviderAuthScope,
) -> codex_protocol::error::Result<SharedAuthProvider> {
if let Some(auth) = bearer_auth_for_provider(provider)? {
return Ok(Arc::new(auth));
}
if provider_uses_first_party_auth_path(provider)
&& let ProviderAuthScope::Thread {
external_ref,
agent_identity_policy,
session_source,
chatgpt_base_url,
} = scope
&& let Some(agent_identity_auth) =
agent_identity_auth_for_scope(auth_manager, auth, agent_identity_policy, session_source)
.await?
{
let task = agent_identity_auth
.registered_thread_task(external_ref, chatgpt_base_url)
.await?;
return Ok(auth_provider_from_agent_task(agent_identity_auth, task));
}
Ok(match auth {
Some(auth) => auth_provider_from_auth(auth),
None => unauthenticated_auth_provider(),
})
}
async fn agent_identity_auth_for_scope(
auth_manager: Option<Arc<AuthManager>>,
auth: Option<&CodexAuth>,
policy: AgentIdentityAuthPolicy,
session_source: SessionSource,
) -> codex_protocol::error::Result<Option<AgentIdentityAuth>> {
if let Some(auth_manager) = auth_manager {
return auth_manager
.agent_identity_auth(policy, session_source)
.await
.map_err(Into::into);
}
Ok(match auth {
Some(CodexAuth::AgentIdentity(auth)) => Some(auth.clone()),
Some(CodexAuth::ApiKey(_))
| Some(CodexAuth::Chatgpt(_))
| Some(CodexAuth::ChatgptAuthTokens(_))
| None => None,
})
}
fn bearer_auth_for_provider(
provider: &ModelProviderInfo,
) -> codex_protocol::error::Result<Option<BearerAuthProvider>> {
@@ -106,12 +182,21 @@ fn bearer_auth_for_provider(
Ok(None)
}
pub fn provider_uses_first_party_auth_path(provider: &ModelProviderInfo) -> bool {
provider.requires_openai_auth
&& provider.env_key.is_none()
&& provider.experimental_bearer_token.is_none()
&& provider.auth.is_none()
&& provider.aws.is_none()
}
/// Builds request-header auth for a first-party Codex auth snapshot.
pub fn auth_provider_from_auth(auth: &CodexAuth) -> SharedAuthProvider {
match auth {
CodexAuth::AgentIdentity(auth) => {
Arc::new(AgentIdentityAuthProvider { auth: auth.clone() })
}
CodexAuth::AgentIdentity(auth) => Arc::new(AgentIdentityAuthProvider {
auth: auth.clone(),
task: None,
}),
CodexAuth::ApiKey(_) | CodexAuth::Chatgpt(_) | CodexAuth::ChatgptAuthTokens(_) => {
Arc::new(BearerAuthProvider {
token: auth.get_token().ok(),
@@ -122,19 +207,173 @@ pub fn auth_provider_from_auth(auth: &CodexAuth) -> SharedAuthProvider {
}
}
pub fn auth_provider_from_agent_task(
auth: AgentIdentityAuth,
task: RegisteredAgentTask,
) -> SharedAuthProvider {
Arc::new(AgentIdentityAuthProvider {
auth,
task: Some(task),
})
}
#[cfg(test)]
mod tests {
use codex_agent_identity::AgentRuntimeId;
use codex_agent_identity::AgentTaskId;
use codex_agent_identity::AgentTaskKind;
use codex_agent_identity::generate_agent_key_material;
use codex_login::auth::AgentIdentityAuthRecord;
use codex_model_provider_info::WireApi;
use codex_model_provider_info::create_oss_provider_with_base_url;
use codex_protocol::account::PlanType;
use pretty_assertions::assert_eq;
use serde_json::json;
use wiremock::Mock;
use wiremock::MockServer;
use wiremock::ResponseTemplate;
use wiremock::matchers::body_partial_json;
use wiremock::matchers::method;
use wiremock::matchers::path;
use super::*;
#[test]
fn unauthenticated_auth_provider_adds_no_headers() {
fn agent_identity_auth(chatgpt_account_is_fedramp: bool) -> AgentIdentityAuth {
let key_material = generate_agent_key_material().expect("generate key material");
AgentIdentityAuth::new(AgentIdentityAuthRecord {
agent_runtime_id: "agent-runtime-1".to_string(),
agent_private_key: key_material.private_key_pkcs8_base64,
account_id: "account-1".to_string(),
chatgpt_user_id: "user-1".to_string(),
email: "agent@example.com".to_string(),
plan_type: PlanType::Plus,
chatgpt_account_is_fedramp,
registered_at: None,
})
}
#[tokio::test]
async fn unauthenticated_auth_provider_adds_no_headers() {
let provider =
create_oss_provider_with_base_url("http://localhost:11434/v1", WireApi::Responses);
let auth = resolve_provider_auth(/*auth*/ None, &provider).expect("auth should resolve");
let auth = resolve_provider_auth(
/*auth_manager*/ None,
/*auth*/ None,
&provider,
ProviderAuthScope::UnscopedProcess,
)
.await
.expect("auth should resolve");
assert!(auth.to_auth_headers().is_empty());
}
#[tokio::test]
async fn first_party_thread_scope_uses_agent_assertion() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/agent/agent-runtime-1/task/register"))
.and(body_partial_json(json!({
"external_task_ref": "thread-1",
})))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"task_id": "task-thread-1",
})))
.expect(1)
.mount(&server)
.await;
let auth = CodexAuth::AgentIdentity(agent_identity_auth(
/*chatgpt_account_is_fedramp*/ false,
));
let provider = ModelProviderInfo::create_openai_provider(/*base_url*/ None);
let auth = resolve_provider_auth(
/*auth_manager*/ None,
Some(&auth),
&provider,
ProviderAuthScope::Thread {
external_ref: AgentTaskExternalRef::new("thread-1"),
agent_identity_policy: AgentIdentityAuthPolicy::JwtOnly,
session_source: SessionSource::Cli,
chatgpt_base_url: Some(server.uri()),
},
)
.await
.expect("auth should resolve");
let headers = auth.to_auth_headers();
assert!(
headers
.get(http::header::AUTHORIZATION)
.and_then(|value| value.to_str().ok())
.is_some_and(|value| value.starts_with("AgentAssertion "))
);
}
#[test]
fn agent_task_auth_provider_preserves_account_routing_headers() {
let auth = agent_identity_auth(/*chatgpt_account_is_fedramp*/ true);
let provider = auth_provider_from_agent_task(
auth,
RegisteredAgentTask::new(
AgentRuntimeId::new("agent-runtime-1"),
AgentTaskId::new("background-task-1"),
AgentTaskKind::Background,
),
);
let headers = provider.to_auth_headers();
assert!(
headers
.get(http::header::AUTHORIZATION)
.and_then(|value| value.to_str().ok())
.is_some_and(|value| value.starts_with("AgentAssertion "))
);
assert_eq!(
headers
.get("ChatGPT-Account-ID")
.and_then(|value| value.to_str().ok()),
Some("account-1")
);
assert_eq!(
headers
.get("X-OpenAI-Fedramp")
.and_then(|value| value.to_str().ok()),
Some("true")
);
}
#[tokio::test]
async fn provider_auth_ignores_thread_scope_for_non_openai_provider() {
let provider =
create_oss_provider_with_base_url("http://localhost:11434/v1", WireApi::Responses);
let auth = resolve_provider_auth(
/*auth_manager*/ None,
/*auth*/ None,
&provider,
ProviderAuthScope::Thread {
external_ref: AgentTaskExternalRef::new("thread-1"),
agent_identity_policy: AgentIdentityAuthPolicy::JwtOnly,
session_source: SessionSource::Cli,
chatgpt_base_url: None,
},
)
.await
.expect("auth should resolve");
assert!(auth.to_auth_headers().is_empty());
}
#[test]
fn first_party_auth_path_excludes_provider_specific_auth() {
let mut env_key_provider =
ModelProviderInfo::create_openai_provider(/*base_url*/ None);
env_key_provider.env_key = Some("OPENAI_API_KEY".to_string());
assert!(!provider_uses_first_party_auth_path(&env_key_provider));
let bedrock_provider = ModelProviderInfo::create_amazon_bedrock_provider(/*aws*/ None);
assert!(!provider_uses_first_party_auth_path(&bedrock_provider));
}
}

View File

@@ -4,7 +4,12 @@ mod bearer_auth_provider;
mod models_endpoint;
mod provider;
pub use codex_agent_identity::AgentTaskExternalRef;
pub use auth::ProviderAuthScope;
pub use auth::auth_provider_from_agent_task;
pub use auth::auth_provider_from_auth;
pub use auth::provider_uses_first_party_auth_path;
pub use auth::unauthenticated_auth_provider;
pub use bearer_auth_provider::BearerAuthProvider;
pub use bearer_auth_provider::BearerAuthProvider as CoreAuthProvider;

View File

@@ -26,6 +26,7 @@ use codex_response_debug_context::telemetry_transport_error_message;
use http::HeaderMap;
use tokio::time::timeout;
use crate::auth::ProviderAuthScope;
use crate::auth::resolve_provider_auth;
const MODELS_REFRESH_TIMEOUT: Duration = Duration::from_secs(5);
@@ -87,7 +88,13 @@ impl ModelsEndpointClient for OpenAiModelsEndpoint {
let auth = self.auth().await;
let auth_mode = auth.as_ref().map(CodexAuth::auth_mode);
let api_provider = self.provider_info.to_api_provider(auth_mode)?;
let api_auth = resolve_provider_auth(auth.as_ref(), &self.provider_info)?;
let api_auth = resolve_provider_auth(
self.auth_manager.clone(),
auth.as_ref(),
&self.provider_info,
ProviderAuthScope::UnscopedProcess,
)
.await?;
let transport = ReqwestTransport::new(build_reqwest_client());
let auth_telemetry = auth_header_telemetry(api_auth.as_ref());
let request_telemetry: Arc<dyn RequestTelemetry> = Arc::new(ModelsRequestTelemetry {

View File

@@ -14,6 +14,7 @@ use codex_protocol::account::ProviderAccount;
use codex_protocol::openai_models::ModelsResponse;
use crate::amazon_bedrock::AmazonBedrockModelProvider;
use crate::auth::ProviderAuthScope;
use crate::auth::auth_manager_for_provider;
use crate::auth::resolve_provider_auth;
use crate::models_endpoint::OpenAiModelsEndpoint;
@@ -129,8 +130,17 @@ pub trait ModelProvider: fmt::Debug + Send + Sync {
/// Returns the auth provider used to attach request credentials.
async fn api_auth(&self) -> codex_protocol::error::Result<SharedAuthProvider> {
self.api_auth_for_scope(ProviderAuthScope::UnscopedProcess)
.await
}
/// Returns request credentials, optionally scoped to a Codex session task.
async fn api_auth_for_scope(
&self,
scope: ProviderAuthScope,
) -> codex_protocol::error::Result<SharedAuthProvider> {
let auth = self.auth().await;
resolve_provider_auth(auth.as_ref(), self.info())
resolve_provider_auth(self.auth_manager(), auth.as_ref(), self.info(), scope).await
}
/// Creates the model manager implementation appropriate for this provider.