fix prewarm tier reset

This commit is contained in:
pash
2026-03-01 23:14:49 -08:00
parent 04a8e3d6d3
commit ca548ea800
3 changed files with 61 additions and 0 deletions

View File

@@ -502,6 +502,23 @@ impl ModelClient {
}
}
impl ModelClientSession {
/// Reuse the startup-prewarmed websocket connection when the first real turn changes service
/// tier, but invalidate request-level reuse state so the next request becomes a fresh
/// `response.create`.
///
/// This keeps the handshake win from prewarm without leaking stale warmup request state into
/// the first user turn.
pub fn reset_prewarm_for_service_tier(&mut self, service_tier: ServiceTier) {
if self.service_tier == service_tier {
return;
}
self.service_tier = service_tier;
self.websocket_session.last_request = None;
self.websocket_session.last_response_rx = None;
}
}
impl Drop for ModelClientSession {
fn drop(&mut self) {
let websocket_session = std::mem::take(&mut self.websocket_session);

View File

@@ -4894,6 +4894,7 @@ pub(crate) async fn run_turn(
.model_client
.new_session_with_service_tier(turn_context.config.service_tier)
});
client_session.reset_prewarm_for_service_tier(turn_context.config.service_tier);
loop {
// Note that pending_input would be something like a message the user

View File

@@ -196,6 +196,49 @@ async fn responses_websocket_request_prewarm_reuses_connection() {
server.shutdown().await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn responses_websocket_request_prewarm_resets_request_state_when_service_tier_changes() {
skip_if_no_network!();
let server = start_websocket_server(vec![vec![
vec![ev_response_created("warm-1"), ev_done_with_id("warm-1")],
vec![ev_response_created("resp-1"), ev_completed("resp-1")],
]])
.await;
let harness = websocket_harness_with_options(&server, false, false, true, true).await;
let mut client_session = harness.client.new_session();
let prompt = prompt_with_input(vec![message_item("hello")]);
client_session
.prewarm_websocket(
&prompt,
&harness.model_info,
&harness.otel_manager,
harness.effort,
harness.summary,
None,
)
.await
.expect("websocket prewarm failed");
client_session.reset_prewarm_for_service_tier(ServiceTier::Fast);
stream_until_complete(&mut client_session, &harness, &prompt).await;
assert_eq!(server.handshakes().len(), 1);
let connection = server.single_connection();
assert_eq!(connection.len(), 2);
let follow_up = connection
.get(1)
.expect("missing follow-up request")
.body_json();
assert_eq!(follow_up["type"].as_str(), Some("response.create"));
assert_eq!(follow_up["service_tier"].as_str(), Some("priority"));
assert_eq!(follow_up.get("previous_response_id"), None);
server.shutdown().await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn responses_websocket_reuses_connection_after_session_drop() {
skip_if_no_network!();