Compare commits

...

1 Commits

Author SHA1 Message Date
xli-oai
fe7044a2aa allow later shared reads to join running read window 2026-05-07 03:21:22 -07:00

View File

@@ -6,7 +6,6 @@ use std::pin::Pin;
use std::sync::Arc;
use codex_app_server_protocol::ClientRequestSerializationScope;
use futures::future::join_all;
use tokio::sync::Mutex;
use tracing::Instrument;
@@ -130,9 +129,71 @@ struct QueuedSerializedRequest {
request: QueuedInitializedRequest,
}
#[derive(Default)]
struct RequestSerializationQueueState {
pending: VecDeque<QueuedSerializedRequest>,
running_shared_reads: usize,
exclusive_running: bool,
}
impl RequestSerializationQueueState {
fn enqueue(&mut self, request: QueuedSerializedRequest) {
self.pending.push_back(request);
}
fn take_ready_requests(&mut self) -> Vec<QueuedSerializedRequest> {
if self.exclusive_running {
return Vec::new();
}
match self.pending.front().map(|request| request.access) {
Some(RequestSerializationAccess::Exclusive) if self.running_shared_reads == 0 => {
let Some(request) = self.pending.pop_front() else {
return Vec::new();
};
self.exclusive_running = true;
vec![request]
}
Some(RequestSerializationAccess::SharedRead) => {
let mut requests = Vec::new();
while self
.pending
.front()
.is_some_and(|request| request.access == RequestSerializationAccess::SharedRead)
{
let Some(request) = self.pending.pop_front() else {
break;
};
self.running_shared_reads += 1;
requests.push(request);
}
requests
}
Some(RequestSerializationAccess::Exclusive) | None => Vec::new(),
}
}
fn complete(&mut self, access: RequestSerializationAccess) {
match access {
RequestSerializationAccess::Exclusive => {
debug_assert!(self.exclusive_running);
self.exclusive_running = false;
}
RequestSerializationAccess::SharedRead => {
debug_assert!(self.running_shared_reads > 0);
self.running_shared_reads -= 1;
}
}
}
fn is_idle(&self) -> bool {
self.pending.is_empty() && self.running_shared_reads == 0 && !self.exclusive_running
}
}
#[derive(Clone, Default)]
pub(crate) struct RequestSerializationQueues {
inner: Arc<Mutex<HashMap<RequestSerializationQueueKey, VecDeque<QueuedSerializedRequest>>>>,
inner: Arc<Mutex<HashMap<RequestSerializationQueueKey, RequestSerializationQueueState>>>,
}
impl RequestSerializationQueues {
@@ -143,61 +204,56 @@ impl RequestSerializationQueues {
request: QueuedInitializedRequest,
) {
let request = QueuedSerializedRequest { access, request };
let should_spawn = {
let ready_requests = {
let mut queues = self.inner.lock().await;
match queues.get_mut(&key) {
Some(queue) => {
queue.push_back(request);
false
}
None => {
let mut queue = VecDeque::new();
queue.push_back(request);
queues.insert(key.clone(), queue);
true
}
}
let queue = queues.entry(key.clone()).or_default();
queue.enqueue(request);
queue.take_ready_requests()
};
if should_spawn {
self.spawn_ready_requests(key, ready_requests);
}
fn spawn_ready_requests(
&self,
key: RequestSerializationQueueKey,
requests: Vec<QueuedSerializedRequest>,
) {
for request in requests {
let queues = self.clone();
let span = tracing::debug_span!("app_server.serialized_request_queue", ?key);
tokio::spawn(async move { queues.drain(key).await }.instrument(span));
let request_key = key.clone();
let span = tracing::debug_span!("app_server.serialized_request_queue", ?request_key);
tokio::spawn(
async move {
let access = request.access;
request.request.run().await;
queues.complete(request_key, access).await;
}
.instrument(span),
);
}
}
async fn drain(self, key: RequestSerializationQueueKey) {
loop {
let requests = {
let mut queues = self.inner.lock().await;
let Some(queue) = queues.get_mut(&key) else {
return;
};
match queue.pop_front() {
Some(request) => {
let access = request.access;
let mut requests = vec![request];
if access == RequestSerializationAccess::SharedRead {
while queue.front().is_some_and(|request| {
request.access == RequestSerializationAccess::SharedRead
}) {
let Some(request) = queue.pop_front() else {
break;
};
requests.push(request);
}
}
requests
}
None => {
queues.remove(&key);
return;
}
}
async fn complete(
&self,
key: RequestSerializationQueueKey,
access: RequestSerializationAccess,
) {
let ready_requests = {
let mut queues = self.inner.lock().await;
let Some(queue) = queues.get_mut(&key) else {
return;
};
queue.complete(access);
let ready_requests = queue.take_ready_requests();
let should_remove = queue.is_idle();
if should_remove {
queues.remove(&key);
}
ready_requests
};
join_all(requests.into_iter().map(|request| request.request.run())).await;
}
self.spawn_ready_requests(key, ready_requests);
}
}
@@ -504,6 +560,52 @@ mod tests {
.expect("shared reads should still be waiting");
}
#[tokio::test]
async fn later_shared_reads_join_running_shared_reads_without_queued_write() {
let queues = RequestSerializationQueues::default();
let key = RequestSerializationQueueKey::Global("test");
let (first_read_started_tx, first_read_started_rx) = oneshot::channel::<()>();
let (first_read_release_tx, first_read_release_rx) = oneshot::channel::<()>();
let (later_read_started_tx, later_read_started_rx) = oneshot::channel::<()>();
queues
.enqueue(
key.clone(),
RequestSerializationAccess::SharedRead,
QueuedInitializedRequest::new(gate(), async move {
first_read_started_tx
.send(())
.expect("receiver should be open");
let _ = first_read_release_rx.await;
}),
)
.await;
timeout(queue_drain_timeout(), first_read_started_rx)
.await
.expect("first read should start")
.expect("sender should be open");
queues
.enqueue(
key,
RequestSerializationAccess::SharedRead,
QueuedInitializedRequest::new(gate(), async move {
later_read_started_tx
.send(())
.expect("receiver should be open");
}),
)
.await;
timeout(queue_drain_timeout(), later_read_started_rx)
.await
.expect("later read should join running reads")
.expect("sender should be open");
first_read_release_tx
.send(())
.expect("first read should still be waiting");
}
#[tokio::test]
async fn exclusive_write_waits_for_running_shared_reads() {
let queues = RequestSerializationQueues::default();