Files
codex/codex-rs/core/src/client.rs
pakrym-oai 2070d5bfd3 [codex] Add response.processed websocket request (#21284)
## Summary

- Add a `response.processed` websocket request payload and sender for
Responses API websockets.
- Send `response.processed` from `try_run_sampling_request` after a
response completes, local turn processing succeeds, and the
session-owned feature flag is enabled.
- Add websocket coverage for both enabled and disabled feature-flag
behavior.

## Validation

- `just fmt`
- `cargo test -p codex-core response_processed`
- `cargo test -p codex-api responses_websocket`
- `cargo test -p codex-features
responses_websocket_response_processed_is_under_development`
- `git diff --check`
- `just fix -p codex-api -p codex-core -p codex-features`
- `git diff --check origin/main...HEAD`
2026-05-06 09:58:46 -07:00

2177 lines
83 KiB
Rust

//! Session- and turn-scoped helpers for talking to model provider APIs.
//!
//! `ModelClient` is intended to live for the lifetime of a Codex session and holds the stable
//! configuration and state needed to talk to a provider (auth, provider selection, conversation id,
//! and transport fallback state).
//!
//! Per-turn settings (model selection, reasoning controls, telemetry context, and turn metadata)
//! are passed explicitly to streaming and unary methods so that the turn lifetime is visible at the
//! call site.
//!
//! A [`ModelClientSession`] is created per turn and is used to stream one or more Responses API
//! requests during that turn. It caches a Responses WebSocket connection (opened lazily) and stores
//! per-turn state such as the `x-codex-turn-state` token used for sticky routing.
//!
//! WebSocket prewarm is a v2-only `response.create` with `generate=false`; it waits for completion
//! so the next request can reuse the same connection and `previous_response_id`.
//!
//! Turn execution performs prewarm as a best-effort step before the first stream request so the
//! subsequent request can reuse the same connection.
//!
//! ## Retry-Budget Tradeoff
//!
//! WebSocket prewarm is treated as the first websocket connection attempt for a turn. If it
//! fails, normal stream retry/fallback logic handles recovery on the same turn.
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::Mutex as StdMutex;
use std::sync::OnceLock;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::AtomicU64;
use std::sync::atomic::Ordering;
use codex_api::ApiError;
use codex_api::AuthProvider;
use codex_api::CompactClient as ApiCompactClient;
use codex_api::CompactionInput as ApiCompactionInput;
use codex_api::Compression;
use codex_api::MemoriesClient as ApiMemoriesClient;
use codex_api::MemorySummarizeInput as ApiMemorySummarizeInput;
use codex_api::MemorySummarizeOutput as ApiMemorySummarizeOutput;
use codex_api::Provider as ApiProvider;
use codex_api::RawMemory as ApiRawMemory;
use codex_api::RealtimeCallClient as ApiRealtimeCallClient;
use codex_api::RealtimeSessionConfig as ApiRealtimeSessionConfig;
use codex_api::Reasoning;
use codex_api::RequestTelemetry;
use codex_api::ReqwestTransport;
use codex_api::ResponseCreateWsRequest;
use codex_api::ResponsesApiRequest;
use codex_api::ResponsesClient as ApiResponsesClient;
use codex_api::ResponsesOptions as ApiResponsesOptions;
use codex_api::ResponsesWebsocketClient as ApiWebSocketResponsesClient;
use codex_api::ResponsesWebsocketConnection as ApiWebSocketConnection;
use codex_api::ResponsesWsRequest;
use codex_api::SharedAuthProvider;
use codex_api::SseTelemetry;
use codex_api::TransportError;
use codex_api::WebsocketTelemetry;
use codex_api::auth_header_telemetry;
use codex_api::build_session_headers;
use codex_api::create_text_param_for_request;
use codex_api::response_create_client_metadata;
use codex_app_server_protocol::AuthMode;
use codex_login::AuthManager;
use codex_login::CodexAuth;
use codex_login::RefreshTokenError;
use codex_login::UnauthorizedRecovery;
use codex_login::default_client::build_reqwest_client;
use codex_otel::SessionTelemetry;
use codex_otel::current_span_w3c_trace_context;
use codex_protocol::SessionId;
use codex_protocol::ThreadId;
use codex_protocol::config_types::ReasoningSummary as ReasoningSummaryConfig;
use codex_protocol::config_types::Verbosity as VerbosityConfig;
use codex_protocol::models::ResponseItem;
use codex_protocol::openai_models::ModelInfo;
use codex_protocol::openai_models::ReasoningEffort as ReasoningEffortConfig;
use codex_protocol::protocol::InternalSessionSource;
use codex_protocol::protocol::SessionSource;
use codex_protocol::protocol::SubAgentSource;
use codex_protocol::protocol::W3cTraceContext;
use codex_rollout_trace::CompactionTraceContext;
use codex_rollout_trace::InferenceTraceAttempt;
use codex_rollout_trace::InferenceTraceContext;
use codex_tools::create_tools_json_for_responses_api;
use eventsource_stream::Event;
use eventsource_stream::EventStreamError;
use futures::StreamExt;
use http::HeaderMap as ApiHeaderMap;
use http::HeaderValue;
use http::StatusCode as HttpStatusCode;
use reqwest::StatusCode;
use std::time::Duration;
use std::time::Instant;
use tokio::sync::mpsc;
use tokio::sync::oneshot;
use tokio::sync::oneshot::error::TryRecvError;
use tokio_tungstenite::tungstenite::Error;
use tokio_tungstenite::tungstenite::Message;
use tokio_util::sync::CancellationToken;
use tracing::debug;
use tracing::instrument;
use tracing::trace;
use tracing::warn;
use crate::client_common::Prompt;
use crate::client_common::ResponseEvent;
use crate::client_common::ResponseStream;
use crate::feedback_tags;
use crate::flags::CODEX_RS_SSE_FIXTURE;
use crate::util::emit_feedback_auth_recovery_tags;
use codex_api::map_api_error;
use codex_feedback::FeedbackRequestTags;
use codex_feedback::emit_feedback_request_tags_with_auth_env;
use codex_login::auth_env_telemetry::AuthEnvTelemetry;
use codex_login::auth_env_telemetry::collect_auth_env_telemetry;
use codex_model_provider::SharedModelProvider;
use codex_model_provider::create_model_provider;
#[cfg(test)]
use codex_model_provider_info::DEFAULT_WEBSOCKET_CONNECT_TIMEOUT_MS;
use codex_model_provider_info::ModelProviderInfo;
use codex_model_provider_info::WireApi;
use codex_protocol::error::CodexErr;
use codex_protocol::error::Result;
use codex_response_debug_context::extract_response_debug_context;
use codex_response_debug_context::extract_response_debug_context_from_api_error;
use codex_response_debug_context::telemetry_api_error_message;
use codex_response_debug_context::telemetry_transport_error_message;
pub const OPENAI_BETA_HEADER: &str = "OpenAI-Beta";
pub const X_CODEX_INSTALLATION_ID_HEADER: &str = "x-codex-installation-id";
pub const X_CODEX_TURN_STATE_HEADER: &str = "x-codex-turn-state";
pub const X_CODEX_TURN_METADATA_HEADER: &str = "x-codex-turn-metadata";
pub const X_CODEX_PARENT_THREAD_ID_HEADER: &str = "x-codex-parent-thread-id";
pub const X_CODEX_WINDOW_ID_HEADER: &str = "x-codex-window-id";
pub const X_OPENAI_MEMGEN_REQUEST_HEADER: &str = "x-openai-memgen-request";
pub const X_OPENAI_SUBAGENT_HEADER: &str = "x-openai-subagent";
pub const X_RESPONSESAPI_INCLUDE_TIMING_METRICS_HEADER: &str =
"x-responsesapi-include-timing-metrics";
const RESPONSES_WEBSOCKETS_V2_BETA_HEADER_VALUE: &str = "responses_websockets=2026-02-06";
const RESPONSES_ENDPOINT: &str = "/responses";
const RESPONSES_COMPACT_ENDPOINT: &str = "/responses/compact";
const MEMORIES_SUMMARIZE_ENDPOINT: &str = "/memories/trace_summarize";
#[cfg(test)]
pub(crate) const WEBSOCKET_CONNECT_TIMEOUT: Duration =
Duration::from_millis(DEFAULT_WEBSOCKET_CONNECT_TIMEOUT_MS);
pub(crate) struct CompactConversationRequestSettings {
pub(crate) effort: Option<ReasoningEffortConfig>,
pub(crate) summary: ReasoningSummaryConfig,
pub(crate) service_tier: Option<String>,
}
/// Session-scoped state shared by all [`ModelClient`] clones.
///
/// This is intentionally kept minimal so `ModelClient` does not need to hold a full `Config`. Most
/// configuration is per turn and is passed explicitly to streaming/unary methods.
#[derive(Debug)]
struct ModelClientState {
session_id: SessionId,
thread_id: ThreadId,
window_generation: AtomicU64,
installation_id: String,
provider: SharedModelProvider,
auth_env_telemetry: AuthEnvTelemetry,
session_source: SessionSource,
model_verbosity: Option<VerbosityConfig>,
enable_request_compression: bool,
include_timing_metrics: bool,
beta_features_header: Option<String>,
disable_websockets: AtomicBool,
cached_websocket_session: StdMutex<WebsocketSession>,
}
/// Resolved API client setup for a single request attempt.
///
/// Keeping this as a single bundle ensures prewarm and normal request paths
/// share the same auth/provider setup flow.
struct CurrentClientSetup {
auth: Option<CodexAuth>,
api_provider: ApiProvider,
api_auth: SharedAuthProvider,
}
#[derive(Clone, Copy)]
struct RequestRouteTelemetry {
endpoint: &'static str,
}
impl RequestRouteTelemetry {
fn for_endpoint(endpoint: &'static str) -> Self {
Self { endpoint }
}
}
/// A session-scoped client for model-provider API calls.
///
/// This holds configuration and state that should be shared across turns within a Codex session
/// (auth, provider selection, thread id, and transport fallback state).
///
/// WebSocket fallback is session-scoped: once a turn activates the HTTP fallback, subsequent turns
/// will also use HTTP for the remainder of the session.
///
/// Turn-scoped settings (model selection, reasoning controls, telemetry context, and turn
/// metadata) are passed explicitly to the relevant methods to keep turn lifetime visible at the
/// call site.
#[derive(Debug, Clone)]
pub struct ModelClient {
state: Arc<ModelClientState>,
}
/// A turn-scoped streaming session created from a [`ModelClient`].
///
/// The session establishes a Responses WebSocket connection lazily and reuses it across multiple
/// requests within the turn. It also caches per-turn state:
///
/// - The last full request, so subsequent calls can reuse incremental websocket request payloads
/// only when the current request is an incremental extension of the previous one.
/// - The `x-codex-turn-state` sticky-routing token, which must be replayed for all requests within
/// the same turn.
///
/// Create a fresh `ModelClientSession` for each Codex turn. Reusing it across turns would replay
/// the previous turn's sticky-routing token into the next turn, which violates the client/server
/// contract and can cause routing bugs.
pub struct ModelClientSession {
client: ModelClient,
websocket_session: WebsocketSession,
/// Turn state for sticky routing.
///
/// This is an `OnceLock` that stores the turn state value received from the server
/// on turn start via the `x-codex-turn-state` response header. Once set, this value
/// should be sent back to the server in the `x-codex-turn-state` request header for
/// all subsequent requests within the same turn to maintain sticky routing.
///
/// This is a contract between the client and server: we receive it at turn start,
/// keep sending it unchanged between turn requests (e.g., for retries, incremental
/// appends, or continuation requests), and must not send it between different turns.
turn_state: Arc<OnceLock<String>>,
}
#[derive(Debug, Clone)]
struct LastResponse {
response_id: String,
items_added: Vec<ResponseItem>,
}
#[derive(Debug, Default)]
struct WebsocketSession {
connection: Option<ApiWebSocketConnection>,
last_request: Option<ResponsesApiRequest>,
last_response_rx: Option<oneshot::Receiver<LastResponse>>,
connection_reused: StdMutex<bool>,
}
impl WebsocketSession {
fn set_connection_reused(&self, connection_reused: bool) {
*self
.connection_reused
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner) = connection_reused;
}
fn connection_reused(&self) -> bool {
*self
.connection_reused
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
}
}
enum WebsocketStreamOutcome {
Stream(ResponseStream),
FallbackToHttp,
}
/// Result of opening a WebRTC Realtime call.
///
/// The SDP answer goes back to the client. The call id and auth headers stay on the server so the
/// ordinary Realtime WebSocket machinery can join the same in-progress call as a sideband
/// controller.
pub(crate) struct RealtimeWebrtcCallStart {
pub(crate) sdp: String,
pub(crate) call_id: String,
pub(crate) sideband_headers: ApiHeaderMap,
}
/// Reuses the API-auth material that created the WebRTC call for the sideband WebSocket join.
///
/// API-key sessions send that API bearer. ChatGPT-auth sessions send their bearer plus account id;
/// transceiver is responsible for accepting that same call-create identity on the direct
/// `api.openai.com` sideband path.
fn sideband_websocket_auth_headers(api_auth: &dyn AuthProvider) -> ApiHeaderMap {
let mut headers = ApiHeaderMap::new();
api_auth.add_auth_headers(&mut headers);
headers
}
impl ModelClient {
#[allow(clippy::too_many_arguments)]
/// Creates a new session-scoped `ModelClient`.
///
/// All arguments are expected to be stable for the lifetime of a Codex session. Per-turn values
/// are passed to [`ModelClientSession::stream`] (and other turn-scoped methods) explicitly.
pub fn new(
auth_manager: Option<Arc<AuthManager>>,
session_id: SessionId,
thread_id: ThreadId,
installation_id: String,
provider_info: ModelProviderInfo,
session_source: SessionSource,
model_verbosity: Option<VerbosityConfig>,
enable_request_compression: bool,
include_timing_metrics: bool,
beta_features_header: Option<String>,
) -> Self {
let model_provider = create_model_provider(provider_info, auth_manager);
let codex_api_key_env_enabled = model_provider
.auth_manager()
.as_ref()
.is_some_and(|manager| manager.codex_api_key_env_enabled());
let auth_env_telemetry =
collect_auth_env_telemetry(model_provider.info(), codex_api_key_env_enabled);
Self {
state: Arc::new(ModelClientState {
session_id,
thread_id,
window_generation: AtomicU64::new(0),
installation_id,
provider: model_provider,
auth_env_telemetry,
session_source,
model_verbosity,
enable_request_compression,
include_timing_metrics,
beta_features_header,
disable_websockets: AtomicBool::new(false),
cached_websocket_session: StdMutex::new(WebsocketSession::default()),
}),
}
}
/// Creates a fresh turn-scoped streaming session.
///
/// This constructor does not perform network I/O itself; the session opens a websocket lazily
/// when the first stream request is issued.
pub fn new_session(&self) -> ModelClientSession {
ModelClientSession {
client: self.clone(),
websocket_session: self.take_cached_websocket_session(),
turn_state: Arc::new(OnceLock::new()),
}
}
pub(crate) fn auth_manager(&self) -> Option<Arc<AuthManager>> {
self.state.provider.auth_manager()
}
pub(crate) fn set_window_generation(&self, window_generation: u64) {
self.state
.window_generation
.store(window_generation, Ordering::Relaxed);
self.store_cached_websocket_session(WebsocketSession::default());
}
pub(crate) fn advance_window_generation(&self) {
self.state.window_generation.fetch_add(1, Ordering::Relaxed);
self.store_cached_websocket_session(WebsocketSession::default());
}
fn current_window_id(&self) -> String {
let thread_id = self.state.thread_id;
let window_generation = self.state.window_generation.load(Ordering::Relaxed);
format!("{thread_id}:{window_generation}")
}
fn take_cached_websocket_session(&self) -> WebsocketSession {
let mut cached_websocket_session = self
.state
.cached_websocket_session
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
std::mem::take(&mut *cached_websocket_session)
}
fn store_cached_websocket_session(&self, websocket_session: WebsocketSession) {
*self
.state
.cached_websocket_session
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner) = websocket_session;
}
pub(crate) fn force_http_fallback(
&self,
session_telemetry: &SessionTelemetry,
_model_info: &ModelInfo,
) -> bool {
let websocket_enabled = self.responses_websocket_enabled();
let activated =
websocket_enabled && !self.state.disable_websockets.swap(true, Ordering::Relaxed);
if activated {
warn!("falling back to HTTP");
session_telemetry.counter(
"codex.transport.fallback_to_http",
/*inc*/ 1,
&[("from_wire_api", "responses_websocket")],
);
}
self.store_cached_websocket_session(WebsocketSession::default());
activated
}
/// Compacts the current conversation history using the Compact endpoint.
///
/// This is a unary call (no streaming) that returns a new list of
/// `ResponseItem`s representing the compacted transcript.
///
/// The model selection and telemetry context are passed explicitly to keep `ModelClient`
/// session-scoped.
pub(crate) async fn compact_conversation_history(
&self,
prompt: &Prompt,
model_info: &ModelInfo,
settings: CompactConversationRequestSettings,
session_telemetry: &SessionTelemetry,
compaction_trace: &CompactionTraceContext,
) -> Result<Vec<ResponseItem>> {
if prompt.input.is_empty() {
return Ok(Vec::new());
}
let client_setup = self.current_client_setup().await?;
let transport = ReqwestTransport::new(build_reqwest_client());
let request_telemetry = Self::build_request_telemetry(
session_telemetry,
AuthRequestTelemetryContext::new(
client_setup.auth.as_ref().map(CodexAuth::auth_mode),
client_setup.api_auth.as_ref(),
PendingUnauthorizedRetry::default(),
),
RequestRouteTelemetry::for_endpoint(RESPONSES_COMPACT_ENDPOINT),
self.state.auth_env_telemetry.clone(),
);
let request = self.build_responses_request(
&client_setup.api_provider,
prompt,
model_info,
settings.effort,
settings.summary,
settings.service_tier,
)?;
let ResponsesApiRequest {
model,
instructions,
input,
tools,
parallel_tool_calls,
reasoning,
service_tier,
prompt_cache_key,
text,
..
} = request;
let client =
ApiCompactClient::new(transport, client_setup.api_provider, client_setup.api_auth)
.with_telemetry(Some(request_telemetry));
let payload = ApiCompactionInput {
model: &model,
input: &input,
instructions: &instructions,
tools,
parallel_tool_calls,
reasoning,
service_tier: service_tier.as_deref(),
prompt_cache_key: prompt_cache_key.as_deref(),
text,
};
let mut extra_headers = ApiHeaderMap::new();
if let Ok(header_value) = HeaderValue::from_str(&self.state.installation_id) {
extra_headers.insert(X_CODEX_INSTALLATION_ID_HEADER, header_value);
}
extra_headers.extend(build_responses_headers(
self.state.beta_features_header.as_deref(),
/*turn_state*/ None,
/*turn_metadata_header*/ None,
));
extra_headers.extend(self.build_responses_identity_headers());
extra_headers.extend(build_session_headers(
Some(self.state.session_id.to_string()),
Some(self.state.thread_id.to_string()),
));
let trace_attempt = compaction_trace.start_attempt(&payload);
let result = client
.compact_input(&payload, extra_headers)
.await
.map_err(map_api_error);
trace_attempt.record_result(result.as_deref());
result
}
pub(crate) async fn create_realtime_call_with_headers(
&self,
sdp: String,
session_config: ApiRealtimeSessionConfig,
extra_headers: ApiHeaderMap,
) -> Result<RealtimeWebrtcCallStart> {
// Create the media call over HTTP first, then retain matching auth so realtime can attach
// the server-side control WebSocket to the call id from that HTTP response.
let client_setup = self.current_client_setup().await?;
let mut sideband_headers = extra_headers.clone();
sideband_headers.extend(sideband_websocket_auth_headers(
client_setup.api_auth.as_ref(),
));
let transport = ReqwestTransport::new(build_reqwest_client());
let response =
ApiRealtimeCallClient::new(transport, client_setup.api_provider, client_setup.api_auth)
.create_with_session_and_headers(sdp, session_config, extra_headers)
.await
.map_err(map_api_error)?;
Ok(RealtimeWebrtcCallStart {
sdp: response.sdp,
call_id: response.call_id,
sideband_headers,
})
}
/// Builds memory summaries for each provided normalized raw memory.
///
/// This is a unary call (no streaming) to `/v1/memories/trace_summarize`.
///
/// The model selection, reasoning effort, and telemetry context are passed explicitly to keep
/// `ModelClient` session-scoped.
pub async fn summarize_memories(
&self,
raw_memories: Vec<ApiRawMemory>,
model_info: &ModelInfo,
effort: Option<ReasoningEffortConfig>,
session_telemetry: &SessionTelemetry,
) -> Result<Vec<ApiMemorySummarizeOutput>> {
if raw_memories.is_empty() {
return Ok(Vec::new());
}
let client_setup = self.current_client_setup().await?;
let transport = ReqwestTransport::new(build_reqwest_client());
let request_telemetry = Self::build_request_telemetry(
session_telemetry,
AuthRequestTelemetryContext::new(
client_setup.auth.as_ref().map(CodexAuth::auth_mode),
client_setup.api_auth.as_ref(),
PendingUnauthorizedRetry::default(),
),
RequestRouteTelemetry::for_endpoint(MEMORIES_SUMMARIZE_ENDPOINT),
self.state.auth_env_telemetry.clone(),
);
let client =
ApiMemoriesClient::new(transport, client_setup.api_provider, client_setup.api_auth)
.with_telemetry(Some(request_telemetry));
let payload = ApiMemorySummarizeInput {
model: model_info.slug.clone(),
raw_memories,
reasoning: effort.map(|effort| Reasoning {
effort: Some(effort),
summary: None,
}),
};
client
.summarize_input(&payload, self.build_subagent_headers())
.await
.map_err(map_api_error)
}
fn build_subagent_headers(&self) -> ApiHeaderMap {
let mut extra_headers = ApiHeaderMap::new();
if let Some(subagent) = subagent_header_value(&self.state.session_source)
&& let Ok(val) = HeaderValue::from_str(&subagent)
{
extra_headers.insert(X_OPENAI_SUBAGENT_HEADER, val);
}
if matches!(
self.state.session_source,
SessionSource::Internal(InternalSessionSource::MemoryConsolidation)
) {
extra_headers.insert(
X_OPENAI_MEMGEN_REQUEST_HEADER,
HeaderValue::from_static("true"),
);
}
extra_headers
}
fn build_responses_identity_headers(&self) -> ApiHeaderMap {
let mut extra_headers = self.build_subagent_headers();
if let Some(parent_thread_id) = parent_thread_id_header_value(&self.state.session_source)
&& let Ok(val) = HeaderValue::from_str(&parent_thread_id)
{
extra_headers.insert(X_CODEX_PARENT_THREAD_ID_HEADER, val);
}
if let Ok(val) = HeaderValue::from_str(&self.current_window_id()) {
extra_headers.insert(X_CODEX_WINDOW_ID_HEADER, val);
}
extra_headers
}
fn build_ws_client_metadata(
&self,
turn_metadata_header: Option<&str>,
) -> HashMap<String, String> {
let mut client_metadata = HashMap::new();
client_metadata.insert(
X_CODEX_INSTALLATION_ID_HEADER.to_string(),
self.state.installation_id.clone(),
);
client_metadata.insert(
X_CODEX_WINDOW_ID_HEADER.to_string(),
self.current_window_id(),
);
if let Some(subagent) = subagent_header_value(&self.state.session_source) {
client_metadata.insert(X_OPENAI_SUBAGENT_HEADER.to_string(), subagent);
}
if let Some(parent_thread_id) = parent_thread_id_header_value(&self.state.session_source) {
client_metadata.insert(
X_CODEX_PARENT_THREAD_ID_HEADER.to_string(),
parent_thread_id,
);
}
if let Some(turn_metadata_header) = parse_turn_metadata_header(turn_metadata_header)
&& let Ok(turn_metadata) = turn_metadata_header.to_str()
{
client_metadata.insert(
X_CODEX_TURN_METADATA_HEADER.to_string(),
turn_metadata.to_string(),
);
}
client_metadata
}
/// Builds request telemetry for unary API calls (e.g., Compact endpoint).
fn build_request_telemetry(
session_telemetry: &SessionTelemetry,
auth_context: AuthRequestTelemetryContext,
request_route_telemetry: RequestRouteTelemetry,
auth_env_telemetry: AuthEnvTelemetry,
) -> Arc<dyn RequestTelemetry> {
let telemetry = Arc::new(ApiTelemetry::new(
session_telemetry.clone(),
auth_context,
request_route_telemetry,
auth_env_telemetry,
));
let request_telemetry: Arc<dyn RequestTelemetry> = telemetry;
request_telemetry
}
fn build_reasoning(
model_info: &ModelInfo,
effort: Option<ReasoningEffortConfig>,
summary: ReasoningSummaryConfig,
) -> Option<Reasoning> {
if model_info.supports_reasoning_summaries {
Some(Reasoning {
effort: effort.or(model_info.default_reasoning_level),
summary: if summary == ReasoningSummaryConfig::None {
None
} else {
Some(summary)
},
})
} else {
None
}
}
fn build_responses_request(
&self,
provider: &codex_api::Provider,
prompt: &Prompt,
model_info: &ModelInfo,
effort: Option<ReasoningEffortConfig>,
summary: ReasoningSummaryConfig,
service_tier: Option<String>,
) -> Result<ResponsesApiRequest> {
let instructions = &prompt.base_instructions.text;
let input = prompt.get_formatted_input();
let tools = create_tools_json_for_responses_api(&prompt.tools)?;
let reasoning = Self::build_reasoning(model_info, effort, summary);
let include = if reasoning.is_some() {
vec!["reasoning.encrypted_content".to_string()]
} else {
Vec::new()
};
let verbosity = if model_info.support_verbosity {
self.state.model_verbosity.or(model_info.default_verbosity)
} else {
if self.state.model_verbosity.is_some() {
warn!(
"model_verbosity is set but ignored as the model does not support verbosity: {}",
model_info.slug
);
}
None
};
let text = create_text_param_for_request(
verbosity,
&prompt.output_schema,
prompt.output_schema_strict,
);
let prompt_cache_key = Some(self.state.thread_id.to_string());
let request = ResponsesApiRequest {
model: model_info.slug.clone(),
instructions: instructions.clone(),
input,
tools,
tool_choice: "auto".to_string(),
parallel_tool_calls: prompt.parallel_tool_calls,
reasoning,
store: provider.is_azure_responses_endpoint(),
stream: true,
include,
service_tier,
prompt_cache_key,
text,
client_metadata: Some(HashMap::from([(
X_CODEX_INSTALLATION_ID_HEADER.to_string(),
self.state.installation_id.clone(),
)])),
};
Ok(request)
}
/// Returns whether the Responses-over-WebSocket transport is active for this session.
///
/// WebSocket use is controlled by provider capability and session-scoped fallback state.
pub fn responses_websocket_enabled(&self) -> bool {
if !self.state.provider.info().supports_websockets
|| self.state.disable_websockets.load(Ordering::Relaxed)
|| (*CODEX_RS_SSE_FIXTURE).is_some()
{
return false;
}
true
}
/// Returns auth + provider configuration resolved from the current session auth state.
///
/// This centralizes setup used by both prewarm and normal request paths so they stay in
/// lockstep when auth/provider resolution changes.
async fn current_client_setup(&self) -> Result<CurrentClientSetup> {
let auth = self.state.provider.auth().await;
let api_provider = self.state.provider.api_provider().await?;
let api_auth = self.state.provider.api_auth().await?;
Ok(CurrentClientSetup {
auth,
api_provider,
api_auth,
})
}
/// Opens a websocket connection using the same header and telemetry wiring as normal turns.
///
/// Both startup prewarm and in-turn `needs_new` reconnects call this path so handshake
/// behavior remains consistent across both flows.
#[allow(clippy::too_many_arguments)]
async fn connect_websocket(
&self,
session_telemetry: &SessionTelemetry,
api_provider: codex_api::Provider,
api_auth: SharedAuthProvider,
turn_state: Option<Arc<OnceLock<String>>>,
turn_metadata_header: Option<&str>,
auth_context: AuthRequestTelemetryContext,
request_route_telemetry: RequestRouteTelemetry,
) -> std::result::Result<ApiWebSocketConnection, ApiError> {
let headers = self.build_websocket_headers(turn_state.as_ref(), turn_metadata_header);
let websocket_telemetry = ModelClientSession::build_websocket_telemetry(
session_telemetry,
auth_context,
request_route_telemetry,
self.state.auth_env_telemetry.clone(),
);
let websocket_connect_timeout = self.state.provider.info().websocket_connect_timeout();
let start = Instant::now();
let result = match tokio::time::timeout(
websocket_connect_timeout,
ApiWebSocketResponsesClient::new(api_provider, api_auth).connect(
headers,
codex_login::default_client::default_headers(),
turn_state,
Some(websocket_telemetry),
),
)
.await
{
Ok(result) => result,
Err(_) => Err(ApiError::Transport(TransportError::Timeout)),
};
let error_message = result.as_ref().err().map(telemetry_api_error_message);
let response_debug = result
.as_ref()
.err()
.map(extract_response_debug_context_from_api_error)
.unwrap_or_default();
let status = result.as_ref().err().and_then(api_error_http_status);
session_telemetry.record_websocket_connect(
start.elapsed(),
status,
error_message.as_deref(),
auth_context.auth_header_attached,
auth_context.auth_header_name,
auth_context.retry_after_unauthorized,
auth_context.recovery_mode,
auth_context.recovery_phase,
request_route_telemetry.endpoint,
/*connection_reused*/ false,
response_debug.request_id.as_deref(),
response_debug.cf_ray.as_deref(),
response_debug.auth_error.as_deref(),
response_debug.auth_error_code.as_deref(),
);
emit_feedback_request_tags_with_auth_env(
&FeedbackRequestTags {
endpoint: request_route_telemetry.endpoint,
auth_header_attached: auth_context.auth_header_attached,
auth_header_name: auth_context.auth_header_name,
auth_mode: auth_context.auth_mode,
auth_retry_after_unauthorized: Some(auth_context.retry_after_unauthorized),
auth_recovery_mode: auth_context.recovery_mode,
auth_recovery_phase: auth_context.recovery_phase,
auth_connection_reused: Some(false),
auth_request_id: response_debug.request_id.as_deref(),
auth_cf_ray: response_debug.cf_ray.as_deref(),
auth_error: response_debug.auth_error.as_deref(),
auth_error_code: response_debug.auth_error_code.as_deref(),
auth_recovery_followup_success: auth_context
.retry_after_unauthorized
.then_some(result.is_ok()),
auth_recovery_followup_status: auth_context
.retry_after_unauthorized
.then_some(status)
.flatten(),
},
&self.state.auth_env_telemetry,
);
result
}
/// Builds websocket handshake headers for both prewarm and turn-time reconnect.
///
/// Callers should pass the current turn-state lock when available so sticky-routing state is
/// replayed on reconnect within the same turn.
fn build_websocket_headers(
&self,
turn_state: Option<&Arc<OnceLock<String>>>,
turn_metadata_header: Option<&str>,
) -> ApiHeaderMap {
let turn_metadata_header = parse_turn_metadata_header(turn_metadata_header);
let session_id = self.state.session_id.to_string();
let thread_id = self.state.thread_id.to_string();
let mut headers = build_responses_headers(
self.state.beta_features_header.as_deref(),
turn_state,
turn_metadata_header.as_ref(),
);
if let Ok(header_value) = HeaderValue::from_str(&thread_id) {
headers.insert("x-client-request-id", header_value);
}
headers.extend(build_session_headers(Some(session_id), Some(thread_id)));
headers.extend(self.build_responses_identity_headers());
headers.insert(
OPENAI_BETA_HEADER,
HeaderValue::from_static(RESPONSES_WEBSOCKETS_V2_BETA_HEADER_VALUE),
);
if self.state.include_timing_metrics {
headers.insert(
X_RESPONSESAPI_INCLUDE_TIMING_METRICS_HEADER,
HeaderValue::from_static("true"),
);
}
headers
}
}
impl Drop for ModelClientSession {
fn drop(&mut self) {
let websocket_session = std::mem::take(&mut self.websocket_session);
self.client
.store_cached_websocket_session(websocket_session);
}
}
impl ModelClientSession {
pub(crate) fn reset_websocket_session(&mut self) {
self.websocket_session.connection = None;
self.websocket_session.last_request = None;
self.websocket_session.last_response_rx = None;
self.websocket_session
.set_connection_reused(/*connection_reused*/ false);
}
pub(crate) async fn send_response_processed(&self, response_id: &str) {
let Some(connection) = self.websocket_session.connection.as_ref() else {
return;
};
if let Err(err) = connection
.send_response_processed(response_id.to_string())
.await
{
debug!("failed to send response.processed websocket request: {err}");
}
}
#[allow(clippy::too_many_arguments)]
/// Builds shared Responses API transport options and request-body options.
///
/// Keeping option construction in one place ensures request-scoped headers are consistent
/// regardless of transport choice.
fn build_responses_options(
&self,
turn_metadata_header: Option<&str>,
compression: Compression,
) -> ApiResponsesOptions {
let turn_metadata_header = parse_turn_metadata_header(turn_metadata_header);
let session_id = self.client.state.session_id.to_string();
let thread_id = self.client.state.thread_id.to_string();
ApiResponsesOptions {
session_id: Some(session_id),
thread_id: Some(thread_id),
session_source: Some(self.client.state.session_source.clone()),
extra_headers: {
let mut headers = build_responses_headers(
self.client.state.beta_features_header.as_deref(),
Some(&self.turn_state),
turn_metadata_header.as_ref(),
);
headers.extend(self.client.build_responses_identity_headers());
headers
},
compression,
turn_state: Some(Arc::clone(&self.turn_state)),
}
}
fn get_incremental_items(
&self,
request: &ResponsesApiRequest,
last_response: Option<&LastResponse>,
allow_empty_delta: bool,
) -> Option<Vec<ResponseItem>> {
// Checks whether the current request is an incremental extension of the previous request.
// We only reuse an incremental input delta when non-input request fields are unchanged and
// `input` is a strict
// extension of the previous known input. Server-returned output items are treated as part
// of the baseline so we do not resend them.
let previous_request = self.websocket_session.last_request.as_ref()?;
let mut previous_without_input = previous_request.clone();
previous_without_input.input.clear();
let mut request_without_input = request.clone();
request_without_input.input.clear();
if previous_without_input != request_without_input {
trace!(
"incremental request failed, properties didn't match {previous_without_input:?} != {request_without_input:?}"
);
return None;
}
let mut baseline = previous_request.input.clone();
if let Some(last_response) = last_response {
baseline.extend(last_response.items_added.clone());
}
let baseline_len = baseline.len();
if request.input.starts_with(&baseline)
&& (allow_empty_delta || baseline_len < request.input.len())
{
Some(request.input[baseline_len..].to_vec())
} else {
trace!("incremental request failed, items didn't match");
None
}
}
fn get_last_response(&mut self) -> Option<LastResponse> {
self.websocket_session
.last_response_rx
.take()
.and_then(|mut receiver| match receiver.try_recv() {
Ok(last_response) => Some(last_response),
Err(TryRecvError::Closed) | Err(TryRecvError::Empty) => None,
})
}
fn prepare_websocket_request(
&mut self,
payload: ResponseCreateWsRequest,
request: &ResponsesApiRequest,
) -> ResponsesWsRequest {
let Some(last_response) = self.get_last_response() else {
return ResponsesWsRequest::ResponseCreate(payload);
};
let Some(incremental_items) = self.get_incremental_items(
request,
Some(&last_response),
/*allow_empty_delta*/ true,
) else {
return ResponsesWsRequest::ResponseCreate(payload);
};
if last_response.response_id.is_empty() {
trace!("incremental request failed, no previous response id");
return ResponsesWsRequest::ResponseCreate(payload);
}
ResponsesWsRequest::ResponseCreate(ResponseCreateWsRequest {
previous_response_id: Some(last_response.response_id),
input: incremental_items,
..payload
})
}
/// Opportunistically preconnects a websocket for this turn-scoped client session.
///
/// This performs only connection setup; it never sends prompt payloads.
pub async fn preconnect_websocket(
&mut self,
session_telemetry: &SessionTelemetry,
_model_info: &ModelInfo,
) -> std::result::Result<(), ApiError> {
if !self.client.responses_websocket_enabled() {
return Ok(());
}
if self.websocket_session.connection.is_some() {
return Ok(());
}
let client_setup = self.client.current_client_setup().await.map_err(|err| {
ApiError::Stream(format!(
"failed to build websocket prewarm client setup: {err}"
))
})?;
let auth_context = AuthRequestTelemetryContext::new(
client_setup.auth.as_ref().map(CodexAuth::auth_mode),
client_setup.api_auth.as_ref(),
PendingUnauthorizedRetry::default(),
);
let connection = self
.client
.connect_websocket(
session_telemetry,
client_setup.api_provider,
client_setup.api_auth,
Some(Arc::clone(&self.turn_state)),
/*turn_metadata_header*/ None,
auth_context,
RequestRouteTelemetry::for_endpoint(RESPONSES_ENDPOINT),
)
.await?;
self.websocket_session.connection = Some(connection);
self.websocket_session
.set_connection_reused(/*connection_reused*/ false);
Ok(())
}
/// Returns a websocket connection for this turn.
#[instrument(
name = "model_client.websocket_connection",
level = "info",
skip_all,
fields(
provider = %self.client.state.provider.info().name,
wire_api = %self.client.state.provider.info().wire_api,
transport = "responses_websocket",
api.path = "responses",
turn.has_metadata_header = params.turn_metadata_header.is_some()
)
)]
async fn websocket_connection(
&mut self,
params: WebsocketConnectParams<'_>,
) -> std::result::Result<&ApiWebSocketConnection, ApiError> {
let WebsocketConnectParams {
session_telemetry,
api_provider,
api_auth,
turn_metadata_header,
options,
auth_context,
request_route_telemetry,
} = params;
let needs_new = match self.websocket_session.connection.as_ref() {
Some(conn) => conn.is_closed().await,
None => true,
};
if needs_new {
self.websocket_session.last_request = None;
self.websocket_session.last_response_rx = None;
let turn_state = options
.turn_state
.clone()
.unwrap_or_else(|| Arc::clone(&self.turn_state));
let new_conn = match self
.client
.connect_websocket(
session_telemetry,
api_provider,
api_auth,
Some(turn_state),
turn_metadata_header,
auth_context,
request_route_telemetry,
)
.await
{
Ok(new_conn) => new_conn,
Err(err) => {
if matches!(err, ApiError::Transport(TransportError::Timeout)) {
self.reset_websocket_session();
}
return Err(err);
}
};
self.websocket_session.connection = Some(new_conn);
self.websocket_session
.set_connection_reused(/*connection_reused*/ false);
} else {
self.websocket_session
.set_connection_reused(/*connection_reused*/ true);
}
self.websocket_session
.connection
.as_ref()
.ok_or(ApiError::Stream(
"websocket connection is unavailable".to_string(),
))
}
fn responses_request_compression(&self, auth: Option<&CodexAuth>) -> Compression {
if self.client.state.enable_request_compression
&& auth.is_some_and(CodexAuth::uses_codex_backend)
&& self.client.state.provider.info().is_openai()
{
Compression::Zstd
} else {
Compression::None
}
}
/// Streams a turn via the OpenAI Responses API.
///
/// Handles SSE fixtures, reasoning summaries, verbosity, and the
/// `text` controls used for output schemas.
#[allow(clippy::too_many_arguments)]
#[instrument(
name = "model_client.stream_responses_api",
level = "info",
skip_all,
fields(
model = %model_info.slug,
wire_api = %self.client.state.provider.info().wire_api,
transport = "responses_http",
http.method = "POST",
api.path = "responses",
turn.has_metadata_header = turn_metadata_header.is_some()
)
)]
async fn stream_responses_api(
&self,
prompt: &Prompt,
model_info: &ModelInfo,
session_telemetry: &SessionTelemetry,
effort: Option<ReasoningEffortConfig>,
summary: ReasoningSummaryConfig,
service_tier: Option<String>,
turn_metadata_header: Option<&str>,
inference_trace: &InferenceTraceContext,
) -> Result<ResponseStream> {
if let Some(path) = &*CODEX_RS_SSE_FIXTURE {
warn!(path, "Streaming from fixture");
let stream = codex_api::stream_from_fixture(
path,
self.client.state.provider.info().stream_idle_timeout(),
)
.map_err(map_api_error)?;
let (stream, _last_request_rx) = map_response_stream(
stream,
session_telemetry.clone(),
InferenceTraceAttempt::disabled(),
);
return Ok(stream);
}
let auth_manager = self.client.state.provider.auth_manager();
let mut auth_recovery = auth_manager
.as_ref()
.map(AuthManager::unauthorized_recovery);
let mut pending_retry = PendingUnauthorizedRetry::default();
loop {
let client_setup = self.client.current_client_setup().await?;
let transport = ReqwestTransport::new(build_reqwest_client());
let request_auth_context = AuthRequestTelemetryContext::new(
client_setup.auth.as_ref().map(CodexAuth::auth_mode),
client_setup.api_auth.as_ref(),
pending_retry,
);
let (request_telemetry, sse_telemetry) = Self::build_streaming_telemetry(
session_telemetry,
request_auth_context,
RequestRouteTelemetry::for_endpoint(RESPONSES_ENDPOINT),
self.client.state.auth_env_telemetry.clone(),
);
let compression = self.responses_request_compression(client_setup.auth.as_ref());
let options = self.build_responses_options(turn_metadata_header, compression);
let request = self.client.build_responses_request(
&client_setup.api_provider,
prompt,
model_info,
effort,
summary,
service_tier.clone(),
)?;
let inference_trace_attempt = inference_trace.start_attempt();
inference_trace_attempt.record_started(&request);
let client = ApiResponsesClient::new(
transport,
client_setup.api_provider,
client_setup.api_auth,
)
.with_telemetry(Some(request_telemetry), Some(sse_telemetry));
let stream_result = client.stream_request(request, options).await;
match stream_result {
Ok(stream) => {
let (stream, _) = map_response_stream(
stream,
session_telemetry.clone(),
inference_trace_attempt,
);
return Ok(stream);
}
Err(ApiError::Transport(
unauthorized_transport @ TransportError::Http { status, .. },
)) if status == StatusCode::UNAUTHORIZED => {
let response_debug_context =
extract_response_debug_context(&unauthorized_transport);
inference_trace_attempt.record_failed(
&unauthorized_transport,
response_debug_context.request_id.as_deref(),
/*output_items*/ &[],
);
pending_retry = PendingUnauthorizedRetry::from_recovery(
handle_unauthorized(
unauthorized_transport,
&mut auth_recovery,
session_telemetry,
)
.await?,
);
continue;
}
Err(err) => {
let response_debug_context =
extract_response_debug_context_from_api_error(&err);
let err = map_api_error(err);
inference_trace_attempt.record_failed(
&err,
response_debug_context.request_id.as_deref(),
/*output_items*/ &[],
);
return Err(err);
}
}
}
}
/// Streams a turn via the Responses API over WebSocket transport.
#[allow(clippy::too_many_arguments)]
#[instrument(
name = "model_client.stream_responses_websocket",
level = "info",
skip_all,
fields(
model = %model_info.slug,
wire_api = %self.client.state.provider.info().wire_api,
transport = "responses_websocket",
api.path = "responses",
turn.has_metadata_header = turn_metadata_header.is_some(),
websocket.warmup = warmup
)
)]
async fn stream_responses_websocket(
&mut self,
prompt: &Prompt,
model_info: &ModelInfo,
session_telemetry: &SessionTelemetry,
effort: Option<ReasoningEffortConfig>,
summary: ReasoningSummaryConfig,
service_tier: Option<String>,
turn_metadata_header: Option<&str>,
warmup: bool,
request_trace: Option<W3cTraceContext>,
inference_trace: &InferenceTraceContext,
) -> Result<WebsocketStreamOutcome> {
let auth_manager = self.client.state.provider.auth_manager();
let mut auth_recovery = auth_manager
.as_ref()
.map(AuthManager::unauthorized_recovery);
let mut pending_retry = PendingUnauthorizedRetry::default();
loop {
let client_setup = self.client.current_client_setup().await?;
let request_auth_context = AuthRequestTelemetryContext::new(
client_setup.auth.as_ref().map(CodexAuth::auth_mode),
client_setup.api_auth.as_ref(),
pending_retry,
);
let compression = self.responses_request_compression(client_setup.auth.as_ref());
let options = self.build_responses_options(turn_metadata_header, compression);
let request = self.client.build_responses_request(
&client_setup.api_provider,
prompt,
model_info,
effort,
summary,
service_tier.clone(),
)?;
let mut ws_payload = ResponseCreateWsRequest {
client_metadata: response_create_client_metadata(
Some(self.client.build_ws_client_metadata(turn_metadata_header)),
request_trace.as_ref(),
),
..ResponseCreateWsRequest::from(&request)
};
if warmup {
ws_payload.generate = Some(false);
}
match self
.websocket_connection(WebsocketConnectParams {
session_telemetry,
api_provider: client_setup.api_provider,
api_auth: client_setup.api_auth,
turn_metadata_header,
options: &options,
auth_context: request_auth_context,
request_route_telemetry: RequestRouteTelemetry::for_endpoint(
RESPONSES_ENDPOINT,
),
})
.await
{
Ok(_) => {}
Err(ApiError::Transport(TransportError::Http { status, .. }))
if status == StatusCode::UPGRADE_REQUIRED =>
{
return Ok(WebsocketStreamOutcome::FallbackToHttp);
}
Err(ApiError::Transport(
unauthorized_transport @ TransportError::Http { status, .. },
)) if status == StatusCode::UNAUTHORIZED => {
pending_retry = PendingUnauthorizedRetry::from_recovery(
handle_unauthorized(
unauthorized_transport,
&mut auth_recovery,
session_telemetry,
)
.await?,
);
continue;
}
Err(err) => return Err(map_api_error(err)),
}
let ws_request = self.prepare_websocket_request(ws_payload, &request);
self.websocket_session.last_request = Some(request);
let inference_trace_attempt = if warmup {
// Prewarm sends `generate=false`; it is connection setup, not a
// model inference attempt that should appear in rollout traces.
InferenceTraceAttempt::disabled()
} else {
inference_trace.start_attempt()
};
inference_trace_attempt.record_started(&ws_request);
let websocket_connection =
self.websocket_session.connection.as_ref().ok_or_else(|| {
map_api_error(ApiError::Stream(
"websocket connection is unavailable".to_string(),
))
})?;
let stream_result = websocket_connection
.stream_request(ws_request, self.websocket_session.connection_reused())
.await
.map_err(|err| {
let response_debug_context =
extract_response_debug_context_from_api_error(&err);
let err = map_api_error(err);
inference_trace_attempt.record_failed(
&err,
response_debug_context.request_id.as_deref(),
/*output_items*/ &[],
);
err
})?;
let (stream, last_request_rx) = map_response_stream(
stream_result,
session_telemetry.clone(),
inference_trace_attempt,
);
self.websocket_session.last_response_rx = Some(last_request_rx);
return Ok(WebsocketStreamOutcome::Stream(stream));
}
}
/// Builds request and SSE telemetry for streaming API calls.
fn build_streaming_telemetry(
session_telemetry: &SessionTelemetry,
auth_context: AuthRequestTelemetryContext,
request_route_telemetry: RequestRouteTelemetry,
auth_env_telemetry: AuthEnvTelemetry,
) -> (Arc<dyn RequestTelemetry>, Arc<dyn SseTelemetry>) {
let telemetry = Arc::new(ApiTelemetry::new(
session_telemetry.clone(),
auth_context,
request_route_telemetry,
auth_env_telemetry,
));
let request_telemetry: Arc<dyn RequestTelemetry> = telemetry.clone();
let sse_telemetry: Arc<dyn SseTelemetry> = telemetry;
(request_telemetry, sse_telemetry)
}
/// Builds telemetry for the Responses API WebSocket transport.
fn build_websocket_telemetry(
session_telemetry: &SessionTelemetry,
auth_context: AuthRequestTelemetryContext,
request_route_telemetry: RequestRouteTelemetry,
auth_env_telemetry: AuthEnvTelemetry,
) -> Arc<dyn WebsocketTelemetry> {
let telemetry = Arc::new(ApiTelemetry::new(
session_telemetry.clone(),
auth_context,
request_route_telemetry,
auth_env_telemetry,
));
let websocket_telemetry: Arc<dyn WebsocketTelemetry> = telemetry;
websocket_telemetry
}
#[allow(clippy::too_many_arguments)]
pub async fn prewarm_websocket(
&mut self,
prompt: &Prompt,
model_info: &ModelInfo,
session_telemetry: &SessionTelemetry,
effort: Option<ReasoningEffortConfig>,
summary: ReasoningSummaryConfig,
service_tier: Option<String>,
turn_metadata_header: Option<&str>,
) -> Result<()> {
if !self.client.responses_websocket_enabled() {
return Ok(());
}
if self.websocket_session.last_request.is_some() {
return Ok(());
}
let disabled_trace = InferenceTraceContext::disabled();
match self
.stream_responses_websocket(
prompt,
model_info,
session_telemetry,
effort,
summary,
service_tier,
turn_metadata_header,
/*warmup*/ true,
current_span_w3c_trace_context(),
&disabled_trace,
)
.await
{
Ok(WebsocketStreamOutcome::Stream(mut stream)) => {
// Wait for the v2 warmup request to complete before sending the first turn request.
while let Some(event) = stream.next().await {
match event {
Ok(ResponseEvent::Completed { .. }) => break,
Err(err) => return Err(err),
_ => {}
}
}
Ok(())
}
Ok(WebsocketStreamOutcome::FallbackToHttp) => {
self.try_switch_fallback_transport(session_telemetry, model_info);
Ok(())
}
Err(err) => Err(err),
}
}
#[allow(clippy::too_many_arguments)]
/// Streams a single model request within the current turn.
///
/// The caller is responsible for passing per-turn settings explicitly (model selection,
/// reasoning settings, telemetry context, and turn metadata). This method will prefer the
/// Responses WebSocket transport when the provider supports it and it remains healthy, and will
/// fall back to the HTTP Responses API transport otherwise. The trace context may be enabled or
/// disabled, but is always explicit so transport paths do not need separate trace/no-trace
/// branches.
pub async fn stream(
&mut self,
prompt: &Prompt,
model_info: &ModelInfo,
session_telemetry: &SessionTelemetry,
effort: Option<ReasoningEffortConfig>,
summary: ReasoningSummaryConfig,
service_tier: Option<String>,
turn_metadata_header: Option<&str>,
inference_trace: &InferenceTraceContext,
) -> Result<ResponseStream> {
let wire_api = self.client.state.provider.info().wire_api;
match wire_api {
WireApi::Responses => {
if self.client.responses_websocket_enabled() {
let request_trace = current_span_w3c_trace_context();
match self
.stream_responses_websocket(
prompt,
model_info,
session_telemetry,
effort,
summary,
service_tier.clone(),
turn_metadata_header,
/*warmup*/ false,
request_trace,
inference_trace,
)
.await?
{
WebsocketStreamOutcome::Stream(stream) => return Ok(stream),
WebsocketStreamOutcome::FallbackToHttp => {
self.try_switch_fallback_transport(session_telemetry, model_info);
}
}
}
self.stream_responses_api(
prompt,
model_info,
session_telemetry,
effort,
summary,
service_tier,
turn_metadata_header,
inference_trace,
)
.await
}
}
}
/// Permanently disables WebSockets for this Codex session and resets WebSocket state.
///
/// This is used after exhausting the provider retry budget, to force subsequent requests onto
/// the HTTP transport.
///
/// Returns `true` if this call activated fallback, or `false` if fallback was already active.
pub(crate) fn try_switch_fallback_transport(
&mut self,
session_telemetry: &SessionTelemetry,
model_info: &ModelInfo,
) -> bool {
let activated = self
.client
.force_http_fallback(session_telemetry, model_info);
self.websocket_session = WebsocketSession::default();
activated
}
}
/// Parses per-turn metadata into an HTTP header value.
///
/// Invalid values are treated as absent so callers can compare and propagate
/// metadata with the same sanitization path used when constructing headers.
fn parse_turn_metadata_header(turn_metadata_header: Option<&str>) -> Option<HeaderValue> {
turn_metadata_header.and_then(|value| HeaderValue::from_str(value).ok())
}
/// Builds the extra headers attached to Responses API requests.
///
/// These headers implement Codex-specific conventions:
///
/// - `x-codex-beta-features`: comma-separated beta feature keys enabled for the session.
/// - `x-codex-turn-state`: sticky routing token captured earlier in the turn.
/// - `x-codex-turn-metadata`: optional per-turn metadata for observability.
fn build_responses_headers(
beta_features_header: Option<&str>,
turn_state: Option<&Arc<OnceLock<String>>>,
turn_metadata_header: Option<&HeaderValue>,
) -> ApiHeaderMap {
let mut headers = ApiHeaderMap::new();
if let Some(value) = beta_features_header
&& !value.is_empty()
&& let Ok(header_value) = HeaderValue::from_str(value)
{
headers.insert("x-codex-beta-features", header_value);
}
if let Some(turn_state) = turn_state
&& let Some(state) = turn_state.get()
&& let Ok(header_value) = HeaderValue::from_str(state)
{
headers.insert(X_CODEX_TURN_STATE_HEADER, header_value);
}
if let Some(header_value) = turn_metadata_header {
headers.insert(X_CODEX_TURN_METADATA_HEADER, header_value.clone());
}
headers
}
fn subagent_header_value(session_source: &SessionSource) -> Option<String> {
match session_source {
SessionSource::SubAgent(subagent_source) => match subagent_source {
SubAgentSource::Review => Some("review".to_string()),
SubAgentSource::Compact => Some("compact".to_string()),
SubAgentSource::MemoryConsolidation => Some("memory_consolidation".to_string()),
SubAgentSource::ThreadSpawn { .. } => Some("collab_spawn".to_string()),
SubAgentSource::Other(label) => Some(label.clone()),
},
SessionSource::Internal(InternalSessionSource::MemoryConsolidation) => {
Some("memory_consolidation".to_string())
}
SessionSource::Cli
| SessionSource::VSCode
| SessionSource::Exec
| SessionSource::Mcp
| SessionSource::Custom(_)
| SessionSource::Unknown => None,
}
}
fn parent_thread_id_header_value(session_source: &SessionSource) -> Option<String> {
match session_source {
SessionSource::SubAgent(SubAgentSource::ThreadSpawn {
parent_thread_id, ..
}) => Some(parent_thread_id.to_string()),
SessionSource::Cli
| SessionSource::VSCode
| SessionSource::Exec
| SessionSource::Mcp
| SessionSource::Custom(_)
| SessionSource::Internal(_)
| SessionSource::SubAgent(_)
| SessionSource::Unknown => None,
}
}
const RESPONSE_STREAM_CHANNEL_CAPACITY: usize = 1600;
const STREAM_DROPPED_REASON: &str = "response stream dropped before provider terminal event";
fn map_response_stream(
api_stream: codex_api::ResponseStream,
session_telemetry: SessionTelemetry,
inference_trace_attempt: InferenceTraceAttempt,
) -> (ResponseStream, oneshot::Receiver<LastResponse>) {
let codex_api::ResponseStream {
rx_event,
upstream_request_id,
} = api_stream;
let api_stream = codex_api::ResponseStream {
rx_event,
upstream_request_id: None,
};
map_response_events(
upstream_request_id,
api_stream,
session_telemetry,
inference_trace_attempt,
)
}
fn map_response_events<S>(
upstream_request_id: Option<String>,
api_stream: S,
session_telemetry: SessionTelemetry,
inference_trace_attempt: InferenceTraceAttempt,
) -> (ResponseStream, oneshot::Receiver<LastResponse>)
where
S: futures::Stream<Item = std::result::Result<ResponseEvent, ApiError>>
+ Unpin
+ Send
+ 'static,
{
let (tx_event, rx_event) =
mpsc::channel::<Result<ResponseEvent>>(RESPONSE_STREAM_CHANNEL_CAPACITY);
let (tx_last_response, rx_last_response) = oneshot::channel::<LastResponse>();
let consumer_dropped = CancellationToken::new();
let consumer_dropped_for_stream = consumer_dropped.clone();
tokio::spawn(async move {
let mut logged_error = false;
let mut tx_last_response = Some(tx_last_response);
let mut items_added: Vec<ResponseItem> = Vec::new();
let mut api_stream = api_stream;
let upstream_request_id = upstream_request_id.as_deref();
if let Some(upstream_request_id) = upstream_request_id {
feedback_tags!(last_model_request_id = upstream_request_id);
}
loop {
let event = tokio::select! {
_ = consumer_dropped.cancelled() => {
inference_trace_attempt.record_cancelled(
STREAM_DROPPED_REASON,
upstream_request_id,
&items_added,
);
return;
}
event = api_stream.next() => event,
};
let Some(event) = event else {
break;
};
match event {
Ok(ResponseEvent::OutputItemDone(item)) => {
items_added.push(item.clone());
if tx_event
.send(Ok(ResponseEvent::OutputItemDone(item)))
.await
.is_err()
{
inference_trace_attempt.record_cancelled(
STREAM_DROPPED_REASON,
upstream_request_id,
&items_added,
);
return;
}
}
Ok(ResponseEvent::Completed {
response_id,
token_usage,
end_turn,
}) => {
feedback_tags!(last_model_response_id = &response_id);
if let Some(usage) = &token_usage {
session_telemetry.sse_event_completed(
usage.input_tokens,
usage.output_tokens,
Some(usage.cached_input_tokens),
Some(usage.reasoning_output_tokens),
usage.total_tokens,
);
}
inference_trace_attempt.record_completed(
&response_id,
upstream_request_id,
&token_usage,
&items_added,
);
if let Some(sender) = tx_last_response.take() {
let _ = sender.send(LastResponse {
response_id: response_id.clone(),
items_added: std::mem::take(&mut items_added),
});
}
if tx_event
.send(Ok(ResponseEvent::Completed {
response_id,
token_usage,
end_turn,
}))
.await
.is_err()
{
return;
}
}
Ok(event) => {
if tx_event.send(Ok(event)).await.is_err() {
inference_trace_attempt.record_cancelled(
STREAM_DROPPED_REASON,
upstream_request_id,
&items_added,
);
return;
}
}
Err(err) => {
let response_debug_context =
extract_response_debug_context_from_api_error(&err);
let upstream_request_id =
upstream_request_id.or(response_debug_context.request_id.as_deref());
if let Some(upstream_request_id) = upstream_request_id {
feedback_tags!(last_model_request_id = upstream_request_id);
}
let mapped = map_api_error(err);
inference_trace_attempt.record_failed(
&mapped,
upstream_request_id,
&items_added,
);
if !logged_error {
session_telemetry.see_event_completed_failed(&mapped);
logged_error = true;
}
if tx_event.send(Err(mapped)).await.is_err() {
return;
}
}
}
}
inference_trace_attempt.record_failed(
"stream closed before response.completed",
upstream_request_id,
&items_added,
);
});
(
ResponseStream {
rx_event,
consumer_dropped: consumer_dropped_for_stream,
},
rx_last_response,
)
}
/// Handles a 401 response by optionally refreshing ChatGPT tokens once.
///
/// When refresh succeeds, the caller should retry the API call; otherwise
/// the mapped `CodexErr` is returned to the caller.
#[derive(Clone, Copy, Debug)]
struct UnauthorizedRecoveryExecution {
mode: &'static str,
phase: &'static str,
}
#[derive(Clone, Copy, Debug, Default)]
struct PendingUnauthorizedRetry {
retry_after_unauthorized: bool,
recovery_mode: Option<&'static str>,
recovery_phase: Option<&'static str>,
}
impl PendingUnauthorizedRetry {
fn from_recovery(recovery: UnauthorizedRecoveryExecution) -> Self {
Self {
retry_after_unauthorized: true,
recovery_mode: Some(recovery.mode),
recovery_phase: Some(recovery.phase),
}
}
}
#[derive(Clone, Copy, Debug, Default)]
struct AuthRequestTelemetryContext {
auth_mode: Option<&'static str>,
auth_header_attached: bool,
auth_header_name: Option<&'static str>,
retry_after_unauthorized: bool,
recovery_mode: Option<&'static str>,
recovery_phase: Option<&'static str>,
}
impl AuthRequestTelemetryContext {
fn new(
auth_mode: Option<AuthMode>,
api_auth: &dyn AuthProvider,
retry: PendingUnauthorizedRetry,
) -> Self {
let auth_telemetry = auth_header_telemetry(api_auth);
Self {
auth_mode: auth_mode.map(|mode| match mode {
AuthMode::ApiKey => "ApiKey",
AuthMode::Chatgpt | AuthMode::ChatgptAuthTokens | AuthMode::AgentIdentity => {
"Chatgpt"
}
}),
auth_header_attached: auth_telemetry.attached,
auth_header_name: auth_telemetry.name,
retry_after_unauthorized: retry.retry_after_unauthorized,
recovery_mode: retry.recovery_mode,
recovery_phase: retry.recovery_phase,
}
}
}
struct WebsocketConnectParams<'a> {
session_telemetry: &'a SessionTelemetry,
api_provider: codex_api::Provider,
api_auth: SharedAuthProvider,
turn_metadata_header: Option<&'a str>,
options: &'a ApiResponsesOptions,
auth_context: AuthRequestTelemetryContext,
request_route_telemetry: RequestRouteTelemetry,
}
async fn handle_unauthorized(
transport: TransportError,
auth_recovery: &mut Option<UnauthorizedRecovery>,
session_telemetry: &SessionTelemetry,
) -> Result<UnauthorizedRecoveryExecution> {
let debug = extract_response_debug_context(&transport);
if let Some(recovery) = auth_recovery
&& recovery.has_next()
{
let mode = recovery.mode_name();
let phase = recovery.step_name();
return match recovery.next().await {
Ok(step_result) => {
session_telemetry.record_auth_recovery(
mode,
phase,
"recovery_succeeded",
debug.request_id.as_deref(),
debug.cf_ray.as_deref(),
debug.auth_error.as_deref(),
debug.auth_error_code.as_deref(),
/*recovery_reason*/ None,
step_result.auth_state_changed(),
);
emit_feedback_auth_recovery_tags(
mode,
phase,
"recovery_succeeded",
debug.request_id.as_deref(),
debug.cf_ray.as_deref(),
debug.auth_error.as_deref(),
debug.auth_error_code.as_deref(),
);
Ok(UnauthorizedRecoveryExecution { mode, phase })
}
Err(RefreshTokenError::Permanent(failed)) => {
session_telemetry.record_auth_recovery(
mode,
phase,
"recovery_failed_permanent",
debug.request_id.as_deref(),
debug.cf_ray.as_deref(),
debug.auth_error.as_deref(),
debug.auth_error_code.as_deref(),
/*recovery_reason*/ None,
/*auth_state_changed*/ None,
);
emit_feedback_auth_recovery_tags(
mode,
phase,
"recovery_failed_permanent",
debug.request_id.as_deref(),
debug.cf_ray.as_deref(),
debug.auth_error.as_deref(),
debug.auth_error_code.as_deref(),
);
Err(CodexErr::RefreshTokenFailed(failed))
}
Err(RefreshTokenError::Transient(other)) => {
session_telemetry.record_auth_recovery(
mode,
phase,
"recovery_failed_transient",
debug.request_id.as_deref(),
debug.cf_ray.as_deref(),
debug.auth_error.as_deref(),
debug.auth_error_code.as_deref(),
/*recovery_reason*/ None,
/*auth_state_changed*/ None,
);
emit_feedback_auth_recovery_tags(
mode,
phase,
"recovery_failed_transient",
debug.request_id.as_deref(),
debug.cf_ray.as_deref(),
debug.auth_error.as_deref(),
debug.auth_error_code.as_deref(),
);
Err(CodexErr::Io(other))
}
};
}
let (mode, phase, recovery_reason) = match auth_recovery.as_ref() {
Some(recovery) => (
recovery.mode_name(),
recovery.step_name(),
Some(recovery.unavailable_reason()),
),
None => ("none", "none", Some("auth_manager_missing")),
};
session_telemetry.record_auth_recovery(
mode,
phase,
"recovery_not_run",
debug.request_id.as_deref(),
debug.cf_ray.as_deref(),
debug.auth_error.as_deref(),
debug.auth_error_code.as_deref(),
recovery_reason,
/*auth_state_changed*/ None,
);
emit_feedback_auth_recovery_tags(
mode,
phase,
"recovery_not_run",
debug.request_id.as_deref(),
debug.cf_ray.as_deref(),
debug.auth_error.as_deref(),
debug.auth_error_code.as_deref(),
);
Err(map_api_error(ApiError::Transport(transport)))
}
fn api_error_http_status(error: &ApiError) -> Option<u16> {
match error {
ApiError::Transport(TransportError::Http { status, .. }) => Some(status.as_u16()),
_ => None,
}
}
struct ApiTelemetry {
session_telemetry: SessionTelemetry,
auth_context: AuthRequestTelemetryContext,
request_route_telemetry: RequestRouteTelemetry,
auth_env_telemetry: AuthEnvTelemetry,
}
impl ApiTelemetry {
fn new(
session_telemetry: SessionTelemetry,
auth_context: AuthRequestTelemetryContext,
request_route_telemetry: RequestRouteTelemetry,
auth_env_telemetry: AuthEnvTelemetry,
) -> Self {
Self {
session_telemetry,
auth_context,
request_route_telemetry,
auth_env_telemetry,
}
}
}
impl RequestTelemetry for ApiTelemetry {
fn on_request(
&self,
attempt: u64,
status: Option<HttpStatusCode>,
error: Option<&TransportError>,
duration: Duration,
) {
let error_message = error.map(telemetry_transport_error_message);
let status = status.map(|s| s.as_u16());
let debug = error
.map(extract_response_debug_context)
.unwrap_or_default();
self.session_telemetry.record_api_request(
attempt,
status,
error_message.as_deref(),
duration,
self.auth_context.auth_header_attached,
self.auth_context.auth_header_name,
self.auth_context.retry_after_unauthorized,
self.auth_context.recovery_mode,
self.auth_context.recovery_phase,
self.request_route_telemetry.endpoint,
debug.request_id.as_deref(),
debug.cf_ray.as_deref(),
debug.auth_error.as_deref(),
debug.auth_error_code.as_deref(),
);
emit_feedback_request_tags_with_auth_env(
&FeedbackRequestTags {
endpoint: self.request_route_telemetry.endpoint,
auth_header_attached: self.auth_context.auth_header_attached,
auth_header_name: self.auth_context.auth_header_name,
auth_mode: self.auth_context.auth_mode,
auth_retry_after_unauthorized: Some(self.auth_context.retry_after_unauthorized),
auth_recovery_mode: self.auth_context.recovery_mode,
auth_recovery_phase: self.auth_context.recovery_phase,
auth_connection_reused: None,
auth_request_id: debug.request_id.as_deref(),
auth_cf_ray: debug.cf_ray.as_deref(),
auth_error: debug.auth_error.as_deref(),
auth_error_code: debug.auth_error_code.as_deref(),
auth_recovery_followup_success: self
.auth_context
.retry_after_unauthorized
.then_some(error.is_none()),
auth_recovery_followup_status: self
.auth_context
.retry_after_unauthorized
.then_some(status)
.flatten(),
},
&self.auth_env_telemetry,
);
}
}
impl SseTelemetry for ApiTelemetry {
fn on_sse_poll(
&self,
result: &std::result::Result<
Option<std::result::Result<Event, EventStreamError<TransportError>>>,
tokio::time::error::Elapsed,
>,
duration: Duration,
) {
self.session_telemetry.log_sse_event(result, duration);
}
}
impl WebsocketTelemetry for ApiTelemetry {
fn on_ws_request(&self, duration: Duration, error: Option<&ApiError>, connection_reused: bool) {
let error_message = error.map(telemetry_api_error_message);
let status = error.and_then(api_error_http_status);
let debug = error
.map(extract_response_debug_context_from_api_error)
.unwrap_or_default();
self.session_telemetry.record_websocket_request(
duration,
error_message.as_deref(),
connection_reused,
);
emit_feedback_request_tags_with_auth_env(
&FeedbackRequestTags {
endpoint: self.request_route_telemetry.endpoint,
auth_header_attached: self.auth_context.auth_header_attached,
auth_header_name: self.auth_context.auth_header_name,
auth_mode: self.auth_context.auth_mode,
auth_retry_after_unauthorized: Some(self.auth_context.retry_after_unauthorized),
auth_recovery_mode: self.auth_context.recovery_mode,
auth_recovery_phase: self.auth_context.recovery_phase,
auth_connection_reused: Some(connection_reused),
auth_request_id: debug.request_id.as_deref(),
auth_cf_ray: debug.cf_ray.as_deref(),
auth_error: debug.auth_error.as_deref(),
auth_error_code: debug.auth_error_code.as_deref(),
auth_recovery_followup_success: self
.auth_context
.retry_after_unauthorized
.then_some(error.is_none()),
auth_recovery_followup_status: self
.auth_context
.retry_after_unauthorized
.then_some(status)
.flatten(),
},
&self.auth_env_telemetry,
);
}
fn on_ws_event(
&self,
result: &std::result::Result<Option<std::result::Result<Message, Error>>, ApiError>,
duration: Duration,
) {
self.session_telemetry
.record_websocket_event(result, duration);
}
}
#[cfg(test)]
#[path = "client_tests.rs"]
mod tests;