Defer agent identity registration until auth exists

# Conflicts:
#	codex-rs/login/src/auth/auth_tests.rs
#	codex-rs/login/src/auth/manager.rs
This commit is contained in:
adrian
2026-04-09 12:19:09 -07:00
parent b56a0bcdb9
commit ea8ec876fd
5 changed files with 206 additions and 85 deletions

View File

@@ -21,7 +21,6 @@ use ed25519_dalek::pkcs8::DecodePrivateKey;
use ed25519_dalek::pkcs8::EncodePrivateKey;
use rand::TryRngCore;
use rand::rngs::OsRng;
use reqwest::StatusCode;
use serde::Deserialize;
use serde::Serialize;
use tokio::sync::Mutex;
@@ -108,6 +107,10 @@ impl AgentIdentityManager {
}
}
pub(crate) fn is_enabled(&self) -> bool {
self.feature_enabled
}
pub(crate) async fn ensure_registered_identity(&self) -> Result<Option<StoredAgentIdentity>> {
if !self.feature_enabled {
return Ok(None);
@@ -153,69 +156,45 @@ impl AgentIdentityManager {
};
let client = create_client();
let urls = agent_registration_urls(&self.chatgpt_base_url);
let url = agent_registration_url(&self.chatgpt_base_url);
let response = client
.post(&url)
.bearer_auth(&binding.access_token)
.header("chatgpt-account-id", &binding.chatgpt_account_id)
.json(&request_body)
.timeout(AGENT_REGISTRATION_TIMEOUT)
.send()
.await
.with_context(|| {
format!("failed to send agent identity registration request to {url}")
})?;
for (index, url) in urls.iter().enumerate() {
let response = client
.post(url)
.bearer_auth(&binding.access_token)
.header("chatgpt-account-id", &binding.chatgpt_account_id)
.json(&request_body)
.timeout(AGENT_REGISTRATION_TIMEOUT)
.send()
if response.status().is_success() {
let response_body = response
.json::<RegisterAgentResponse>()
.await
.with_context(|| {
format!("failed to send agent identity registration request to {url}")
})?;
if response.status().is_success() {
let response_body = response
.json::<RegisterAgentResponse>()
.await
.with_context(|| {
format!("failed to parse agent identity response from {url}")
})?;
let stored_identity = StoredAgentIdentity {
binding_id: binding.binding_id.clone(),
chatgpt_account_id: binding.chatgpt_account_id.clone(),
chatgpt_user_id: binding.chatgpt_user_id.clone(),
agent_runtime_id: response_body.agent_runtime_id,
private_key_pkcs8_base64: key_material.private_key_pkcs8_base64,
public_key_ssh: key_material.public_key_ssh,
registered_at: Utc::now().to_rfc3339_opts(SecondsFormat::Secs, true),
abom: self.abom.clone(),
};
info!(
agent_runtime_id = %stored_identity.agent_runtime_id,
binding_id = %binding.binding_id,
"registered agent identity"
);
return Ok(stored_identity);
}
let status = response.status();
let body = response.text().await.unwrap_or_default();
let is_last_candidate = index + 1 == urls.len();
if !is_last_candidate
&& matches!(
status,
StatusCode::NOT_FOUND | StatusCode::METHOD_NOT_ALLOWED
)
{
debug!(
url = %url,
status = %status,
"agent identity registration endpoint unavailable at candidate URL; trying fallback"
);
continue;
}
anyhow::bail!(
"agent identity registration failed with status {status} from {url}: {body}"
.with_context(|| format!("failed to parse agent identity response from {url}"))?;
let stored_identity = StoredAgentIdentity {
binding_id: binding.binding_id.clone(),
chatgpt_account_id: binding.chatgpt_account_id.clone(),
chatgpt_user_id: binding.chatgpt_user_id.clone(),
agent_runtime_id: response_body.agent_runtime_id,
private_key_pkcs8_base64: key_material.private_key_pkcs8_base64,
public_key_ssh: key_material.public_key_ssh,
registered_at: Utc::now().to_rfc3339_opts(SecondsFormat::Secs, true),
abom: self.abom.clone(),
};
info!(
agent_runtime_id = %stored_identity.agent_runtime_id,
binding_id = %binding.binding_id,
"registered agent identity"
);
return Ok(stored_identity);
}
anyhow::bail!("no candidate URLs were available for agent identity registration")
let status = response.status();
let body = response.text().await.unwrap_or_default();
anyhow::bail!("agent identity registration failed with status {status} from {url}: {body}")
}
fn load_stored_identity(
@@ -410,15 +389,12 @@ fn secret_scope(binding: &AgentIdentityBinding) -> Result<SecretScope> {
.context("agent identity binding must be a valid secrets scope")
}
fn agent_registration_urls(chatgpt_base_url: &str) -> Vec<String> {
fn agent_registration_url(chatgpt_base_url: &str) -> String {
let trimmed = chatgpt_base_url.trim_end_matches('/');
if let Some(root) = trimmed.strip_suffix("/backend-api") {
return vec![
format!("{root}/v1/agent/register"),
format!("{trimmed}/v1/agent/register"),
];
return format!("{root}/v1/agent/register");
}
vec![format!("{trimmed}/v1/agent/register")]
format!("{trimmed}/v1/agent/register")
}
#[cfg(test)]
@@ -589,18 +565,12 @@ mod tests {
}
#[tokio::test]
async fn ensure_registered_identity_falls_back_to_backend_api_v1() {
async fn ensure_registered_identity_uses_canonical_agent_registration_url() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/agent/register"))
.respond_with(ResponseTemplate::new(404))
.expect(1)
.mount(&server)
.await;
Mock::given(method("POST"))
.and(path("/backend-api/v1/agent/register"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"agent_runtime_id": "agent_fallback",
"agent_runtime_id": "agent_canonical",
})))
.expect(1)
.mount(&server)
@@ -628,7 +598,7 @@ mod tests {
.await
.unwrap()
.expect("identity should be registered");
assert_eq!(stored.agent_runtime_id, "agent_fallback");
assert_eq!(stored.agent_runtime_id, "agent_canonical");
}
#[test]

View File

@@ -1479,22 +1479,55 @@ impl Session {
}
fn start_agent_identity_registration(self: &Arc<Self>) {
if !self.services.agent_identity_manager.is_enabled() {
return;
}
let weak_sess = Arc::downgrade(self);
let mut auth_state_rx = self.services.auth_manager.subscribe_auth_state();
tokio::spawn(async move {
let Some(sess) = weak_sess.upgrade() else {
return;
};
if let Err(error) = sess
.services
.agent_identity_manager
.ensure_registered_identity()
.await
{
warn!(error = %error, "agent identity registration failed");
loop {
let Some(sess) = weak_sess.upgrade() else {
return;
};
match sess
.services
.agent_identity_manager
.ensure_registered_identity()
.await
{
Ok(Some(_)) => return,
Ok(None) => {
drop(sess);
if auth_state_rx.changed().await.is_err() {
return;
}
}
Err(error) => {
sess.fail_agent_identity_registration(error).await;
return;
}
}
}
});
}
async fn fail_agent_identity_registration(self: &Arc<Self>, error: anyhow::Error) {
warn!(error = %error, "agent identity registration failed");
let message = format!(
"Agent identity registration failed. Codex cannot continue while `features.use_agent_identity` is enabled: {error}"
);
self.send_event_raw(Event {
id: self.next_internal_sub_id(),
msg: EventMsg::Error(ErrorEvent {
message,
codex_error_info: Some(CodexErrorInfo::Other),
}),
})
.await;
handlers::shutdown(self, self.next_internal_sub_id()).await;
}
#[allow(clippy::too_many_arguments)]
fn make_turn_context(
conversation_id: ThreadId,

View File

@@ -110,6 +110,7 @@ use opentelemetry::trace::TraceId;
use std::path::Path;
use std::time::Duration;
use tokio::time::sleep;
use tokio::time::timeout;
use tracing_opentelemetry::OpenTelemetrySpanExt;
use codex_protocol::mcp::CallToolResult as McpCallToolResult;
@@ -3818,6 +3819,42 @@ pub(crate) async fn make_session_and_context_with_rx() -> (
make_session_and_context_with_dynamic_tools_and_rx(Vec::new()).await
}
#[tokio::test]
async fn fail_agent_identity_registration_emits_error_and_shutdown() {
let (session, _turn_context, rx_event) = make_session_and_context_with_rx().await;
session
.fail_agent_identity_registration(anyhow::anyhow!("registration exploded"))
.await;
let error_event = timeout(Duration::from_secs(1), rx_event.recv())
.await
.expect("error event should arrive")
.expect("error event should be readable");
match error_event.msg {
EventMsg::Error(ErrorEvent {
message,
codex_error_info,
}) => {
assert_eq!(
message,
"Agent identity registration failed. Codex cannot continue while `features.use_agent_identity` is enabled: registration exploded".to_string()
);
assert_eq!(codex_error_info, Some(CodexErrorInfo::Other));
}
other => panic!("expected error event, got {other:?}"),
}
let shutdown_event = timeout(Duration::from_secs(1), rx_event.recv())
.await
.expect("shutdown event should arrive")
.expect("shutdown event should be readable");
match shutdown_event.msg {
EventMsg::ShutdownComplete => {}
other => panic!("expected shutdown event, got {other:?}"),
}
}
#[tokio::test]
async fn refresh_mcp_servers_is_deferred_until_next_turn() {
let (session, turn_context) = make_session_and_context().await;

View File

@@ -16,6 +16,8 @@ use serde_json::json;
use std::sync::Arc;
use tempfile::TempDir;
use tempfile::tempdir;
use tokio::time::Duration;
use tokio::time::timeout;
#[tokio::test]
async fn refresh_without_id_token() {
@@ -474,6 +476,65 @@ exit 1
}
}
#[tokio::test]
async fn auth_manager_notifies_when_auth_state_changes() {
let dir = tempdir().unwrap();
let manager = AuthManager::shared(
dir.path().to_path_buf(),
false,
AuthCredentialsStoreMode::File,
);
let mut auth_state_rx = manager.subscribe_auth_state();
save_auth(
dir.path(),
&AuthDotJson {
auth_mode: Some(ApiAuthMode::ApiKey),
openai_api_key: Some("sk-test-key".to_string()),
tokens: None,
last_refresh: None,
},
AuthCredentialsStoreMode::File,
)
.expect("save auth");
assert!(
manager.reload(),
"reload should report a changed auth state"
);
timeout(Duration::from_secs(1), auth_state_rx.changed())
.await
.expect("auth change notification should arrive")
.expect("auth state watch should remain open");
save_auth(
dir.path(),
&AuthDotJson {
auth_mode: Some(ApiAuthMode::ApiKey),
openai_api_key: Some("sk-updated-key".to_string()),
tokens: None,
last_refresh: None,
},
AuthCredentialsStoreMode::File,
)
.expect("save updated auth");
assert!(
!manager.reload(),
"reload remains mode-stable even when the underlying credentials change"
);
timeout(Duration::from_secs(1), auth_state_rx.changed())
.await
.expect("auth reload notification should still arrive")
.expect("auth state watch should remain open");
manager.set_forced_chatgpt_workspace_id(Some("workspace-123".to_string()));
timeout(Duration::from_secs(1), auth_state_rx.changed())
.await
.expect("workspace change notification should arrive")
.expect("auth state watch should remain open");
}
struct AuthFileParams {
openai_api_key: Option<String>,
chatgpt_plan_type: Option<String>,

View File

@@ -13,6 +13,7 @@ use std::sync::Arc;
use std::sync::Mutex;
use std::sync::RwLock;
use tokio::sync::Mutex as AsyncMutex;
use tokio::sync::watch;
use codex_app_server_protocol::AuthMode;
use codex_app_server_protocol::AuthMode as ApiAuthMode;
@@ -1106,6 +1107,7 @@ pub struct AuthManager {
forced_chatgpt_workspace_id: RwLock<Option<String>>,
refresh_lock: AsyncMutex<()>,
external_auth: RwLock<Option<Arc<dyn ExternalAuth>>>,
auth_state_tx: watch::Sender<()>,
}
/// Configuration view required to construct a shared [`AuthManager`].
@@ -1154,6 +1156,7 @@ impl AuthManager {
enable_codex_api_key_env: bool,
auth_credentials_store_mode: AuthCredentialsStoreMode,
) -> Self {
let (auth_state_tx, _) = watch::channel(());
let managed_auth = load_auth(
&codex_home,
enable_codex_api_key_env,
@@ -1172,11 +1175,13 @@ impl AuthManager {
forced_chatgpt_workspace_id: RwLock::new(None),
refresh_lock: AsyncMutex::new(()),
external_auth: RwLock::new(None),
auth_state_tx,
}
}
/// Create an AuthManager with a specific CodexAuth, for testing only.
pub fn from_auth_for_testing(auth: CodexAuth) -> Arc<Self> {
let (auth_state_tx, _) = watch::channel(());
let cached = CachedAuth {
auth: Some(auth),
permanent_refresh_failure: None,
@@ -1190,11 +1195,13 @@ impl AuthManager {
forced_chatgpt_workspace_id: RwLock::new(None),
refresh_lock: AsyncMutex::new(()),
external_auth: RwLock::new(None),
auth_state_tx,
})
}
/// Create an AuthManager with a specific CodexAuth and codex home, for testing only.
pub fn from_auth_for_testing_with_home(auth: CodexAuth, codex_home: PathBuf) -> Arc<Self> {
let (auth_state_tx, _) = watch::channel(());
let cached = CachedAuth {
auth: Some(auth),
permanent_refresh_failure: None,
@@ -1207,10 +1214,12 @@ impl AuthManager {
forced_chatgpt_workspace_id: RwLock::new(None),
refresh_lock: AsyncMutex::new(()),
external_auth: RwLock::new(None),
auth_state_tx,
})
}
pub fn external_bearer_only(config: ModelProviderAuthInfo) -> Arc<Self> {
let (auth_state_tx, _) = watch::channel(());
Arc::new(Self {
codex_home: PathBuf::from("non-existent"),
inner: RwLock::new(CachedAuth {
@@ -1224,6 +1233,7 @@ impl AuthManager {
external_auth: RwLock::new(Some(
Arc::new(BearerTokenRefresher::new(config)) as Arc<dyn ExternalAuth>
)),
auth_state_tx,
})
}
@@ -1363,6 +1373,7 @@ impl AuthManager {
}
tracing::info!("Reloaded auth, changed: {changed}");
guard.auth = new_auth;
self.auth_state_tx.send_replace(());
changed
} else {
false
@@ -1372,18 +1383,23 @@ impl AuthManager {
pub fn set_external_auth(&self, external_auth: Arc<dyn ExternalAuth>) {
if let Ok(mut guard) = self.external_auth.write() {
*guard = Some(external_auth);
self.auth_state_tx.send_replace(());
}
}
pub fn clear_external_auth(&self) {
if let Ok(mut guard) = self.external_auth.write() {
*guard = None;
self.auth_state_tx.send_replace(());
}
}
pub fn set_forced_chatgpt_workspace_id(&self, workspace_id: Option<String>) {
if let Ok(mut guard) = self.forced_chatgpt_workspace_id.write() {
if let Ok(mut guard) = self.forced_chatgpt_workspace_id.write()
&& *guard != workspace_id
{
*guard = workspace_id;
self.auth_state_tx.send_replace(());
}
}
@@ -1394,6 +1410,10 @@ impl AuthManager {
.and_then(|guard| guard.clone())
}
pub fn subscribe_auth_state(&self) -> watch::Receiver<()> {
self.auth_state_tx.subscribe()
}
pub fn has_external_auth(&self) -> bool {
self.external_auth().is_some()
}