fix(app-server-transport): recover rejected pairing auth

This commit is contained in:
Anton Panasenko
2026-05-29 04:48:25 -07:00
parent 1f77ff722b
commit 7bf007ddd9
4 changed files with 418 additions and 32 deletions

View File

@@ -60,11 +60,19 @@ pub struct RemoteControlHandle {
status_tx: Arc<watch::Sender<RemoteControlStatusChangedNotification>>,
state_db_available: bool,
pairing: PairingClientState,
#[cfg(test)]
pairing_request_cancellation_revision: Arc<AtomicU64>,
pairing_refresh_tx: Arc<watch::Sender<u64>>,
auth_change_rx: Arc<StdMutex<watch::Receiver<u64>>>,
}
// `/server/pair` runs outside the websocket task, so keep the auth and disable
// revisions that make its eventual response safe to return to app-server.
struct PairingRequestContext {
auth_change_revision: u64,
cancellation_revision: u64,
pairing_client: RemoteControlPairingClient,
}
#[derive(Clone)]
pub(super) struct PairingClientState {
client: Arc<StdMutex<Option<RemoteControlPairingClient>>>,
@@ -135,6 +143,7 @@ impl RemoteControlHandle {
changed
});
clear_pairing_client(&self.pairing);
self.cancel_pending_pairing_requests();
let status = self.status();
info!(
@@ -148,6 +157,11 @@ impl RemoteControlHandle {
self.publish_status(RemoteControlConnectionStatus::Disabled)
}
pub fn cancel_pending_pairing_requests(&self) {
self.pairing_request_cancellation_revision
.fetch_add(1, Ordering::Relaxed);
}
pub fn status(&self) -> RemoteControlStatusChangedNotification {
self.status_tx.borrow().clone()
}
@@ -160,6 +174,22 @@ impl RemoteControlHandle {
&self,
params: RemoteControlPairingStartParams,
) -> io::Result<RemoteControlPairingStartResponse> {
let pairing_request = self.pairing_request_context()?;
let pairing_response = pairing_request
.pairing_client
.start(protocol::StartRemoteControlPairingRequest {
manual_code: params.manual_code,
})
.await;
self.refresh_pairing_auth_after_rejection(
&pairing_request.pairing_client,
&pairing_response,
);
self.validate_pairing_response(&pairing_request, &pairing_response)?;
pairing_response
}
fn pairing_request_context(&self) -> io::Result<PairingRequestContext> {
if !*self.enabled_tx.borrow() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
@@ -189,18 +219,49 @@ impl RemoteControlHandle {
current_pairing_client
}
.ok_or_else(Self::pairing_unavailable_error)?;
let pairing_response = pairing_client
.start(protocol::StartRemoteControlPairingRequest {
manual_code: params.manual_code,
})
.await;
if self.auth_change_revision() != auth_change_revision {
Ok(PairingRequestContext {
auth_change_revision,
cancellation_revision: self
.pairing_request_cancellation_revision
.load(Ordering::Relaxed),
pairing_client,
})
}
fn refresh_pairing_auth_after_rejection(
&self,
pairing_client: &RemoteControlPairingClient,
pairing_response: &io::Result<RemoteControlPairingStartResponse>,
) {
if pairing_response.as_ref().is_err_and(|err| {
matches!(
err.kind(),
io::ErrorKind::NotFound | io::ErrorKind::PermissionDenied
)
}) && clear_pairing_client_if_current(&self.pairing, pairing_client)
{
self.request_pairing_auth_refresh();
}
}
fn validate_pairing_response(
&self,
pairing_request: &PairingRequestContext,
pairing_response: &io::Result<RemoteControlPairingStartResponse>,
) -> io::Result<()> {
if self.auth_change_revision() != pairing_request.auth_change_revision {
return Err(Self::pairing_unavailable_error());
}
if pairing_response.is_ok() && !self.pairing_client_is_current(&pairing_client) {
if pairing_response.is_ok()
&& self
.pairing_request_cancellation_revision
.load(Ordering::Relaxed)
!= pairing_request.cancellation_revision
{
return Err(Self::pairing_unavailable_error());
}
pairing_response
Ok(())
}
fn auth_change_revision(&self) -> u64 {
@@ -218,21 +279,6 @@ impl RemoteControlHandle {
)
}
fn pairing_client_is_current(&self, pairing_client: &RemoteControlPairingClient) -> bool {
*self.enabled_tx.borrow()
&& self.status().status == RemoteControlConnectionStatus::Connected
&& self
.pairing
.client
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.as_ref()
.is_some_and(|current_pairing_client| {
current_pairing_client.generation() == pairing_client.generation()
})
}
#[cfg(test)]
fn request_pairing_auth_refresh(&self) {
self.pairing_refresh_tx
.send_modify(|revision| *revision = revision.wrapping_add(1));
@@ -293,6 +339,25 @@ fn clear_pairing_client(pairing: &PairingClientState) {
pairing.generation.fetch_add(1, Ordering::Relaxed);
}
fn clear_pairing_client_if_current(
pairing: &PairingClientState,
expected_pairing_client: &RemoteControlPairingClient,
) -> bool {
let mut pairing_client = pairing
.client
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
if pairing_client
.as_ref()
.is_none_or(|pairing_client| !pairing_client.matches_pairing_auth(expected_pairing_client))
{
return false;
}
*pairing_client = None;
pairing.generation.fetch_add(1, Ordering::Relaxed);
true
}
pub async fn start_remote_control(
config: RemoteControlStartConfig,
state_db: Option<Arc<StateRuntime>>,
@@ -317,6 +382,7 @@ pub async fn start_remote_control(
let (enabled_tx, enabled_rx) = watch::channel(initial_enabled);
let pairing = PairingClientState::new();
let websocket_pairing = pairing.clone();
let pairing_request_cancellation_revision = Arc::new(AtomicU64::new(0));
let (pairing_refresh_tx, pairing_refresh_rx) = watch::channel(0u64);
let websocket_pairing_refresh_tx = pairing_refresh_tx.clone();
let auth_change_rx = Arc::new(StdMutex::new(auth_manager.auth_change_receiver()));
@@ -415,7 +481,7 @@ pub async fn start_remote_control(
status_tx: Arc::new(status_tx),
state_db_available,
pairing,
#[cfg(test)]
pairing_request_cancellation_revision,
pairing_refresh_tx: Arc::new(pairing_refresh_tx),
auth_change_rx,
},

View File

@@ -48,8 +48,13 @@ impl RemoteControlPairingClient {
self.auth_change_revision == auth_change_revision
}
pub(super) fn generation(&self) -> u64 {
self.generation
pub(super) fn matches_pairing_auth(&self, other: &Self) -> bool {
self.pairing_url == other.pairing_url
&& self.remote_control_token == other.remote_control_token
&& self.server_id == other.server_id
&& self.environment_id == other.environment_id
&& self.auth_change_revision == other.auth_change_revision
&& self.generation == other.generation
}
pub(super) async fn start(
@@ -86,11 +91,19 @@ impl RemoteControlPairingClient {
})?;
let body_preview = preview_remote_control_response_body(&body);
if !status.is_success() {
return Err(io::Error::other(format!(
"remote control pairing failed at `{}`: HTTP {status}, {}, body: {body_preview}",
self.pairing_url,
format_headers(&headers)
)));
let error_kind = match status.as_u16() {
401 | 403 => ErrorKind::PermissionDenied,
404 => ErrorKind::NotFound,
_ => ErrorKind::Other,
};
return Err(io::Error::new(
error_kind,
format!(
"remote control pairing failed at `{}`: HTTP {status}, {}, body: {body_preview}",
self.pairing_url,
format_headers(&headers)
),
));
}
let pairing = serde_json::from_slice::<StartRemoteControlPairingResponse>(&body).map_err(

View File

@@ -300,6 +300,162 @@ async fn remote_control_handle_keeps_pairing_response_after_pairing_auth_refresh
);
}
#[tokio::test]
async fn remote_control_handle_keeps_pairing_response_after_connection_cycle_end() {
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("listener should bind");
let remote_control_url = remote_control_url_for_listener(&listener);
let remote_handle = remote_control_handle_with_pairing_client(
&remote_control_url,
watch::channel(/*init*/ 0u64).1,
);
let pairing_task = tokio::spawn({
let remote_handle = remote_handle.clone();
async move {
remote_handle
.start_pairing(RemoteControlPairingStartParams::default())
.await
}
});
let pairing_request = accept_http_request(&listener).await;
clear_pairing_client(&remote_handle.pairing);
respond_with_json(
pairing_request.stream,
json!({
"pairing_code": "fresh-pairing-code",
"manual_pairing_code": "ABCD-EFGH",
"server_id": "srv_e_test",
"environment_id": "env_test",
"expires_at": "3026-05-22T12:34:56Z",
}),
)
.await;
assert_eq!(
pairing_task
.await
.expect("pairing task should join")
.expect("pairing response should be kept")
.pairing_code,
"fresh-pairing-code"
);
}
#[tokio::test]
async fn remote_control_handle_discards_pairing_response_after_disable_and_reenable() {
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("listener should bind");
let remote_control_url = remote_control_url_for_listener(&listener);
let remote_handle = remote_control_handle_with_pairing_client(
&remote_control_url,
watch::channel(/*init*/ 0u64).1,
);
let pairing_task = tokio::spawn({
let remote_handle = remote_handle.clone();
async move {
remote_handle
.start_pairing(RemoteControlPairingStartParams::default())
.await
}
});
let pairing_request = accept_http_request(&listener).await;
remote_handle.disable();
remote_handle.enable().expect("enable should succeed");
respond_with_json(
pairing_request.stream,
json!({
"pairing_code": "stale-pairing-code",
"manual_pairing_code": "ABCD-EFGH",
"server_id": "srv_e_test",
"environment_id": "env_test",
"expires_at": "3026-05-22T12:34:56Z",
}),
)
.await;
assert_eq!(
pairing_task
.await
.expect("pairing task should join")
.expect_err("disabled remote control should discard pairing response")
.to_string(),
"remote control pairing is unavailable until enrollment completes"
);
}
#[tokio::test]
async fn remote_control_handle_keeps_refreshed_pairing_auth_after_stale_rejection() {
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("listener should bind");
let remote_control_url = remote_control_url_for_listener(&listener);
let remote_handle = remote_control_handle_with_pairing_client(
&remote_control_url,
watch::channel(/*init*/ 0u64).1,
);
let pairing_task = tokio::spawn({
let remote_handle = remote_handle.clone();
async move {
remote_handle
.start_pairing(RemoteControlPairingStartParams::default())
.await
}
});
let pairing_request = accept_http_request(&listener).await;
let generation = remote_handle
.pairing
.generation
.load(std::sync::atomic::Ordering::Relaxed);
let refreshed_pairing_client = RemoteControlPairingClient::new(
&normalize_remote_control_url(&remote_control_url)
.expect("remote control url should normalize"),
TEST_REFRESHED_REMOTE_CONTROL_SERVER_TOKEN.to_string(),
"srv_e_test".to_string(),
"env_test".to_string(),
OffsetDateTime::parse(
TEST_REMOTE_CONTROL_SERVER_TOKEN_EXPIRES_AT,
&time::format_description::well_known::Rfc3339,
)
.expect("server token expiry should parse"),
/*auth_change_revision*/ 0,
generation,
);
*remote_handle
.pairing
.client
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner) =
Some(refreshed_pairing_client.clone());
respond_with_status(pairing_request.stream, "401 Unauthorized", "stale token").await;
assert_eq!(
pairing_task
.await
.expect("pairing task should join")
.expect_err("stale pairing token should be rejected")
.kind(),
std::io::ErrorKind::PermissionDenied
);
assert!(
remote_handle
.pairing
.client
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.as_ref()
.is_some_and(|pairing_client| {
pairing_client.matches_pairing_auth(&refreshed_pairing_client)
}),
"stale pairing rejection should keep refreshed pairing auth"
);
assert_eq!(*remote_handle.pairing_refresh_tx.borrow(), 0);
}
#[tokio::test]
async fn remote_control_handle_clears_pairing_client_after_auth_change() {
let listener = TcpListener::bind("127.0.0.1:0")
@@ -809,3 +965,153 @@ async fn remote_control_connected_refresh_404_reenrolls() {
shutdown_token.cancel();
let _ = remote_task.await;
}
#[tokio::test]
async fn remote_control_pairing_rejection_refreshes_server_token_while_connected() {
remote_control_pairing_rejection_recovers_while_connected("401 Unauthorized").await;
}
#[tokio::test]
async fn remote_control_pairing_not_found_refreshes_server_token_while_connected() {
remote_control_pairing_rejection_recovers_while_connected("404 Not Found").await;
}
async fn remote_control_pairing_rejection_recovers_while_connected(pair_status: &str) {
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("listener should bind");
let remote_control_url = remote_control_url_for_listener(&listener);
let codex_home = TempDir::new().expect("temp dir should create");
let (transport_event_tx, _transport_event_rx) =
mpsc::channel::<TransportEvent>(CHANNEL_CAPACITY);
let shutdown_token = CancellationToken::new();
let (remote_task, remote_handle) = start_remote_control(
RemoteControlStartConfig {
remote_control_url,
installation_id: TEST_INSTALLATION_ID.to_string(),
},
Some(remote_control_state_runtime(&codex_home).await),
remote_control_auth_manager(),
transport_event_tx,
shutdown_token.clone(),
/*app_server_client_name_rx*/ None,
/*initial_enabled*/ true,
)
.await
.expect("remote control should start");
let enroll_request = accept_http_request(&listener).await;
respond_with_json(
enroll_request.stream,
remote_control_server_token_response(
"srv_e_test",
"env_test",
TEST_REMOTE_CONTROL_SERVER_TOKEN,
),
)
.await;
let mut first_websocket = accept_remote_control_connection(&listener).await;
timeout(Duration::from_secs(5), async {
while remote_handle.status().status != RemoteControlConnectionStatus::Connected {
tokio::task::yield_now().await;
}
})
.await
.expect("remote control should publish connected before pairing");
let stale_pairing_task = tokio::spawn({
let remote_handle = remote_handle.clone();
async move {
remote_handle
.start_pairing(RemoteControlPairingStartParams::default())
.await
}
});
let pairing_request = accept_http_request(&listener).await;
assert_eq!(
pairing_request.headers.get("authorization"),
Some(&format!("Bearer {TEST_REMOTE_CONTROL_SERVER_TOKEN}"))
);
respond_with_status(pairing_request.stream, pair_status, "stale token").await;
assert_eq!(
stale_pairing_task
.await
.expect("stale pairing task should join")
.expect_err("stale pairing token should be rejected")
.kind(),
if pair_status == "404 Not Found" {
std::io::ErrorKind::NotFound
} else {
std::io::ErrorKind::PermissionDenied
}
);
assert!(
remote_handle
.pairing
.client
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.is_none(),
"pairing rejection should clear cached pairing auth"
);
assert_eq!(*remote_handle.pairing_refresh_tx.borrow(), 1);
let refresh_request = accept_http_request(&listener).await;
assert_eq!(
refresh_request.request_line,
"POST /backend-api/wham/remote/control/server/refresh HTTP/1.1"
);
respond_with_json(
refresh_request.stream,
remote_control_server_token_response(
"srv_e_test",
"env_test",
TEST_REFRESHED_REMOTE_CONTROL_SERVER_TOKEN,
),
)
.await;
assert!(
timeout(Duration::from_millis(100), first_websocket.next())
.await
.is_err(),
"pairing auth refresh should keep the websocket open"
);
let refreshed_pairing_task = tokio::spawn({
let remote_handle = remote_handle.clone();
async move {
remote_handle
.start_pairing(RemoteControlPairingStartParams::default())
.await
}
});
let pairing_request = accept_http_request(&listener).await;
assert_eq!(
pairing_request.headers.get("authorization"),
Some(&format!(
"Bearer {TEST_REFRESHED_REMOTE_CONTROL_SERVER_TOKEN}"
))
);
respond_with_json(
pairing_request.stream,
json!({
"pairing_code": "pairing-code",
"manual_pairing_code": "ABCD-EFGH",
"server_id": "srv_e_test",
"environment_id": "env_test",
"expires_at": "3026-05-22T12:34:56Z",
}),
)
.await;
refreshed_pairing_task
.await
.expect("refreshed pairing task should join")
.expect("refreshed pairing token should pair");
first_websocket
.close(None)
.await
.expect("websocket should close");
shutdown_token.cancel();
let _ = remote_task.await;
}

View File

@@ -163,6 +163,7 @@ fn remote_control_handle_with_pairing_client(
client: pairing_client,
generation: Arc::new(std::sync::atomic::AtomicU64::new(0)),
},
pairing_request_cancellation_revision: Arc::new(std::sync::atomic::AtomicU64::new(0)),
pairing_refresh_tx: Arc::new(pairing_refresh_tx),
auth_change_rx: Arc::new(StdMutex::new(auth_change_rx)),
}