app-server: allow shared config reads

This commit is contained in:
xli-oai
2026-05-06 03:25:34 -07:00
parent 06e5dfa4dd
commit 0029bf63be
3 changed files with 349 additions and 29 deletions

View File

@@ -77,6 +77,7 @@ macro_rules! experimental_type_entry {
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ClientRequestSerializationScope {
Global(&'static str),
GlobalSharedRead(&'static str),
Thread { thread_id: String },
ThreadPath { path: PathBuf },
CommandExecProcess { process_id: String },
@@ -93,6 +94,9 @@ macro_rules! serialization_scope_expr {
($actual_params:ident, global($key:literal)) => {
Some(ClientRequestSerializationScope::Global($key))
};
($actual_params:ident, global_shared_read($key:literal)) => {
Some(ClientRequestSerializationScope::GlobalSharedRead($key))
};
($actual_params:ident, thread_id($params:ident . $field:ident)) => {
Some(ClientRequestSerializationScope::Thread {
thread_id: $actual_params.$field.clone(),
@@ -585,7 +589,7 @@ client_request_definitions! {
},
SkillsList => "skills/list" {
params: v2::SkillsListParams,
serialization: global("config"),
serialization: global_shared_read("config"),
response: v2::SkillsListResponse,
},
HooksList => "hooks/list" {
@@ -610,7 +614,7 @@ client_request_definitions! {
},
PluginList => "plugin/list" {
params: v2::PluginListParams,
serialization: global("config"),
serialization: global_shared_read("config"),
response: v2::PluginListResponse,
},
PluginRead => "plugin/read" {
@@ -1655,6 +1659,28 @@ mod tests {
Some(ClientRequestSerializationScope::Global("config"))
);
let skills_list = ClientRequest::SkillsList {
request_id: request_id(),
params: v2::SkillsListParams {
cwds: Vec::new(),
force_reload: false,
per_cwd_extra_user_roots: None,
},
};
assert_eq!(
skills_list.serialization_scope(),
Some(ClientRequestSerializationScope::GlobalSharedRead("config"))
);
let plugin_list = ClientRequest::PluginList {
request_id: request_id(),
params: v2::PluginListParams { cwds: None },
};
assert_eq!(
plugin_list.serialization_scope(),
Some(ClientRequestSerializationScope::GlobalSharedRead("config"))
);
let plugin_uninstall = ClientRequest::PluginUninstall {
request_id: request_id(),
params: v2::PluginUninstallParams {

View File

@@ -798,9 +798,9 @@ impl MessageProcessor {
);
if let Some(scope) = serialization_scope {
let key = RequestSerializationQueueKey::from_scope(connection_id, scope);
let (key, access) = RequestSerializationQueueKey::from_scope(connection_id, scope);
self.request_serialization_queues
.enqueue(key, request)
.enqueue(key, access, request)
.await;
} else {
tokio::spawn(async move {

View File

@@ -6,6 +6,7 @@ use std::pin::Pin;
use std::sync::Arc;
use codex_app_server_protocol::ClientRequestSerializationScope;
use futures::future::join_all;
use tokio::sync::Mutex;
use tracing::Instrument;
@@ -43,35 +44,61 @@ pub(crate) enum RequestSerializationQueueKey {
},
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub(crate) enum RequestSerializationAccess {
Exclusive,
SharedRead,
}
impl RequestSerializationQueueKey {
pub(crate) fn from_scope(
connection_id: ConnectionId,
scope: ClientRequestSerializationScope,
) -> Self {
) -> (Self, RequestSerializationAccess) {
match scope {
ClientRequestSerializationScope::Global(name) => Self::Global(name),
ClientRequestSerializationScope::Thread { thread_id } => Self::Thread { thread_id },
ClientRequestSerializationScope::ThreadPath { path } => Self::ThreadPath { path },
ClientRequestSerializationScope::CommandExecProcess { process_id } => {
ClientRequestSerializationScope::Global(name) => {
(Self::Global(name), RequestSerializationAccess::Exclusive)
}
ClientRequestSerializationScope::GlobalSharedRead(name) => {
(Self::Global(name), RequestSerializationAccess::SharedRead)
}
ClientRequestSerializationScope::Thread { thread_id } => (
Self::Thread { thread_id },
RequestSerializationAccess::Exclusive,
),
ClientRequestSerializationScope::ThreadPath { path } => (
Self::ThreadPath { path },
RequestSerializationAccess::Exclusive,
),
ClientRequestSerializationScope::CommandExecProcess { process_id } => (
Self::CommandExecProcess {
connection_id,
process_id,
}
}
ClientRequestSerializationScope::Process { process_handle } => Self::Process {
connection_id,
process_handle,
},
ClientRequestSerializationScope::FuzzyFileSearchSession { session_id } => {
Self::FuzzyFileSearchSession { session_id }
}
ClientRequestSerializationScope::FsWatch { watch_id } => Self::FsWatch {
connection_id,
watch_id,
},
ClientRequestSerializationScope::McpOauth { server_name } => {
Self::McpOauth { server_name }
}
},
RequestSerializationAccess::Exclusive,
),
ClientRequestSerializationScope::Process { process_handle } => (
Self::Process {
connection_id,
process_handle,
},
RequestSerializationAccess::Exclusive,
),
ClientRequestSerializationScope::FuzzyFileSearchSession { session_id } => (
Self::FuzzyFileSearchSession { session_id },
RequestSerializationAccess::Exclusive,
),
ClientRequestSerializationScope::FsWatch { watch_id } => (
Self::FsWatch {
connection_id,
watch_id,
},
RequestSerializationAccess::Exclusive,
),
ClientRequestSerializationScope::McpOauth { server_name } => (
Self::McpOauth { server_name },
RequestSerializationAccess::Exclusive,
),
}
}
}
@@ -98,17 +125,24 @@ impl QueuedInitializedRequest {
}
}
struct QueuedSerializedRequest {
access: RequestSerializationAccess,
request: QueuedInitializedRequest,
}
#[derive(Clone, Default)]
pub(crate) struct RequestSerializationQueues {
inner: Arc<Mutex<HashMap<RequestSerializationQueueKey, VecDeque<QueuedInitializedRequest>>>>,
inner: Arc<Mutex<HashMap<RequestSerializationQueueKey, VecDeque<QueuedSerializedRequest>>>>,
}
impl RequestSerializationQueues {
pub(crate) async fn enqueue(
&self,
key: RequestSerializationQueueKey,
access: RequestSerializationAccess,
request: QueuedInitializedRequest,
) {
let request = QueuedSerializedRequest { access, request };
let should_spawn = {
let mut queues = self.inner.lock().await;
match queues.get_mut(&key) {
@@ -134,13 +168,27 @@ impl RequestSerializationQueues {
async fn drain(self, key: RequestSerializationQueueKey) {
loop {
let request = {
let requests = {
let mut queues = self.inner.lock().await;
let Some(queue) = queues.get_mut(&key) else {
return;
};
match queue.pop_front() {
Some(request) => request,
Some(request) => {
let access = request.access;
let mut requests = vec![request];
if access == RequestSerializationAccess::SharedRead {
while queue.front().is_some_and(|request| {
request.access == RequestSerializationAccess::SharedRead
}) {
let Some(request) = queue.pop_front() else {
break;
};
requests.push(request);
}
}
requests
}
None => {
queues.remove(&key);
return;
@@ -148,7 +196,7 @@ impl RequestSerializationQueues {
}
};
request.run().await;
join_all(requests.into_iter().map(|request| request.request.run())).await;
}
}
}
@@ -158,6 +206,7 @@ mod tests {
use super::*;
use pretty_assertions::assert_eq;
use std::sync::Arc;
use tokio::sync::broadcast;
use tokio::sync::mpsc;
use tokio::sync::oneshot;
use tokio::time::Duration;
@@ -195,6 +244,7 @@ mod tests {
queues
.enqueue(
key.clone(),
RequestSerializationAccess::Exclusive,
QueuedInitializedRequest::new(Arc::clone(&gate), async move {
tx.send(value).expect("receiver should be open");
}),
@@ -230,6 +280,7 @@ mod tests {
queues
.enqueue(
RequestSerializationQueueKey::Global("blocked"),
RequestSerializationAccess::Exclusive,
QueuedInitializedRequest::new(gate(), async move {
let _ = blocked_rx.await;
}),
@@ -238,6 +289,7 @@ mod tests {
queues
.enqueue(
RequestSerializationQueueKey::Global("other"),
RequestSerializationAccess::Exclusive,
QueuedInitializedRequest::new(gate(), async move {
ran_tx.send(()).expect("receiver should be open");
}),
@@ -268,6 +320,7 @@ mod tests {
queues
.enqueue(
key.clone(),
RequestSerializationAccess::Exclusive,
QueuedInitializedRequest::new(Arc::clone(&live_gate), async move {
tx.send(FIRST_REQUEST_VALUE)
.expect("receiver should be open");
@@ -281,6 +334,7 @@ mod tests {
queues
.enqueue(
key.clone(),
RequestSerializationAccess::Exclusive,
QueuedInitializedRequest::new(closed_gate, async move {
tx.send(SECOND_REQUEST_VALUE)
.expect("receiver should be open");
@@ -293,6 +347,7 @@ mod tests {
queues
.enqueue(
key,
RequestSerializationAccess::Exclusive,
QueuedInitializedRequest::new(live_gate, async move {
tx.send(THIRD_REQUEST_VALUE)
.expect("receiver should be open");
@@ -336,6 +391,7 @@ mod tests {
queues
.enqueue(
key.clone(),
RequestSerializationAccess::Exclusive,
QueuedInitializedRequest::new(Arc::clone(&live_gate), async move {
tx.send(FIRST_REQUEST_VALUE)
.expect("receiver should be open");
@@ -349,6 +405,7 @@ mod tests {
queues
.enqueue(
key,
RequestSerializationAccess::Exclusive,
QueuedInitializedRequest::new(live_gate.clone(), async move {
tx.send(SECOND_REQUEST_VALUE)
.expect("receiver should be open");
@@ -385,4 +442,241 @@ mod tests {
None
);
}
#[tokio::test]
async fn same_key_shared_reads_run_concurrently() {
let queues = RequestSerializationQueues::default();
let key = RequestSerializationQueueKey::Global("test");
let (blocker_started_tx, blocker_started_rx) = oneshot::channel::<()>();
let (blocker_release_tx, blocker_release_rx) = oneshot::channel::<()>();
let (started_tx, mut started_rx) = mpsc::unbounded_channel();
let (release_tx, _) = broadcast::channel::<()>(/*capacity*/ 1);
queues
.enqueue(
key.clone(),
RequestSerializationAccess::Exclusive,
QueuedInitializedRequest::new(gate(), async move {
blocker_started_tx
.send(())
.expect("receiver should be open");
let _ = blocker_release_rx.await;
}),
)
.await;
timeout(queue_drain_timeout(), blocker_started_rx)
.await
.expect("blocker should start")
.expect("sender should be open");
for value in [FIRST_REQUEST_VALUE, SECOND_REQUEST_VALUE] {
let started_tx = started_tx.clone();
let mut release_rx = release_tx.subscribe();
queues
.enqueue(
key.clone(),
RequestSerializationAccess::SharedRead,
QueuedInitializedRequest::new(gate(), async move {
started_tx.send(value).expect("receiver should be open");
let _ = release_rx.recv().await;
}),
)
.await;
}
drop(started_tx);
blocker_release_tx
.send(())
.expect("blocker should still be waiting");
let mut started = Vec::new();
for _ in 0..2 {
started.push(
timeout(queue_drain_timeout(), started_rx.recv())
.await
.expect("timed out waiting for shared read")
.expect("sender should be open"),
);
}
assert_eq!(started, vec![FIRST_REQUEST_VALUE, SECOND_REQUEST_VALUE]);
release_tx
.send(())
.expect("shared reads should still be waiting");
}
#[tokio::test]
async fn exclusive_write_waits_for_running_shared_reads() {
let queues = RequestSerializationQueues::default();
let key = RequestSerializationQueueKey::Global("test");
let (blocker_started_tx, blocker_started_rx) = oneshot::channel::<()>();
let (blocker_release_tx, blocker_release_rx) = oneshot::channel::<()>();
let (read_started_tx, mut read_started_rx) = mpsc::unbounded_channel();
let (read_release_tx, _) = broadcast::channel::<()>(/*capacity*/ 1);
let (write_started_tx, write_started_rx) = oneshot::channel::<()>();
queues
.enqueue(
key.clone(),
RequestSerializationAccess::Exclusive,
QueuedInitializedRequest::new(gate(), async move {
blocker_started_tx
.send(())
.expect("receiver should be open");
let _ = blocker_release_rx.await;
}),
)
.await;
timeout(queue_drain_timeout(), blocker_started_rx)
.await
.expect("blocker should start")
.expect("sender should be open");
for value in [FIRST_REQUEST_VALUE, SECOND_REQUEST_VALUE] {
let read_started_tx = read_started_tx.clone();
let mut read_release_rx = read_release_tx.subscribe();
queues
.enqueue(
key.clone(),
RequestSerializationAccess::SharedRead,
QueuedInitializedRequest::new(gate(), async move {
read_started_tx
.send(value)
.expect("receiver should be open");
let _ = read_release_rx.recv().await;
}),
)
.await;
}
queues
.enqueue(
key.clone(),
RequestSerializationAccess::Exclusive,
QueuedInitializedRequest::new(gate(), async move {
write_started_tx.send(()).expect("receiver should be open");
}),
)
.await;
drop(read_started_tx);
blocker_release_tx
.send(())
.expect("blocker should still be waiting");
for _ in 0..2 {
timeout(queue_drain_timeout(), read_started_rx.recv())
.await
.expect("timed out waiting for shared read")
.expect("sender should be open");
}
let mut write_started_rx = Box::pin(write_started_rx);
timeout(shutdown_wait_timeout(), &mut write_started_rx)
.await
.expect_err("write should wait for running shared reads");
read_release_tx
.send(())
.expect("shared reads should still be waiting");
timeout(queue_drain_timeout(), &mut write_started_rx)
.await
.expect("write should start after shared reads finish")
.expect("sender should be open");
}
#[tokio::test]
async fn later_shared_reads_do_not_jump_ahead_of_queued_write() {
let queues = RequestSerializationQueues::default();
let key = RequestSerializationQueueKey::Global("test");
let (blocker_started_tx, blocker_started_rx) = oneshot::channel::<()>();
let (blocker_release_tx, blocker_release_rx) = oneshot::channel::<()>();
let (first_read_started_tx, first_read_started_rx) = oneshot::channel::<()>();
let (first_read_release_tx, first_read_release_rx) = oneshot::channel::<()>();
let (write_started_tx, write_started_rx) = oneshot::channel::<()>();
let (write_release_tx, write_release_rx) = oneshot::channel::<()>();
let (later_read_started_tx, later_read_started_rx) = oneshot::channel::<()>();
queues
.enqueue(
key.clone(),
RequestSerializationAccess::Exclusive,
QueuedInitializedRequest::new(gate(), async move {
blocker_started_tx
.send(())
.expect("receiver should be open");
let _ = blocker_release_rx.await;
}),
)
.await;
timeout(queue_drain_timeout(), blocker_started_rx)
.await
.expect("blocker should start")
.expect("sender should be open");
queues
.enqueue(
key.clone(),
RequestSerializationAccess::SharedRead,
QueuedInitializedRequest::new(gate(), async move {
first_read_started_tx
.send(())
.expect("receiver should be open");
let _ = first_read_release_rx.await;
}),
)
.await;
queues
.enqueue(
key.clone(),
RequestSerializationAccess::Exclusive,
QueuedInitializedRequest::new(gate(), async move {
write_started_tx.send(()).expect("receiver should be open");
let _ = write_release_rx.await;
}),
)
.await;
queues
.enqueue(
key.clone(),
RequestSerializationAccess::SharedRead,
QueuedInitializedRequest::new(gate(), async move {
later_read_started_tx
.send(())
.expect("receiver should be open");
}),
)
.await;
blocker_release_tx
.send(())
.expect("blocker should still be waiting");
timeout(queue_drain_timeout(), first_read_started_rx)
.await
.expect("first read should start")
.expect("sender should be open");
let mut write_started_rx = Box::pin(write_started_rx);
timeout(shutdown_wait_timeout(), &mut write_started_rx)
.await
.expect_err("write should wait for the first read");
let mut later_read_started_rx = Box::pin(later_read_started_rx);
timeout(shutdown_wait_timeout(), &mut later_read_started_rx)
.await
.expect_err("later read should wait behind the queued write");
first_read_release_tx
.send(())
.expect("first read should still be waiting");
timeout(queue_drain_timeout(), &mut write_started_rx)
.await
.expect("write should start after the first read")
.expect("sender should be open");
timeout(shutdown_wait_timeout(), &mut later_read_started_rx)
.await
.expect_err("later read should still wait while the write is running");
write_release_tx
.send(())
.expect("write should still be waiting");
timeout(queue_drain_timeout(), &mut later_read_started_rx)
.await
.expect("later read should start after the write")
.expect("sender should be open");
}
}