mirror of
https://github.com/openai/codex.git
synced 2026-05-25 13:34:51 +00:00
Refresh stale agent tasks lazily
This commit is contained in:
@@ -34,6 +34,7 @@ mod task_registration;
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) use assertion::AgentAssertionEnvelope;
|
||||
pub(crate) use assertion::AgentTaskRuntimeMismatch;
|
||||
pub(crate) use task_registration::RegisteredAgentTask;
|
||||
|
||||
const AGENT_REGISTRATION_TIMEOUT: Duration = Duration::from_secs(15);
|
||||
@@ -451,6 +452,9 @@ impl AgentIdentityBinding {
|
||||
}
|
||||
|
||||
fn from_auth(auth: &CodexAuth, forced_workspace_id: Option<String>) -> Option<Self> {
|
||||
// AgentAssertion is currently supported only for ChatGPT-backed Codex sessions. API-key
|
||||
// sessions keep using their API key until the registration service supports API-key
|
||||
// identity binding.
|
||||
if !auth.is_chatgpt_auth() {
|
||||
return None;
|
||||
}
|
||||
|
||||
@@ -7,10 +7,21 @@ use base64::engine::general_purpose::URL_SAFE_NO_PAD;
|
||||
use ed25519_dalek::Signer as _;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use thiserror::Error;
|
||||
use tracing::debug;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
#[error(
|
||||
"agent task runtime {agent_runtime_id} does not match stored agent identity {stored_agent_runtime_id}"
|
||||
)]
|
||||
pub(crate) struct AgentTaskRuntimeMismatch {
|
||||
pub(crate) agent_runtime_id: String,
|
||||
pub(crate) task_id: String,
|
||||
pub(crate) stored_agent_runtime_id: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub(crate) struct AgentAssertionEnvelope {
|
||||
pub(crate) agent_runtime_id: String,
|
||||
@@ -31,12 +42,14 @@ impl AgentIdentityManager {
|
||||
let Some(stored_identity) = self.ensure_registered_identity().await? else {
|
||||
return Ok(None);
|
||||
};
|
||||
anyhow::ensure!(
|
||||
stored_identity.agent_runtime_id == agent_task.agent_runtime_id,
|
||||
"agent task runtime {} does not match stored agent identity {}",
|
||||
agent_task.agent_runtime_id,
|
||||
stored_identity.agent_runtime_id
|
||||
);
|
||||
if stored_identity.agent_runtime_id != agent_task.agent_runtime_id {
|
||||
return Err(AgentTaskRuntimeMismatch {
|
||||
agent_runtime_id: agent_task.agent_runtime_id.clone(),
|
||||
task_id: agent_task.task_id.clone(),
|
||||
stored_agent_runtime_id: stored_identity.agent_runtime_id,
|
||||
}
|
||||
.into());
|
||||
}
|
||||
|
||||
let timestamp = Utc::now().to_rfc3339_opts(SecondsFormat::Secs, true);
|
||||
let envelope = AgentAssertionEnvelope {
|
||||
@@ -176,6 +189,39 @@ mod tests {
|
||||
.expect("signature should verify");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn authorization_header_for_task_reports_runtime_mismatch() {
|
||||
let codex_home = tempfile::tempdir().expect("tempdir");
|
||||
let auth = make_chatgpt_auth(codex_home.path(), "account-123", Some("user-123"));
|
||||
let auth_manager = AuthManager::from_auth_for_testing(auth);
|
||||
let manager = AgentIdentityManager::new_for_tests(
|
||||
auth_manager,
|
||||
/*feature_enabled*/ true,
|
||||
"https://chatgpt.com/backend-api/".to_string(),
|
||||
SessionSource::Cli,
|
||||
);
|
||||
manager
|
||||
.seed_generated_identity_for_tests("agent-current")
|
||||
.await
|
||||
.expect("seed test identity");
|
||||
let agent_task = RegisteredAgentTask {
|
||||
agent_runtime_id: "agent-stale".to_string(),
|
||||
task_id: "task-123".to_string(),
|
||||
registered_at: "2026-03-23T12:00:00Z".to_string(),
|
||||
};
|
||||
|
||||
let error = manager
|
||||
.authorization_header_for_task(&agent_task)
|
||||
.await
|
||||
.expect_err("stale task should be reported");
|
||||
let mismatch = error
|
||||
.downcast_ref::<AgentTaskRuntimeMismatch>()
|
||||
.expect("runtime mismatch error");
|
||||
assert_eq!(mismatch.agent_runtime_id, "agent-stale");
|
||||
assert_eq!(mismatch.task_id, "task-123");
|
||||
assert_eq!(mismatch.stored_agent_runtime_id, "agent-current");
|
||||
}
|
||||
|
||||
fn make_chatgpt_auth(
|
||||
codex_home: &std::path::Path,
|
||||
account_id: &str,
|
||||
|
||||
@@ -32,6 +32,7 @@ use std::sync::atomic::AtomicU64;
|
||||
use std::sync::atomic::Ordering;
|
||||
|
||||
use crate::agent_identity::AgentIdentityManager;
|
||||
use crate::agent_identity::AgentTaskRuntimeMismatch;
|
||||
use crate::agent_identity::RegisteredAgentTask;
|
||||
use codex_api::ApiError;
|
||||
use codex_api::CompactClient as ApiCompactClient;
|
||||
@@ -159,7 +160,7 @@ struct ModelClientState {
|
||||
include_timing_metrics: bool,
|
||||
beta_features_header: Option<String>,
|
||||
disable_websockets: AtomicBool,
|
||||
cached_websocket_session: StdMutex<WebsocketSession>,
|
||||
cached_websocket_session: StdMutex<CachedWebsocketSession>,
|
||||
}
|
||||
|
||||
/// Resolved API client setup for a single request attempt.
|
||||
@@ -244,6 +245,12 @@ struct WebsocketSession {
|
||||
connection_reused: StdMutex<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
struct CachedWebsocketSession {
|
||||
agent_task: Option<RegisteredAgentTask>,
|
||||
websocket_session: WebsocketSession,
|
||||
}
|
||||
|
||||
impl WebsocketSession {
|
||||
fn set_connection_reused(&self, connection_reused: bool) {
|
||||
*self
|
||||
@@ -360,7 +367,7 @@ impl ModelClient {
|
||||
include_timing_metrics,
|
||||
beta_features_header,
|
||||
disable_websockets: AtomicBool::new(false),
|
||||
cached_websocket_session: StdMutex::new(WebsocketSession::default()),
|
||||
cached_websocket_session: StdMutex::new(CachedWebsocketSession::default()),
|
||||
}),
|
||||
}
|
||||
}
|
||||
@@ -377,18 +384,15 @@ impl ModelClient {
|
||||
&self,
|
||||
agent_task: Option<RegisteredAgentTask>,
|
||||
) -> ModelClientSession {
|
||||
let cache_websocket_session_on_drop = agent_task.is_none();
|
||||
let websocket_session = if agent_task.is_some() {
|
||||
drop(self.take_cached_websocket_session());
|
||||
WebsocketSession::default()
|
||||
} else {
|
||||
self.take_cached_websocket_session()
|
||||
};
|
||||
// WebSocket auth is bound to the task that opened the connection. Reuse only when the
|
||||
// cached connection was created for the same task, and drop mismatched taskless/task-scoped
|
||||
// sessions rather than mixing auth contexts.
|
||||
let websocket_session = self.take_cached_websocket_session(agent_task.as_ref());
|
||||
ModelClientSession {
|
||||
client: self.clone(),
|
||||
websocket_session,
|
||||
agent_task,
|
||||
cache_websocket_session_on_drop,
|
||||
cache_websocket_session_on_drop: true,
|
||||
turn_state: Arc::new(OnceLock::new()),
|
||||
}
|
||||
}
|
||||
@@ -401,12 +405,12 @@ impl ModelClient {
|
||||
self.state
|
||||
.window_generation
|
||||
.store(window_generation, Ordering::Relaxed);
|
||||
self.store_cached_websocket_session(WebsocketSession::default());
|
||||
self.clear_cached_websocket_session();
|
||||
}
|
||||
|
||||
pub(crate) fn advance_window_generation(&self) {
|
||||
self.state.window_generation.fetch_add(1, Ordering::Relaxed);
|
||||
self.store_cached_websocket_session(WebsocketSession::default());
|
||||
self.clear_cached_websocket_session();
|
||||
}
|
||||
|
||||
fn current_window_id(&self) -> String {
|
||||
@@ -415,21 +419,44 @@ impl ModelClient {
|
||||
format!("{conversation_id}:{window_generation}")
|
||||
}
|
||||
|
||||
fn take_cached_websocket_session(&self) -> WebsocketSession {
|
||||
fn take_cached_websocket_session(
|
||||
&self,
|
||||
agent_task: Option<&RegisteredAgentTask>,
|
||||
) -> WebsocketSession {
|
||||
let mut cached_websocket_session = self
|
||||
.state
|
||||
.cached_websocket_session
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
std::mem::take(&mut *cached_websocket_session)
|
||||
if cached_websocket_session.agent_task.as_ref() == agent_task {
|
||||
return std::mem::take(&mut *cached_websocket_session).websocket_session;
|
||||
}
|
||||
|
||||
*cached_websocket_session = CachedWebsocketSession::default();
|
||||
WebsocketSession::default()
|
||||
}
|
||||
|
||||
fn store_cached_websocket_session(&self, websocket_session: WebsocketSession) {
|
||||
fn store_cached_websocket_session(
|
||||
&self,
|
||||
agent_task: Option<RegisteredAgentTask>,
|
||||
websocket_session: WebsocketSession,
|
||||
) {
|
||||
*self
|
||||
.state
|
||||
.cached_websocket_session
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner) = websocket_session;
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner) = CachedWebsocketSession {
|
||||
agent_task,
|
||||
websocket_session,
|
||||
};
|
||||
}
|
||||
|
||||
fn clear_cached_websocket_session(&self) {
|
||||
*self
|
||||
.state
|
||||
.cached_websocket_session
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner) = CachedWebsocketSession::default();
|
||||
}
|
||||
|
||||
pub(crate) fn force_http_fallback(
|
||||
@@ -449,7 +476,7 @@ impl ModelClient {
|
||||
);
|
||||
}
|
||||
|
||||
self.store_cached_websocket_session(WebsocketSession::default());
|
||||
self.clear_cached_websocket_session();
|
||||
activated
|
||||
}
|
||||
|
||||
@@ -727,6 +754,15 @@ impl ModelClient {
|
||||
.authorization_header_for_task(agent_task)
|
||||
.await
|
||||
.map_err(|err| {
|
||||
if let Some(mismatch) = err.downcast_ref::<AgentTaskRuntimeMismatch>() {
|
||||
debug!(
|
||||
agent_runtime_id = %mismatch.agent_runtime_id,
|
||||
task_id = %mismatch.task_id,
|
||||
stored_agent_runtime_id = %mismatch.stored_agent_runtime_id,
|
||||
"agent task no longer matches stored identity"
|
||||
);
|
||||
return CodexErr::AgentTaskStale;
|
||||
}
|
||||
CodexErr::Stream(
|
||||
format!("failed to build agent assertion authorization: {err}"),
|
||||
None,
|
||||
@@ -883,12 +919,16 @@ impl Drop for ModelClientSession {
|
||||
let websocket_session = std::mem::take(&mut self.websocket_session);
|
||||
if self.cache_websocket_session_on_drop {
|
||||
self.client
|
||||
.store_cached_websocket_session(websocket_session);
|
||||
.store_cached_websocket_session(self.agent_task.clone(), websocket_session);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ModelClientSession {
|
||||
pub(crate) fn agent_task(&self) -> Option<&RegisteredAgentTask> {
|
||||
self.agent_task.as_ref()
|
||||
}
|
||||
|
||||
pub(crate) fn disable_cached_websocket_session_on_drop(&mut self) {
|
||||
self.cache_websocket_session_on_drop = false;
|
||||
}
|
||||
|
||||
@@ -33,6 +33,7 @@ use codex_model_provider_info::create_oss_provider_with_base_url;
|
||||
use codex_otel::SessionTelemetry;
|
||||
use codex_protocol::ThreadId;
|
||||
use codex_protocol::config_types::ReasoningSummary;
|
||||
use codex_protocol::error::CodexErr;
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::openai_models::ModelInfo;
|
||||
@@ -393,6 +394,35 @@ async fn responses_http_uses_agent_assertion_when_agent_task_is_present() {
|
||||
assert_eq!(request.header("chatgpt-account-id"), None);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn responses_http_reports_stale_agent_task_when_identity_changed() {
|
||||
let provider = create_oss_provider_with_base_url("https://example.com/v1", WireApi::Responses);
|
||||
let (_codex_home, client, mut agent_task, _stored_identity) =
|
||||
model_client_with_agent_task(provider).await;
|
||||
agent_task.agent_runtime_id = "agent-stale".to_string();
|
||||
let model_info = test_model_info();
|
||||
let session_telemetry = test_session_telemetry();
|
||||
let mut client_session = client.new_session_with_agent_task(Some(agent_task));
|
||||
|
||||
let error = match client_session
|
||||
.stream(
|
||||
&test_prompt("hello"),
|
||||
&model_info,
|
||||
&session_telemetry,
|
||||
/*effort*/ None,
|
||||
ReasoningSummary::Auto,
|
||||
/*service_tier*/ None,
|
||||
/*turn_metadata_header*/ None,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(_) => panic!("stale task should be reported before sending a request"),
|
||||
Err(error) => error,
|
||||
};
|
||||
|
||||
assert!(matches!(error, CodexErr::AgentTaskStale));
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn websocket_agent_task_bypasses_cached_bearer_prewarm() {
|
||||
core_test_support::skip_if_no_network!();
|
||||
@@ -469,3 +499,80 @@ async fn websocket_agent_task_bypasses_cached_bearer_prewarm() {
|
||||
|
||||
server.shutdown().await;
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn websocket_agent_task_reuses_cached_connection_for_same_task() {
|
||||
core_test_support::skip_if_no_network!();
|
||||
|
||||
let server = responses::start_websocket_server(vec![vec![
|
||||
vec![
|
||||
responses::ev_response_created("resp-1"),
|
||||
responses::ev_completed("resp-1"),
|
||||
],
|
||||
vec![
|
||||
responses::ev_response_created("resp-2"),
|
||||
responses::ev_completed("resp-2"),
|
||||
],
|
||||
]])
|
||||
.await;
|
||||
let mut provider =
|
||||
create_oss_provider_with_base_url(&format!("{}/v1", server.uri()), WireApi::Responses);
|
||||
provider.supports_websockets = true;
|
||||
provider.websocket_connect_timeout_ms = Some(5_000);
|
||||
let (_codex_home, client, agent_task, stored_identity) =
|
||||
model_client_with_agent_task(provider).await;
|
||||
let model_info = test_model_info();
|
||||
let session_telemetry = test_session_telemetry();
|
||||
let prompt = test_prompt("hello");
|
||||
|
||||
{
|
||||
let mut first_session = client.new_session_with_agent_task(Some(agent_task.clone()));
|
||||
let mut stream = first_session
|
||||
.stream(
|
||||
&prompt,
|
||||
&model_info,
|
||||
&session_telemetry,
|
||||
/*effort*/ None,
|
||||
ReasoningSummary::Auto,
|
||||
/*service_tier*/ None,
|
||||
/*turn_metadata_header*/ None,
|
||||
)
|
||||
.await
|
||||
.expect("first agent task stream should succeed");
|
||||
drain_stream_to_completion(&mut stream)
|
||||
.await
|
||||
.expect("first agent task websocket stream should complete");
|
||||
}
|
||||
|
||||
let mut second_session = client.new_session_with_agent_task(Some(agent_task.clone()));
|
||||
let mut stream = second_session
|
||||
.stream(
|
||||
&prompt,
|
||||
&model_info,
|
||||
&session_telemetry,
|
||||
/*effort*/ None,
|
||||
ReasoningSummary::Auto,
|
||||
/*service_tier*/ None,
|
||||
/*turn_metadata_header*/ None,
|
||||
)
|
||||
.await
|
||||
.expect("second agent task stream should succeed");
|
||||
drain_stream_to_completion(&mut stream)
|
||||
.await
|
||||
.expect("second agent task websocket stream should complete");
|
||||
|
||||
let handshakes = server.handshakes();
|
||||
assert_eq!(handshakes.len(), 1);
|
||||
let agent_authorization = handshakes[0]
|
||||
.header("authorization")
|
||||
.expect("agent handshake should include authorization");
|
||||
assert_agent_assertion_header(
|
||||
&agent_authorization,
|
||||
&stored_identity,
|
||||
&agent_task.agent_runtime_id,
|
||||
&agent_task.task_id,
|
||||
);
|
||||
assert_eq!(server.single_connection().len(), 2);
|
||||
|
||||
server.shutdown().await;
|
||||
}
|
||||
|
||||
@@ -6354,6 +6354,14 @@ pub(crate) async fn run_turn(
|
||||
.await;
|
||||
user_prompt_submit_outcome.additional_contexts
|
||||
};
|
||||
let agent_task_registration = if sess.services.agent_identity_manager.is_enabled() {
|
||||
let sess = Arc::clone(&sess);
|
||||
Some(tokio::spawn(async move {
|
||||
sess.ensure_agent_task_registered().await
|
||||
}))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
sess.services
|
||||
.analytics_events_client
|
||||
.track_app_mentioned(tracking.clone(), mentioned_app_invocations);
|
||||
@@ -6375,7 +6383,15 @@ pub(crate) async fn run_turn(
|
||||
}))
|
||||
.await;
|
||||
}
|
||||
let agent_task = match sess.ensure_agent_task_registered().await {
|
||||
let agent_task_result = match agent_task_registration {
|
||||
Some(registration) => registration.await.unwrap_or_else(|error| {
|
||||
Err(anyhow::anyhow!(
|
||||
"agent task registration task failed: {error}"
|
||||
))
|
||||
}),
|
||||
None => sess.ensure_agent_task_registered().await,
|
||||
};
|
||||
let agent_task = match agent_task_result {
|
||||
Ok(agent_task) => agent_task,
|
||||
Err(error) => {
|
||||
warn!(error = %error, "agent task registration failed");
|
||||
@@ -7031,6 +7047,7 @@ async fn run_sampling_request(
|
||||
)
|
||||
.await;
|
||||
let mut retries = 0;
|
||||
let mut stale_agent_task_refreshed = false;
|
||||
loop {
|
||||
let err = match try_run_sampling_request(
|
||||
tool_runtime.clone(),
|
||||
@@ -7052,6 +7069,40 @@ async fn run_sampling_request(
|
||||
sess.set_total_tokens_full(&turn_context).await;
|
||||
return Err(CodexErr::ContextWindowExceeded);
|
||||
}
|
||||
Err(CodexErr::AgentTaskStale) => {
|
||||
if stale_agent_task_refreshed {
|
||||
return Err(CodexErr::AgentTaskStale);
|
||||
}
|
||||
stale_agent_task_refreshed = true;
|
||||
let stale_agent_task = client_session.agent_task().cloned();
|
||||
client_session.disable_cached_websocket_session_on_drop();
|
||||
if let Some(stale_agent_task) = stale_agent_task.as_ref() {
|
||||
sess.clear_cached_agent_task(stale_agent_task).await;
|
||||
}
|
||||
match sess.ensure_agent_task_registered().await {
|
||||
Ok(Some(agent_task)) => {
|
||||
*client_session = sess
|
||||
.services
|
||||
.model_client
|
||||
.new_session_with_agent_task(Some(agent_task));
|
||||
retries = 0;
|
||||
continue;
|
||||
}
|
||||
Ok(None) => {
|
||||
return Err(CodexErr::Stream(
|
||||
"agent assertion task became unavailable after identity changed"
|
||||
.to_string(),
|
||||
None,
|
||||
));
|
||||
}
|
||||
Err(error) => {
|
||||
return Err(CodexErr::Stream(
|
||||
format!("failed to refresh stale agent task: {error}"),
|
||||
None,
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(CodexErr::UsageLimitReached(e)) => {
|
||||
let rate_limits = e.rate_limits.clone();
|
||||
if let Some(rate_limits) = rate_limits {
|
||||
|
||||
@@ -13,6 +13,7 @@ use crate::codex::INITIAL_SUBMIT_ID;
|
||||
use crate::codex::Session;
|
||||
use crate::codex::build_prompt;
|
||||
use crate::codex::built_tools;
|
||||
use codex_app_server_protocol::AuthMode;
|
||||
use codex_otel::STARTUP_PREWARM_AGE_AT_FIRST_TURN_METRIC;
|
||||
use codex_otel::STARTUP_PREWARM_DURATION_METRIC;
|
||||
use codex_otel::SessionTelemetry;
|
||||
@@ -157,6 +158,15 @@ impl SessionStartupPrewarmHandle {
|
||||
|
||||
impl Session {
|
||||
pub(crate) async fn schedule_startup_prewarm(self: &Arc<Self>, base_instructions: String) {
|
||||
if self.services.agent_identity_manager.is_enabled()
|
||||
&& self.services.auth_manager.auth_mode() != Some(AuthMode::ApiKey)
|
||||
{
|
||||
info!(
|
||||
"skipping startup websocket prewarm because agent identity requires task-scoped auth"
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
let session_telemetry = self.services.session_telemetry.clone();
|
||||
let websocket_connect_timeout = self.provider().await.websocket_connect_timeout();
|
||||
let started_at = Instant::now();
|
||||
|
||||
@@ -76,6 +76,8 @@ pub enum CodexErr {
|
||||
/// Optionally includes the requested delay before retrying the turn.
|
||||
#[error("stream disconnected before completion: {0}")]
|
||||
Stream(String, Option<Duration>),
|
||||
#[error("agent task no longer matches the current agent identity")]
|
||||
AgentTaskStale,
|
||||
#[error(
|
||||
"Codex ran out of room in the model's context window. Start a new thread or clear earlier history before retrying."
|
||||
)]
|
||||
@@ -183,6 +185,7 @@ impl CodexErr {
|
||||
| CodexErr::ContextWindowExceeded
|
||||
| CodexErr::ThreadNotFound(_)
|
||||
| CodexErr::AgentLimitReached { .. }
|
||||
| CodexErr::AgentTaskStale
|
||||
| CodexErr::Spawn
|
||||
| CodexErr::SessionConfiguredNotFirstEvent
|
||||
| CodexErr::UsageLimitReached(_)
|
||||
|
||||
Reference in New Issue
Block a user