test(core): cover rollback websocket continuation behavior

This commit is contained in:
Ningyi Xie
2026-06-01 18:40:18 -07:00
parent 8b29608673
commit 59d4a3cb49
3 changed files with 30 additions and 64 deletions

View File

@@ -180,7 +180,7 @@ struct ModelClientState {
include_attestation: bool,
attestation_provider: Option<Arc<dyn AttestationProvider>>,
disable_websockets: AtomicBool,
cached_websocket_session: StdMutex<CachedWebsocketSession>,
cached_websocket_session: StdMutex<WebsocketSession>,
}
/// Resolved API client setup for a single request attempt.
@@ -237,7 +237,6 @@ pub struct ModelClient {
pub struct ModelClientSession {
client: ModelClient,
websocket_session: WebsocketSession,
websocket_cache_generation: u64,
/// Turn state for sticky routing.
///
/// This is an `OnceLock` that stores the turn state value received from the server
@@ -266,12 +265,6 @@ struct WebsocketSession {
connection_reused: StdMutex<bool>,
}
#[derive(Debug, Default)]
struct CachedWebsocketSession {
generation: u64,
session: WebsocketSession,
}
impl WebsocketSession {
fn set_connection_reused(&self, connection_reused: bool) {
*self
@@ -358,7 +351,7 @@ impl ModelClient {
include_attestation,
attestation_provider,
disable_websockets: AtomicBool::new(false),
cached_websocket_session: StdMutex::new(CachedWebsocketSession::default()),
cached_websocket_session: StdMutex::new(WebsocketSession::default()),
}),
prompt_cache_key_override: None,
}
@@ -383,11 +376,9 @@ impl ModelClient {
/// This constructor does not perform network I/O itself; the session opens a websocket lazily
/// when the first stream request is issued.
pub fn new_session(&self) -> ModelClientSession {
let (websocket_cache_generation, websocket_session) = self.take_cached_websocket_session();
ModelClientSession {
client: self.clone(),
websocket_session,
websocket_cache_generation,
websocket_session: self.take_cached_websocket_session(),
turn_state: Arc::new(OnceLock::new()),
}
}
@@ -400,12 +391,12 @@ impl ModelClient {
self.state
.window_generation
.store(window_generation, Ordering::Relaxed);
self.invalidate_cached_websocket_session();
self.store_cached_websocket_session(WebsocketSession::default());
}
pub(crate) fn advance_window_generation(&self) {
self.state.window_generation.fetch_add(1, Ordering::Relaxed);
self.invalidate_cached_websocket_session();
self.store_cached_websocket_session(WebsocketSession::default());
}
pub(crate) fn current_window_id(&self) -> String {
@@ -414,36 +405,21 @@ impl ModelClient {
format!("{thread_id}:{window_generation}")
}
pub(crate) fn invalidate_cached_websocket_session(&self) {
fn take_cached_websocket_session(&self) -> WebsocketSession {
let mut cached_websocket_session = self
.state
.cached_websocket_session
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
cached_websocket_session.generation = cached_websocket_session.generation.saturating_add(1);
cached_websocket_session.session = WebsocketSession::default();
std::mem::take(&mut *cached_websocket_session)
}
fn take_cached_websocket_session(&self) -> (u64, WebsocketSession) {
let mut cached_websocket_session = self
fn store_cached_websocket_session(&self, websocket_session: WebsocketSession) {
*self
.state
.cached_websocket_session
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
let generation = cached_websocket_session.generation;
let session = std::mem::take(&mut cached_websocket_session.session);
(generation, session)
}
fn store_cached_websocket_session(&self, generation: u64, websocket_session: WebsocketSession) {
let mut cached_websocket_session = self
.state
.cached_websocket_session
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
if cached_websocket_session.generation == generation {
cached_websocket_session.session = websocket_session;
}
.unwrap_or_else(std::sync::PoisonError::into_inner) = websocket_session;
}
pub(crate) fn force_http_fallback(
@@ -463,7 +439,7 @@ impl ModelClient {
);
}
self.invalidate_cached_websocket_session();
self.store_cached_websocket_session(WebsocketSession::default());
activated
}
@@ -974,7 +950,7 @@ impl Drop for ModelClientSession {
fn drop(&mut self) {
let websocket_session = std::mem::take(&mut self.websocket_session);
self.client
.store_cached_websocket_session(self.websocket_cache_generation, websocket_session);
.store_cached_websocket_session(websocket_session);
}
}

View File

@@ -569,9 +569,6 @@ pub async fn thread_rollback(sess: &Arc<Session>, sub_id: String, num_turns: u32
sess.apply_rollout_reconstruction(turn_context.as_ref(), replay_items.as_slice())
.await;
sess.recompute_token_usage(turn_context.as_ref()).await;
sess.services
.model_client
.invalidate_cached_websocket_session();
sess.persist_rollout_items(&[RolloutItem::EventMsg(rollback_msg.clone())])
.await;

View File

@@ -259,33 +259,27 @@ async fn websocket_v2_test_codex_shell_chain() -> Result<()> {
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn websocket_v2_rollback_opens_new_connection_for_rewritten_history() -> Result<()> {
async fn websocket_v2_rollback_reuses_connection_without_previous_response_id() -> Result<()> {
skip_if_no_network!(Ok(()));
let server = start_websocket_server(vec![
let server = start_websocket_server(vec![vec![
vec![ev_response_created("warm-1"), ev_completed("warm-1")],
vec![
vec![ev_response_created("warm-1"), ev_completed("warm-1")],
vec![
ev_response_created("resp-1"),
ev_assistant_message("msg-1", "kept"),
ev_completed("resp-1"),
],
vec![
ev_response_created("resp-2"),
ev_assistant_message("msg-2", "discarded"),
ev_completed("resp-2"),
],
vec![
ev_response_created("should-not-be-used"),
ev_completed("should-not-be-used"),
],
ev_response_created("resp-1"),
ev_assistant_message("msg-1", "kept"),
ev_completed("resp-1"),
],
vec![vec![
vec![
ev_response_created("resp-2"),
ev_assistant_message("msg-2", "discarded"),
ev_completed("resp-2"),
],
vec![
ev_response_created("resp-3"),
ev_assistant_message("msg-3", "after rollback"),
ev_completed("resp-3"),
]],
])
],
]])
.await;
let mut builder = test_codex().with_config(|config| {
@@ -307,13 +301,12 @@ async fn websocket_v2_rollback_opens_new_connection_for_rewritten_history() -> R
.await;
test.submit_turn("after rollback").await?;
assert_eq!(server.handshakes().len(), 2);
assert_eq!(server.handshakes().len(), 1);
let connections = server.connections();
assert_eq!(connections.len(), 2);
assert_eq!(connections[0].len(), 3);
assert_eq!(connections[1].len(), 1);
assert_eq!(connections.len(), 1);
assert_eq!(connections[0].len(), 4);
let after_rollback = connections[1][0].body_json();
let after_rollback = connections[0][3].body_json();
assert_eq!(after_rollback["type"].as_str(), Some("response.create"));
assert_eq!(after_rollback.get("previous_response_id"), None);