mirror of
https://github.com/openai/codex.git
synced 2026-04-24 22:54:54 +00:00
Move warmup to the task level (#11216)
Instead of storing a special connection on the client level make the regular task responsible for establishing a normal client session and open a connection on it. Then when the turn is started we pass in a pre-established session.
This commit is contained in:
@@ -9,26 +9,24 @@
|
||||
//! call site.
|
||||
//!
|
||||
//! A [`ModelClientSession`] is created per turn and is used to stream one or more Responses API
|
||||
//! requests during that turn. It caches a Responses WebSocket connection (opened lazily, or reused
|
||||
//! from a session-level preconnect) and stores per-turn state such as the `x-codex-turn-state`
|
||||
//! token used for sticky routing.
|
||||
//! requests during that turn. It caches a Responses WebSocket connection (opened lazily) and stores
|
||||
//! per-turn state such as the `x-codex-turn-state` token used for sticky routing.
|
||||
//!
|
||||
//! Preconnect is intentionally handshake-only: it may warm a socket and capture sticky-routing
|
||||
//! Prewarm is intentionally handshake-only: it may warm a socket and capture sticky-routing
|
||||
//! state, but the first `response.create` payload is still sent only when a turn starts.
|
||||
//!
|
||||
//! Internally, startup preconnect stores a single task handle. On first use in a turn, the session
|
||||
//! awaits that task and adopts the warmed socket if it succeeds; if it fails, the stream attempt
|
||||
//! fails and the normal retry/fallback loop decides what to do next.
|
||||
//! Startup prewarm is owned by turn-scoped callers (for example, a pre-created regular task). When
|
||||
//! a warmed [`ModelClientSession`] is available, turn execution can reuse it; otherwise the turn
|
||||
//! lazily opens a websocket on first stream call.
|
||||
//!
|
||||
//! ## Retry-Budget Tradeoff
|
||||
//!
|
||||
//! Startup preconnect is treated as the first websocket connection attempt for the first turn. If
|
||||
//! Startup prewarm is treated as the first websocket connection attempt for the first turn. If
|
||||
//! it fails, the stream attempt fails and the retry/fallback loop decides whether to retry or fall
|
||||
//! back. This avoids duplicate handshakes but means a failed preconnect can consume one retry
|
||||
//! back. This avoids duplicate handshakes but means a failed prewarm can consume one retry
|
||||
//! budget slot before any turn payload is sent.
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
use std::sync::OnceLock;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::atomic::Ordering;
|
||||
@@ -73,7 +71,6 @@ use codex_protocol::protocol::SessionSource;
|
||||
use eventsource_stream::Event;
|
||||
use eventsource_stream::EventStreamError;
|
||||
use futures::StreamExt;
|
||||
use futures::future::BoxFuture;
|
||||
use http::HeaderMap as ApiHeaderMap;
|
||||
use http::HeaderValue;
|
||||
use http::StatusCode as HttpStatusCode;
|
||||
@@ -83,7 +80,6 @@ use std::time::Duration;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::oneshot;
|
||||
use tokio::sync::oneshot::error::TryRecvError;
|
||||
use tokio::task::JoinHandle;
|
||||
use tokio_tungstenite::tungstenite::Error;
|
||||
use tokio_tungstenite::tungstenite::Message;
|
||||
use tracing::warn;
|
||||
@@ -109,17 +105,11 @@ pub const X_CODEX_TURN_METADATA_HEADER: &str = "x-codex-turn-metadata";
|
||||
pub const X_RESPONSESAPI_INCLUDE_TIMING_METRICS_HEADER: &str =
|
||||
"x-responsesapi-include-timing-metrics";
|
||||
const RESPONSES_WEBSOCKETS_V2_BETA_HEADER_VALUE: &str = "responses_websockets=2026-02-06";
|
||||
|
||||
struct PreconnectedWebSocket {
|
||||
connection: ApiWebSocketConnection,
|
||||
turn_state: Option<String>,
|
||||
}
|
||||
|
||||
type PreconnectTask = JoinHandle<Option<PreconnectedWebSocket>>;
|
||||
/// Session-scoped state shared by all [`ModelClient`] clones.
|
||||
///
|
||||
/// This is intentionally kept minimal so `ModelClient` does not need to hold a full `Config`. Most
|
||||
/// configuration is per turn and is passed explicitly to streaming/unary methods.
|
||||
#[derive(Debug)]
|
||||
struct ModelClientState {
|
||||
auth_manager: Option<Arc<AuthManager>>,
|
||||
conversation_id: ThreadId,
|
||||
@@ -132,40 +122,11 @@ struct ModelClientState {
|
||||
include_timing_metrics: bool,
|
||||
beta_features_header: Option<String>,
|
||||
disable_websockets: AtomicBool,
|
||||
|
||||
preconnect: Mutex<Option<PreconnectTask>>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for ModelClientState {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("ModelClientState")
|
||||
.field("auth_manager", &self.auth_manager)
|
||||
.field("conversation_id", &self.conversation_id)
|
||||
.field("provider", &self.provider)
|
||||
.field("session_source", &self.session_source)
|
||||
.field("model_verbosity", &self.model_verbosity)
|
||||
.field(
|
||||
"enable_responses_websockets",
|
||||
&self.enable_responses_websockets,
|
||||
)
|
||||
.field(
|
||||
"enable_request_compression",
|
||||
&self.enable_request_compression,
|
||||
)
|
||||
.field("include_timing_metrics", &self.include_timing_metrics)
|
||||
.field("beta_features_header", &self.beta_features_header)
|
||||
.field(
|
||||
"disable_websockets",
|
||||
&self.disable_websockets.load(Ordering::Relaxed),
|
||||
)
|
||||
.field("preconnect", &"<opaque>")
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
/// Resolved API client setup for a single request attempt.
|
||||
///
|
||||
/// Keeping this as a single bundle ensures preconnect and normal request paths
|
||||
/// Keeping this as a single bundle ensures prewarm and normal request paths
|
||||
/// share the same auth/provider setup flow.
|
||||
struct CurrentClientSetup {
|
||||
auth: Option<CodexAuth>,
|
||||
@@ -192,17 +153,14 @@ pub struct ModelClient {
|
||||
|
||||
/// A turn-scoped streaming session created from a [`ModelClient`].
|
||||
///
|
||||
/// The session establishes a Responses WebSocket connection lazily (or adopts a preconnected one)
|
||||
/// and reuses it across multiple requests within the turn. It also caches per-turn state:
|
||||
/// The session establishes a Responses WebSocket connection lazily and reuses it across multiple
|
||||
/// requests within the turn. It also caches per-turn state:
|
||||
///
|
||||
/// - The last request's input items, so subsequent calls can use `response.append` when the input
|
||||
/// is an incremental extension of the previous request.
|
||||
/// - The `x-codex-turn-state` sticky-routing token, which must be replayed for all requests within
|
||||
/// the same turn.
|
||||
///
|
||||
/// When startup preconnect is still running, first use of this session awaits that in-flight task
|
||||
/// before opening a new websocket so preconnect acts as the first connection attempt for the turn.
|
||||
///
|
||||
/// Create a fresh `ModelClientSession` for each Codex turn. Reusing it across turns would replay
|
||||
/// the previous turn's sticky-routing token into the next turn, which violates the client/server
|
||||
/// contract and can cause routing bugs.
|
||||
@@ -261,16 +219,14 @@ impl ModelClient {
|
||||
include_timing_metrics,
|
||||
beta_features_header,
|
||||
disable_websockets: AtomicBool::new(false),
|
||||
preconnect: Mutex::new(None),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a fresh turn-scoped streaming session.
|
||||
///
|
||||
/// This constructor does not perform network I/O itself. The returned session either adopts a
|
||||
/// previously preconnected websocket or opens a websocket lazily when the first stream request
|
||||
/// is issued.
|
||||
/// This constructor does not perform network I/O itself; the session opens a websocket lazily
|
||||
/// when the first stream request is issued.
|
||||
pub fn new_session(&self) -> ModelClientSession {
|
||||
ModelClientSession {
|
||||
client: self.clone(),
|
||||
@@ -282,79 +238,6 @@ impl ModelClient {
|
||||
}
|
||||
}
|
||||
|
||||
/// Spawns a best-effort task that warms a websocket for the first turn.
|
||||
///
|
||||
/// This call performs only connection setup; it never sends prompt payloads.
|
||||
///
|
||||
/// A timeout when computing turn metadata is treated the same as "no metadata" so startup
|
||||
/// cannot block indefinitely on optional preconnect context.
|
||||
pub fn pre_establish_connection(
|
||||
&self,
|
||||
otel_manager: OtelManager,
|
||||
turn_metadata_header: BoxFuture<'static, Option<String>>,
|
||||
) {
|
||||
if !self.responses_websocket_enabled() || self.disable_websockets() {
|
||||
return;
|
||||
}
|
||||
|
||||
let model_client = self.clone();
|
||||
let handle = tokio::spawn(async move {
|
||||
let turn_metadata_header = turn_metadata_header.await;
|
||||
|
||||
model_client
|
||||
.preconnect(&otel_manager, turn_metadata_header.as_deref())
|
||||
.await
|
||||
});
|
||||
self.set_preconnected_task(Some(handle));
|
||||
}
|
||||
|
||||
/// Opportunistically pre-establishes a Responses WebSocket connection for this session.
|
||||
///
|
||||
/// This method is best-effort: it returns an error on setup/connect failure and the caller
|
||||
/// can decide whether to ignore it. A successful preconnect reduces first-turn latency but
|
||||
/// never sends an initial prompt; the first `response.create` is still sent only when a turn
|
||||
/// starts.
|
||||
///
|
||||
/// The preconnected slot is single-consumer and single-use: the next `ModelClientSession` may
|
||||
/// adopt it once, after which later turns either keep using that same turn-local connection or
|
||||
/// create a new one.
|
||||
async fn preconnect(
|
||||
&self,
|
||||
otel_manager: &OtelManager,
|
||||
turn_metadata_header: Option<&str>,
|
||||
) -> Option<PreconnectedWebSocket> {
|
||||
if !self.responses_websocket_enabled() || self.disable_websockets() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let client_setup = self
|
||||
.current_client_setup()
|
||||
.await
|
||||
.map_err(|err| {
|
||||
ApiError::Stream(format!(
|
||||
"failed to build websocket preconnect client setup: {err}"
|
||||
))
|
||||
})
|
||||
.ok()?;
|
||||
|
||||
let turn_state = Arc::new(OnceLock::new());
|
||||
let connection = self
|
||||
.connect_websocket(
|
||||
otel_manager,
|
||||
client_setup.api_provider,
|
||||
client_setup.api_auth,
|
||||
Some(Arc::clone(&turn_state)),
|
||||
turn_metadata_header,
|
||||
)
|
||||
.await
|
||||
.ok()?;
|
||||
|
||||
Some(PreconnectedWebSocket {
|
||||
connection,
|
||||
turn_state: turn_state.get().cloned(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Compacts the current conversation history using the Compact endpoint.
|
||||
///
|
||||
/// This is a unary call (no streaming) that returns a new list of
|
||||
@@ -475,7 +358,7 @@ impl ModelClient {
|
||||
|
||||
/// Returns auth + provider configuration resolved from the current session auth state.
|
||||
///
|
||||
/// This centralizes setup used by both preconnect and normal request paths so they stay in
|
||||
/// This centralizes setup used by both prewarm and normal request paths so they stay in
|
||||
/// lockstep when auth/provider resolution changes.
|
||||
async fn current_client_setup(&self) -> Result<CurrentClientSetup> {
|
||||
let auth = match self.state.auth_manager.as_ref() {
|
||||
@@ -496,7 +379,7 @@ impl ModelClient {
|
||||
|
||||
/// Opens a websocket connection using the same header and telemetry wiring as normal turns.
|
||||
///
|
||||
/// Both startup preconnect and in-turn `needs_new` reconnects call this path so handshake
|
||||
/// Both startup prewarm and in-turn `needs_new` reconnects call this path so handshake
|
||||
/// behavior remains consistent across both flows.
|
||||
async fn connect_websocket(
|
||||
&self,
|
||||
@@ -513,7 +396,7 @@ impl ModelClient {
|
||||
.await
|
||||
}
|
||||
|
||||
/// Builds websocket handshake headers for both preconnect and turn-time reconnect.
|
||||
/// Builds websocket handshake headers for both prewarm and turn-time reconnect.
|
||||
///
|
||||
/// Callers should pass the current turn-state lock when available so sticky-routing state is
|
||||
/// replayed on reconnect within the same turn.
|
||||
@@ -548,28 +431,6 @@ impl ModelClient {
|
||||
}
|
||||
headers
|
||||
}
|
||||
|
||||
/// Consumes the warmed websocket task slot.
|
||||
fn take_preconnected_task(&self) -> Option<PreconnectTask> {
|
||||
let mut state = self
|
||||
.state
|
||||
.preconnect
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
state.take()
|
||||
}
|
||||
|
||||
fn set_preconnected_task(&self, task: Option<PreconnectTask>) {
|
||||
let mut state = self
|
||||
.state
|
||||
.preconnect
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
if let Some(running_task) = state.take() {
|
||||
running_task.abort();
|
||||
}
|
||||
*state = task;
|
||||
}
|
||||
}
|
||||
|
||||
impl ModelClientSession {
|
||||
@@ -772,13 +633,42 @@ impl ModelClientSession {
|
||||
)
|
||||
}
|
||||
|
||||
/// Returns a websocket connection for this turn, reusing preconnect when possible.
|
||||
/// Opportunistically warms a websocket for this turn-scoped client session.
|
||||
///
|
||||
/// This method first tries to adopt the session-level preconnect slot, then falls back to a
|
||||
/// fresh websocket handshake only when the turn has no live connection. If startup preconnect
|
||||
/// is still running, it is awaited first so that task acts as the first connection attempt for
|
||||
/// this turn instead of racing a second handshake. If that attempt fails, the normal connect
|
||||
/// and stream retry flow continues unchanged.
|
||||
/// This performs only connection setup; it never sends prompt payloads.
|
||||
pub async fn prewarm_websocket(
|
||||
&mut self,
|
||||
otel_manager: &OtelManager,
|
||||
turn_metadata_header: Option<&str>,
|
||||
) -> std::result::Result<(), ApiError> {
|
||||
if !self.client.responses_websocket_enabled() || self.client.disable_websockets() {
|
||||
return Ok(());
|
||||
}
|
||||
if self.connection.is_some() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let client_setup = self.client.current_client_setup().await.map_err(|err| {
|
||||
ApiError::Stream(format!(
|
||||
"failed to build websocket prewarm client setup: {err}"
|
||||
))
|
||||
})?;
|
||||
|
||||
let connection = self
|
||||
.client
|
||||
.connect_websocket(
|
||||
otel_manager,
|
||||
client_setup.api_provider,
|
||||
client_setup.api_auth,
|
||||
Some(Arc::clone(&self.turn_state)),
|
||||
turn_metadata_header,
|
||||
)
|
||||
.await?;
|
||||
self.connection = Some(connection);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Returns a websocket connection for this turn.
|
||||
async fn websocket_connection(
|
||||
&mut self,
|
||||
otel_manager: &OtelManager,
|
||||
@@ -787,27 +677,6 @@ impl ModelClientSession {
|
||||
turn_metadata_header: Option<&str>,
|
||||
options: &ApiResponsesOptions,
|
||||
) -> std::result::Result<&ApiWebSocketConnection, ApiError> {
|
||||
// Prefer the session-level preconnect slot before creating a new websocket.
|
||||
if self.connection.is_none()
|
||||
&& let Some(task) = self.client.take_preconnected_task()
|
||||
{
|
||||
match task.await {
|
||||
Ok(Some(preconnected)) => {
|
||||
let PreconnectedWebSocket {
|
||||
connection,
|
||||
turn_state,
|
||||
} = preconnected;
|
||||
if let Some(turn_state) = turn_state {
|
||||
let _ = self.turn_state.set(turn_state);
|
||||
}
|
||||
self.connection = Some(connection);
|
||||
}
|
||||
_ => {
|
||||
warn!("startup websocket preconnect task failed");
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
let needs_new = match self.connection.as_ref() {
|
||||
Some(conn) => conn.is_closed().await,
|
||||
None => true,
|
||||
@@ -1083,8 +952,7 @@ impl ModelClientSession {
|
||||
/// Permanently disables WebSockets for this Codex session and resets WebSocket state.
|
||||
///
|
||||
/// This is used after exhausting the provider retry budget, to force subsequent requests onto
|
||||
/// the HTTP transport. It also clears any warmed websocket preconnect state so future turns
|
||||
/// cannot accidentally adopt a stale socket after fallback has been activated.
|
||||
/// the HTTP transport.
|
||||
///
|
||||
/// Returns `true` if this call activated fallback, or `false` if fallback was already active.
|
||||
pub(crate) fn try_switch_fallback_transport(&mut self, otel_manager: &OtelManager) -> bool {
|
||||
@@ -1098,7 +966,6 @@ impl ModelClientSession {
|
||||
&[("from_wire_api", "responses_websocket")],
|
||||
);
|
||||
|
||||
self.client.set_preconnected_task(None);
|
||||
self.connection = None;
|
||||
self.websocket_last_items.clear();
|
||||
}
|
||||
|
||||
@@ -200,6 +200,7 @@ use crate::state::SessionServices;
|
||||
use crate::state::SessionState;
|
||||
use crate::state_db;
|
||||
use crate::tasks::GhostSnapshotTask;
|
||||
use crate::tasks::RegularTask;
|
||||
use crate::tasks::ReviewTask;
|
||||
use crate::tasks::SessionTask;
|
||||
use crate::tasks::SessionTaskContext;
|
||||
@@ -566,8 +567,8 @@ impl TurnContext {
|
||||
|
||||
/// Resolves the per-turn metadata header under a shared timeout policy.
|
||||
///
|
||||
/// This uses the same timeout helper as websocket startup preconnect so both turn execution
|
||||
/// and background preconnect observe identical "timeout means best-effort fallback" behavior.
|
||||
/// This uses the same timeout helper as websocket startup prewarm so both turn execution and
|
||||
/// background prewarm observe identical "timeout means best-effort fallback" behavior.
|
||||
pub async fn resolve_turn_metadata_header(&self) -> Option<String> {
|
||||
resolve_turn_metadata_header_with_timeout(
|
||||
self.build_turn_metadata_header(),
|
||||
@@ -579,7 +580,7 @@ impl TurnContext {
|
||||
/// Starts best-effort background computation of turn metadata.
|
||||
///
|
||||
/// This warms the cached value used by [`TurnContext::resolve_turn_metadata_header`] so turns
|
||||
/// and websocket preconnect are less likely to pay metadata construction latency on demand.
|
||||
/// and websocket prewarm are less likely to pay metadata construction latency on demand.
|
||||
pub fn spawn_turn_metadata_header_task(self: &Arc<Self>) {
|
||||
let context = Arc::clone(self);
|
||||
tokio::spawn(async move {
|
||||
@@ -1044,7 +1045,7 @@ impl Session {
|
||||
}
|
||||
};
|
||||
session_configuration.thread_name = thread_name.clone();
|
||||
let state = SessionState::new(session_configuration.clone());
|
||||
let mut state = SessionState::new(session_configuration.clone());
|
||||
|
||||
let services = SessionServices {
|
||||
mcp_connection_manager: Arc::new(RwLock::new(McpConnectionManager::default())),
|
||||
@@ -1082,6 +1083,18 @@ impl Session {
|
||||
),
|
||||
};
|
||||
|
||||
let turn_metadata_header = resolve_turn_metadata_header_with_timeout(
|
||||
build_turn_metadata_header(session_configuration.cwd.clone()),
|
||||
None,
|
||||
)
|
||||
.boxed();
|
||||
let startup_regular_task = RegularTask::with_startup_prewarm(
|
||||
services.model_client.clone(),
|
||||
services.otel_manager.clone(),
|
||||
turn_metadata_header,
|
||||
);
|
||||
state.set_startup_regular_task(startup_regular_task);
|
||||
|
||||
let sess = Arc::new(Session {
|
||||
conversation_id,
|
||||
tx_event: tx_event.clone(),
|
||||
@@ -1094,18 +1107,6 @@ impl Session {
|
||||
next_internal_sub_id: AtomicU64::new(0),
|
||||
});
|
||||
|
||||
// Warm a websocket in the background so the first turn can reuse it.
|
||||
// This performs only connection setup; user input is still sent later via response.create
|
||||
// when submit_turn() runs.
|
||||
let turn_metadata_header = resolve_turn_metadata_header_with_timeout(
|
||||
build_turn_metadata_header(session_configuration.cwd.clone()),
|
||||
None,
|
||||
)
|
||||
.boxed();
|
||||
sess.services
|
||||
.model_client
|
||||
.pre_establish_connection(sess.services.otel_manager.clone(), turn_metadata_header);
|
||||
|
||||
// Dispatch the SessionConfiguredEvent first and then report any errors.
|
||||
// If resuming, include converted initial messages in the payload so UIs can render them immediately.
|
||||
let initial_messages = initial_history.get_event_msgs();
|
||||
@@ -1493,6 +1494,11 @@ impl Session {
|
||||
.await
|
||||
}
|
||||
|
||||
pub(crate) async fn take_startup_regular_task(&self) -> Option<RegularTask> {
|
||||
let mut state = self.state.lock().await;
|
||||
state.take_startup_regular_task()
|
||||
}
|
||||
|
||||
async fn get_config(&self) -> std::sync::Arc<Config> {
|
||||
let state = self.state.lock().await;
|
||||
state
|
||||
@@ -2849,7 +2855,6 @@ mod handlers {
|
||||
use crate::review_prompts::resolve_review_request;
|
||||
use crate::rollout::session_index;
|
||||
use crate::tasks::CompactTask;
|
||||
use crate::tasks::RegularTask;
|
||||
use crate::tasks::UndoTask;
|
||||
use crate::tasks::UserShellCommandMode;
|
||||
use crate::tasks::UserShellCommandTask;
|
||||
@@ -2992,7 +2997,8 @@ mod handlers {
|
||||
|
||||
sess.refresh_mcp_servers_if_requested(¤t_context)
|
||||
.await;
|
||||
sess.spawn_task(Arc::clone(¤t_context), items, RegularTask)
|
||||
let regular_task = sess.take_startup_regular_task().await.unwrap_or_default();
|
||||
sess.spawn_task(Arc::clone(¤t_context), items, regular_task)
|
||||
.await;
|
||||
*previous_context = Some(current_context);
|
||||
}
|
||||
@@ -3708,6 +3714,7 @@ pub(crate) async fn run_turn(
|
||||
sess: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
input: Vec<UserInput>,
|
||||
prewarmed_client_session: Option<ModelClientSession>,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> Option<String> {
|
||||
if input.is_empty() {
|
||||
@@ -3825,7 +3832,8 @@ pub(crate) async fn run_turn(
|
||||
let turn_metadata_header = turn_context.resolve_turn_metadata_header().await;
|
||||
// `ModelClientSession` is turn-scoped and caches WebSocket + sticky routing state, so we reuse
|
||||
// one instance across retries within this turn.
|
||||
let mut client_session = sess.services.model_client.new_session();
|
||||
let mut client_session =
|
||||
prewarmed_client_session.unwrap_or_else(|| sess.services.model_client.new_session());
|
||||
|
||||
loop {
|
||||
// Note that pending_input would be something like a message the user
|
||||
|
||||
@@ -9,6 +9,7 @@ use crate::context_manager::ContextManager;
|
||||
use crate::protocol::RateLimitSnapshot;
|
||||
use crate::protocol::TokenUsage;
|
||||
use crate::protocol::TokenUsageInfo;
|
||||
use crate::tasks::RegularTask;
|
||||
use crate::truncate::TruncationPolicy;
|
||||
|
||||
/// Persistent, session-scoped state previously stored directly on `Session`.
|
||||
@@ -26,6 +27,8 @@ pub(crate) struct SessionState {
|
||||
pub(crate) initial_context_seeded: bool,
|
||||
/// Previous rollout model for one-shot model-switch handling on first turn after resume.
|
||||
pub(crate) pending_resume_previous_model: Option<String>,
|
||||
/// Startup regular task pre-created during session initialization.
|
||||
pub(crate) startup_regular_task: Option<RegularTask>,
|
||||
}
|
||||
|
||||
impl SessionState {
|
||||
@@ -41,6 +44,7 @@ impl SessionState {
|
||||
mcp_dependency_prompted: HashSet::new(),
|
||||
initial_context_seeded: false,
|
||||
pending_resume_previous_model: None,
|
||||
startup_regular_task: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -128,6 +132,14 @@ impl SessionState {
|
||||
pub(crate) fn dependency_env(&self) -> HashMap<String, String> {
|
||||
self.dependency_env.clone()
|
||||
}
|
||||
|
||||
pub(crate) fn set_startup_regular_task(&mut self, task: RegularTask) {
|
||||
self.startup_regular_task = Some(task);
|
||||
}
|
||||
|
||||
pub(crate) fn take_startup_regular_task(&mut self) -> Option<RegularTask> {
|
||||
self.startup_regular_task.take()
|
||||
}
|
||||
}
|
||||
|
||||
// Sometimes new snapshots don't include credits or plan information.
|
||||
|
||||
@@ -1,19 +1,82 @@
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
|
||||
use crate::client::ModelClient;
|
||||
use crate::client::ModelClientSession;
|
||||
use crate::codex::TurnContext;
|
||||
use crate::codex::run_turn;
|
||||
use crate::state::TaskKind;
|
||||
use async_trait::async_trait;
|
||||
use codex_otel::OtelManager;
|
||||
use codex_protocol::user_input::UserInput;
|
||||
use futures::future::BoxFuture;
|
||||
use tokio::task::JoinHandle;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::Instrument;
|
||||
use tracing::trace_span;
|
||||
use tracing::warn;
|
||||
|
||||
use super::SessionTask;
|
||||
use super::SessionTaskContext;
|
||||
|
||||
#[derive(Clone, Copy, Default)]
|
||||
pub(crate) struct RegularTask;
|
||||
type PrewarmedSessionTask = JoinHandle<Option<ModelClientSession>>;
|
||||
|
||||
pub(crate) struct RegularTask {
|
||||
prewarmed_session_task: Mutex<Option<PrewarmedSessionTask>>,
|
||||
}
|
||||
|
||||
impl Default for RegularTask {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
prewarmed_session_task: Mutex::new(None),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RegularTask {
|
||||
pub(crate) fn with_startup_prewarm(
|
||||
model_client: ModelClient,
|
||||
otel_manager: OtelManager,
|
||||
turn_metadata_header: BoxFuture<'static, Option<String>>,
|
||||
) -> Self {
|
||||
let prewarmed_session_task = tokio::spawn(async move {
|
||||
let mut client_session = model_client.new_session();
|
||||
let turn_metadata_header = turn_metadata_header.await;
|
||||
match client_session
|
||||
.prewarm_websocket(&otel_manager, turn_metadata_header.as_deref())
|
||||
.await
|
||||
{
|
||||
Ok(()) => Some(client_session),
|
||||
Err(err) => {
|
||||
warn!("startup websocket prewarm task failed: {err}");
|
||||
None
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Self {
|
||||
prewarmed_session_task: Mutex::new(Some(prewarmed_session_task)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn take_prewarmed_session(&self) -> Option<ModelClientSession> {
|
||||
let prewarmed_session_task = self
|
||||
.prewarmed_session_task
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
||||
.take();
|
||||
match prewarmed_session_task {
|
||||
Some(task) => match task.await {
|
||||
Ok(client_session) => client_session,
|
||||
Err(err) => {
|
||||
warn!("startup websocket prewarm task join failed: {err}");
|
||||
None
|
||||
}
|
||||
},
|
||||
None => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl SessionTask for RegularTask {
|
||||
@@ -34,8 +97,15 @@ impl SessionTask for RegularTask {
|
||||
sess.services
|
||||
.otel_manager
|
||||
.apply_traceparent_parent(&run_turn_span);
|
||||
run_turn(sess, ctx, input, cancellation_token)
|
||||
.instrument(run_turn_span)
|
||||
.await
|
||||
let prewarmed_client_session = self.take_prewarmed_session().await;
|
||||
run_turn(
|
||||
sess,
|
||||
ctx,
|
||||
input,
|
||||
prewarmed_client_session,
|
||||
cancellation_token,
|
||||
)
|
||||
.instrument(run_turn_span)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
//! Helpers for computing and resolving optional per-turn metadata headers.
|
||||
//!
|
||||
//! This module owns both metadata construction and the shared timeout policy used by
|
||||
//! turn execution and startup websocket preconnect. Keeping timeout behavior centralized
|
||||
//! turn execution and startup websocket prewarm. Keeping timeout behavior centralized
|
||||
//! ensures both call sites treat timeout as the same best-effort fallback condition.
|
||||
|
||||
use std::collections::BTreeMap;
|
||||
@@ -23,7 +23,7 @@ pub(crate) const TURN_METADATA_HEADER_TIMEOUT: Duration = Duration::from_millis(
|
||||
/// On timeout, this logs a warning and returns the provided fallback header.
|
||||
///
|
||||
/// Keeping this helper centralized avoids drift between turn-time metadata resolution and startup
|
||||
/// websocket preconnect, both of which need identical timeout semantics.
|
||||
/// websocket prewarm, both of which need identical timeout semantics.
|
||||
pub(crate) async fn resolve_turn_metadata_header_with_timeout<F>(
|
||||
build_header: F,
|
||||
fallback_on_timeout: Option<String>,
|
||||
|
||||
@@ -35,7 +35,6 @@ use core_test_support::responses::start_websocket_server_with_headers;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::test_codex::test_codex;
|
||||
use core_test_support::wait_for_event;
|
||||
use futures::FutureExt;
|
||||
use futures::StreamExt;
|
||||
use opentelemetry_sdk::metrics::InMemoryMetricExporter;
|
||||
use pretty_assertions::assert_eq;
|
||||
@@ -103,11 +102,11 @@ async fn responses_websocket_preconnect_reuses_connection() {
|
||||
.await;
|
||||
|
||||
let harness = websocket_harness(&server).await;
|
||||
harness
|
||||
.client
|
||||
.pre_establish_connection(harness.otel_manager.clone(), async { None }.boxed());
|
||||
|
||||
let mut client_session = harness.client.new_session();
|
||||
client_session
|
||||
.prewarm_websocket(&harness.otel_manager, None)
|
||||
.await
|
||||
.expect("websocket prewarm failed");
|
||||
let prompt = prompt_with_input(vec![message_item("hello")]);
|
||||
stream_until_complete(&mut client_session, &harness, &prompt).await;
|
||||
|
||||
@@ -128,11 +127,11 @@ async fn responses_websocket_preconnect_is_reused_even_with_header_changes() {
|
||||
.await;
|
||||
|
||||
let harness = websocket_harness(&server).await;
|
||||
harness
|
||||
.client
|
||||
.pre_establish_connection(harness.otel_manager.clone(), async { None }.boxed());
|
||||
|
||||
let mut client_session = harness.client.new_session();
|
||||
client_session
|
||||
.prewarm_websocket(&harness.otel_manager, None)
|
||||
.await
|
||||
.expect("websocket prewarm failed");
|
||||
let prompt = prompt_with_input(vec![message_item("hello")]);
|
||||
let mut stream = client_session
|
||||
.stream(
|
||||
|
||||
Reference in New Issue
Block a user