diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs index cbffab7b33..07a1c2adf5 100644 --- a/codex-rs/core/src/client.rs +++ b/codex-rs/core/src/client.rs @@ -260,6 +260,7 @@ struct WebsocketSession { connection: Option, last_request: Option, last_response_rx: Option>, + last_response_from_untraced_warmup: bool, connection_reused: StdMutex, } @@ -941,6 +942,7 @@ impl ModelClientSession { self.websocket_session.connection = None; self.websocket_session.last_request = None; self.websocket_session.last_response_rx = None; + self.websocket_session.last_response_from_untraced_warmup = false; self.websocket_session .set_connection_reused(/*connection_reused*/ false); } @@ -1044,28 +1046,33 @@ impl ModelClientSession { &mut self, payload: ResponseCreateWsRequest, request: &ResponsesApiRequest, - ) -> ResponsesWsRequest { + ) -> (ResponsesWsRequest, bool) { let Some(last_response) = self.get_last_response() else { - return ResponsesWsRequest::ResponseCreate(payload); + return (ResponsesWsRequest::ResponseCreate(payload), false); }; + let previous_response_id_from_untraced_warmup = + self.websocket_session.last_response_from_untraced_warmup; let Some(incremental_items) = self.get_incremental_items( request, Some(&last_response), /*allow_empty_delta*/ true, ) else { - return ResponsesWsRequest::ResponseCreate(payload); + return (ResponsesWsRequest::ResponseCreate(payload), false); }; if last_response.response_id.is_empty() { trace!("incremental request failed, no previous response id"); - return ResponsesWsRequest::ResponseCreate(payload); + return (ResponsesWsRequest::ResponseCreate(payload), false); } - ResponsesWsRequest::ResponseCreate(ResponseCreateWsRequest { - previous_response_id: Some(last_response.response_id), - input: incremental_items, - ..payload - }) + ( + ResponsesWsRequest::ResponseCreate(ResponseCreateWsRequest { + previous_response_id: Some(last_response.response_id), + input: incremental_items, + ..payload + }), + previous_response_id_from_untraced_warmup, + ) } /// Opportunistically preconnects a websocket for this turn-scoped client session. @@ -1144,6 +1151,7 @@ impl ModelClientSession { if needs_new { self.websocket_session.last_request = None; self.websocket_session.last_response_rx = None; + self.websocket_session.last_response_from_untraced_warmup = false; let turn_state = options .turn_state .clone() @@ -1412,8 +1420,8 @@ impl ModelClientSession { Err(err) => return Err(map_api_error(err)), } - let mut ws_request = self.prepare_websocket_request(ws_payload, &request); - self.websocket_session.last_request = Some(request); + let (mut ws_request, previous_response_id_from_untraced_warmup) = + self.prepare_websocket_request(ws_payload, &request); let inference_trace_attempt = if warmup { // Prewarm sends `generate=false`; it is connection setup, not a // model inference attempt that should appear in rollout traces. @@ -1422,7 +1430,16 @@ impl ModelClientSession { inference_trace.start_attempt() }; stamp_ws_stream_request_start_ms(&mut ws_request); - inference_trace_attempt.record_started(&ws_request); + if previous_response_id_from_untraced_warmup { + // The transport can reuse an untraced warmup response id and omit the + // already-sent input, but rollout replay needs the logical model-visible + // request rather than the compressed websocket delta. + inference_trace_attempt.record_started(&request); + } else { + inference_trace_attempt.record_started(&ws_request); + } + self.websocket_session.last_request = Some(request); + self.websocket_session.last_response_from_untraced_warmup = warmup; let websocket_connection = self.websocket_session.connection.as_ref().ok_or_else(|| { map_api_error(ApiError::Stream( diff --git a/codex-rs/core/tests/suite/client_websockets.rs b/codex-rs/core/tests/suite/client_websockets.rs index feee9d4227..74eb053325 100755 --- a/codex-rs/core/tests/suite/client_websockets.rs +++ b/codex-rs/core/tests/suite/client_websockets.rs @@ -30,6 +30,11 @@ use codex_protocol::protocol::Op; use codex_protocol::protocol::SessionSource; use codex_protocol::protocol::W3cTraceContext; use codex_protocol::user_input::UserInput; +use codex_rollout_trace::ConversationPart; +use codex_rollout_trace::InferenceTraceContext; +use codex_rollout_trace::RawTraceEventPayload; +use codex_rollout_trace::TraceWriter; +use codex_rollout_trace::replay_bundle; use core_test_support::load_default_config_for_test; use core_test_support::responses::WebSocketConnectionConfig; use core_test_support::responses::WebSocketTestServer; @@ -535,6 +540,112 @@ async fn responses_websocket_request_prewarm_reuses_connection() { server.shutdown().await; } +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn responses_websocket_request_prewarm_traces_logical_request() { + skip_if_no_network!(); + + let server = start_websocket_server(vec![vec![ + vec![ev_response_created("warm-1"), ev_completed("warm-1")], + vec![ev_response_created("resp-1"), ev_completed("resp-1")], + ]]) + .await; + + let harness = websocket_harness_with_options(&server, /*runtime_metrics_enabled*/ true).await; + let mut client_session = harness.client.new_session(); + let prompt = prompt_with_input(vec![message_item("hello")]); + + client_session + .prewarm_websocket( + &prompt, + &harness.model_info, + &harness.session_telemetry, + harness.effort, + harness.summary, + /*service_tier*/ None, + /*turn_metadata_header*/ None, + ) + .await + .expect("websocket prewarm failed"); + + let trace_dir = TempDir::new().expect("trace dir"); + let writer = Arc::new( + TraceWriter::create( + trace_dir.path(), + "trace-1".to_string(), + harness.session_id.to_string(), + harness.thread_id.to_string(), + ) + .expect("trace writer"), + ); + writer + .append(RawTraceEventPayload::ThreadStarted { + thread_id: harness.thread_id.to_string(), + agent_path: "/root".to_string(), + metadata_payload: None, + }) + .expect("thread started"); + writer + .append(RawTraceEventPayload::CodexTurnStarted { + codex_turn_id: "turn-1".to_string(), + thread_id: harness.thread_id.to_string(), + }) + .expect("turn started"); + + let inference_trace = InferenceTraceContext::enabled( + writer, + harness.thread_id.to_string(), + "turn-1".to_string(), + harness.model_info.slug.clone(), + "test-provider".to_string(), + ); + + let mut stream = client_session + .stream( + &prompt, + &harness.model_info, + &harness.session_telemetry, + harness.effort, + harness.summary, + /*service_tier*/ None, + /*turn_metadata_header*/ None, + &inference_trace, + ) + .await + .expect("websocket stream failed"); + + while let Some(event) = stream.next().await { + if matches!(event, Ok(ResponseEvent::Completed { .. })) { + break; + } + } + + let connection = server.single_connection(); + let follow_up = connection + .get(1) + .expect("missing follow-up request") + .body_json(); + assert_eq!(follow_up["previous_response_id"].as_str(), Some("warm-1")); + assert_eq!(follow_up["input"], serde_json::json!([])); + + let rollout = replay_bundle(trace_dir.path()).expect("replay trace"); + let inference = rollout + .inference_calls + .values() + .next() + .expect("inference should be present"); + assert_eq!(inference.request_item_ids.len(), 1); + assert_eq!( + rollout.conversation_items[&inference.request_item_ids[0]] + .body + .parts, + vec![ConversationPart::Text { + text: "hello".to_string(), + }], + ); + + server.shutdown().await; +} + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn responses_websocket_reuses_connection_after_session_drop() { skip_if_no_network!(); diff --git a/codex-rs/rollout-trace/src/inference.rs b/codex-rs/rollout-trace/src/inference.rs index 20366c2c70..a4977b0dc5 100644 --- a/codex-rs/rollout-trace/src/inference.rs +++ b/codex-rs/rollout-trace/src/inference.rs @@ -166,7 +166,11 @@ impl InferenceTraceAttempt { headers.insert(INFERENCE_CALL_ID_HEADER, inference_call_id); } - /// Records the exact request object about to be sent to the model provider. + /// Records the request payload replay should treat as the model-visible inference input. + /// + /// This is usually the exact provider request. Callers may instead pass a + /// logical request when the transport omits already-sent input, such as + /// websocket reuse after an untraced warmup response. pub fn record_started(&self, request: &impl Serialize) { let InferenceTraceAttemptState::Enabled(attempt) = &self.state else { return;