mirror of
https://github.com/openai/codex.git
synced 2026-05-18 02:02:30 +00:00
Defer agent identity registration until auth exists
# Conflicts: # codex-rs/login/src/auth/auth_tests.rs # codex-rs/login/src/auth/manager.rs
This commit is contained in:
@@ -21,7 +21,6 @@ use ed25519_dalek::pkcs8::DecodePrivateKey;
|
||||
use ed25519_dalek::pkcs8::EncodePrivateKey;
|
||||
use rand::TryRngCore;
|
||||
use rand::rngs::OsRng;
|
||||
use reqwest::StatusCode;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use tokio::sync::Mutex;
|
||||
@@ -108,6 +107,10 @@ impl AgentIdentityManager {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn is_enabled(&self) -> bool {
|
||||
self.feature_enabled
|
||||
}
|
||||
|
||||
pub(crate) async fn ensure_registered_identity(&self) -> Result<Option<StoredAgentIdentity>> {
|
||||
if !self.feature_enabled {
|
||||
return Ok(None);
|
||||
@@ -153,69 +156,45 @@ impl AgentIdentityManager {
|
||||
};
|
||||
|
||||
let client = create_client();
|
||||
let urls = agent_registration_urls(&self.chatgpt_base_url);
|
||||
let url = agent_registration_url(&self.chatgpt_base_url);
|
||||
let response = client
|
||||
.post(&url)
|
||||
.bearer_auth(&binding.access_token)
|
||||
.header("chatgpt-account-id", &binding.chatgpt_account_id)
|
||||
.json(&request_body)
|
||||
.timeout(AGENT_REGISTRATION_TIMEOUT)
|
||||
.send()
|
||||
.await
|
||||
.with_context(|| {
|
||||
format!("failed to send agent identity registration request to {url}")
|
||||
})?;
|
||||
|
||||
for (index, url) in urls.iter().enumerate() {
|
||||
let response = client
|
||||
.post(url)
|
||||
.bearer_auth(&binding.access_token)
|
||||
.header("chatgpt-account-id", &binding.chatgpt_account_id)
|
||||
.json(&request_body)
|
||||
.timeout(AGENT_REGISTRATION_TIMEOUT)
|
||||
.send()
|
||||
if response.status().is_success() {
|
||||
let response_body = response
|
||||
.json::<RegisterAgentResponse>()
|
||||
.await
|
||||
.with_context(|| {
|
||||
format!("failed to send agent identity registration request to {url}")
|
||||
})?;
|
||||
|
||||
if response.status().is_success() {
|
||||
let response_body = response
|
||||
.json::<RegisterAgentResponse>()
|
||||
.await
|
||||
.with_context(|| {
|
||||
format!("failed to parse agent identity response from {url}")
|
||||
})?;
|
||||
let stored_identity = StoredAgentIdentity {
|
||||
binding_id: binding.binding_id.clone(),
|
||||
chatgpt_account_id: binding.chatgpt_account_id.clone(),
|
||||
chatgpt_user_id: binding.chatgpt_user_id.clone(),
|
||||
agent_runtime_id: response_body.agent_runtime_id,
|
||||
private_key_pkcs8_base64: key_material.private_key_pkcs8_base64,
|
||||
public_key_ssh: key_material.public_key_ssh,
|
||||
registered_at: Utc::now().to_rfc3339_opts(SecondsFormat::Secs, true),
|
||||
abom: self.abom.clone(),
|
||||
};
|
||||
info!(
|
||||
agent_runtime_id = %stored_identity.agent_runtime_id,
|
||||
binding_id = %binding.binding_id,
|
||||
"registered agent identity"
|
||||
);
|
||||
return Ok(stored_identity);
|
||||
}
|
||||
|
||||
let status = response.status();
|
||||
let body = response.text().await.unwrap_or_default();
|
||||
let is_last_candidate = index + 1 == urls.len();
|
||||
if !is_last_candidate
|
||||
&& matches!(
|
||||
status,
|
||||
StatusCode::NOT_FOUND | StatusCode::METHOD_NOT_ALLOWED
|
||||
)
|
||||
{
|
||||
debug!(
|
||||
url = %url,
|
||||
status = %status,
|
||||
"agent identity registration endpoint unavailable at candidate URL; trying fallback"
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
anyhow::bail!(
|
||||
"agent identity registration failed with status {status} from {url}: {body}"
|
||||
.with_context(|| format!("failed to parse agent identity response from {url}"))?;
|
||||
let stored_identity = StoredAgentIdentity {
|
||||
binding_id: binding.binding_id.clone(),
|
||||
chatgpt_account_id: binding.chatgpt_account_id.clone(),
|
||||
chatgpt_user_id: binding.chatgpt_user_id.clone(),
|
||||
agent_runtime_id: response_body.agent_runtime_id,
|
||||
private_key_pkcs8_base64: key_material.private_key_pkcs8_base64,
|
||||
public_key_ssh: key_material.public_key_ssh,
|
||||
registered_at: Utc::now().to_rfc3339_opts(SecondsFormat::Secs, true),
|
||||
abom: self.abom.clone(),
|
||||
};
|
||||
info!(
|
||||
agent_runtime_id = %stored_identity.agent_runtime_id,
|
||||
binding_id = %binding.binding_id,
|
||||
"registered agent identity"
|
||||
);
|
||||
return Ok(stored_identity);
|
||||
}
|
||||
|
||||
anyhow::bail!("no candidate URLs were available for agent identity registration")
|
||||
let status = response.status();
|
||||
let body = response.text().await.unwrap_or_default();
|
||||
anyhow::bail!("agent identity registration failed with status {status} from {url}: {body}")
|
||||
}
|
||||
|
||||
fn load_stored_identity(
|
||||
@@ -410,15 +389,12 @@ fn secret_scope(binding: &AgentIdentityBinding) -> Result<SecretScope> {
|
||||
.context("agent identity binding must be a valid secrets scope")
|
||||
}
|
||||
|
||||
fn agent_registration_urls(chatgpt_base_url: &str) -> Vec<String> {
|
||||
fn agent_registration_url(chatgpt_base_url: &str) -> String {
|
||||
let trimmed = chatgpt_base_url.trim_end_matches('/');
|
||||
if let Some(root) = trimmed.strip_suffix("/backend-api") {
|
||||
return vec![
|
||||
format!("{root}/v1/agent/register"),
|
||||
format!("{trimmed}/v1/agent/register"),
|
||||
];
|
||||
return format!("{root}/v1/agent/register");
|
||||
}
|
||||
vec![format!("{trimmed}/v1/agent/register")]
|
||||
format!("{trimmed}/v1/agent/register")
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -589,18 +565,12 @@ mod tests {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn ensure_registered_identity_falls_back_to_backend_api_v1() {
|
||||
async fn ensure_registered_identity_uses_canonical_agent_registration_url() {
|
||||
let server = MockServer::start().await;
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/agent/register"))
|
||||
.respond_with(ResponseTemplate::new(404))
|
||||
.expect(1)
|
||||
.mount(&server)
|
||||
.await;
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/backend-api/v1/agent/register"))
|
||||
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
|
||||
"agent_runtime_id": "agent_fallback",
|
||||
"agent_runtime_id": "agent_canonical",
|
||||
})))
|
||||
.expect(1)
|
||||
.mount(&server)
|
||||
@@ -628,7 +598,7 @@ mod tests {
|
||||
.await
|
||||
.unwrap()
|
||||
.expect("identity should be registered");
|
||||
assert_eq!(stored.agent_runtime_id, "agent_fallback");
|
||||
assert_eq!(stored.agent_runtime_id, "agent_canonical");
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -1479,22 +1479,55 @@ impl Session {
|
||||
}
|
||||
|
||||
fn start_agent_identity_registration(self: &Arc<Self>) {
|
||||
if !self.services.agent_identity_manager.is_enabled() {
|
||||
return;
|
||||
}
|
||||
|
||||
let weak_sess = Arc::downgrade(self);
|
||||
let mut auth_state_rx = self.services.auth_manager.subscribe_auth_state();
|
||||
tokio::spawn(async move {
|
||||
let Some(sess) = weak_sess.upgrade() else {
|
||||
return;
|
||||
};
|
||||
if let Err(error) = sess
|
||||
.services
|
||||
.agent_identity_manager
|
||||
.ensure_registered_identity()
|
||||
.await
|
||||
{
|
||||
warn!(error = %error, "agent identity registration failed");
|
||||
loop {
|
||||
let Some(sess) = weak_sess.upgrade() else {
|
||||
return;
|
||||
};
|
||||
match sess
|
||||
.services
|
||||
.agent_identity_manager
|
||||
.ensure_registered_identity()
|
||||
.await
|
||||
{
|
||||
Ok(Some(_)) => return,
|
||||
Ok(None) => {
|
||||
drop(sess);
|
||||
if auth_state_rx.changed().await.is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
Err(error) => {
|
||||
sess.fail_agent_identity_registration(error).await;
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
async fn fail_agent_identity_registration(self: &Arc<Self>, error: anyhow::Error) {
|
||||
warn!(error = %error, "agent identity registration failed");
|
||||
let message = format!(
|
||||
"Agent identity registration failed. Codex cannot continue while `features.use_agent_identity` is enabled: {error}"
|
||||
);
|
||||
self.send_event_raw(Event {
|
||||
id: self.next_internal_sub_id(),
|
||||
msg: EventMsg::Error(ErrorEvent {
|
||||
message,
|
||||
codex_error_info: Some(CodexErrorInfo::Other),
|
||||
}),
|
||||
})
|
||||
.await;
|
||||
handlers::shutdown(self, self.next_internal_sub_id()).await;
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn make_turn_context(
|
||||
conversation_id: ThreadId,
|
||||
|
||||
@@ -110,6 +110,7 @@ use opentelemetry::trace::TraceId;
|
||||
use std::path::Path;
|
||||
use std::time::Duration;
|
||||
use tokio::time::sleep;
|
||||
use tokio::time::timeout;
|
||||
use tracing_opentelemetry::OpenTelemetrySpanExt;
|
||||
|
||||
use codex_protocol::mcp::CallToolResult as McpCallToolResult;
|
||||
@@ -3818,6 +3819,42 @@ pub(crate) async fn make_session_and_context_with_rx() -> (
|
||||
make_session_and_context_with_dynamic_tools_and_rx(Vec::new()).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn fail_agent_identity_registration_emits_error_and_shutdown() {
|
||||
let (session, _turn_context, rx_event) = make_session_and_context_with_rx().await;
|
||||
|
||||
session
|
||||
.fail_agent_identity_registration(anyhow::anyhow!("registration exploded"))
|
||||
.await;
|
||||
|
||||
let error_event = timeout(Duration::from_secs(1), rx_event.recv())
|
||||
.await
|
||||
.expect("error event should arrive")
|
||||
.expect("error event should be readable");
|
||||
match error_event.msg {
|
||||
EventMsg::Error(ErrorEvent {
|
||||
message,
|
||||
codex_error_info,
|
||||
}) => {
|
||||
assert_eq!(
|
||||
message,
|
||||
"Agent identity registration failed. Codex cannot continue while `features.use_agent_identity` is enabled: registration exploded".to_string()
|
||||
);
|
||||
assert_eq!(codex_error_info, Some(CodexErrorInfo::Other));
|
||||
}
|
||||
other => panic!("expected error event, got {other:?}"),
|
||||
}
|
||||
|
||||
let shutdown_event = timeout(Duration::from_secs(1), rx_event.recv())
|
||||
.await
|
||||
.expect("shutdown event should arrive")
|
||||
.expect("shutdown event should be readable");
|
||||
match shutdown_event.msg {
|
||||
EventMsg::ShutdownComplete => {}
|
||||
other => panic!("expected shutdown event, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn refresh_mcp_servers_is_deferred_until_next_turn() {
|
||||
let (session, turn_context) = make_session_and_context().await;
|
||||
|
||||
@@ -16,6 +16,8 @@ use serde_json::json;
|
||||
use std::sync::Arc;
|
||||
use tempfile::TempDir;
|
||||
use tempfile::tempdir;
|
||||
use tokio::time::Duration;
|
||||
use tokio::time::timeout;
|
||||
|
||||
#[tokio::test]
|
||||
async fn refresh_without_id_token() {
|
||||
@@ -474,6 +476,65 @@ exit 1
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn auth_manager_notifies_when_auth_state_changes() {
|
||||
let dir = tempdir().unwrap();
|
||||
let manager = AuthManager::shared(
|
||||
dir.path().to_path_buf(),
|
||||
false,
|
||||
AuthCredentialsStoreMode::File,
|
||||
);
|
||||
let mut auth_state_rx = manager.subscribe_auth_state();
|
||||
|
||||
save_auth(
|
||||
dir.path(),
|
||||
&AuthDotJson {
|
||||
auth_mode: Some(ApiAuthMode::ApiKey),
|
||||
openai_api_key: Some("sk-test-key".to_string()),
|
||||
tokens: None,
|
||||
last_refresh: None,
|
||||
},
|
||||
AuthCredentialsStoreMode::File,
|
||||
)
|
||||
.expect("save auth");
|
||||
|
||||
assert!(
|
||||
manager.reload(),
|
||||
"reload should report a changed auth state"
|
||||
);
|
||||
timeout(Duration::from_secs(1), auth_state_rx.changed())
|
||||
.await
|
||||
.expect("auth change notification should arrive")
|
||||
.expect("auth state watch should remain open");
|
||||
|
||||
save_auth(
|
||||
dir.path(),
|
||||
&AuthDotJson {
|
||||
auth_mode: Some(ApiAuthMode::ApiKey),
|
||||
openai_api_key: Some("sk-updated-key".to_string()),
|
||||
tokens: None,
|
||||
last_refresh: None,
|
||||
},
|
||||
AuthCredentialsStoreMode::File,
|
||||
)
|
||||
.expect("save updated auth");
|
||||
|
||||
assert!(
|
||||
!manager.reload(),
|
||||
"reload remains mode-stable even when the underlying credentials change"
|
||||
);
|
||||
timeout(Duration::from_secs(1), auth_state_rx.changed())
|
||||
.await
|
||||
.expect("auth reload notification should still arrive")
|
||||
.expect("auth state watch should remain open");
|
||||
|
||||
manager.set_forced_chatgpt_workspace_id(Some("workspace-123".to_string()));
|
||||
timeout(Duration::from_secs(1), auth_state_rx.changed())
|
||||
.await
|
||||
.expect("workspace change notification should arrive")
|
||||
.expect("auth state watch should remain open");
|
||||
}
|
||||
|
||||
struct AuthFileParams {
|
||||
openai_api_key: Option<String>,
|
||||
chatgpt_plan_type: Option<String>,
|
||||
|
||||
@@ -13,6 +13,7 @@ use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
use std::sync::RwLock;
|
||||
use tokio::sync::Mutex as AsyncMutex;
|
||||
use tokio::sync::watch;
|
||||
|
||||
use codex_app_server_protocol::AuthMode;
|
||||
use codex_app_server_protocol::AuthMode as ApiAuthMode;
|
||||
@@ -1106,6 +1107,7 @@ pub struct AuthManager {
|
||||
forced_chatgpt_workspace_id: RwLock<Option<String>>,
|
||||
refresh_lock: AsyncMutex<()>,
|
||||
external_auth: RwLock<Option<Arc<dyn ExternalAuth>>>,
|
||||
auth_state_tx: watch::Sender<()>,
|
||||
}
|
||||
|
||||
/// Configuration view required to construct a shared [`AuthManager`].
|
||||
@@ -1154,6 +1156,7 @@ impl AuthManager {
|
||||
enable_codex_api_key_env: bool,
|
||||
auth_credentials_store_mode: AuthCredentialsStoreMode,
|
||||
) -> Self {
|
||||
let (auth_state_tx, _) = watch::channel(());
|
||||
let managed_auth = load_auth(
|
||||
&codex_home,
|
||||
enable_codex_api_key_env,
|
||||
@@ -1172,11 +1175,13 @@ impl AuthManager {
|
||||
forced_chatgpt_workspace_id: RwLock::new(None),
|
||||
refresh_lock: AsyncMutex::new(()),
|
||||
external_auth: RwLock::new(None),
|
||||
auth_state_tx,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an AuthManager with a specific CodexAuth, for testing only.
|
||||
pub fn from_auth_for_testing(auth: CodexAuth) -> Arc<Self> {
|
||||
let (auth_state_tx, _) = watch::channel(());
|
||||
let cached = CachedAuth {
|
||||
auth: Some(auth),
|
||||
permanent_refresh_failure: None,
|
||||
@@ -1190,11 +1195,13 @@ impl AuthManager {
|
||||
forced_chatgpt_workspace_id: RwLock::new(None),
|
||||
refresh_lock: AsyncMutex::new(()),
|
||||
external_auth: RwLock::new(None),
|
||||
auth_state_tx,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create an AuthManager with a specific CodexAuth and codex home, for testing only.
|
||||
pub fn from_auth_for_testing_with_home(auth: CodexAuth, codex_home: PathBuf) -> Arc<Self> {
|
||||
let (auth_state_tx, _) = watch::channel(());
|
||||
let cached = CachedAuth {
|
||||
auth: Some(auth),
|
||||
permanent_refresh_failure: None,
|
||||
@@ -1207,10 +1214,12 @@ impl AuthManager {
|
||||
forced_chatgpt_workspace_id: RwLock::new(None),
|
||||
refresh_lock: AsyncMutex::new(()),
|
||||
external_auth: RwLock::new(None),
|
||||
auth_state_tx,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn external_bearer_only(config: ModelProviderAuthInfo) -> Arc<Self> {
|
||||
let (auth_state_tx, _) = watch::channel(());
|
||||
Arc::new(Self {
|
||||
codex_home: PathBuf::from("non-existent"),
|
||||
inner: RwLock::new(CachedAuth {
|
||||
@@ -1224,6 +1233,7 @@ impl AuthManager {
|
||||
external_auth: RwLock::new(Some(
|
||||
Arc::new(BearerTokenRefresher::new(config)) as Arc<dyn ExternalAuth>
|
||||
)),
|
||||
auth_state_tx,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1363,6 +1373,7 @@ impl AuthManager {
|
||||
}
|
||||
tracing::info!("Reloaded auth, changed: {changed}");
|
||||
guard.auth = new_auth;
|
||||
self.auth_state_tx.send_replace(());
|
||||
changed
|
||||
} else {
|
||||
false
|
||||
@@ -1372,18 +1383,23 @@ impl AuthManager {
|
||||
pub fn set_external_auth(&self, external_auth: Arc<dyn ExternalAuth>) {
|
||||
if let Ok(mut guard) = self.external_auth.write() {
|
||||
*guard = Some(external_auth);
|
||||
self.auth_state_tx.send_replace(());
|
||||
}
|
||||
}
|
||||
|
||||
pub fn clear_external_auth(&self) {
|
||||
if let Ok(mut guard) = self.external_auth.write() {
|
||||
*guard = None;
|
||||
self.auth_state_tx.send_replace(());
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_forced_chatgpt_workspace_id(&self, workspace_id: Option<String>) {
|
||||
if let Ok(mut guard) = self.forced_chatgpt_workspace_id.write() {
|
||||
if let Ok(mut guard) = self.forced_chatgpt_workspace_id.write()
|
||||
&& *guard != workspace_id
|
||||
{
|
||||
*guard = workspace_id;
|
||||
self.auth_state_tx.send_replace(());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1394,6 +1410,10 @@ impl AuthManager {
|
||||
.and_then(|guard| guard.clone())
|
||||
}
|
||||
|
||||
pub fn subscribe_auth_state(&self) -> watch::Receiver<()> {
|
||||
self.auth_state_tx.subscribe()
|
||||
}
|
||||
|
||||
pub fn has_external_auth(&self) -> bool {
|
||||
self.external_auth().is_some()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user