mirror of
https://github.com/openai/codex.git
synced 2026-06-02 11:22:01 +00:00
fix(app-server-transport): recover rejected pairing auth
This commit is contained in:
@@ -60,11 +60,19 @@ pub struct RemoteControlHandle {
|
||||
status_tx: Arc<watch::Sender<RemoteControlStatusChangedNotification>>,
|
||||
state_db_available: bool,
|
||||
pairing: PairingClientState,
|
||||
#[cfg(test)]
|
||||
pairing_request_cancellation_revision: Arc<AtomicU64>,
|
||||
pairing_refresh_tx: Arc<watch::Sender<u64>>,
|
||||
auth_change_rx: Arc<StdMutex<watch::Receiver<u64>>>,
|
||||
}
|
||||
|
||||
// `/server/pair` runs outside the websocket task, so keep the auth and disable
|
||||
// revisions that make its eventual response safe to return to app-server.
|
||||
struct PairingRequestContext {
|
||||
auth_change_revision: u64,
|
||||
cancellation_revision: u64,
|
||||
pairing_client: RemoteControlPairingClient,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(super) struct PairingClientState {
|
||||
client: Arc<StdMutex<Option<RemoteControlPairingClient>>>,
|
||||
@@ -135,6 +143,7 @@ impl RemoteControlHandle {
|
||||
changed
|
||||
});
|
||||
clear_pairing_client(&self.pairing);
|
||||
self.cancel_pending_pairing_requests();
|
||||
|
||||
let status = self.status();
|
||||
info!(
|
||||
@@ -148,6 +157,11 @@ impl RemoteControlHandle {
|
||||
self.publish_status(RemoteControlConnectionStatus::Disabled)
|
||||
}
|
||||
|
||||
pub fn cancel_pending_pairing_requests(&self) {
|
||||
self.pairing_request_cancellation_revision
|
||||
.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn status(&self) -> RemoteControlStatusChangedNotification {
|
||||
self.status_tx.borrow().clone()
|
||||
}
|
||||
@@ -160,6 +174,22 @@ impl RemoteControlHandle {
|
||||
&self,
|
||||
params: RemoteControlPairingStartParams,
|
||||
) -> io::Result<RemoteControlPairingStartResponse> {
|
||||
let pairing_request = self.pairing_request_context()?;
|
||||
let pairing_response = pairing_request
|
||||
.pairing_client
|
||||
.start(protocol::StartRemoteControlPairingRequest {
|
||||
manual_code: params.manual_code,
|
||||
})
|
||||
.await;
|
||||
self.refresh_pairing_auth_after_rejection(
|
||||
&pairing_request.pairing_client,
|
||||
&pairing_response,
|
||||
);
|
||||
self.validate_pairing_response(&pairing_request, &pairing_response)?;
|
||||
pairing_response
|
||||
}
|
||||
|
||||
fn pairing_request_context(&self) -> io::Result<PairingRequestContext> {
|
||||
if !*self.enabled_tx.borrow() {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
@@ -189,18 +219,49 @@ impl RemoteControlHandle {
|
||||
current_pairing_client
|
||||
}
|
||||
.ok_or_else(Self::pairing_unavailable_error)?;
|
||||
let pairing_response = pairing_client
|
||||
.start(protocol::StartRemoteControlPairingRequest {
|
||||
manual_code: params.manual_code,
|
||||
})
|
||||
.await;
|
||||
if self.auth_change_revision() != auth_change_revision {
|
||||
|
||||
Ok(PairingRequestContext {
|
||||
auth_change_revision,
|
||||
cancellation_revision: self
|
||||
.pairing_request_cancellation_revision
|
||||
.load(Ordering::Relaxed),
|
||||
pairing_client,
|
||||
})
|
||||
}
|
||||
|
||||
fn refresh_pairing_auth_after_rejection(
|
||||
&self,
|
||||
pairing_client: &RemoteControlPairingClient,
|
||||
pairing_response: &io::Result<RemoteControlPairingStartResponse>,
|
||||
) {
|
||||
if pairing_response.as_ref().is_err_and(|err| {
|
||||
matches!(
|
||||
err.kind(),
|
||||
io::ErrorKind::NotFound | io::ErrorKind::PermissionDenied
|
||||
)
|
||||
}) && clear_pairing_client_if_current(&self.pairing, pairing_client)
|
||||
{
|
||||
self.request_pairing_auth_refresh();
|
||||
}
|
||||
}
|
||||
|
||||
fn validate_pairing_response(
|
||||
&self,
|
||||
pairing_request: &PairingRequestContext,
|
||||
pairing_response: &io::Result<RemoteControlPairingStartResponse>,
|
||||
) -> io::Result<()> {
|
||||
if self.auth_change_revision() != pairing_request.auth_change_revision {
|
||||
return Err(Self::pairing_unavailable_error());
|
||||
}
|
||||
if pairing_response.is_ok() && !self.pairing_client_is_current(&pairing_client) {
|
||||
if pairing_response.is_ok()
|
||||
&& self
|
||||
.pairing_request_cancellation_revision
|
||||
.load(Ordering::Relaxed)
|
||||
!= pairing_request.cancellation_revision
|
||||
{
|
||||
return Err(Self::pairing_unavailable_error());
|
||||
}
|
||||
pairing_response
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn auth_change_revision(&self) -> u64 {
|
||||
@@ -218,21 +279,6 @@ impl RemoteControlHandle {
|
||||
)
|
||||
}
|
||||
|
||||
fn pairing_client_is_current(&self, pairing_client: &RemoteControlPairingClient) -> bool {
|
||||
*self.enabled_tx.borrow()
|
||||
&& self.status().status == RemoteControlConnectionStatus::Connected
|
||||
&& self
|
||||
.pairing
|
||||
.client
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
||||
.as_ref()
|
||||
.is_some_and(|current_pairing_client| {
|
||||
current_pairing_client.generation() == pairing_client.generation()
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn request_pairing_auth_refresh(&self) {
|
||||
self.pairing_refresh_tx
|
||||
.send_modify(|revision| *revision = revision.wrapping_add(1));
|
||||
@@ -293,6 +339,25 @@ fn clear_pairing_client(pairing: &PairingClientState) {
|
||||
pairing.generation.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
fn clear_pairing_client_if_current(
|
||||
pairing: &PairingClientState,
|
||||
expected_pairing_client: &RemoteControlPairingClient,
|
||||
) -> bool {
|
||||
let mut pairing_client = pairing
|
||||
.client
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
if pairing_client
|
||||
.as_ref()
|
||||
.is_none_or(|pairing_client| !pairing_client.matches_pairing_auth(expected_pairing_client))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
*pairing_client = None;
|
||||
pairing.generation.fetch_add(1, Ordering::Relaxed);
|
||||
true
|
||||
}
|
||||
|
||||
pub async fn start_remote_control(
|
||||
config: RemoteControlStartConfig,
|
||||
state_db: Option<Arc<StateRuntime>>,
|
||||
@@ -317,6 +382,7 @@ pub async fn start_remote_control(
|
||||
let (enabled_tx, enabled_rx) = watch::channel(initial_enabled);
|
||||
let pairing = PairingClientState::new();
|
||||
let websocket_pairing = pairing.clone();
|
||||
let pairing_request_cancellation_revision = Arc::new(AtomicU64::new(0));
|
||||
let (pairing_refresh_tx, pairing_refresh_rx) = watch::channel(0u64);
|
||||
let websocket_pairing_refresh_tx = pairing_refresh_tx.clone();
|
||||
let auth_change_rx = Arc::new(StdMutex::new(auth_manager.auth_change_receiver()));
|
||||
@@ -415,7 +481,7 @@ pub async fn start_remote_control(
|
||||
status_tx: Arc::new(status_tx),
|
||||
state_db_available,
|
||||
pairing,
|
||||
#[cfg(test)]
|
||||
pairing_request_cancellation_revision,
|
||||
pairing_refresh_tx: Arc::new(pairing_refresh_tx),
|
||||
auth_change_rx,
|
||||
},
|
||||
|
||||
@@ -48,8 +48,13 @@ impl RemoteControlPairingClient {
|
||||
self.auth_change_revision == auth_change_revision
|
||||
}
|
||||
|
||||
pub(super) fn generation(&self) -> u64 {
|
||||
self.generation
|
||||
pub(super) fn matches_pairing_auth(&self, other: &Self) -> bool {
|
||||
self.pairing_url == other.pairing_url
|
||||
&& self.remote_control_token == other.remote_control_token
|
||||
&& self.server_id == other.server_id
|
||||
&& self.environment_id == other.environment_id
|
||||
&& self.auth_change_revision == other.auth_change_revision
|
||||
&& self.generation == other.generation
|
||||
}
|
||||
|
||||
pub(super) async fn start(
|
||||
@@ -86,11 +91,19 @@ impl RemoteControlPairingClient {
|
||||
})?;
|
||||
let body_preview = preview_remote_control_response_body(&body);
|
||||
if !status.is_success() {
|
||||
return Err(io::Error::other(format!(
|
||||
"remote control pairing failed at `{}`: HTTP {status}, {}, body: {body_preview}",
|
||||
self.pairing_url,
|
||||
format_headers(&headers)
|
||||
)));
|
||||
let error_kind = match status.as_u16() {
|
||||
401 | 403 => ErrorKind::PermissionDenied,
|
||||
404 => ErrorKind::NotFound,
|
||||
_ => ErrorKind::Other,
|
||||
};
|
||||
return Err(io::Error::new(
|
||||
error_kind,
|
||||
format!(
|
||||
"remote control pairing failed at `{}`: HTTP {status}, {}, body: {body_preview}",
|
||||
self.pairing_url,
|
||||
format_headers(&headers)
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
let pairing = serde_json::from_slice::<StartRemoteControlPairingResponse>(&body).map_err(
|
||||
|
||||
@@ -300,6 +300,162 @@ async fn remote_control_handle_keeps_pairing_response_after_pairing_auth_refresh
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn remote_control_handle_keeps_pairing_response_after_connection_cycle_end() {
|
||||
let listener = TcpListener::bind("127.0.0.1:0")
|
||||
.await
|
||||
.expect("listener should bind");
|
||||
let remote_control_url = remote_control_url_for_listener(&listener);
|
||||
let remote_handle = remote_control_handle_with_pairing_client(
|
||||
&remote_control_url,
|
||||
watch::channel(/*init*/ 0u64).1,
|
||||
);
|
||||
let pairing_task = tokio::spawn({
|
||||
let remote_handle = remote_handle.clone();
|
||||
async move {
|
||||
remote_handle
|
||||
.start_pairing(RemoteControlPairingStartParams::default())
|
||||
.await
|
||||
}
|
||||
});
|
||||
|
||||
let pairing_request = accept_http_request(&listener).await;
|
||||
clear_pairing_client(&remote_handle.pairing);
|
||||
respond_with_json(
|
||||
pairing_request.stream,
|
||||
json!({
|
||||
"pairing_code": "fresh-pairing-code",
|
||||
"manual_pairing_code": "ABCD-EFGH",
|
||||
"server_id": "srv_e_test",
|
||||
"environment_id": "env_test",
|
||||
"expires_at": "3026-05-22T12:34:56Z",
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
assert_eq!(
|
||||
pairing_task
|
||||
.await
|
||||
.expect("pairing task should join")
|
||||
.expect("pairing response should be kept")
|
||||
.pairing_code,
|
||||
"fresh-pairing-code"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn remote_control_handle_discards_pairing_response_after_disable_and_reenable() {
|
||||
let listener = TcpListener::bind("127.0.0.1:0")
|
||||
.await
|
||||
.expect("listener should bind");
|
||||
let remote_control_url = remote_control_url_for_listener(&listener);
|
||||
let remote_handle = remote_control_handle_with_pairing_client(
|
||||
&remote_control_url,
|
||||
watch::channel(/*init*/ 0u64).1,
|
||||
);
|
||||
let pairing_task = tokio::spawn({
|
||||
let remote_handle = remote_handle.clone();
|
||||
async move {
|
||||
remote_handle
|
||||
.start_pairing(RemoteControlPairingStartParams::default())
|
||||
.await
|
||||
}
|
||||
});
|
||||
|
||||
let pairing_request = accept_http_request(&listener).await;
|
||||
remote_handle.disable();
|
||||
remote_handle.enable().expect("enable should succeed");
|
||||
respond_with_json(
|
||||
pairing_request.stream,
|
||||
json!({
|
||||
"pairing_code": "stale-pairing-code",
|
||||
"manual_pairing_code": "ABCD-EFGH",
|
||||
"server_id": "srv_e_test",
|
||||
"environment_id": "env_test",
|
||||
"expires_at": "3026-05-22T12:34:56Z",
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
assert_eq!(
|
||||
pairing_task
|
||||
.await
|
||||
.expect("pairing task should join")
|
||||
.expect_err("disabled remote control should discard pairing response")
|
||||
.to_string(),
|
||||
"remote control pairing is unavailable until enrollment completes"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn remote_control_handle_keeps_refreshed_pairing_auth_after_stale_rejection() {
|
||||
let listener = TcpListener::bind("127.0.0.1:0")
|
||||
.await
|
||||
.expect("listener should bind");
|
||||
let remote_control_url = remote_control_url_for_listener(&listener);
|
||||
let remote_handle = remote_control_handle_with_pairing_client(
|
||||
&remote_control_url,
|
||||
watch::channel(/*init*/ 0u64).1,
|
||||
);
|
||||
let pairing_task = tokio::spawn({
|
||||
let remote_handle = remote_handle.clone();
|
||||
async move {
|
||||
remote_handle
|
||||
.start_pairing(RemoteControlPairingStartParams::default())
|
||||
.await
|
||||
}
|
||||
});
|
||||
|
||||
let pairing_request = accept_http_request(&listener).await;
|
||||
let generation = remote_handle
|
||||
.pairing
|
||||
.generation
|
||||
.load(std::sync::atomic::Ordering::Relaxed);
|
||||
let refreshed_pairing_client = RemoteControlPairingClient::new(
|
||||
&normalize_remote_control_url(&remote_control_url)
|
||||
.expect("remote control url should normalize"),
|
||||
TEST_REFRESHED_REMOTE_CONTROL_SERVER_TOKEN.to_string(),
|
||||
"srv_e_test".to_string(),
|
||||
"env_test".to_string(),
|
||||
OffsetDateTime::parse(
|
||||
TEST_REMOTE_CONTROL_SERVER_TOKEN_EXPIRES_AT,
|
||||
&time::format_description::well_known::Rfc3339,
|
||||
)
|
||||
.expect("server token expiry should parse"),
|
||||
/*auth_change_revision*/ 0,
|
||||
generation,
|
||||
);
|
||||
*remote_handle
|
||||
.pairing
|
||||
.client
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner) =
|
||||
Some(refreshed_pairing_client.clone());
|
||||
respond_with_status(pairing_request.stream, "401 Unauthorized", "stale token").await;
|
||||
|
||||
assert_eq!(
|
||||
pairing_task
|
||||
.await
|
||||
.expect("pairing task should join")
|
||||
.expect_err("stale pairing token should be rejected")
|
||||
.kind(),
|
||||
std::io::ErrorKind::PermissionDenied
|
||||
);
|
||||
assert!(
|
||||
remote_handle
|
||||
.pairing
|
||||
.client
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
||||
.as_ref()
|
||||
.is_some_and(|pairing_client| {
|
||||
pairing_client.matches_pairing_auth(&refreshed_pairing_client)
|
||||
}),
|
||||
"stale pairing rejection should keep refreshed pairing auth"
|
||||
);
|
||||
assert_eq!(*remote_handle.pairing_refresh_tx.borrow(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn remote_control_handle_clears_pairing_client_after_auth_change() {
|
||||
let listener = TcpListener::bind("127.0.0.1:0")
|
||||
@@ -809,3 +965,153 @@ async fn remote_control_connected_refresh_404_reenrolls() {
|
||||
shutdown_token.cancel();
|
||||
let _ = remote_task.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn remote_control_pairing_rejection_refreshes_server_token_while_connected() {
|
||||
remote_control_pairing_rejection_recovers_while_connected("401 Unauthorized").await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn remote_control_pairing_not_found_refreshes_server_token_while_connected() {
|
||||
remote_control_pairing_rejection_recovers_while_connected("404 Not Found").await;
|
||||
}
|
||||
|
||||
async fn remote_control_pairing_rejection_recovers_while_connected(pair_status: &str) {
|
||||
let listener = TcpListener::bind("127.0.0.1:0")
|
||||
.await
|
||||
.expect("listener should bind");
|
||||
let remote_control_url = remote_control_url_for_listener(&listener);
|
||||
let codex_home = TempDir::new().expect("temp dir should create");
|
||||
let (transport_event_tx, _transport_event_rx) =
|
||||
mpsc::channel::<TransportEvent>(CHANNEL_CAPACITY);
|
||||
let shutdown_token = CancellationToken::new();
|
||||
let (remote_task, remote_handle) = start_remote_control(
|
||||
RemoteControlStartConfig {
|
||||
remote_control_url,
|
||||
installation_id: TEST_INSTALLATION_ID.to_string(),
|
||||
},
|
||||
Some(remote_control_state_runtime(&codex_home).await),
|
||||
remote_control_auth_manager(),
|
||||
transport_event_tx,
|
||||
shutdown_token.clone(),
|
||||
/*app_server_client_name_rx*/ None,
|
||||
/*initial_enabled*/ true,
|
||||
)
|
||||
.await
|
||||
.expect("remote control should start");
|
||||
|
||||
let enroll_request = accept_http_request(&listener).await;
|
||||
respond_with_json(
|
||||
enroll_request.stream,
|
||||
remote_control_server_token_response(
|
||||
"srv_e_test",
|
||||
"env_test",
|
||||
TEST_REMOTE_CONTROL_SERVER_TOKEN,
|
||||
),
|
||||
)
|
||||
.await;
|
||||
let mut first_websocket = accept_remote_control_connection(&listener).await;
|
||||
timeout(Duration::from_secs(5), async {
|
||||
while remote_handle.status().status != RemoteControlConnectionStatus::Connected {
|
||||
tokio::task::yield_now().await;
|
||||
}
|
||||
})
|
||||
.await
|
||||
.expect("remote control should publish connected before pairing");
|
||||
|
||||
let stale_pairing_task = tokio::spawn({
|
||||
let remote_handle = remote_handle.clone();
|
||||
async move {
|
||||
remote_handle
|
||||
.start_pairing(RemoteControlPairingStartParams::default())
|
||||
.await
|
||||
}
|
||||
});
|
||||
let pairing_request = accept_http_request(&listener).await;
|
||||
assert_eq!(
|
||||
pairing_request.headers.get("authorization"),
|
||||
Some(&format!("Bearer {TEST_REMOTE_CONTROL_SERVER_TOKEN}"))
|
||||
);
|
||||
respond_with_status(pairing_request.stream, pair_status, "stale token").await;
|
||||
assert_eq!(
|
||||
stale_pairing_task
|
||||
.await
|
||||
.expect("stale pairing task should join")
|
||||
.expect_err("stale pairing token should be rejected")
|
||||
.kind(),
|
||||
if pair_status == "404 Not Found" {
|
||||
std::io::ErrorKind::NotFound
|
||||
} else {
|
||||
std::io::ErrorKind::PermissionDenied
|
||||
}
|
||||
);
|
||||
assert!(
|
||||
remote_handle
|
||||
.pairing
|
||||
.client
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
||||
.is_none(),
|
||||
"pairing rejection should clear cached pairing auth"
|
||||
);
|
||||
assert_eq!(*remote_handle.pairing_refresh_tx.borrow(), 1);
|
||||
|
||||
let refresh_request = accept_http_request(&listener).await;
|
||||
assert_eq!(
|
||||
refresh_request.request_line,
|
||||
"POST /backend-api/wham/remote/control/server/refresh HTTP/1.1"
|
||||
);
|
||||
respond_with_json(
|
||||
refresh_request.stream,
|
||||
remote_control_server_token_response(
|
||||
"srv_e_test",
|
||||
"env_test",
|
||||
TEST_REFRESHED_REMOTE_CONTROL_SERVER_TOKEN,
|
||||
),
|
||||
)
|
||||
.await;
|
||||
assert!(
|
||||
timeout(Duration::from_millis(100), first_websocket.next())
|
||||
.await
|
||||
.is_err(),
|
||||
"pairing auth refresh should keep the websocket open"
|
||||
);
|
||||
|
||||
let refreshed_pairing_task = tokio::spawn({
|
||||
let remote_handle = remote_handle.clone();
|
||||
async move {
|
||||
remote_handle
|
||||
.start_pairing(RemoteControlPairingStartParams::default())
|
||||
.await
|
||||
}
|
||||
});
|
||||
let pairing_request = accept_http_request(&listener).await;
|
||||
assert_eq!(
|
||||
pairing_request.headers.get("authorization"),
|
||||
Some(&format!(
|
||||
"Bearer {TEST_REFRESHED_REMOTE_CONTROL_SERVER_TOKEN}"
|
||||
))
|
||||
);
|
||||
respond_with_json(
|
||||
pairing_request.stream,
|
||||
json!({
|
||||
"pairing_code": "pairing-code",
|
||||
"manual_pairing_code": "ABCD-EFGH",
|
||||
"server_id": "srv_e_test",
|
||||
"environment_id": "env_test",
|
||||
"expires_at": "3026-05-22T12:34:56Z",
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
refreshed_pairing_task
|
||||
.await
|
||||
.expect("refreshed pairing task should join")
|
||||
.expect("refreshed pairing token should pair");
|
||||
|
||||
first_websocket
|
||||
.close(None)
|
||||
.await
|
||||
.expect("websocket should close");
|
||||
shutdown_token.cancel();
|
||||
let _ = remote_task.await;
|
||||
}
|
||||
|
||||
@@ -163,6 +163,7 @@ fn remote_control_handle_with_pairing_client(
|
||||
client: pairing_client,
|
||||
generation: Arc::new(std::sync::atomic::AtomicU64::new(0)),
|
||||
},
|
||||
pairing_request_cancellation_revision: Arc::new(std::sync::atomic::AtomicU64::new(0)),
|
||||
pairing_refresh_tx: Arc::new(pairing_refresh_tx),
|
||||
auth_change_rx: Arc::new(StdMutex::new(auth_change_rx)),
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user