From 7bf007ddd986b7190ed5020a2fe7ad776f130de4 Mon Sep 17 00:00:00 2001 From: Anton Panasenko Date: Fri, 29 May 2026 04:48:25 -0700 Subject: [PATCH] fix(app-server-transport): recover rejected pairing auth --- .../src/transport/remote_control/mod.rs | 116 +++++-- .../src/transport/remote_control/pairing.rs | 27 +- .../pairing_integration_tests.rs | 306 ++++++++++++++++++ .../src/transport/remote_control/tests.rs | 1 + 4 files changed, 418 insertions(+), 32 deletions(-) diff --git a/codex-rs/app-server-transport/src/transport/remote_control/mod.rs b/codex-rs/app-server-transport/src/transport/remote_control/mod.rs index 5f6dfda15f..b7d683210f 100644 --- a/codex-rs/app-server-transport/src/transport/remote_control/mod.rs +++ b/codex-rs/app-server-transport/src/transport/remote_control/mod.rs @@ -60,11 +60,19 @@ pub struct RemoteControlHandle { status_tx: Arc>, state_db_available: bool, pairing: PairingClientState, - #[cfg(test)] + pairing_request_cancellation_revision: Arc, pairing_refresh_tx: Arc>, auth_change_rx: Arc>>, } +// `/server/pair` runs outside the websocket task, so keep the auth and disable +// revisions that make its eventual response safe to return to app-server. +struct PairingRequestContext { + auth_change_revision: u64, + cancellation_revision: u64, + pairing_client: RemoteControlPairingClient, +} + #[derive(Clone)] pub(super) struct PairingClientState { client: Arc>>, @@ -135,6 +143,7 @@ impl RemoteControlHandle { changed }); clear_pairing_client(&self.pairing); + self.cancel_pending_pairing_requests(); let status = self.status(); info!( @@ -148,6 +157,11 @@ impl RemoteControlHandle { self.publish_status(RemoteControlConnectionStatus::Disabled) } + pub fn cancel_pending_pairing_requests(&self) { + self.pairing_request_cancellation_revision + .fetch_add(1, Ordering::Relaxed); + } + pub fn status(&self) -> RemoteControlStatusChangedNotification { self.status_tx.borrow().clone() } @@ -160,6 +174,22 @@ impl RemoteControlHandle { &self, params: RemoteControlPairingStartParams, ) -> io::Result { + let pairing_request = self.pairing_request_context()?; + let pairing_response = pairing_request + .pairing_client + .start(protocol::StartRemoteControlPairingRequest { + manual_code: params.manual_code, + }) + .await; + self.refresh_pairing_auth_after_rejection( + &pairing_request.pairing_client, + &pairing_response, + ); + self.validate_pairing_response(&pairing_request, &pairing_response)?; + pairing_response + } + + fn pairing_request_context(&self) -> io::Result { if !*self.enabled_tx.borrow() { return Err(io::Error::new( io::ErrorKind::InvalidInput, @@ -189,18 +219,49 @@ impl RemoteControlHandle { current_pairing_client } .ok_or_else(Self::pairing_unavailable_error)?; - let pairing_response = pairing_client - .start(protocol::StartRemoteControlPairingRequest { - manual_code: params.manual_code, - }) - .await; - if self.auth_change_revision() != auth_change_revision { + + Ok(PairingRequestContext { + auth_change_revision, + cancellation_revision: self + .pairing_request_cancellation_revision + .load(Ordering::Relaxed), + pairing_client, + }) + } + + fn refresh_pairing_auth_after_rejection( + &self, + pairing_client: &RemoteControlPairingClient, + pairing_response: &io::Result, + ) { + if pairing_response.as_ref().is_err_and(|err| { + matches!( + err.kind(), + io::ErrorKind::NotFound | io::ErrorKind::PermissionDenied + ) + }) && clear_pairing_client_if_current(&self.pairing, pairing_client) + { + self.request_pairing_auth_refresh(); + } + } + + fn validate_pairing_response( + &self, + pairing_request: &PairingRequestContext, + pairing_response: &io::Result, + ) -> io::Result<()> { + if self.auth_change_revision() != pairing_request.auth_change_revision { return Err(Self::pairing_unavailable_error()); } - if pairing_response.is_ok() && !self.pairing_client_is_current(&pairing_client) { + if pairing_response.is_ok() + && self + .pairing_request_cancellation_revision + .load(Ordering::Relaxed) + != pairing_request.cancellation_revision + { return Err(Self::pairing_unavailable_error()); } - pairing_response + Ok(()) } fn auth_change_revision(&self) -> u64 { @@ -218,21 +279,6 @@ impl RemoteControlHandle { ) } - fn pairing_client_is_current(&self, pairing_client: &RemoteControlPairingClient) -> bool { - *self.enabled_tx.borrow() - && self.status().status == RemoteControlConnectionStatus::Connected - && self - .pairing - .client - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner) - .as_ref() - .is_some_and(|current_pairing_client| { - current_pairing_client.generation() == pairing_client.generation() - }) - } - - #[cfg(test)] fn request_pairing_auth_refresh(&self) { self.pairing_refresh_tx .send_modify(|revision| *revision = revision.wrapping_add(1)); @@ -293,6 +339,25 @@ fn clear_pairing_client(pairing: &PairingClientState) { pairing.generation.fetch_add(1, Ordering::Relaxed); } +fn clear_pairing_client_if_current( + pairing: &PairingClientState, + expected_pairing_client: &RemoteControlPairingClient, +) -> bool { + let mut pairing_client = pairing + .client + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + if pairing_client + .as_ref() + .is_none_or(|pairing_client| !pairing_client.matches_pairing_auth(expected_pairing_client)) + { + return false; + } + *pairing_client = None; + pairing.generation.fetch_add(1, Ordering::Relaxed); + true +} + pub async fn start_remote_control( config: RemoteControlStartConfig, state_db: Option>, @@ -317,6 +382,7 @@ pub async fn start_remote_control( let (enabled_tx, enabled_rx) = watch::channel(initial_enabled); let pairing = PairingClientState::new(); let websocket_pairing = pairing.clone(); + let pairing_request_cancellation_revision = Arc::new(AtomicU64::new(0)); let (pairing_refresh_tx, pairing_refresh_rx) = watch::channel(0u64); let websocket_pairing_refresh_tx = pairing_refresh_tx.clone(); let auth_change_rx = Arc::new(StdMutex::new(auth_manager.auth_change_receiver())); @@ -415,7 +481,7 @@ pub async fn start_remote_control( status_tx: Arc::new(status_tx), state_db_available, pairing, - #[cfg(test)] + pairing_request_cancellation_revision, pairing_refresh_tx: Arc::new(pairing_refresh_tx), auth_change_rx, }, diff --git a/codex-rs/app-server-transport/src/transport/remote_control/pairing.rs b/codex-rs/app-server-transport/src/transport/remote_control/pairing.rs index 4f3b1d2412..298479ef06 100644 --- a/codex-rs/app-server-transport/src/transport/remote_control/pairing.rs +++ b/codex-rs/app-server-transport/src/transport/remote_control/pairing.rs @@ -48,8 +48,13 @@ impl RemoteControlPairingClient { self.auth_change_revision == auth_change_revision } - pub(super) fn generation(&self) -> u64 { - self.generation + pub(super) fn matches_pairing_auth(&self, other: &Self) -> bool { + self.pairing_url == other.pairing_url + && self.remote_control_token == other.remote_control_token + && self.server_id == other.server_id + && self.environment_id == other.environment_id + && self.auth_change_revision == other.auth_change_revision + && self.generation == other.generation } pub(super) async fn start( @@ -86,11 +91,19 @@ impl RemoteControlPairingClient { })?; let body_preview = preview_remote_control_response_body(&body); if !status.is_success() { - return Err(io::Error::other(format!( - "remote control pairing failed at `{}`: HTTP {status}, {}, body: {body_preview}", - self.pairing_url, - format_headers(&headers) - ))); + let error_kind = match status.as_u16() { + 401 | 403 => ErrorKind::PermissionDenied, + 404 => ErrorKind::NotFound, + _ => ErrorKind::Other, + }; + return Err(io::Error::new( + error_kind, + format!( + "remote control pairing failed at `{}`: HTTP {status}, {}, body: {body_preview}", + self.pairing_url, + format_headers(&headers) + ), + )); } let pairing = serde_json::from_slice::(&body).map_err( diff --git a/codex-rs/app-server-transport/src/transport/remote_control/pairing_integration_tests.rs b/codex-rs/app-server-transport/src/transport/remote_control/pairing_integration_tests.rs index 532cabedb4..cb48796846 100644 --- a/codex-rs/app-server-transport/src/transport/remote_control/pairing_integration_tests.rs +++ b/codex-rs/app-server-transport/src/transport/remote_control/pairing_integration_tests.rs @@ -300,6 +300,162 @@ async fn remote_control_handle_keeps_pairing_response_after_pairing_auth_refresh ); } +#[tokio::test] +async fn remote_control_handle_keeps_pairing_response_after_connection_cycle_end() { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("listener should bind"); + let remote_control_url = remote_control_url_for_listener(&listener); + let remote_handle = remote_control_handle_with_pairing_client( + &remote_control_url, + watch::channel(/*init*/ 0u64).1, + ); + let pairing_task = tokio::spawn({ + let remote_handle = remote_handle.clone(); + async move { + remote_handle + .start_pairing(RemoteControlPairingStartParams::default()) + .await + } + }); + + let pairing_request = accept_http_request(&listener).await; + clear_pairing_client(&remote_handle.pairing); + respond_with_json( + pairing_request.stream, + json!({ + "pairing_code": "fresh-pairing-code", + "manual_pairing_code": "ABCD-EFGH", + "server_id": "srv_e_test", + "environment_id": "env_test", + "expires_at": "3026-05-22T12:34:56Z", + }), + ) + .await; + + assert_eq!( + pairing_task + .await + .expect("pairing task should join") + .expect("pairing response should be kept") + .pairing_code, + "fresh-pairing-code" + ); +} + +#[tokio::test] +async fn remote_control_handle_discards_pairing_response_after_disable_and_reenable() { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("listener should bind"); + let remote_control_url = remote_control_url_for_listener(&listener); + let remote_handle = remote_control_handle_with_pairing_client( + &remote_control_url, + watch::channel(/*init*/ 0u64).1, + ); + let pairing_task = tokio::spawn({ + let remote_handle = remote_handle.clone(); + async move { + remote_handle + .start_pairing(RemoteControlPairingStartParams::default()) + .await + } + }); + + let pairing_request = accept_http_request(&listener).await; + remote_handle.disable(); + remote_handle.enable().expect("enable should succeed"); + respond_with_json( + pairing_request.stream, + json!({ + "pairing_code": "stale-pairing-code", + "manual_pairing_code": "ABCD-EFGH", + "server_id": "srv_e_test", + "environment_id": "env_test", + "expires_at": "3026-05-22T12:34:56Z", + }), + ) + .await; + + assert_eq!( + pairing_task + .await + .expect("pairing task should join") + .expect_err("disabled remote control should discard pairing response") + .to_string(), + "remote control pairing is unavailable until enrollment completes" + ); +} + +#[tokio::test] +async fn remote_control_handle_keeps_refreshed_pairing_auth_after_stale_rejection() { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("listener should bind"); + let remote_control_url = remote_control_url_for_listener(&listener); + let remote_handle = remote_control_handle_with_pairing_client( + &remote_control_url, + watch::channel(/*init*/ 0u64).1, + ); + let pairing_task = tokio::spawn({ + let remote_handle = remote_handle.clone(); + async move { + remote_handle + .start_pairing(RemoteControlPairingStartParams::default()) + .await + } + }); + + let pairing_request = accept_http_request(&listener).await; + let generation = remote_handle + .pairing + .generation + .load(std::sync::atomic::Ordering::Relaxed); + let refreshed_pairing_client = RemoteControlPairingClient::new( + &normalize_remote_control_url(&remote_control_url) + .expect("remote control url should normalize"), + TEST_REFRESHED_REMOTE_CONTROL_SERVER_TOKEN.to_string(), + "srv_e_test".to_string(), + "env_test".to_string(), + OffsetDateTime::parse( + TEST_REMOTE_CONTROL_SERVER_TOKEN_EXPIRES_AT, + &time::format_description::well_known::Rfc3339, + ) + .expect("server token expiry should parse"), + /*auth_change_revision*/ 0, + generation, + ); + *remote_handle + .pairing + .client + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) = + Some(refreshed_pairing_client.clone()); + respond_with_status(pairing_request.stream, "401 Unauthorized", "stale token").await; + + assert_eq!( + pairing_task + .await + .expect("pairing task should join") + .expect_err("stale pairing token should be rejected") + .kind(), + std::io::ErrorKind::PermissionDenied + ); + assert!( + remote_handle + .pairing + .client + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .as_ref() + .is_some_and(|pairing_client| { + pairing_client.matches_pairing_auth(&refreshed_pairing_client) + }), + "stale pairing rejection should keep refreshed pairing auth" + ); + assert_eq!(*remote_handle.pairing_refresh_tx.borrow(), 0); +} + #[tokio::test] async fn remote_control_handle_clears_pairing_client_after_auth_change() { let listener = TcpListener::bind("127.0.0.1:0") @@ -809,3 +965,153 @@ async fn remote_control_connected_refresh_404_reenrolls() { shutdown_token.cancel(); let _ = remote_task.await; } + +#[tokio::test] +async fn remote_control_pairing_rejection_refreshes_server_token_while_connected() { + remote_control_pairing_rejection_recovers_while_connected("401 Unauthorized").await; +} + +#[tokio::test] +async fn remote_control_pairing_not_found_refreshes_server_token_while_connected() { + remote_control_pairing_rejection_recovers_while_connected("404 Not Found").await; +} + +async fn remote_control_pairing_rejection_recovers_while_connected(pair_status: &str) { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("listener should bind"); + let remote_control_url = remote_control_url_for_listener(&listener); + let codex_home = TempDir::new().expect("temp dir should create"); + let (transport_event_tx, _transport_event_rx) = + mpsc::channel::(CHANNEL_CAPACITY); + let shutdown_token = CancellationToken::new(); + let (remote_task, remote_handle) = start_remote_control( + RemoteControlStartConfig { + remote_control_url, + installation_id: TEST_INSTALLATION_ID.to_string(), + }, + Some(remote_control_state_runtime(&codex_home).await), + remote_control_auth_manager(), + transport_event_tx, + shutdown_token.clone(), + /*app_server_client_name_rx*/ None, + /*initial_enabled*/ true, + ) + .await + .expect("remote control should start"); + + let enroll_request = accept_http_request(&listener).await; + respond_with_json( + enroll_request.stream, + remote_control_server_token_response( + "srv_e_test", + "env_test", + TEST_REMOTE_CONTROL_SERVER_TOKEN, + ), + ) + .await; + let mut first_websocket = accept_remote_control_connection(&listener).await; + timeout(Duration::from_secs(5), async { + while remote_handle.status().status != RemoteControlConnectionStatus::Connected { + tokio::task::yield_now().await; + } + }) + .await + .expect("remote control should publish connected before pairing"); + + let stale_pairing_task = tokio::spawn({ + let remote_handle = remote_handle.clone(); + async move { + remote_handle + .start_pairing(RemoteControlPairingStartParams::default()) + .await + } + }); + let pairing_request = accept_http_request(&listener).await; + assert_eq!( + pairing_request.headers.get("authorization"), + Some(&format!("Bearer {TEST_REMOTE_CONTROL_SERVER_TOKEN}")) + ); + respond_with_status(pairing_request.stream, pair_status, "stale token").await; + assert_eq!( + stale_pairing_task + .await + .expect("stale pairing task should join") + .expect_err("stale pairing token should be rejected") + .kind(), + if pair_status == "404 Not Found" { + std::io::ErrorKind::NotFound + } else { + std::io::ErrorKind::PermissionDenied + } + ); + assert!( + remote_handle + .pairing + .client + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .is_none(), + "pairing rejection should clear cached pairing auth" + ); + assert_eq!(*remote_handle.pairing_refresh_tx.borrow(), 1); + + let refresh_request = accept_http_request(&listener).await; + assert_eq!( + refresh_request.request_line, + "POST /backend-api/wham/remote/control/server/refresh HTTP/1.1" + ); + respond_with_json( + refresh_request.stream, + remote_control_server_token_response( + "srv_e_test", + "env_test", + TEST_REFRESHED_REMOTE_CONTROL_SERVER_TOKEN, + ), + ) + .await; + assert!( + timeout(Duration::from_millis(100), first_websocket.next()) + .await + .is_err(), + "pairing auth refresh should keep the websocket open" + ); + + let refreshed_pairing_task = tokio::spawn({ + let remote_handle = remote_handle.clone(); + async move { + remote_handle + .start_pairing(RemoteControlPairingStartParams::default()) + .await + } + }); + let pairing_request = accept_http_request(&listener).await; + assert_eq!( + pairing_request.headers.get("authorization"), + Some(&format!( + "Bearer {TEST_REFRESHED_REMOTE_CONTROL_SERVER_TOKEN}" + )) + ); + respond_with_json( + pairing_request.stream, + json!({ + "pairing_code": "pairing-code", + "manual_pairing_code": "ABCD-EFGH", + "server_id": "srv_e_test", + "environment_id": "env_test", + "expires_at": "3026-05-22T12:34:56Z", + }), + ) + .await; + refreshed_pairing_task + .await + .expect("refreshed pairing task should join") + .expect("refreshed pairing token should pair"); + + first_websocket + .close(None) + .await + .expect("websocket should close"); + shutdown_token.cancel(); + let _ = remote_task.await; +} diff --git a/codex-rs/app-server-transport/src/transport/remote_control/tests.rs b/codex-rs/app-server-transport/src/transport/remote_control/tests.rs index 0e43e6d92e..024fc61fa0 100644 --- a/codex-rs/app-server-transport/src/transport/remote_control/tests.rs +++ b/codex-rs/app-server-transport/src/transport/remote_control/tests.rs @@ -163,6 +163,7 @@ fn remote_control_handle_with_pairing_client( client: pairing_client, generation: Arc::new(std::sync::atomic::AtomicU64::new(0)), }, + pairing_request_cancellation_revision: Arc::new(std::sync::atomic::AtomicU64::new(0)), pairing_refresh_tx: Arc::new(pairing_refresh_tx), auth_change_rx: Arc::new(StdMutex::new(auth_change_rx)), }