mirror of
https://github.com/openai/codex.git
synced 2026-06-02 11:22:01 +00:00
test(core): cover rollback websocket continuation behavior
This commit is contained in:
@@ -180,7 +180,7 @@ struct ModelClientState {
|
||||
include_attestation: bool,
|
||||
attestation_provider: Option<Arc<dyn AttestationProvider>>,
|
||||
disable_websockets: AtomicBool,
|
||||
cached_websocket_session: StdMutex<CachedWebsocketSession>,
|
||||
cached_websocket_session: StdMutex<WebsocketSession>,
|
||||
}
|
||||
|
||||
/// Resolved API client setup for a single request attempt.
|
||||
@@ -237,7 +237,6 @@ pub struct ModelClient {
|
||||
pub struct ModelClientSession {
|
||||
client: ModelClient,
|
||||
websocket_session: WebsocketSession,
|
||||
websocket_cache_generation: u64,
|
||||
/// Turn state for sticky routing.
|
||||
///
|
||||
/// This is an `OnceLock` that stores the turn state value received from the server
|
||||
@@ -266,12 +265,6 @@ struct WebsocketSession {
|
||||
connection_reused: StdMutex<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
struct CachedWebsocketSession {
|
||||
generation: u64,
|
||||
session: WebsocketSession,
|
||||
}
|
||||
|
||||
impl WebsocketSession {
|
||||
fn set_connection_reused(&self, connection_reused: bool) {
|
||||
*self
|
||||
@@ -358,7 +351,7 @@ impl ModelClient {
|
||||
include_attestation,
|
||||
attestation_provider,
|
||||
disable_websockets: AtomicBool::new(false),
|
||||
cached_websocket_session: StdMutex::new(CachedWebsocketSession::default()),
|
||||
cached_websocket_session: StdMutex::new(WebsocketSession::default()),
|
||||
}),
|
||||
prompt_cache_key_override: None,
|
||||
}
|
||||
@@ -383,11 +376,9 @@ impl ModelClient {
|
||||
/// This constructor does not perform network I/O itself; the session opens a websocket lazily
|
||||
/// when the first stream request is issued.
|
||||
pub fn new_session(&self) -> ModelClientSession {
|
||||
let (websocket_cache_generation, websocket_session) = self.take_cached_websocket_session();
|
||||
ModelClientSession {
|
||||
client: self.clone(),
|
||||
websocket_session,
|
||||
websocket_cache_generation,
|
||||
websocket_session: self.take_cached_websocket_session(),
|
||||
turn_state: Arc::new(OnceLock::new()),
|
||||
}
|
||||
}
|
||||
@@ -400,12 +391,12 @@ impl ModelClient {
|
||||
self.state
|
||||
.window_generation
|
||||
.store(window_generation, Ordering::Relaxed);
|
||||
self.invalidate_cached_websocket_session();
|
||||
self.store_cached_websocket_session(WebsocketSession::default());
|
||||
}
|
||||
|
||||
pub(crate) fn advance_window_generation(&self) {
|
||||
self.state.window_generation.fetch_add(1, Ordering::Relaxed);
|
||||
self.invalidate_cached_websocket_session();
|
||||
self.store_cached_websocket_session(WebsocketSession::default());
|
||||
}
|
||||
|
||||
pub(crate) fn current_window_id(&self) -> String {
|
||||
@@ -414,36 +405,21 @@ impl ModelClient {
|
||||
format!("{thread_id}:{window_generation}")
|
||||
}
|
||||
|
||||
pub(crate) fn invalidate_cached_websocket_session(&self) {
|
||||
fn take_cached_websocket_session(&self) -> WebsocketSession {
|
||||
let mut cached_websocket_session = self
|
||||
.state
|
||||
.cached_websocket_session
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
cached_websocket_session.generation = cached_websocket_session.generation.saturating_add(1);
|
||||
cached_websocket_session.session = WebsocketSession::default();
|
||||
std::mem::take(&mut *cached_websocket_session)
|
||||
}
|
||||
|
||||
fn take_cached_websocket_session(&self) -> (u64, WebsocketSession) {
|
||||
let mut cached_websocket_session = self
|
||||
fn store_cached_websocket_session(&self, websocket_session: WebsocketSession) {
|
||||
*self
|
||||
.state
|
||||
.cached_websocket_session
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
let generation = cached_websocket_session.generation;
|
||||
let session = std::mem::take(&mut cached_websocket_session.session);
|
||||
(generation, session)
|
||||
}
|
||||
|
||||
fn store_cached_websocket_session(&self, generation: u64, websocket_session: WebsocketSession) {
|
||||
let mut cached_websocket_session = self
|
||||
.state
|
||||
.cached_websocket_session
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
if cached_websocket_session.generation == generation {
|
||||
cached_websocket_session.session = websocket_session;
|
||||
}
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner) = websocket_session;
|
||||
}
|
||||
|
||||
pub(crate) fn force_http_fallback(
|
||||
@@ -463,7 +439,7 @@ impl ModelClient {
|
||||
);
|
||||
}
|
||||
|
||||
self.invalidate_cached_websocket_session();
|
||||
self.store_cached_websocket_session(WebsocketSession::default());
|
||||
activated
|
||||
}
|
||||
|
||||
@@ -974,7 +950,7 @@ impl Drop for ModelClientSession {
|
||||
fn drop(&mut self) {
|
||||
let websocket_session = std::mem::take(&mut self.websocket_session);
|
||||
self.client
|
||||
.store_cached_websocket_session(self.websocket_cache_generation, websocket_session);
|
||||
.store_cached_websocket_session(websocket_session);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -569,9 +569,6 @@ pub async fn thread_rollback(sess: &Arc<Session>, sub_id: String, num_turns: u32
|
||||
sess.apply_rollout_reconstruction(turn_context.as_ref(), replay_items.as_slice())
|
||||
.await;
|
||||
sess.recompute_token_usage(turn_context.as_ref()).await;
|
||||
sess.services
|
||||
.model_client
|
||||
.invalidate_cached_websocket_session();
|
||||
|
||||
sess.persist_rollout_items(&[RolloutItem::EventMsg(rollback_msg.clone())])
|
||||
.await;
|
||||
|
||||
@@ -259,33 +259,27 @@ async fn websocket_v2_test_codex_shell_chain() -> Result<()> {
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn websocket_v2_rollback_opens_new_connection_for_rewritten_history() -> Result<()> {
|
||||
async fn websocket_v2_rollback_reuses_connection_without_previous_response_id() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_websocket_server(vec![
|
||||
let server = start_websocket_server(vec![vec![
|
||||
vec![ev_response_created("warm-1"), ev_completed("warm-1")],
|
||||
vec![
|
||||
vec![ev_response_created("warm-1"), ev_completed("warm-1")],
|
||||
vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_assistant_message("msg-1", "kept"),
|
||||
ev_completed("resp-1"),
|
||||
],
|
||||
vec![
|
||||
ev_response_created("resp-2"),
|
||||
ev_assistant_message("msg-2", "discarded"),
|
||||
ev_completed("resp-2"),
|
||||
],
|
||||
vec![
|
||||
ev_response_created("should-not-be-used"),
|
||||
ev_completed("should-not-be-used"),
|
||||
],
|
||||
ev_response_created("resp-1"),
|
||||
ev_assistant_message("msg-1", "kept"),
|
||||
ev_completed("resp-1"),
|
||||
],
|
||||
vec![vec![
|
||||
vec![
|
||||
ev_response_created("resp-2"),
|
||||
ev_assistant_message("msg-2", "discarded"),
|
||||
ev_completed("resp-2"),
|
||||
],
|
||||
vec![
|
||||
ev_response_created("resp-3"),
|
||||
ev_assistant_message("msg-3", "after rollback"),
|
||||
ev_completed("resp-3"),
|
||||
]],
|
||||
])
|
||||
],
|
||||
]])
|
||||
.await;
|
||||
|
||||
let mut builder = test_codex().with_config(|config| {
|
||||
@@ -307,13 +301,12 @@ async fn websocket_v2_rollback_opens_new_connection_for_rewritten_history() -> R
|
||||
.await;
|
||||
test.submit_turn("after rollback").await?;
|
||||
|
||||
assert_eq!(server.handshakes().len(), 2);
|
||||
assert_eq!(server.handshakes().len(), 1);
|
||||
let connections = server.connections();
|
||||
assert_eq!(connections.len(), 2);
|
||||
assert_eq!(connections[0].len(), 3);
|
||||
assert_eq!(connections[1].len(), 1);
|
||||
assert_eq!(connections.len(), 1);
|
||||
assert_eq!(connections[0].len(), 4);
|
||||
|
||||
let after_rollback = connections[1][0].body_json();
|
||||
let after_rollback = connections[0][3].body_json();
|
||||
assert_eq!(after_rollback["type"].as_str(), Some("response.create"));
|
||||
assert_eq!(after_rollback.get("previous_response_id"), None);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user