From 973c5c823ed1b8ae76ab5a677c5864f624a2fafb Mon Sep 17 00:00:00 2001 From: rhan-oai Date: Wed, 29 Apr 2026 13:50:47 -0700 Subject: [PATCH] [app-server] type client response payloads (#20050) ## Why `pr17088` adds typed server-originated request/response plumbing, but successful client responses are still erased into bare JSON-RPC `result` values before app-server can make any typed decision about them. This precursor PR keeps successful client responses typed until the outgoing response seam. It is intentionally limited to protocol/app-server plumbing so the analytics behavior change can review separately on top. ## What changed - Add `ClientResponsePayload` as the pre-serialization client response body type. - Route app-server successful response paths through the typed payload seam while preserving existing handler-local analytics behavior. - Keep `InterruptConversation` JSON-RPC-only because it has no `ClientResponse` variant. - Move the new payload conversion tests into a dedicated protocol test module. ## Verification - `cargo check -p codex-app-server` - `cargo test -p codex-app-server-protocol` --- codex-rs/analytics/src/client.rs | 19 ++- codex-rs/analytics/src/client_tests.rs | 79 +++++++++++++ .../src/protocol/common.rs | 110 ++++++++++++++++++ .../src/protocol/common_tests.rs | 44 +++++++ .../app-server/src/codex_message_processor.rs | 11 +- codex-rs/app-server/src/message_processor.rs | 45 ++++--- codex-rs/app-server/src/outgoing_message.rs | 57 ++++++--- 7 files changed, 323 insertions(+), 42 deletions(-) create mode 100644 codex-rs/analytics/src/client_tests.rs create mode 100644 codex-rs/app-server-protocol/src/protocol/common_tests.rs diff --git a/codex-rs/analytics/src/client.rs b/codex-rs/analytics/src/client.rs index b149614353..2ff4c794ed 100644 --- a/codex-rs/analytics/src/client.rs +++ b/codex-rs/analytics/src/client.rs @@ -186,11 +186,22 @@ impl AnalyticsEventsClient { ))); } - pub fn track_request(&self, connection_id: u64, request_id: RequestId, request: ClientRequest) { + pub fn track_request( + &self, + connection_id: u64, + request_id: RequestId, + request: &ClientRequest, + ) { + if !matches!( + request, + ClientRequest::TurnStart { .. } | ClientRequest::TurnSteer { .. } + ) { + return; + } self.record_fact(AnalyticsFact::ClientRequest { connection_id, request_id, - request: Box::new(request), + request: Box::new(request.clone()), }); } @@ -324,6 +335,10 @@ impl AnalyticsEventsClient { } } +#[cfg(test)] +#[path = "client_tests.rs"] +mod tests; + async fn send_track_events( auth_manager: &AuthManager, base_url: &str, diff --git a/codex-rs/analytics/src/client_tests.rs b/codex-rs/analytics/src/client_tests.rs new file mode 100644 index 0000000000..57679f24a1 --- /dev/null +++ b/codex-rs/analytics/src/client_tests.rs @@ -0,0 +1,79 @@ +use super::AnalyticsEventsClient; +use super::AnalyticsEventsQueue; +use crate::facts::AnalyticsFact; +use codex_app_server_protocol::ClientRequest; +use codex_app_server_protocol::RequestId; +use codex_app_server_protocol::ThreadArchiveParams; +use codex_app_server_protocol::TurnStartParams; +use codex_app_server_protocol::TurnSteerParams; +use std::collections::HashSet; +use std::sync::Arc; +use std::sync::Mutex; +use tokio::sync::mpsc; +use tokio::sync::mpsc::error::TryRecvError; + +fn client_with_receiver() -> (AnalyticsEventsClient, mpsc::Receiver) { + let (sender, receiver) = mpsc::channel(4); + let queue = AnalyticsEventsQueue { + sender, + app_used_emitted_keys: Arc::new(Mutex::new(HashSet::new())), + plugin_used_emitted_keys: Arc::new(Mutex::new(HashSet::new())), + }; + (AnalyticsEventsClient { queue: Some(queue) }, receiver) +} + +fn sample_turn_start_request() -> ClientRequest { + ClientRequest::TurnStart { + request_id: RequestId::Integer(1), + params: TurnStartParams { + thread_id: "thread-1".to_string(), + input: Vec::new(), + ..Default::default() + }, + } +} + +fn sample_turn_steer_request() -> ClientRequest { + ClientRequest::TurnSteer { + request_id: RequestId::Integer(2), + params: TurnSteerParams { + thread_id: "thread-1".to_string(), + expected_turn_id: "turn-1".to_string(), + input: Vec::new(), + responsesapi_client_metadata: None, + }, + } +} + +fn sample_thread_archive_request() -> ClientRequest { + ClientRequest::ThreadArchive { + request_id: RequestId::Integer(3), + params: ThreadArchiveParams { + thread_id: "thread-1".to_string(), + }, + } +} + +#[test] +fn track_request_only_enqueues_analytics_relevant_requests() { + let (client, mut receiver) = client_with_receiver(); + + for (request_id, request) in [ + (RequestId::Integer(1), sample_turn_start_request()), + (RequestId::Integer(2), sample_turn_steer_request()), + ] { + client.track_request(/*connection_id*/ 7, request_id, &request); + assert!(matches!( + receiver.try_recv(), + Ok(AnalyticsFact::ClientRequest { .. }) + )); + } + + let ignored_request = sample_thread_archive_request(); + client.track_request( + /*connection_id*/ 7, + RequestId::Integer(3), + &ignored_request, + ); + assert!(matches!(receiver.try_recv(), Err(TryRecvError::Empty))); +} diff --git a/codex-rs/app-server-protocol/src/protocol/common.rs b/codex-rs/app-server-protocol/src/protocol/common.rs index e89af03d85..94659ee738 100644 --- a/codex-rs/app-server-protocol/src/protocol/common.rs +++ b/codex-rs/app-server-protocol/src/protocol/common.rs @@ -157,6 +157,7 @@ macro_rules! client_request_definitions { params: $(#[$params_meta:meta])* $params:ty, $(inspect_params: $inspect_params:tt,)? serialization: $serialization:ident $( ( $($serialization_args:tt)* ) )?, + $(manual_payload_conversion: $manual_payload_conversion:ident,)? response: $response:ty, } ),* $(,)? @@ -243,8 +244,100 @@ macro_rules! client_request_definitions { }) .unwrap_or_else(|| "".to_string()) } + + pub fn into_jsonrpc_parts( + self, + ) -> std::result::Result<(RequestId, crate::Result), serde_json::Error> { + match self { + $( + Self::$variant { request_id, response } => { + serde_json::to_value(response).map(|result| (request_id, result)) + } + )* + } + } } + #[derive(Debug, Clone)] + #[allow(clippy::large_enum_variant)] + pub enum ClientResponsePayload { + $( $variant($response), )* + InterruptConversation(v1::InterruptConversationResponse), + } + + impl ClientResponsePayload { + pub fn into_jsonrpc_parts_and_payload( + self, + request_id: RequestId, + ) -> std::result::Result< + (RequestId, crate::Result, Option), + serde_json::Error, + > { + match self { + $( + Self::$variant(response) => { + let result = serde_json::to_value(&response)?; + Ok((request_id, result, Some(Self::$variant(response)))) + } + )* + Self::InterruptConversation(response) => { + serde_json::to_value(response).map(|result| (request_id, result, None)) + } + } + } + + pub fn into_client_response(self, request_id: RequestId) -> Option { + match self { + $( + Self::$variant(response) => { + Some(ClientResponse::$variant { + request_id, + response, + }) + } + )* + Self::InterruptConversation(_) => None, + } + } + + pub fn into_jsonrpc_parts( + self, + request_id: RequestId, + ) -> std::result::Result<(RequestId, crate::Result), serde_json::Error> { + self.to_jsonrpc_parts(request_id) + } + + pub fn to_jsonrpc_parts( + &self, + request_id: RequestId, + ) -> std::result::Result<(RequestId, crate::Result), serde_json::Error> { + match self { + $( + Self::$variant(response) => { + serde_json::to_value(response).map(|result| (request_id, result)) + } + )* + Self::InterruptConversation(response) => { + serde_json::to_value(response).map(|result| (request_id, result)) + } + } + } + } + + impl From for ClientResponsePayload { + fn from(response: v1::InterruptConversationResponse) -> Self { + Self::InterruptConversation(response) + } + } + + $( + client_response_payload_from_impl!( + $variant, + $response + $(, $manual_payload_conversion)? + ); + )* + impl crate::experimental_api::ExperimentalApi for ClientRequest { fn experimental_reason(&self) -> Option<&'static str> { match self { @@ -317,6 +410,17 @@ macro_rules! client_request_definitions { }; } +macro_rules! client_response_payload_from_impl { + ($variant:ident, $response:ty) => { + impl From<$response> for ClientResponsePayload { + fn from(response: $response) -> Self { + Self::$variant(response) + } + } + }; + ($variant:ident, $response:ty, manual) => {}; +} + client_request_definitions! { Initialize { params: v1::InitializeParams, @@ -789,11 +893,13 @@ client_request_definitions! { ConfigValueWrite => "config/value/write" { params: v2::ConfigValueWriteParams, serialization: global("config"), + manual_payload_conversion: manual, response: v2::ConfigWriteResponse, }, ConfigBatchWrite => "config/batchWrite" { params: v2::ConfigBatchWriteParams, serialization: global("config"), + manual_payload_conversion: manual, response: v2::ConfigWriteResponse, }, @@ -2766,3 +2872,7 @@ mod tests { ); } } + +#[cfg(test)] +#[path = "common_tests.rs"] +mod common_tests; diff --git a/codex-rs/app-server-protocol/src/protocol/common_tests.rs b/codex-rs/app-server-protocol/src/protocol/common_tests.rs new file mode 100644 index 0000000000..83e5d37117 --- /dev/null +++ b/codex-rs/app-server-protocol/src/protocol/common_tests.rs @@ -0,0 +1,44 @@ +use super::*; +use anyhow::Result; +use codex_protocol::protocol::TurnAbortReason; +use pretty_assertions::assert_eq; +use serde_json::json; + +#[test] +fn client_response_payload_returns_jsonrpc_parts_and_client_response() -> Result<()> { + let (request_id, result, payload) = + ClientResponsePayload::ThreadArchive(v2::ThreadArchiveResponse {}) + .into_jsonrpc_parts_and_payload(RequestId::Integer(7))?; + + assert_eq!(request_id, RequestId::Integer(7)); + assert_eq!(result, json!({})); + + let Some(ClientResponse::ThreadArchive { + request_id, + response: _, + }) = payload.and_then(|payload| payload.into_client_response(RequestId::Integer(7))) + else { + panic!("expected thread/archive client response"); + }; + assert_eq!(request_id, RequestId::Integer(7)); + Ok(()) +} + +#[test] +fn interrupt_conversation_payload_stays_jsonrpc_only() -> Result<()> { + let (request_id, result, payload) = + ClientResponsePayload::InterruptConversation(v1::InterruptConversationResponse { + abort_reason: TurnAbortReason::Interrupted, + }) + .into_jsonrpc_parts_and_payload(RequestId::Integer(8))?; + + assert_eq!(request_id, RequestId::Integer(8)); + assert_eq!( + result, + json!({ + "abortReason": "interrupted", + }) + ); + assert!(payload.is_none()); + Ok(()) +} diff --git a/codex-rs/app-server/src/codex_message_processor.rs b/codex-rs/app-server/src/codex_message_processor.rs index e53b5e2c20..d250959d8d 100644 --- a/codex-rs/app-server/src/codex_message_processor.rs +++ b/codex-rs/app-server/src/codex_message_processor.rs @@ -43,6 +43,7 @@ use codex_app_server_protocol::CancelLoginAccountResponse; use codex_app_server_protocol::CancelLoginAccountStatus; use codex_app_server_protocol::ClientRequest; use codex_app_server_protocol::ClientResponse; +use codex_app_server_protocol::ClientResponsePayload; use codex_app_server_protocol::CodexErrorInfo; use codex_app_server_protocol::CollaborationModeListParams; use codex_app_server_protocol::CollaborationModeListResponse; @@ -2118,7 +2119,7 @@ impl CodexMessageProcessor { let result = self .exec_one_off_command_inner(request_id.clone(), params) .await - .map(|()| None::); + .map(|()| None::); self.send_optional_result(request_id, result).await; } @@ -2864,7 +2865,6 @@ impl CodexMessageProcessor { response: response.clone(), }, ); - listener_task_context .outgoing .send_response(request_id, response) @@ -3544,7 +3544,7 @@ impl CodexMessageProcessor { let result = self .thread_rollback_start(&request_id, params) .await - .map(|()| None::); + .map(|()| None::); self.send_optional_result(request_id, result).await; } @@ -4401,6 +4401,7 @@ impl CodexMessageProcessor { permission_profile, reasoning_effort: session_configured.reasoning_effort, }; + self.analytics_events_client.track_response( request_id.connection_id.0, ClientResponse::ThreadResume { @@ -4408,7 +4409,6 @@ impl CodexMessageProcessor { response: response.clone(), }, ); - let connection_id = request_id.connection_id; let token_usage_thread = include_turns.then(|| response.thread.clone()); self.outgoing.send_response(request_id, response).await; @@ -5027,7 +5027,6 @@ impl CodexMessageProcessor { response: response.clone(), }, ); - let connection_id = request_id.connection_id; let token_usage_thread = include_turns.then(|| response.thread.clone()); self.outgoing.send_response(request_id, response).await; @@ -5811,7 +5810,7 @@ impl CodexMessageProcessor { request_id: ConnectionRequestId, result: Result, JSONRPCErrorError>, ) where - T: serde::Serialize, + T: Into, { match result { Ok(Some(response)) => self.outgoing.send_response(request_id, response).await, diff --git a/codex-rs/app-server/src/message_processor.rs b/codex-rs/app-server/src/message_processor.rs index 56d5ff183a..1a28f8e278 100644 --- a/codex-rs/app-server/src/message_processor.rs +++ b/codex-rs/app-server/src/message_processor.rs @@ -37,6 +37,7 @@ use codex_app_server_protocol::ChatgptAuthTokensRefreshResponse; use codex_app_server_protocol::ClientInfo; use codex_app_server_protocol::ClientNotification; use codex_app_server_protocol::ClientRequest; +use codex_app_server_protocol::ClientResponsePayload; use codex_app_server_protocol::ConfigBatchWriteParams; use codex_app_server_protocol::ConfigValueWriteParams; use codex_app_server_protocol::ConfigWarningNotification; @@ -732,15 +733,11 @@ impl MessageProcessor { return Err(invalid_request(experimental_required_message(reason))); } let connection_id = connection_request_id.connection_id; - if let ClientRequest::TurnStart { request_id, .. } - | ClientRequest::TurnSteer { request_id, .. } = &codex_request - { - self.analytics_events_client.track_request( - connection_id.0, - request_id.clone(), - codex_request.clone(), - ); - } + self.analytics_events_client.track_request( + connection_id.0, + connection_request_id.request_id.clone(), + &codex_request, + ); let serialization_scope = codex_request.serialization_scope(); let app_server_client_name = session.app_server_client_name().map(str::to_string); @@ -992,7 +989,12 @@ impl MessageProcessor { params: ConfigValueWriteParams, ) { let result = self.config_api.write_value(params).await; - self.handle_config_mutation_result(request_id, result).await + self.handle_config_mutation_result( + request_id, + result, + ClientResponsePayload::ConfigValueWrite, + ) + .await } async fn handle_config_batch_write( @@ -1001,7 +1003,12 @@ impl MessageProcessor { params: ConfigBatchWriteParams, ) { let result = self.config_api.batch_write(params).await; - self.handle_config_mutation_result(request_id, result).await; + self.handle_config_mutation_result( + request_id, + result, + ClientResponsePayload::ConfigBatchWrite, + ) + .await; } async fn handle_experimental_feature_enablement_set( @@ -1015,7 +1022,12 @@ impl MessageProcessor { .set_experimental_feature_enablement(params) .await; let is_ok = result.is_ok(); - self.handle_config_mutation_result(request_id, result).await; + self.handle_config_mutation_result( + request_id, + result, + ClientResponsePayload::ExperimentalFeatureEnablementSet, + ) + .await; if should_refresh_apps_list && is_ok { self.refresh_apps_list_after_experimental_feature_enablement_set() .await; @@ -1091,15 +1103,18 @@ impl MessageProcessor { }); } - async fn handle_config_mutation_result( + async fn handle_config_mutation_result( &self, request_id: ConnectionRequestId, result: std::result::Result, + wrap_success: impl FnOnce(T) -> ClientResponsePayload, ) { match result { Ok(response) => { self.handle_config_mutation().await; - self.outgoing.send_response(request_id, response).await; + self.outgoing + .send_response_as(request_id, wrap_success(response)) + .await; } Err(error) => self.outgoing.send_error(request_id, error).await, } @@ -1177,7 +1192,7 @@ impl MessageProcessor { device_key_requests_allowed: bool, run_request: F, ) where - R: serde::Serialize + Send + 'static, + R: Into + Send + 'static, F: FnOnce(DeviceKeyApi) -> Fut + Send + 'static, Fut: Future> + Send + 'static, { diff --git a/codex-rs/app-server/src/outgoing_message.rs b/codex-rs/app-server/src/outgoing_message.rs index 300a7044e5..588b4b53eb 100644 --- a/codex-rs/app-server/src/outgoing_message.rs +++ b/codex-rs/app-server/src/outgoing_message.rs @@ -5,6 +5,7 @@ use std::sync::atomic::AtomicI64; use std::sync::atomic::Ordering; use codex_analytics::AnalyticsEventsClient; +use codex_app_server_protocol::ClientResponsePayload; use codex_app_server_protocol::JSONRPCErrorError; use codex_app_server_protocol::RequestId; use codex_app_server_protocol::Result; @@ -188,11 +189,10 @@ impl ThreadScopedOutgoingMessageSender { .await } - pub(crate) async fn send_response( - &self, - request_id: ConnectionRequestId, - response: T, - ) { + pub(crate) async fn send_response(&self, request_id: ConnectionRequestId, response: T) + where + T: Into, + { self.outgoing.send_response(request_id, response).await; } @@ -482,21 +482,28 @@ impl OutgoingMessageSender { } } - pub(crate) async fn send_response( + pub(crate) async fn send_response(&self, request_id: ConnectionRequestId, response: T) + where + T: Into, + { + self.send_response_as(request_id, response.into()).await; + } + + pub(crate) async fn send_response_as( &self, request_id: ConnectionRequestId, - response: T, + response: ClientResponsePayload, ) { + let connection_id = request_id.connection_id; + let serialized_response = response.into_jsonrpc_parts(request_id.request_id.clone()); let request_context = self.take_request_context(&request_id).await; - match serde_json::to_value(response) { - Ok(result) => { - let outgoing_message = OutgoingMessage::Response(OutgoingResponse { - id: request_id.request_id.clone(), - result, - }); + + match serialized_response { + Ok((id, result)) => { + let outgoing_message = OutgoingMessage::Response(OutgoingResponse { id, result }); self.send_outgoing_message_to_connection( request_context, - request_id.connection_id, + connection_id, outgoing_message, "response", ) @@ -592,11 +599,13 @@ impl OutgoingMessageSender { request_id: ConnectionRequestId, result: std::result::Result, ) where - T: Serialize, + T: Into, E: Into, { match result { - Ok(response) => self.send_response(request_id, response).await, + Ok(response) => { + self.send_response(request_id, response).await; + } Err(error) => self.send_error(request_id, error).await, } } @@ -966,7 +975,12 @@ mod tests { }; outgoing - .send_response(request_id.clone(), json!({ "ok": true })) + .send_response( + request_id.clone(), + ClientResponsePayload::ThreadArchive( + codex_app_server_protocol::ThreadArchiveResponse {}, + ), + ) .await; let envelope = timeout(Duration::from_secs(1), rx.recv()) @@ -985,7 +999,7 @@ mod tests { panic!("expected response message"); }; assert_eq!(response.id, request_id.request_id); - assert_eq!(response.result, json!({ "ok": true })); + assert_eq!(response.result, json!({})); } other => panic!("expected targeted response envelope, got: {other:?}"), } @@ -1011,7 +1025,12 @@ mod tests { assert_eq!(outgoing.request_context_count().await, 1); outgoing - .send_response(request_id, json!({ "ok": true })) + .send_response( + request_id, + ClientResponsePayload::ThreadArchive( + codex_app_server_protocol::ThreadArchiveResponse {}, + ), + ) .await; assert_eq!(outgoing.request_context_count().await, 0);