Compare commits

...

3 Commits

Author SHA1 Message Date
adrian
d0400ba0d5 feat: use thread agent task auth for inference 2026-05-05 12:54:30 -07:00
adrian
27cecd38da feat: add chatgpt agent identity opt-in 2026-05-05 12:54:30 -07:00
adrian
baf2618196 feat: add agent task identity primitives 2026-05-04 17:02:27 -07:00
19 changed files with 1520 additions and 189 deletions

View File

@@ -34,6 +34,8 @@ const AGENT_TASK_REGISTRATION_TIMEOUT: Duration = Duration::from_secs(30);
const AGENT_IDENTITY_JWKS_TIMEOUT: Duration = Duration::from_secs(10);
const AGENT_IDENTITY_JWT_AUDIENCE: &str = "codex-app-server";
const AGENT_IDENTITY_JWT_ISSUER: &str = "https://chatgpt.com/codex-backend/agent-identity";
const AGENT_REGISTRATION_TIMEOUT: Duration = Duration::from_secs(15);
const AGENT_IDENTITY_BISCUIT_TIMEOUT: Duration = Duration::from_secs(15);
/// Stored key material for a registered agent identity.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
@@ -49,6 +51,133 @@ pub struct AgentTaskAuthorizationTarget<'a> {
pub task_id: &'a str,
}
/// Runtime identity that owns one or more registered agent tasks.
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
pub struct AgentRuntimeId(String);
impl AgentRuntimeId {
pub fn new(agent_runtime_id: impl Into<String>) -> Self {
Self(agent_runtime_id.into())
}
pub fn as_str(&self) -> &str {
&self.0
}
pub fn into_string(self) -> String {
self.0
}
}
impl From<String> for AgentRuntimeId {
fn from(agent_runtime_id: String) -> Self {
Self::new(agent_runtime_id)
}
}
impl From<&str> for AgentRuntimeId {
fn from(agent_runtime_id: &str) -> Self {
Self::new(agent_runtime_id)
}
}
/// Task identifier granted to an agent runtime for a scoped objective.
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
pub struct AgentTaskId(String);
impl AgentTaskId {
pub fn new(task_id: impl Into<String>) -> Self {
Self(task_id.into())
}
pub fn as_str(&self) -> &str {
&self.0
}
pub fn into_string(self) -> String {
self.0
}
}
impl From<String> for AgentTaskId {
fn from(task_id: String) -> Self {
Self::new(task_id)
}
}
impl From<&str> for AgentTaskId {
fn from(task_id: &str) -> Self {
Self::new(task_id)
}
}
/// Caller-owned stable reference that HAI can resolve to an opaque task id.
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub struct AgentTaskExternalRef(String);
impl AgentTaskExternalRef {
pub fn new(external_ref: impl Into<String>) -> Self {
Self(external_ref.into())
}
pub fn as_str(&self) -> &str {
&self.0
}
pub fn into_string(self) -> String {
self.0
}
}
impl From<String> for AgentTaskExternalRef {
fn from(external_ref: String) -> Self {
Self::new(external_ref)
}
}
impl From<&str> for AgentTaskExternalRef {
fn from(external_ref: &str) -> Self {
Self::new(external_ref)
}
}
/// Purpose of a registered task binding.
#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum AgentTaskKind {
Thread,
Background,
}
/// Registered task binding used to authorize work for an agent runtime.
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
pub struct RegisteredAgentTask {
pub agent_runtime_id: AgentRuntimeId,
pub task_id: AgentTaskId,
pub kind: AgentTaskKind,
}
impl RegisteredAgentTask {
pub fn new(
agent_runtime_id: impl Into<AgentRuntimeId>,
task_id: impl Into<AgentTaskId>,
kind: AgentTaskKind,
) -> Self {
Self {
agent_runtime_id: agent_runtime_id.into(),
task_id: task_id.into(),
kind,
}
}
pub fn authorization_target(&self) -> AgentTaskAuthorizationTarget<'_> {
AgentTaskAuthorizationTarget {
agent_runtime_id: self.agent_runtime_id.as_str(),
task_id: self.task_id.as_str(),
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
pub struct AgentBillOfMaterials {
pub agent_version: String,
@@ -86,9 +215,11 @@ struct AgentAssertionEnvelope {
}
#[derive(Serialize)]
struct RegisterTaskRequest {
struct RegisterTaskRequest<'a> {
timestamp: String,
signature: String,
#[serde(skip_serializing_if = "Option::is_none")]
external_task_ref: Option<&'a str>,
}
#[derive(Deserialize)]
@@ -103,6 +234,18 @@ struct RegisterTaskResponse {
encrypted_task_id_camel: Option<String>,
}
#[derive(Debug, Serialize)]
struct RegisterAgentRequest {
abom: AgentBillOfMaterials,
agent_public_key: String,
capabilities: Vec<String>,
}
#[derive(Debug, Deserialize)]
struct RegisterAgentResponse {
agent_runtime_id: String,
}
pub fn authorization_header_for_agent_task(
key: AgentIdentityKey<'_>,
target: AgentTaskAuthorizationTarget<'_>,
@@ -125,6 +268,13 @@ pub fn authorization_header_for_agent_task(
Ok(format!("AgentAssertion {serialized_assertion}"))
}
pub fn authorization_header_for_registered_task(
key: AgentIdentityKey<'_>,
task: &RegisteredAgentTask,
) -> Result<String> {
authorization_header_for_agent_task(key, task.authorization_target())
}
pub async fn fetch_agent_identity_jwks(
client: &reqwest::Client,
chatgpt_base_url: &str,
@@ -197,11 +347,22 @@ pub async fn register_agent_task(
client: &reqwest::Client,
chatgpt_base_url: &str,
key: AgentIdentityKey<'_>,
) -> Result<String> {
register_agent_task_with_external_ref(client, chatgpt_base_url, key, /*external_ref*/ None)
.await
}
pub async fn register_agent_task_with_external_ref(
client: &reqwest::Client,
chatgpt_base_url: &str,
key: AgentIdentityKey<'_>,
external_ref: Option<&AgentTaskExternalRef>,
) -> Result<String> {
let timestamp = Utc::now().to_rfc3339_opts(SecondsFormat::Secs, true);
let request = RegisterTaskRequest {
signature: sign_task_registration_payload(key, &timestamp)?,
timestamp,
external_task_ref: external_ref.map(AgentTaskExternalRef::as_str),
};
let url = agent_task_registration_url(chatgpt_base_url, key.agent_runtime_id);
@@ -231,6 +392,70 @@ pub async fn register_agent_task(
task_id_from_register_task_response(key, response)
}
pub async fn register_agent_identity(
client: &reqwest::Client,
chatgpt_base_url: &str,
access_token: &str,
key_material: &GeneratedAgentKeyMaterial,
abom: AgentBillOfMaterials,
) -> Result<AgentRuntimeId> {
let url = agent_registration_url(chatgpt_base_url);
let human_biscuit =
mint_agent_identity_biscuit(client, chatgpt_base_url, access_token, "POST", &url).await?;
let request = RegisterAgentRequest {
abom,
agent_public_key: key_material.public_key_ssh.clone(),
capabilities: Vec::new(),
};
let response = client
.post(&url)
.header("X-OpenAI-Authorization", human_biscuit)
.json(&request)
.timeout(AGENT_REGISTRATION_TIMEOUT)
.send()
.await
.with_context(|| format!("failed to send agent identity registration request to {url}"))?
.error_for_status()
.with_context(|| format!("agent identity registration failed for {url}"))?
.json::<RegisterAgentResponse>()
.await
.with_context(|| format!("failed to parse agent identity response from {url}"))?;
Ok(AgentRuntimeId::new(response.agent_runtime_id))
}
async fn mint_agent_identity_biscuit(
client: &reqwest::Client,
chatgpt_base_url: &str,
access_token: &str,
target_method: &str,
target_url: &str,
) -> Result<String> {
let url = agent_identity_biscuit_url(chatgpt_base_url);
let request_id = agent_identity_request_id()?;
let response = client
.get(&url)
.bearer_auth(access_token)
.header("X-Request-Id", request_id)
.header("X-Original-Method", target_method)
.header("X-Original-Url", target_url)
.timeout(AGENT_IDENTITY_BISCUIT_TIMEOUT)
.send()
.await
.with_context(|| format!("failed to send agent identity biscuit request to {url}"))?
.error_for_status()
.with_context(|| format!("agent identity biscuit minting failed for {url}"))?;
response
.headers()
.get("x-openai-authorization")
.context("agent identity biscuit response did not include x-openai-authorization")?
.to_str()
.context("agent identity biscuit response header was not valid UTF-8")
.map(str::to_string)
}
fn task_id_from_register_task_response(
key: AgentIdentityKey<'_>,
response: RegisterTaskResponse,
@@ -331,6 +556,29 @@ pub fn agent_identity_request_id() -> Result<String> {
))
}
pub fn normalize_chatgpt_base_url(chatgpt_base_url: &str) -> String {
let mut base_url = chatgpt_base_url.trim_end_matches('/').to_string();
for suffix in [
"/wham/remote/control/server/enroll",
"/wham/remote/control/server",
] {
if let Some(stripped) = base_url.strip_suffix(suffix) {
base_url = stripped.to_string();
break;
}
}
if let Some(stripped) = base_url.strip_suffix("/codex") {
base_url = stripped.to_string();
}
if (base_url.starts_with("https://chatgpt.com")
|| base_url.starts_with("https://chat.openai.com"))
&& !base_url.contains("/backend-api")
{
base_url = format!("{base_url}/backend-api");
}
base_url
}
pub fn build_abom(session_source: SessionSource) -> AgentBillOfMaterials {
AgentBillOfMaterials {
agent_version: env!("CARGO_PKG_VERSION").to_string(),
@@ -412,6 +660,63 @@ mod tests {
use super::*;
#[test]
fn registered_agent_task_builds_authorization_target() {
let task = RegisteredAgentTask::new(
"agent-runtime-123",
"task-thread-456",
AgentTaskKind::Thread,
);
assert_eq!(
task.authorization_target(),
AgentTaskAuthorizationTarget {
agent_runtime_id: "agent-runtime-123",
task_id: "task-thread-456",
}
);
}
#[test]
fn register_task_request_omits_external_ref_by_default() {
let request = RegisterTaskRequest {
timestamp: "2026-04-23T00:00:00Z".to_string(),
signature: "signature".to_string(),
external_task_ref: None,
};
let serialized = serde_json::to_value(request).expect("serialize request");
assert_eq!(
serialized,
serde_json::json!({
"timestamp": "2026-04-23T00:00:00Z",
"signature": "signature",
})
);
}
#[test]
fn register_task_request_includes_external_ref_when_provided() {
let external_ref = AgentTaskExternalRef::new("thread-123");
let request = RegisterTaskRequest {
timestamp: "2026-04-23T00:00:00Z".to_string(),
signature: "signature".to_string(),
external_task_ref: Some(external_ref.as_str()),
};
let serialized = serde_json::to_value(request).expect("serialize request");
assert_eq!(
serialized,
serde_json::json!({
"timestamp": "2026-04-23T00:00:00Z",
"signature": "signature",
"external_task_ref": "thread-123",
})
);
}
#[test]
fn authorization_header_for_agent_task_serializes_signed_agent_assertion() {
let signing_key = SigningKey::from_bytes(&[7u8; 32]);
@@ -489,6 +794,41 @@ mod tests {
);
}
#[test]
fn authorization_header_for_registered_task_uses_existing_wire_shape() {
let signing_key = SigningKey::from_bytes(&[7u8; 32]);
let private_key = signing_key
.to_pkcs8_der()
.expect("encode test key material");
let private_key_pkcs8_base64 = BASE64_STANDARD.encode(private_key.as_bytes());
let key = AgentIdentityKey {
agent_runtime_id: "agent-123",
private_key_pkcs8_base64: &private_key_pkcs8_base64,
};
let task = RegisteredAgentTask::new("agent-123", "task-123", AgentTaskKind::Background);
let header = authorization_header_for_registered_task(key, &task)
.expect("build registered task assertion header");
let token = header
.strip_prefix("AgentAssertion ")
.expect("agent assertion scheme");
let payload = URL_SAFE_NO_PAD
.decode(token)
.expect("valid base64url payload");
let envelope: AgentAssertionEnvelope =
serde_json::from_slice(&payload).expect("valid assertion envelope");
assert_eq!(
envelope,
AgentAssertionEnvelope {
agent_runtime_id: "agent-123".to_string(),
task_id: "task-123".to_string(),
timestamp: envelope.timestamp.clone(),
signature: envelope.signature.clone(),
}
);
}
#[test]
fn decode_agent_identity_jwt_reads_claims() {
let jwt = jwt_with_payload(serde_json::json!({
@@ -703,6 +1043,14 @@ J1bwkqKZTB5dHolX9A58e/xXnfZ5P8f3Z83+Izap3FwqQulk7b1WO1MQcHuVg2NN
.expect("test JWKS should parse")
}
#[test]
fn normalize_chatgpt_base_url_strips_codex_before_backend_api() {
assert_eq!(
normalize_chatgpt_base_url("https://chatgpt.com/codex"),
"https://chatgpt.com/backend-api"
);
}
#[test]
fn agent_identity_jwks_url_uses_backend_api_base_url() {
assert_eq!(

View File

@@ -837,34 +837,13 @@ mod tests {
use serde_json::json;
use std::collections::BTreeMap;
use std::collections::VecDeque;
use std::ffi::OsString;
use std::future::pending;
use std::io::Read;
use std::io::Write;
use std::net::TcpListener;
use std::path::Path;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;
use std::thread;
use tempfile::TempDir;
use tempfile::tempdir;
struct EnvVarGuard {
key: &'static str,
original: Option<OsString>,
}
impl Drop for EnvVarGuard {
fn drop(&mut self) {
unsafe {
match &self.original {
Some(value) => std::env::set_var(self.key, value),
None => std::env::remove_var(self.key),
}
}
}
}
fn write_auth_json(codex_home: &Path, value: serde_json::Value) -> std::io::Result<()> {
std::fs::write(codex_home.join("auth.json"), serde_json::to_string(&value)?)?;
Ok(())
@@ -1219,25 +1198,6 @@ mod tests {
#[tokio::test]
async fn cloud_requirements_eligible_auth_allows_agent_identity_business_plan() {
let listener = TcpListener::bind("127.0.0.1:0").expect("bind task registration server");
let addr = listener
.local_addr()
.expect("task registration server addr");
let server = thread::spawn(move || {
let (mut stream, _) = listener.accept().expect("accept task registration request");
let mut request = [0; 4096];
let _ = stream
.read(&mut request)
.expect("read task registration request");
let body = r#"{"task_id":"task-123"}"#;
write!(
stream,
"HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{}",
body.len(),
body
)
.expect("write task registration response");
});
let record = AgentIdentityAuthRecord {
agent_runtime_id: "agent-runtime-123".to_string(),
agent_private_key: "MC4CAQAwBQYDK2VwBCIEIDQg14jybCLydjHQwXeBzsDM7oB6BSAenodx6oCovQ/D"
@@ -1247,21 +1207,9 @@ mod tests {
email: "user@example.com".to_string(),
plan_type: PlanType::Business,
chatgpt_account_is_fedramp: false,
registered_at: None,
};
let authapi_base_url = format!("http://{addr}/backend-api");
let original_authapi_base_url = std::env::var_os("CODEX_AGENT_IDENTITY_AUTHAPI_BASE_URL");
unsafe {
std::env::set_var("CODEX_AGENT_IDENTITY_AUTHAPI_BASE_URL", &authapi_base_url);
}
let _authapi_guard = EnvVarGuard {
key: "CODEX_AGENT_IDENTITY_AUTHAPI_BASE_URL",
original: original_authapi_base_url,
};
let auth = AgentIdentityAuth::load(record)
.await
.map(CodexAuth::AgentIdentity)
.expect("agent identity auth");
server.join().expect("task registration server joined");
let auth = CodexAuth::AgentIdentity(AgentIdentityAuth::new(record));
assert!(cloud_requirements_eligible_auth(&auth));
}

View File

@@ -580,6 +580,9 @@
"unified_exec": {
"type": "boolean"
},
"use_agent_identity": {
"type": "boolean"
},
"use_legacy_landlock": {
"type": "boolean"
},
@@ -4069,6 +4072,9 @@
"unified_exec": {
"type": "boolean"
},
"use_agent_identity": {
"type": "boolean"
},
"use_legacy_landlock": {
"type": "boolean"
},

View File

@@ -705,11 +705,15 @@ impl AgentControl {
let result = if let Ok(thread) = state.get_thread(agent_id).await {
thread.codex.session.ensure_rollout_materialized().await;
thread.codex.session.flush_rollout().await?;
if matches!(thread.agent_status().await, AgentStatus::Shutdown) {
let result = if matches!(thread.agent_status().await, AgentStatus::Shutdown) {
Ok(String::new())
} else {
state.send_op(agent_id, Op::Shutdown {}).await
};
if result.is_ok() || matches!(result, Err(CodexErr::InternalAgentDied)) {
thread.wait_until_terminated().await;
}
result
} else {
state.send_op(agent_id, Op::Shutdown {}).await
};

View File

@@ -113,8 +113,11 @@ use crate::util::emit_feedback_auth_recovery_tags;
use codex_api::map_api_error;
use codex_feedback::FeedbackRequestTags;
use codex_feedback::emit_feedback_request_tags_with_auth_env;
use codex_login::auth::AgentIdentityAuthPolicy;
use codex_login::auth_env_telemetry::AuthEnvTelemetry;
use codex_login::auth_env_telemetry::collect_auth_env_telemetry;
use codex_model_provider::AgentTaskExternalRef;
use codex_model_provider::ProviderAuthScope;
use codex_model_provider::SharedModelProvider;
use codex_model_provider::create_model_provider;
#[cfg(test)]
@@ -158,6 +161,8 @@ struct ModelClientState {
provider: SharedModelProvider,
auth_env_telemetry: AuthEnvTelemetry,
session_source: SessionSource,
agent_identity_policy: AgentIdentityAuthPolicy,
chatgpt_base_url: Option<String>,
model_verbosity: Option<VerbosityConfig>,
enable_request_compression: bool,
include_timing_metrics: bool,
@@ -305,6 +310,35 @@ impl ModelClient {
enable_request_compression: bool,
include_timing_metrics: bool,
beta_features_header: Option<String>,
) -> Self {
Self::new_with_agent_identity_policy(
auth_manager,
conversation_id,
installation_id,
provider_info,
session_source,
AgentIdentityAuthPolicy::JwtOnly,
/*chatgpt_base_url*/ None,
model_verbosity,
enable_request_compression,
include_timing_metrics,
beta_features_header,
)
}
#[allow(clippy::too_many_arguments)]
pub fn new_with_agent_identity_policy(
auth_manager: Option<Arc<AuthManager>>,
conversation_id: ThreadId,
installation_id: String,
provider_info: ModelProviderInfo,
session_source: SessionSource,
agent_identity_policy: AgentIdentityAuthPolicy,
chatgpt_base_url: Option<String>,
model_verbosity: Option<VerbosityConfig>,
enable_request_compression: bool,
include_timing_metrics: bool,
beta_features_header: Option<String>,
) -> Self {
let model_provider = create_model_provider(provider_info, auth_manager);
let codex_api_key_env_enabled = model_provider
@@ -321,6 +355,8 @@ impl ModelClient {
provider: model_provider,
auth_env_telemetry,
session_source,
agent_identity_policy,
chatgpt_base_url,
model_verbosity,
enable_request_compression,
include_timing_metrics,
@@ -744,7 +780,11 @@ impl ModelClient {
async fn current_client_setup(&self) -> Result<CurrentClientSetup> {
let auth = self.state.provider.auth().await;
let api_provider = self.state.provider.api_provider().await?;
let api_auth = self.state.provider.api_auth().await?;
let api_auth = self
.state
.provider
.api_auth_for_scope(self.provider_auth_scope())
.await?;
Ok(CurrentClientSetup {
auth,
api_provider,
@@ -752,6 +792,15 @@ impl ModelClient {
})
}
fn provider_auth_scope(&self) -> ProviderAuthScope {
ProviderAuthScope::Thread {
external_ref: AgentTaskExternalRef::new(self.state.conversation_id.to_string()),
agent_identity_policy: self.state.agent_identity_policy,
session_source: self.state.session_source.clone(),
chatgpt_base_url: self.state.chatgpt_base_url.clone(),
}
}
/// Opens a websocket connection using the same header and telemetry wiring as normal turns.
///
/// Both startup prewarm and in-turn `needs_new` reconnects call this path so handshake
@@ -884,6 +933,10 @@ impl Drop for ModelClientSession {
}
impl ModelClientSession {
async fn current_client_setup(&self) -> Result<CurrentClientSetup> {
self.client.current_client_setup().await
}
pub(crate) fn reset_websocket_session(&mut self) {
self.websocket_session.connection = None;
self.websocket_session.last_request = None;
@@ -1013,7 +1066,7 @@ impl ModelClientSession {
return Ok(());
}
let client_setup = self.client.current_client_setup().await.map_err(|err| {
let client_setup = self.current_client_setup().await.map_err(|err| {
ApiError::Stream(format!(
"failed to build websocket prewarm client setup: {err}"
))
@@ -1176,7 +1229,7 @@ impl ModelClientSession {
.map(AuthManager::unauthorized_recovery);
let mut pending_retry = PendingUnauthorizedRetry::default();
loop {
let client_setup = self.client.current_client_setup().await?;
let client_setup = self.current_client_setup().await?;
let transport = ReqwestTransport::new(build_reqwest_client());
let request_auth_context = AuthRequestTelemetryContext::new(
client_setup.auth.as_ref().map(CodexAuth::auth_mode),
@@ -1289,7 +1342,7 @@ impl ModelClientSession {
.map(AuthManager::unauthorized_recovery);
let mut pending_retry = PendingUnauthorizedRetry::default();
loop {
let client_setup = self.client.current_client_setup().await?;
let client_setup = self.current_client_setup().await?;
let request_auth_context = AuthRequestTelemetryContext::new(
client_setup.auth.as_ref().map(CodexAuth::auth_mode),
client_setup.api_auth.as_ref(),

View File

@@ -10,7 +10,10 @@ use super::X_OPENAI_SUBAGENT_HEADER;
use codex_api::ApiError;
use codex_api::ResponseEvent;
use codex_app_server_protocol::AuthMode;
use codex_login::auth::AgentIdentityAuthPolicy;
use codex_model_provider::AgentTaskExternalRef;
use codex_model_provider::BearerAuthProvider;
use codex_model_provider::ProviderAuthScope;
use codex_model_provider_info::WireApi;
use codex_model_provider_info::create_oss_provider_with_base_url;
use codex_otel::SessionTelemetry;
@@ -51,10 +54,17 @@ use tracing_subscriber::registry::LookupSpan;
use tracing_subscriber::util::SubscriberInitExt;
fn test_model_client(session_source: SessionSource) -> ModelClient {
test_model_client_with_thread_id(ThreadId::new(), session_source)
}
fn test_model_client_with_thread_id(
conversation_id: ThreadId,
session_source: SessionSource,
) -> ModelClient {
let provider = create_oss_provider_with_base_url("https://example.com/v1", WireApi::Responses);
ModelClient::new(
/*auth_manager*/ None,
ThreadId::new(),
conversation_id,
/*installation_id*/ "11111111-1111-4111-8111-111111111111".to_string(),
provider,
session_source,
@@ -65,6 +75,23 @@ fn test_model_client(session_source: SessionSource) -> ModelClient {
)
}
#[test]
fn provider_auth_scope_uses_thread_id_as_session_ref() {
let conversation_id =
ThreadId::from_string("018f4f4c-43f5-7b28-8e24-000000000001").expect("valid thread id");
let client = test_model_client_with_thread_id(conversation_id, SessionSource::Cli);
assert_eq!(
client.provider_auth_scope(),
ProviderAuthScope::Thread {
external_ref: AgentTaskExternalRef::new(conversation_id.to_string()),
agent_identity_policy: AgentIdentityAuthPolicy::JwtOnly,
session_source: SessionSource::Cli,
chatgpt_base_url: None,
}
);
}
fn test_model_info() -> ModelInfo {
serde_json::from_value(json!({
"slug": "gpt-test",

View File

@@ -865,12 +865,18 @@ impl Session {
state_db: state_db_ctx.clone(),
live_thread: live_thread_init.as_ref().cloned(),
thread_store: Arc::clone(&thread_store),
model_client: ModelClient::new(
model_client: ModelClient::new_with_agent_identity_policy(
Some(Arc::clone(&auth_manager)),
conversation_id,
installation_id,
session_configuration.provider.clone(),
session_configuration.session_source.clone(),
if config.features.enabled(Feature::UseAgentIdentity) {
codex_login::auth::AgentIdentityAuthPolicy::JwtOrChatgpt
} else {
codex_login::auth::AgentIdentityAuthPolicy::JwtOnly
},
Some(config.chatgpt_base_url.clone()),
config.model_verbosity,
config.features.enabled(Feature::EnableRequestCompression),
config.features.enabled(Feature::RuntimeMetrics),
@@ -1038,8 +1044,6 @@ impl Session {
anyhow::bail!("required MCP servers failed to initialize: {details}");
}
}
sess.schedule_startup_prewarm(session_configuration.base_instructions.clone())
.await;
let session_start_source = match &initial_history {
InitialHistory::Resumed(_) => codex_hooks::SessionStartSource::Resume,
InitialHistory::New | InitialHistory::Forked(_) => {
@@ -1050,6 +1054,8 @@ impl Session {
// record_initial_history can emit events. We record only after the SessionConfiguredEvent is emitted.
sess.record_initial_history(initial_history).await;
sess.schedule_startup_prewarm(session_configuration.base_instructions.clone())
.await;
{
let mut state = sess.state.lock().await;
state.set_pending_session_start_source(Some(session_start_source));

View File

@@ -229,6 +229,8 @@ pub enum Feature {
ResponsesWebsocketsV2,
/// Enable remote compaction v2 over the normal Responses API.
RemoteCompactionV2,
/// Use Agent Identity for ChatGPT-authenticated sessions.
UseAgentIdentity,
/// Enable workspace dependency support.
WorkspaceDependencies,
}
@@ -1129,6 +1131,12 @@ pub const FEATURES: &[FeatureSpec] = &[
stage: Stage::UnderDevelopment,
default_enabled: false,
},
FeatureSpec {
id: Feature::UseAgentIdentity,
key: "use_agent_identity",
stage: Stage::UnderDevelopment,
default_enabled: false,
},
FeatureSpec {
id: Feature::WorkspaceDependencies,
key: "workspace_dependencies",

View File

@@ -1,43 +1,109 @@
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::Mutex;
use codex_agent_identity::AgentIdentityKey;
use codex_agent_identity::register_agent_task;
use codex_agent_identity::AgentRuntimeId;
use codex_agent_identity::AgentTaskExternalRef;
use codex_agent_identity::AgentTaskId;
use codex_agent_identity::AgentTaskKind;
use codex_agent_identity::RegisteredAgentTask;
use codex_agent_identity::normalize_chatgpt_base_url;
use codex_agent_identity::register_agent_task_with_external_ref;
use codex_protocol::account::PlanType as AccountPlanType;
use std::env;
use tokio::sync::OnceCell;
use crate::default_client::build_reqwest_client;
use super::storage::AgentIdentityAuthRecord;
const PROD_AGENT_IDENTITY_AUTHAPI_BASE_URL: &str = "https://auth.openai.com/api/accounts";
const CODEX_AGENT_IDENTITY_AUTHAPI_BASE_URL_ENV_VAR: &str = "CODEX_AGENT_IDENTITY_AUTHAPI_BASE_URL";
const DEFAULT_CHATGPT_BACKEND_BASE_URL: &str = "https://chatgpt.com/backend-api";
#[derive(Clone, Debug)]
#[derive(Debug)]
pub struct AgentIdentityAuth {
record: AgentIdentityAuthRecord,
process_task_id: String,
task_ids: Arc<Mutex<HashMap<AgentTaskCacheKey, Arc<OnceCell<String>>>>>,
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
enum AgentTaskCacheKey {
Process,
Thread(AgentTaskExternalRef),
}
impl Clone for AgentIdentityAuth {
fn clone(&self) -> Self {
Self {
record: self.record.clone(),
task_ids: Arc::clone(&self.task_ids),
}
}
}
impl AgentIdentityAuth {
pub async fn load(record: AgentIdentityAuthRecord) -> std::io::Result<Self> {
let agent_identity_authapi_base_url = agent_identity_authapi_base_url();
let process_task_id = register_agent_task(
&build_reqwest_client(),
&agent_identity_authapi_base_url,
key(&record),
)
.await
.map_err(std::io::Error::other)?;
Ok(Self {
pub fn new(record: AgentIdentityAuthRecord) -> Self {
Self {
record,
process_task_id,
})
task_ids: Arc::new(Mutex::new(HashMap::new())),
}
}
pub fn record(&self) -> &AgentIdentityAuthRecord {
&self.record
}
pub fn process_task_id(&self) -> &str {
&self.process_task_id
pub fn process_task_id(&self) -> Option<String> {
self.task_id_for_initialized_key(&AgentTaskCacheKey::Process)
}
pub async fn ensure_runtime(&self, chatgpt_base_url: Option<String>) -> std::io::Result<()> {
self.task_id_for_key(
AgentTaskCacheKey::Process,
chatgpt_base_url,
/*external_ref*/ None,
)
.await
.map(|_| ())
}
pub async fn registered_thread_task(
&self,
external_ref: AgentTaskExternalRef,
chatgpt_base_url: Option<String>,
) -> std::io::Result<RegisteredAgentTask> {
let task_id = self
.task_id_for_key(
AgentTaskCacheKey::Thread(external_ref.clone()),
chatgpt_base_url,
Some(external_ref),
)
.await?;
Ok(self.registered_task(task_id, AgentTaskKind::Thread))
}
pub async fn register_task(&self, chatgpt_base_url: Option<String>) -> std::io::Result<String> {
self.register_task_with_external_ref(chatgpt_base_url, /*external_ref*/ None)
.await
}
async fn register_task_with_external_ref(
&self,
chatgpt_base_url: Option<String>,
external_ref: Option<&AgentTaskExternalRef>,
) -> std::io::Result<String> {
let base_url = normalize_chatgpt_base_url(
chatgpt_base_url
.as_deref()
.unwrap_or(DEFAULT_CHATGPT_BACKEND_BASE_URL),
);
register_agent_task_with_external_ref(
&build_reqwest_client(),
&base_url,
self.key(),
external_ref,
)
.await
.map_err(std::io::Error::other)
}
pub fn account_id(&self) -> &str {
@@ -59,82 +125,202 @@ impl AgentIdentityAuth {
pub fn is_fedramp_account(&self) -> bool {
self.record.chatgpt_account_is_fedramp
}
}
fn key(&self) -> AgentIdentityKey<'_> {
AgentIdentityKey {
agent_runtime_id: &self.record.agent_runtime_id,
private_key_pkcs8_base64: &self.record.agent_private_key,
}
}
fn agent_identity_authapi_base_url() -> String {
env::var(CODEX_AGENT_IDENTITY_AUTHAPI_BASE_URL_ENV_VAR)
.ok()
.map(|base_url| base_url.trim().trim_end_matches('/').to_string())
.filter(|base_url| !base_url.is_empty())
.unwrap_or_else(|| PROD_AGENT_IDENTITY_AUTHAPI_BASE_URL.to_string())
}
async fn task_id_for_key(
&self,
key: AgentTaskCacheKey,
chatgpt_base_url: Option<String>,
external_ref: Option<AgentTaskExternalRef>,
) -> std::io::Result<String> {
let slot = self.task_slot(key)?;
slot.get_or_try_init(|| async {
self.register_task_with_external_ref(chatgpt_base_url, external_ref.as_ref())
.await
})
.await
.cloned()
}
fn key(record: &AgentIdentityAuthRecord) -> AgentIdentityKey<'_> {
AgentIdentityKey {
agent_runtime_id: &record.agent_runtime_id,
private_key_pkcs8_base64: &record.agent_private_key,
fn task_slot(&self, key: AgentTaskCacheKey) -> std::io::Result<Arc<OnceCell<String>>> {
let mut task_ids = self
.task_ids
.lock()
.map_err(|_| std::io::Error::other("failed to lock agent task cache"))?;
Ok(task_ids
.entry(key)
.or_insert_with(|| Arc::new(OnceCell::new()))
.clone())
}
fn task_id_for_initialized_key(&self, key: &AgentTaskCacheKey) -> Option<String> {
let task_ids = self.task_ids.lock().ok()?;
task_ids.get(key)?.get().cloned()
}
fn registered_task(&self, task_id: String, kind: AgentTaskKind) -> RegisteredAgentTask {
RegisteredAgentTask::new(
AgentRuntimeId::new(self.record.agent_runtime_id.clone()),
AgentTaskId::new(task_id),
kind,
)
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;
use codex_agent_identity::generate_agent_key_material;
use pretty_assertions::assert_eq;
use serde_json::json;
use wiremock::Mock;
use wiremock::MockServer;
use wiremock::ResponseTemplate;
use wiremock::matchers::body_partial_json;
use wiremock::matchers::method;
use wiremock::matchers::path;
use super::*;
use serial_test::serial;
#[test]
#[serial(codex_auth_env)]
fn agent_identity_authapi_base_url_prefers_env_value() {
let _guard = EnvVarGuard::set(
CODEX_AGENT_IDENTITY_AUTHAPI_BASE_URL_ENV_VAR,
"https://authapi.example.test/api/accounts/",
);
assert_eq!(
agent_identity_authapi_base_url(),
"https://authapi.example.test/api/accounts"
);
}
#[test]
#[serial(codex_auth_env)]
fn agent_identity_authapi_base_url_uses_prod_authapi_by_default() {
let _guard = EnvVarGuard::remove(CODEX_AGENT_IDENTITY_AUTHAPI_BASE_URL_ENV_VAR);
assert_eq!(
agent_identity_authapi_base_url(),
PROD_AGENT_IDENTITY_AUTHAPI_BASE_URL
);
}
struct EnvVarGuard {
key: &'static str,
original: Option<std::ffi::OsString>,
}
impl EnvVarGuard {
fn set(key: &'static str, value: &str) -> Self {
let original = env::var_os(key);
unsafe {
env::set_var(key, value);
}
Self { key, original }
}
fn remove(key: &'static str) -> Self {
let original = env::var_os(key);
unsafe {
env::remove_var(key);
}
Self { key, original }
fn agent_identity_record(private_key: String) -> AgentIdentityAuthRecord {
AgentIdentityAuthRecord {
agent_runtime_id: "agent-runtime-1".to_string(),
agent_private_key: private_key,
account_id: "account-1".to_string(),
chatgpt_user_id: "user-1".to_string(),
email: "agent@example.com".to_string(),
plan_type: AccountPlanType::Plus,
chatgpt_account_is_fedramp: false,
registered_at: None,
}
}
impl Drop for EnvVarGuard {
fn drop(&mut self) {
unsafe {
match &self.original {
Some(value) => env::set_var(self.key, value),
None => env::remove_var(self.key),
fn agent_identity_auth() -> AgentIdentityAuth {
let key_material = generate_agent_key_material().expect("generate key material");
AgentIdentityAuth::new(agent_identity_record(key_material.private_key_pkcs8_base64))
}
#[tokio::test]
async fn registered_thread_task_registers_once_per_external_ref() -> anyhow::Result<()> {
let auth = agent_identity_auth();
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/agent/agent-runtime-1/task/register"))
.and(body_partial_json(json!({
"external_task_ref": "thread-1",
})))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"task_id": "task-thread-1",
})))
.expect(1)
.mount(&server)
.await;
let first = auth
.registered_thread_task(AgentTaskExternalRef::new("thread-1"), Some(server.uri()))
.await?;
let second = auth
.registered_thread_task(AgentTaskExternalRef::new("thread-1"), Some(server.uri()))
.await?;
assert_eq!(first, second);
assert_eq!(
first,
RegisteredAgentTask::new(
AgentRuntimeId::new("agent-runtime-1"),
AgentTaskId::new("task-thread-1"),
AgentTaskKind::Thread,
)
);
Ok(())
}
#[tokio::test]
async fn registered_thread_task_uses_distinct_external_refs() -> anyhow::Result<()> {
let auth = agent_identity_auth();
let server = MockServer::start().await;
for (external_ref, task_id) in
[("thread-1", "task-thread-1"), ("thread-2", "task-thread-2")]
{
Mock::given(method("POST"))
.and(path("/v1/agent/agent-runtime-1/task/register"))
.and(body_partial_json(json!({
"external_task_ref": external_ref,
})))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"task_id": task_id,
})))
.expect(1)
.mount(&server)
.await;
}
let first = auth
.registered_thread_task(AgentTaskExternalRef::new("thread-1"), Some(server.uri()))
.await?;
let second = auth
.registered_thread_task(AgentTaskExternalRef::new("thread-2"), Some(server.uri()))
.await?;
assert_eq!(first.task_id.as_str(), "task-thread-1");
assert_eq!(second.task_id.as_str(), "task-thread-2");
Ok(())
}
#[tokio::test]
async fn failed_thread_task_registration_can_retry() -> anyhow::Result<()> {
let auth = agent_identity_auth();
let server = MockServer::start().await;
let request_count = Arc::new(AtomicUsize::new(0));
let response_count = Arc::clone(&request_count);
Mock::given(method("POST"))
.and(path("/v1/agent/agent-runtime-1/task/register"))
.and(body_partial_json(json!({
"external_task_ref": "thread-1",
})))
.respond_with(move |_request: &wiremock::Request| {
if response_count.fetch_add(1, Ordering::SeqCst) == 0 {
ResponseTemplate::new(500)
} else {
ResponseTemplate::new(200).set_body_json(json!({
"task_id": "task-thread-1",
}))
}
}
}
})
.expect(2)
.mount(&server)
.await;
auth.registered_thread_task(AgentTaskExternalRef::new("thread-1"), Some(server.uri()))
.await
.expect_err("first registration should fail");
let task = auth
.registered_thread_task(AgentTaskExternalRef::new("thread-1"), Some(server.uri()))
.await?;
assert_eq!(request_count.load(Ordering::SeqCst), 2);
assert_eq!(task.task_id.as_str(), "task-thread-1");
Ok(())
}
#[test]
fn task_slots_are_shared_across_clones() {
let auth = agent_identity_auth();
let cloned = auth.clone();
let slot = auth
.task_slot(AgentTaskCacheKey::Process)
.expect("task slot should be available");
slot.set("process-task-1".to_string())
.expect("process task should be unset");
assert_eq!(cloned.process_task_id(), Some("process-task-1".to_string()));
}
}

View File

@@ -1,4 +1,5 @@
use super::*;
use crate::auth::storage::AgentIdentityStorage;
use crate::auth::storage::FileAuthStorage;
use crate::auth::storage::get_auth_file;
use crate::token_data::IdTokenInfo;
@@ -6,6 +7,7 @@ use codex_app_server_protocol::AuthMode;
use codex_protocol::account::PlanType as AccountPlanType;
use codex_protocol::auth::KnownPlan as InternalKnownPlan;
use codex_protocol::auth::PlanType as InternalPlanType;
use codex_protocol::protocol::SessionSource;
use base64::Engine;
use codex_protocol::config_types::ForcedLoginMethod;
@@ -19,6 +21,7 @@ use tempfile::tempdir;
use wiremock::Mock;
use wiremock::MockServer;
use wiremock::ResponseTemplate;
use wiremock::matchers::header;
use wiremock::matchers::method;
use wiremock::matchers::path;
@@ -114,7 +117,9 @@ async fn login_with_agent_identity_writes_only_token() {
.expect("auth.json should parse");
assert_eq!(auth.auth_mode, Some(AuthMode::AgentIdentity));
assert_eq!(
auth.agent_identity.as_deref(),
auth.agent_identity
.as_ref()
.and_then(AgentIdentityStorage::as_jwt),
Some(agent_identity.as_str())
);
assert!(auth.tokens.is_none(), "tokens should be cleared");
@@ -142,6 +147,109 @@ async fn login_with_agent_identity_rejects_invalid_jwt() {
);
}
#[tokio::test]
async fn chatgpt_auth_registers_agent_identity_when_enabled() -> anyhow::Result<()> {
let codex_home = tempdir()?;
write_auth_file(
AuthFileParams {
openai_api_key: None,
chatgpt_plan_type: Some("pro".to_string()),
chatgpt_account_id: Some("account-123".to_string()),
},
codex_home.path(),
)?;
let auth = super::load_auth(
codex_home.path(),
/*enable_codex_api_key_env*/ false,
AuthCredentialsStoreMode::File,
/*chatgpt_base_url*/ None,
)
.await?
.expect("auth should load");
assert!(
auth.agent_identity_auth(
AgentIdentityAuthPolicy::JwtOnly,
/*chatgpt_base_url*/ None,
/*forced_chatgpt_workspace_id*/ None,
SessionSource::Cli,
)
.await?
.is_none()
);
let server = MockServer::start().await;
let target_url = format!("{}/v1/agent/register", server.uri());
Mock::given(method("GET"))
.and(path("/authenticate_app_v2"))
.and(header("authorization", "Bearer test-access-token"))
.and(header("x-original-method", "POST"))
.and(header("x-original-url", target_url))
.respond_with(
ResponseTemplate::new(/*s*/ 200)
.insert_header("x-openai-authorization", "human-biscuit"),
)
.expect(/*r*/ 1)
.mount(&server)
.await;
Mock::given(method("POST"))
.and(path("/v1/agent/register"))
.and(header("x-openai-authorization", "human-biscuit"))
.respond_with(ResponseTemplate::new(/*s*/ 200).set_body_json(json!({
"agent_runtime_id": "agent-runtime-123",
})))
.expect(/*r*/ 1)
.mount(&server)
.await;
let agent_auth = auth
.agent_identity_auth(
AgentIdentityAuthPolicy::JwtOrChatgpt,
Some(server.uri()),
/*forced_chatgpt_workspace_id*/ None,
SessionSource::Cli,
)
.await?
.expect("agent identity should register");
let reused = auth
.agent_identity_auth(
AgentIdentityAuthPolicy::JwtOrChatgpt,
Some(server.uri()),
/*forced_chatgpt_workspace_id*/ None,
SessionSource::Cli,
)
.await?
.expect("agent identity should be reused");
Mock::given(method("POST"))
.and(path("/v1/agent/agent-runtime-123/task/register"))
.respond_with(ResponseTemplate::new(/*s*/ 200).set_body_json(json!({
"task_id": "task-123",
})))
.expect(/*r*/ 1)
.mount(&server)
.await;
agent_auth.ensure_runtime(Some(server.uri())).await?;
reused.ensure_runtime(Some(server.uri())).await?;
assert_eq!(
agent_auth.record().agent_runtime_id,
reused.record().agent_runtime_id
);
assert_eq!(agent_auth.process_task_id(), Some("task-123".to_string()));
assert_eq!(reused.process_task_id(), Some("task-123".to_string()));
assert_eq!(agent_auth.record().agent_runtime_id, "agent-runtime-123");
assert_eq!(agent_auth.record().account_id, "account-123");
assert_eq!(agent_auth.record().chatgpt_user_id, "user-12345");
assert_eq!(
auth.get_agent_identity("account-123")
.expect("identity should persist")
.agent_runtime_id,
"agent-runtime-123"
);
Ok(())
}
#[tokio::test]
async fn login_with_agent_identity_rejects_unsigned_jwt() {
let dir = tempdir().unwrap();
@@ -746,7 +854,10 @@ async fn load_auth_reads_agent_identity_from_env() {
panic!("env auth should load as agent identity");
};
assert_eq!(agent_identity.record(), &expected_record);
assert_eq!(agent_identity.process_task_id(), "task-123");
assert_eq!(
agent_identity.process_task_id(),
Some("task-123".to_string())
);
assert!(
!get_auth_file(codex_home.path()).exists(),
"env auth should not write auth.json"
@@ -923,6 +1034,7 @@ fn agent_identity_record(account_id: &str) -> AgentIdentityAuthRecord {
email: "user@example.com".to_string(),
plan_type: AccountPlanType::Pro,
chatgpt_account_is_fedramp: false,
registered_at: None,
}
}

View File

@@ -16,8 +16,13 @@ use std::sync::atomic::AtomicU64;
use std::sync::atomic::Ordering;
use tokio::sync::Semaphore;
use codex_agent_identity::build_abom;
use codex_agent_identity::decode_agent_identity_jwt;
use codex_agent_identity::fetch_agent_identity_jwks;
use codex_agent_identity::generate_agent_key_material;
use codex_agent_identity::normalize_chatgpt_base_url;
use codex_agent_identity::public_key_ssh_from_private_key_pkcs8_base64;
use codex_agent_identity::register_agent_identity;
use codex_app_server_protocol::AuthMode;
use codex_app_server_protocol::AuthMode as ApiAuthMode;
use codex_protocol::config_types::ForcedLoginMethod;
@@ -27,6 +32,7 @@ use super::external_bearer::BearerTokenRefresher;
use super::revoke::revoke_auth_tokens;
pub use crate::auth::agent_identity::AgentIdentityAuth;
pub use crate::auth::storage::AgentIdentityAuthRecord;
use crate::auth::storage::AgentIdentityStorage;
pub use crate::auth::storage::AuthDotJson;
use crate::auth::storage::AuthStorageBackend;
use crate::auth::storage::create_auth_storage;
@@ -42,6 +48,7 @@ use codex_protocol::account::PlanType as AccountPlanType;
use codex_protocol::auth::PlanType as InternalPlanType;
use codex_protocol::auth::RefreshTokenFailedError;
use codex_protocol::auth::RefreshTokenFailedReason;
use codex_protocol::protocol::SessionSource;
use serde_json::Value;
use thiserror::Error;
@@ -54,6 +61,15 @@ pub enum CodexAuth {
AgentIdentity(AgentIdentityAuth),
}
/// Policy for resolving Agent Identity auth from a broader Codex auth snapshot.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AgentIdentityAuthPolicy {
/// Only use an existing Agent Identity JWT/runtime auth.
JwtOnly,
/// Use an Agent Identity JWT/runtime auth, or register one from ChatGPT auth.
JwtOrChatgpt,
}
impl PartialEq for CodexAuth {
fn eq(&self, other: &Self) -> bool {
self.api_auth_mode() == other.api_auth_mode()
@@ -79,10 +95,38 @@ pub struct ChatgptAuthTokens {
#[derive(Debug, Clone)]
struct ChatgptAuthState {
auth_dot_json: Arc<Mutex<Option<AuthDotJson>>>,
agent_identity_auth: Arc<Mutex<Option<AgentIdentityAuth>>>,
client: CodexHttpClient,
}
impl ChatgptAuthState {
fn new(auth_dot_json: AuthDotJson) -> Self {
let agent_identity_auth = auth_dot_json
.agent_identity
.as_ref()
.and_then(AgentIdentityStorage::as_record)
.cloned()
.map(AgentIdentityAuth::new);
Self {
auth_dot_json: Arc::new(Mutex::new(Some(auth_dot_json))),
agent_identity_auth: Arc::new(Mutex::new(agent_identity_auth)),
client: create_client(),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct ChatgptAgentIdentityBinding {
account_id: String,
chatgpt_user_id: String,
email: String,
plan_type: AccountPlanType,
chatgpt_account_is_fedramp: bool,
access_token: String,
}
const TOKEN_REFRESH_INTERVAL: i64 = 8;
const DEFAULT_CHATGPT_BACKEND_BASE_URL: &str = "https://chatgpt.com/backend-api";
const REFRESH_TOKEN_EXPIRED_MESSAGE: &str = "Your access token could not be refreshed because your refresh token has expired. Please log out and sign in again.";
const REFRESH_TOKEN_REUSED_MESSAGE: &str = "Your access token could not be refreshed because your refresh token was already used. Please log out and sign in again.";
@@ -90,7 +134,6 @@ const REFRESH_TOKEN_INVALIDATED_MESSAGE: &str = "Your access token could not be
const REFRESH_TOKEN_UNKNOWN_MESSAGE: &str =
"Your access token could not be refreshed. Please log out and sign in again.";
const REFRESH_TOKEN_ACCOUNT_MISMATCH_MESSAGE: &str = "Your access token could not be refreshed because you have since logged out or signed in to another account. Please sign in again.";
const DEFAULT_CHATGPT_BACKEND_BASE_URL: &str = "https://chatgpt.com/backend-api";
const REFRESH_TOKEN_URL: &str = "https://auth.openai.com/oauth/token";
pub(super) const REVOKE_TOKEN_URL: &str = "https://auth.openai.com/oauth/revoke";
pub const REFRESH_TOKEN_URL_OVERRIDE_ENV_VAR: &str = "CODEX_REFRESH_TOKEN_URL_OVERRIDE";
@@ -203,7 +246,6 @@ impl CodexAuth {
chatgpt_base_url: Option<&str>,
) -> std::io::Result<Self> {
let auth_mode = auth_dot_json.resolved_mode();
let client = create_client();
if auth_mode == ApiAuthMode::ApiKey {
let Some(api_key) = auth_dot_json.openai_api_key.as_deref() else {
return Err(std::io::Error::other("API key auth is missing a key."));
@@ -211,19 +253,20 @@ impl CodexAuth {
return Ok(Self::from_api_key(api_key));
}
if auth_mode == ApiAuthMode::AgentIdentity {
let Some(agent_identity) = auth_dot_json.agent_identity else {
let Some(agent_identity) = auth_dot_json
.agent_identity
.as_ref()
.and_then(AgentIdentityStorage::as_jwt)
else {
return Err(std::io::Error::other(
"agent identity auth is missing an agent identity token.",
));
};
return Self::from_agent_identity_jwt(&agent_identity, chatgpt_base_url).await;
return Self::from_agent_identity_jwt(agent_identity, chatgpt_base_url).await;
}
let storage_mode = auth_dot_json.storage_mode(auth_credentials_store_mode);
let state = ChatgptAuthState {
auth_dot_json: Arc::new(Mutex::new(Some(auth_dot_json))),
client,
};
let state = ChatgptAuthState::new(auth_dot_json);
match auth_mode {
ApiAuthMode::Chatgpt => {
@@ -261,7 +304,9 @@ impl CodexAuth {
.trim_end_matches('/')
.to_string();
let record = verified_agent_identity_record(jwt, &base_url).await?;
Ok(Self::AgentIdentity(AgentIdentityAuth::load(record).await?))
let auth = AgentIdentityAuth::new(record);
auth.ensure_runtime(Some(base_url)).await?;
Ok(Self::AgentIdentity(auth))
}
pub fn auth_mode(&self) -> AuthMode {
@@ -408,6 +453,184 @@ impl CodexAuth {
self.get_current_auth_json().and_then(|t| t.tokens)
}
pub fn get_agent_identity(&self, account_id: &str) -> Option<AgentIdentityAuthRecord> {
self.get_current_auth_json()
.and_then(|auth| auth.agent_identity)
.and_then(|identity| identity.as_record().cloned())
.filter(|identity| identity.account_id == account_id)
}
pub fn set_agent_identity(&self, record: AgentIdentityAuthRecord) -> std::io::Result<()> {
let agent_identity_auth = self.agent_identity_auth_for_record(record.clone())?;
match self {
Self::Chatgpt(auth) => auth
.update_auth_json(|auth_dot_json| {
auth_dot_json.agent_identity = Some(AgentIdentityStorage::Record(record));
true
})
.map(|_| ()),
Self::ChatgptAuthTokens(auth) => auth.update_auth_json_in_memory(|auth_dot_json| {
auth_dot_json.agent_identity = Some(AgentIdentityStorage::Record(record));
}),
Self::ApiKey(_) | Self::AgentIdentity(_) => Ok(()),
}?;
self.set_cached_agent_identity_auth(Some(agent_identity_auth))
}
pub fn remove_agent_identity(&self) -> std::io::Result<bool> {
let removed = match self {
Self::Chatgpt(auth) => {
auth.update_auth_json(|auth_dot_json| auth_dot_json.agent_identity.take().is_some())
}
Self::ChatgptAuthTokens(auth) => {
let mut removed = false;
auth.update_auth_json_in_memory(|auth_dot_json| {
removed = auth_dot_json.agent_identity.take().is_some();
})?;
Ok(removed)
}
Self::ApiKey(_) | Self::AgentIdentity(_) => Ok(false),
}?;
if removed {
self.set_cached_agent_identity_auth(/*auth*/ None)?;
}
Ok(removed)
}
fn cached_agent_identity_auth(
&self,
binding: &ChatgptAgentIdentityBinding,
) -> Option<AgentIdentityAuth> {
let auth = self.cached_agent_identity_auth_value()?;
if agent_identity_record_matches_binding(auth.record(), binding)
&& public_key_ssh_from_private_key_pkcs8_base64(&auth.record().agent_private_key)
.is_ok()
{
Some(auth)
} else {
None
}
}
fn agent_identity_auth_for_record(
&self,
record: AgentIdentityAuthRecord,
) -> std::io::Result<AgentIdentityAuth> {
if let Some(auth) = self.cached_agent_identity_auth_value()
&& auth.record() == &record
{
return Ok(auth);
}
Ok(AgentIdentityAuth::new(record))
}
fn cached_agent_identity_auth_value(&self) -> Option<AgentIdentityAuth> {
let state = self.chatgpt_state()?;
let auth = state.agent_identity_auth.lock().ok()?;
auth.clone()
}
fn set_cached_agent_identity_auth(
&self,
auth: Option<AgentIdentityAuth>,
) -> std::io::Result<()> {
let Some(state) = self.chatgpt_state() else {
return Ok(());
};
let mut cached = state
.agent_identity_auth
.lock()
.map_err(|_| std::io::Error::other("failed to lock agent identity cache"))?;
*cached = auth;
Ok(())
}
fn chatgpt_state(&self) -> Option<&ChatgptAuthState> {
match self {
Self::Chatgpt(auth) => Some(&auth.state),
Self::ChatgptAuthTokens(auth) => Some(&auth.state),
Self::ApiKey(_) | Self::AgentIdentity(_) => None,
}
}
pub async fn agent_identity_auth(
&self,
policy: AgentIdentityAuthPolicy,
chatgpt_base_url: Option<String>,
forced_chatgpt_workspace_id: Option<String>,
session_source: SessionSource,
) -> std::io::Result<Option<AgentIdentityAuth>> {
match self {
Self::AgentIdentity(auth) => Ok(Some(auth.clone())),
Self::ApiKey(_) => Ok(None),
Self::Chatgpt(_) | Self::ChatgptAuthTokens(_) => {
if policy == AgentIdentityAuthPolicy::JwtOnly {
return Ok(None);
}
self.ensure_chatgpt_agent_identity(
chatgpt_base_url,
forced_chatgpt_workspace_id,
session_source,
)
.await
.map(Some)
}
}
}
async fn ensure_chatgpt_agent_identity(
&self,
chatgpt_base_url: Option<String>,
forced_chatgpt_workspace_id: Option<String>,
session_source: SessionSource,
) -> std::io::Result<AgentIdentityAuth> {
let binding = ChatgptAgentIdentityBinding::from_auth(self, forced_chatgpt_workspace_id)
.ok_or_else(|| std::io::Error::other("ChatGPT auth is unavailable"))?;
if let Some(auth) = self.cached_agent_identity_auth(&binding) {
return Ok(auth);
}
if let Some(record) = self.get_agent_identity(&binding.account_id)
&& agent_identity_record_matches_binding(&record, &binding)
&& public_key_ssh_from_private_key_pkcs8_base64(&record.agent_private_key).is_ok()
{
let auth = self.agent_identity_auth_for_record(record)?;
self.set_cached_agent_identity_auth(Some(auth.clone()))?;
return Ok(auth);
}
let key_material = generate_agent_key_material().map_err(std::io::Error::other)?;
let base_url = normalize_chatgpt_base_url(
chatgpt_base_url
.as_deref()
.unwrap_or(DEFAULT_CHATGPT_BACKEND_BASE_URL),
);
let runtime_id = register_agent_identity(
&build_reqwest_client(),
&base_url,
&binding.access_token,
&key_material,
build_abom(session_source),
)
.await
.map_err(std::io::Error::other)?;
let record = AgentIdentityAuthRecord {
agent_runtime_id: runtime_id.into_string(),
agent_private_key: key_material.private_key_pkcs8_base64,
account_id: binding.account_id,
chatgpt_user_id: binding.chatgpt_user_id,
email: binding.email,
plan_type: binding.plan_type,
chatgpt_account_is_fedramp: binding.chatgpt_account_is_fedramp,
registered_at: Some(
Utc::now().to_rfc3339_opts(chrono::SecondsFormat::Secs, /*use_z*/ true),
),
};
self.set_agent_identity(record.clone())?;
self.agent_identity_auth_for_record(record)
}
/// Consider this private to integration tests.
pub fn create_dummy_chatgpt_auth_for_testing() -> Self {
let auth_dot_json = AuthDotJson {
@@ -423,11 +646,7 @@ impl CodexAuth {
agent_identity: None,
};
let client = create_client();
let state = ChatgptAuthState {
auth_dot_json: Arc::new(Mutex::new(Some(auth_dot_json))),
client,
};
let state = ChatgptAuthState::new(auth_dot_json);
let dummy_auth_id = NEXT_DUMMY_AUTH_ID.fetch_add(1, Ordering::Relaxed);
let storage = create_auth_storage(
PathBuf::from(format!("dummy-chatgpt-auth-{dummy_auth_id}")),
@@ -443,6 +662,44 @@ impl CodexAuth {
}
}
impl ChatgptAgentIdentityBinding {
fn from_auth(auth: &CodexAuth, forced_workspace_id: Option<String>) -> Option<Self> {
if !auth.is_chatgpt_auth() {
return None;
}
let token_data = auth.get_token_data().ok()?;
let account_id = forced_workspace_id
.filter(|value| !value.is_empty())
.or(token_data
.account_id
.clone()
.filter(|value| !value.is_empty()))
.or(token_data.id_token.chatgpt_account_id.clone())?;
let chatgpt_user_id = token_data
.id_token
.chatgpt_user_id
.clone()
.filter(|value| !value.is_empty())?;
Some(Self {
account_id,
chatgpt_user_id,
email: token_data.id_token.email.clone().unwrap_or_default(),
plan_type: auth.account_plan_type().unwrap_or(AccountPlanType::Unknown),
chatgpt_account_is_fedramp: auth.is_fedramp_account(),
access_token: token_data.access_token,
})
}
}
fn agent_identity_record_matches_binding(
record: &AgentIdentityAuthRecord,
binding: &ChatgptAgentIdentityBinding,
) -> bool {
record.account_id == binding.account_id && record.chatgpt_user_id == binding.chatgpt_user_id
}
impl ChatgptAuth {
fn current_auth_json(&self) -> Option<AuthDotJson> {
#[expect(clippy::unwrap_used)]
@@ -460,6 +717,45 @@ impl ChatgptAuth {
fn client(&self) -> &CodexHttpClient {
&self.state.client
}
fn update_auth_json(
&self,
update: impl FnOnce(&mut AuthDotJson) -> bool,
) -> std::io::Result<bool> {
let mut guard = self
.state
.auth_dot_json
.lock()
.map_err(|_| std::io::Error::other("failed to lock auth state"))?;
let mut auth = guard
.clone()
.ok_or_else(|| std::io::Error::other("auth data is not available"))?;
let changed = update(&mut auth);
if changed {
self.storage.save(&auth)?;
*guard = Some(auth);
}
Ok(changed)
}
}
impl ChatgptAuthTokens {
fn update_auth_json_in_memory(
&self,
update: impl FnOnce(&mut AuthDotJson),
) -> std::io::Result<()> {
let mut guard = self
.state
.auth_dot_json
.lock()
.map_err(|_| std::io::Error::other("failed to lock auth state"))?;
let mut auth = guard
.clone()
.ok_or_else(|| std::io::Error::other("auth data is not available"))?;
update(&mut auth);
*guard = Some(auth);
Ok(())
}
}
pub const OPENAI_API_KEY_ENV_VAR: &str = "OPENAI_API_KEY";
@@ -557,7 +853,7 @@ pub async fn login_with_agent_identity(
openai_api_key: None,
tokens: None,
last_refresh: None,
agent_identity: Some(agent_identity.to_string()),
agent_identity: Some(AgentIdentityStorage::Jwt(agent_identity.to_string())),
};
save_auth(codex_home, &auth_dot_json, auth_credentials_store_mode)
}
@@ -1249,6 +1545,7 @@ pub struct AuthManager {
forced_chatgpt_workspace_id: RwLock<Option<String>>,
chatgpt_base_url: Option<String>,
refresh_lock: Semaphore,
agent_identity_lock: Semaphore,
external_auth: RwLock<Option<Arc<dyn ExternalAuth>>>,
}
@@ -1323,6 +1620,7 @@ impl AuthManager {
forced_chatgpt_workspace_id: RwLock::new(None),
chatgpt_base_url,
refresh_lock: Semaphore::new(/*permits*/ 1),
agent_identity_lock: Semaphore::new(/*permits*/ 1),
external_auth: RwLock::new(None),
}
}
@@ -1342,6 +1640,7 @@ impl AuthManager {
forced_chatgpt_workspace_id: RwLock::new(None),
chatgpt_base_url: None,
refresh_lock: Semaphore::new(/*permits*/ 1),
agent_identity_lock: Semaphore::new(/*permits*/ 1),
external_auth: RwLock::new(None),
})
}
@@ -1360,6 +1659,7 @@ impl AuthManager {
forced_chatgpt_workspace_id: RwLock::new(None),
chatgpt_base_url: None,
refresh_lock: Semaphore::new(/*permits*/ 1),
agent_identity_lock: Semaphore::new(/*permits*/ 1),
external_auth: RwLock::new(None),
})
}
@@ -1376,6 +1676,7 @@ impl AuthManager {
forced_chatgpt_workspace_id: RwLock::new(None),
chatgpt_base_url: None,
refresh_lock: Semaphore::new(/*permits*/ 1),
agent_identity_lock: Semaphore::new(/*permits*/ 1),
external_auth: RwLock::new(Some(
Arc::new(BearerTokenRefresher::new(config)) as Arc<dyn ExternalAuth>
)),
@@ -1415,6 +1716,38 @@ impl AuthManager {
self.auth_cached()
}
pub async fn agent_identity_auth(
&self,
policy: AgentIdentityAuthPolicy,
session_source: SessionSource,
) -> std::io::Result<Option<AgentIdentityAuth>> {
let Some(auth) = self.auth().await else {
return Ok(None);
};
if policy == AgentIdentityAuthPolicy::JwtOrChatgpt && auth.is_chatgpt_auth() {
let _permit = self
.agent_identity_lock
.acquire()
.await
.map_err(std::io::Error::other)?;
return auth
.agent_identity_auth(
policy,
self.chatgpt_base_url.clone(),
self.forced_chatgpt_workspace_id(),
session_source,
)
.await;
}
auth.agent_identity_auth(
policy,
self.chatgpt_base_url.clone(),
self.forced_chatgpt_workspace_id(),
session_source,
)
.await
}
/// Force a reload of the auth information from auth.json. Returns
/// whether the auth value changed.
pub async fn reload(&self) -> bool {

View File

@@ -44,7 +44,30 @@ pub struct AuthDotJson {
pub last_refresh: Option<DateTime<Utc>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub agent_identity: Option<String>,
pub agent_identity: Option<AgentIdentityStorage>,
}
#[derive(Deserialize, Serialize, Clone, Debug, PartialEq, Eq)]
#[serde(untagged)]
pub enum AgentIdentityStorage {
Jwt(String),
Record(AgentIdentityAuthRecord),
}
impl AgentIdentityStorage {
pub(crate) fn as_jwt(&self) -> Option<&str> {
match self {
Self::Jwt(jwt) => Some(jwt),
Self::Record(_) => None,
}
}
pub(crate) fn as_record(&self) -> Option<&AgentIdentityAuthRecord> {
match self {
Self::Jwt(_) => None,
Self::Record(record) => Some(record),
}
}
}
#[derive(Deserialize, Serialize, Clone, Debug, PartialEq, Eq)]
@@ -56,6 +79,8 @@ pub struct AgentIdentityAuthRecord {
pub email: String,
pub plan_type: AccountPlanType,
pub chatgpt_account_is_fedramp: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub registered_at: Option<String>,
}
impl AgentIdentityAuthRecord {
@@ -77,6 +102,7 @@ impl From<AgentIdentityJwtClaims> for AgentIdentityAuthRecord {
email: claims.email,
plan_type: claims.plan_type.into(),
chatgpt_account_is_fedramp: claims.chatgpt_account_is_fedramp,
registered_at: None,
}
}
}

View File

@@ -72,7 +72,7 @@ async fn file_storage_round_trips_agent_identity_auth() -> anyhow::Result<()> {
openai_api_key: None,
tokens: None,
last_refresh: None,
agent_identity: Some(agent_identity),
agent_identity: Some(AgentIdentityStorage::Jwt(agent_identity)),
};
storage.save(&auth_dot_json)?;
@@ -107,7 +107,11 @@ async fn file_storage_loads_agent_identity_as_jwt() -> anyhow::Result<()> {
let loaded = storage.load()?;
assert_eq!(
loaded.expect("auth should load").agent_identity.as_deref(),
loaded
.expect("auth should load")
.agent_identity
.as_ref()
.and_then(AgentIdentityStorage::as_jwt),
Some(agent_identity_jwt.as_str())
);
Ok(())

View File

@@ -17,6 +17,7 @@ pub use server::ServerOptions;
pub use server::ShutdownHandle;
pub use server::run_login_server;
pub use auth::AgentIdentityAuthPolicy;
pub use auth::AuthConfig;
pub use auth::AuthDotJson;
pub use auth::AuthManager;

View File

@@ -17,6 +17,7 @@ use codex_protocol::account::ProviderAccount;
use codex_protocol::error::Result;
use codex_protocol::openai_models::ModelsResponse;
use crate::auth::ProviderAuthScope;
use crate::provider::ModelProvider;
use crate::provider::ProviderAccountResult;
use crate::provider::ProviderAccountState;
@@ -91,6 +92,10 @@ impl ModelProvider for AmazonBedrockModelProvider {
resolve_provider_auth(&self.aws).await
}
async fn api_auth_for_scope(&self, _scope: ProviderAuthScope) -> Result<SharedAuthProvider> {
resolve_provider_auth(&self.aws).await
}
fn models_manager(
&self,
_codex_home: PathBuf,

View File

@@ -2,36 +2,74 @@ use std::sync::Arc;
use codex_agent_identity::AgentIdentityKey;
use codex_agent_identity::AgentTaskAuthorizationTarget;
use codex_agent_identity::AgentTaskExternalRef;
use codex_agent_identity::RegisteredAgentTask;
use codex_agent_identity::authorization_header_for_agent_task;
use codex_agent_identity::authorization_header_for_registered_task;
use codex_api::AuthProvider;
use codex_api::SharedAuthProvider;
use codex_login::AuthManager;
use codex_login::CodexAuth;
use codex_login::auth::AgentIdentityAuth;
use codex_login::auth::AgentIdentityAuthPolicy;
use codex_model_provider_info::ModelProviderInfo;
use codex_protocol::protocol::SessionSource;
use http::HeaderMap;
use http::HeaderValue;
use crate::bearer_auth_provider::BearerAuthProvider;
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum ProviderAuthScope {
/// Use the provider's default auth. Agent Identity auth uses its process task here.
UnscopedProcess,
/// Use a task-scoped Agent Assertion for work tied to a Codex thread.
Thread {
external_ref: AgentTaskExternalRef,
agent_identity_policy: AgentIdentityAuthPolicy,
session_source: SessionSource,
chatgpt_base_url: Option<String>,
},
}
#[derive(Clone, Debug)]
struct AgentIdentityAuthProvider {
auth: codex_login::auth::AgentIdentityAuth,
auth: AgentIdentityAuth,
task: Option<RegisteredAgentTask>,
}
impl AuthProvider for AgentIdentityAuthProvider {
fn add_auth_headers(&self, headers: &mut HeaderMap) {
let record = self.auth.record();
let header_value = authorization_header_for_agent_task(
AgentIdentityKey {
agent_runtime_id: &record.agent_runtime_id,
private_key_pkcs8_base64: &record.agent_private_key,
},
AgentTaskAuthorizationTarget {
agent_runtime_id: &record.agent_runtime_id,
task_id: self.auth.process_task_id(),
},
)
.map_err(std::io::Error::other);
let header_value = match self.task.as_ref() {
Some(task) => authorization_header_for_registered_task(
AgentIdentityKey {
agent_runtime_id: &record.agent_runtime_id,
private_key_pkcs8_base64: &record.agent_private_key,
},
task,
)
.map_err(std::io::Error::other),
None => self
.auth
.process_task_id()
.ok_or_else(|| {
std::io::Error::other("agent identity process task is not initialized")
})
.and_then(|task_id| {
authorization_header_for_agent_task(
AgentIdentityKey {
agent_runtime_id: &record.agent_runtime_id,
private_key_pkcs8_base64: &record.agent_private_key,
},
AgentTaskAuthorizationTarget {
agent_runtime_id: &record.agent_runtime_id,
task_id: &task_id,
},
)
.map_err(std::io::Error::other)
}),
};
if let Ok(header_value) = header_value
&& let Ok(header) = HeaderValue::from_str(&header_value)
@@ -75,20 +113,61 @@ pub(crate) fn auth_manager_for_provider(
}
}
pub(crate) fn resolve_provider_auth(
pub(crate) async fn resolve_provider_auth(
auth_manager: Option<Arc<AuthManager>>,
auth: Option<&CodexAuth>,
provider: &ModelProviderInfo,
scope: ProviderAuthScope,
) -> codex_protocol::error::Result<SharedAuthProvider> {
if let Some(auth) = bearer_auth_for_provider(provider)? {
return Ok(Arc::new(auth));
}
if provider_uses_first_party_auth_path(provider)
&& let ProviderAuthScope::Thread {
external_ref,
agent_identity_policy,
session_source,
chatgpt_base_url,
} = scope
&& let Some(agent_identity_auth) =
agent_identity_auth_for_scope(auth_manager, auth, agent_identity_policy, session_source)
.await?
{
let task = agent_identity_auth
.registered_thread_task(external_ref, chatgpt_base_url)
.await?;
return Ok(auth_provider_from_agent_task(agent_identity_auth, task));
}
Ok(match auth {
Some(auth) => auth_provider_from_auth(auth),
None => unauthenticated_auth_provider(),
})
}
async fn agent_identity_auth_for_scope(
auth_manager: Option<Arc<AuthManager>>,
auth: Option<&CodexAuth>,
policy: AgentIdentityAuthPolicy,
session_source: SessionSource,
) -> codex_protocol::error::Result<Option<AgentIdentityAuth>> {
if let Some(auth_manager) = auth_manager {
return auth_manager
.agent_identity_auth(policy, session_source)
.await
.map_err(Into::into);
}
Ok(match auth {
Some(CodexAuth::AgentIdentity(auth)) => Some(auth.clone()),
Some(CodexAuth::ApiKey(_))
| Some(CodexAuth::Chatgpt(_))
| Some(CodexAuth::ChatgptAuthTokens(_))
| None => None,
})
}
fn bearer_auth_for_provider(
provider: &ModelProviderInfo,
) -> codex_protocol::error::Result<Option<BearerAuthProvider>> {
@@ -103,12 +182,21 @@ fn bearer_auth_for_provider(
Ok(None)
}
pub fn provider_uses_first_party_auth_path(provider: &ModelProviderInfo) -> bool {
provider.requires_openai_auth
&& provider.env_key.is_none()
&& provider.experimental_bearer_token.is_none()
&& provider.auth.is_none()
&& provider.aws.is_none()
}
/// Builds request-header auth for a first-party Codex auth snapshot.
pub fn auth_provider_from_auth(auth: &CodexAuth) -> SharedAuthProvider {
match auth {
CodexAuth::AgentIdentity(auth) => {
Arc::new(AgentIdentityAuthProvider { auth: auth.clone() })
}
CodexAuth::AgentIdentity(auth) => Arc::new(AgentIdentityAuthProvider {
auth: auth.clone(),
task: None,
}),
CodexAuth::ApiKey(_) | CodexAuth::Chatgpt(_) | CodexAuth::ChatgptAuthTokens(_) => {
Arc::new(BearerAuthProvider {
token: auth.get_token().ok(),
@@ -119,19 +207,173 @@ pub fn auth_provider_from_auth(auth: &CodexAuth) -> SharedAuthProvider {
}
}
pub fn auth_provider_from_agent_task(
auth: AgentIdentityAuth,
task: RegisteredAgentTask,
) -> SharedAuthProvider {
Arc::new(AgentIdentityAuthProvider {
auth,
task: Some(task),
})
}
#[cfg(test)]
mod tests {
use codex_agent_identity::AgentRuntimeId;
use codex_agent_identity::AgentTaskId;
use codex_agent_identity::AgentTaskKind;
use codex_agent_identity::generate_agent_key_material;
use codex_login::auth::AgentIdentityAuthRecord;
use codex_model_provider_info::WireApi;
use codex_model_provider_info::create_oss_provider_with_base_url;
use codex_protocol::account::PlanType;
use pretty_assertions::assert_eq;
use serde_json::json;
use wiremock::Mock;
use wiremock::MockServer;
use wiremock::ResponseTemplate;
use wiremock::matchers::body_partial_json;
use wiremock::matchers::method;
use wiremock::matchers::path;
use super::*;
#[test]
fn unauthenticated_auth_provider_adds_no_headers() {
fn agent_identity_auth(chatgpt_account_is_fedramp: bool) -> AgentIdentityAuth {
let key_material = generate_agent_key_material().expect("generate key material");
AgentIdentityAuth::new(AgentIdentityAuthRecord {
agent_runtime_id: "agent-runtime-1".to_string(),
agent_private_key: key_material.private_key_pkcs8_base64,
account_id: "account-1".to_string(),
chatgpt_user_id: "user-1".to_string(),
email: "agent@example.com".to_string(),
plan_type: PlanType::Plus,
chatgpt_account_is_fedramp,
registered_at: None,
})
}
#[tokio::test]
async fn unauthenticated_auth_provider_adds_no_headers() {
let provider =
create_oss_provider_with_base_url("http://localhost:11434/v1", WireApi::Responses);
let auth = resolve_provider_auth(/*auth*/ None, &provider).expect("auth should resolve");
let auth = resolve_provider_auth(
/*auth_manager*/ None,
/*auth*/ None,
&provider,
ProviderAuthScope::UnscopedProcess,
)
.await
.expect("auth should resolve");
assert!(auth.to_auth_headers().is_empty());
}
#[tokio::test]
async fn first_party_thread_scope_uses_agent_assertion() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/agent/agent-runtime-1/task/register"))
.and(body_partial_json(json!({
"external_task_ref": "thread-1",
})))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"task_id": "task-thread-1",
})))
.expect(1)
.mount(&server)
.await;
let auth = CodexAuth::AgentIdentity(agent_identity_auth(
/*chatgpt_account_is_fedramp*/ false,
));
let provider = ModelProviderInfo::create_openai_provider(/*base_url*/ None);
let auth = resolve_provider_auth(
/*auth_manager*/ None,
Some(&auth),
&provider,
ProviderAuthScope::Thread {
external_ref: AgentTaskExternalRef::new("thread-1"),
agent_identity_policy: AgentIdentityAuthPolicy::JwtOnly,
session_source: SessionSource::Cli,
chatgpt_base_url: Some(server.uri()),
},
)
.await
.expect("auth should resolve");
let headers = auth.to_auth_headers();
assert!(
headers
.get(http::header::AUTHORIZATION)
.and_then(|value| value.to_str().ok())
.is_some_and(|value| value.starts_with("AgentAssertion "))
);
}
#[test]
fn agent_task_auth_provider_preserves_account_routing_headers() {
let auth = agent_identity_auth(/*chatgpt_account_is_fedramp*/ true);
let provider = auth_provider_from_agent_task(
auth,
RegisteredAgentTask::new(
AgentRuntimeId::new("agent-runtime-1"),
AgentTaskId::new("background-task-1"),
AgentTaskKind::Background,
),
);
let headers = provider.to_auth_headers();
assert!(
headers
.get(http::header::AUTHORIZATION)
.and_then(|value| value.to_str().ok())
.is_some_and(|value| value.starts_with("AgentAssertion "))
);
assert_eq!(
headers
.get("ChatGPT-Account-ID")
.and_then(|value| value.to_str().ok()),
Some("account-1")
);
assert_eq!(
headers
.get("X-OpenAI-Fedramp")
.and_then(|value| value.to_str().ok()),
Some("true")
);
}
#[tokio::test]
async fn provider_auth_ignores_thread_scope_for_non_openai_provider() {
let provider =
create_oss_provider_with_base_url("http://localhost:11434/v1", WireApi::Responses);
let auth = resolve_provider_auth(
/*auth_manager*/ None,
/*auth*/ None,
&provider,
ProviderAuthScope::Thread {
external_ref: AgentTaskExternalRef::new("thread-1"),
agent_identity_policy: AgentIdentityAuthPolicy::JwtOnly,
session_source: SessionSource::Cli,
chatgpt_base_url: None,
},
)
.await
.expect("auth should resolve");
assert!(auth.to_auth_headers().is_empty());
}
#[test]
fn first_party_auth_path_excludes_provider_specific_auth() {
let mut env_key_provider =
ModelProviderInfo::create_openai_provider(/*base_url*/ None);
env_key_provider.env_key = Some("OPENAI_API_KEY".to_string());
assert!(!provider_uses_first_party_auth_path(&env_key_provider));
let bedrock_provider = ModelProviderInfo::create_amazon_bedrock_provider(/*aws*/ None);
assert!(!provider_uses_first_party_auth_path(&bedrock_provider));
}
}

View File

@@ -4,7 +4,12 @@ mod bearer_auth_provider;
mod models_endpoint;
mod provider;
pub use codex_agent_identity::AgentTaskExternalRef;
pub use auth::ProviderAuthScope;
pub use auth::auth_provider_from_agent_task;
pub use auth::auth_provider_from_auth;
pub use auth::provider_uses_first_party_auth_path;
pub use auth::unauthenticated_auth_provider;
pub use bearer_auth_provider::BearerAuthProvider;
pub use bearer_auth_provider::BearerAuthProvider as CoreAuthProvider;

View File

@@ -26,6 +26,7 @@ use codex_response_debug_context::telemetry_transport_error_message;
use http::HeaderMap;
use tokio::time::timeout;
use crate::auth::ProviderAuthScope;
use crate::auth::resolve_provider_auth;
const MODELS_REFRESH_TIMEOUT: Duration = Duration::from_secs(5);
@@ -87,7 +88,13 @@ impl ModelsEndpointClient for OpenAiModelsEndpoint {
let auth = self.auth().await;
let auth_mode = auth.as_ref().map(CodexAuth::auth_mode);
let api_provider = self.provider_info.to_api_provider(auth_mode)?;
let api_auth = resolve_provider_auth(auth.as_ref(), &self.provider_info)?;
let api_auth = resolve_provider_auth(
self.auth_manager.clone(),
auth.as_ref(),
&self.provider_info,
ProviderAuthScope::UnscopedProcess,
)
.await?;
let transport = ReqwestTransport::new(build_reqwest_client());
let auth_telemetry = auth_header_telemetry(api_auth.as_ref());
let request_telemetry: Arc<dyn RequestTelemetry> = Arc::new(ModelsRequestTelemetry {

View File

@@ -14,6 +14,7 @@ use codex_protocol::account::ProviderAccount;
use codex_protocol::openai_models::ModelsResponse;
use crate::amazon_bedrock::AmazonBedrockModelProvider;
use crate::auth::ProviderAuthScope;
use crate::auth::auth_manager_for_provider;
use crate::auth::resolve_provider_auth;
use crate::models_endpoint::OpenAiModelsEndpoint;
@@ -113,8 +114,17 @@ pub trait ModelProvider: fmt::Debug + Send + Sync {
/// Returns the auth provider used to attach request credentials.
async fn api_auth(&self) -> codex_protocol::error::Result<SharedAuthProvider> {
self.api_auth_for_scope(ProviderAuthScope::UnscopedProcess)
.await
}
/// Returns request credentials, optionally scoped to a Codex session task.
async fn api_auth_for_scope(
&self,
scope: ProviderAuthScope,
) -> codex_protocol::error::Result<SharedAuthProvider> {
let auth = self.auth().await;
resolve_provider_auth(auth.as_ref(), self.info())
resolve_provider_auth(self.auth_manager(), auth.as_ref(), self.info(), scope).await
}
/// Creates the model manager implementation appropriate for this provider.