use std::collections::HashMap; use std::collections::VecDeque; use std::future::Future; use std::path::PathBuf; use std::pin::Pin; use std::sync::Arc; use codex_app_server_protocol::ClientRequestSerializationScope; use tokio::sync::Mutex; use tracing::Instrument; use crate::connection_rpc_gate::ConnectionRpcGate; use crate::outgoing_message::ConnectionId; type BoxFutureUnit = Pin + Send + 'static>>; #[derive(Clone, Debug, Eq, Hash, PartialEq)] pub(crate) enum RequestSerializationQueueKey { Global(&'static str), Thread { thread_id: String, }, ThreadPath { path: PathBuf, }, CommandExecProcess { connection_id: ConnectionId, process_id: String, }, FuzzyFileSearchSession { session_id: String, }, FsWatch { connection_id: ConnectionId, watch_id: String, }, McpOauth { server_name: String, }, } impl RequestSerializationQueueKey { pub(crate) fn from_scope( connection_id: ConnectionId, scope: ClientRequestSerializationScope, ) -> Self { match scope { ClientRequestSerializationScope::Global(name) => Self::Global(name), ClientRequestSerializationScope::Thread { thread_id } => Self::Thread { thread_id }, ClientRequestSerializationScope::ThreadPath { path } => Self::ThreadPath { path }, ClientRequestSerializationScope::CommandExecProcess { process_id } => { Self::CommandExecProcess { connection_id, process_id, } } ClientRequestSerializationScope::FuzzyFileSearchSession { session_id } => { Self::FuzzyFileSearchSession { session_id } } ClientRequestSerializationScope::FsWatch { watch_id } => Self::FsWatch { connection_id, watch_id, }, ClientRequestSerializationScope::McpOauth { server_name } => { Self::McpOauth { server_name } } } } } pub(crate) struct QueuedInitializedRequest { gate: Arc, future: BoxFutureUnit, } impl QueuedInitializedRequest { pub(crate) fn new( gate: Arc, future: impl Future + Send + 'static, ) -> Self { Self { gate, future: Box::pin(future), } } pub(crate) async fn run(self) { let Self { gate, future } = self; gate.run(future).await; } } #[derive(Clone, Default)] pub(crate) struct RequestSerializationQueues { inner: Arc>>>, } impl RequestSerializationQueues { pub(crate) async fn enqueue( &self, key: RequestSerializationQueueKey, request: QueuedInitializedRequest, ) { let should_spawn = { let mut queues = self.inner.lock().await; match queues.get_mut(&key) { Some(queue) => { queue.push_back(request); false } None => { let mut queue = VecDeque::new(); queue.push_back(request); queues.insert(key.clone(), queue); true } } }; if should_spawn { let queues = self.clone(); let span = tracing::debug_span!("app_server.serialized_request_queue", ?key); tokio::spawn(async move { queues.drain(key).await }.instrument(span)); } } async fn drain(self, key: RequestSerializationQueueKey) { loop { let request = { let mut queues = self.inner.lock().await; let Some(queue) = queues.get_mut(&key) else { return; }; match queue.pop_front() { Some(request) => request, None => { queues.remove(&key); return; } } }; request.run().await; } } } #[cfg(test)] mod tests { use super::*; use pretty_assertions::assert_eq; use std::sync::Arc; use tokio::sync::mpsc; use tokio::sync::oneshot; use tokio::time::Duration; use tokio::time::timeout; const FIRST_REQUEST_VALUE: i32 = 1; const SECOND_REQUEST_VALUE: i32 = 2; const THIRD_REQUEST_VALUE: i32 = 3; fn gate() -> Arc { Arc::new(ConnectionRpcGate::new()) } fn queue_drain_timeout() -> Duration { Duration::from_secs(/*secs*/ 1) } fn shutdown_wait_timeout() -> Duration { Duration::from_millis(/*millis*/ 50) } #[tokio::test] async fn same_key_requests_run_fifo() { let queues = RequestSerializationQueues::default(); let key = RequestSerializationQueueKey::Global("test"); let gate = gate(); let (tx, mut rx) = mpsc::unbounded_channel(); for value in [ FIRST_REQUEST_VALUE, SECOND_REQUEST_VALUE, THIRD_REQUEST_VALUE, ] { let tx = tx.clone(); queues .enqueue( key.clone(), QueuedInitializedRequest::new(Arc::clone(&gate), async move { tx.send(value).expect("receiver should be open"); }), ) .await; } drop(tx); let mut values = Vec::new(); while let Some(value) = timeout(queue_drain_timeout(), rx.recv()) .await .expect("timed out waiting for queued request") { values.push(value); } assert_eq!( values, vec![ FIRST_REQUEST_VALUE, SECOND_REQUEST_VALUE, THIRD_REQUEST_VALUE ] ); } #[tokio::test] async fn different_keys_run_concurrently() { let queues = RequestSerializationQueues::default(); let (blocked_tx, blocked_rx) = oneshot::channel::<()>(); let (ran_tx, ran_rx) = oneshot::channel::<()>(); queues .enqueue( RequestSerializationQueueKey::Global("blocked"), QueuedInitializedRequest::new(gate(), async move { let _ = blocked_rx.await; }), ) .await; queues .enqueue( RequestSerializationQueueKey::Global("other"), QueuedInitializedRequest::new(gate(), async move { ran_tx.send(()).expect("receiver should be open"); }), ) .await; timeout(queue_drain_timeout(), ran_rx) .await .expect("other key should not be blocked") .expect("sender should be open"); blocked_tx .send(()) .expect("blocked request should be waiting"); } #[tokio::test] async fn closed_gate_request_is_skipped_and_following_requests_continue() { let queues = RequestSerializationQueues::default(); let key = RequestSerializationQueueKey::Global("test"); let live_gate = gate(); let closed_gate = gate(); closed_gate.shutdown().await; let (tx, mut rx) = mpsc::unbounded_channel(); let (blocked_tx, blocked_rx) = oneshot::channel::<()>(); { let tx = tx.clone(); queues .enqueue( key.clone(), QueuedInitializedRequest::new(Arc::clone(&live_gate), async move { tx.send(FIRST_REQUEST_VALUE) .expect("receiver should be open"); let _ = blocked_rx.await; }), ) .await; } { let tx = tx.clone(); queues .enqueue( key.clone(), QueuedInitializedRequest::new(closed_gate, async move { tx.send(SECOND_REQUEST_VALUE) .expect("receiver should be open"); }), ) .await; } { let tx = tx.clone(); queues .enqueue( key, QueuedInitializedRequest::new(live_gate, async move { tx.send(THIRD_REQUEST_VALUE) .expect("receiver should be open"); }), ) .await; } drop(tx); assert_eq!( timeout(queue_drain_timeout(), rx.recv()) .await .expect("timed out waiting for first request"), Some(FIRST_REQUEST_VALUE) ); blocked_tx .send(()) .expect("blocked request should be waiting"); let mut values = Vec::new(); while let Some(value) = timeout(queue_drain_timeout(), rx.recv()) .await .expect("timed out waiting for queue to drain") { values.push(value); } assert_eq!(values, vec![THIRD_REQUEST_VALUE]); } #[tokio::test] async fn shutdown_of_live_gate_skips_already_queued_requests() { let queues = RequestSerializationQueues::default(); let key = RequestSerializationQueueKey::Global("test"); let live_gate = gate(); let (tx, mut rx) = mpsc::unbounded_channel(); let (blocked_tx, blocked_rx) = oneshot::channel::<()>(); { let tx = tx.clone(); queues .enqueue( key.clone(), QueuedInitializedRequest::new(Arc::clone(&live_gate), async move { tx.send(FIRST_REQUEST_VALUE) .expect("receiver should be open"); let _ = blocked_rx.await; }), ) .await; } { let tx = tx.clone(); queues .enqueue( key, QueuedInitializedRequest::new(live_gate.clone(), async move { tx.send(SECOND_REQUEST_VALUE) .expect("receiver should be open"); }), ) .await; } drop(tx); assert_eq!( timeout(queue_drain_timeout(), rx.recv()) .await .expect("timed out waiting for first request"), Some(FIRST_REQUEST_VALUE) ); let gate_for_shutdown = Arc::clone(&live_gate); let shutdown_task = tokio::spawn(async move { gate_for_shutdown.shutdown().await; }); timeout(shutdown_wait_timeout(), shutdown_task) .await .expect_err("shutdown should wait for the running request"); blocked_tx .send(()) .expect("blocked request should still be waiting"); assert_eq!( timeout(queue_drain_timeout(), rx.recv()) .await .expect("timed out waiting for queue to drain"), None ); } }