mirror of
https://github.com/openai/codex.git
synced 2026-05-23 12:34:25 +00:00
feat: use thread agent task auth for inference
This commit is contained in:
@@ -722,7 +722,9 @@ impl AgentControl {
|
||||
} else {
|
||||
state.send_op(agent_id, Op::Shutdown {}).await
|
||||
};
|
||||
thread.wait_until_terminated().await;
|
||||
if result.is_ok() || matches!(result, Err(CodexErr::InternalAgentDied)) {
|
||||
thread.wait_until_terminated().await;
|
||||
}
|
||||
result
|
||||
} else {
|
||||
state.send_op(agent_id, Op::Shutdown {}).await
|
||||
|
||||
@@ -116,8 +116,11 @@ use crate::util::emit_feedback_auth_recovery_tags;
|
||||
use codex_api::map_api_error;
|
||||
use codex_feedback::FeedbackRequestTags;
|
||||
use codex_feedback::emit_feedback_request_tags_with_auth_env;
|
||||
use codex_login::auth::AgentIdentityAuthPolicy;
|
||||
use codex_login::auth_env_telemetry::AuthEnvTelemetry;
|
||||
use codex_login::auth_env_telemetry::collect_auth_env_telemetry;
|
||||
use codex_model_provider::AgentTaskExternalRef;
|
||||
use codex_model_provider::ProviderAuthScope;
|
||||
use codex_model_provider::SharedModelProvider;
|
||||
use codex_model_provider::create_model_provider;
|
||||
#[cfg(test)]
|
||||
@@ -170,6 +173,8 @@ struct ModelClientState {
|
||||
provider: SharedModelProvider,
|
||||
auth_env_telemetry: AuthEnvTelemetry,
|
||||
session_source: SessionSource,
|
||||
agent_identity_policy: AgentIdentityAuthPolicy,
|
||||
chatgpt_base_url: Option<String>,
|
||||
model_verbosity: Option<VerbosityConfig>,
|
||||
enable_request_compression: bool,
|
||||
include_timing_metrics: bool,
|
||||
@@ -321,6 +326,39 @@ impl ModelClient {
|
||||
include_timing_metrics: bool,
|
||||
beta_features_header: Option<String>,
|
||||
attestation_provider: Option<Arc<dyn AttestationProvider>>,
|
||||
) -> Self {
|
||||
Self::new_with_agent_identity_policy(
|
||||
auth_manager,
|
||||
session_id,
|
||||
thread_id,
|
||||
installation_id,
|
||||
provider_info,
|
||||
session_source,
|
||||
AgentIdentityAuthPolicy::JwtOnly,
|
||||
/*chatgpt_base_url*/ None,
|
||||
model_verbosity,
|
||||
enable_request_compression,
|
||||
include_timing_metrics,
|
||||
beta_features_header,
|
||||
attestation_provider,
|
||||
)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn new_with_agent_identity_policy(
|
||||
auth_manager: Option<Arc<AuthManager>>,
|
||||
session_id: SessionId,
|
||||
thread_id: ThreadId,
|
||||
installation_id: String,
|
||||
provider_info: ModelProviderInfo,
|
||||
session_source: SessionSource,
|
||||
agent_identity_policy: AgentIdentityAuthPolicy,
|
||||
chatgpt_base_url: Option<String>,
|
||||
model_verbosity: Option<VerbosityConfig>,
|
||||
enable_request_compression: bool,
|
||||
include_timing_metrics: bool,
|
||||
beta_features_header: Option<String>,
|
||||
attestation_provider: Option<Arc<dyn AttestationProvider>>,
|
||||
) -> Self {
|
||||
let model_provider = create_model_provider(provider_info, auth_manager);
|
||||
let codex_api_key_env_enabled = model_provider
|
||||
@@ -339,6 +377,8 @@ impl ModelClient {
|
||||
provider: model_provider,
|
||||
auth_env_telemetry,
|
||||
session_source,
|
||||
agent_identity_policy,
|
||||
chatgpt_base_url,
|
||||
model_verbosity,
|
||||
enable_request_compression,
|
||||
include_timing_metrics,
|
||||
@@ -785,7 +825,11 @@ impl ModelClient {
|
||||
async fn current_client_setup(&self) -> Result<CurrentClientSetup> {
|
||||
let auth = self.state.provider.auth().await;
|
||||
let api_provider = self.state.provider.api_provider().await?;
|
||||
let api_auth = self.state.provider.api_auth().await?;
|
||||
let api_auth = self
|
||||
.state
|
||||
.provider
|
||||
.api_auth_for_scope(self.provider_auth_scope())
|
||||
.await?;
|
||||
Ok(CurrentClientSetup {
|
||||
auth,
|
||||
api_provider,
|
||||
@@ -793,6 +837,15 @@ impl ModelClient {
|
||||
})
|
||||
}
|
||||
|
||||
fn provider_auth_scope(&self) -> ProviderAuthScope {
|
||||
ProviderAuthScope::Thread {
|
||||
external_ref: AgentTaskExternalRef::new(self.state.thread_id.to_string()),
|
||||
agent_identity_policy: self.state.agent_identity_policy,
|
||||
session_source: self.state.session_source.clone(),
|
||||
chatgpt_base_url: self.state.chatgpt_base_url.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Opens a websocket connection using the same header and telemetry wiring as normal turns.
|
||||
///
|
||||
/// Both startup prewarm and in-turn `needs_new` reconnects call this path so handshake
|
||||
@@ -931,6 +984,10 @@ impl Drop for ModelClientSession {
|
||||
}
|
||||
|
||||
impl ModelClientSession {
|
||||
async fn current_client_setup(&self) -> Result<CurrentClientSetup> {
|
||||
self.client.current_client_setup().await
|
||||
}
|
||||
|
||||
pub(crate) fn reset_websocket_session(&mut self) {
|
||||
self.websocket_session.connection = None;
|
||||
self.websocket_session.last_request = None;
|
||||
@@ -1077,7 +1134,7 @@ impl ModelClientSession {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let client_setup = self.client.current_client_setup().await.map_err(|err| {
|
||||
let client_setup = self.current_client_setup().await.map_err(|err| {
|
||||
ApiError::Stream(format!(
|
||||
"failed to build websocket prewarm client setup: {err}"
|
||||
))
|
||||
@@ -1224,7 +1281,7 @@ impl ModelClientSession {
|
||||
.map(AuthManager::unauthorized_recovery);
|
||||
let mut pending_retry = PendingUnauthorizedRetry::default();
|
||||
loop {
|
||||
let client_setup = self.client.current_client_setup().await?;
|
||||
let client_setup = self.current_client_setup().await?;
|
||||
let transport = ReqwestTransport::new(build_reqwest_client());
|
||||
let request_auth_context = AuthRequestTelemetryContext::new(
|
||||
client_setup.auth.as_ref().map(CodexAuth::auth_mode),
|
||||
@@ -1340,7 +1397,7 @@ impl ModelClientSession {
|
||||
.map(AuthManager::unauthorized_recovery);
|
||||
let mut pending_retry = PendingUnauthorizedRetry::default();
|
||||
loop {
|
||||
let client_setup = self.client.current_client_setup().await?;
|
||||
let client_setup = self.current_client_setup().await?;
|
||||
let request_auth_context = AuthRequestTelemetryContext::new(
|
||||
client_setup.auth.as_ref().map(CodexAuth::auth_mode),
|
||||
client_setup.api_auth.as_ref(),
|
||||
|
||||
@@ -15,7 +15,10 @@ use codex_api::ResponseEvent;
|
||||
use codex_app_server_protocol::AuthMode;
|
||||
use codex_login::AuthManager;
|
||||
use codex_login::CodexAuth;
|
||||
use codex_login::auth::AgentIdentityAuthPolicy;
|
||||
use codex_model_provider::AgentTaskExternalRef;
|
||||
use codex_model_provider::BearerAuthProvider;
|
||||
use codex_model_provider::ProviderAuthScope;
|
||||
use codex_model_provider_info::CHATGPT_CODEX_BASE_URL;
|
||||
use codex_model_provider_info::ModelProviderInfo;
|
||||
use codex_model_provider_info::WireApi;
|
||||
@@ -61,8 +64,15 @@ use tracing_subscriber::registry::LookupSpan;
|
||||
use tracing_subscriber::util::SubscriberInitExt;
|
||||
|
||||
fn test_model_client(session_source: SessionSource) -> ModelClient {
|
||||
test_model_client_with_thread_id(ThreadId::new(), session_source)
|
||||
}
|
||||
|
||||
fn test_model_client_with_thread_id(
|
||||
conversation_id: ThreadId,
|
||||
session_source: SessionSource,
|
||||
) -> ModelClient {
|
||||
let provider = create_oss_provider_with_base_url("https://example.com/v1", WireApi::Responses);
|
||||
let thread_id = ThreadId::new();
|
||||
let thread_id = conversation_id;
|
||||
ModelClient::new(
|
||||
/*auth_manager*/ None,
|
||||
thread_id.into(),
|
||||
@@ -78,6 +88,23 @@ fn test_model_client(session_source: SessionSource) -> ModelClient {
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn provider_auth_scope_uses_thread_id_as_session_ref() {
|
||||
let conversation_id =
|
||||
ThreadId::from_string("018f4f4c-43f5-7b28-8e24-000000000001").expect("valid thread id");
|
||||
let client = test_model_client_with_thread_id(conversation_id, SessionSource::Cli);
|
||||
|
||||
assert_eq!(
|
||||
client.provider_auth_scope(),
|
||||
ProviderAuthScope::Thread {
|
||||
external_ref: AgentTaskExternalRef::new(conversation_id.to_string()),
|
||||
agent_identity_policy: AgentIdentityAuthPolicy::JwtOnly,
|
||||
session_source: SessionSource::Cli,
|
||||
chatgpt_base_url: None,
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
fn test_model_info() -> ModelInfo {
|
||||
serde_json::from_value(json!({
|
||||
"slug": "gpt-test",
|
||||
|
||||
@@ -928,13 +928,19 @@ impl Session {
|
||||
live_thread: live_thread_init.as_ref().cloned(),
|
||||
thread_store: Arc::clone(&thread_store),
|
||||
attestation_provider: attestation_provider.clone(),
|
||||
model_client: ModelClient::new(
|
||||
model_client: ModelClient::new_with_agent_identity_policy(
|
||||
Some(Arc::clone(&auth_manager)),
|
||||
session_id,
|
||||
thread_id,
|
||||
installation_id.clone(),
|
||||
session_configuration.provider.clone(),
|
||||
session_configuration.session_source.clone(),
|
||||
if config.features.enabled(Feature::UseAgentIdentity) {
|
||||
codex_login::auth::AgentIdentityAuthPolicy::JwtOrChatgpt
|
||||
} else {
|
||||
codex_login::auth::AgentIdentityAuthPolicy::JwtOnly
|
||||
},
|
||||
Some(config.chatgpt_base_url.clone()),
|
||||
config.model_verbosity,
|
||||
config.features.enabled(Feature::EnableRequestCompression),
|
||||
config.features.enabled(Feature::RuntimeMetrics),
|
||||
@@ -1119,8 +1125,6 @@ impl Session {
|
||||
anyhow::bail!("required MCP servers failed to initialize: {details}");
|
||||
}
|
||||
}
|
||||
sess.schedule_startup_prewarm(session_configuration.base_instructions.clone())
|
||||
.await;
|
||||
let session_start_source = match &initial_history {
|
||||
InitialHistory::Resumed(_) => codex_hooks::SessionStartSource::Resume,
|
||||
InitialHistory::New | InitialHistory::Forked(_) => {
|
||||
@@ -1131,6 +1135,8 @@ impl Session {
|
||||
|
||||
// record_initial_history can emit events. We record only after the SessionConfiguredEvent is emitted.
|
||||
sess.record_initial_history(initial_history).await;
|
||||
sess.schedule_startup_prewarm(session_configuration.base_instructions.clone())
|
||||
.await;
|
||||
{
|
||||
let mut state = sess.state.lock().await;
|
||||
state.set_pending_session_start_source(Some(session_start_source));
|
||||
|
||||
@@ -1,8 +1,15 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
|
||||
use codex_agent_identity::AgentIdentityKey;
|
||||
use codex_agent_identity::AgentRuntimeId;
|
||||
use codex_agent_identity::AgentTaskExternalRef;
|
||||
use codex_agent_identity::AgentTaskId;
|
||||
use codex_agent_identity::AgentTaskKind;
|
||||
use codex_agent_identity::RegisteredAgentTask;
|
||||
use codex_agent_identity::normalize_chatgpt_base_url;
|
||||
use codex_agent_identity::register_agent_task;
|
||||
use codex_agent_identity::register_agent_task_with_external_ref;
|
||||
use codex_protocol::account::PlanType as AccountPlanType;
|
||||
use tokio::sync::OnceCell;
|
||||
|
||||
@@ -15,14 +22,20 @@ const DEFAULT_CHATGPT_BACKEND_BASE_URL: &str = "https://chatgpt.com/backend-api"
|
||||
#[derive(Debug)]
|
||||
pub struct AgentIdentityAuth {
|
||||
record: AgentIdentityAuthRecord,
|
||||
process_task_id: Arc<OnceCell<String>>,
|
||||
task_ids: Arc<Mutex<HashMap<AgentTaskCacheKey, Arc<OnceCell<String>>>>>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
|
||||
enum AgentTaskCacheKey {
|
||||
Process,
|
||||
Thread(AgentTaskExternalRef),
|
||||
}
|
||||
|
||||
impl Clone for AgentIdentityAuth {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
record: self.record.clone(),
|
||||
process_task_id: Arc::clone(&self.process_task_id),
|
||||
task_ids: Arc::clone(&self.task_ids),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -31,7 +44,7 @@ impl AgentIdentityAuth {
|
||||
pub fn new(record: AgentIdentityAuthRecord) -> Self {
|
||||
Self {
|
||||
record,
|
||||
process_task_id: Arc::new(OnceCell::new()),
|
||||
task_ids: Arc::new(Mutex::new(HashMap::new())),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -39,24 +52,58 @@ impl AgentIdentityAuth {
|
||||
&self.record
|
||||
}
|
||||
|
||||
pub fn process_task_id(&self) -> Option<&str> {
|
||||
self.process_task_id.get().map(String::as_str)
|
||||
pub fn process_task_id(&self) -> Option<String> {
|
||||
self.task_id_for_initialized_key(&AgentTaskCacheKey::Process)
|
||||
}
|
||||
|
||||
pub async fn ensure_runtime(&self, chatgpt_base_url: Option<String>) -> std::io::Result<()> {
|
||||
self.process_task_id
|
||||
.get_or_try_init(|| async {
|
||||
let base_url = normalize_chatgpt_base_url(
|
||||
chatgpt_base_url
|
||||
.as_deref()
|
||||
.unwrap_or(DEFAULT_CHATGPT_BACKEND_BASE_URL),
|
||||
);
|
||||
register_agent_task(&build_reqwest_client(), &base_url, self.key())
|
||||
.await
|
||||
.map_err(std::io::Error::other)
|
||||
})
|
||||
self.task_id_for_key(
|
||||
AgentTaskCacheKey::Process,
|
||||
chatgpt_base_url,
|
||||
/*external_ref*/ None,
|
||||
)
|
||||
.await
|
||||
.map(|_| ())
|
||||
}
|
||||
|
||||
pub async fn registered_thread_task(
|
||||
&self,
|
||||
external_ref: AgentTaskExternalRef,
|
||||
chatgpt_base_url: Option<String>,
|
||||
) -> std::io::Result<RegisteredAgentTask> {
|
||||
let task_id = self
|
||||
.task_id_for_key(
|
||||
AgentTaskCacheKey::Thread(external_ref.clone()),
|
||||
chatgpt_base_url,
|
||||
Some(external_ref),
|
||||
)
|
||||
.await?;
|
||||
Ok(self.registered_task(task_id, AgentTaskKind::Thread))
|
||||
}
|
||||
|
||||
pub async fn register_task(&self, chatgpt_base_url: Option<String>) -> std::io::Result<String> {
|
||||
self.register_task_with_external_ref(chatgpt_base_url, /*external_ref*/ None)
|
||||
.await
|
||||
.map(|_| ())
|
||||
}
|
||||
|
||||
async fn register_task_with_external_ref(
|
||||
&self,
|
||||
chatgpt_base_url: Option<String>,
|
||||
external_ref: Option<&AgentTaskExternalRef>,
|
||||
) -> std::io::Result<String> {
|
||||
let base_url = normalize_chatgpt_base_url(
|
||||
chatgpt_base_url
|
||||
.as_deref()
|
||||
.unwrap_or(DEFAULT_CHATGPT_BACKEND_BASE_URL),
|
||||
);
|
||||
register_agent_task_with_external_ref(
|
||||
&build_reqwest_client(),
|
||||
&base_url,
|
||||
self.key(),
|
||||
external_ref,
|
||||
)
|
||||
.await
|
||||
.map_err(std::io::Error::other)
|
||||
}
|
||||
|
||||
pub fn account_id(&self) -> &str {
|
||||
@@ -84,4 +131,196 @@ impl AgentIdentityAuth {
|
||||
private_key_pkcs8_base64: &self.record.agent_private_key,
|
||||
}
|
||||
}
|
||||
|
||||
async fn task_id_for_key(
|
||||
&self,
|
||||
key: AgentTaskCacheKey,
|
||||
chatgpt_base_url: Option<String>,
|
||||
external_ref: Option<AgentTaskExternalRef>,
|
||||
) -> std::io::Result<String> {
|
||||
let slot = self.task_slot(key)?;
|
||||
slot.get_or_try_init(|| async {
|
||||
self.register_task_with_external_ref(chatgpt_base_url, external_ref.as_ref())
|
||||
.await
|
||||
})
|
||||
.await
|
||||
.cloned()
|
||||
}
|
||||
|
||||
fn task_slot(&self, key: AgentTaskCacheKey) -> std::io::Result<Arc<OnceCell<String>>> {
|
||||
let mut task_ids = self
|
||||
.task_ids
|
||||
.lock()
|
||||
.map_err(|_| std::io::Error::other("failed to lock agent task cache"))?;
|
||||
Ok(task_ids
|
||||
.entry(key)
|
||||
.or_insert_with(|| Arc::new(OnceCell::new()))
|
||||
.clone())
|
||||
}
|
||||
|
||||
fn task_id_for_initialized_key(&self, key: &AgentTaskCacheKey) -> Option<String> {
|
||||
let task_ids = self.task_ids.lock().ok()?;
|
||||
task_ids.get(key)?.get().cloned()
|
||||
}
|
||||
|
||||
fn registered_task(&self, task_id: String, kind: AgentTaskKind) -> RegisteredAgentTask {
|
||||
RegisteredAgentTask::new(
|
||||
AgentRuntimeId::new(self.record.agent_runtime_id.clone()),
|
||||
AgentTaskId::new(task_id),
|
||||
kind,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicUsize;
|
||||
use std::sync::atomic::Ordering;
|
||||
|
||||
use codex_agent_identity::generate_agent_key_material;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
use wiremock::Mock;
|
||||
use wiremock::MockServer;
|
||||
use wiremock::ResponseTemplate;
|
||||
use wiremock::matchers::body_partial_json;
|
||||
use wiremock::matchers::method;
|
||||
use wiremock::matchers::path;
|
||||
|
||||
use super::*;
|
||||
|
||||
fn agent_identity_record(private_key: String) -> AgentIdentityAuthRecord {
|
||||
AgentIdentityAuthRecord {
|
||||
agent_runtime_id: "agent-runtime-1".to_string(),
|
||||
agent_private_key: private_key,
|
||||
account_id: "account-1".to_string(),
|
||||
chatgpt_user_id: "user-1".to_string(),
|
||||
email: "agent@example.com".to_string(),
|
||||
plan_type: AccountPlanType::Plus,
|
||||
chatgpt_account_is_fedramp: false,
|
||||
registered_at: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn agent_identity_auth() -> AgentIdentityAuth {
|
||||
let key_material = generate_agent_key_material().expect("generate key material");
|
||||
AgentIdentityAuth::new(agent_identity_record(key_material.private_key_pkcs8_base64))
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn registered_thread_task_registers_once_per_external_ref() -> anyhow::Result<()> {
|
||||
let auth = agent_identity_auth();
|
||||
let server = MockServer::start().await;
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/agent/agent-runtime-1/task/register"))
|
||||
.and(body_partial_json(json!({
|
||||
"external_task_ref": "thread-1",
|
||||
})))
|
||||
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
|
||||
"task_id": "task-thread-1",
|
||||
})))
|
||||
.expect(1)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let first = auth
|
||||
.registered_thread_task(AgentTaskExternalRef::new("thread-1"), Some(server.uri()))
|
||||
.await?;
|
||||
let second = auth
|
||||
.registered_thread_task(AgentTaskExternalRef::new("thread-1"), Some(server.uri()))
|
||||
.await?;
|
||||
|
||||
assert_eq!(first, second);
|
||||
assert_eq!(
|
||||
first,
|
||||
RegisteredAgentTask::new(
|
||||
AgentRuntimeId::new("agent-runtime-1"),
|
||||
AgentTaskId::new("task-thread-1"),
|
||||
AgentTaskKind::Thread,
|
||||
)
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn registered_thread_task_uses_distinct_external_refs() -> anyhow::Result<()> {
|
||||
let auth = agent_identity_auth();
|
||||
let server = MockServer::start().await;
|
||||
for (external_ref, task_id) in
|
||||
[("thread-1", "task-thread-1"), ("thread-2", "task-thread-2")]
|
||||
{
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/agent/agent-runtime-1/task/register"))
|
||||
.and(body_partial_json(json!({
|
||||
"external_task_ref": external_ref,
|
||||
})))
|
||||
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
|
||||
"task_id": task_id,
|
||||
})))
|
||||
.expect(1)
|
||||
.mount(&server)
|
||||
.await;
|
||||
}
|
||||
|
||||
let first = auth
|
||||
.registered_thread_task(AgentTaskExternalRef::new("thread-1"), Some(server.uri()))
|
||||
.await?;
|
||||
let second = auth
|
||||
.registered_thread_task(AgentTaskExternalRef::new("thread-2"), Some(server.uri()))
|
||||
.await?;
|
||||
|
||||
assert_eq!(first.task_id.as_str(), "task-thread-1");
|
||||
assert_eq!(second.task_id.as_str(), "task-thread-2");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn failed_thread_task_registration_can_retry() -> anyhow::Result<()> {
|
||||
let auth = agent_identity_auth();
|
||||
let server = MockServer::start().await;
|
||||
let request_count = Arc::new(AtomicUsize::new(0));
|
||||
let response_count = Arc::clone(&request_count);
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/agent/agent-runtime-1/task/register"))
|
||||
.and(body_partial_json(json!({
|
||||
"external_task_ref": "thread-1",
|
||||
})))
|
||||
.respond_with(move |_request: &wiremock::Request| {
|
||||
if response_count.fetch_add(1, Ordering::SeqCst) == 0 {
|
||||
ResponseTemplate::new(500)
|
||||
} else {
|
||||
ResponseTemplate::new(200).set_body_json(json!({
|
||||
"task_id": "task-thread-1",
|
||||
}))
|
||||
}
|
||||
})
|
||||
.expect(2)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
auth.registered_thread_task(AgentTaskExternalRef::new("thread-1"), Some(server.uri()))
|
||||
.await
|
||||
.expect_err("first registration should fail");
|
||||
let task = auth
|
||||
.registered_thread_task(AgentTaskExternalRef::new("thread-1"), Some(server.uri()))
|
||||
.await?;
|
||||
|
||||
assert_eq!(request_count.load(Ordering::SeqCst), 2);
|
||||
assert_eq!(task.task_id.as_str(), "task-thread-1");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn task_slots_are_shared_across_clones() {
|
||||
let auth = agent_identity_auth();
|
||||
let cloned = auth.clone();
|
||||
let slot = auth
|
||||
.task_slot(AgentTaskCacheKey::Process)
|
||||
.expect("task slot should be available");
|
||||
slot.set("process-task-1".to_string())
|
||||
.expect("process task should be unset");
|
||||
|
||||
assert_eq!(cloned.process_task_id(), Some("process-task-1".to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -233,8 +233,8 @@ async fn chatgpt_auth_registers_agent_identity_when_enabled() -> anyhow::Result<
|
||||
agent_auth.record().agent_runtime_id,
|
||||
reused.record().agent_runtime_id
|
||||
);
|
||||
assert_eq!(agent_auth.process_task_id(), Some("task-123"));
|
||||
assert_eq!(reused.process_task_id(), Some("task-123"));
|
||||
assert_eq!(agent_auth.process_task_id(), Some("task-123".to_string()));
|
||||
assert_eq!(reused.process_task_id(), Some("task-123".to_string()));
|
||||
assert_eq!(agent_auth.record().agent_runtime_id, "agent-runtime-123");
|
||||
assert_eq!(agent_auth.record().account_id, "account-123");
|
||||
assert_eq!(agent_auth.record().chatgpt_user_id, "user-12345");
|
||||
@@ -853,7 +853,10 @@ async fn load_auth_reads_access_token_from_env() {
|
||||
panic!("env auth should load as agent identity");
|
||||
};
|
||||
assert_eq!(agent_identity.record(), &expected_record);
|
||||
assert_eq!(agent_identity.process_task_id(), Some("task-123"));
|
||||
assert_eq!(
|
||||
agent_identity.process_task_id(),
|
||||
Some("task-123".to_string())
|
||||
);
|
||||
assert!(
|
||||
!get_auth_file(codex_home.path()).exists(),
|
||||
"env auth should not write auth.json"
|
||||
|
||||
@@ -18,6 +18,7 @@ use codex_protocol::account::ProviderAccount;
|
||||
use codex_protocol::error::Result;
|
||||
use codex_protocol::openai_models::ModelsResponse;
|
||||
|
||||
use crate::auth::ProviderAuthScope;
|
||||
use crate::provider::ModelProvider;
|
||||
use crate::provider::ProviderAccountResult;
|
||||
use crate::provider::ProviderAccountState;
|
||||
@@ -96,6 +97,10 @@ impl ModelProvider for AmazonBedrockModelProvider {
|
||||
resolve_provider_auth(&self.aws).await
|
||||
}
|
||||
|
||||
async fn api_auth_for_scope(&self, _scope: ProviderAuthScope) -> Result<SharedAuthProvider> {
|
||||
resolve_provider_auth(&self.aws).await
|
||||
}
|
||||
|
||||
fn models_manager(
|
||||
&self,
|
||||
_codex_home: PathBuf,
|
||||
|
||||
@@ -2,39 +2,74 @@ use std::sync::Arc;
|
||||
|
||||
use codex_agent_identity::AgentIdentityKey;
|
||||
use codex_agent_identity::AgentTaskAuthorizationTarget;
|
||||
use codex_agent_identity::AgentTaskExternalRef;
|
||||
use codex_agent_identity::RegisteredAgentTask;
|
||||
use codex_agent_identity::authorization_header_for_agent_task;
|
||||
use codex_agent_identity::authorization_header_for_registered_task;
|
||||
use codex_api::AuthProvider;
|
||||
use codex_api::SharedAuthProvider;
|
||||
use codex_login::AuthManager;
|
||||
use codex_login::CodexAuth;
|
||||
use codex_login::auth::AgentIdentityAuth;
|
||||
use codex_login::auth::AgentIdentityAuthPolicy;
|
||||
use codex_model_provider_info::ModelProviderInfo;
|
||||
use codex_protocol::protocol::SessionSource;
|
||||
use http::HeaderMap;
|
||||
use http::HeaderValue;
|
||||
|
||||
use crate::bearer_auth_provider::BearerAuthProvider;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub enum ProviderAuthScope {
|
||||
/// Use the provider's default auth. Agent Identity auth uses its process task here.
|
||||
UnscopedProcess,
|
||||
/// Use a task-scoped Agent Assertion for work tied to a Codex thread.
|
||||
Thread {
|
||||
external_ref: AgentTaskExternalRef,
|
||||
agent_identity_policy: AgentIdentityAuthPolicy,
|
||||
session_source: SessionSource,
|
||||
chatgpt_base_url: Option<String>,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct AgentIdentityAuthProvider {
|
||||
auth: codex_login::auth::AgentIdentityAuth,
|
||||
auth: AgentIdentityAuth,
|
||||
task: Option<RegisteredAgentTask>,
|
||||
}
|
||||
|
||||
impl AuthProvider for AgentIdentityAuthProvider {
|
||||
fn add_auth_headers(&self, headers: &mut HeaderMap) {
|
||||
let record = self.auth.record();
|
||||
let Some(task_id) = self.auth.process_task_id() else {
|
||||
return;
|
||||
let header_value = match self.task.as_ref() {
|
||||
Some(task) => authorization_header_for_registered_task(
|
||||
AgentIdentityKey {
|
||||
agent_runtime_id: &record.agent_runtime_id,
|
||||
private_key_pkcs8_base64: &record.agent_private_key,
|
||||
},
|
||||
task,
|
||||
)
|
||||
.map_err(std::io::Error::other),
|
||||
None => self
|
||||
.auth
|
||||
.process_task_id()
|
||||
.ok_or_else(|| {
|
||||
std::io::Error::other("agent identity process task is not initialized")
|
||||
})
|
||||
.and_then(|task_id| {
|
||||
authorization_header_for_agent_task(
|
||||
AgentIdentityKey {
|
||||
agent_runtime_id: &record.agent_runtime_id,
|
||||
private_key_pkcs8_base64: &record.agent_private_key,
|
||||
},
|
||||
AgentTaskAuthorizationTarget {
|
||||
agent_runtime_id: &record.agent_runtime_id,
|
||||
task_id: &task_id,
|
||||
},
|
||||
)
|
||||
.map_err(std::io::Error::other)
|
||||
}),
|
||||
};
|
||||
let header_value = authorization_header_for_agent_task(
|
||||
AgentIdentityKey {
|
||||
agent_runtime_id: &record.agent_runtime_id,
|
||||
private_key_pkcs8_base64: &record.agent_private_key,
|
||||
},
|
||||
AgentTaskAuthorizationTarget {
|
||||
agent_runtime_id: &record.agent_runtime_id,
|
||||
task_id,
|
||||
},
|
||||
)
|
||||
.map_err(std::io::Error::other);
|
||||
|
||||
if let Ok(header_value) = header_value
|
||||
&& let Ok(header) = HeaderValue::from_str(&header_value)
|
||||
@@ -78,20 +113,61 @@ pub(crate) fn auth_manager_for_provider(
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn resolve_provider_auth(
|
||||
pub(crate) async fn resolve_provider_auth(
|
||||
auth_manager: Option<Arc<AuthManager>>,
|
||||
auth: Option<&CodexAuth>,
|
||||
provider: &ModelProviderInfo,
|
||||
scope: ProviderAuthScope,
|
||||
) -> codex_protocol::error::Result<SharedAuthProvider> {
|
||||
if let Some(auth) = bearer_auth_for_provider(provider)? {
|
||||
return Ok(Arc::new(auth));
|
||||
}
|
||||
|
||||
if provider_uses_first_party_auth_path(provider)
|
||||
&& let ProviderAuthScope::Thread {
|
||||
external_ref,
|
||||
agent_identity_policy,
|
||||
session_source,
|
||||
chatgpt_base_url,
|
||||
} = scope
|
||||
&& let Some(agent_identity_auth) =
|
||||
agent_identity_auth_for_scope(auth_manager, auth, agent_identity_policy, session_source)
|
||||
.await?
|
||||
{
|
||||
let task = agent_identity_auth
|
||||
.registered_thread_task(external_ref, chatgpt_base_url)
|
||||
.await?;
|
||||
return Ok(auth_provider_from_agent_task(agent_identity_auth, task));
|
||||
}
|
||||
|
||||
Ok(match auth {
|
||||
Some(auth) => auth_provider_from_auth(auth),
|
||||
None => unauthenticated_auth_provider(),
|
||||
})
|
||||
}
|
||||
|
||||
async fn agent_identity_auth_for_scope(
|
||||
auth_manager: Option<Arc<AuthManager>>,
|
||||
auth: Option<&CodexAuth>,
|
||||
policy: AgentIdentityAuthPolicy,
|
||||
session_source: SessionSource,
|
||||
) -> codex_protocol::error::Result<Option<AgentIdentityAuth>> {
|
||||
if let Some(auth_manager) = auth_manager {
|
||||
return auth_manager
|
||||
.agent_identity_auth(policy, session_source)
|
||||
.await
|
||||
.map_err(Into::into);
|
||||
}
|
||||
|
||||
Ok(match auth {
|
||||
Some(CodexAuth::AgentIdentity(auth)) => Some(auth.clone()),
|
||||
Some(CodexAuth::ApiKey(_))
|
||||
| Some(CodexAuth::Chatgpt(_))
|
||||
| Some(CodexAuth::ChatgptAuthTokens(_))
|
||||
| None => None,
|
||||
})
|
||||
}
|
||||
|
||||
fn bearer_auth_for_provider(
|
||||
provider: &ModelProviderInfo,
|
||||
) -> codex_protocol::error::Result<Option<BearerAuthProvider>> {
|
||||
@@ -106,12 +182,21 @@ fn bearer_auth_for_provider(
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
pub fn provider_uses_first_party_auth_path(provider: &ModelProviderInfo) -> bool {
|
||||
provider.requires_openai_auth
|
||||
&& provider.env_key.is_none()
|
||||
&& provider.experimental_bearer_token.is_none()
|
||||
&& provider.auth.is_none()
|
||||
&& provider.aws.is_none()
|
||||
}
|
||||
|
||||
/// Builds request-header auth for a first-party Codex auth snapshot.
|
||||
pub fn auth_provider_from_auth(auth: &CodexAuth) -> SharedAuthProvider {
|
||||
match auth {
|
||||
CodexAuth::AgentIdentity(auth) => {
|
||||
Arc::new(AgentIdentityAuthProvider { auth: auth.clone() })
|
||||
}
|
||||
CodexAuth::AgentIdentity(auth) => Arc::new(AgentIdentityAuthProvider {
|
||||
auth: auth.clone(),
|
||||
task: None,
|
||||
}),
|
||||
CodexAuth::ApiKey(_) | CodexAuth::Chatgpt(_) | CodexAuth::ChatgptAuthTokens(_) => {
|
||||
Arc::new(BearerAuthProvider {
|
||||
token: auth.get_token().ok(),
|
||||
@@ -122,19 +207,173 @@ pub fn auth_provider_from_auth(auth: &CodexAuth) -> SharedAuthProvider {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn auth_provider_from_agent_task(
|
||||
auth: AgentIdentityAuth,
|
||||
task: RegisteredAgentTask,
|
||||
) -> SharedAuthProvider {
|
||||
Arc::new(AgentIdentityAuthProvider {
|
||||
auth,
|
||||
task: Some(task),
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use codex_agent_identity::AgentRuntimeId;
|
||||
use codex_agent_identity::AgentTaskId;
|
||||
use codex_agent_identity::AgentTaskKind;
|
||||
use codex_agent_identity::generate_agent_key_material;
|
||||
use codex_login::auth::AgentIdentityAuthRecord;
|
||||
use codex_model_provider_info::WireApi;
|
||||
use codex_model_provider_info::create_oss_provider_with_base_url;
|
||||
use codex_protocol::account::PlanType;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
use wiremock::Mock;
|
||||
use wiremock::MockServer;
|
||||
use wiremock::ResponseTemplate;
|
||||
use wiremock::matchers::body_partial_json;
|
||||
use wiremock::matchers::method;
|
||||
use wiremock::matchers::path;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn unauthenticated_auth_provider_adds_no_headers() {
|
||||
fn agent_identity_auth(chatgpt_account_is_fedramp: bool) -> AgentIdentityAuth {
|
||||
let key_material = generate_agent_key_material().expect("generate key material");
|
||||
AgentIdentityAuth::new(AgentIdentityAuthRecord {
|
||||
agent_runtime_id: "agent-runtime-1".to_string(),
|
||||
agent_private_key: key_material.private_key_pkcs8_base64,
|
||||
account_id: "account-1".to_string(),
|
||||
chatgpt_user_id: "user-1".to_string(),
|
||||
email: "agent@example.com".to_string(),
|
||||
plan_type: PlanType::Plus,
|
||||
chatgpt_account_is_fedramp,
|
||||
registered_at: None,
|
||||
})
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn unauthenticated_auth_provider_adds_no_headers() {
|
||||
let provider =
|
||||
create_oss_provider_with_base_url("http://localhost:11434/v1", WireApi::Responses);
|
||||
let auth = resolve_provider_auth(/*auth*/ None, &provider).expect("auth should resolve");
|
||||
let auth = resolve_provider_auth(
|
||||
/*auth_manager*/ None,
|
||||
/*auth*/ None,
|
||||
&provider,
|
||||
ProviderAuthScope::UnscopedProcess,
|
||||
)
|
||||
.await
|
||||
.expect("auth should resolve");
|
||||
|
||||
assert!(auth.to_auth_headers().is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn first_party_thread_scope_uses_agent_assertion() {
|
||||
let server = MockServer::start().await;
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/agent/agent-runtime-1/task/register"))
|
||||
.and(body_partial_json(json!({
|
||||
"external_task_ref": "thread-1",
|
||||
})))
|
||||
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
|
||||
"task_id": "task-thread-1",
|
||||
})))
|
||||
.expect(1)
|
||||
.mount(&server)
|
||||
.await;
|
||||
let auth = CodexAuth::AgentIdentity(agent_identity_auth(
|
||||
/*chatgpt_account_is_fedramp*/ false,
|
||||
));
|
||||
let provider = ModelProviderInfo::create_openai_provider(/*base_url*/ None);
|
||||
|
||||
let auth = resolve_provider_auth(
|
||||
/*auth_manager*/ None,
|
||||
Some(&auth),
|
||||
&provider,
|
||||
ProviderAuthScope::Thread {
|
||||
external_ref: AgentTaskExternalRef::new("thread-1"),
|
||||
agent_identity_policy: AgentIdentityAuthPolicy::JwtOnly,
|
||||
session_source: SessionSource::Cli,
|
||||
chatgpt_base_url: Some(server.uri()),
|
||||
},
|
||||
)
|
||||
.await
|
||||
.expect("auth should resolve");
|
||||
|
||||
let headers = auth.to_auth_headers();
|
||||
assert!(
|
||||
headers
|
||||
.get(http::header::AUTHORIZATION)
|
||||
.and_then(|value| value.to_str().ok())
|
||||
.is_some_and(|value| value.starts_with("AgentAssertion "))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn agent_task_auth_provider_preserves_account_routing_headers() {
|
||||
let auth = agent_identity_auth(/*chatgpt_account_is_fedramp*/ true);
|
||||
let provider = auth_provider_from_agent_task(
|
||||
auth,
|
||||
RegisteredAgentTask::new(
|
||||
AgentRuntimeId::new("agent-runtime-1"),
|
||||
AgentTaskId::new("background-task-1"),
|
||||
AgentTaskKind::Background,
|
||||
),
|
||||
);
|
||||
|
||||
let headers = provider.to_auth_headers();
|
||||
|
||||
assert!(
|
||||
headers
|
||||
.get(http::header::AUTHORIZATION)
|
||||
.and_then(|value| value.to_str().ok())
|
||||
.is_some_and(|value| value.starts_with("AgentAssertion "))
|
||||
);
|
||||
assert_eq!(
|
||||
headers
|
||||
.get("ChatGPT-Account-ID")
|
||||
.and_then(|value| value.to_str().ok()),
|
||||
Some("account-1")
|
||||
);
|
||||
assert_eq!(
|
||||
headers
|
||||
.get("X-OpenAI-Fedramp")
|
||||
.and_then(|value| value.to_str().ok()),
|
||||
Some("true")
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn provider_auth_ignores_thread_scope_for_non_openai_provider() {
|
||||
let provider =
|
||||
create_oss_provider_with_base_url("http://localhost:11434/v1", WireApi::Responses);
|
||||
|
||||
let auth = resolve_provider_auth(
|
||||
/*auth_manager*/ None,
|
||||
/*auth*/ None,
|
||||
&provider,
|
||||
ProviderAuthScope::Thread {
|
||||
external_ref: AgentTaskExternalRef::new("thread-1"),
|
||||
agent_identity_policy: AgentIdentityAuthPolicy::JwtOnly,
|
||||
session_source: SessionSource::Cli,
|
||||
chatgpt_base_url: None,
|
||||
},
|
||||
)
|
||||
.await
|
||||
.expect("auth should resolve");
|
||||
|
||||
assert!(auth.to_auth_headers().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn first_party_auth_path_excludes_provider_specific_auth() {
|
||||
let mut env_key_provider =
|
||||
ModelProviderInfo::create_openai_provider(/*base_url*/ None);
|
||||
env_key_provider.env_key = Some("OPENAI_API_KEY".to_string());
|
||||
assert!(!provider_uses_first_party_auth_path(&env_key_provider));
|
||||
|
||||
let bedrock_provider = ModelProviderInfo::create_amazon_bedrock_provider(/*aws*/ None);
|
||||
assert!(!provider_uses_first_party_auth_path(&bedrock_provider));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,7 +4,12 @@ mod bearer_auth_provider;
|
||||
mod models_endpoint;
|
||||
mod provider;
|
||||
|
||||
pub use codex_agent_identity::AgentTaskExternalRef;
|
||||
|
||||
pub use auth::ProviderAuthScope;
|
||||
pub use auth::auth_provider_from_agent_task;
|
||||
pub use auth::auth_provider_from_auth;
|
||||
pub use auth::provider_uses_first_party_auth_path;
|
||||
pub use auth::unauthenticated_auth_provider;
|
||||
pub use bearer_auth_provider::BearerAuthProvider;
|
||||
pub use bearer_auth_provider::BearerAuthProvider as CoreAuthProvider;
|
||||
|
||||
@@ -26,6 +26,7 @@ use codex_response_debug_context::telemetry_transport_error_message;
|
||||
use http::HeaderMap;
|
||||
use tokio::time::timeout;
|
||||
|
||||
use crate::auth::ProviderAuthScope;
|
||||
use crate::auth::resolve_provider_auth;
|
||||
|
||||
const MODELS_REFRESH_TIMEOUT: Duration = Duration::from_secs(5);
|
||||
@@ -87,7 +88,13 @@ impl ModelsEndpointClient for OpenAiModelsEndpoint {
|
||||
let auth = self.auth().await;
|
||||
let auth_mode = auth.as_ref().map(CodexAuth::auth_mode);
|
||||
let api_provider = self.provider_info.to_api_provider(auth_mode)?;
|
||||
let api_auth = resolve_provider_auth(auth.as_ref(), &self.provider_info)?;
|
||||
let api_auth = resolve_provider_auth(
|
||||
self.auth_manager.clone(),
|
||||
auth.as_ref(),
|
||||
&self.provider_info,
|
||||
ProviderAuthScope::UnscopedProcess,
|
||||
)
|
||||
.await?;
|
||||
let transport = ReqwestTransport::new(build_reqwest_client());
|
||||
let auth_telemetry = auth_header_telemetry(api_auth.as_ref());
|
||||
let request_telemetry: Arc<dyn RequestTelemetry> = Arc::new(ModelsRequestTelemetry {
|
||||
|
||||
@@ -14,6 +14,7 @@ use codex_protocol::account::ProviderAccount;
|
||||
use codex_protocol::openai_models::ModelsResponse;
|
||||
|
||||
use crate::amazon_bedrock::AmazonBedrockModelProvider;
|
||||
use crate::auth::ProviderAuthScope;
|
||||
use crate::auth::auth_manager_for_provider;
|
||||
use crate::auth::resolve_provider_auth;
|
||||
use crate::models_endpoint::OpenAiModelsEndpoint;
|
||||
@@ -129,8 +130,17 @@ pub trait ModelProvider: fmt::Debug + Send + Sync {
|
||||
|
||||
/// Returns the auth provider used to attach request credentials.
|
||||
async fn api_auth(&self) -> codex_protocol::error::Result<SharedAuthProvider> {
|
||||
self.api_auth_for_scope(ProviderAuthScope::UnscopedProcess)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Returns request credentials, optionally scoped to a Codex session task.
|
||||
async fn api_auth_for_scope(
|
||||
&self,
|
||||
scope: ProviderAuthScope,
|
||||
) -> codex_protocol::error::Result<SharedAuthProvider> {
|
||||
let auth = self.auth().await;
|
||||
resolve_provider_auth(auth.as_ref(), self.info())
|
||||
resolve_provider_auth(self.auth_manager(), auth.as_ref(), self.info(), scope).await
|
||||
}
|
||||
|
||||
/// Creates the model manager implementation appropriate for this provider.
|
||||
|
||||
Reference in New Issue
Block a user