Compare commits

...

1 Commits

Author SHA1 Message Date
Rasmus Rygaard
74d0570cd7 Surface error on WS close, only retry retryable errors 2026-02-27 20:07:19 -08:00
8 changed files with 201 additions and 5 deletions

View File

@@ -33,6 +33,7 @@ use tokio_tungstenite::WebSocketStream;
use tokio_tungstenite::tungstenite::Error as WsError;
use tokio_tungstenite::tungstenite::Message;
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
use tokio_tungstenite::tungstenite::protocol::CloseFrame;
use tracing::debug;
use tracing::error;
use tracing::info;
@@ -40,6 +41,7 @@ use tracing::trace;
use tungstenite::extensions::ExtensionsConfig;
use tungstenite::extensions::compression::deflate::DeflateConfig;
use tungstenite::protocol::WebSocketConfig;
use tungstenite::protocol::frame::coding::CloseCode;
use url::Url;
struct WsStream {
@@ -419,6 +421,41 @@ fn map_ws_error(err: WsError, url: &Url) -> ApiError {
}
}
fn map_websocket_close(close_frame: Option<&CloseFrame>) -> ApiError {
let message = format_websocket_close_message(close_frame);
match close_frame {
Some(frame) if should_retry_websocket_close_code(frame.code) => ApiError::Retryable {
message,
delay: None,
},
Some(_) => ApiError::NonRetryableStream(message),
None => ApiError::Stream(message),
}
}
fn format_websocket_close_message(close_frame: Option<&CloseFrame>) -> String {
let Some(frame) = close_frame else {
return "websocket closed by server before response.completed".to_string();
};
let code = u16::from(frame.code);
if frame.reason.is_empty() {
format!("websocket closed by server before response.completed (code {code})")
} else {
format!(
"websocket closed by server before response.completed (code {code}: {})",
frame.reason
)
}
}
fn should_retry_websocket_close_code(code: CloseCode) -> bool {
matches!(
code,
CloseCode::Away | CloseCode::Error | CloseCode::Restart | CloseCode::Again
)
}
#[derive(Debug, Deserialize)]
struct WrappedWebsocketError {
code: Option<String>,
@@ -607,10 +644,8 @@ async fn run_websocket_response_stream(
Message::Binary(_) => {
return Err(ApiError::Stream("unexpected binary websocket event".into()));
}
Message::Close(_) => {
return Err(ApiError::Stream(
"websocket closed by server before response.completed".into(),
));
Message::Close(close_frame) => {
return Err(map_websocket_close(close_frame.as_ref()));
}
Message::Frame(_) => {}
Message::Ping(_) | Message::Pong(_) => {}
@@ -768,6 +803,41 @@ mod tests {
assert!(api_error.is_none());
}
#[test]
fn websocket_close_bad_code_is_non_retryable_and_surfaces_reason() {
let close_frame = CloseFrame {
code: CloseCode::Bad(108),
reason: "server-side validation failed".into(),
};
let api_error = map_websocket_close(Some(&close_frame));
let ApiError::NonRetryableStream(message) = api_error else {
panic!("expected ApiError::NonRetryableStream");
};
assert_eq!(
message,
"websocket closed by server before response.completed (code 108: server-side validation failed)"
);
}
#[test]
fn websocket_close_again_is_retryable_and_surfaces_reason() {
let close_frame = CloseFrame {
code: CloseCode::Again,
reason: "retry after rebalance".into(),
};
let api_error = map_websocket_close(Some(&close_frame));
let ApiError::Retryable { message, delay } = api_error else {
panic!("expected ApiError::Retryable");
};
assert_eq!(
message,
"websocket closed by server before response.completed (code 1013: retry after rebalance)"
);
assert_eq!(delay, None);
}
#[test]
fn merge_request_headers_matches_http_precedence() {
let mut provider_headers = HeaderMap::new();

View File

@@ -12,6 +12,8 @@ pub enum ApiError {
Api { status: StatusCode, message: String },
#[error("stream error: {0}")]
Stream(String),
#[error("stream error: {0}")]
NonRetryableStream(String),
#[error("context window exceeded")]
ContextWindowExceeded,
#[error("quota exceeded")]

View File

@@ -23,6 +23,7 @@ pub(crate) fn map_api_error(err: ApiError) -> CodexErr {
ApiError::UsageNotIncluded => CodexErr::UsageNotIncluded,
ApiError::Retryable { message, delay } => CodexErr::Stream(message, delay),
ApiError::Stream(msg) => CodexErr::Stream(msg, None),
ApiError::NonRetryableStream(msg) => CodexErr::NonRetryableStream(msg),
ApiError::ServerOverloaded => CodexErr::ServerOverloaded,
ApiError::Api { status, message } => CodexErr::UnexpectedStatus(UnexpectedResponseError {
status,

View File

@@ -74,6 +74,11 @@ pub enum CodexErr {
#[error("stream disconnected before completion: {0}")]
Stream(String, Option<Duration>),
/// Returned when the stream terminated in a way that should be surfaced immediately instead
/// of retried automatically.
#[error("stream disconnected before completion: {0}")]
NonRetryableStream(String),
#[error(
"Codex ran out of room in the model's context window. Start a new thread or clear earlier history before retrying."
)]
@@ -208,6 +213,7 @@ impl CodexErr {
| CodexErr::LandlockSandboxExecutableNotProvided
| CodexErr::RetryLimit(_)
| CodexErr::ContextWindowExceeded
| CodexErr::NonRetryableStream(_)
| CodexErr::ThreadNotFound(_)
| CodexErr::AgentLimitReached { .. }
| CodexErr::Spawn

View File

@@ -19,6 +19,7 @@ use tokio_tungstenite::tungstenite::extensions::ExtensionsConfig;
use tokio_tungstenite::tungstenite::extensions::compression::deflate::DeflateConfig;
use tokio_tungstenite::tungstenite::handshake::server::Request;
use tokio_tungstenite::tungstenite::handshake::server::Response;
use tokio_tungstenite::tungstenite::protocol::CloseFrame;
use tokio_tungstenite::tungstenite::protocol::WebSocketConfig;
use wiremock::BodyPrintLimit;
use wiremock::Match;
@@ -294,6 +295,12 @@ impl WebSocketHandshake {
}
}
#[derive(Debug, Clone)]
pub struct WebSocketCloseFrame {
pub code: u16,
pub reason: String,
}
#[derive(Debug, Clone)]
pub struct WebSocketConnectionConfig {
pub requests: Vec<Vec<Value>>,
@@ -303,6 +310,8 @@ pub struct WebSocketConnectionConfig {
/// Tests use this to force websocket setup into an in-flight state so first-turn warmup paths
/// can be exercised deterministically.
pub accept_delay: Option<Duration>,
/// Optional close frame sent after all configured request events are emitted.
pub close_frame: Option<WebSocketCloseFrame>,
}
pub struct WebSocketTestServer {
@@ -1035,6 +1044,7 @@ pub async fn start_websocket_server(connections: Vec<Vec<Vec<Value>>>) -> WebSoc
requests,
response_headers: Vec::new(),
accept_delay: None,
close_frame: None,
})
.collect();
start_websocket_server_with_headers(connections).await
@@ -1146,7 +1156,11 @@ pub async fn start_websocket_server_with_headers(
}
}
let _ = ws_stream.close(None).await;
let close_frame = connection.close_frame.map(|frame| CloseFrame {
code: frame.code.into(),
reason: frame.reason.into(),
});
let _ = ws_stream.close(close_frame).await;
if connections.lock().unwrap().is_empty() {
return;

View File

@@ -130,6 +130,7 @@ async fn websocket_first_turn_handles_handshake_delay_with_preconnect() -> Resul
response_headers: Vec::new(),
// Delay handshake so turn processing must tolerate websocket startup latency.
accept_delay: Some(Duration::from_millis(150)),
close_frame: None,
}])
.await;

View File

@@ -26,6 +26,7 @@ use codex_protocol::protocol::Op;
use codex_protocol::protocol::SessionSource;
use codex_protocol::user_input::UserInput;
use core_test_support::load_default_config_for_test;
use core_test_support::responses::WebSocketCloseFrame;
use core_test_support::responses::WebSocketConnectionConfig;
use core_test_support::responses::WebSocketTestServer;
use core_test_support::responses::ev_assistant_message;
@@ -658,6 +659,7 @@ async fn responses_websocket_emits_reasoning_included_event() {
requests: vec![vec![ev_response_created("resp-1"), ev_completed("resp-1")]],
response_headers: vec![("X-Reasoning-Included".to_string(), "true".to_string())],
accept_delay: None,
close_frame: None,
}])
.await;
@@ -729,6 +731,7 @@ async fn responses_websocket_emits_rate_limit_events() {
("X-Reasoning-Included".to_string(), "true".to_string()),
],
accept_delay: None,
close_frame: None,
}])
.await;
@@ -961,6 +964,102 @@ async fn responses_websocket_connection_limit_error_reconnects_and_completes() {
server.shutdown().await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn responses_websocket_policy_close_with_reason_does_not_retry() {
skip_if_no_network!();
let server = start_websocket_server_with_headers(vec![WebSocketConnectionConfig {
requests: vec![vec![ev_response_created("resp-1")]],
response_headers: Vec::new(),
accept_delay: None,
close_frame: Some(WebSocketCloseFrame {
code: 1008,
reason: "policy violation".to_string(),
}),
}])
.await;
let harness = websocket_harness(&server).await;
let mut client_session = harness.client.new_session();
let prompt = prompt_with_input(vec![message_item("hello")]);
let mut stream = client_session
.stream(
&prompt,
&harness.model_info,
&harness.otel_manager,
harness.effort,
harness.summary,
None,
)
.await
.expect("websocket stream should start");
let err = loop {
let event = stream
.next()
.await
.expect("expected websocket event before stream termination");
match event {
Ok(_) => {}
Err(err) => break err,
}
};
let message = err.to_string();
assert!(message.contains("code 1008"), "unexpected error: {message}");
assert!(
message.contains("policy violation"),
"unexpected error: {message}"
);
let total_websocket_requests: usize = server.connections().iter().map(Vec::len).sum();
assert_eq!(total_websocket_requests, 1);
server.shutdown().await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn responses_websocket_again_close_with_reason_retries_and_completes() {
skip_if_no_network!();
let server = start_websocket_server_with_headers(vec![
WebSocketConnectionConfig {
requests: vec![vec![ev_response_created("resp-1")]],
response_headers: Vec::new(),
accept_delay: None,
close_frame: Some(WebSocketCloseFrame {
code: 1013,
reason: "retry after rebalance".to_string(),
}),
},
WebSocketConnectionConfig {
requests: vec![vec![ev_response_created("resp-2"), ev_completed("resp-2")]],
response_headers: Vec::new(),
accept_delay: None,
close_frame: None,
},
])
.await;
let mut builder = test_codex().with_config(|config| {
config.model_provider.request_max_retries = Some(0);
config.model_provider.stream_max_retries = Some(1);
});
let test = builder
.build_with_websocket_server(&server)
.await
.expect("build websocket codex");
test.submit_turn("hello")
.await
.expect("submission should retry after websocket again close");
let total_websocket_requests: usize = server.connections().iter().map(Vec::len).sum();
assert_eq!(total_websocket_requests, 2);
server.shutdown().await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn responses_websocket_appends_on_prefix() {
skip_if_no_network!();

View File

@@ -104,6 +104,7 @@ async fn websocket_turn_state_persists_within_turn_and_resets_after() -> Result<
]],
response_headers: vec![(TURN_STATE_HEADER.to_string(), "ts-1".to_string())],
accept_delay: None,
close_frame: None,
},
WebSocketConnectionConfig {
requests: vec![vec![
@@ -113,6 +114,7 @@ async fn websocket_turn_state_persists_within_turn_and_resets_after() -> Result<
]],
response_headers: Vec::new(),
accept_delay: None,
close_frame: None,
},
WebSocketConnectionConfig {
requests: vec![vec![
@@ -122,6 +124,7 @@ async fn websocket_turn_state_persists_within_turn_and_resets_after() -> Result<
]],
response_headers: Vec::new(),
accept_delay: None,
close_frame: None,
},
])
.await;