Files
codex/codex-rs/core/src/client.rs
jif-oai 20fedafff8 Trace logical websocket request after untraced warmup (#23581)
## Why

`prewarm_websocket` intentionally stays out of rollout inference
tracing, but the next traced websocket request can still reuse the
warmup `response_id` and send an empty `input` delta. If tracing records
that wire payload verbatim, replay sees an incremental request whose
parent was never traced and cannot reconstruct the conversation.

This fixes that at the producer boundary instead of relaxing
`rollout-trace` replay semantics around unresolved
`previous_response_id` values.

## What

- track whether the last websocket response came from an untraced warmup
and clear that state when the websocket session is reset or reconnected
- when a traced websocket request reuses that warmup parent, keep
sending the compressed websocket request on the wire but record the
logical `ResponsesApiRequest` in the rollout trace
- add a regression test that proves replay reconstructs the logical user
message even though the websocket follow-up carries
`previous_response_id = warm-1` with empty `input`
- update `InferenceTraceAttempt::record_started` docs to reflect that
callers may record a logical request rather than the exact transport
payload

## Testing

- `cargo test -p codex-core --test all
responses_websocket_request_prewarm_traces_logical_request`
2026-05-21 11:13:23 +02:00

2246 lines
86 KiB
Rust

//! Session- and turn-scoped helpers for talking to model provider APIs.
//!
//! `ModelClient` is intended to live for the lifetime of a Codex session and holds the stable
//! configuration and state needed to talk to a provider (auth, provider selection, conversation id,
//! and transport fallback state).
//!
//! Per-turn settings (model selection, reasoning controls, telemetry context, and turn metadata)
//! are passed explicitly to streaming and unary methods so that the turn lifetime is visible at the
//! call site.
//!
//! A [`ModelClientSession`] is created per turn and is used to stream one or more Responses API
//! requests during that turn. It caches a Responses WebSocket connection (opened lazily) and stores
//! per-turn state such as the `x-codex-turn-state` token used for sticky routing.
//!
//! WebSocket prewarm is a v2-only `response.create` with `generate=false`; it waits for completion
//! so the next request can reuse the same connection and `previous_response_id`.
//!
//! Turn execution performs prewarm as a best-effort step before the first stream request so the
//! subsequent request can reuse the same connection.
//!
//! ## Retry-Budget Tradeoff
//!
//! WebSocket prewarm is treated as the first websocket connection attempt for a turn. If it
//! fails, normal stream retry/fallback logic handles recovery on the same turn.
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::Mutex as StdMutex;
use std::sync::OnceLock;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::AtomicU64;
use std::sync::atomic::Ordering;
use codex_api::ApiError;
use codex_api::AuthProvider;
use codex_api::CompactClient as ApiCompactClient;
use codex_api::CompactionInput as ApiCompactionInput;
use codex_api::Compression;
use codex_api::MemoriesClient as ApiMemoriesClient;
use codex_api::MemorySummarizeInput as ApiMemorySummarizeInput;
use codex_api::MemorySummarizeOutput as ApiMemorySummarizeOutput;
use codex_api::Provider as ApiProvider;
use codex_api::RawMemory as ApiRawMemory;
use codex_api::RealtimeCallClient as ApiRealtimeCallClient;
use codex_api::RealtimeSessionConfig as ApiRealtimeSessionConfig;
use codex_api::Reasoning;
use codex_api::RequestTelemetry;
use codex_api::ReqwestTransport;
use codex_api::ResponseCreateWsRequest;
use codex_api::ResponsesApiRequest;
use codex_api::ResponsesClient as ApiResponsesClient;
use codex_api::ResponsesOptions as ApiResponsesOptions;
use codex_api::ResponsesWebsocketClient as ApiWebSocketResponsesClient;
use codex_api::ResponsesWebsocketConnection as ApiWebSocketConnection;
use codex_api::ResponsesWsRequest;
use codex_api::SharedAuthProvider;
use codex_api::SseTelemetry;
use codex_api::TransportError;
use codex_api::WebsocketTelemetry;
use codex_api::auth_header_telemetry;
use codex_api::build_session_headers;
use codex_api::create_text_param_for_request;
use codex_api::response_create_client_metadata;
use codex_app_server_protocol::AuthMode;
use codex_login::AuthManager;
use codex_login::CodexAuth;
use codex_login::RefreshTokenError;
use codex_login::UnauthorizedRecovery;
use codex_login::default_client::build_reqwest_client;
use codex_otel::SessionTelemetry;
use codex_otel::current_span_w3c_trace_context;
use codex_protocol::SessionId;
use codex_protocol::ThreadId;
use codex_protocol::config_types::ReasoningSummary as ReasoningSummaryConfig;
use codex_protocol::config_types::Verbosity as VerbosityConfig;
use codex_protocol::models::ResponseItem;
use codex_protocol::openai_models::ModelInfo;
use codex_protocol::openai_models::ReasoningEffort as ReasoningEffortConfig;
use codex_protocol::protocol::InternalSessionSource;
use codex_protocol::protocol::SessionSource;
use codex_protocol::protocol::SubAgentSource;
use codex_protocol::protocol::W3cTraceContext;
use codex_rollout_trace::CompactionTraceContext;
use codex_rollout_trace::InferenceTraceAttempt;
use codex_rollout_trace::InferenceTraceContext;
use codex_tools::create_tools_json_for_responses_api;
use eventsource_stream::Event;
use eventsource_stream::EventStreamError;
use futures::StreamExt;
use http::HeaderMap as ApiHeaderMap;
use http::HeaderValue;
use http::StatusCode as HttpStatusCode;
use reqwest::StatusCode;
use std::time::Duration;
use std::time::Instant;
use tokio::sync::mpsc;
use tokio::sync::oneshot;
use tokio::sync::oneshot::error::TryRecvError;
use tokio_tungstenite::tungstenite::Error;
use tokio_tungstenite::tungstenite::Message;
use tokio_util::sync::CancellationToken;
use tracing::debug;
use tracing::instrument;
use tracing::trace;
use tracing::warn;
use crate::attestation::AttestationContext;
use crate::attestation::AttestationProvider;
use crate::attestation::X_OAI_ATTESTATION_HEADER;
use crate::client_common::Prompt;
use crate::client_common::ResponseEvent;
use crate::client_common::ResponseStream;
use crate::feedback_tags;
use crate::util::emit_feedback_auth_recovery_tags;
use codex_api::map_api_error;
use codex_feedback::FeedbackRequestTags;
use codex_feedback::emit_feedback_request_tags_with_auth_env;
use codex_login::auth_env_telemetry::AuthEnvTelemetry;
use codex_login::auth_env_telemetry::collect_auth_env_telemetry;
use codex_model_provider::SharedModelProvider;
use codex_model_provider::create_model_provider;
#[cfg(test)]
use codex_model_provider_info::DEFAULT_WEBSOCKET_CONNECT_TIMEOUT_MS;
use codex_model_provider_info::ModelProviderInfo;
use codex_model_provider_info::WireApi;
use codex_protocol::error::CodexErr;
use codex_protocol::error::Result;
use codex_response_debug_context::extract_response_debug_context;
use codex_response_debug_context::extract_response_debug_context_from_api_error;
use codex_response_debug_context::telemetry_api_error_message;
use codex_response_debug_context::telemetry_transport_error_message;
pub const OPENAI_BETA_HEADER: &str = "OpenAI-Beta";
pub const X_CODEX_INSTALLATION_ID_HEADER: &str = "x-codex-installation-id";
pub const X_CODEX_TURN_STATE_HEADER: &str = "x-codex-turn-state";
pub const X_CODEX_TURN_METADATA_HEADER: &str = "x-codex-turn-metadata";
pub const X_CODEX_PARENT_THREAD_ID_HEADER: &str = "x-codex-parent-thread-id";
pub const X_CODEX_WINDOW_ID_HEADER: &str = "x-codex-window-id";
pub const X_OPENAI_MEMGEN_REQUEST_HEADER: &str = "x-openai-memgen-request";
pub const X_OPENAI_SUBAGENT_HEADER: &str = "x-openai-subagent";
pub const X_RESPONSESAPI_INCLUDE_TIMING_METRICS_HEADER: &str =
"x-responsesapi-include-timing-metrics";
const X_CODEX_WS_STREAM_REQUEST_START_MS_CLIENT_METADATA_KEY: &str =
"x-codex-ws-stream-request-start-ms";
const RESPONSES_WEBSOCKETS_V2_BETA_HEADER_VALUE: &str = "responses_websockets=2026-02-06";
const RESPONSES_ENDPOINT: &str = "/responses";
const RESPONSES_COMPACT_ENDPOINT: &str = "/responses/compact";
// `/responses/compact` is unary, so the timeout covers the full response rather than one idle
// period between stream events.
const COMPACT_REQUEST_TIMEOUT_IDLE_MULTIPLIER: u32 = 4;
const MEMORIES_SUMMARIZE_ENDPOINT: &str = "/memories/trace_summarize";
#[cfg(test)]
pub(crate) const WEBSOCKET_CONNECT_TIMEOUT: Duration =
Duration::from_millis(DEFAULT_WEBSOCKET_CONNECT_TIMEOUT_MS);
pub(crate) struct CompactConversationRequestSettings {
pub(crate) effort: Option<ReasoningEffortConfig>,
pub(crate) summary: ReasoningSummaryConfig,
pub(crate) service_tier: Option<String>,
}
/// Session-scoped state shared by all [`ModelClient`] clones.
///
/// This is intentionally kept minimal so `ModelClient` does not need to hold a full `Config`. Most
/// configuration is per turn and is passed explicitly to streaming/unary methods.
#[derive(Debug)]
struct ModelClientState {
session_id: SessionId,
thread_id: ThreadId,
window_generation: AtomicU64,
installation_id: String,
provider: SharedModelProvider,
auth_env_telemetry: AuthEnvTelemetry,
session_source: SessionSource,
model_verbosity: Option<VerbosityConfig>,
enable_request_compression: bool,
include_timing_metrics: bool,
beta_features_header: Option<String>,
include_attestation: bool,
attestation_provider: Option<Arc<dyn AttestationProvider>>,
disable_websockets: AtomicBool,
cached_websocket_session: StdMutex<WebsocketSession>,
}
/// Resolved API client setup for a single request attempt.
///
/// Keeping this as a single bundle ensures prewarm and normal request paths
/// share the same auth/provider setup flow.
struct CurrentClientSetup {
auth: Option<CodexAuth>,
api_provider: ApiProvider,
api_auth: SharedAuthProvider,
}
#[derive(Clone, Copy)]
struct RequestRouteTelemetry {
endpoint: &'static str,
}
impl RequestRouteTelemetry {
fn for_endpoint(endpoint: &'static str) -> Self {
Self { endpoint }
}
}
/// A session-scoped client for model-provider API calls.
///
/// This holds configuration and state that should be shared across turns within a Codex session
/// (auth, provider selection, thread id, and transport fallback state).
///
/// WebSocket fallback is session-scoped: once a turn activates the HTTP fallback, subsequent turns
/// will also use HTTP for the remainder of the session.
///
/// Turn-scoped settings (model selection, reasoning controls, telemetry context, and turn
/// metadata) are passed explicitly to the relevant methods to keep turn lifetime visible at the
/// call site.
#[derive(Debug, Clone)]
pub struct ModelClient {
state: Arc<ModelClientState>,
}
/// A turn-scoped streaming session created from a [`ModelClient`].
///
/// The session establishes a Responses WebSocket connection lazily and reuses it across multiple
/// requests within the turn. It also caches per-turn state:
///
/// - The last full request, so subsequent calls can reuse incremental websocket request payloads
/// only when the current request is an incremental extension of the previous one.
/// - The `x-codex-turn-state` sticky-routing token, which must be replayed for all requests within
/// the same turn.
///
/// Create a fresh `ModelClientSession` for each Codex turn. Reusing it across turns would replay
/// the previous turn's sticky-routing token into the next turn, which violates the client/server
/// contract and can cause routing bugs.
pub struct ModelClientSession {
client: ModelClient,
websocket_session: WebsocketSession,
/// Turn state for sticky routing.
///
/// This is an `OnceLock` that stores the turn state value received from the server
/// on turn start via the `x-codex-turn-state` response header. Once set, this value
/// should be sent back to the server in the `x-codex-turn-state` request header for
/// all subsequent requests within the same turn to maintain sticky routing.
///
/// This is a contract between the client and server: we receive it at turn start,
/// keep sending it unchanged between turn requests (e.g., for retries, incremental
/// appends, or continuation requests), and must not send it between different turns.
turn_state: Arc<OnceLock<String>>,
}
#[derive(Debug, Clone)]
struct LastResponse {
response_id: String,
items_added: Vec<ResponseItem>,
}
#[derive(Debug, Default)]
struct WebsocketSession {
connection: Option<ApiWebSocketConnection>,
last_request: Option<ResponsesApiRequest>,
last_response_rx: Option<oneshot::Receiver<LastResponse>>,
last_response_from_untraced_warmup: bool,
connection_reused: StdMutex<bool>,
}
impl WebsocketSession {
fn set_connection_reused(&self, connection_reused: bool) {
*self
.connection_reused
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner) = connection_reused;
}
fn connection_reused(&self) -> bool {
*self
.connection_reused
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
}
}
enum WebsocketStreamOutcome {
Stream(ResponseStream),
FallbackToHttp,
}
/// Result of opening a WebRTC Realtime call.
///
/// The SDP answer goes back to the client. The call id and auth headers stay on the server so the
/// ordinary Realtime WebSocket machinery can join the same in-progress call as a sideband
/// controller.
pub(crate) struct RealtimeWebrtcCallStart {
pub(crate) sdp: String,
pub(crate) call_id: String,
pub(crate) sideband_headers: ApiHeaderMap,
}
/// Reuses the API-auth material that created the WebRTC call for the sideband WebSocket join.
///
/// API-key sessions send that API bearer. ChatGPT-auth sessions send their bearer plus account id;
/// transceiver is responsible for accepting that same call-create identity on the direct
/// `api.openai.com` sideband path.
fn sideband_websocket_auth_headers(api_auth: &dyn AuthProvider) -> ApiHeaderMap {
let mut headers = ApiHeaderMap::new();
api_auth.add_auth_headers(&mut headers);
headers
}
impl ModelClient {
#[allow(clippy::too_many_arguments)]
/// Creates a new session-scoped `ModelClient`.
///
/// All arguments are expected to be stable for the lifetime of a Codex session. Per-turn values
/// are passed to [`ModelClientSession::stream`] (and other turn-scoped methods) explicitly.
pub fn new(
auth_manager: Option<Arc<AuthManager>>,
session_id: SessionId,
thread_id: ThreadId,
installation_id: String,
provider_info: ModelProviderInfo,
session_source: SessionSource,
model_verbosity: Option<VerbosityConfig>,
enable_request_compression: bool,
include_timing_metrics: bool,
beta_features_header: Option<String>,
attestation_provider: Option<Arc<dyn AttestationProvider>>,
) -> Self {
let model_provider = create_model_provider(provider_info, auth_manager);
let codex_api_key_env_enabled = model_provider
.auth_manager()
.as_ref()
.is_some_and(|manager| manager.codex_api_key_env_enabled());
let auth_env_telemetry =
collect_auth_env_telemetry(model_provider.info(), codex_api_key_env_enabled);
let include_attestation = model_provider.supports_attestation();
Self {
state: Arc::new(ModelClientState {
session_id,
thread_id,
window_generation: AtomicU64::new(0),
installation_id,
provider: model_provider,
auth_env_telemetry,
session_source,
model_verbosity,
enable_request_compression,
include_timing_metrics,
beta_features_header,
include_attestation,
attestation_provider,
disable_websockets: AtomicBool::new(false),
cached_websocket_session: StdMutex::new(WebsocketSession::default()),
}),
}
}
/// Creates a fresh turn-scoped streaming session.
///
/// This constructor does not perform network I/O itself; the session opens a websocket lazily
/// when the first stream request is issued.
pub fn new_session(&self) -> ModelClientSession {
ModelClientSession {
client: self.clone(),
websocket_session: self.take_cached_websocket_session(),
turn_state: Arc::new(OnceLock::new()),
}
}
pub(crate) fn auth_manager(&self) -> Option<Arc<AuthManager>> {
self.state.provider.auth_manager()
}
pub(crate) fn set_window_generation(&self, window_generation: u64) {
self.state
.window_generation
.store(window_generation, Ordering::Relaxed);
self.store_cached_websocket_session(WebsocketSession::default());
}
pub(crate) fn advance_window_generation(&self) {
self.state.window_generation.fetch_add(1, Ordering::Relaxed);
self.store_cached_websocket_session(WebsocketSession::default());
}
fn current_window_id(&self) -> String {
let thread_id = self.state.thread_id;
let window_generation = self.state.window_generation.load(Ordering::Relaxed);
format!("{thread_id}:{window_generation}")
}
fn take_cached_websocket_session(&self) -> WebsocketSession {
let mut cached_websocket_session = self
.state
.cached_websocket_session
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
std::mem::take(&mut *cached_websocket_session)
}
fn store_cached_websocket_session(&self, websocket_session: WebsocketSession) {
*self
.state
.cached_websocket_session
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner) = websocket_session;
}
pub(crate) fn force_http_fallback(
&self,
session_telemetry: &SessionTelemetry,
_model_info: &ModelInfo,
) -> bool {
let websocket_enabled = self.responses_websocket_enabled();
let activated =
websocket_enabled && !self.state.disable_websockets.swap(true, Ordering::Relaxed);
if activated {
warn!("falling back to HTTP");
session_telemetry.counter(
"codex.transport.fallback_to_http",
/*inc*/ 1,
&[("from_wire_api", "responses_websocket")],
);
}
self.store_cached_websocket_session(WebsocketSession::default());
activated
}
/// Compacts the current conversation history using the Compact endpoint.
///
/// This is a unary call (no streaming) that returns a new list of
/// `ResponseItem`s representing the compacted transcript.
///
/// The model selection and telemetry context are passed explicitly to keep `ModelClient`
/// session-scoped.
pub(crate) async fn compact_conversation_history(
&self,
prompt: &Prompt,
model_info: &ModelInfo,
settings: CompactConversationRequestSettings,
session_telemetry: &SessionTelemetry,
compaction_trace: &CompactionTraceContext,
) -> Result<Vec<ResponseItem>> {
if prompt.input.is_empty() {
return Ok(Vec::new());
}
let client_setup = self.current_client_setup().await?;
let transport = ReqwestTransport::new(build_reqwest_client());
let request_telemetry = Self::build_request_telemetry(
session_telemetry,
AuthRequestTelemetryContext::new(
client_setup.auth.as_ref().map(CodexAuth::auth_mode),
client_setup.api_auth.as_ref(),
PendingUnauthorizedRetry::default(),
),
RequestRouteTelemetry::for_endpoint(RESPONSES_COMPACT_ENDPOINT),
self.state.auth_env_telemetry.clone(),
);
let request = self.build_responses_request(
&client_setup.api_provider,
prompt,
model_info,
settings.effort,
settings.summary,
settings.service_tier,
)?;
let ResponsesApiRequest {
model,
instructions,
input,
tools,
parallel_tool_calls,
reasoning,
service_tier,
prompt_cache_key,
text,
..
} = request;
let payload = ApiCompactionInput {
model: &model,
input: &input,
instructions: &instructions,
tools,
parallel_tool_calls,
reasoning,
service_tier: service_tier.as_deref(),
prompt_cache_key: prompt_cache_key.as_deref(),
text,
};
let mut extra_headers = ApiHeaderMap::new();
if let Ok(header_value) = HeaderValue::from_str(&self.state.installation_id) {
extra_headers.insert(X_CODEX_INSTALLATION_ID_HEADER, header_value);
}
extra_headers.extend(build_responses_headers(
self.state.beta_features_header.as_deref(),
/*turn_state*/ None,
/*turn_metadata_header*/ None,
));
extra_headers.extend(self.build_responses_identity_headers());
extra_headers.extend(build_session_headers(
Some(self.state.session_id.to_string()),
Some(self.state.thread_id.to_string()),
));
if let Some(header_value) = self.generate_attestation_header_for().await {
extra_headers.insert(X_OAI_ATTESTATION_HEADER, header_value);
}
let compact_request_timeout = client_setup
.api_provider
.stream_idle_timeout
.saturating_mul(COMPACT_REQUEST_TIMEOUT_IDLE_MULTIPLIER);
let client =
ApiCompactClient::new(transport, client_setup.api_provider, client_setup.api_auth)
.with_telemetry(Some(request_telemetry));
let trace_attempt = compaction_trace.start_attempt(&payload);
let result = client
.compact_input(&payload, extra_headers, compact_request_timeout)
.await
.map_err(map_api_error);
trace_attempt.record_result(result.as_deref());
result
}
pub(crate) async fn create_realtime_call_with_headers(
&self,
sdp: String,
session_config: ApiRealtimeSessionConfig,
mut extra_headers: ApiHeaderMap,
) -> Result<RealtimeWebrtcCallStart> {
// Create the media call over HTTP first, then retain matching auth so realtime can attach
// the server-side control WebSocket to the call id from that HTTP response.
let client_setup = self.current_client_setup().await?;
if let Some(header_value) = self.generate_attestation_header_for().await {
extra_headers.insert(X_OAI_ATTESTATION_HEADER, header_value);
}
let mut sideband_headers = extra_headers.clone();
sideband_headers.extend(sideband_websocket_auth_headers(
client_setup.api_auth.as_ref(),
));
let transport = ReqwestTransport::new(build_reqwest_client());
let response =
ApiRealtimeCallClient::new(transport, client_setup.api_provider, client_setup.api_auth)
.create_with_session_and_headers(sdp, session_config, extra_headers)
.await
.map_err(map_api_error)?;
Ok(RealtimeWebrtcCallStart {
sdp: response.sdp,
call_id: response.call_id,
sideband_headers,
})
}
/// Builds memory summaries for each provided normalized raw memory.
///
/// This is a unary call (no streaming) to `/v1/memories/trace_summarize`.
///
/// The model selection, reasoning effort, and telemetry context are passed explicitly to keep
/// `ModelClient` session-scoped.
pub async fn summarize_memories(
&self,
raw_memories: Vec<ApiRawMemory>,
model_info: &ModelInfo,
effort: Option<ReasoningEffortConfig>,
session_telemetry: &SessionTelemetry,
) -> Result<Vec<ApiMemorySummarizeOutput>> {
if raw_memories.is_empty() {
return Ok(Vec::new());
}
let client_setup = self.current_client_setup().await?;
let transport = ReqwestTransport::new(build_reqwest_client());
let request_telemetry = Self::build_request_telemetry(
session_telemetry,
AuthRequestTelemetryContext::new(
client_setup.auth.as_ref().map(CodexAuth::auth_mode),
client_setup.api_auth.as_ref(),
PendingUnauthorizedRetry::default(),
),
RequestRouteTelemetry::for_endpoint(MEMORIES_SUMMARIZE_ENDPOINT),
self.state.auth_env_telemetry.clone(),
);
let client =
ApiMemoriesClient::new(transport, client_setup.api_provider, client_setup.api_auth)
.with_telemetry(Some(request_telemetry));
let payload = ApiMemorySummarizeInput {
model: model_info.slug.clone(),
raw_memories,
reasoning: effort.map(|effort| Reasoning {
effort: Some(effort),
summary: None,
}),
};
client
.summarize_input(&payload, self.build_subagent_headers())
.await
.map_err(map_api_error)
}
fn build_subagent_headers(&self) -> ApiHeaderMap {
let mut extra_headers = ApiHeaderMap::new();
if let Some(subagent) = subagent_header_value(&self.state.session_source)
&& let Ok(val) = HeaderValue::from_str(&subagent)
{
extra_headers.insert(X_OPENAI_SUBAGENT_HEADER, val);
}
if matches!(
self.state.session_source,
SessionSource::Internal(InternalSessionSource::MemoryConsolidation)
) {
extra_headers.insert(
X_OPENAI_MEMGEN_REQUEST_HEADER,
HeaderValue::from_static("true"),
);
}
extra_headers
}
fn build_responses_identity_headers(&self) -> ApiHeaderMap {
let mut extra_headers = self.build_subagent_headers();
if let Some(parent_thread_id) = parent_thread_id_header_value(&self.state.session_source)
&& let Ok(val) = HeaderValue::from_str(&parent_thread_id)
{
extra_headers.insert(X_CODEX_PARENT_THREAD_ID_HEADER, val);
}
if let Ok(val) = HeaderValue::from_str(&self.current_window_id()) {
extra_headers.insert(X_CODEX_WINDOW_ID_HEADER, val);
}
extra_headers
}
fn build_ws_client_metadata(
&self,
turn_metadata_header: Option<&str>,
) -> HashMap<String, String> {
let mut client_metadata = HashMap::new();
client_metadata.insert(
X_CODEX_INSTALLATION_ID_HEADER.to_string(),
self.state.installation_id.clone(),
);
client_metadata.insert(
X_CODEX_WINDOW_ID_HEADER.to_string(),
self.current_window_id(),
);
if let Some(subagent) = subagent_header_value(&self.state.session_source) {
client_metadata.insert(X_OPENAI_SUBAGENT_HEADER.to_string(), subagent);
}
if let Some(parent_thread_id) = parent_thread_id_header_value(&self.state.session_source) {
client_metadata.insert(
X_CODEX_PARENT_THREAD_ID_HEADER.to_string(),
parent_thread_id,
);
}
if let Some(turn_metadata_header) = parse_turn_metadata_header(turn_metadata_header)
&& let Ok(turn_metadata) = turn_metadata_header.to_str()
{
client_metadata.insert(
X_CODEX_TURN_METADATA_HEADER.to_string(),
turn_metadata.to_string(),
);
}
client_metadata
}
async fn generate_attestation_header_for(&self) -> Option<HeaderValue> {
if !self.state.include_attestation {
return None;
}
self.state
.attestation_provider
.as_ref()?
.header_for_request(AttestationContext {
thread_id: self.state.thread_id,
})
.await
}
/// Builds request telemetry for unary API calls (e.g., Compact endpoint).
fn build_request_telemetry(
session_telemetry: &SessionTelemetry,
auth_context: AuthRequestTelemetryContext,
request_route_telemetry: RequestRouteTelemetry,
auth_env_telemetry: AuthEnvTelemetry,
) -> Arc<dyn RequestTelemetry> {
let telemetry = Arc::new(ApiTelemetry::new(
session_telemetry.clone(),
auth_context,
request_route_telemetry,
auth_env_telemetry,
));
let request_telemetry: Arc<dyn RequestTelemetry> = telemetry;
request_telemetry
}
fn build_reasoning(
model_info: &ModelInfo,
effort: Option<ReasoningEffortConfig>,
summary: ReasoningSummaryConfig,
) -> Option<Reasoning> {
if model_info.supports_reasoning_summaries {
Some(Reasoning {
effort: effort.or(model_info.default_reasoning_level),
summary: if summary == ReasoningSummaryConfig::None {
None
} else {
Some(summary)
},
})
} else {
None
}
}
fn build_responses_request(
&self,
provider: &codex_api::Provider,
prompt: &Prompt,
model_info: &ModelInfo,
effort: Option<ReasoningEffortConfig>,
summary: ReasoningSummaryConfig,
service_tier: Option<String>,
) -> Result<ResponsesApiRequest> {
let instructions = &prompt.base_instructions.text;
let input = prompt.get_formatted_input();
let tools = create_tools_json_for_responses_api(&prompt.tools)?;
let reasoning = Self::build_reasoning(model_info, effort, summary);
let include = if reasoning.is_some() {
vec!["reasoning.encrypted_content".to_string()]
} else {
Vec::new()
};
let verbosity = if model_info.support_verbosity {
self.state.model_verbosity.or(model_info.default_verbosity)
} else {
if self.state.model_verbosity.is_some() {
warn!(
"model_verbosity is set but ignored as the model does not support verbosity: {}",
model_info.slug
);
}
None
};
let text = create_text_param_for_request(
verbosity,
&prompt.output_schema,
prompt.output_schema_strict,
);
let prompt_cache_key = Some(self.state.thread_id.to_string());
let service_tier = model_info.service_tier_for_request(service_tier);
let request = ResponsesApiRequest {
model: model_info.slug.clone(),
instructions: instructions.clone(),
input,
tools,
tool_choice: "auto".to_string(),
parallel_tool_calls: prompt.parallel_tool_calls,
reasoning,
store: provider.is_azure_responses_endpoint(),
stream: true,
include,
service_tier,
prompt_cache_key,
text,
client_metadata: Some(HashMap::from([(
X_CODEX_INSTALLATION_ID_HEADER.to_string(),
self.state.installation_id.clone(),
)])),
};
Ok(request)
}
/// Returns whether the Responses-over-WebSocket transport is active for this session.
///
/// WebSocket use is controlled by provider capability and session-scoped fallback state.
pub fn responses_websocket_enabled(&self) -> bool {
if !self.state.provider.info().supports_websockets
|| self.state.disable_websockets.load(Ordering::Relaxed)
{
return false;
}
true
}
/// Returns auth + provider configuration resolved from the current session auth state.
///
/// This centralizes setup used by both prewarm and normal request paths so they stay in
/// lockstep when auth/provider resolution changes.
async fn current_client_setup(&self) -> Result<CurrentClientSetup> {
let auth = self.state.provider.auth().await;
let api_provider = self.state.provider.api_provider().await?;
let api_auth = self.state.provider.api_auth().await?;
Ok(CurrentClientSetup {
auth,
api_provider,
api_auth,
})
}
/// Opens a websocket connection using the same header and telemetry wiring as normal turns.
///
/// Both startup prewarm and in-turn `needs_new` reconnects call this path so handshake
/// behavior remains consistent across both flows.
#[allow(clippy::too_many_arguments)]
async fn connect_websocket(
&self,
session_telemetry: &SessionTelemetry,
api_provider: codex_api::Provider,
api_auth: SharedAuthProvider,
turn_state: Option<Arc<OnceLock<String>>>,
turn_metadata_header: Option<&str>,
auth_context: AuthRequestTelemetryContext,
request_route_telemetry: RequestRouteTelemetry,
) -> std::result::Result<ApiWebSocketConnection, ApiError> {
let headers = self
.build_websocket_headers(turn_state.as_ref(), turn_metadata_header)
.await;
let websocket_telemetry = ModelClientSession::build_websocket_telemetry(
session_telemetry,
auth_context,
request_route_telemetry,
self.state.auth_env_telemetry.clone(),
);
let websocket_connect_timeout = self.state.provider.info().websocket_connect_timeout();
let start = Instant::now();
let result = match tokio::time::timeout(
websocket_connect_timeout,
ApiWebSocketResponsesClient::new(api_provider, api_auth).connect(
headers,
codex_login::default_client::default_headers(),
turn_state,
Some(websocket_telemetry),
),
)
.await
{
Ok(result) => result,
Err(_) => Err(ApiError::Transport(TransportError::Timeout)),
};
let error_message = result.as_ref().err().map(telemetry_api_error_message);
let response_debug = result
.as_ref()
.err()
.map(extract_response_debug_context_from_api_error)
.unwrap_or_default();
let status = result.as_ref().err().and_then(api_error_http_status);
session_telemetry.record_websocket_connect(
start.elapsed(),
status,
error_message.as_deref(),
auth_context.auth_header_attached,
auth_context.auth_header_name,
auth_context.retry_after_unauthorized,
auth_context.recovery_mode,
auth_context.recovery_phase,
request_route_telemetry.endpoint,
/*connection_reused*/ false,
response_debug.request_id.as_deref(),
response_debug.cf_ray.as_deref(),
response_debug.auth_error.as_deref(),
response_debug.auth_error_code.as_deref(),
);
emit_feedback_request_tags_with_auth_env(
&FeedbackRequestTags {
endpoint: request_route_telemetry.endpoint,
auth_header_attached: auth_context.auth_header_attached,
auth_header_name: auth_context.auth_header_name,
auth_mode: auth_context.auth_mode,
auth_retry_after_unauthorized: Some(auth_context.retry_after_unauthorized),
auth_recovery_mode: auth_context.recovery_mode,
auth_recovery_phase: auth_context.recovery_phase,
auth_connection_reused: Some(false),
auth_request_id: response_debug.request_id.as_deref(),
auth_cf_ray: response_debug.cf_ray.as_deref(),
auth_error: response_debug.auth_error.as_deref(),
auth_error_code: response_debug.auth_error_code.as_deref(),
auth_recovery_followup_success: auth_context
.retry_after_unauthorized
.then_some(result.is_ok()),
auth_recovery_followup_status: auth_context
.retry_after_unauthorized
.then_some(status)
.flatten(),
},
&self.state.auth_env_telemetry,
);
result
}
/// Builds websocket handshake headers for both prewarm and turn-time reconnect.
///
/// Callers should pass the current turn-state lock when available so sticky-routing state is
/// replayed on reconnect within the same turn.
async fn build_websocket_headers(
&self,
turn_state: Option<&Arc<OnceLock<String>>>,
turn_metadata_header: Option<&str>,
) -> ApiHeaderMap {
let turn_metadata_header = parse_turn_metadata_header(turn_metadata_header);
let session_id = self.state.session_id.to_string();
let thread_id = self.state.thread_id.to_string();
let mut headers = build_responses_headers(
self.state.beta_features_header.as_deref(),
turn_state,
turn_metadata_header.as_ref(),
);
if let Ok(header_value) = HeaderValue::from_str(&thread_id) {
headers.insert("x-client-request-id", header_value);
}
headers.extend(build_session_headers(Some(session_id), Some(thread_id)));
headers.extend(self.build_responses_identity_headers());
if let Some(header_value) = self.generate_attestation_header_for().await {
headers.insert(X_OAI_ATTESTATION_HEADER, header_value);
}
headers.insert(
OPENAI_BETA_HEADER,
HeaderValue::from_static(RESPONSES_WEBSOCKETS_V2_BETA_HEADER_VALUE),
);
if self.state.include_timing_metrics {
headers.insert(
X_RESPONSESAPI_INCLUDE_TIMING_METRICS_HEADER,
HeaderValue::from_static("true"),
);
}
headers
}
}
impl Drop for ModelClientSession {
fn drop(&mut self) {
let websocket_session = std::mem::take(&mut self.websocket_session);
self.client
.store_cached_websocket_session(websocket_session);
}
}
impl ModelClientSession {
pub(crate) fn reset_websocket_session(&mut self) {
self.websocket_session.connection = None;
self.websocket_session.last_request = None;
self.websocket_session.last_response_rx = None;
self.websocket_session.last_response_from_untraced_warmup = false;
self.websocket_session
.set_connection_reused(/*connection_reused*/ false);
}
pub(crate) async fn send_response_processed(&self, response_id: &str) {
let Some(connection) = self.websocket_session.connection.as_ref() else {
return;
};
if let Err(err) = connection
.send_response_processed(response_id.to_string())
.await
{
debug!("failed to send response.processed websocket request: {err}");
}
}
#[allow(clippy::too_many_arguments)]
/// Builds shared Responses API transport options and request-body options.
///
/// Keeping option construction in one place ensures request-scoped headers are consistent
/// regardless of transport choice.
async fn build_responses_options(
&self,
turn_metadata_header: Option<&str>,
compression: Compression,
) -> ApiResponsesOptions {
let turn_metadata_header = parse_turn_metadata_header(turn_metadata_header);
let session_id = self.client.state.session_id.to_string();
let thread_id = self.client.state.thread_id.to_string();
ApiResponsesOptions {
session_id: Some(session_id),
thread_id: Some(thread_id),
session_source: Some(self.client.state.session_source.clone()),
extra_headers: {
let mut headers = build_responses_headers(
self.client.state.beta_features_header.as_deref(),
Some(&self.turn_state),
turn_metadata_header.as_ref(),
);
headers.extend(self.client.build_responses_identity_headers());
if let Some(header_value) = self.client.generate_attestation_header_for().await {
headers.insert(X_OAI_ATTESTATION_HEADER, header_value);
}
headers
},
compression,
turn_state: Some(Arc::clone(&self.turn_state)),
}
}
fn get_incremental_items(
&self,
request: &ResponsesApiRequest,
last_response: Option<&LastResponse>,
allow_empty_delta: bool,
) -> Option<Vec<ResponseItem>> {
// Checks whether the current request is an incremental extension of the previous request.
// We only reuse an incremental input delta when non-input request fields are unchanged and
// `input` is a strict
// extension of the previous known input. Server-returned output items are treated as part
// of the baseline so we do not resend them.
let previous_request = self.websocket_session.last_request.as_ref()?;
let mut previous_without_input = previous_request.clone();
previous_without_input.input.clear();
let mut request_without_input = request.clone();
request_without_input.input.clear();
if previous_without_input != request_without_input {
trace!(
"incremental request failed, properties didn't match {previous_without_input:?} != {request_without_input:?}"
);
return None;
}
let mut baseline = previous_request.input.clone();
if let Some(last_response) = last_response {
baseline.extend(last_response.items_added.clone());
}
let baseline_len = baseline.len();
if request.input.starts_with(&baseline)
&& (allow_empty_delta || baseline_len < request.input.len())
{
Some(request.input[baseline_len..].to_vec())
} else {
trace!("incremental request failed, items didn't match");
None
}
}
fn get_last_response(&mut self) -> Option<LastResponse> {
self.websocket_session
.last_response_rx
.take()
.and_then(|mut receiver| match receiver.try_recv() {
Ok(last_response) => Some(last_response),
Err(TryRecvError::Closed) | Err(TryRecvError::Empty) => None,
})
}
fn prepare_websocket_request(
&mut self,
payload: ResponseCreateWsRequest,
request: &ResponsesApiRequest,
) -> (ResponsesWsRequest, bool) {
let Some(last_response) = self.get_last_response() else {
return (ResponsesWsRequest::ResponseCreate(payload), false);
};
let previous_response_id_from_untraced_warmup =
self.websocket_session.last_response_from_untraced_warmup;
let Some(incremental_items) = self.get_incremental_items(
request,
Some(&last_response),
/*allow_empty_delta*/ true,
) else {
return (ResponsesWsRequest::ResponseCreate(payload), false);
};
if last_response.response_id.is_empty() {
trace!("incremental request failed, no previous response id");
return (ResponsesWsRequest::ResponseCreate(payload), false);
}
(
ResponsesWsRequest::ResponseCreate(ResponseCreateWsRequest {
previous_response_id: Some(last_response.response_id),
input: incremental_items,
..payload
}),
previous_response_id_from_untraced_warmup,
)
}
/// Opportunistically preconnects a websocket for this turn-scoped client session.
///
/// This performs only connection setup; it never sends prompt payloads.
pub async fn preconnect_websocket(
&mut self,
session_telemetry: &SessionTelemetry,
_model_info: &ModelInfo,
) -> std::result::Result<(), ApiError> {
if !self.client.responses_websocket_enabled() {
return Ok(());
}
if self.websocket_session.connection.is_some() {
return Ok(());
}
let client_setup = self.client.current_client_setup().await.map_err(|err| {
ApiError::Stream(format!(
"failed to build websocket prewarm client setup: {err}"
))
})?;
let auth_context = AuthRequestTelemetryContext::new(
client_setup.auth.as_ref().map(CodexAuth::auth_mode),
client_setup.api_auth.as_ref(),
PendingUnauthorizedRetry::default(),
);
let connection = self
.client
.connect_websocket(
session_telemetry,
client_setup.api_provider,
client_setup.api_auth,
Some(Arc::clone(&self.turn_state)),
/*turn_metadata_header*/ None,
auth_context,
RequestRouteTelemetry::for_endpoint(RESPONSES_ENDPOINT),
)
.await?;
self.websocket_session.connection = Some(connection);
self.websocket_session
.set_connection_reused(/*connection_reused*/ false);
Ok(())
}
/// Returns a websocket connection for this turn.
#[instrument(
name = "model_client.websocket_connection",
level = "info",
skip_all,
fields(
provider = %self.client.state.provider.info().name,
wire_api = %self.client.state.provider.info().wire_api,
transport = "responses_websocket",
api.path = "responses",
turn.has_metadata_header = params.turn_metadata_header.is_some()
)
)]
async fn websocket_connection(
&mut self,
params: WebsocketConnectParams<'_>,
) -> std::result::Result<&ApiWebSocketConnection, ApiError> {
let WebsocketConnectParams {
session_telemetry,
api_provider,
api_auth,
turn_metadata_header,
options,
auth_context,
request_route_telemetry,
} = params;
let needs_new = match self.websocket_session.connection.as_ref() {
Some(conn) => conn.is_closed().await,
None => true,
};
if needs_new {
self.websocket_session.last_request = None;
self.websocket_session.last_response_rx = None;
self.websocket_session.last_response_from_untraced_warmup = false;
let turn_state = options
.turn_state
.clone()
.unwrap_or_else(|| Arc::clone(&self.turn_state));
let new_conn = match self
.client
.connect_websocket(
session_telemetry,
api_provider,
api_auth,
Some(turn_state),
turn_metadata_header,
auth_context,
request_route_telemetry,
)
.await
{
Ok(new_conn) => new_conn,
Err(err) => {
if matches!(err, ApiError::Transport(TransportError::Timeout)) {
self.reset_websocket_session();
}
return Err(err);
}
};
self.websocket_session.connection = Some(new_conn);
self.websocket_session
.set_connection_reused(/*connection_reused*/ false);
} else {
self.websocket_session
.set_connection_reused(/*connection_reused*/ true);
}
self.websocket_session
.connection
.as_ref()
.ok_or(ApiError::Stream(
"websocket connection is unavailable".to_string(),
))
}
fn responses_request_compression(&self, auth: Option<&CodexAuth>) -> Compression {
if self.client.state.enable_request_compression
&& auth.is_some_and(CodexAuth::uses_codex_backend)
&& self.client.state.provider.info().is_openai()
{
Compression::Zstd
} else {
Compression::None
}
}
/// Streams a turn via the OpenAI Responses API.
///
/// Handles reasoning summaries, verbosity, and the `text` controls used for output schemas.
#[allow(clippy::too_many_arguments)]
#[instrument(
name = "model_client.stream_responses_api",
level = "info",
skip_all,
fields(
model = %model_info.slug,
wire_api = %self.client.state.provider.info().wire_api,
transport = "responses_http",
http.method = "POST",
api.path = "responses",
turn.has_metadata_header = turn_metadata_header.is_some()
)
)]
async fn stream_responses_api(
&self,
prompt: &Prompt,
model_info: &ModelInfo,
session_telemetry: &SessionTelemetry,
effort: Option<ReasoningEffortConfig>,
summary: ReasoningSummaryConfig,
service_tier: Option<String>,
turn_metadata_header: Option<&str>,
inference_trace: &InferenceTraceContext,
) -> Result<ResponseStream> {
let auth_manager = self.client.state.provider.auth_manager();
let mut auth_recovery = auth_manager
.as_ref()
.map(AuthManager::unauthorized_recovery);
let mut pending_retry = PendingUnauthorizedRetry::default();
loop {
let client_setup = self.client.current_client_setup().await?;
let transport = ReqwestTransport::new(build_reqwest_client());
let request_auth_context = AuthRequestTelemetryContext::new(
client_setup.auth.as_ref().map(CodexAuth::auth_mode),
client_setup.api_auth.as_ref(),
pending_retry,
);
let (request_telemetry, sse_telemetry) = Self::build_streaming_telemetry(
session_telemetry,
request_auth_context,
RequestRouteTelemetry::for_endpoint(RESPONSES_ENDPOINT),
self.client.state.auth_env_telemetry.clone(),
);
let compression = self.responses_request_compression(client_setup.auth.as_ref());
let mut options = self
.build_responses_options(turn_metadata_header, compression)
.await;
let request = self.client.build_responses_request(
&client_setup.api_provider,
prompt,
model_info,
effort,
summary,
service_tier.clone(),
)?;
let inference_trace_attempt = inference_trace.start_attempt();
inference_trace_attempt.add_request_headers(&mut options.extra_headers);
inference_trace_attempt.record_started(&request);
let client = ApiResponsesClient::new(
transport,
client_setup.api_provider,
client_setup.api_auth,
)
.with_telemetry(Some(request_telemetry), Some(sse_telemetry));
let stream_result = client.stream_request(request, options).await;
match stream_result {
Ok(stream) => {
let (stream, _) = map_response_stream(
stream,
session_telemetry.clone(),
inference_trace_attempt,
);
return Ok(stream);
}
Err(ApiError::Transport(
unauthorized_transport @ TransportError::Http { status, .. },
)) if status == StatusCode::UNAUTHORIZED => {
let response_debug_context =
extract_response_debug_context(&unauthorized_transport);
inference_trace_attempt.record_failed(
&unauthorized_transport,
response_debug_context.request_id.as_deref(),
/*output_items*/ &[],
);
pending_retry = PendingUnauthorizedRetry::from_recovery(
handle_unauthorized(
unauthorized_transport,
&mut auth_recovery,
session_telemetry,
)
.await?,
);
continue;
}
Err(err) => {
let response_debug_context =
extract_response_debug_context_from_api_error(&err);
let err = map_api_error(err);
inference_trace_attempt.record_failed(
&err,
response_debug_context.request_id.as_deref(),
/*output_items*/ &[],
);
return Err(err);
}
}
}
}
/// Streams a turn via the Responses API over WebSocket transport.
#[allow(clippy::too_many_arguments)]
#[instrument(
name = "model_client.stream_responses_websocket",
level = "info",
skip_all,
fields(
model = %model_info.slug,
wire_api = %self.client.state.provider.info().wire_api,
transport = "responses_websocket",
api.path = "responses",
turn.has_metadata_header = turn_metadata_header.is_some(),
websocket.warmup = warmup
)
)]
async fn stream_responses_websocket(
&mut self,
prompt: &Prompt,
model_info: &ModelInfo,
session_telemetry: &SessionTelemetry,
effort: Option<ReasoningEffortConfig>,
summary: ReasoningSummaryConfig,
service_tier: Option<String>,
turn_metadata_header: Option<&str>,
warmup: bool,
request_trace: Option<W3cTraceContext>,
inference_trace: &InferenceTraceContext,
) -> Result<WebsocketStreamOutcome> {
let auth_manager = self.client.state.provider.auth_manager();
let mut auth_recovery = auth_manager
.as_ref()
.map(AuthManager::unauthorized_recovery);
let mut pending_retry = PendingUnauthorizedRetry::default();
loop {
let client_setup = self.client.current_client_setup().await?;
let request_auth_context = AuthRequestTelemetryContext::new(
client_setup.auth.as_ref().map(CodexAuth::auth_mode),
client_setup.api_auth.as_ref(),
pending_retry,
);
let compression = self.responses_request_compression(client_setup.auth.as_ref());
let options = self
.build_responses_options(turn_metadata_header, compression)
.await;
let request = self.client.build_responses_request(
&client_setup.api_provider,
prompt,
model_info,
effort,
summary,
service_tier.clone(),
)?;
let mut ws_payload = ResponseCreateWsRequest {
client_metadata: response_create_client_metadata(
Some(self.client.build_ws_client_metadata(turn_metadata_header)),
request_trace.as_ref(),
),
..ResponseCreateWsRequest::from(&request)
};
if warmup {
ws_payload.generate = Some(false);
}
match self
.websocket_connection(WebsocketConnectParams {
session_telemetry,
api_provider: client_setup.api_provider,
api_auth: client_setup.api_auth,
turn_metadata_header,
options: &options,
auth_context: request_auth_context,
request_route_telemetry: RequestRouteTelemetry::for_endpoint(
RESPONSES_ENDPOINT,
),
})
.await
{
Ok(_) => {}
Err(ApiError::Transport(TransportError::Http { status, .. }))
if status == StatusCode::UPGRADE_REQUIRED =>
{
return Ok(WebsocketStreamOutcome::FallbackToHttp);
}
Err(ApiError::Transport(
unauthorized_transport @ TransportError::Http { status, .. },
)) if status == StatusCode::UNAUTHORIZED => {
pending_retry = PendingUnauthorizedRetry::from_recovery(
handle_unauthorized(
unauthorized_transport,
&mut auth_recovery,
session_telemetry,
)
.await?,
);
continue;
}
Err(err) => return Err(map_api_error(err)),
}
let (mut ws_request, previous_response_id_from_untraced_warmup) =
self.prepare_websocket_request(ws_payload, &request);
let inference_trace_attempt = if warmup {
// Prewarm sends `generate=false`; it is connection setup, not a
// model inference attempt that should appear in rollout traces.
InferenceTraceAttempt::disabled()
} else {
inference_trace.start_attempt()
};
stamp_ws_stream_request_start_ms(&mut ws_request);
if previous_response_id_from_untraced_warmup {
// The transport can reuse an untraced warmup response id and omit the
// already-sent input, but rollout replay needs the logical model-visible
// request rather than the compressed websocket delta.
inference_trace_attempt.record_started(&request);
} else {
inference_trace_attempt.record_started(&ws_request);
}
self.websocket_session.last_request = Some(request);
self.websocket_session.last_response_from_untraced_warmup = warmup;
let websocket_connection =
self.websocket_session.connection.as_ref().ok_or_else(|| {
map_api_error(ApiError::Stream(
"websocket connection is unavailable".to_string(),
))
})?;
let stream_result = websocket_connection
.stream_request(ws_request, self.websocket_session.connection_reused())
.await
.map_err(|err| {
let response_debug_context =
extract_response_debug_context_from_api_error(&err);
let err = map_api_error(err);
inference_trace_attempt.record_failed(
&err,
response_debug_context.request_id.as_deref(),
/*output_items*/ &[],
);
err
})?;
let (stream, last_request_rx) = map_response_stream(
stream_result,
session_telemetry.clone(),
inference_trace_attempt,
);
self.websocket_session.last_response_rx = Some(last_request_rx);
return Ok(WebsocketStreamOutcome::Stream(stream));
}
}
/// Builds request and SSE telemetry for streaming API calls.
fn build_streaming_telemetry(
session_telemetry: &SessionTelemetry,
auth_context: AuthRequestTelemetryContext,
request_route_telemetry: RequestRouteTelemetry,
auth_env_telemetry: AuthEnvTelemetry,
) -> (Arc<dyn RequestTelemetry>, Arc<dyn SseTelemetry>) {
let telemetry = Arc::new(ApiTelemetry::new(
session_telemetry.clone(),
auth_context,
request_route_telemetry,
auth_env_telemetry,
));
let request_telemetry: Arc<dyn RequestTelemetry> = telemetry.clone();
let sse_telemetry: Arc<dyn SseTelemetry> = telemetry;
(request_telemetry, sse_telemetry)
}
/// Builds telemetry for the Responses API WebSocket transport.
fn build_websocket_telemetry(
session_telemetry: &SessionTelemetry,
auth_context: AuthRequestTelemetryContext,
request_route_telemetry: RequestRouteTelemetry,
auth_env_telemetry: AuthEnvTelemetry,
) -> Arc<dyn WebsocketTelemetry> {
let telemetry = Arc::new(ApiTelemetry::new(
session_telemetry.clone(),
auth_context,
request_route_telemetry,
auth_env_telemetry,
));
let websocket_telemetry: Arc<dyn WebsocketTelemetry> = telemetry;
websocket_telemetry
}
#[allow(clippy::too_many_arguments)]
pub async fn prewarm_websocket(
&mut self,
prompt: &Prompt,
model_info: &ModelInfo,
session_telemetry: &SessionTelemetry,
effort: Option<ReasoningEffortConfig>,
summary: ReasoningSummaryConfig,
service_tier: Option<String>,
turn_metadata_header: Option<&str>,
) -> Result<()> {
if !self.client.responses_websocket_enabled() {
return Ok(());
}
if self.websocket_session.last_request.is_some() {
return Ok(());
}
let disabled_trace = InferenceTraceContext::disabled();
match self
.stream_responses_websocket(
prompt,
model_info,
session_telemetry,
effort,
summary,
service_tier,
turn_metadata_header,
/*warmup*/ true,
current_span_w3c_trace_context(),
&disabled_trace,
)
.await
{
Ok(WebsocketStreamOutcome::Stream(mut stream)) => {
// Wait for the v2 warmup request to complete before sending the first turn request.
while let Some(event) = stream.next().await {
match event {
Ok(ResponseEvent::Completed { .. }) => break,
Err(err) => return Err(err),
_ => {}
}
}
Ok(())
}
Ok(WebsocketStreamOutcome::FallbackToHttp) => {
self.try_switch_fallback_transport(session_telemetry, model_info);
Ok(())
}
Err(err) => Err(err),
}
}
#[allow(clippy::too_many_arguments)]
/// Streams a single model request within the current turn.
///
/// The caller is responsible for passing per-turn settings explicitly (model selection,
/// reasoning settings, telemetry context, and turn metadata). This method will prefer the
/// Responses WebSocket transport when the provider supports it and it remains healthy, and will
/// fall back to the HTTP Responses API transport otherwise. The trace context may be enabled or
/// disabled, but is always explicit so transport paths do not need separate trace/no-trace
/// branches.
pub async fn stream(
&mut self,
prompt: &Prompt,
model_info: &ModelInfo,
session_telemetry: &SessionTelemetry,
effort: Option<ReasoningEffortConfig>,
summary: ReasoningSummaryConfig,
service_tier: Option<String>,
turn_metadata_header: Option<&str>,
inference_trace: &InferenceTraceContext,
) -> Result<ResponseStream> {
let wire_api = self.client.state.provider.info().wire_api;
match wire_api {
WireApi::Responses => {
if self.client.responses_websocket_enabled() {
let request_trace = current_span_w3c_trace_context();
match self
.stream_responses_websocket(
prompt,
model_info,
session_telemetry,
effort,
summary,
service_tier.clone(),
turn_metadata_header,
/*warmup*/ false,
request_trace,
inference_trace,
)
.await?
{
WebsocketStreamOutcome::Stream(stream) => return Ok(stream),
WebsocketStreamOutcome::FallbackToHttp => {
self.try_switch_fallback_transport(session_telemetry, model_info);
}
}
}
self.stream_responses_api(
prompt,
model_info,
session_telemetry,
effort,
summary,
service_tier,
turn_metadata_header,
inference_trace,
)
.await
}
}
}
/// Permanently disables WebSockets for this Codex session and resets WebSocket state.
///
/// This is used after exhausting the provider retry budget, to force subsequent requests onto
/// the HTTP transport.
///
/// Returns `true` if this call activated fallback, or `false` if fallback was already active.
pub(crate) fn try_switch_fallback_transport(
&mut self,
session_telemetry: &SessionTelemetry,
model_info: &ModelInfo,
) -> bool {
let activated = self
.client
.force_http_fallback(session_telemetry, model_info);
self.websocket_session = WebsocketSession::default();
activated
}
}
/// Parses per-turn metadata into an HTTP header value.
///
/// Invalid values are treated as absent so callers can compare and propagate
/// metadata with the same sanitization path used when constructing headers.
fn parse_turn_metadata_header(turn_metadata_header: Option<&str>) -> Option<HeaderValue> {
turn_metadata_header.and_then(|value| HeaderValue::from_str(value).ok())
}
/// Stamp a ResponsesWsRequest with the current time.
///
/// Meant to be called just before sending the request over the socket, to capture realistic
/// transport timing.
fn stamp_ws_stream_request_start_ms(request: &mut ResponsesWsRequest) {
let ResponsesWsRequest::ResponseCreate(payload) = request else {
return;
};
payload
.client_metadata
.get_or_insert_with(HashMap::new)
.insert(
X_CODEX_WS_STREAM_REQUEST_START_MS_CLIENT_METADATA_KEY.to_string(),
crate::turn_timing::now_unix_timestamp_ms().to_string(),
);
}
/// Builds the extra headers attached to Responses API requests.
///
/// These headers implement Codex-specific conventions:
///
/// - `x-codex-beta-features`: comma-separated beta feature keys enabled for the session.
/// - `x-codex-turn-state`: sticky routing token captured earlier in the turn.
/// - `x-codex-turn-metadata`: optional per-turn metadata for observability.
fn build_responses_headers(
beta_features_header: Option<&str>,
turn_state: Option<&Arc<OnceLock<String>>>,
turn_metadata_header: Option<&HeaderValue>,
) -> ApiHeaderMap {
let mut headers = ApiHeaderMap::new();
if let Some(value) = beta_features_header
&& !value.is_empty()
&& let Ok(header_value) = HeaderValue::from_str(value)
{
headers.insert("x-codex-beta-features", header_value);
}
if let Some(turn_state) = turn_state
&& let Some(state) = turn_state.get()
&& let Ok(header_value) = HeaderValue::from_str(state)
{
headers.insert(X_CODEX_TURN_STATE_HEADER, header_value);
}
if let Some(header_value) = turn_metadata_header {
headers.insert(X_CODEX_TURN_METADATA_HEADER, header_value.clone());
}
headers
}
fn subagent_header_value(session_source: &SessionSource) -> Option<String> {
match session_source {
SessionSource::SubAgent(subagent_source) => match subagent_source {
SubAgentSource::Review => Some("review".to_string()),
SubAgentSource::Compact => Some("compact".to_string()),
SubAgentSource::MemoryConsolidation => Some("memory_consolidation".to_string()),
SubAgentSource::ThreadSpawn { .. } => Some("collab_spawn".to_string()),
SubAgentSource::Other(label) => Some(label.clone()),
},
SessionSource::Internal(InternalSessionSource::MemoryConsolidation) => {
Some("memory_consolidation".to_string())
}
SessionSource::Cli
| SessionSource::VSCode
| SessionSource::Exec
| SessionSource::Mcp
| SessionSource::Custom(_)
| SessionSource::Unknown => None,
}
}
fn parent_thread_id_header_value(session_source: &SessionSource) -> Option<String> {
match session_source {
SessionSource::SubAgent(SubAgentSource::ThreadSpawn {
parent_thread_id, ..
}) => Some(parent_thread_id.to_string()),
SessionSource::Cli
| SessionSource::VSCode
| SessionSource::Exec
| SessionSource::Mcp
| SessionSource::Custom(_)
| SessionSource::Internal(_)
| SessionSource::SubAgent(_)
| SessionSource::Unknown => None,
}
}
const RESPONSE_STREAM_CHANNEL_CAPACITY: usize = 1600;
const STREAM_DROPPED_REASON: &str = "response stream dropped before provider terminal event";
fn map_response_stream(
api_stream: codex_api::ResponseStream,
session_telemetry: SessionTelemetry,
inference_trace_attempt: InferenceTraceAttempt,
) -> (ResponseStream, oneshot::Receiver<LastResponse>) {
let codex_api::ResponseStream {
rx_event,
upstream_request_id,
} = api_stream;
let api_stream = codex_api::ResponseStream {
rx_event,
upstream_request_id: None,
};
map_response_events(
upstream_request_id,
api_stream,
session_telemetry,
inference_trace_attempt,
)
}
fn map_response_events<S>(
upstream_request_id: Option<String>,
api_stream: S,
session_telemetry: SessionTelemetry,
inference_trace_attempt: InferenceTraceAttempt,
) -> (ResponseStream, oneshot::Receiver<LastResponse>)
where
S: futures::Stream<Item = std::result::Result<ResponseEvent, ApiError>>
+ Unpin
+ Send
+ 'static,
{
let (tx_event, rx_event) =
mpsc::channel::<Result<ResponseEvent>>(RESPONSE_STREAM_CHANNEL_CAPACITY);
let (tx_last_response, rx_last_response) = oneshot::channel::<LastResponse>();
let consumer_dropped = CancellationToken::new();
let consumer_dropped_for_stream = consumer_dropped.clone();
tokio::spawn(async move {
let mut logged_error = false;
let mut tx_last_response = Some(tx_last_response);
let mut items_added: Vec<ResponseItem> = Vec::new();
let mut api_stream = api_stream;
let upstream_request_id = upstream_request_id.as_deref();
if let Some(upstream_request_id) = upstream_request_id {
feedback_tags!(last_model_request_id = upstream_request_id);
}
loop {
let event = tokio::select! {
_ = consumer_dropped.cancelled() => {
inference_trace_attempt.record_cancelled(
STREAM_DROPPED_REASON,
upstream_request_id,
&items_added,
);
return;
}
event = api_stream.next() => event,
};
let Some(event) = event else {
break;
};
match event {
Ok(ResponseEvent::OutputItemDone(item)) => {
items_added.push(item.clone());
if tx_event
.send(Ok(ResponseEvent::OutputItemDone(item)))
.await
.is_err()
{
inference_trace_attempt.record_cancelled(
STREAM_DROPPED_REASON,
upstream_request_id,
&items_added,
);
return;
}
}
Ok(ResponseEvent::Completed {
response_id,
token_usage,
end_turn,
}) => {
feedback_tags!(last_model_response_id = &response_id);
if let Some(usage) = &token_usage {
session_telemetry.sse_event_completed(
usage.input_tokens,
usage.output_tokens,
Some(usage.cached_input_tokens),
Some(usage.reasoning_output_tokens),
usage.total_tokens,
);
}
inference_trace_attempt.record_completed(
&response_id,
upstream_request_id,
&token_usage,
&items_added,
);
if let Some(sender) = tx_last_response.take() {
let _ = sender.send(LastResponse {
response_id: response_id.clone(),
items_added: std::mem::take(&mut items_added),
});
}
if tx_event
.send(Ok(ResponseEvent::Completed {
response_id,
token_usage,
end_turn,
}))
.await
.is_err()
{
return;
}
}
Ok(event) => {
if tx_event.send(Ok(event)).await.is_err() {
inference_trace_attempt.record_cancelled(
STREAM_DROPPED_REASON,
upstream_request_id,
&items_added,
);
return;
}
}
Err(err) => {
let response_debug_context =
extract_response_debug_context_from_api_error(&err);
let upstream_request_id =
upstream_request_id.or(response_debug_context.request_id.as_deref());
if let Some(upstream_request_id) = upstream_request_id {
feedback_tags!(last_model_request_id = upstream_request_id);
}
let mapped = map_api_error(err);
inference_trace_attempt.record_failed(
&mapped,
upstream_request_id,
&items_added,
);
if !logged_error {
session_telemetry.see_event_completed_failed(&mapped);
logged_error = true;
}
if tx_event.send(Err(mapped)).await.is_err() {
return;
}
}
}
}
inference_trace_attempt.record_failed(
"stream closed before response.completed",
upstream_request_id,
&items_added,
);
});
(
ResponseStream {
rx_event,
consumer_dropped: consumer_dropped_for_stream,
},
rx_last_response,
)
}
/// Handles a 401 response by optionally refreshing ChatGPT tokens once.
///
/// When refresh succeeds, the caller should retry the API call; otherwise
/// the mapped `CodexErr` is returned to the caller.
#[derive(Clone, Copy, Debug)]
struct UnauthorizedRecoveryExecution {
mode: &'static str,
phase: &'static str,
}
#[derive(Clone, Copy, Debug, Default)]
struct PendingUnauthorizedRetry {
retry_after_unauthorized: bool,
recovery_mode: Option<&'static str>,
recovery_phase: Option<&'static str>,
}
impl PendingUnauthorizedRetry {
fn from_recovery(recovery: UnauthorizedRecoveryExecution) -> Self {
Self {
retry_after_unauthorized: true,
recovery_mode: Some(recovery.mode),
recovery_phase: Some(recovery.phase),
}
}
}
#[derive(Clone, Copy, Debug, Default)]
struct AuthRequestTelemetryContext {
auth_mode: Option<&'static str>,
auth_header_attached: bool,
auth_header_name: Option<&'static str>,
retry_after_unauthorized: bool,
recovery_mode: Option<&'static str>,
recovery_phase: Option<&'static str>,
}
impl AuthRequestTelemetryContext {
fn new(
auth_mode: Option<AuthMode>,
api_auth: &dyn AuthProvider,
retry: PendingUnauthorizedRetry,
) -> Self {
let auth_telemetry = auth_header_telemetry(api_auth);
Self {
auth_mode: auth_mode.map(|mode| match mode {
AuthMode::ApiKey => "ApiKey",
AuthMode::Chatgpt | AuthMode::ChatgptAuthTokens | AuthMode::AgentIdentity => {
"Chatgpt"
}
}),
auth_header_attached: auth_telemetry.attached,
auth_header_name: auth_telemetry.name,
retry_after_unauthorized: retry.retry_after_unauthorized,
recovery_mode: retry.recovery_mode,
recovery_phase: retry.recovery_phase,
}
}
}
struct WebsocketConnectParams<'a> {
session_telemetry: &'a SessionTelemetry,
api_provider: codex_api::Provider,
api_auth: SharedAuthProvider,
turn_metadata_header: Option<&'a str>,
options: &'a ApiResponsesOptions,
auth_context: AuthRequestTelemetryContext,
request_route_telemetry: RequestRouteTelemetry,
}
async fn handle_unauthorized(
transport: TransportError,
auth_recovery: &mut Option<UnauthorizedRecovery>,
session_telemetry: &SessionTelemetry,
) -> Result<UnauthorizedRecoveryExecution> {
let debug = extract_response_debug_context(&transport);
if let Some(recovery) = auth_recovery
&& recovery.has_next()
{
let mode = recovery.mode_name();
let phase = recovery.step_name();
return match recovery.next().await {
Ok(step_result) => {
session_telemetry.record_auth_recovery(
mode,
phase,
"recovery_succeeded",
debug.request_id.as_deref(),
debug.cf_ray.as_deref(),
debug.auth_error.as_deref(),
debug.auth_error_code.as_deref(),
/*recovery_reason*/ None,
step_result.auth_state_changed(),
);
emit_feedback_auth_recovery_tags(
mode,
phase,
"recovery_succeeded",
debug.request_id.as_deref(),
debug.cf_ray.as_deref(),
debug.auth_error.as_deref(),
debug.auth_error_code.as_deref(),
);
Ok(UnauthorizedRecoveryExecution { mode, phase })
}
Err(RefreshTokenError::Permanent(failed)) => {
session_telemetry.record_auth_recovery(
mode,
phase,
"recovery_failed_permanent",
debug.request_id.as_deref(),
debug.cf_ray.as_deref(),
debug.auth_error.as_deref(),
debug.auth_error_code.as_deref(),
/*recovery_reason*/ None,
/*auth_state_changed*/ None,
);
emit_feedback_auth_recovery_tags(
mode,
phase,
"recovery_failed_permanent",
debug.request_id.as_deref(),
debug.cf_ray.as_deref(),
debug.auth_error.as_deref(),
debug.auth_error_code.as_deref(),
);
Err(CodexErr::RefreshTokenFailed(failed))
}
Err(RefreshTokenError::Transient(other)) => {
session_telemetry.record_auth_recovery(
mode,
phase,
"recovery_failed_transient",
debug.request_id.as_deref(),
debug.cf_ray.as_deref(),
debug.auth_error.as_deref(),
debug.auth_error_code.as_deref(),
/*recovery_reason*/ None,
/*auth_state_changed*/ None,
);
emit_feedback_auth_recovery_tags(
mode,
phase,
"recovery_failed_transient",
debug.request_id.as_deref(),
debug.cf_ray.as_deref(),
debug.auth_error.as_deref(),
debug.auth_error_code.as_deref(),
);
Err(CodexErr::Io(other))
}
};
}
let (mode, phase, recovery_reason) = match auth_recovery.as_ref() {
Some(recovery) => (
recovery.mode_name(),
recovery.step_name(),
Some(recovery.unavailable_reason()),
),
None => ("none", "none", Some("auth_manager_missing")),
};
session_telemetry.record_auth_recovery(
mode,
phase,
"recovery_not_run",
debug.request_id.as_deref(),
debug.cf_ray.as_deref(),
debug.auth_error.as_deref(),
debug.auth_error_code.as_deref(),
recovery_reason,
/*auth_state_changed*/ None,
);
emit_feedback_auth_recovery_tags(
mode,
phase,
"recovery_not_run",
debug.request_id.as_deref(),
debug.cf_ray.as_deref(),
debug.auth_error.as_deref(),
debug.auth_error_code.as_deref(),
);
Err(map_api_error(ApiError::Transport(transport)))
}
fn api_error_http_status(error: &ApiError) -> Option<u16> {
match error {
ApiError::Transport(TransportError::Http { status, .. }) => Some(status.as_u16()),
_ => None,
}
}
struct ApiTelemetry {
session_telemetry: SessionTelemetry,
auth_context: AuthRequestTelemetryContext,
request_route_telemetry: RequestRouteTelemetry,
auth_env_telemetry: AuthEnvTelemetry,
}
impl ApiTelemetry {
fn new(
session_telemetry: SessionTelemetry,
auth_context: AuthRequestTelemetryContext,
request_route_telemetry: RequestRouteTelemetry,
auth_env_telemetry: AuthEnvTelemetry,
) -> Self {
Self {
session_telemetry,
auth_context,
request_route_telemetry,
auth_env_telemetry,
}
}
}
impl RequestTelemetry for ApiTelemetry {
fn on_request(
&self,
attempt: u64,
status: Option<HttpStatusCode>,
error: Option<&TransportError>,
duration: Duration,
) {
let error_message = error.map(telemetry_transport_error_message);
let status = status.map(|s| s.as_u16());
let debug = error
.map(extract_response_debug_context)
.unwrap_or_default();
self.session_telemetry.record_api_request(
attempt,
status,
error_message.as_deref(),
duration,
self.auth_context.auth_header_attached,
self.auth_context.auth_header_name,
self.auth_context.retry_after_unauthorized,
self.auth_context.recovery_mode,
self.auth_context.recovery_phase,
self.request_route_telemetry.endpoint,
debug.request_id.as_deref(),
debug.cf_ray.as_deref(),
debug.auth_error.as_deref(),
debug.auth_error_code.as_deref(),
);
emit_feedback_request_tags_with_auth_env(
&FeedbackRequestTags {
endpoint: self.request_route_telemetry.endpoint,
auth_header_attached: self.auth_context.auth_header_attached,
auth_header_name: self.auth_context.auth_header_name,
auth_mode: self.auth_context.auth_mode,
auth_retry_after_unauthorized: Some(self.auth_context.retry_after_unauthorized),
auth_recovery_mode: self.auth_context.recovery_mode,
auth_recovery_phase: self.auth_context.recovery_phase,
auth_connection_reused: None,
auth_request_id: debug.request_id.as_deref(),
auth_cf_ray: debug.cf_ray.as_deref(),
auth_error: debug.auth_error.as_deref(),
auth_error_code: debug.auth_error_code.as_deref(),
auth_recovery_followup_success: self
.auth_context
.retry_after_unauthorized
.then_some(error.is_none()),
auth_recovery_followup_status: self
.auth_context
.retry_after_unauthorized
.then_some(status)
.flatten(),
},
&self.auth_env_telemetry,
);
}
}
impl SseTelemetry for ApiTelemetry {
fn on_sse_poll(
&self,
result: &std::result::Result<
Option<std::result::Result<Event, EventStreamError<TransportError>>>,
tokio::time::error::Elapsed,
>,
duration: Duration,
) {
self.session_telemetry.log_sse_event(result, duration);
}
}
impl WebsocketTelemetry for ApiTelemetry {
fn on_ws_request(&self, duration: Duration, error: Option<&ApiError>, connection_reused: bool) {
let error_message = error.map(telemetry_api_error_message);
let status = error.and_then(api_error_http_status);
let debug = error
.map(extract_response_debug_context_from_api_error)
.unwrap_or_default();
self.session_telemetry.record_websocket_request(
duration,
error_message.as_deref(),
connection_reused,
);
emit_feedback_request_tags_with_auth_env(
&FeedbackRequestTags {
endpoint: self.request_route_telemetry.endpoint,
auth_header_attached: self.auth_context.auth_header_attached,
auth_header_name: self.auth_context.auth_header_name,
auth_mode: self.auth_context.auth_mode,
auth_retry_after_unauthorized: Some(self.auth_context.retry_after_unauthorized),
auth_recovery_mode: self.auth_context.recovery_mode,
auth_recovery_phase: self.auth_context.recovery_phase,
auth_connection_reused: Some(connection_reused),
auth_request_id: debug.request_id.as_deref(),
auth_cf_ray: debug.cf_ray.as_deref(),
auth_error: debug.auth_error.as_deref(),
auth_error_code: debug.auth_error_code.as_deref(),
auth_recovery_followup_success: self
.auth_context
.retry_after_unauthorized
.then_some(error.is_none()),
auth_recovery_followup_status: self
.auth_context
.retry_after_unauthorized
.then_some(status)
.flatten(),
},
&self.auth_env_telemetry,
);
}
fn on_ws_event(
&self,
result: &std::result::Result<Option<std::result::Result<Message, Error>>, ApiError>,
duration: Duration,
) {
self.session_telemetry
.record_websocket_event(result, duration);
}
}
#[cfg(test)]
#[path = "client_tests.rs"]
mod tests;