diff --git a/codex-rs/app-server-protocol/src/protocol/common.rs b/codex-rs/app-server-protocol/src/protocol/common.rs index 167435ac88..99f4d7d3e1 100644 --- a/codex-rs/app-server-protocol/src/protocol/common.rs +++ b/codex-rs/app-server-protocol/src/protocol/common.rs @@ -1,4 +1,5 @@ use std::path::Path; +use std::path::PathBuf; use crate::JSONRPCNotification; use crate::JSONRPCRequest; @@ -73,6 +74,76 @@ macro_rules! experimental_type_entry { }; } +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum ClientRequestSerializationScope { + Global(&'static str), + Thread { thread_id: String }, + ThreadPath { path: PathBuf }, + CommandExecProcess { process_id: String }, + FuzzyFileSearchSession { session_id: String }, + FsWatch { watch_id: String }, + McpOauth { server_name: String }, +} + +macro_rules! serialization_scope_expr { + ($actual_params:ident, None) => { + None + }; + ($actual_params:ident, global($key:literal)) => { + Some(ClientRequestSerializationScope::Global($key)) + }; + ($actual_params:ident, thread_id($params:ident . $field:ident)) => { + Some(ClientRequestSerializationScope::Thread { + thread_id: $actual_params.$field.clone(), + }) + }; + ($actual_params:ident, optional_thread_id($params:ident . $field:ident)) => { + $actual_params + .$field + .clone() + .map(|thread_id| ClientRequestSerializationScope::Thread { thread_id }) + }; + ($actual_params:ident, thread_or_path($params:ident . $thread_field:ident, $params2:ident . $path_field:ident)) => { + if !$actual_params.$thread_field.is_empty() { + Some(ClientRequestSerializationScope::Thread { + thread_id: $actual_params.$thread_field.clone(), + }) + } else if let Some(path) = $actual_params.$path_field.clone() { + Some(ClientRequestSerializationScope::ThreadPath { path }) + } else { + Some(ClientRequestSerializationScope::Thread { + thread_id: $actual_params.$thread_field.clone(), + }) + } + }; + ($actual_params:ident, optional_command_process_id($params:ident . $field:ident)) => { + $actual_params + .$field + .clone() + .map(|process_id| ClientRequestSerializationScope::CommandExecProcess { process_id }) + }; + ($actual_params:ident, command_process_id($params:ident . $field:ident)) => { + Some(ClientRequestSerializationScope::CommandExecProcess { + process_id: $actual_params.$field.clone(), + }) + }; + ($actual_params:ident, fuzzy_session_id($params:ident . $field:ident)) => { + Some(ClientRequestSerializationScope::FuzzyFileSearchSession { + session_id: $actual_params.$field.clone(), + }) + }; + ($actual_params:ident, fs_watch_id($params:ident . $field:ident)) => { + Some(ClientRequestSerializationScope::FsWatch { + watch_id: $actual_params.$field.clone(), + }) + }; + ($actual_params:ident, mcp_oauth_server($params:ident . $field:ident)) => { + Some(ClientRequestSerializationScope::McpOauth { + server_name: $actual_params.$field.clone(), + }) + }; +} + /// Generates an `enum ClientRequest` where each variant is a request that the /// client can send to the server. Each variant has associated `params` and /// `response` types. Also generates a `export_client_responses()` function to @@ -85,6 +156,7 @@ macro_rules! client_request_definitions { $variant:ident $(=> $wire:literal)? { params: $(#[$params_meta:meta])* $params:ty, $(inspect_params: $inspect_params:tt,)? + serialization: $serialization:ident $( ( $($serialization_args:tt)* ) )?, response: $response:ty, } ),* $(,)? @@ -123,6 +195,19 @@ macro_rules! client_request_definitions { }) .unwrap_or_else(|| "".to_string()) } + + pub fn serialization_scope(&self) -> Option { + match self { + $( + Self::$variant { params, .. } => { + let _ = params; + serialization_scope_expr!( + params, $serialization $( ( $($serialization_args)* ) )? + ) + } + )* + } + } } /// Typed response from the server to the client. @@ -235,6 +320,7 @@ macro_rules! client_request_definitions { client_request_definitions! { Initialize { params: v1::InitializeParams, + serialization: None, response: v1::InitializeResponse, }, @@ -244,24 +330,29 @@ client_request_definitions! { ThreadStart => "thread/start" { params: v2::ThreadStartParams, inspect_params: true, + serialization: None, response: v2::ThreadStartResponse, }, ThreadResume => "thread/resume" { params: v2::ThreadResumeParams, inspect_params: true, + serialization: thread_or_path(params.thread_id, params.path), response: v2::ThreadResumeResponse, }, ThreadFork => "thread/fork" { params: v2::ThreadForkParams, inspect_params: true, + serialization: thread_or_path(params.thread_id, params.path), response: v2::ThreadForkResponse, }, ThreadArchive => "thread/archive" { params: v2::ThreadArchiveParams, + serialization: thread_id(params.thread_id), response: v2::ThreadArchiveResponse, }, ThreadUnsubscribe => "thread/unsubscribe" { params: v2::ThreadUnsubscribeParams, + serialization: thread_id(params.thread_id), response: v2::ThreadUnsubscribeResponse, }, #[experimental("thread/increment_elicitation")] @@ -271,6 +362,7 @@ client_request_definitions! { /// approval or other elicitation is pending outside the app-server request flow. ThreadIncrementElicitation => "thread/increment_elicitation" { params: v2::ThreadIncrementElicitationParams, + serialization: thread_id(params.thread_id), response: v2::ThreadIncrementElicitationResponse, }, #[experimental("thread/decrement_elicitation")] @@ -279,302 +371,372 @@ client_request_definitions! { /// When the count reaches zero, timeout accounting resumes for the thread. ThreadDecrementElicitation => "thread/decrement_elicitation" { params: v2::ThreadDecrementElicitationParams, + serialization: thread_id(params.thread_id), response: v2::ThreadDecrementElicitationResponse, }, ThreadSetName => "thread/name/set" { params: v2::ThreadSetNameParams, + serialization: thread_id(params.thread_id), response: v2::ThreadSetNameResponse, }, #[experimental("thread/goal/set")] ThreadGoalSet => "thread/goal/set" { params: v2::ThreadGoalSetParams, + serialization: thread_id(params.thread_id), response: v2::ThreadGoalSetResponse, }, #[experimental("thread/goal/get")] ThreadGoalGet => "thread/goal/get" { params: v2::ThreadGoalGetParams, + serialization: thread_id(params.thread_id), response: v2::ThreadGoalGetResponse, }, #[experimental("thread/goal/clear")] ThreadGoalClear => "thread/goal/clear" { params: v2::ThreadGoalClearParams, + serialization: thread_id(params.thread_id), response: v2::ThreadGoalClearResponse, }, ThreadMetadataUpdate => "thread/metadata/update" { params: v2::ThreadMetadataUpdateParams, + serialization: thread_id(params.thread_id), response: v2::ThreadMetadataUpdateResponse, }, #[experimental("thread/memoryMode/set")] ThreadMemoryModeSet => "thread/memoryMode/set" { params: v2::ThreadMemoryModeSetParams, + serialization: thread_id(params.thread_id), response: v2::ThreadMemoryModeSetResponse, }, #[experimental("memory/reset")] MemoryReset => "memory/reset" { params: #[ts(type = "undefined")] #[serde(skip_serializing_if = "Option::is_none")] Option<()>, + serialization: global("memory"), response: v2::MemoryResetResponse, }, ThreadUnarchive => "thread/unarchive" { params: v2::ThreadUnarchiveParams, + serialization: thread_id(params.thread_id), response: v2::ThreadUnarchiveResponse, }, ThreadCompactStart => "thread/compact/start" { params: v2::ThreadCompactStartParams, + serialization: thread_id(params.thread_id), response: v2::ThreadCompactStartResponse, }, ThreadShellCommand => "thread/shellCommand" { params: v2::ThreadShellCommandParams, + serialization: thread_id(params.thread_id), response: v2::ThreadShellCommandResponse, }, ThreadApproveGuardianDeniedAction => "thread/approveGuardianDeniedAction" { params: v2::ThreadApproveGuardianDeniedActionParams, + serialization: thread_id(params.thread_id), response: v2::ThreadApproveGuardianDeniedActionResponse, }, #[experimental("thread/backgroundTerminals/clean")] ThreadBackgroundTerminalsClean => "thread/backgroundTerminals/clean" { params: v2::ThreadBackgroundTerminalsCleanParams, + serialization: thread_id(params.thread_id), response: v2::ThreadBackgroundTerminalsCleanResponse, }, ThreadRollback => "thread/rollback" { params: v2::ThreadRollbackParams, + serialization: thread_id(params.thread_id), response: v2::ThreadRollbackResponse, }, ThreadList => "thread/list" { params: v2::ThreadListParams, + serialization: None, response: v2::ThreadListResponse, }, ThreadLoadedList => "thread/loaded/list" { params: v2::ThreadLoadedListParams, + serialization: None, response: v2::ThreadLoadedListResponse, }, ThreadRead => "thread/read" { params: v2::ThreadReadParams, + serialization: thread_id(params.thread_id), response: v2::ThreadReadResponse, }, ThreadTurnsList => "thread/turns/list" { params: v2::ThreadTurnsListParams, + // Explicitly concurrent: this primarily reads append-only rollout storage. + serialization: None, response: v2::ThreadTurnsListResponse, }, /// Append raw Responses API items to the thread history without starting a user turn. ThreadInjectItems => "thread/inject_items" { params: v2::ThreadInjectItemsParams, + serialization: thread_id(params.thread_id), response: v2::ThreadInjectItemsResponse, }, SkillsList => "skills/list" { params: v2::SkillsListParams, + serialization: global("config"), response: v2::SkillsListResponse, }, MarketplaceAdd => "marketplace/add" { params: v2::MarketplaceAddParams, + serialization: global("config"), response: v2::MarketplaceAddResponse, }, MarketplaceRemove => "marketplace/remove" { params: v2::MarketplaceRemoveParams, + serialization: global("config"), response: v2::MarketplaceRemoveResponse, }, MarketplaceUpgrade => "marketplace/upgrade" { params: v2::MarketplaceUpgradeParams, + serialization: global("config"), response: v2::MarketplaceUpgradeResponse, }, PluginList => "plugin/list" { params: v2::PluginListParams, + serialization: global("config"), response: v2::PluginListResponse, }, PluginRead => "plugin/read" { params: v2::PluginReadParams, + serialization: global("config"), response: v2::PluginReadResponse, }, AppsList => "app/list" { params: v2::AppsListParams, + serialization: None, response: v2::AppsListResponse, }, DeviceKeyCreate => "device/key/create" { params: v2::DeviceKeyCreateParams, + serialization: global("device-key"), response: v2::DeviceKeyCreateResponse, }, DeviceKeyPublic => "device/key/public" { params: v2::DeviceKeyPublicParams, + serialization: global("device-key"), response: v2::DeviceKeyPublicResponse, }, DeviceKeySign => "device/key/sign" { params: v2::DeviceKeySignParams, + serialization: global("device-key"), response: v2::DeviceKeySignResponse, }, + // File system requests are intentionally concurrent. Desktop already treats local + // file system operations as concurrent, and app-server remote fs mirrors that model. FsReadFile => "fs/readFile" { params: v2::FsReadFileParams, + serialization: None, response: v2::FsReadFileResponse, }, FsWriteFile => "fs/writeFile" { params: v2::FsWriteFileParams, + serialization: None, response: v2::FsWriteFileResponse, }, FsCreateDirectory => "fs/createDirectory" { params: v2::FsCreateDirectoryParams, + serialization: None, response: v2::FsCreateDirectoryResponse, }, FsGetMetadata => "fs/getMetadata" { params: v2::FsGetMetadataParams, + serialization: None, response: v2::FsGetMetadataResponse, }, FsReadDirectory => "fs/readDirectory" { params: v2::FsReadDirectoryParams, + serialization: None, response: v2::FsReadDirectoryResponse, }, FsRemove => "fs/remove" { params: v2::FsRemoveParams, + serialization: None, response: v2::FsRemoveResponse, }, FsCopy => "fs/copy" { params: v2::FsCopyParams, + serialization: None, response: v2::FsCopyResponse, }, FsWatch => "fs/watch" { params: v2::FsWatchParams, + serialization: fs_watch_id(params.watch_id), response: v2::FsWatchResponse, }, FsUnwatch => "fs/unwatch" { params: v2::FsUnwatchParams, + serialization: fs_watch_id(params.watch_id), response: v2::FsUnwatchResponse, }, SkillsConfigWrite => "skills/config/write" { params: v2::SkillsConfigWriteParams, + serialization: global("config"), response: v2::SkillsConfigWriteResponse, }, PluginInstall => "plugin/install" { params: v2::PluginInstallParams, + serialization: global("config"), response: v2::PluginInstallResponse, }, PluginUninstall => "plugin/uninstall" { params: v2::PluginUninstallParams, + serialization: global("config"), response: v2::PluginUninstallResponse, }, TurnStart => "turn/start" { params: v2::TurnStartParams, inspect_params: true, + serialization: thread_id(params.thread_id), response: v2::TurnStartResponse, }, TurnSteer => "turn/steer" { params: v2::TurnSteerParams, inspect_params: true, + serialization: thread_id(params.thread_id), response: v2::TurnSteerResponse, }, TurnInterrupt => "turn/interrupt" { params: v2::TurnInterruptParams, + serialization: thread_id(params.thread_id), response: v2::TurnInterruptResponse, }, #[experimental("thread/realtime/start")] ThreadRealtimeStart => "thread/realtime/start" { params: v2::ThreadRealtimeStartParams, + serialization: thread_id(params.thread_id), response: v2::ThreadRealtimeStartResponse, }, #[experimental("thread/realtime/appendAudio")] ThreadRealtimeAppendAudio => "thread/realtime/appendAudio" { params: v2::ThreadRealtimeAppendAudioParams, + serialization: thread_id(params.thread_id), response: v2::ThreadRealtimeAppendAudioResponse, }, #[experimental("thread/realtime/appendText")] ThreadRealtimeAppendText => "thread/realtime/appendText" { params: v2::ThreadRealtimeAppendTextParams, + serialization: thread_id(params.thread_id), response: v2::ThreadRealtimeAppendTextResponse, }, #[experimental("thread/realtime/stop")] ThreadRealtimeStop => "thread/realtime/stop" { params: v2::ThreadRealtimeStopParams, + serialization: thread_id(params.thread_id), response: v2::ThreadRealtimeStopResponse, }, #[experimental("thread/realtime/listVoices")] ThreadRealtimeListVoices => "thread/realtime/listVoices" { params: v2::ThreadRealtimeListVoicesParams, + serialization: None, response: v2::ThreadRealtimeListVoicesResponse, }, ReviewStart => "review/start" { params: v2::ReviewStartParams, + serialization: thread_id(params.thread_id), response: v2::ReviewStartResponse, }, ModelList => "model/list" { params: v2::ModelListParams, + serialization: None, response: v2::ModelListResponse, }, ExperimentalFeatureList => "experimentalFeature/list" { params: v2::ExperimentalFeatureListParams, + serialization: global("config"), response: v2::ExperimentalFeatureListResponse, }, ExperimentalFeatureEnablementSet => "experimentalFeature/enablement/set" { params: v2::ExperimentalFeatureEnablementSetParams, + serialization: global("config"), response: v2::ExperimentalFeatureEnablementSetResponse, }, #[experimental("collaborationMode/list")] /// Lists collaboration mode presets. CollaborationModeList => "collaborationMode/list" { params: v2::CollaborationModeListParams, + serialization: None, response: v2::CollaborationModeListResponse, }, #[experimental("mock/experimentalMethod")] /// Test-only method used to validate experimental gating. MockExperimentalMethod => "mock/experimentalMethod" { params: v2::MockExperimentalMethodParams, + serialization: None, response: v2::MockExperimentalMethodResponse, }, McpServerOauthLogin => "mcpServer/oauth/login" { params: v2::McpServerOauthLoginParams, + serialization: mcp_oauth_server(params.name), response: v2::McpServerOauthLoginResponse, }, McpServerRefresh => "config/mcpServer/reload" { params: #[ts(type = "undefined")] #[serde(skip_serializing_if = "Option::is_none")] Option<()>, + serialization: global("mcp-registry"), response: v2::McpServerRefreshResponse, }, McpServerStatusList => "mcpServerStatus/list" { params: v2::ListMcpServerStatusParams, + serialization: global("mcp-registry"), response: v2::ListMcpServerStatusResponse, }, McpResourceRead => "mcpServer/resource/read" { params: v2::McpResourceReadParams, + serialization: optional_thread_id(params.thread_id), response: v2::McpResourceReadResponse, }, McpServerToolCall => "mcpServer/tool/call" { params: v2::McpServerToolCallParams, + serialization: thread_id(params.thread_id), response: v2::McpServerToolCallResponse, }, WindowsSandboxSetupStart => "windowsSandbox/setupStart" { params: v2::WindowsSandboxSetupStartParams, + serialization: global("windows-sandbox-setup"), response: v2::WindowsSandboxSetupStartResponse, }, LoginAccount => "account/login/start" { params: v2::LoginAccountParams, inspect_params: true, + serialization: global("account-auth"), response: v2::LoginAccountResponse, }, CancelLoginAccount => "account/login/cancel" { params: v2::CancelLoginAccountParams, + serialization: global("account-auth"), response: v2::CancelLoginAccountResponse, }, LogoutAccount => "account/logout" { params: #[ts(type = "undefined")] #[serde(skip_serializing_if = "Option::is_none")] Option<()>, + serialization: global("account-auth"), response: v2::LogoutAccountResponse, }, GetAccountRateLimits => "account/rateLimits/read" { params: #[ts(type = "undefined")] #[serde(skip_serializing_if = "Option::is_none")] Option<()>, + serialization: None, response: v2::GetAccountRateLimitsResponse, }, SendAddCreditsNudgeEmail => "account/sendAddCreditsNudgeEmail" { params: v2::SendAddCreditsNudgeEmailParams, + serialization: global("account-auth"), response: v2::SendAddCreditsNudgeEmailResponse, }, FeedbackUpload => "feedback/upload" { params: v2::FeedbackUploadParams, + serialization: None, response: v2::FeedbackUploadResponse, }, @@ -582,86 +744,106 @@ client_request_definitions! { OneOffCommandExec => "command/exec" { params: v2::CommandExecParams, inspect_params: true, + serialization: optional_command_process_id(params.process_id), response: v2::CommandExecResponse, }, /// Write stdin bytes to a running `command/exec` session or close stdin. CommandExecWrite => "command/exec/write" { params: v2::CommandExecWriteParams, + serialization: command_process_id(params.process_id), response: v2::CommandExecWriteResponse, }, /// Terminate a running `command/exec` session by client-supplied `processId`. CommandExecTerminate => "command/exec/terminate" { params: v2::CommandExecTerminateParams, + serialization: command_process_id(params.process_id), response: v2::CommandExecTerminateResponse, }, /// Resize a running PTY-backed `command/exec` session by client-supplied `processId`. CommandExecResize => "command/exec/resize" { params: v2::CommandExecResizeParams, + serialization: command_process_id(params.process_id), response: v2::CommandExecResizeResponse, }, ConfigRead => "config/read" { params: v2::ConfigReadParams, + serialization: global("config"), response: v2::ConfigReadResponse, }, ExternalAgentConfigDetect => "externalAgentConfig/detect" { params: v2::ExternalAgentConfigDetectParams, + serialization: global("config"), response: v2::ExternalAgentConfigDetectResponse, }, ExternalAgentConfigImport => "externalAgentConfig/import" { params: v2::ExternalAgentConfigImportParams, + serialization: global("config"), response: v2::ExternalAgentConfigImportResponse, }, ConfigValueWrite => "config/value/write" { params: v2::ConfigValueWriteParams, + serialization: global("config"), response: v2::ConfigWriteResponse, }, ConfigBatchWrite => "config/batchWrite" { params: v2::ConfigBatchWriteParams, + serialization: global("config"), response: v2::ConfigWriteResponse, }, ConfigRequirementsRead => "configRequirements/read" { params: #[ts(type = "undefined")] #[serde(skip_serializing_if = "Option::is_none")] Option<()>, + serialization: global("config"), response: v2::ConfigRequirementsReadResponse, }, GetAccount => "account/read" { params: v2::GetAccountParams, + serialization: global("account-auth"), response: v2::GetAccountResponse, }, /// DEPRECATED APIs below GetConversationSummary { params: v1::GetConversationSummaryParams, + serialization: None, response: v1::GetConversationSummaryResponse, }, GitDiffToRemote { params: v1::GitDiffToRemoteParams, + serialization: None, response: v1::GitDiffToRemoteResponse, }, /// DEPRECATED in favor of GetAccount GetAuthStatus { params: v1::GetAuthStatusParams, + serialization: global("account-auth"), response: v1::GetAuthStatusResponse, }, + // Legacy fuzzy search cancellation is intentionally concurrent: clients reuse a + // cancellation token so a newer request can cancel an older in-flight search. FuzzyFileSearch { params: FuzzyFileSearchParams, + serialization: None, response: FuzzyFileSearchResponse, }, #[experimental("fuzzyFileSearch/sessionStart")] FuzzyFileSearchSessionStart => "fuzzyFileSearch/sessionStart" { params: FuzzyFileSearchSessionStartParams, + serialization: fuzzy_session_id(params.session_id), response: FuzzyFileSearchSessionStartResponse, }, #[experimental("fuzzyFileSearch/sessionUpdate")] FuzzyFileSearchSessionUpdate => "fuzzyFileSearch/sessionUpdate" { params: FuzzyFileSearchSessionUpdateParams, + serialization: fuzzy_session_id(params.session_id), response: FuzzyFileSearchSessionUpdateResponse, }, #[experimental("fuzzyFileSearch/sessionStop")] FuzzyFileSearchSessionStop => "fuzzyFileSearch/sessionStop" { params: FuzzyFileSearchSessionStopParams, + serialization: fuzzy_session_id(params.session_id), response: FuzzyFileSearchSessionStopResponse, }, } @@ -1150,6 +1332,325 @@ mod tests { test_path_buf(&path).abs() } + fn request_id() -> RequestId { + const REQUEST_ID: i64 = 1; + RequestId::Integer(REQUEST_ID) + } + + #[test] + fn client_request_serialization_scope_covers_keyed_families() { + let thread_id = "thread-1".to_string(); + let thread_resume = ClientRequest::ThreadResume { + request_id: request_id(), + params: v2::ThreadResumeParams { + thread_id: thread_id.clone(), + ..Default::default() + }, + }; + assert_eq!( + thread_resume.serialization_scope(), + Some(ClientRequestSerializationScope::Thread { + thread_id: thread_id.clone() + }) + ); + + let thread_resume_with_path = ClientRequest::ThreadResume { + request_id: request_id(), + params: v2::ThreadResumeParams { + thread_id: thread_id.clone(), + path: Some(PathBuf::from("/tmp/resume-thread.jsonl")), + ..Default::default() + }, + }; + assert_eq!( + thread_resume_with_path.serialization_scope(), + Some(ClientRequestSerializationScope::Thread { + thread_id: thread_id.clone() + }) + ); + + let thread_fork = ClientRequest::ThreadFork { + request_id: request_id(), + params: v2::ThreadForkParams { + thread_id: thread_id.clone(), + path: Some(PathBuf::from("/tmp/source-thread.jsonl")), + ..Default::default() + }, + }; + assert_eq!( + thread_fork.serialization_scope(), + Some(ClientRequestSerializationScope::Thread { thread_id }) + ); + + let command_exec = ClientRequest::OneOffCommandExec { + request_id: request_id(), + params: v2::CommandExecParams { + command: vec!["sleep".to_string(), "10".to_string()], + process_id: Some("proc-1".to_string()), + tty: false, + stream_stdin: false, + stream_stdout_stderr: false, + output_bytes_cap: None, + disable_output_cap: false, + disable_timeout: false, + timeout_ms: None, + cwd: None, + env: None, + size: None, + sandbox_policy: None, + permission_profile: None, + }, + }; + assert_eq!( + command_exec.serialization_scope(), + Some(ClientRequestSerializationScope::CommandExecProcess { + process_id: "proc-1".to_string() + }) + ); + + let fuzzy_update = ClientRequest::FuzzyFileSearchSessionUpdate { + request_id: request_id(), + params: FuzzyFileSearchSessionUpdateParams { + session_id: "search-1".to_string(), + query: "lib".to_string(), + }, + }; + assert_eq!( + fuzzy_update.serialization_scope(), + Some(ClientRequestSerializationScope::FuzzyFileSearchSession { + session_id: "search-1".to_string() + }) + ); + + let fs_watch = ClientRequest::FsWatch { + request_id: request_id(), + params: v2::FsWatchParams { + watch_id: "watch-1".to_string(), + path: absolute_path("/tmp/repo"), + }, + }; + assert_eq!( + fs_watch.serialization_scope(), + Some(ClientRequestSerializationScope::FsWatch { + watch_id: "watch-1".to_string() + }) + ); + + let plugin_install = ClientRequest::PluginInstall { + request_id: request_id(), + params: v2::PluginInstallParams { + marketplace_path: Some(absolute_path("/tmp/marketplace")), + remote_marketplace_name: None, + plugin_name: "plugin-a".to_string(), + }, + }; + assert_eq!( + plugin_install.serialization_scope(), + Some(ClientRequestSerializationScope::Global("config")) + ); + + let plugin_uninstall = ClientRequest::PluginUninstall { + request_id: request_id(), + params: v2::PluginUninstallParams { + plugin_id: "plugin-a".to_string(), + }, + }; + assert_eq!( + plugin_uninstall.serialization_scope(), + Some(ClientRequestSerializationScope::Global("config")) + ); + + let mcp_oauth = ClientRequest::McpServerOauthLogin { + request_id: request_id(), + params: v2::McpServerOauthLoginParams { + name: "server-a".to_string(), + scopes: None, + timeout_secs: None, + }, + }; + assert_eq!( + mcp_oauth.serialization_scope(), + Some(ClientRequestSerializationScope::McpOauth { + server_name: "server-a".to_string() + }) + ); + + let mcp_resource_read = ClientRequest::McpResourceRead { + request_id: request_id(), + params: v2::McpResourceReadParams { + thread_id: Some("thread-1".to_string()), + server: "server-a".to_string(), + uri: "file:///tmp/resource".to_string(), + }, + }; + assert_eq!( + mcp_resource_read.serialization_scope(), + Some(ClientRequestSerializationScope::Thread { + thread_id: "thread-1".to_string() + }) + ); + + let config_read = ClientRequest::ConfigRead { + request_id: request_id(), + params: v2::ConfigReadParams { + include_layers: false, + cwd: None, + }, + }; + assert_eq!( + config_read.serialization_scope(), + Some(ClientRequestSerializationScope::Global("config")) + ); + + let account_read = ClientRequest::GetAccount { + request_id: request_id(), + params: v2::GetAccountParams { + refresh_token: false, + }, + }; + assert_eq!( + account_read.serialization_scope(), + Some(ClientRequestSerializationScope::Global("account-auth")) + ); + + let thread_goal_set = ClientRequest::ThreadGoalSet { + request_id: request_id(), + params: v2::ThreadGoalSetParams { + thread_id: "goal-thread".to_string(), + objective: Some("ship it".to_string()), + status: None, + token_budget: None, + }, + }; + assert_eq!( + thread_goal_set.serialization_scope(), + Some(ClientRequestSerializationScope::Thread { + thread_id: "goal-thread".to_string() + }) + ); + + let guardian_approval = ClientRequest::ThreadApproveGuardianDeniedAction { + request_id: request_id(), + params: v2::ThreadApproveGuardianDeniedActionParams { + thread_id: "guardian-thread".to_string(), + event: json!({ "type": "guardian" }), + }, + }; + assert_eq!( + guardian_approval.serialization_scope(), + Some(ClientRequestSerializationScope::Thread { + thread_id: "guardian-thread".to_string() + }) + ); + + let marketplace_remove = ClientRequest::MarketplaceRemove { + request_id: request_id(), + params: v2::MarketplaceRemoveParams { + marketplace_name: "marketplace".to_string(), + }, + }; + assert_eq!( + marketplace_remove.serialization_scope(), + Some(ClientRequestSerializationScope::Global("config")) + ); + + let device_key_create = ClientRequest::DeviceKeyCreate { + request_id: request_id(), + params: v2::DeviceKeyCreateParams { + protection_policy: None, + account_user_id: "user".to_string(), + client_id: "client".to_string(), + }, + }; + assert_eq!( + device_key_create.serialization_scope(), + Some(ClientRequestSerializationScope::Global("device-key")) + ); + + let add_credits_nudge = ClientRequest::SendAddCreditsNudgeEmail { + request_id: request_id(), + params: v2::SendAddCreditsNudgeEmailParams { + credit_type: v2::AddCreditsNudgeCreditType::Credits, + }, + }; + assert_eq!( + add_credits_nudge.serialization_scope(), + Some(ClientRequestSerializationScope::Global("account-auth")) + ); + } + + #[test] + fn client_request_serialization_scope_covers_unkeyed_representatives() { + let initialize = ClientRequest::Initialize { + request_id: request_id(), + params: v1::InitializeParams { + client_info: v1::ClientInfo { + name: "test".to_string(), + title: None, + version: "0.1.0".to_string(), + }, + capabilities: None, + }, + }; + assert_eq!(initialize.serialization_scope(), None); + + let thread_start = ClientRequest::ThreadStart { + request_id: request_id(), + params: v2::ThreadStartParams::default(), + }; + assert_eq!(thread_start.serialization_scope(), None); + + let command_exec = ClientRequest::OneOffCommandExec { + request_id: request_id(), + params: v2::CommandExecParams { + command: vec!["true".to_string()], + process_id: None, + tty: false, + stream_stdin: false, + stream_stdout_stderr: false, + output_bytes_cap: None, + disable_output_cap: false, + disable_timeout: false, + timeout_ms: None, + cwd: None, + env: None, + size: None, + sandbox_policy: None, + permission_profile: None, + }, + }; + assert_eq!(command_exec.serialization_scope(), None); + + let fs_read = ClientRequest::FsReadFile { + request_id: request_id(), + params: v2::FsReadFileParams { + path: absolute_path("/tmp/file.txt"), + }, + }; + assert_eq!(fs_read.serialization_scope(), None); + + let thread_turns_list = ClientRequest::ThreadTurnsList { + request_id: request_id(), + params: v2::ThreadTurnsListParams { + thread_id: "thread-1".to_string(), + cursor: None, + limit: None, + sort_direction: None, + }, + }; + assert_eq!(thread_turns_list.serialization_scope(), None); + + let mcp_resource_read = ClientRequest::McpResourceRead { + request_id: request_id(), + params: v2::McpResourceReadParams { + thread_id: None, + server: "server-a".to_string(), + uri: "file:///tmp/resource".to_string(), + }, + }; + assert_eq!(mcp_resource_read.serialization_scope(), None); + } + #[test] fn serialize_get_conversation_summary() -> Result<()> { let request = ClientRequest::GetConversationSummary { diff --git a/codex-rs/app-server/src/bespoke_event_handling.rs b/codex-rs/app-server/src/bespoke_event_handling.rs index 04de06f469..d67efed164 100644 --- a/codex-rs/app-server/src/bespoke_event_handling.rs +++ b/codex-rs/app-server/src/bespoke_event_handling.rs @@ -178,6 +178,7 @@ pub(crate) async fn apply_bespoke_event_handling( outgoing: ThreadScopedOutgoingMessageSender, thread_state: Arc>, thread_watch_manager: ThreadWatchManager, + thread_list_state_permit: Arc, api_version: ApiVersion, fallback_model_provider: String, codex_home: &Path, @@ -1874,6 +1875,20 @@ pub(crate) async fn apply_bespoke_event_handling( }; if let Some(request_id) = pending { + let _thread_list_state_permit = match thread_list_state_permit.acquire().await { + Ok(permit) => permit, + Err(err) => { + outgoing + .send_error( + request_id, + internal_error(format!( + "failed to acquire thread list state permit: {err}" + )), + ) + .await; + return; + } + }; let Some(rollout_path) = conversation.rollout_path() else { outgoing .send_error( @@ -3271,6 +3286,7 @@ mod tests { self.outgoing.clone(), self.thread_state.clone(), self.thread_watch_manager.clone(), + Arc::new(tokio::sync::Semaphore::new(/*permits*/ 1)), ApiVersion::V2, "test-provider".to_string(), &self.codex_home, diff --git a/codex-rs/app-server/src/codex_message_processor.rs b/codex-rs/app-server/src/codex_message_processor.rs index 1ab9c0c789..362f8692b6 100644 --- a/codex-rs/app-server/src/codex_message_processor.rs +++ b/codex-rs/app-server/src/codex_message_processor.rs @@ -393,6 +393,8 @@ use std::sync::atomic::Ordering; use std::time::Duration; use std::time::Instant; use tokio::sync::Mutex; +use tokio::sync::Semaphore; +use tokio::sync::SemaphorePermit; use tokio::sync::broadcast; use tokio::sync::oneshot; use tokio::sync::watch; @@ -524,6 +526,10 @@ pub(crate) struct CodexMessageProcessor { pending_thread_unloads: Arc>>, thread_state_manager: ThreadStateManager, thread_watch_manager: ThreadWatchManager, + /// Serializes mutations of list membership or fields rendered from list + /// results. `thread/list` is intentionally not serialized so it can run + /// concurrently against mostly append-only storage. + thread_list_state_permit: Arc, command_exec_manager: CommandExecManager, workspace_settings_cache: Arc, pending_fuzzy_searches: Arc>>>, @@ -549,6 +555,7 @@ struct ListenerTaskContext { pending_thread_unloads: Arc>>, analytics_events_client: AnalyticsEventsClient, thread_watch_manager: ThreadWatchManager, + thread_list_state_permit: Arc, fallback_model_provider: String, codex_home: PathBuf, } @@ -790,6 +797,7 @@ impl CodexMessageProcessor { pending_thread_unloads: Arc::new(Mutex::new(HashSet::new())), thread_state_manager: ThreadStateManager::new(), thread_watch_manager: ThreadWatchManager::new_with_outgoing(outgoing), + thread_list_state_permit: Arc::new(Semaphore::new(/*permits*/ 1)), command_exec_manager: CommandExecManager::default(), workspace_settings_cache: Arc::new( workspace_settings::WorkspaceSettingsCache::default(), @@ -1326,6 +1334,17 @@ impl CodexMessageProcessor { } } + async fn acquire_thread_list_state_permit( + &self, + ) -> Result, JSONRPCErrorError> { + self.thread_list_state_permit + .acquire() + .await + .map_err(|err| { + internal_error(format!("failed to acquire thread list state permit: {err}")) + }) + } + async fn login_api_key_common( &self, params: &LoginApiKeyParams, @@ -2421,6 +2440,7 @@ impl CodexMessageProcessor { pending_thread_unloads: Arc::clone(&self.pending_thread_unloads), analytics_events_client: self.analytics_events_client.clone(), thread_watch_manager: self.thread_watch_manager.clone(), + thread_list_state_permit: self.thread_list_state_permit.clone(), fallback_model_provider: self.config.model_provider_id.clone(), codex_home: self.config.codex_home.to_path_buf(), }; @@ -2871,6 +2891,13 @@ impl CodexMessageProcessor { } async fn thread_archive(&self, request_id: ConnectionRequestId, params: ThreadArchiveParams) { + let _thread_list_state_permit = match self.acquire_thread_list_state_permit().await { + Ok(permit) => permit, + Err(error) => { + self.outgoing.send_error(request_id, error).await; + return; + } + }; let result = self.thread_archive_response(params).await; let archived_thread_ids = result .as_ref() @@ -3077,6 +3104,7 @@ impl CodexMessageProcessor { return Err(invalid_request("thread name must not be empty")); }; + let _thread_list_state_permit = self.acquire_thread_list_state_permit().await?; if let Ok(thread) = self.thread_manager.get_thread(thread_id).await { self.submit_core_op(request_id, thread.as_ref(), Op::SetThreadName { name }) .await @@ -3218,6 +3246,7 @@ impl CodexMessageProcessor { return Err(invalid_request("gitInfo must include at least one field")); } + let _thread_list_state_permit = self.acquire_thread_list_state_permit().await?; let loaded_thread = self.thread_manager.get_thread(thread_uuid).await.ok(); let mut state_db_ctx = loaded_thread.as_ref().and_then(|thread| thread.state_db()); if state_db_ctx.is_none() { @@ -3414,6 +3443,13 @@ impl CodexMessageProcessor { request_id: ConnectionRequestId, params: ThreadUnarchiveParams, ) { + let _thread_list_state_permit = match self.acquire_thread_list_state_permit().await { + Ok(permit) => permit, + Err(error) => { + self.outgoing.send_error(request_id, error).await; + return; + } + }; let result = self.thread_unarchive_response(params).await; let notification = result @@ -4143,6 +4179,13 @@ impl CodexMessageProcessor { return; } + let _thread_list_state_permit = match self.acquire_thread_list_state_permit().await { + Ok(permit) => permit, + Err(error) => { + self.outgoing.send_error(request_id, error).await; + return; + } + }; match self.resume_running_thread(&request_id, ¶ms).await { Ok(true) => return, Ok(false) => {} @@ -7237,6 +7280,7 @@ impl CodexMessageProcessor { pending_thread_unloads: Arc::clone(&self.pending_thread_unloads), analytics_events_client: self.analytics_events_client.clone(), thread_watch_manager: self.thread_watch_manager.clone(), + thread_list_state_permit: self.thread_list_state_permit.clone(), fallback_model_provider: self.config.model_provider_id.clone(), codex_home: self.config.codex_home.to_path_buf(), }, @@ -7354,6 +7398,7 @@ impl CodexMessageProcessor { pending_thread_unloads: Arc::clone(&self.pending_thread_unloads), analytics_events_client: self.analytics_events_client.clone(), thread_watch_manager: self.thread_watch_manager.clone(), + thread_list_state_permit: self.thread_list_state_permit.clone(), fallback_model_provider: self.config.model_provider_id.clone(), codex_home: self.config.codex_home.to_path_buf(), }, @@ -7402,6 +7447,7 @@ impl CodexMessageProcessor { pending_thread_unloads, analytics_events_client: _, thread_watch_manager, + thread_list_state_permit, fallback_model_provider, codex_home, } = listener_task_context; @@ -7480,6 +7526,7 @@ impl CodexMessageProcessor { thread_outgoing, thread_state.clone(), thread_watch_manager.clone(), + thread_list_state_permit.clone(), api_version, fallback_model_provider.clone(), codex_home.as_path(), diff --git a/codex-rs/app-server/src/connection_rpc_gate.rs b/codex-rs/app-server/src/connection_rpc_gate.rs new file mode 100644 index 0000000000..12fed79b36 --- /dev/null +++ b/codex-rs/app-server/src/connection_rpc_gate.rs @@ -0,0 +1,209 @@ +use std::future::Future; + +use tokio::sync::Mutex; +use tokio_util::task::TaskTracker; + +/// Per-connection gate for initialized RPC handler execution. +/// +/// Closing the gate prevents queued handlers from starting while allowing +/// handlers that already acquired a token to finish. +#[derive(Debug)] +pub(crate) struct ConnectionRpcGate { + accepting: Mutex, + tasks: TaskTracker, +} + +impl ConnectionRpcGate { + pub(crate) fn new() -> Self { + let accepting = true; + Self { + accepting: Mutex::new(accepting), + tasks: TaskTracker::new(), + } + } + + pub(crate) async fn run(&self, future: F) + where + F: Future, + { + let token = { + let accepting = self.accepting.lock().await; + if !*accepting { + return; + } + self.tasks.token() + }; + + future.await; + drop(token); + } + + pub(crate) async fn shutdown(&self) { + { + let mut accepting = self.accepting.lock().await; + *accepting = false; + self.tasks.close(); + } + self.tasks.wait().await; + } + + #[cfg(test)] + async fn is_accepting(&self) -> bool { + *self.accepting.lock().await + } + + #[cfg(test)] + fn inflight_count(&self) -> usize { + self.tasks.len() + } +} + +impl Default for ConnectionRpcGate { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use pretty_assertions::assert_eq; + use std::sync::Arc; + use std::sync::atomic::AtomicBool; + use std::sync::atomic::Ordering; + use tokio::sync::oneshot; + use tokio::time::Duration; + use tokio::time::timeout; + + #[tokio::test] + async fn run_executes_while_open() { + let gate = ConnectionRpcGate::new(); + let ran = Arc::new(AtomicBool::new(/*v*/ false)); + let ran_clone = Arc::clone(&ran); + + gate.run(async move { + ran_clone.store(/*val*/ true, Ordering::Release); + }) + .await; + + assert!(ran.load(Ordering::Acquire)); + } + + #[tokio::test] + async fn run_drops_future_without_polling_after_shutdown() { + let gate = ConnectionRpcGate::new(); + gate.shutdown().await; + let polled = Arc::new(AtomicBool::new(/*v*/ false)); + let polled_clone = Arc::clone(&polled); + + gate.run(async move { + polled_clone.store(/*val*/ true, Ordering::Release); + }) + .await; + + assert!(!polled.load(Ordering::Acquire)); + assert!(!gate.is_accepting().await); + } + + #[tokio::test] + async fn shutdown_waits_for_started_run_to_finish() { + let gate = Arc::new(ConnectionRpcGate::new()); + let (started_tx, started_rx) = oneshot::channel(); + let (finish_tx, finish_rx) = oneshot::channel(); + let gate_for_run = Arc::clone(&gate); + let run_task = tokio::spawn(async move { + gate_for_run + .run(async move { + started_tx.send(()).expect("receiver should be open"); + let _ = finish_rx.await; + }) + .await; + }); + + started_rx.await.expect("run should start"); + assert_eq!(gate.inflight_count(), 1); + + let gate_for_shutdown = Arc::clone(&gate); + let shutdown_task = tokio::spawn(async move { + gate_for_shutdown.shutdown().await; + }); + + timeout(Duration::from_millis(/*millis*/ 50), shutdown_task) + .await + .expect_err("shutdown should wait for the running future"); + + finish_tx + .send(()) + .expect("running future should be waiting"); + run_task.await.expect("run task should complete"); + gate.shutdown().await; + assert_eq!(gate.inflight_count(), 0); + } + + #[tokio::test] + async fn shutdown_drops_late_runs_while_waiting_for_inflight_work() { + let gate = Arc::new(ConnectionRpcGate::new()); + let (started_tx, started_rx) = oneshot::channel(); + let (finish_tx, finish_rx) = oneshot::channel(); + let gate_for_run = Arc::clone(&gate); + let run_task = tokio::spawn(async move { + gate_for_run + .run(async move { + started_tx.send(()).expect("receiver should be open"); + let _ = finish_rx.await; + }) + .await; + }); + + started_rx.await.expect("run should start"); + let gate_for_shutdown = Arc::clone(&gate); + let shutdown_task = tokio::spawn(async move { + gate_for_shutdown.shutdown().await; + }); + + timeout(Duration::from_millis(/*millis*/ 50), shutdown_task) + .await + .expect_err("shutdown should wait for the running future"); + + let late_polled = Arc::new(AtomicBool::new(/*v*/ false)); + let late_polled_clone = Arc::clone(&late_polled); + gate.run(async move { + late_polled_clone.store(/*val*/ true, Ordering::Release); + }) + .await; + + assert!(!late_polled.load(Ordering::Acquire)); + + finish_tx + .send(()) + .expect("running future should still be waiting"); + run_task.await.expect("run task should complete"); + gate.shutdown().await; + assert_eq!(gate.inflight_count(), 0); + } + + #[tokio::test] + async fn run_is_counted_before_handler_body_continues() { + let gate = Arc::new(ConnectionRpcGate::new()); + let (entered_tx, entered_rx) = oneshot::channel(); + let (continue_tx, continue_rx) = oneshot::channel(); + let gate_for_run = Arc::clone(&gate); + let run_task = tokio::spawn(async move { + gate_for_run + .run(async move { + entered_tx.send(()).expect("receiver should be open"); + let _ = continue_rx.await; + }) + .await; + }); + + entered_rx.await.expect("handler body should be entered"); + assert_eq!(gate.inflight_count(), 1); + + continue_tx + .send(()) + .expect("handler body should still be waiting"); + run_task.await.expect("run task should complete"); + assert_eq!(gate.inflight_count(), 0); + } +} diff --git a/codex-rs/app-server/src/in_process.rs b/codex-rs/app-server/src/in_process.rs index 73332394f4..5ee044b14a 100644 --- a/codex-rs/app-server/src/in_process.rs +++ b/codex-rs/app-server/src/in_process.rs @@ -490,7 +490,9 @@ fn start_uninitialized(args: InProcessStartArgs) -> InProcessClientHandle { processor.clear_runtime_references(); processor.cancel_active_login().await; - processor.connection_closed(IN_PROCESS_CONNECTION_ID).await; + processor + .connection_closed(IN_PROCESS_CONNECTION_ID, &session) + .await; processor.clear_all_thread_listeners().await; processor.drain_background_tasks().await; processor.shutdown_threads().await; diff --git a/codex-rs/app-server/src/lib.rs b/codex-rs/app-server/src/lib.rs index de89608101..4369bbee31 100644 --- a/codex-rs/app-server/src/lib.rs +++ b/codex-rs/app-server/src/lib.rs @@ -75,6 +75,7 @@ mod config; mod config_api; mod config_manager; mod config_manager_service; +mod connection_rpc_gate; mod device_key_api; mod dynamic_tools; mod error_code; @@ -87,6 +88,7 @@ pub mod in_process; mod message_processor; mod models; mod outgoing_message; +mod request_serialization; mod server_request_error; mod thread_state; mod thread_status; @@ -809,9 +811,9 @@ pub async fn run_main_with_transport_options( ); } TransportEvent::ConnectionClosed { connection_id } => { - if connections.remove(&connection_id).is_none() { + let Some(connection_state) = connections.remove(&connection_id) else { continue; - } + }; if outbound_control_tx .send(OutboundControlEvent::Closed { connection_id }) .await @@ -819,7 +821,7 @@ pub async fn run_main_with_transport_options( { break; } - processor.connection_closed(connection_id).await; + processor.connection_closed(connection_id, &connection_state.session).await; if shutdown_when_no_connections && connections.is_empty() { break; } @@ -933,6 +935,12 @@ pub async fn run_main_with_transport_options( } if !shutdown_state.forced() { + futures::future::join_all( + connections + .values() + .map(|connection_state| connection_state.session.rpc_gate.shutdown()), + ) + .await; processor.drain_background_tasks().await; processor.shutdown_threads().await; } diff --git a/codex-rs/app-server/src/message_processor.rs b/codex-rs/app-server/src/message_processor.rs index 5cb1c1da36..df77e8ca2b 100644 --- a/codex-rs/app-server/src/message_processor.rs +++ b/codex-rs/app-server/src/message_processor.rs @@ -9,6 +9,7 @@ use crate::codex_message_processor::CodexMessageProcessor; use crate::codex_message_processor::CodexMessageProcessorArgs; use crate::config_api::ConfigApi; use crate::config_manager::ConfigManager; +use crate::connection_rpc_gate::ConnectionRpcGate; use crate::device_key_api::DeviceKeyApi; use crate::error_code::invalid_request; use crate::external_agent_config_api::ExternalAgentConfigApi; @@ -18,6 +19,9 @@ use crate::outgoing_message::ConnectionId; use crate::outgoing_message::ConnectionRequestId; use crate::outgoing_message::OutgoingMessageSender; use crate::outgoing_message::RequestContext; +use crate::request_serialization::QueuedInitializedRequest; +use crate::request_serialization::RequestSerializationQueueKey; +use crate::request_serialization::RequestSerializationQueues; use crate::transport::AppServerTransport; use crate::transport::ConnectionOrigin; use crate::transport::RemoteControlHandle; @@ -167,11 +171,13 @@ pub(crate) struct MessageProcessor { config_warnings: Arc>, rpc_transport: AppServerRpcTransport, remote_control_handle: Option, + request_serialization_queues: RequestSerializationQueues, } #[derive(Debug)] pub(crate) struct ConnectionSessionState { origin: ConnectionOrigin, + pub(crate) rpc_gate: Arc, initialized: OnceLock, } @@ -193,6 +199,7 @@ impl ConnectionSessionState { pub(crate) fn new(origin: ConnectionOrigin) -> Self { Self { origin, + rpc_gate: Arc::new(ConnectionRpcGate::new()), initialized: OnceLock::new(), } } @@ -344,6 +351,7 @@ impl MessageProcessor { config_warnings: Arc::new(config_warnings), rpc_transport, remote_control_handle, + request_serialization_queues: RequestSerializationQueues::default(), } } @@ -540,7 +548,12 @@ impl MessageProcessor { self.codex_message_processor.shutdown_threads().await; } - pub(crate) async fn connection_closed(&self, connection_id: ConnectionId) { + pub(crate) async fn connection_closed( + &self, + connection_id: ConnectionId, + session_state: &ConnectionSessionState, + ) { + session_state.rpc_gate.shutdown().await; self.outgoing.connection_closed(connection_id).await; self.fs_watch_manager.connection_closed(connection_id).await; self.codex_message_processor @@ -724,19 +737,46 @@ impl MessageProcessor { ); } + let serialization_scope = codex_request.serialization_scope(); let app_server_client_name = session.app_server_client_name().map(str::to_string); let client_version = session.client_version().map(str::to_string); let device_key_requests_allowed = session.allows_device_key_requests(); - Arc::clone(self) - .handle_initialized_client_request( - connection_request_id, - codex_request, - request_context, - app_server_client_name, - client_version, - device_key_requests_allowed, - ) - .await + let error_request_id = connection_request_id.clone(); + let rpc_gate = Arc::clone(&session.rpc_gate); + let processor = Arc::clone(self); + let span = request_context.span(); + let request = QueuedInitializedRequest::new( + rpc_gate, + async move { + let processor_for_request = Arc::clone(&processor); + let result = processor_for_request + .handle_initialized_client_request( + connection_request_id, + codex_request, + request_context, + app_server_client_name, + client_version, + device_key_requests_allowed, + ) + .await; + if let Err(error) = result { + processor.outgoing.send_error(error_request_id, error).await; + } + } + .instrument(span), + ); + + if let Some(scope) = serialization_scope { + let key = RequestSerializationQueueKey::from_scope(connection_id, scope); + self.request_serialization_queues + .enqueue(key, request) + .await; + } else { + tokio::spawn(async move { + request.run().await; + }); + } + Ok(()) } async fn handle_initialized_client_request( diff --git a/codex-rs/app-server/src/request_serialization.rs b/codex-rs/app-server/src/request_serialization.rs new file mode 100644 index 0000000000..c3e21d134e --- /dev/null +++ b/codex-rs/app-server/src/request_serialization.rs @@ -0,0 +1,380 @@ +use std::collections::HashMap; +use std::collections::VecDeque; +use std::future::Future; +use std::path::PathBuf; +use std::pin::Pin; +use std::sync::Arc; + +use codex_app_server_protocol::ClientRequestSerializationScope; +use tokio::sync::Mutex; +use tracing::Instrument; + +use crate::connection_rpc_gate::ConnectionRpcGate; +use crate::outgoing_message::ConnectionId; + +type BoxFutureUnit = Pin + Send + 'static>>; + +#[derive(Clone, Debug, Eq, Hash, PartialEq)] +pub(crate) enum RequestSerializationQueueKey { + Global(&'static str), + Thread { + thread_id: String, + }, + ThreadPath { + path: PathBuf, + }, + CommandExecProcess { + connection_id: ConnectionId, + process_id: String, + }, + FuzzyFileSearchSession { + session_id: String, + }, + FsWatch { + connection_id: ConnectionId, + watch_id: String, + }, + McpOauth { + server_name: String, + }, +} + +impl RequestSerializationQueueKey { + pub(crate) fn from_scope( + connection_id: ConnectionId, + scope: ClientRequestSerializationScope, + ) -> Self { + match scope { + ClientRequestSerializationScope::Global(name) => Self::Global(name), + ClientRequestSerializationScope::Thread { thread_id } => Self::Thread { thread_id }, + ClientRequestSerializationScope::ThreadPath { path } => Self::ThreadPath { path }, + ClientRequestSerializationScope::CommandExecProcess { process_id } => { + Self::CommandExecProcess { + connection_id, + process_id, + } + } + ClientRequestSerializationScope::FuzzyFileSearchSession { session_id } => { + Self::FuzzyFileSearchSession { session_id } + } + ClientRequestSerializationScope::FsWatch { watch_id } => Self::FsWatch { + connection_id, + watch_id, + }, + ClientRequestSerializationScope::McpOauth { server_name } => { + Self::McpOauth { server_name } + } + } + } +} + +pub(crate) struct QueuedInitializedRequest { + gate: Arc, + future: BoxFutureUnit, +} + +impl QueuedInitializedRequest { + pub(crate) fn new( + gate: Arc, + future: impl Future + Send + 'static, + ) -> Self { + Self { + gate, + future: Box::pin(future), + } + } + + pub(crate) async fn run(self) { + let Self { gate, future } = self; + gate.run(future).await; + } +} + +#[derive(Clone, Default)] +pub(crate) struct RequestSerializationQueues { + inner: Arc>>>, +} + +impl RequestSerializationQueues { + pub(crate) async fn enqueue( + &self, + key: RequestSerializationQueueKey, + request: QueuedInitializedRequest, + ) { + let should_spawn = { + let mut queues = self.inner.lock().await; + match queues.get_mut(&key) { + Some(queue) => { + queue.push_back(request); + false + } + None => { + let mut queue = VecDeque::new(); + queue.push_back(request); + queues.insert(key.clone(), queue); + true + } + } + }; + + if should_spawn { + let queues = self.clone(); + let span = tracing::debug_span!("app_server.serialized_request_queue", ?key); + tokio::spawn(async move { queues.drain(key).await }.instrument(span)); + } + } + + async fn drain(self, key: RequestSerializationQueueKey) { + loop { + let request = { + let mut queues = self.inner.lock().await; + let Some(queue) = queues.get_mut(&key) else { + return; + }; + match queue.pop_front() { + Some(request) => request, + None => { + queues.remove(&key); + return; + } + } + }; + + request.run().await; + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use pretty_assertions::assert_eq; + use std::sync::Arc; + use tokio::sync::mpsc; + use tokio::sync::oneshot; + use tokio::time::Duration; + use tokio::time::timeout; + + const FIRST_REQUEST_VALUE: i32 = 1; + const SECOND_REQUEST_VALUE: i32 = 2; + const THIRD_REQUEST_VALUE: i32 = 3; + + fn gate() -> Arc { + Arc::new(ConnectionRpcGate::new()) + } + + fn queue_drain_timeout() -> Duration { + Duration::from_secs(/*secs*/ 1) + } + + fn shutdown_wait_timeout() -> Duration { + Duration::from_millis(/*millis*/ 50) + } + + #[tokio::test] + async fn same_key_requests_run_fifo() { + let queues = RequestSerializationQueues::default(); + let key = RequestSerializationQueueKey::Global("test"); + let gate = gate(); + let (tx, mut rx) = mpsc::unbounded_channel(); + + for value in [ + FIRST_REQUEST_VALUE, + SECOND_REQUEST_VALUE, + THIRD_REQUEST_VALUE, + ] { + let tx = tx.clone(); + queues + .enqueue( + key.clone(), + QueuedInitializedRequest::new(Arc::clone(&gate), async move { + tx.send(value).expect("receiver should be open"); + }), + ) + .await; + } + drop(tx); + + let mut values = Vec::new(); + while let Some(value) = timeout(queue_drain_timeout(), rx.recv()) + .await + .expect("timed out waiting for queued request") + { + values.push(value); + } + + assert_eq!( + values, + vec![ + FIRST_REQUEST_VALUE, + SECOND_REQUEST_VALUE, + THIRD_REQUEST_VALUE + ] + ); + } + + #[tokio::test] + async fn different_keys_run_concurrently() { + let queues = RequestSerializationQueues::default(); + let (blocked_tx, blocked_rx) = oneshot::channel::<()>(); + let (ran_tx, ran_rx) = oneshot::channel::<()>(); + + queues + .enqueue( + RequestSerializationQueueKey::Global("blocked"), + QueuedInitializedRequest::new(gate(), async move { + let _ = blocked_rx.await; + }), + ) + .await; + queues + .enqueue( + RequestSerializationQueueKey::Global("other"), + QueuedInitializedRequest::new(gate(), async move { + ran_tx.send(()).expect("receiver should be open"); + }), + ) + .await; + + timeout(queue_drain_timeout(), ran_rx) + .await + .expect("other key should not be blocked") + .expect("sender should be open"); + blocked_tx + .send(()) + .expect("blocked request should be waiting"); + } + + #[tokio::test] + async fn closed_gate_request_is_skipped_and_following_requests_continue() { + let queues = RequestSerializationQueues::default(); + let key = RequestSerializationQueueKey::Global("test"); + let live_gate = gate(); + let closed_gate = gate(); + closed_gate.shutdown().await; + let (tx, mut rx) = mpsc::unbounded_channel(); + let (blocked_tx, blocked_rx) = oneshot::channel::<()>(); + + { + let tx = tx.clone(); + queues + .enqueue( + key.clone(), + QueuedInitializedRequest::new(Arc::clone(&live_gate), async move { + tx.send(FIRST_REQUEST_VALUE) + .expect("receiver should be open"); + let _ = blocked_rx.await; + }), + ) + .await; + } + { + let tx = tx.clone(); + queues + .enqueue( + key.clone(), + QueuedInitializedRequest::new(closed_gate, async move { + tx.send(SECOND_REQUEST_VALUE) + .expect("receiver should be open"); + }), + ) + .await; + } + { + let tx = tx.clone(); + queues + .enqueue( + key, + QueuedInitializedRequest::new(live_gate, async move { + tx.send(THIRD_REQUEST_VALUE) + .expect("receiver should be open"); + }), + ) + .await; + } + drop(tx); + + assert_eq!( + timeout(queue_drain_timeout(), rx.recv()) + .await + .expect("timed out waiting for first request"), + Some(FIRST_REQUEST_VALUE) + ); + blocked_tx + .send(()) + .expect("blocked request should be waiting"); + + let mut values = Vec::new(); + while let Some(value) = timeout(queue_drain_timeout(), rx.recv()) + .await + .expect("timed out waiting for queue to drain") + { + values.push(value); + } + + assert_eq!(values, vec![THIRD_REQUEST_VALUE]); + } + + #[tokio::test] + async fn shutdown_of_live_gate_skips_already_queued_requests() { + let queues = RequestSerializationQueues::default(); + let key = RequestSerializationQueueKey::Global("test"); + let live_gate = gate(); + let (tx, mut rx) = mpsc::unbounded_channel(); + let (blocked_tx, blocked_rx) = oneshot::channel::<()>(); + + { + let tx = tx.clone(); + queues + .enqueue( + key.clone(), + QueuedInitializedRequest::new(Arc::clone(&live_gate), async move { + tx.send(FIRST_REQUEST_VALUE) + .expect("receiver should be open"); + let _ = blocked_rx.await; + }), + ) + .await; + } + { + let tx = tx.clone(); + queues + .enqueue( + key, + QueuedInitializedRequest::new(live_gate.clone(), async move { + tx.send(SECOND_REQUEST_VALUE) + .expect("receiver should be open"); + }), + ) + .await; + } + drop(tx); + + assert_eq!( + timeout(queue_drain_timeout(), rx.recv()) + .await + .expect("timed out waiting for first request"), + Some(FIRST_REQUEST_VALUE) + ); + + let gate_for_shutdown = Arc::clone(&live_gate); + let shutdown_task = tokio::spawn(async move { + gate_for_shutdown.shutdown().await; + }); + + timeout(shutdown_wait_timeout(), shutdown_task) + .await + .expect_err("shutdown should wait for the running request"); + + blocked_tx + .send(()) + .expect("blocked request should still be waiting"); + + assert_eq!( + timeout(queue_drain_timeout(), rx.recv()) + .await + .expect("timed out waiting for queue to drain"), + None + ); + } +}