From 0029bf63bee53eacad5191c4d72f0b3d28b7027d Mon Sep 17 00:00:00 2001 From: xli-oai Date: Wed, 6 May 2026 03:25:34 -0700 Subject: [PATCH] app-server: allow shared config reads --- .../src/protocol/common.rs | 30 +- codex-rs/app-server/src/message_processor.rs | 4 +- .../app-server/src/request_serialization.rs | 344 ++++++++++++++++-- 3 files changed, 349 insertions(+), 29 deletions(-) diff --git a/codex-rs/app-server-protocol/src/protocol/common.rs b/codex-rs/app-server-protocol/src/protocol/common.rs index 8fd267d014..f7faae2f4a 100644 --- a/codex-rs/app-server-protocol/src/protocol/common.rs +++ b/codex-rs/app-server-protocol/src/protocol/common.rs @@ -77,6 +77,7 @@ macro_rules! experimental_type_entry { #[derive(Debug, Clone, PartialEq, Eq)] pub enum ClientRequestSerializationScope { Global(&'static str), + GlobalSharedRead(&'static str), Thread { thread_id: String }, ThreadPath { path: PathBuf }, CommandExecProcess { process_id: String }, @@ -93,6 +94,9 @@ macro_rules! serialization_scope_expr { ($actual_params:ident, global($key:literal)) => { Some(ClientRequestSerializationScope::Global($key)) }; + ($actual_params:ident, global_shared_read($key:literal)) => { + Some(ClientRequestSerializationScope::GlobalSharedRead($key)) + }; ($actual_params:ident, thread_id($params:ident . $field:ident)) => { Some(ClientRequestSerializationScope::Thread { thread_id: $actual_params.$field.clone(), @@ -585,7 +589,7 @@ client_request_definitions! { }, SkillsList => "skills/list" { params: v2::SkillsListParams, - serialization: global("config"), + serialization: global_shared_read("config"), response: v2::SkillsListResponse, }, HooksList => "hooks/list" { @@ -610,7 +614,7 @@ client_request_definitions! { }, PluginList => "plugin/list" { params: v2::PluginListParams, - serialization: global("config"), + serialization: global_shared_read("config"), response: v2::PluginListResponse, }, PluginRead => "plugin/read" { @@ -1655,6 +1659,28 @@ mod tests { Some(ClientRequestSerializationScope::Global("config")) ); + let skills_list = ClientRequest::SkillsList { + request_id: request_id(), + params: v2::SkillsListParams { + cwds: Vec::new(), + force_reload: false, + per_cwd_extra_user_roots: None, + }, + }; + assert_eq!( + skills_list.serialization_scope(), + Some(ClientRequestSerializationScope::GlobalSharedRead("config")) + ); + + let plugin_list = ClientRequest::PluginList { + request_id: request_id(), + params: v2::PluginListParams { cwds: None }, + }; + assert_eq!( + plugin_list.serialization_scope(), + Some(ClientRequestSerializationScope::GlobalSharedRead("config")) + ); + let plugin_uninstall = ClientRequest::PluginUninstall { request_id: request_id(), params: v2::PluginUninstallParams { diff --git a/codex-rs/app-server/src/message_processor.rs b/codex-rs/app-server/src/message_processor.rs index e0cc3bd176..f9036af000 100644 --- a/codex-rs/app-server/src/message_processor.rs +++ b/codex-rs/app-server/src/message_processor.rs @@ -798,9 +798,9 @@ impl MessageProcessor { ); if let Some(scope) = serialization_scope { - let key = RequestSerializationQueueKey::from_scope(connection_id, scope); + let (key, access) = RequestSerializationQueueKey::from_scope(connection_id, scope); self.request_serialization_queues - .enqueue(key, request) + .enqueue(key, access, request) .await; } else { tokio::spawn(async move { diff --git a/codex-rs/app-server/src/request_serialization.rs b/codex-rs/app-server/src/request_serialization.rs index 0eb509e098..0dd167b74d 100644 --- a/codex-rs/app-server/src/request_serialization.rs +++ b/codex-rs/app-server/src/request_serialization.rs @@ -6,6 +6,7 @@ use std::pin::Pin; use std::sync::Arc; use codex_app_server_protocol::ClientRequestSerializationScope; +use futures::future::join_all; use tokio::sync::Mutex; use tracing::Instrument; @@ -43,35 +44,61 @@ pub(crate) enum RequestSerializationQueueKey { }, } +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub(crate) enum RequestSerializationAccess { + Exclusive, + SharedRead, +} + impl RequestSerializationQueueKey { pub(crate) fn from_scope( connection_id: ConnectionId, scope: ClientRequestSerializationScope, - ) -> Self { + ) -> (Self, RequestSerializationAccess) { match scope { - ClientRequestSerializationScope::Global(name) => Self::Global(name), - ClientRequestSerializationScope::Thread { thread_id } => Self::Thread { thread_id }, - ClientRequestSerializationScope::ThreadPath { path } => Self::ThreadPath { path }, - ClientRequestSerializationScope::CommandExecProcess { process_id } => { + ClientRequestSerializationScope::Global(name) => { + (Self::Global(name), RequestSerializationAccess::Exclusive) + } + ClientRequestSerializationScope::GlobalSharedRead(name) => { + (Self::Global(name), RequestSerializationAccess::SharedRead) + } + ClientRequestSerializationScope::Thread { thread_id } => ( + Self::Thread { thread_id }, + RequestSerializationAccess::Exclusive, + ), + ClientRequestSerializationScope::ThreadPath { path } => ( + Self::ThreadPath { path }, + RequestSerializationAccess::Exclusive, + ), + ClientRequestSerializationScope::CommandExecProcess { process_id } => ( Self::CommandExecProcess { connection_id, process_id, - } - } - ClientRequestSerializationScope::Process { process_handle } => Self::Process { - connection_id, - process_handle, - }, - ClientRequestSerializationScope::FuzzyFileSearchSession { session_id } => { - Self::FuzzyFileSearchSession { session_id } - } - ClientRequestSerializationScope::FsWatch { watch_id } => Self::FsWatch { - connection_id, - watch_id, - }, - ClientRequestSerializationScope::McpOauth { server_name } => { - Self::McpOauth { server_name } - } + }, + RequestSerializationAccess::Exclusive, + ), + ClientRequestSerializationScope::Process { process_handle } => ( + Self::Process { + connection_id, + process_handle, + }, + RequestSerializationAccess::Exclusive, + ), + ClientRequestSerializationScope::FuzzyFileSearchSession { session_id } => ( + Self::FuzzyFileSearchSession { session_id }, + RequestSerializationAccess::Exclusive, + ), + ClientRequestSerializationScope::FsWatch { watch_id } => ( + Self::FsWatch { + connection_id, + watch_id, + }, + RequestSerializationAccess::Exclusive, + ), + ClientRequestSerializationScope::McpOauth { server_name } => ( + Self::McpOauth { server_name }, + RequestSerializationAccess::Exclusive, + ), } } } @@ -98,17 +125,24 @@ impl QueuedInitializedRequest { } } +struct QueuedSerializedRequest { + access: RequestSerializationAccess, + request: QueuedInitializedRequest, +} + #[derive(Clone, Default)] pub(crate) struct RequestSerializationQueues { - inner: Arc>>>, + inner: Arc>>>, } impl RequestSerializationQueues { pub(crate) async fn enqueue( &self, key: RequestSerializationQueueKey, + access: RequestSerializationAccess, request: QueuedInitializedRequest, ) { + let request = QueuedSerializedRequest { access, request }; let should_spawn = { let mut queues = self.inner.lock().await; match queues.get_mut(&key) { @@ -134,13 +168,27 @@ impl RequestSerializationQueues { async fn drain(self, key: RequestSerializationQueueKey) { loop { - let request = { + let requests = { let mut queues = self.inner.lock().await; let Some(queue) = queues.get_mut(&key) else { return; }; match queue.pop_front() { - Some(request) => request, + Some(request) => { + let access = request.access; + let mut requests = vec![request]; + if access == RequestSerializationAccess::SharedRead { + while queue.front().is_some_and(|request| { + request.access == RequestSerializationAccess::SharedRead + }) { + let Some(request) = queue.pop_front() else { + break; + }; + requests.push(request); + } + } + requests + } None => { queues.remove(&key); return; @@ -148,7 +196,7 @@ impl RequestSerializationQueues { } }; - request.run().await; + join_all(requests.into_iter().map(|request| request.request.run())).await; } } } @@ -158,6 +206,7 @@ mod tests { use super::*; use pretty_assertions::assert_eq; use std::sync::Arc; + use tokio::sync::broadcast; use tokio::sync::mpsc; use tokio::sync::oneshot; use tokio::time::Duration; @@ -195,6 +244,7 @@ mod tests { queues .enqueue( key.clone(), + RequestSerializationAccess::Exclusive, QueuedInitializedRequest::new(Arc::clone(&gate), async move { tx.send(value).expect("receiver should be open"); }), @@ -230,6 +280,7 @@ mod tests { queues .enqueue( RequestSerializationQueueKey::Global("blocked"), + RequestSerializationAccess::Exclusive, QueuedInitializedRequest::new(gate(), async move { let _ = blocked_rx.await; }), @@ -238,6 +289,7 @@ mod tests { queues .enqueue( RequestSerializationQueueKey::Global("other"), + RequestSerializationAccess::Exclusive, QueuedInitializedRequest::new(gate(), async move { ran_tx.send(()).expect("receiver should be open"); }), @@ -268,6 +320,7 @@ mod tests { queues .enqueue( key.clone(), + RequestSerializationAccess::Exclusive, QueuedInitializedRequest::new(Arc::clone(&live_gate), async move { tx.send(FIRST_REQUEST_VALUE) .expect("receiver should be open"); @@ -281,6 +334,7 @@ mod tests { queues .enqueue( key.clone(), + RequestSerializationAccess::Exclusive, QueuedInitializedRequest::new(closed_gate, async move { tx.send(SECOND_REQUEST_VALUE) .expect("receiver should be open"); @@ -293,6 +347,7 @@ mod tests { queues .enqueue( key, + RequestSerializationAccess::Exclusive, QueuedInitializedRequest::new(live_gate, async move { tx.send(THIRD_REQUEST_VALUE) .expect("receiver should be open"); @@ -336,6 +391,7 @@ mod tests { queues .enqueue( key.clone(), + RequestSerializationAccess::Exclusive, QueuedInitializedRequest::new(Arc::clone(&live_gate), async move { tx.send(FIRST_REQUEST_VALUE) .expect("receiver should be open"); @@ -349,6 +405,7 @@ mod tests { queues .enqueue( key, + RequestSerializationAccess::Exclusive, QueuedInitializedRequest::new(live_gate.clone(), async move { tx.send(SECOND_REQUEST_VALUE) .expect("receiver should be open"); @@ -385,4 +442,241 @@ mod tests { None ); } + + #[tokio::test] + async fn same_key_shared_reads_run_concurrently() { + let queues = RequestSerializationQueues::default(); + let key = RequestSerializationQueueKey::Global("test"); + let (blocker_started_tx, blocker_started_rx) = oneshot::channel::<()>(); + let (blocker_release_tx, blocker_release_rx) = oneshot::channel::<()>(); + let (started_tx, mut started_rx) = mpsc::unbounded_channel(); + let (release_tx, _) = broadcast::channel::<()>(/*capacity*/ 1); + + queues + .enqueue( + key.clone(), + RequestSerializationAccess::Exclusive, + QueuedInitializedRequest::new(gate(), async move { + blocker_started_tx + .send(()) + .expect("receiver should be open"); + let _ = blocker_release_rx.await; + }), + ) + .await; + timeout(queue_drain_timeout(), blocker_started_rx) + .await + .expect("blocker should start") + .expect("sender should be open"); + + for value in [FIRST_REQUEST_VALUE, SECOND_REQUEST_VALUE] { + let started_tx = started_tx.clone(); + let mut release_rx = release_tx.subscribe(); + queues + .enqueue( + key.clone(), + RequestSerializationAccess::SharedRead, + QueuedInitializedRequest::new(gate(), async move { + started_tx.send(value).expect("receiver should be open"); + let _ = release_rx.recv().await; + }), + ) + .await; + } + drop(started_tx); + blocker_release_tx + .send(()) + .expect("blocker should still be waiting"); + + let mut started = Vec::new(); + for _ in 0..2 { + started.push( + timeout(queue_drain_timeout(), started_rx.recv()) + .await + .expect("timed out waiting for shared read") + .expect("sender should be open"), + ); + } + assert_eq!(started, vec![FIRST_REQUEST_VALUE, SECOND_REQUEST_VALUE]); + + release_tx + .send(()) + .expect("shared reads should still be waiting"); + } + + #[tokio::test] + async fn exclusive_write_waits_for_running_shared_reads() { + let queues = RequestSerializationQueues::default(); + let key = RequestSerializationQueueKey::Global("test"); + let (blocker_started_tx, blocker_started_rx) = oneshot::channel::<()>(); + let (blocker_release_tx, blocker_release_rx) = oneshot::channel::<()>(); + let (read_started_tx, mut read_started_rx) = mpsc::unbounded_channel(); + let (read_release_tx, _) = broadcast::channel::<()>(/*capacity*/ 1); + let (write_started_tx, write_started_rx) = oneshot::channel::<()>(); + + queues + .enqueue( + key.clone(), + RequestSerializationAccess::Exclusive, + QueuedInitializedRequest::new(gate(), async move { + blocker_started_tx + .send(()) + .expect("receiver should be open"); + let _ = blocker_release_rx.await; + }), + ) + .await; + timeout(queue_drain_timeout(), blocker_started_rx) + .await + .expect("blocker should start") + .expect("sender should be open"); + + for value in [FIRST_REQUEST_VALUE, SECOND_REQUEST_VALUE] { + let read_started_tx = read_started_tx.clone(); + let mut read_release_rx = read_release_tx.subscribe(); + queues + .enqueue( + key.clone(), + RequestSerializationAccess::SharedRead, + QueuedInitializedRequest::new(gate(), async move { + read_started_tx + .send(value) + .expect("receiver should be open"); + let _ = read_release_rx.recv().await; + }), + ) + .await; + } + queues + .enqueue( + key.clone(), + RequestSerializationAccess::Exclusive, + QueuedInitializedRequest::new(gate(), async move { + write_started_tx.send(()).expect("receiver should be open"); + }), + ) + .await; + drop(read_started_tx); + blocker_release_tx + .send(()) + .expect("blocker should still be waiting"); + + for _ in 0..2 { + timeout(queue_drain_timeout(), read_started_rx.recv()) + .await + .expect("timed out waiting for shared read") + .expect("sender should be open"); + } + let mut write_started_rx = Box::pin(write_started_rx); + timeout(shutdown_wait_timeout(), &mut write_started_rx) + .await + .expect_err("write should wait for running shared reads"); + + read_release_tx + .send(()) + .expect("shared reads should still be waiting"); + timeout(queue_drain_timeout(), &mut write_started_rx) + .await + .expect("write should start after shared reads finish") + .expect("sender should be open"); + } + + #[tokio::test] + async fn later_shared_reads_do_not_jump_ahead_of_queued_write() { + let queues = RequestSerializationQueues::default(); + let key = RequestSerializationQueueKey::Global("test"); + let (blocker_started_tx, blocker_started_rx) = oneshot::channel::<()>(); + let (blocker_release_tx, blocker_release_rx) = oneshot::channel::<()>(); + let (first_read_started_tx, first_read_started_rx) = oneshot::channel::<()>(); + let (first_read_release_tx, first_read_release_rx) = oneshot::channel::<()>(); + let (write_started_tx, write_started_rx) = oneshot::channel::<()>(); + let (write_release_tx, write_release_rx) = oneshot::channel::<()>(); + let (later_read_started_tx, later_read_started_rx) = oneshot::channel::<()>(); + + queues + .enqueue( + key.clone(), + RequestSerializationAccess::Exclusive, + QueuedInitializedRequest::new(gate(), async move { + blocker_started_tx + .send(()) + .expect("receiver should be open"); + let _ = blocker_release_rx.await; + }), + ) + .await; + timeout(queue_drain_timeout(), blocker_started_rx) + .await + .expect("blocker should start") + .expect("sender should be open"); + + queues + .enqueue( + key.clone(), + RequestSerializationAccess::SharedRead, + QueuedInitializedRequest::new(gate(), async move { + first_read_started_tx + .send(()) + .expect("receiver should be open"); + let _ = first_read_release_rx.await; + }), + ) + .await; + queues + .enqueue( + key.clone(), + RequestSerializationAccess::Exclusive, + QueuedInitializedRequest::new(gate(), async move { + write_started_tx.send(()).expect("receiver should be open"); + let _ = write_release_rx.await; + }), + ) + .await; + queues + .enqueue( + key.clone(), + RequestSerializationAccess::SharedRead, + QueuedInitializedRequest::new(gate(), async move { + later_read_started_tx + .send(()) + .expect("receiver should be open"); + }), + ) + .await; + blocker_release_tx + .send(()) + .expect("blocker should still be waiting"); + + timeout(queue_drain_timeout(), first_read_started_rx) + .await + .expect("first read should start") + .expect("sender should be open"); + let mut write_started_rx = Box::pin(write_started_rx); + timeout(shutdown_wait_timeout(), &mut write_started_rx) + .await + .expect_err("write should wait for the first read"); + let mut later_read_started_rx = Box::pin(later_read_started_rx); + timeout(shutdown_wait_timeout(), &mut later_read_started_rx) + .await + .expect_err("later read should wait behind the queued write"); + + first_read_release_tx + .send(()) + .expect("first read should still be waiting"); + timeout(queue_drain_timeout(), &mut write_started_rx) + .await + .expect("write should start after the first read") + .expect("sender should be open"); + timeout(shutdown_wait_timeout(), &mut later_read_started_rx) + .await + .expect_err("later read should still wait while the write is running"); + + write_release_tx + .send(()) + .expect("write should still be waiting"); + timeout(queue_drain_timeout(), &mut later_read_started_rx) + .await + .expect("later read should start after the write") + .expect("sender should be open"); + } }