mirror of
https://github.com/openai/codex.git
synced 2026-04-28 00:25:56 +00:00
CXC-392 [With 401](https://openai.sentry.io/issues/7333870443/?project=4510195390611458&query=019ce8f8-560c-7f10-a00a-c59553740674&referrer=issue-stream) <img width="1909" height="555" alt="401 auth tags in Sentry" src="https://github.com/user-attachments/assets/412ea950-61c4-4780-9697-15c270971ee3" /> - auth_401_*: preserved facts from the latest unauthorized response snapshot - auth_*: latest auth-related facts from the latest request attempt - auth_recovery_*: unauthorized recovery state and follow-up result Without 401 <img width="1917" height="522" alt="happy-path auth tags in Sentry" src="https://github.com/user-attachments/assets/3381ed28-8022-43b0-b6c0-623a630e679f" /> ###### Summary - Add client-visible 401 diagnostics for auth attachment, upstream auth classification, and 401 request id / cf-ray correlation. - Record unauthorized recovery mode, phase, outcome, and retry/follow-up status without changing auth behavior. - Surface the highest-signal auth and recovery fields on uploaded client bug reports so they are usable in Sentry. - Preserve original unauthorized evidence under `auth_401_*` while keeping follow-up result tags separate. ###### Rationale (from spec findings) - The dominant bucket needed proof of whether the client attached auth before send or upstream still classified the request as missing auth. - Client uploads needed to show whether unauthorized recovery ran and what the client tried next. - Request id and cf-ray needed to be preserved on the unauthorized response so server-side correlation is immediate. - The bug-report path needed the same auth evidence as the request telemetry path, otherwise the observability would not be operationally useful. ###### Scope - Add auth 401 and unauthorized-recovery observability in `codex-rs/core`, `codex-rs/codex-api`, and `codex-rs/otel`, including feedback-tag surfacing. - Keep auth semantics, refresh behavior, retry behavior, endpoint classification, and geo-denial follow-up work out of this PR. ###### Trade-offs - This exports only safe auth evidence: header presence/name, upstream auth classification, request ids, and recovery state. It does not export token values or raw upstream bodies. - This keeps websocket connection reuse as a transport clue because it can help distinguish stale reused sessions from fresh reconnects. - Misroute/base-url classification and geo-denial are intentionally deferred to a separate follow-up PR so this review stays focused on the dominant auth 401 bucket. ###### Client follow-up - PR 2 will add misroute/provider and geo-denial observability plus the matching feedback-tag surfacing. - A separate host/app-server PR should log auth-decision inputs so pre-send host auth state can be correlated with client request evidence. - `device_id` remains intentionally separate until there is a safe existing source on the feedback upload path. ###### Testing - `cargo test -p codex-core refresh_available_models_sorts_by_priority` - `cargo test -p codex-core emit_feedback_request_tags_` - `cargo test -p codex-core emit_feedback_auth_recovery_tags_` - `cargo test -p codex-core auth_request_telemetry_context_tracks_attached_auth_and_retry_phase` - `cargo test -p codex-core extract_response_debug_context_decodes_identity_headers` - `cargo test -p codex-core identity_auth_details` - `cargo test -p codex-core telemetry_error_messages_preserve_non_http_details` - `cargo test -p codex-core --all-features --no-run` - `cargo test -p codex-otel otel_export_routing_policy_routes_api_request_auth_observability` - `cargo test -p codex-otel otel_export_routing_policy_routes_websocket_connect_auth_observability` - `cargo test -p codex-otel otel_export_routing_policy_routes_websocket_request_transport_observability`
833 lines
28 KiB
Rust
833 lines
28 KiB
Rust
use crate::auth::AuthProvider;
|
|
use crate::auth::add_auth_headers_to_header_map;
|
|
use crate::common::ResponseEvent;
|
|
use crate::common::ResponseStream;
|
|
use crate::common::ResponsesWsRequest;
|
|
use crate::error::ApiError;
|
|
use crate::provider::Provider;
|
|
use crate::rate_limits::parse_rate_limit_event;
|
|
use crate::sse::responses::ResponsesStreamEvent;
|
|
use crate::sse::responses::process_responses_event;
|
|
use crate::telemetry::WebsocketTelemetry;
|
|
use codex_client::TransportError;
|
|
use codex_client::maybe_build_rustls_client_config_with_custom_ca;
|
|
use codex_utils_rustls_provider::ensure_rustls_crypto_provider;
|
|
use futures::SinkExt;
|
|
use futures::StreamExt;
|
|
use http::HeaderMap;
|
|
use http::HeaderName;
|
|
use http::HeaderValue;
|
|
use http::StatusCode;
|
|
use serde::Deserialize;
|
|
use serde_json::Value;
|
|
use serde_json::map::Map as JsonMap;
|
|
use std::sync::Arc;
|
|
use std::sync::OnceLock;
|
|
use std::time::Duration;
|
|
use tokio::net::TcpStream;
|
|
use tokio::sync::Mutex;
|
|
use tokio::sync::mpsc;
|
|
use tokio::sync::oneshot;
|
|
use tokio::time::Instant;
|
|
use tokio_tungstenite::MaybeTlsStream;
|
|
use tokio_tungstenite::WebSocketStream;
|
|
use tokio_tungstenite::connect_async_tls_with_config;
|
|
use tokio_tungstenite::tungstenite::Error as WsError;
|
|
use tokio_tungstenite::tungstenite::Message;
|
|
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
|
|
use tracing::Instrument;
|
|
use tracing::Span;
|
|
use tracing::debug;
|
|
use tracing::error;
|
|
use tracing::info;
|
|
use tracing::instrument;
|
|
use tracing::trace;
|
|
use tungstenite::extensions::ExtensionsConfig;
|
|
use tungstenite::extensions::compression::deflate::DeflateConfig;
|
|
use tungstenite::protocol::WebSocketConfig;
|
|
use url::Url;
|
|
|
|
struct WsStream {
|
|
tx_command: mpsc::Sender<WsCommand>,
|
|
rx_message: mpsc::UnboundedReceiver<Result<Message, WsError>>,
|
|
pump_task: tokio::task::JoinHandle<()>,
|
|
}
|
|
|
|
enum WsCommand {
|
|
Send {
|
|
message: Message,
|
|
tx_result: oneshot::Sender<Result<(), WsError>>,
|
|
},
|
|
}
|
|
|
|
impl WsStream {
|
|
fn new(inner: WebSocketStream<MaybeTlsStream<TcpStream>>) -> Self {
|
|
let (tx_command, mut rx_command) = mpsc::channel::<WsCommand>(32);
|
|
let (tx_message, rx_message) = mpsc::unbounded_channel::<Result<Message, WsError>>();
|
|
|
|
let pump_task = tokio::spawn(async move {
|
|
let mut inner = inner;
|
|
loop {
|
|
tokio::select! {
|
|
command = rx_command.recv() => {
|
|
let Some(command) = command else {
|
|
break;
|
|
};
|
|
match command {
|
|
WsCommand::Send { message, tx_result } => {
|
|
let result = inner.send(message).await;
|
|
let should_break = result.is_err();
|
|
let _ = tx_result.send(result);
|
|
if should_break {
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
message = inner.next() => {
|
|
let Some(message) = message else {
|
|
break;
|
|
};
|
|
match message {
|
|
Ok(Message::Ping(payload)) => {
|
|
if let Err(err) = inner.send(Message::Pong(payload)).await {
|
|
let _ = tx_message.send(Err(err));
|
|
break;
|
|
}
|
|
}
|
|
Ok(Message::Pong(_)) => {}
|
|
Ok(message @ (Message::Text(_)
|
|
| Message::Binary(_)
|
|
| Message::Close(_)
|
|
| Message::Frame(_))) => {
|
|
let is_close = matches!(message, Message::Close(_));
|
|
if tx_message.send(Ok(message)).is_err() {
|
|
break;
|
|
}
|
|
if is_close {
|
|
break;
|
|
}
|
|
}
|
|
Err(err) => {
|
|
let _ = tx_message.send(Err(err));
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
});
|
|
|
|
Self {
|
|
tx_command,
|
|
rx_message,
|
|
pump_task,
|
|
}
|
|
}
|
|
|
|
async fn request(
|
|
&self,
|
|
make_command: impl FnOnce(oneshot::Sender<Result<(), WsError>>) -> WsCommand,
|
|
) -> Result<(), WsError> {
|
|
let (tx_result, rx_result) = oneshot::channel();
|
|
if self.tx_command.send(make_command(tx_result)).await.is_err() {
|
|
return Err(WsError::ConnectionClosed);
|
|
}
|
|
rx_result.await.unwrap_or(Err(WsError::ConnectionClosed))
|
|
}
|
|
|
|
async fn send(&self, message: Message) -> Result<(), WsError> {
|
|
self.request(|tx_result| WsCommand::Send { message, tx_result })
|
|
.await
|
|
}
|
|
|
|
async fn next(&mut self) -> Option<Result<Message, WsError>> {
|
|
self.rx_message.recv().await
|
|
}
|
|
}
|
|
|
|
impl Drop for WsStream {
|
|
fn drop(&mut self) {
|
|
self.pump_task.abort();
|
|
}
|
|
}
|
|
|
|
const X_CODEX_TURN_STATE_HEADER: &str = "x-codex-turn-state";
|
|
const X_MODELS_ETAG_HEADER: &str = "x-models-etag";
|
|
const X_REASONING_INCLUDED_HEADER: &str = "x-reasoning-included";
|
|
const OPENAI_MODEL_HEADER: &str = "openai-model";
|
|
const WEBSOCKET_CONNECTION_LIMIT_REACHED_CODE: &str = "websocket_connection_limit_reached";
|
|
const WEBSOCKET_CONNECTION_LIMIT_REACHED_MESSAGE: &str = "Responses websocket connection limit reached (60 minutes). Create a new websocket connection to continue.";
|
|
|
|
pub struct ResponsesWebsocketConnection {
|
|
stream: Arc<Mutex<Option<WsStream>>>,
|
|
// TODO (pakrym): is this the right place for timeout?
|
|
idle_timeout: Duration,
|
|
server_reasoning_included: bool,
|
|
models_etag: Option<String>,
|
|
server_model: Option<String>,
|
|
telemetry: Option<Arc<dyn WebsocketTelemetry>>,
|
|
}
|
|
|
|
impl std::fmt::Debug for ResponsesWebsocketConnection {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
f.debug_struct("ResponsesWebsocketConnection")
|
|
.field("stream", &"<ws-stream>")
|
|
.field("idle_timeout", &self.idle_timeout)
|
|
.field("server_reasoning_included", &self.server_reasoning_included)
|
|
.field("models_etag", &self.models_etag)
|
|
.field("server_model", &self.server_model)
|
|
.field("telemetry", &self.telemetry.as_ref().map(|_| "<telemetry>"))
|
|
.finish()
|
|
}
|
|
}
|
|
|
|
impl ResponsesWebsocketConnection {
|
|
fn new(
|
|
stream: WsStream,
|
|
idle_timeout: Duration,
|
|
server_reasoning_included: bool,
|
|
models_etag: Option<String>,
|
|
server_model: Option<String>,
|
|
telemetry: Option<Arc<dyn WebsocketTelemetry>>,
|
|
) -> Self {
|
|
Self {
|
|
stream: Arc::new(Mutex::new(Some(stream))),
|
|
idle_timeout,
|
|
server_reasoning_included,
|
|
models_etag,
|
|
server_model,
|
|
telemetry,
|
|
}
|
|
}
|
|
|
|
pub async fn is_closed(&self) -> bool {
|
|
self.stream.lock().await.is_none()
|
|
}
|
|
|
|
#[instrument(
|
|
name = "responses_websocket.stream_request",
|
|
level = "info",
|
|
skip_all,
|
|
fields(transport = "responses_websocket", api.path = "responses")
|
|
)]
|
|
pub async fn stream_request(
|
|
&self,
|
|
request: ResponsesWsRequest,
|
|
connection_reused: bool,
|
|
) -> Result<ResponseStream, ApiError> {
|
|
let (tx_event, rx_event) =
|
|
mpsc::channel::<std::result::Result<ResponseEvent, ApiError>>(1600);
|
|
let stream = Arc::clone(&self.stream);
|
|
let idle_timeout = self.idle_timeout;
|
|
let server_reasoning_included = self.server_reasoning_included;
|
|
let models_etag = self.models_etag.clone();
|
|
let server_model = self.server_model.clone();
|
|
let telemetry = self.telemetry.clone();
|
|
let request_body = serde_json::to_value(&request).map_err(|err| {
|
|
ApiError::Stream(format!("failed to encode websocket request: {err}"))
|
|
})?;
|
|
|
|
let current_span = Span::current();
|
|
tokio::spawn(
|
|
async move {
|
|
if let Some(model) = server_model {
|
|
let _ = tx_event.send(Ok(ResponseEvent::ServerModel(model))).await;
|
|
}
|
|
if let Some(etag) = models_etag {
|
|
let _ = tx_event.send(Ok(ResponseEvent::ModelsEtag(etag))).await;
|
|
}
|
|
if server_reasoning_included {
|
|
let _ = tx_event
|
|
.send(Ok(ResponseEvent::ServerReasoningIncluded(true)))
|
|
.await;
|
|
}
|
|
let mut guard = stream.lock().await;
|
|
let result = {
|
|
let Some(ws_stream) = guard.as_mut() else {
|
|
let _ = tx_event
|
|
.send(Err(ApiError::Stream(
|
|
"websocket connection is closed".to_string(),
|
|
)))
|
|
.await;
|
|
return;
|
|
};
|
|
|
|
run_websocket_response_stream(
|
|
ws_stream,
|
|
tx_event.clone(),
|
|
request_body,
|
|
idle_timeout,
|
|
telemetry,
|
|
connection_reused,
|
|
)
|
|
.await
|
|
};
|
|
|
|
if let Err(err) = result {
|
|
// A terminal stream error should reach the caller immediately. Waiting for a
|
|
// graceful close handshake here can stall indefinitely and mask the error.
|
|
let failed_stream = guard.take();
|
|
drop(guard);
|
|
drop(failed_stream);
|
|
let _ = tx_event.send(Err(err)).await;
|
|
}
|
|
}
|
|
.instrument(current_span),
|
|
);
|
|
|
|
Ok(ResponseStream { rx_event })
|
|
}
|
|
}
|
|
|
|
pub struct ResponsesWebsocketClient<A: AuthProvider> {
|
|
provider: Provider,
|
|
auth: A,
|
|
}
|
|
|
|
impl<A: AuthProvider> ResponsesWebsocketClient<A> {
|
|
pub fn new(provider: Provider, auth: A) -> Self {
|
|
Self { provider, auth }
|
|
}
|
|
|
|
#[instrument(
|
|
name = "responses_websocket.connect",
|
|
level = "info",
|
|
skip_all,
|
|
fields(transport = "responses_websocket", api.path = "responses")
|
|
)]
|
|
pub async fn connect(
|
|
&self,
|
|
extra_headers: HeaderMap,
|
|
default_headers: HeaderMap,
|
|
turn_state: Option<Arc<OnceLock<String>>>,
|
|
telemetry: Option<Arc<dyn WebsocketTelemetry>>,
|
|
) -> Result<ResponsesWebsocketConnection, ApiError> {
|
|
let ws_url = self
|
|
.provider
|
|
.websocket_url_for_path("responses")
|
|
.map_err(|err| ApiError::Stream(format!("failed to build websocket URL: {err}")))?;
|
|
|
|
let mut headers =
|
|
merge_request_headers(&self.provider.headers, extra_headers, default_headers);
|
|
add_auth_headers_to_header_map(&self.auth, &mut headers);
|
|
|
|
let (stream, server_reasoning_included, models_etag, server_model) =
|
|
connect_websocket(ws_url, headers, turn_state.clone()).await?;
|
|
Ok(ResponsesWebsocketConnection::new(
|
|
stream,
|
|
self.provider.stream_idle_timeout,
|
|
server_reasoning_included,
|
|
models_etag,
|
|
server_model,
|
|
telemetry,
|
|
))
|
|
}
|
|
}
|
|
|
|
fn merge_request_headers(
|
|
provider_headers: &HeaderMap,
|
|
extra_headers: HeaderMap,
|
|
default_headers: HeaderMap,
|
|
) -> HeaderMap {
|
|
let mut headers = provider_headers.clone();
|
|
headers.extend(extra_headers);
|
|
for (name, value) in &default_headers {
|
|
if let http::header::Entry::Vacant(entry) = headers.entry(name) {
|
|
entry.insert(value.clone());
|
|
}
|
|
}
|
|
headers
|
|
}
|
|
|
|
async fn connect_websocket(
|
|
url: Url,
|
|
headers: HeaderMap,
|
|
turn_state: Option<Arc<OnceLock<String>>>,
|
|
) -> Result<(WsStream, bool, Option<String>, Option<String>), ApiError> {
|
|
ensure_rustls_crypto_provider();
|
|
info!("connecting to websocket: {url}");
|
|
|
|
let mut request = url
|
|
.as_str()
|
|
.into_client_request()
|
|
.map_err(|err| ApiError::Stream(format!("failed to build websocket request: {err}")))?;
|
|
request.headers_mut().extend(headers);
|
|
|
|
// Secure websocket traffic needs the same custom-CA policy as reqwest-based HTTPS traffic.
|
|
// If a Codex-specific CA bundle is configured, build an explicit rustls connector so this
|
|
// websocket path does not fall back to tungstenite's default native-roots-only behavior.
|
|
let connector = maybe_build_rustls_client_config_with_custom_ca()
|
|
.map_err(|err| ApiError::Stream(format!("failed to configure websocket TLS: {err}")))?
|
|
.map(tokio_tungstenite::Connector::Rustls);
|
|
|
|
let response = connect_async_tls_with_config(
|
|
request,
|
|
Some(websocket_config()),
|
|
false, // `false` means "do not disable Nagle", which is tungstenite's recommended default.
|
|
connector,
|
|
)
|
|
.await;
|
|
|
|
let (stream, response) = match response {
|
|
Ok((stream, response)) => {
|
|
info!(
|
|
"successfully connected to websocket: {url}, headers: {:?}",
|
|
response.headers()
|
|
);
|
|
(stream, response)
|
|
}
|
|
Err(err) => {
|
|
error!("failed to connect to websocket: {err}, url: {url}");
|
|
return Err(map_ws_error(err, &url));
|
|
}
|
|
};
|
|
|
|
let reasoning_included = response.headers().contains_key(X_REASONING_INCLUDED_HEADER);
|
|
let models_etag = response
|
|
.headers()
|
|
.get(X_MODELS_ETAG_HEADER)
|
|
.and_then(|value| value.to_str().ok())
|
|
.map(ToString::to_string);
|
|
let server_model = response
|
|
.headers()
|
|
.get(OPENAI_MODEL_HEADER)
|
|
.and_then(|value| value.to_str().ok())
|
|
.map(ToString::to_string);
|
|
if let Some(turn_state) = turn_state
|
|
&& let Some(header_value) = response
|
|
.headers()
|
|
.get(X_CODEX_TURN_STATE_HEADER)
|
|
.and_then(|value| value.to_str().ok())
|
|
{
|
|
let _ = turn_state.set(header_value.to_string());
|
|
}
|
|
Ok((
|
|
WsStream::new(stream),
|
|
reasoning_included,
|
|
models_etag,
|
|
server_model,
|
|
))
|
|
}
|
|
|
|
fn websocket_config() -> WebSocketConfig {
|
|
let mut extensions = ExtensionsConfig::default();
|
|
extensions.permessage_deflate = Some(DeflateConfig::default());
|
|
|
|
let mut config = WebSocketConfig::default();
|
|
config.extensions = extensions;
|
|
config
|
|
}
|
|
|
|
fn map_ws_error(err: WsError, url: &Url) -> ApiError {
|
|
match err {
|
|
WsError::Http(response) => {
|
|
let status = response.status();
|
|
let headers = response.headers().clone();
|
|
let body = response
|
|
.body()
|
|
.as_ref()
|
|
.and_then(|bytes| String::from_utf8(bytes.clone()).ok());
|
|
ApiError::Transport(TransportError::Http {
|
|
status,
|
|
url: Some(url.to_string()),
|
|
headers: Some(headers),
|
|
body,
|
|
})
|
|
}
|
|
WsError::ConnectionClosed | WsError::AlreadyClosed => {
|
|
ApiError::Stream("websocket closed".to_string())
|
|
}
|
|
WsError::Io(err) => ApiError::Transport(TransportError::Network(err.to_string())),
|
|
other => ApiError::Transport(TransportError::Network(other.to_string())),
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct WrappedWebsocketError {
|
|
code: Option<String>,
|
|
message: Option<String>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct WrappedWebsocketErrorEvent {
|
|
#[serde(rename = "type")]
|
|
kind: String,
|
|
#[serde(alias = "status_code")]
|
|
status: Option<u16>,
|
|
#[serde(default)]
|
|
error: Option<WrappedWebsocketError>,
|
|
#[serde(default)]
|
|
headers: Option<JsonMap<String, Value>>,
|
|
}
|
|
|
|
fn parse_wrapped_websocket_error_event(payload: &str) -> Option<WrappedWebsocketErrorEvent> {
|
|
let event: WrappedWebsocketErrorEvent = serde_json::from_str(payload).ok()?;
|
|
if event.kind != "error" {
|
|
return None;
|
|
}
|
|
Some(event)
|
|
}
|
|
|
|
fn map_wrapped_websocket_error_event(
|
|
event: WrappedWebsocketErrorEvent,
|
|
original_payload: String,
|
|
) -> Option<ApiError> {
|
|
let WrappedWebsocketErrorEvent {
|
|
status,
|
|
error,
|
|
headers,
|
|
..
|
|
} = event;
|
|
|
|
if let Some(error) = error.as_ref()
|
|
&& let Some(code) = error.code.as_deref()
|
|
&& code == WEBSOCKET_CONNECTION_LIMIT_REACHED_CODE
|
|
{
|
|
return Some(ApiError::Retryable {
|
|
message: error
|
|
.message
|
|
.clone()
|
|
.unwrap_or_else(|| WEBSOCKET_CONNECTION_LIMIT_REACHED_MESSAGE.to_string()),
|
|
delay: None,
|
|
});
|
|
}
|
|
|
|
let status = StatusCode::from_u16(status?).ok()?;
|
|
if status.is_success() {
|
|
return None;
|
|
}
|
|
|
|
Some(ApiError::Transport(TransportError::Http {
|
|
status,
|
|
url: None,
|
|
headers: headers.map(json_headers_to_http_headers),
|
|
body: Some(original_payload),
|
|
}))
|
|
}
|
|
|
|
fn json_headers_to_http_headers(headers: JsonMap<String, Value>) -> HeaderMap {
|
|
let mut mapped = HeaderMap::new();
|
|
for (name, value) in headers {
|
|
let Ok(header_name) = HeaderName::from_bytes(name.as_bytes()) else {
|
|
continue;
|
|
};
|
|
let Some(header_value) = json_header_value(value) else {
|
|
continue;
|
|
};
|
|
mapped.insert(header_name, header_value);
|
|
}
|
|
mapped
|
|
}
|
|
|
|
fn json_header_value(value: Value) -> Option<HeaderValue> {
|
|
let value = match value {
|
|
Value::String(value) => value,
|
|
Value::Number(value) => value.to_string(),
|
|
Value::Bool(value) => value.to_string(),
|
|
_ => return None,
|
|
};
|
|
HeaderValue::from_str(&value).ok()
|
|
}
|
|
|
|
async fn run_websocket_response_stream(
|
|
ws_stream: &mut WsStream,
|
|
tx_event: mpsc::Sender<std::result::Result<ResponseEvent, ApiError>>,
|
|
request_body: Value,
|
|
idle_timeout: Duration,
|
|
telemetry: Option<Arc<dyn WebsocketTelemetry>>,
|
|
connection_reused: bool,
|
|
) -> Result<(), ApiError> {
|
|
let mut last_server_model: Option<String> = None;
|
|
let request_text = match serde_json::to_string(&request_body) {
|
|
Ok(text) => text,
|
|
Err(err) => {
|
|
return Err(ApiError::Stream(format!(
|
|
"failed to encode websocket request: {err}"
|
|
)));
|
|
}
|
|
};
|
|
trace!("websocket request: {request_text}");
|
|
|
|
let request_start = Instant::now();
|
|
let result = ws_stream
|
|
.send(Message::Text(request_text.into()))
|
|
.await
|
|
.map_err(|err| ApiError::Stream(format!("failed to send websocket request: {err}")));
|
|
|
|
if let Some(t) = telemetry.as_ref() {
|
|
t.on_ws_request(
|
|
request_start.elapsed(),
|
|
result.as_ref().err(),
|
|
connection_reused,
|
|
);
|
|
}
|
|
|
|
result?;
|
|
|
|
loop {
|
|
let poll_start = Instant::now();
|
|
let response = tokio::time::timeout(idle_timeout, ws_stream.next())
|
|
.await
|
|
.map_err(|_| ApiError::Stream("idle timeout waiting for websocket".into()));
|
|
if let Some(t) = telemetry.as_ref() {
|
|
t.on_ws_event(&response, poll_start.elapsed());
|
|
}
|
|
let message = match response {
|
|
Ok(Some(Ok(msg))) => msg,
|
|
Ok(Some(Err(err))) => {
|
|
return Err(ApiError::Stream(err.to_string()));
|
|
}
|
|
Ok(None) => {
|
|
return Err(ApiError::Stream(
|
|
"stream closed before response.completed".into(),
|
|
));
|
|
}
|
|
Err(err) => {
|
|
return Err(err);
|
|
}
|
|
};
|
|
|
|
match message {
|
|
Message::Text(text) => {
|
|
trace!("websocket event: {text}");
|
|
if let Some(wrapped_error) = parse_wrapped_websocket_error_event(&text)
|
|
&& let Some(error) =
|
|
map_wrapped_websocket_error_event(wrapped_error, text.to_string())
|
|
{
|
|
return Err(error);
|
|
}
|
|
|
|
let event = match serde_json::from_str::<ResponsesStreamEvent>(&text) {
|
|
Ok(event) => event,
|
|
Err(err) => {
|
|
debug!("failed to parse websocket event: {err}, data: {text}");
|
|
continue;
|
|
}
|
|
};
|
|
if event.kind() == "codex.rate_limits" {
|
|
if let Some(snapshot) = parse_rate_limit_event(&text) {
|
|
let _ = tx_event.send(Ok(ResponseEvent::RateLimits(snapshot))).await;
|
|
}
|
|
continue;
|
|
}
|
|
if let Some(model) = event.response_model()
|
|
&& last_server_model.as_deref() != Some(model.as_str())
|
|
{
|
|
let _ = tx_event
|
|
.send(Ok(ResponseEvent::ServerModel(model.clone())))
|
|
.await;
|
|
last_server_model = Some(model);
|
|
}
|
|
match process_responses_event(event) {
|
|
Ok(Some(event)) => {
|
|
let is_completed = matches!(event, ResponseEvent::Completed { .. });
|
|
let _ = tx_event.send(Ok(event)).await;
|
|
if is_completed {
|
|
break;
|
|
}
|
|
}
|
|
Ok(None) => {}
|
|
Err(error) => {
|
|
return Err(error.into_api_error());
|
|
}
|
|
}
|
|
}
|
|
Message::Binary(_) => {
|
|
return Err(ApiError::Stream("unexpected binary websocket event".into()));
|
|
}
|
|
Message::Close(_) => {
|
|
return Err(ApiError::Stream(
|
|
"websocket closed by server before response.completed".into(),
|
|
));
|
|
}
|
|
Message::Frame(_) => {}
|
|
Message::Ping(_) | Message::Pong(_) => {}
|
|
}
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use pretty_assertions::assert_eq;
|
|
use serde_json::json;
|
|
|
|
#[test]
|
|
fn websocket_config_enables_permessage_deflate() {
|
|
let config = websocket_config();
|
|
assert!(config.extensions.permessage_deflate.is_some());
|
|
}
|
|
|
|
#[test]
|
|
fn parse_wrapped_websocket_error_event_maps_to_transport_http() {
|
|
let payload = json!({
|
|
"type": "error",
|
|
"status": 429,
|
|
"error": {
|
|
"type": "usage_limit_reached",
|
|
"message": "The usage limit has been reached",
|
|
"plan_type": "pro",
|
|
"resets_at": 1738888888
|
|
},
|
|
"headers": {
|
|
"x-codex-primary-used-percent": "100.0",
|
|
"x-codex-primary-window-minutes": 15
|
|
}
|
|
})
|
|
.to_string();
|
|
|
|
let wrapped_error = parse_wrapped_websocket_error_event(&payload)
|
|
.expect("expected websocket error payload to be parsed");
|
|
let api_error = map_wrapped_websocket_error_event(wrapped_error, payload)
|
|
.expect("expected websocket error payload to map to ApiError");
|
|
|
|
let ApiError::Transport(TransportError::Http {
|
|
status,
|
|
headers,
|
|
body,
|
|
..
|
|
}) = api_error
|
|
else {
|
|
panic!("expected ApiError::Transport(Http)");
|
|
};
|
|
|
|
assert_eq!(status, StatusCode::TOO_MANY_REQUESTS);
|
|
let headers = headers.expect("expected headers");
|
|
assert_eq!(
|
|
headers
|
|
.get("x-codex-primary-used-percent")
|
|
.and_then(|value| value.to_str().ok()),
|
|
Some("100.0")
|
|
);
|
|
assert_eq!(
|
|
headers
|
|
.get("x-codex-primary-window-minutes")
|
|
.and_then(|value| value.to_str().ok()),
|
|
Some("15")
|
|
);
|
|
let body = body.expect("expected body");
|
|
assert!(body.contains("usage_limit_reached"));
|
|
assert!(body.contains("The usage limit has been reached"));
|
|
}
|
|
|
|
#[test]
|
|
fn parse_wrapped_websocket_error_event_ignores_non_error_payloads() {
|
|
let payload = json!({
|
|
"type": "response.created",
|
|
"response": {
|
|
"id": "resp-1"
|
|
}
|
|
})
|
|
.to_string();
|
|
|
|
let wrapped_error = parse_wrapped_websocket_error_event(&payload);
|
|
assert!(wrapped_error.is_none());
|
|
}
|
|
|
|
#[test]
|
|
fn parse_wrapped_websocket_error_event_with_status_maps_invalid_request() {
|
|
let payload = json!({
|
|
"type": "error",
|
|
"status": 400,
|
|
"error": {
|
|
"type": "invalid_request_error",
|
|
"message": "Model does not support image inputs"
|
|
}
|
|
})
|
|
.to_string();
|
|
|
|
let wrapped_error = parse_wrapped_websocket_error_event(&payload)
|
|
.expect("expected websocket error payload to be parsed");
|
|
let api_error = map_wrapped_websocket_error_event(wrapped_error, payload)
|
|
.expect("expected websocket error payload to map to ApiError");
|
|
let ApiError::Transport(TransportError::Http { status, body, .. }) = api_error else {
|
|
panic!("expected ApiError::Transport(Http)");
|
|
};
|
|
assert_eq!(status, StatusCode::BAD_REQUEST);
|
|
let body = body.expect("expected body");
|
|
assert!(body.contains("invalid_request_error"));
|
|
assert!(body.contains("Model does not support image inputs"));
|
|
}
|
|
|
|
#[test]
|
|
fn parse_wrapped_websocket_error_event_with_connection_limit_maps_retryable() {
|
|
let payload = json!({
|
|
"type": "error",
|
|
"status": 400,
|
|
"error": {
|
|
"type": "invalid_request_error",
|
|
"code": "websocket_connection_limit_reached",
|
|
"message": "Responses websocket connection limit reached (60 minutes). Create a new websocket connection to continue."
|
|
}
|
|
})
|
|
.to_string();
|
|
|
|
let wrapped_error = parse_wrapped_websocket_error_event(&payload)
|
|
.expect("expected websocket error payload to be parsed");
|
|
let api_error = map_wrapped_websocket_error_event(wrapped_error, payload)
|
|
.expect("expected websocket error payload to map to ApiError");
|
|
let ApiError::Retryable { message, delay } = api_error else {
|
|
panic!("expected ApiError::Retryable");
|
|
};
|
|
assert_eq!(message, WEBSOCKET_CONNECTION_LIMIT_REACHED_MESSAGE);
|
|
assert_eq!(delay, None);
|
|
}
|
|
|
|
#[test]
|
|
fn parse_wrapped_websocket_error_event_without_status_is_not_mapped() {
|
|
let payload = json!({
|
|
"type": "error",
|
|
"error": {
|
|
"type": "usage_limit_reached",
|
|
"message": "The usage limit has been reached"
|
|
},
|
|
"headers": {
|
|
"x-codex-primary-used-percent": "100.0",
|
|
"x-codex-primary-window-minutes": 15
|
|
}
|
|
})
|
|
.to_string();
|
|
|
|
let wrapped_error = parse_wrapped_websocket_error_event(&payload)
|
|
.expect("expected websocket error payload to be parsed");
|
|
let api_error = map_wrapped_websocket_error_event(wrapped_error, payload);
|
|
assert!(api_error.is_none());
|
|
}
|
|
|
|
#[test]
|
|
fn merge_request_headers_matches_http_precedence() {
|
|
let mut provider_headers = HeaderMap::new();
|
|
provider_headers.insert(
|
|
"originator",
|
|
HeaderValue::from_static("provider-originator"),
|
|
);
|
|
provider_headers.insert("x-priority", HeaderValue::from_static("provider"));
|
|
|
|
let mut extra_headers = HeaderMap::new();
|
|
extra_headers.insert("x-priority", HeaderValue::from_static("extra"));
|
|
|
|
let mut default_headers = HeaderMap::new();
|
|
default_headers.insert("originator", HeaderValue::from_static("default-originator"));
|
|
default_headers.insert("x-priority", HeaderValue::from_static("default"));
|
|
default_headers.insert("x-default-only", HeaderValue::from_static("default-only"));
|
|
|
|
let merged = merge_request_headers(&provider_headers, extra_headers, default_headers);
|
|
|
|
assert_eq!(
|
|
merged.get("originator"),
|
|
Some(&HeaderValue::from_static("provider-originator"))
|
|
);
|
|
assert_eq!(
|
|
merged.get("x-priority"),
|
|
Some(&HeaderValue::from_static("extra"))
|
|
);
|
|
assert_eq!(
|
|
merged.get("x-default-only"),
|
|
Some(&HeaderValue::from_static("default-only"))
|
|
);
|
|
}
|
|
}
|