Compare commits

...

1 Commits

Author SHA1 Message Date
pakrym-oai
6ce5930139 Send deferred request 2026-02-09 14:12:50 -08:00
7 changed files with 328 additions and 157 deletions

View File

@@ -174,6 +174,8 @@ pub struct ResponseCreateWsRequest {
pub instructions: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub previous_response_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub defer: Option<bool>,
pub input: Vec<ResponseItem>,
pub tools: Vec<Value>,
pub tool_choice: String,

View File

@@ -127,6 +127,24 @@ impl ResponsesWebsocketConnection {
Ok(ResponseStream { rx_event })
}
pub async fn prewarm_deferred_create(
&self,
request: ResponsesWsRequest,
) -> Result<(), ApiError> {
let mut stream = self.stream_request(request).await?;
while let Some(event) = stream.next().await {
match event {
Ok(ResponseEvent::Completed { .. }) => return Ok(()),
Ok(_) => {}
Err(err) => return Err(err),
}
}
Err(ApiError::Stream(
"stream closed before response.completed".to_string(),
))
}
}
pub struct ResponsesWebsocketClient<A: AuthProvider> {

View File

@@ -12,19 +12,9 @@
//! requests during that turn. It caches a Responses WebSocket connection (opened lazily) and stores
//! per-turn state such as the `x-codex-turn-state` token used for sticky routing.
//!
//! Prewarm is intentionally handshake-only: it may warm a socket and capture sticky-routing
//! state, but the first `response.create` payload is still sent only when a turn starts.
//!
//! Startup prewarm is owned by turn-scoped callers (for example, a pre-created regular task). When
//! a warmed [`ModelClientSession`] is available, turn execution can reuse it; otherwise the turn
//! lazily opens a websocket on first stream call.
//!
//! ## Retry-Budget Tradeoff
//!
//! Startup prewarm is treated as the first websocket connection attempt for the first turn. If
//! it fails, the stream attempt fails and the retry/fallback loop decides whether to retry or fall
//! back. This avoids duplicate handshakes but means a failed prewarm can consume one retry
//! budget slot before any turn payload is sent.
//! Turn-scoped callers can optionally prewarm the websocket request state by sending a deferred
//! `response.create` (with empty input). When prewarmed this way, the first model request sends
//! `response.append` with the actual input items.
use std::sync::Arc;
use std::sync::OnceLock;
@@ -167,7 +157,10 @@ pub struct ModelClient {
pub struct ModelClientSession {
client: ModelClient,
connection: Option<ApiWebSocketConnection>,
websocket_last_items: Vec<ResponseItem>,
/// `None` means there is no previous websocket request in this turn.
/// `Some` means there was a previous request and the next request may append
/// if its inputs have this vector as a prefix.
websocket_last_items: Option<Vec<ResponseItem>>,
websocket_last_response_id: Option<String>,
websocket_last_response_id_rx: Option<oneshot::Receiver<String>>,
/// Turn state for sticky routing.
@@ -231,7 +224,7 @@ impl ModelClient {
ModelClientSession {
client: self.clone(),
connection: None,
websocket_last_items: Vec::new(),
websocket_last_items: None,
websocket_last_response_id: None,
websocket_last_response_id_rx: None,
turn_state: Arc::new(OnceLock::new()),
@@ -525,10 +518,10 @@ impl ModelClientSession {
// Checks whether the current request input is an incremental append to the previous request.
// If items in the new request contain all the items from the previous request we build
// a response.append request otherwise we start with a fresh response.create request.
let previous_len = self.websocket_last_items.len();
let can_append = previous_len > 0
&& input_items.starts_with(&self.websocket_last_items)
&& previous_len < input_items.len();
let previous_items = self.websocket_last_items.as_ref()?;
let previous_len = previous_items.len();
let can_append =
input_items.starts_with(previous_items) && previous_len < input_items.len();
if can_append {
Some(input_items[previous_len..].to_vec())
} else {
@@ -566,6 +559,7 @@ impl ModelClientSession {
options: &ApiResponsesOptions,
input: Vec<ResponseItem>,
previous_response_id: Option<String>,
defer: Option<bool>,
) -> ResponsesWsRequest {
let ApiResponsesOptions {
reasoning,
@@ -581,6 +575,7 @@ impl ModelClientSession {
model: model_slug.to_string(),
instructions: api_prompt.instructions.clone(),
previous_response_id,
defer,
input,
tools: api_prompt.tools.clone(),
tool_choice: "auto".to_string(),
@@ -605,6 +600,16 @@ impl ModelClientSession {
let responses_websockets_v2_enabled = self.client.responses_websockets_v2_enabled();
let incremental_items = self.get_incremental_items(&api_prompt.input);
if let Some(append_items) = incremental_items {
if self
.websocket_last_items
.as_ref()
.is_some_and(Vec::is_empty)
{
return ResponsesWsRequest::ResponseAppend(ResponseAppendWsRequest {
input: append_items,
});
}
if responses_websockets_v2_enabled
&& let Some(previous_response_id) = self.websocket_previous_response_id()
{
@@ -614,6 +619,7 @@ impl ModelClientSession {
options,
append_items,
Some(previous_response_id),
None,
);
}
@@ -630,42 +636,136 @@ impl ModelClientSession {
options,
api_prompt.input.clone(),
None,
None,
)
}
/// Opportunistically warms a websocket for this turn-scoped client session.
///
/// This performs only connection setup; it never sends prompt payloads.
async fn prewarm_websocket_deferred_response(
&mut self,
model_slug: &str,
api_prompt: &ApiPrompt,
options: &ApiResponsesOptions,
) -> std::result::Result<(), ApiError> {
let request = self.prepare_websocket_create_request(
model_slug,
api_prompt,
options,
Vec::new(),
None,
Some(true),
);
self.connection
.as_ref()
.ok_or(ApiError::Stream(
"websocket connection is unavailable".to_string(),
))?
.prewarm_deferred_create(request)
.await?;
self.websocket_last_items = Some(Vec::new());
Ok(())
}
/// Prewarms the turn-scoped websocket request state by sending a deferred
/// `response.create` (empty input), so the next model request can use
/// `response.append` with actual items.
pub async fn prewarm_websocket(
&mut self,
prompt: &Prompt,
model_info: &ModelInfo,
effort: Option<ReasoningEffortConfig>,
summary: ReasoningSummaryConfig,
otel_manager: &OtelManager,
turn_metadata_header: Option<&str>,
) -> std::result::Result<(), ApiError> {
if !self.client.responses_websocket_enabled() || self.client.disable_websockets() {
return Ok(());
}
if self.connection.is_some() {
if self.websocket_last_items.is_some() {
return Ok(());
}
let client_setup = self.client.current_client_setup().await.map_err(|err| {
ApiError::Stream(format!(
"failed to build websocket prewarm client setup: {err}"
))
let auth_manager = self.client.state.auth_manager.clone();
let api_prompt = Self::build_responses_request(prompt).map_err(|err| {
ApiError::Stream(format!("failed to build websocket prewarm prompt: {err}"))
})?;
let mut auth_recovery = auth_manager
.as_ref()
.map(super::auth::AuthManager::unauthorized_recovery);
let connection = self
.client
.connect_websocket(
otel_manager,
client_setup.api_provider,
client_setup.api_auth,
Some(Arc::clone(&self.turn_state)),
loop {
let client_setup = self.client.current_client_setup().await.map_err(|err| {
ApiError::Stream(format!(
"failed to build websocket prewarm client setup: {err}"
))
})?;
let compression = self.responses_request_compression(client_setup.auth.as_ref());
let options = self.build_responses_options(
prompt,
model_info,
effort,
summary,
turn_metadata_header,
)
.await?;
self.connection = Some(connection);
Ok(())
compression,
);
match self
.websocket_connection(
otel_manager,
client_setup.api_provider,
client_setup.api_auth,
turn_metadata_header,
&options,
)
.await
{
Ok(_) => {}
Err(ApiError::Transport(TransportError::Http { status, .. }))
if status == StatusCode::UPGRADE_REQUIRED =>
{
self.try_switch_fallback_transport(otel_manager);
return Ok(());
}
Err(ApiError::Transport(
unauthorized_transport @ TransportError::Http { status, .. },
)) if status == StatusCode::UNAUTHORIZED => {
handle_unauthorized(unauthorized_transport, &mut auth_recovery)
.await
.map_err(|err| {
ApiError::Stream(format!(
"websocket prewarm auth recovery failed: {err}"
))
})?;
continue;
}
Err(err) => return Err(err),
}
match self
.prewarm_websocket_deferred_response(&model_info.slug, &api_prompt, &options)
.await
{
Ok(()) => return Ok(()),
Err(ApiError::Transport(TransportError::Http { status, .. }))
if status == StatusCode::UPGRADE_REQUIRED =>
{
self.try_switch_fallback_transport(otel_manager);
return Ok(());
}
Err(ApiError::Transport(
unauthorized_transport @ TransportError::Http { status, .. },
)) if status == StatusCode::UNAUTHORIZED => {
handle_unauthorized(unauthorized_transport, &mut auth_recovery)
.await
.map_err(|err| {
ApiError::Stream(format!(
"websocket prewarm auth recovery failed: {err}"
))
})?;
continue;
}
Err(err) => return Err(err),
}
}
}
/// Returns a websocket connection for this turn.
@@ -683,7 +783,7 @@ impl ModelClientSession {
};
if needs_new {
self.websocket_last_items.clear();
self.websocket_last_items = None;
self.websocket_last_response_id = None;
self.websocket_last_response_id_rx = None;
let turn_state = options
@@ -858,7 +958,7 @@ impl ModelClientSession {
.stream_request(request)
.await
.map_err(map_api_error)?;
self.websocket_last_items = api_prompt.input.clone();
self.websocket_last_items = Some(api_prompt.input.clone());
let (last_response_id_sender, last_response_id_receiver) = oneshot::channel();
self.websocket_last_response_id_rx = Some(last_response_id_receiver);
let mut last_response_id_sender = Some(last_response_id_sender);
@@ -967,7 +1067,7 @@ impl ModelClientSession {
);
self.connection = None;
self.websocket_last_items.clear();
self.websocket_last_items = None;
}
activated
}

View File

@@ -14,6 +14,7 @@ use crate::agent::MAX_THREAD_SPAWN_DEPTH;
use crate::agent::agent_status_from_event;
use crate::analytics_client::AnalyticsEventsClient;
use crate::analytics_client::build_track_events_context;
use crate::api_bridge::map_api_error;
use crate::compact;
use crate::compact::run_inline_auto_compact_task;
use crate::compact::should_use_remote_compact_task;
@@ -1083,17 +1084,7 @@ impl Session {
),
};
let turn_metadata_header = resolve_turn_metadata_header_with_timeout(
build_turn_metadata_header(session_configuration.cwd.clone()),
None,
)
.boxed();
let startup_regular_task = RegularTask::with_startup_prewarm(
services.model_client.clone(),
services.otel_manager.clone(),
turn_metadata_header,
);
state.set_startup_regular_task(startup_regular_task);
state.set_startup_regular_task(RegularTask);
let sess = Arc::new(Session {
conversation_id,
@@ -4150,6 +4141,69 @@ async fn run_sampling_request(
let mut retries = 0;
loop {
if let Err(err) = client_session
.prewarm_websocket(
&prompt,
&turn_context.model_info,
turn_context.reasoning_effort,
turn_context.reasoning_summary,
&turn_context.otel_manager,
turn_metadata_header,
)
.await
{
let err = map_api_error(err);
match err {
CodexErr::ContextWindowExceeded => {
sess.set_total_tokens_full(&turn_context).await;
return Err(CodexErr::ContextWindowExceeded);
}
CodexErr::UsageLimitReached(e) => {
if let Some(rate_limits) = e.rate_limits.clone() {
sess.update_rate_limits(&turn_context, rate_limits).await;
}
return Err(CodexErr::UsageLimitReached(e));
}
_ => {}
}
if !err.is_retryable() {
return Err(err);
}
let max_retries = turn_context.provider.stream_max_retries();
if retries >= max_retries
&& client_session.try_switch_fallback_transport(&turn_context.otel_manager)
{
sess.send_event(
&turn_context,
EventMsg::Warning(WarningEvent {
message: format!(
"Falling back from WebSockets to HTTPS transport. {err:#}"
),
}),
)
.await;
retries = 0;
} else if retries < max_retries {
retries += 1;
let delay = backoff(retries);
warn!(
"websocket prewarm failed - retrying sampling request ({retries}/{max_retries} in {delay:?})...",
);
sess.notify_stream_error(
&turn_context,
format!("Reconnecting... {retries}/{max_retries}"),
err,
)
.await;
tokio::time::sleep(delay).await;
} else {
return Err(err);
}
continue;
}
let err = match try_run_sampling_request(
Arc::clone(&router),
Arc::clone(&sess),

View File

@@ -1,80 +1,22 @@
use std::sync::Arc;
use std::sync::Mutex;
use crate::client::ModelClient;
use crate::client::ModelClientSession;
use crate::codex::TurnContext;
use crate::codex::run_turn;
use crate::state::TaskKind;
use async_trait::async_trait;
use codex_otel::OtelManager;
use codex_protocol::user_input::UserInput;
use futures::future::BoxFuture;
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
use tracing::Instrument;
use tracing::trace_span;
use tracing::warn;
use super::SessionTask;
use super::SessionTaskContext;
type PrewarmedSessionTask = JoinHandle<Option<ModelClientSession>>;
pub(crate) struct RegularTask {
prewarmed_session_task: Mutex<Option<PrewarmedSessionTask>>,
}
pub(crate) struct RegularTask;
impl Default for RegularTask {
fn default() -> Self {
Self {
prewarmed_session_task: Mutex::new(None),
}
}
}
impl RegularTask {
pub(crate) fn with_startup_prewarm(
model_client: ModelClient,
otel_manager: OtelManager,
turn_metadata_header: BoxFuture<'static, Option<String>>,
) -> Self {
let prewarmed_session_task = tokio::spawn(async move {
let mut client_session = model_client.new_session();
let turn_metadata_header = turn_metadata_header.await;
match client_session
.prewarm_websocket(&otel_manager, turn_metadata_header.as_deref())
.await
{
Ok(()) => Some(client_session),
Err(err) => {
warn!("startup websocket prewarm task failed: {err}");
None
}
}
});
Self {
prewarmed_session_task: Mutex::new(Some(prewarmed_session_task)),
}
}
async fn take_prewarmed_session(&self) -> Option<ModelClientSession> {
let prewarmed_session_task = self
.prewarmed_session_task
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.take();
match prewarmed_session_task {
Some(task) => match task.await {
Ok(client_session) => client_session,
Err(err) => {
warn!("startup websocket prewarm task join failed: {err}");
None
}
},
None => None,
}
Self
}
}
@@ -97,15 +39,8 @@ impl SessionTask for RegularTask {
sess.services
.otel_manager
.apply_traceparent_parent(&run_turn_span);
let prewarmed_client_session = self.take_prewarmed_session().await;
run_turn(
sess,
ctx,
input,
prewarmed_client_session,
cancellation_token,
)
.instrument(run_turn_span)
.await
run_turn(sess, ctx, input, None, cancellation_token)
.instrument(run_turn_span)
.await
}
}

View File

@@ -22,6 +22,10 @@ async fn websocket_test_codex_shell_chain() -> Result<()> {
let call_id = "shell-command-call";
let server = start_websocket_server(vec![vec![
vec![
ev_response_created("resp-warm-1"),
ev_completed("resp-warm-1"),
],
vec![
ev_response_created("resp-1"),
ev_shell_command_call(call_id, "echo websocket"),
@@ -41,7 +45,7 @@ async fn websocket_test_codex_shell_chain() -> Result<()> {
test.submit_turn("run the echo command").await?;
let connection = server.single_connection();
assert_eq!(connection.len(), 2);
assert_eq!(connection.len(), 3);
let first = connection
.first()
@@ -51,11 +55,18 @@ async fn websocket_test_codex_shell_chain() -> Result<()> {
.get(1)
.expect("missing second request")
.body_json();
let third = connection
.get(2)
.expect("missing third request")
.body_json();
assert_eq!(first["type"].as_str(), Some("response.create"));
assert_eq!(first["defer"], Value::Bool(true));
assert_eq!(first["input"], Value::Array(vec![]));
assert_eq!(second["type"].as_str(), Some("response.append"));
assert_eq!(third["type"].as_str(), Some("response.append"));
let append_items = second
let append_items = third
.get("input")
.and_then(Value::as_array)
.expect("response.append input array");
@@ -75,40 +86,44 @@ async fn websocket_test_codex_shell_chain() -> Result<()> {
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn websocket_preconnect_happens_on_session_start() -> Result<()> {
async fn websocket_prewarm_happens_on_first_turn() -> Result<()> {
skip_if_no_network!(Ok(()));
let server = start_websocket_server(vec![vec![vec![
ev_response_created("resp-1"),
ev_completed("resp-1"),
]]])
let server = start_websocket_server(vec![vec![
vec![
ev_response_created("resp-warm-1"),
ev_completed("resp-warm-1"),
],
vec![ev_response_created("resp-1"), ev_completed("resp-1")],
]])
.await;
let mut builder = test_codex();
let test = builder.build_with_websocket_server(&server).await?;
assert!(
server.wait_for_handshakes(1, Duration::from_secs(2)).await,
"expected websocket preconnect handshake during session startup"
);
test.submit_turn("hello").await?;
assert_eq!(server.handshakes().len(), 1);
assert_eq!(server.single_connection().len(), 1);
assert_eq!(server.single_connection().len(), 2);
server.shutdown().await;
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn websocket_first_turn_waits_for_inflight_preconnect() -> Result<()> {
async fn websocket_first_turn_waits_for_inflight_connect() -> Result<()> {
skip_if_no_network!(Ok(()));
let server = start_websocket_server_with_headers(vec![WebSocketConnectionConfig {
requests: vec![vec![ev_response_created("resp-1"), ev_completed("resp-1")]],
requests: vec![
vec![
ev_response_created("resp-warm-1"),
ev_completed("resp-warm-1"),
],
vec![ev_response_created("resp-1"), ev_completed("resp-1")],
],
response_headers: Vec::new(),
// Delay handshake so submit_turn() observes startup preconnect as in-flight.
// Delay handshake so submit_turn() observes websocket connect as in-flight.
accept_delay: Some(Duration::from_millis(150)),
}])
.await;
@@ -118,7 +133,7 @@ async fn websocket_first_turn_waits_for_inflight_preconnect() -> Result<()> {
test.submit_turn("hello").await?;
assert_eq!(server.handshakes().len(), 1);
assert_eq!(server.single_connection().len(), 1);
assert_eq!(server.single_connection().len(), 2);
server.shutdown().await;
Ok(())
@@ -130,6 +145,10 @@ async fn websocket_v2_test_codex_shell_chain() -> Result<()> {
let call_id = "shell-command-call";
let server = start_websocket_server(vec![vec![
vec![
ev_response_created("resp-warm-1"),
ev_completed("resp-warm-1"),
],
vec![
ev_response_created("resp-1"),
ev_shell_command_call(call_id, "echo websocket"),
@@ -151,7 +170,7 @@ async fn websocket_v2_test_codex_shell_chain() -> Result<()> {
test.submit_turn("run the echo command").await?;
let connection = server.single_connection();
assert_eq!(connection.len(), 2);
assert_eq!(connection.len(), 3);
let first = connection
.first()
@@ -161,12 +180,19 @@ async fn websocket_v2_test_codex_shell_chain() -> Result<()> {
.get(1)
.expect("missing second request")
.body_json();
let third = connection
.get(2)
.expect("missing third request")
.body_json();
assert_eq!(first["type"].as_str(), Some("response.create"));
assert_eq!(second["type"].as_str(), Some("response.create"));
assert_eq!(second["previous_response_id"].as_str(), Some("resp-1"));
assert_eq!(first["defer"], Value::Bool(true));
assert_eq!(first["input"], Value::Array(vec![]));
assert_eq!(second["type"].as_str(), Some("response.append"));
assert_eq!(third["type"].as_str(), Some("response.create"));
assert_eq!(third["previous_response_id"].as_str(), Some("resp-1"));
let create_items = second
let create_items = third
.get("input")
.and_then(Value::as_array)
.expect("response.create input array");

View File

@@ -95,23 +95,41 @@ async fn responses_websocket_streams_request() {
async fn responses_websocket_preconnect_reuses_connection() {
skip_if_no_network!();
let server = start_websocket_server(vec![vec![vec![
ev_response_created("resp-1"),
ev_completed("resp-1"),
]]])
let server = start_websocket_server(vec![vec![
vec![
ev_response_created("resp-warm-1"),
ev_completed("resp-warm-1"),
],
vec![ev_response_created("resp-1"), ev_completed("resp-1")],
]])
.await;
let harness = websocket_harness(&server).await;
let mut client_session = harness.client.new_session();
let prompt = prompt_with_input(vec![message_item("hello")]);
client_session
.prewarm_websocket(&harness.otel_manager, None)
.prewarm_websocket(
&prompt,
&harness.model_info,
harness.effort,
harness.summary,
&harness.otel_manager,
None,
)
.await
.expect("websocket prewarm failed");
let prompt = prompt_with_input(vec![message_item("hello")]);
stream_until_complete(&mut client_session, &harness, &prompt).await;
assert_eq!(server.handshakes().len(), 1);
assert_eq!(server.single_connection().len(), 1);
let connection = server.single_connection();
assert_eq!(connection.len(), 2);
let first = connection.first().expect("missing request").body_json();
let second = connection.get(1).expect("missing request").body_json();
assert_eq!(first["type"].as_str(), Some("response.create"));
assert_eq!(first["defer"], json!(true));
assert_eq!(first["input"], json!([]));
assert_eq!(second["type"].as_str(), Some("response.append"));
assert_eq!(second["input"], serde_json::to_value(prompt.input).unwrap());
server.shutdown().await;
}
@@ -120,19 +138,29 @@ async fn responses_websocket_preconnect_reuses_connection() {
async fn responses_websocket_preconnect_is_reused_even_with_header_changes() {
skip_if_no_network!();
let server = start_websocket_server(vec![vec![vec![
ev_response_created("resp-1"),
ev_completed("resp-1"),
]]])
let server = start_websocket_server(vec![vec![
vec![
ev_response_created("resp-warm-1"),
ev_completed("resp-warm-1"),
],
vec![ev_response_created("resp-1"), ev_completed("resp-1")],
]])
.await;
let harness = websocket_harness(&server).await;
let mut client_session = harness.client.new_session();
let prompt = prompt_with_input(vec![message_item("hello")]);
client_session
.prewarm_websocket(&harness.otel_manager, None)
.prewarm_websocket(
&prompt,
&harness.model_info,
harness.effort,
harness.summary,
&harness.otel_manager,
None,
)
.await
.expect("websocket prewarm failed");
let prompt = prompt_with_input(vec![message_item("hello")]);
let mut stream = client_session
.stream(
&prompt,
@@ -152,7 +180,15 @@ async fn responses_websocket_preconnect_is_reused_even_with_header_changes() {
}
assert_eq!(server.handshakes().len(), 1);
assert_eq!(server.single_connection().len(), 1);
let connection = server.single_connection();
assert_eq!(connection.len(), 2);
let first = connection.first().expect("missing request").body_json();
let second = connection.get(1).expect("missing request").body_json();
assert_eq!(first["type"].as_str(), Some("response.create"));
assert_eq!(first["defer"], json!(true));
assert_eq!(first["input"], json!([]));
assert_eq!(second["type"].as_str(), Some("response.append"));
assert_eq!(second["input"], serde_json::to_value(prompt.input).unwrap());
server.shutdown().await;
}