mirror of
https://github.com/openai/codex.git
synced 2026-06-04 04:12:03 +00:00
Compare commits
10 Commits
codemode_i
...
codex/remo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7bf007ddd9 | ||
|
|
1f77ff722b | ||
|
|
df31943eb8 | ||
|
|
974cee60db | ||
|
|
5b7077c8bb | ||
|
|
53faa76492 | ||
|
|
84529b9f5f | ||
|
|
413ae91e53 | ||
|
|
19dd22aebf | ||
|
|
f73c101e1e |
@@ -44,6 +44,24 @@ pub struct RemoteControlStatusReadResponse {
|
||||
pub environment_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, Default, PartialEq, Eq, JsonSchema, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(export_to = "v2/")]
|
||||
pub struct RemoteControlPairingStartParams {
|
||||
#[serde(default, skip_serializing_if = "std::ops::Not::not")]
|
||||
pub manual_code: bool,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, JsonSchema, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(export_to = "v2/")]
|
||||
pub struct RemoteControlPairingStartResponse {
|
||||
pub pairing_code: String,
|
||||
pub manual_pairing_code: Option<String>,
|
||||
pub environment_id: String,
|
||||
pub expires_at: i64,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq, JsonSchema, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(rename_all = "camelCase", export_to = "v2/")]
|
||||
|
||||
@@ -46,6 +46,13 @@ impl RemoteControlEnrollment {
|
||||
})
|
||||
}
|
||||
|
||||
pub(super) fn server_token_refresh_delay(&self) -> Option<std::time::Duration> {
|
||||
let refresh_at = self.expires_at?
|
||||
- time::Duration::seconds(REMOTE_CONTROL_SERVER_TOKEN_REFRESH_SKEW_SECS);
|
||||
let refresh_delay = refresh_at - OffsetDateTime::now_utc();
|
||||
Some(std::time::Duration::try_from(refresh_delay).unwrap_or_default())
|
||||
}
|
||||
|
||||
pub(super) fn clear_server_token(&mut self) {
|
||||
self.remote_control_token = None;
|
||||
self.expires_at = None;
|
||||
@@ -211,8 +218,14 @@ fn redact_remote_control_response_body(body: &str) -> String {
|
||||
let Some(body_object) = body_json.as_object_mut() else {
|
||||
return body.to_string();
|
||||
};
|
||||
if let Some(remote_control_token) = body_object.get_mut("remote_control_token") {
|
||||
*remote_control_token = serde_json::Value::String("<redacted>".to_string());
|
||||
for sensitive_field in [
|
||||
"remote_control_token",
|
||||
"pairing_code",
|
||||
"manual_pairing_code",
|
||||
] {
|
||||
if let Some(value) = body_object.get_mut(sensitive_field) {
|
||||
*value = serde_json::Value::String("<redacted>".to_string());
|
||||
}
|
||||
}
|
||||
body_json.to_string()
|
||||
}
|
||||
@@ -429,6 +442,22 @@ mod tests {
|
||||
assert!(!expires_later.should_refresh_server_token());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn remote_control_enrollment_schedules_server_token_refresh_before_expiry() {
|
||||
let refresh_delay = RemoteControlEnrollment {
|
||||
account_id: "account_id".to_string(),
|
||||
environment_id: "environment_id".to_string(),
|
||||
server_id: "server_id".to_string(),
|
||||
server_name: "server_name".to_string(),
|
||||
remote_control_token: Some("remote-control-token".to_string()),
|
||||
expires_at: Some(OffsetDateTime::now_utc() + time::Duration::seconds(31)),
|
||||
}
|
||||
.server_token_refresh_delay()
|
||||
.expect("server token refresh should be scheduled");
|
||||
|
||||
assert!(refresh_delay <= std::time::Duration::from_secs(1));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn preview_remote_control_response_body_redacts_server_token() {
|
||||
assert_eq!(
|
||||
@@ -443,6 +472,20 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn preview_remote_control_response_body_redacts_pairing_codes() {
|
||||
assert_eq!(
|
||||
serde_json::from_str::<serde_json::Value>(&preview_remote_control_response_body(
|
||||
br#"{"pairing_code":"pairing-code","manual_pairing_code":"ABCD-EFGH"}"#
|
||||
))
|
||||
.expect("redacted response preview should stay valid json"),
|
||||
json!({
|
||||
"pairing_code": "<redacted>",
|
||||
"manual_pairing_code": "<redacted>",
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn persisted_remote_control_enrollment_round_trips_by_target_and_account() {
|
||||
let codex_home = TempDir::new().expect("temp dir should create");
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
mod client_tracker;
|
||||
mod enroll;
|
||||
mod pairing;
|
||||
mod protocol;
|
||||
mod segment;
|
||||
mod websocket;
|
||||
|
||||
use self::pairing::RemoteControlPairingClient;
|
||||
use crate::transport::remote_control::websocket::RemoteControlChannels;
|
||||
use crate::transport::remote_control::websocket::RemoteControlStatusPublisher;
|
||||
use crate::transport::remote_control::websocket::RemoteControlWebsocket;
|
||||
@@ -16,6 +18,8 @@ use super::CHANNEL_CAPACITY;
|
||||
use super::TransportEvent;
|
||||
use super::next_connection_id;
|
||||
use codex_app_server_protocol::RemoteControlConnectionStatus;
|
||||
use codex_app_server_protocol::RemoteControlPairingStartParams;
|
||||
use codex_app_server_protocol::RemoteControlPairingStartResponse;
|
||||
use codex_app_server_protocol::RemoteControlStatusChangedNotification;
|
||||
use codex_login::AuthManager;
|
||||
use codex_state::StateRuntime;
|
||||
@@ -26,6 +30,9 @@ use std::fmt;
|
||||
use std::io;
|
||||
use std::panic::AssertUnwindSafe;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex as StdMutex;
|
||||
use std::sync::atomic::AtomicU64;
|
||||
use std::sync::atomic::Ordering;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::oneshot;
|
||||
use tokio::sync::watch;
|
||||
@@ -52,6 +59,33 @@ pub struct RemoteControlHandle {
|
||||
enabled_tx: Arc<watch::Sender<bool>>,
|
||||
status_tx: Arc<watch::Sender<RemoteControlStatusChangedNotification>>,
|
||||
state_db_available: bool,
|
||||
pairing: PairingClientState,
|
||||
pairing_request_cancellation_revision: Arc<AtomicU64>,
|
||||
pairing_refresh_tx: Arc<watch::Sender<u64>>,
|
||||
auth_change_rx: Arc<StdMutex<watch::Receiver<u64>>>,
|
||||
}
|
||||
|
||||
// `/server/pair` runs outside the websocket task, so keep the auth and disable
|
||||
// revisions that make its eventual response safe to return to app-server.
|
||||
struct PairingRequestContext {
|
||||
auth_change_revision: u64,
|
||||
cancellation_revision: u64,
|
||||
pairing_client: RemoteControlPairingClient,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(super) struct PairingClientState {
|
||||
client: Arc<StdMutex<Option<RemoteControlPairingClient>>>,
|
||||
generation: Arc<AtomicU64>,
|
||||
}
|
||||
|
||||
impl PairingClientState {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
client: Arc::new(StdMutex::new(None)),
|
||||
generation: Arc::new(AtomicU64::new(0)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
@@ -108,6 +142,8 @@ impl RemoteControlHandle {
|
||||
*state = false;
|
||||
changed
|
||||
});
|
||||
clear_pairing_client(&self.pairing);
|
||||
self.cancel_pending_pairing_requests();
|
||||
|
||||
let status = self.status();
|
||||
info!(
|
||||
@@ -121,6 +157,11 @@ impl RemoteControlHandle {
|
||||
self.publish_status(RemoteControlConnectionStatus::Disabled)
|
||||
}
|
||||
|
||||
pub fn cancel_pending_pairing_requests(&self) {
|
||||
self.pairing_request_cancellation_revision
|
||||
.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn status(&self) -> RemoteControlStatusChangedNotification {
|
||||
self.status_tx.borrow().clone()
|
||||
}
|
||||
@@ -129,6 +170,120 @@ impl RemoteControlHandle {
|
||||
self.status_tx.subscribe()
|
||||
}
|
||||
|
||||
pub async fn start_pairing(
|
||||
&self,
|
||||
params: RemoteControlPairingStartParams,
|
||||
) -> io::Result<RemoteControlPairingStartResponse> {
|
||||
let pairing_request = self.pairing_request_context()?;
|
||||
let pairing_response = pairing_request
|
||||
.pairing_client
|
||||
.start(protocol::StartRemoteControlPairingRequest {
|
||||
manual_code: params.manual_code,
|
||||
})
|
||||
.await;
|
||||
self.refresh_pairing_auth_after_rejection(
|
||||
&pairing_request.pairing_client,
|
||||
&pairing_response,
|
||||
);
|
||||
self.validate_pairing_response(&pairing_request, &pairing_response)?;
|
||||
pairing_response
|
||||
}
|
||||
|
||||
fn pairing_request_context(&self) -> io::Result<PairingRequestContext> {
|
||||
if !*self.enabled_tx.borrow() {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"remote control pairing requires remote control to be enabled",
|
||||
));
|
||||
}
|
||||
if self.status().status != RemoteControlConnectionStatus::Connected {
|
||||
return Err(Self::pairing_unavailable_error());
|
||||
}
|
||||
|
||||
let auth_change_revision = self.auth_change_revision();
|
||||
let pairing_client = {
|
||||
let mut pairing_client = self
|
||||
.pairing
|
||||
.client
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
let current_pairing_client = pairing_client
|
||||
.as_ref()
|
||||
.filter(|pairing_client| {
|
||||
pairing_client.matches_auth_change_revision(auth_change_revision)
|
||||
})
|
||||
.cloned();
|
||||
if current_pairing_client.is_none() {
|
||||
*pairing_client = None;
|
||||
}
|
||||
current_pairing_client
|
||||
}
|
||||
.ok_or_else(Self::pairing_unavailable_error)?;
|
||||
|
||||
Ok(PairingRequestContext {
|
||||
auth_change_revision,
|
||||
cancellation_revision: self
|
||||
.pairing_request_cancellation_revision
|
||||
.load(Ordering::Relaxed),
|
||||
pairing_client,
|
||||
})
|
||||
}
|
||||
|
||||
fn refresh_pairing_auth_after_rejection(
|
||||
&self,
|
||||
pairing_client: &RemoteControlPairingClient,
|
||||
pairing_response: &io::Result<RemoteControlPairingStartResponse>,
|
||||
) {
|
||||
if pairing_response.as_ref().is_err_and(|err| {
|
||||
matches!(
|
||||
err.kind(),
|
||||
io::ErrorKind::NotFound | io::ErrorKind::PermissionDenied
|
||||
)
|
||||
}) && clear_pairing_client_if_current(&self.pairing, pairing_client)
|
||||
{
|
||||
self.request_pairing_auth_refresh();
|
||||
}
|
||||
}
|
||||
|
||||
fn validate_pairing_response(
|
||||
&self,
|
||||
pairing_request: &PairingRequestContext,
|
||||
pairing_response: &io::Result<RemoteControlPairingStartResponse>,
|
||||
) -> io::Result<()> {
|
||||
if self.auth_change_revision() != pairing_request.auth_change_revision {
|
||||
return Err(Self::pairing_unavailable_error());
|
||||
}
|
||||
if pairing_response.is_ok()
|
||||
&& self
|
||||
.pairing_request_cancellation_revision
|
||||
.load(Ordering::Relaxed)
|
||||
!= pairing_request.cancellation_revision
|
||||
{
|
||||
return Err(Self::pairing_unavailable_error());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn auth_change_revision(&self) -> u64 {
|
||||
*self
|
||||
.auth_change_rx
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
||||
.borrow()
|
||||
}
|
||||
|
||||
fn pairing_unavailable_error() -> io::Error {
|
||||
io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"remote control pairing is unavailable until enrollment completes",
|
||||
)
|
||||
}
|
||||
|
||||
fn request_pairing_auth_refresh(&self) {
|
||||
self.pairing_refresh_tx
|
||||
.send_modify(|revision| *revision = revision.wrapping_add(1));
|
||||
}
|
||||
|
||||
fn publish_status(
|
||||
&self,
|
||||
connection_status: RemoteControlConnectionStatus,
|
||||
@@ -176,6 +331,33 @@ fn remote_control_status_with_connection_status(
|
||||
}
|
||||
}
|
||||
|
||||
fn clear_pairing_client(pairing: &PairingClientState) {
|
||||
*pairing
|
||||
.client
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner) = None;
|
||||
pairing.generation.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
fn clear_pairing_client_if_current(
|
||||
pairing: &PairingClientState,
|
||||
expected_pairing_client: &RemoteControlPairingClient,
|
||||
) -> bool {
|
||||
let mut pairing_client = pairing
|
||||
.client
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
if pairing_client
|
||||
.as_ref()
|
||||
.is_none_or(|pairing_client| !pairing_client.matches_pairing_auth(expected_pairing_client))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
*pairing_client = None;
|
||||
pairing.generation.fetch_add(1, Ordering::Relaxed);
|
||||
true
|
||||
}
|
||||
|
||||
pub async fn start_remote_control(
|
||||
config: RemoteControlStartConfig,
|
||||
state_db: Option<Arc<StateRuntime>>,
|
||||
@@ -198,6 +380,12 @@ pub async fn start_remote_control(
|
||||
};
|
||||
|
||||
let (enabled_tx, enabled_rx) = watch::channel(initial_enabled);
|
||||
let pairing = PairingClientState::new();
|
||||
let websocket_pairing = pairing.clone();
|
||||
let pairing_request_cancellation_revision = Arc::new(AtomicU64::new(0));
|
||||
let (pairing_refresh_tx, pairing_refresh_rx) = watch::channel(0u64);
|
||||
let websocket_pairing_refresh_tx = pairing_refresh_tx.clone();
|
||||
let auth_change_rx = Arc::new(StdMutex::new(auth_manager.auth_change_receiver()));
|
||||
let server_name = gethostname().to_string_lossy().trim().to_string();
|
||||
let remote_control_url = config.remote_control_url;
|
||||
let installation_id = config.installation_id;
|
||||
@@ -245,6 +433,9 @@ pub async fn start_remote_control(
|
||||
RemoteControlChannels {
|
||||
transport_event_tx,
|
||||
status_publisher,
|
||||
pairing: websocket_pairing,
|
||||
pairing_refresh_tx: websocket_pairing_refresh_tx,
|
||||
pairing_refresh_rx,
|
||||
},
|
||||
shutdown_token,
|
||||
enabled_rx,
|
||||
@@ -289,10 +480,16 @@ pub async fn start_remote_control(
|
||||
enabled_tx: Arc::new(enabled_tx),
|
||||
status_tx: Arc::new(status_tx),
|
||||
state_db_available,
|
||||
pairing,
|
||||
pairing_request_cancellation_revision,
|
||||
pairing_refresh_tx: Arc::new(pairing_refresh_tx),
|
||||
auth_change_rx,
|
||||
},
|
||||
))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod pairing_tests;
|
||||
#[cfg(test)]
|
||||
mod segment_tests;
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -0,0 +1,150 @@
|
||||
use super::enroll::format_headers;
|
||||
use super::enroll::preview_remote_control_response_body;
|
||||
use super::protocol::RemoteControlTarget;
|
||||
use super::protocol::StartRemoteControlPairingRequest;
|
||||
use super::protocol::StartRemoteControlPairingResponse;
|
||||
use codex_app_server_protocol::RemoteControlPairingStartResponse;
|
||||
use codex_login::default_client::build_reqwest_client;
|
||||
use std::io;
|
||||
use std::io::ErrorKind;
|
||||
use time::OffsetDateTime;
|
||||
use time::format_description::well_known::Rfc3339;
|
||||
|
||||
const REMOTE_CONTROL_PAIRING_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(30);
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(super) struct RemoteControlPairingClient {
|
||||
pairing_url: String,
|
||||
remote_control_token: String,
|
||||
server_id: String,
|
||||
environment_id: String,
|
||||
expires_at: OffsetDateTime,
|
||||
auth_change_revision: u64,
|
||||
generation: u64,
|
||||
}
|
||||
|
||||
impl RemoteControlPairingClient {
|
||||
pub(super) fn new(
|
||||
remote_control_target: &RemoteControlTarget,
|
||||
remote_control_token: String,
|
||||
server_id: String,
|
||||
environment_id: String,
|
||||
expires_at: OffsetDateTime,
|
||||
auth_change_revision: u64,
|
||||
generation: u64,
|
||||
) -> Self {
|
||||
Self {
|
||||
pairing_url: remote_control_target.pair_url.clone(),
|
||||
remote_control_token,
|
||||
server_id,
|
||||
environment_id,
|
||||
expires_at,
|
||||
auth_change_revision,
|
||||
generation,
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn matches_auth_change_revision(&self, auth_change_revision: u64) -> bool {
|
||||
self.auth_change_revision == auth_change_revision
|
||||
}
|
||||
|
||||
pub(super) fn matches_pairing_auth(&self, other: &Self) -> bool {
|
||||
self.pairing_url == other.pairing_url
|
||||
&& self.remote_control_token == other.remote_control_token
|
||||
&& self.server_id == other.server_id
|
||||
&& self.environment_id == other.environment_id
|
||||
&& self.auth_change_revision == other.auth_change_revision
|
||||
&& self.generation == other.generation
|
||||
}
|
||||
|
||||
pub(super) async fn start(
|
||||
&self,
|
||||
request: StartRemoteControlPairingRequest,
|
||||
) -> io::Result<RemoteControlPairingStartResponse> {
|
||||
if self.expires_at <= OffsetDateTime::now_utc() {
|
||||
return Err(io::Error::new(
|
||||
ErrorKind::InvalidInput,
|
||||
"remote control pairing is unavailable because the server token expired",
|
||||
));
|
||||
}
|
||||
|
||||
let response = build_reqwest_client()
|
||||
.post(&self.pairing_url)
|
||||
.timeout(REMOTE_CONTROL_PAIRING_TIMEOUT)
|
||||
.bearer_auth(&self.remote_control_token)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|err| {
|
||||
io::Error::other(format!(
|
||||
"failed to start remote control pairing at `{}`: {err}",
|
||||
self.pairing_url
|
||||
))
|
||||
})?;
|
||||
let headers = response.headers().clone();
|
||||
let status = response.status();
|
||||
let body = response.bytes().await.map_err(|err| {
|
||||
io::Error::other(format!(
|
||||
"failed to read remote control pairing response from `{}`: {err}",
|
||||
self.pairing_url
|
||||
))
|
||||
})?;
|
||||
let body_preview = preview_remote_control_response_body(&body);
|
||||
if !status.is_success() {
|
||||
let error_kind = match status.as_u16() {
|
||||
401 | 403 => ErrorKind::PermissionDenied,
|
||||
404 => ErrorKind::NotFound,
|
||||
_ => ErrorKind::Other,
|
||||
};
|
||||
return Err(io::Error::new(
|
||||
error_kind,
|
||||
format!(
|
||||
"remote control pairing failed at `{}`: HTTP {status}, {}, body: {body_preview}",
|
||||
self.pairing_url,
|
||||
format_headers(&headers)
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
let pairing = serde_json::from_slice::<StartRemoteControlPairingResponse>(&body).map_err(
|
||||
|err| {
|
||||
io::Error::other(format!(
|
||||
"failed to parse remote control pairing response from `{}`: HTTP {status}, {}, body: {body_preview}, decode error: {err}",
|
||||
self.pairing_url,
|
||||
format_headers(&headers)
|
||||
))
|
||||
},
|
||||
)?;
|
||||
let StartRemoteControlPairingResponse {
|
||||
pairing_code,
|
||||
manual_pairing_code,
|
||||
server_id,
|
||||
environment_id,
|
||||
expires_at,
|
||||
} = pairing;
|
||||
if server_id != self.server_id || environment_id != self.environment_id {
|
||||
return Err(io::Error::new(
|
||||
ErrorKind::InvalidData,
|
||||
format!(
|
||||
"remote control pairing returned mismatched enrollment: expected server_id={}, environment_id={}; got server_id={}, environment_id={}",
|
||||
self.server_id, self.environment_id, server_id, environment_id
|
||||
),
|
||||
));
|
||||
}
|
||||
let expires_at = OffsetDateTime::parse(&expires_at, &Rfc3339)
|
||||
.map_err(|err| {
|
||||
io::Error::new(
|
||||
ErrorKind::InvalidData,
|
||||
format!("invalid remote control pairing expires_at: {err}"),
|
||||
)
|
||||
})?
|
||||
.unix_timestamp();
|
||||
|
||||
Ok(RemoteControlPairingStartResponse {
|
||||
pairing_code,
|
||||
manual_pairing_code,
|
||||
environment_id,
|
||||
expires_at,
|
||||
})
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,273 @@
|
||||
use super::pairing::RemoteControlPairingClient;
|
||||
use super::protocol::RemoteControlTarget;
|
||||
use super::protocol::StartRemoteControlPairingRequest;
|
||||
use codex_app_server_protocol::RemoteControlPairingStartResponse;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
use time::OffsetDateTime;
|
||||
use tokio::io::AsyncBufReadExt;
|
||||
use tokio::io::AsyncReadExt;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::io::BufReader;
|
||||
use tokio::net::TcpListener;
|
||||
|
||||
#[tokio::test]
|
||||
async fn start_remote_control_pairing_uses_server_token_and_maps_response() {
|
||||
let listener = TcpListener::bind("127.0.0.1:0")
|
||||
.await
|
||||
.expect("listener should bind");
|
||||
let pair_url = format!(
|
||||
"http://{}/backend-api/wham/remote/control/server/pair",
|
||||
listener.local_addr().expect("listener should have addr")
|
||||
);
|
||||
let server_task = tokio::spawn(async move {
|
||||
let (stream, _) = listener.accept().await.expect("request should arrive");
|
||||
let mut reader = BufReader::new(stream);
|
||||
|
||||
let mut request_line = String::new();
|
||||
reader
|
||||
.read_line(&mut request_line)
|
||||
.await
|
||||
.expect("request line should read");
|
||||
assert_eq!(
|
||||
request_line.trim_end(),
|
||||
"POST /backend-api/wham/remote/control/server/pair HTTP/1.1"
|
||||
);
|
||||
|
||||
let mut authorization = None;
|
||||
let mut content_length = None;
|
||||
loop {
|
||||
let mut line = String::new();
|
||||
reader
|
||||
.read_line(&mut line)
|
||||
.await
|
||||
.expect("header line should read");
|
||||
if line == "\r\n" {
|
||||
break;
|
||||
}
|
||||
let (name, value) = line
|
||||
.trim_end()
|
||||
.split_once(": ")
|
||||
.expect("header should split");
|
||||
match name.to_ascii_lowercase().as_str() {
|
||||
"authorization" => authorization = Some(value.to_string()),
|
||||
"content-length" => {
|
||||
content_length =
|
||||
Some(value.parse::<usize>().expect("content length should parse"))
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
assert_eq!(
|
||||
authorization,
|
||||
Some("Bearer remote-control-token".to_string())
|
||||
);
|
||||
|
||||
let mut body = vec![0; content_length.expect("request should have body")];
|
||||
reader
|
||||
.read_exact(&mut body)
|
||||
.await
|
||||
.expect("request body should read");
|
||||
assert_eq!(
|
||||
serde_json::from_slice::<serde_json::Value>(&body)
|
||||
.expect("request body should be json"),
|
||||
json!({ "manual_code": true })
|
||||
);
|
||||
|
||||
let response_body = json!({
|
||||
"pairing_code": "pairing-code",
|
||||
"manual_pairing_code": "ABCD-EFGH",
|
||||
"server_id": "server-id",
|
||||
"environment_id": "environment-id",
|
||||
"expires_at": "3026-05-22T12:34:56Z",
|
||||
})
|
||||
.to_string();
|
||||
reader
|
||||
.get_mut()
|
||||
.write_all(
|
||||
format!(
|
||||
"HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\n\r\n{response_body}",
|
||||
response_body.len()
|
||||
)
|
||||
.as_bytes(),
|
||||
)
|
||||
.await
|
||||
.expect("response should write");
|
||||
});
|
||||
let client = pairing_client(pair_url);
|
||||
|
||||
let response = client
|
||||
.start(StartRemoteControlPairingRequest { manual_code: true })
|
||||
.await
|
||||
.expect("pairing should succeed");
|
||||
server_task.await.expect("server task should finish");
|
||||
|
||||
assert_eq!(
|
||||
response,
|
||||
RemoteControlPairingStartResponse {
|
||||
pairing_code: "pairing-code".to_string(),
|
||||
manual_pairing_code: Some("ABCD-EFGH".to_string()),
|
||||
environment_id: "environment-id".to_string(),
|
||||
expires_at: 33_336_362_096,
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn start_remote_control_pairing_preserves_backend_error_context() {
|
||||
let listener = TcpListener::bind("127.0.0.1:0")
|
||||
.await
|
||||
.expect("listener should bind");
|
||||
let pair_url = format!(
|
||||
"http://{}/backend-api/wham/remote/control/server/pair",
|
||||
listener.local_addr().expect("listener should have addr")
|
||||
);
|
||||
let expected_pair_url = pair_url.clone();
|
||||
let server_task = tokio::spawn(async move {
|
||||
let (stream, _) = listener.accept().await.expect("request should arrive");
|
||||
let mut reader = BufReader::new(stream);
|
||||
|
||||
loop {
|
||||
let mut line = String::new();
|
||||
reader
|
||||
.read_line(&mut line)
|
||||
.await
|
||||
.expect("request line should read");
|
||||
if line == "\r\n" {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let response_body = "pairing unavailable";
|
||||
reader
|
||||
.get_mut()
|
||||
.write_all(
|
||||
format!(
|
||||
"HTTP/1.1 503 Service Unavailable\r\nx-request-id: request-123\r\ncf-ray: ray-123\r\ncontent-length: {}\r\n\r\n{response_body}",
|
||||
response_body.len()
|
||||
)
|
||||
.as_bytes(),
|
||||
)
|
||||
.await
|
||||
.expect("response should write");
|
||||
});
|
||||
let client = pairing_client(pair_url);
|
||||
|
||||
let err = client
|
||||
.start(StartRemoteControlPairingRequest { manual_code: false })
|
||||
.await
|
||||
.expect_err("pairing should fail");
|
||||
server_task.await.expect("server task should finish");
|
||||
|
||||
assert_eq!(
|
||||
err.to_string(),
|
||||
format!(
|
||||
"remote control pairing failed at `{expected_pair_url}`: HTTP 503 Service Unavailable, request-id: request-123, cf-ray: ray-123, body: pairing unavailable"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn start_remote_control_pairing_rejects_expired_server_token() {
|
||||
let client = RemoteControlPairingClient::new(
|
||||
&RemoteControlTarget {
|
||||
websocket_url: "ws://unused".to_string(),
|
||||
enroll_url: "http://unused".to_string(),
|
||||
refresh_url: "http://unused".to_string(),
|
||||
pair_url: "http://unused".to_string(),
|
||||
},
|
||||
"remote-control-token".to_string(),
|
||||
"server-id".to_string(),
|
||||
"environment-id".to_string(),
|
||||
OffsetDateTime::from_unix_timestamp(0).expect("expired timestamp should parse"),
|
||||
/*auth_change_revision*/ 0,
|
||||
/*generation*/ 0,
|
||||
);
|
||||
|
||||
let err = client
|
||||
.start(StartRemoteControlPairingRequest { manual_code: false })
|
||||
.await
|
||||
.expect_err("expired server token should fail pairing");
|
||||
|
||||
assert_eq!(err.kind(), std::io::ErrorKind::InvalidInput);
|
||||
assert_eq!(
|
||||
err.to_string(),
|
||||
"remote control pairing is unavailable because the server token expired"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn start_remote_control_pairing_rejects_mismatched_enrollment() {
|
||||
let listener = TcpListener::bind("127.0.0.1:0")
|
||||
.await
|
||||
.expect("listener should bind");
|
||||
let pair_url = format!(
|
||||
"http://{}/backend-api/wham/remote/control/server/pair",
|
||||
listener.local_addr().expect("listener should have addr")
|
||||
);
|
||||
let server_task = tokio::spawn(async move {
|
||||
let (stream, _) = listener.accept().await.expect("request should arrive");
|
||||
let mut reader = BufReader::new(stream);
|
||||
|
||||
loop {
|
||||
let mut line = String::new();
|
||||
reader
|
||||
.read_line(&mut line)
|
||||
.await
|
||||
.expect("request line should read");
|
||||
if line == "\r\n" {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let response_body = json!({
|
||||
"pairing_code": "pairing-code",
|
||||
"manual_pairing_code": null,
|
||||
"server_id": "other-server-id",
|
||||
"environment_id": "other-environment-id",
|
||||
"expires_at": "3026-05-22T12:34:56Z",
|
||||
})
|
||||
.to_string();
|
||||
reader
|
||||
.get_mut()
|
||||
.write_all(
|
||||
format!(
|
||||
"HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\n\r\n{response_body}",
|
||||
response_body.len()
|
||||
)
|
||||
.as_bytes(),
|
||||
)
|
||||
.await
|
||||
.expect("response should write");
|
||||
});
|
||||
let client = pairing_client(pair_url);
|
||||
|
||||
let err = client
|
||||
.start(StartRemoteControlPairingRequest { manual_code: false })
|
||||
.await
|
||||
.expect_err("mismatched enrollment should fail pairing");
|
||||
server_task.await.expect("server task should finish");
|
||||
|
||||
assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
|
||||
assert_eq!(
|
||||
err.to_string(),
|
||||
"remote control pairing returned mismatched enrollment: expected server_id=server-id, environment_id=environment-id; got server_id=other-server-id, environment_id=other-environment-id"
|
||||
);
|
||||
}
|
||||
|
||||
fn pairing_client(pair_url: String) -> RemoteControlPairingClient {
|
||||
RemoteControlPairingClient::new(
|
||||
&RemoteControlTarget {
|
||||
websocket_url: "ws://unused".to_string(),
|
||||
enroll_url: "http://unused".to_string(),
|
||||
refresh_url: "http://unused".to_string(),
|
||||
pair_url,
|
||||
},
|
||||
"remote-control-token".to_string(),
|
||||
"server-id".to_string(),
|
||||
"environment-id".to_string(),
|
||||
OffsetDateTime::from_unix_timestamp(33_336_362_096).expect("future timestamp should parse"),
|
||||
/*auth_change_revision*/ 0,
|
||||
/*generation*/ 0,
|
||||
)
|
||||
}
|
||||
@@ -12,6 +12,7 @@ pub(super) struct RemoteControlTarget {
|
||||
pub(super) websocket_url: String,
|
||||
pub(super) enroll_url: String,
|
||||
pub(super) refresh_url: String,
|
||||
pub(super) pair_url: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
@@ -37,6 +38,20 @@ pub(super) struct RefreshRemoteServerRequest {
|
||||
pub(super) installation_id: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub(super) struct StartRemoteControlPairingRequest {
|
||||
pub(super) manual_code: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub(super) struct StartRemoteControlPairingResponse {
|
||||
pub(super) pairing_code: String,
|
||||
pub(super) manual_pairing_code: Option<String>,
|
||||
pub(super) server_id: String,
|
||||
pub(super) environment_id: String,
|
||||
pub(super) expires_at: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
#[serde(transparent)]
|
||||
pub struct ClientId(pub String);
|
||||
@@ -189,6 +204,9 @@ pub(super) fn normalize_remote_control_url(
|
||||
let refresh_url = remote_control_url
|
||||
.join("wham/remote/control/server/refresh")
|
||||
.map_err(map_url_parse_error)?;
|
||||
let pair_url = remote_control_url
|
||||
.join("wham/remote/control/server/pair")
|
||||
.map_err(map_url_parse_error)?;
|
||||
let mut websocket_url = remote_control_url
|
||||
.join("wham/remote/control/server")
|
||||
.map_err(map_url_parse_error)?;
|
||||
@@ -207,6 +225,7 @@ pub(super) fn normalize_remote_control_url(
|
||||
websocket_url: websocket_url.to_string(),
|
||||
enroll_url: enroll_url.to_string(),
|
||||
refresh_url: refresh_url.to_string(),
|
||||
pair_url: pair_url.to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -227,6 +246,8 @@ mod tests {
|
||||
.to_string(),
|
||||
refresh_url: "https://chatgpt.com/backend-api/wham/remote/control/server/refresh"
|
||||
.to_string(),
|
||||
pair_url: "https://chatgpt.com/backend-api/wham/remote/control/server/pair"
|
||||
.to_string(),
|
||||
}
|
||||
);
|
||||
assert_eq!(
|
||||
@@ -242,6 +263,9 @@ mod tests {
|
||||
refresh_url:
|
||||
"https://api.chatgpt-staging.com/backend-api/wham/remote/control/server/refresh"
|
||||
.to_string(),
|
||||
pair_url:
|
||||
"https://api.chatgpt-staging.com/backend-api/wham/remote/control/server/pair"
|
||||
.to_string(),
|
||||
}
|
||||
);
|
||||
}
|
||||
@@ -258,6 +282,8 @@ mod tests {
|
||||
.to_string(),
|
||||
refresh_url: "http://localhost:8080/backend-api/wham/remote/control/server/refresh"
|
||||
.to_string(),
|
||||
pair_url: "http://localhost:8080/backend-api/wham/remote/control/server/pair"
|
||||
.to_string(),
|
||||
}
|
||||
);
|
||||
assert_eq!(
|
||||
@@ -271,6 +297,8 @@ mod tests {
|
||||
refresh_url:
|
||||
"https://localhost:8443/backend-api/wham/remote/control/server/refresh"
|
||||
.to_string(),
|
||||
pair_url: "https://localhost:8443/backend-api/wham/remote/control/server/pair"
|
||||
.to_string(),
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@ use codex_app_server_protocol::AuthMode;
|
||||
use codex_app_server_protocol::ConfigWarningNotification;
|
||||
use codex_app_server_protocol::JSONRPCMessage;
|
||||
use codex_app_server_protocol::RemoteControlConnectionStatus;
|
||||
use codex_app_server_protocol::RemoteControlPairingStartParams;
|
||||
use codex_app_server_protocol::RemoteControlStatusChangedNotification;
|
||||
use codex_app_server_protocol::ServerNotification;
|
||||
use codex_config::types::AuthCredentialsStoreMode;
|
||||
@@ -39,7 +40,9 @@ use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
use std::collections::BTreeMap;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex as StdMutex;
|
||||
use tempfile::TempDir;
|
||||
use time::OffsetDateTime;
|
||||
use tokio::io::AsyncBufReadExt;
|
||||
use tokio::io::AsyncReadExt;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
@@ -58,9 +61,11 @@ use tokio_tungstenite::tungstenite;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
const TEST_INSTALLATION_ID: &str = "11111111-1111-4111-8111-111111111111";
|
||||
const TEST_REMOTE_CONTROL_URL: &str = "http://127.0.0.1:1/backend-api/wham/remote/control";
|
||||
const TEST_REMOTE_CONTROL_SERVER_TOKEN: &str = "Remote Control Token";
|
||||
const TEST_REFRESHED_REMOTE_CONTROL_SERVER_TOKEN: &str = "Refreshed Remote Control Token";
|
||||
const TEST_REMOTE_CONTROL_SERVER_TOKEN_EXPIRES_AT: &str = "2999-01-01T00:00:00Z";
|
||||
const TEST_HTTP_ACCEPT_TIMEOUT: Duration = Duration::from_secs(10);
|
||||
|
||||
fn remote_control_auth_manager() -> Arc<AuthManager> {
|
||||
auth_manager_from_auth(CodexAuth::create_dummy_chatgpt_auth_for_testing())
|
||||
@@ -128,16 +133,66 @@ fn test_server_name() -> String {
|
||||
gethostname().to_string_lossy().trim().to_string()
|
||||
}
|
||||
|
||||
fn remote_control_handle_with_pairing_client(
|
||||
remote_control_url: &str,
|
||||
auth_change_rx: watch::Receiver<u64>,
|
||||
) -> RemoteControlHandle {
|
||||
let (enabled_tx, _enabled_rx) = watch::channel(/*init*/ true);
|
||||
let (status_tx, _status_rx) = watch::channel(RemoteControlStatusChangedNotification {
|
||||
status: RemoteControlConnectionStatus::Connected,
|
||||
server_name: test_server_name(),
|
||||
installation_id: TEST_INSTALLATION_ID.to_string(),
|
||||
environment_id: Some("env_test".to_string()),
|
||||
});
|
||||
let (pairing_refresh_tx, _pairing_refresh_rx) = watch::channel(0u64);
|
||||
let pairing_client = Arc::new(StdMutex::new(Some(RemoteControlPairingClient::new(
|
||||
&normalize_remote_control_url(remote_control_url)
|
||||
.expect("remote control target should normalize"),
|
||||
TEST_REMOTE_CONTROL_SERVER_TOKEN.to_string(),
|
||||
"srv_e_test".to_string(),
|
||||
"env_test".to_string(),
|
||||
OffsetDateTime::from_unix_timestamp(33_336_362_096).expect("future timestamp should parse"),
|
||||
/*auth_change_revision*/ 0,
|
||||
/*generation*/ 0,
|
||||
))));
|
||||
RemoteControlHandle {
|
||||
enabled_tx: Arc::new(enabled_tx),
|
||||
status_tx: Arc::new(status_tx),
|
||||
state_db_available: true,
|
||||
pairing: PairingClientState {
|
||||
client: pairing_client,
|
||||
generation: Arc::new(std::sync::atomic::AtomicU64::new(0)),
|
||||
},
|
||||
pairing_request_cancellation_revision: Arc::new(std::sync::atomic::AtomicU64::new(0)),
|
||||
pairing_refresh_tx: Arc::new(pairing_refresh_tx),
|
||||
auth_change_rx: Arc::new(StdMutex::new(auth_change_rx)),
|
||||
}
|
||||
}
|
||||
|
||||
fn remote_control_server_token_response(
|
||||
server_id: &str,
|
||||
environment_id: &str,
|
||||
remote_control_token: &str,
|
||||
) -> serde_json::Value {
|
||||
remote_control_server_token_response_with_expires_at(
|
||||
server_id,
|
||||
environment_id,
|
||||
remote_control_token,
|
||||
TEST_REMOTE_CONTROL_SERVER_TOKEN_EXPIRES_AT,
|
||||
)
|
||||
}
|
||||
|
||||
fn remote_control_server_token_response_with_expires_at(
|
||||
server_id: &str,
|
||||
environment_id: &str,
|
||||
remote_control_token: &str,
|
||||
expires_at: &str,
|
||||
) -> serde_json::Value {
|
||||
json!({
|
||||
"server_id": server_id,
|
||||
"environment_id": environment_id,
|
||||
"remote_control_token": remote_control_token,
|
||||
"expires_at": TEST_REMOTE_CONTROL_SERVER_TOKEN_EXPIRES_AT,
|
||||
"expires_at": expires_at,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -895,6 +950,9 @@ async fn remote_control_handle_enable_disable_stops_and_restarts_connections() {
|
||||
let _ = remote_task.await;
|
||||
}
|
||||
|
||||
#[path = "pairing_integration_tests.rs"]
|
||||
mod pairing_integration_tests;
|
||||
|
||||
#[tokio::test]
|
||||
async fn remote_control_transport_clears_outgoing_buffer_when_backend_acks() {
|
||||
let listener = TcpListener::bind("127.0.0.1:0")
|
||||
@@ -1571,6 +1629,111 @@ async fn remote_control_waits_for_account_id_before_enrolling() {
|
||||
let _ = remote_task.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn remote_control_pairs_after_auth_reloads_during_connect() {
|
||||
let listener = TcpListener::bind("127.0.0.1:0")
|
||||
.await
|
||||
.expect("listener should bind");
|
||||
let remote_control_url = remote_control_url_for_listener(&listener);
|
||||
let codex_home = TempDir::new().expect("temp dir should create");
|
||||
save_auth(
|
||||
codex_home.path(),
|
||||
&remote_control_auth_dot_json(/*account_id*/ None),
|
||||
AuthCredentialsStoreMode::File,
|
||||
)
|
||||
.expect("auth without account id should save");
|
||||
let state_db = remote_control_state_runtime(&codex_home).await;
|
||||
let auth_manager = AuthManager::shared(
|
||||
codex_home.path().to_path_buf(),
|
||||
/*enable_codex_api_key_env*/ false,
|
||||
AuthCredentialsStoreMode::File,
|
||||
/*chatgpt_base_url*/ None,
|
||||
)
|
||||
.await;
|
||||
save_auth(
|
||||
codex_home.path(),
|
||||
&remote_control_auth_dot_json(Some("account_id")),
|
||||
AuthCredentialsStoreMode::File,
|
||||
)
|
||||
.expect("auth with account id should save before connect");
|
||||
|
||||
let (transport_event_tx, _transport_event_rx) =
|
||||
mpsc::channel::<TransportEvent>(CHANNEL_CAPACITY);
|
||||
let shutdown_token = CancellationToken::new();
|
||||
let (remote_task, remote_handle) = start_remote_control(
|
||||
RemoteControlStartConfig {
|
||||
remote_control_url,
|
||||
installation_id: TEST_INSTALLATION_ID.to_string(),
|
||||
},
|
||||
Some(state_db),
|
||||
auth_manager,
|
||||
transport_event_tx,
|
||||
shutdown_token.clone(),
|
||||
/*app_server_client_name_rx*/ None,
|
||||
/*initial_enabled*/ true,
|
||||
)
|
||||
.await
|
||||
.expect("remote control should start");
|
||||
|
||||
let enroll_request = accept_http_request(&listener).await;
|
||||
assert_eq!(
|
||||
enroll_request.request_line,
|
||||
"POST /backend-api/wham/remote/control/server/enroll HTTP/1.1"
|
||||
);
|
||||
respond_with_json(
|
||||
enroll_request.stream,
|
||||
remote_control_server_token_response(
|
||||
"srv_e_ready",
|
||||
"env_ready",
|
||||
TEST_REMOTE_CONTROL_SERVER_TOKEN,
|
||||
),
|
||||
)
|
||||
.await;
|
||||
let (_handshake_request, mut websocket) =
|
||||
accept_remote_control_backend_connection(&listener).await;
|
||||
timeout(Duration::from_secs(5), async {
|
||||
while remote_handle.status().status != RemoteControlConnectionStatus::Connected {
|
||||
tokio::task::yield_now().await;
|
||||
}
|
||||
})
|
||||
.await
|
||||
.expect("remote control should publish connected before pairing");
|
||||
tokio::task::yield_now().await;
|
||||
|
||||
let pairing_task = tokio::spawn({
|
||||
let remote_handle = remote_handle.clone();
|
||||
async move {
|
||||
remote_handle
|
||||
.start_pairing(RemoteControlPairingStartParams::default())
|
||||
.await
|
||||
}
|
||||
});
|
||||
let pairing_request = accept_http_request(&listener).await;
|
||||
assert_eq!(
|
||||
pairing_request.request_line,
|
||||
"POST /backend-api/wham/remote/control/server/pair HTTP/1.1"
|
||||
);
|
||||
respond_with_json(
|
||||
pairing_request.stream,
|
||||
json!({
|
||||
"pairing_code": "pairing-code",
|
||||
"manual_pairing_code": "ABCD-EFGH",
|
||||
"server_id": "srv_e_ready",
|
||||
"environment_id": "env_ready",
|
||||
"expires_at": "3026-05-22T12:34:56Z",
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
pairing_task
|
||||
.await
|
||||
.expect("pairing task should join")
|
||||
.expect("pairing should use auth reloaded during connect");
|
||||
|
||||
websocket.close(None).await.expect("websocket should close");
|
||||
shutdown_token.cancel();
|
||||
let _ = remote_task.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn remote_control_http_mode_reenrolls_when_refresh_reports_stale_enrollment() {
|
||||
let listener = TcpListener::bind("127.0.0.1:0")
|
||||
@@ -1848,8 +2011,36 @@ async fn accept_remote_control_connection(listener: &TcpListener) -> WebSocketSt
|
||||
.expect("websocket handshake should succeed")
|
||||
}
|
||||
|
||||
async fn expect_remote_control_connection_closed(
|
||||
websocket: &mut WebSocketStream<TcpStream>,
|
||||
timeout_message: &str,
|
||||
) {
|
||||
loop {
|
||||
let frame = timeout(Duration::from_secs(5), websocket.next())
|
||||
.await
|
||||
.expect(timeout_message);
|
||||
let Some(frame) = frame else {
|
||||
return;
|
||||
};
|
||||
let Ok(frame) = frame else {
|
||||
return;
|
||||
};
|
||||
match frame {
|
||||
tungstenite::Message::Close(_) => return,
|
||||
tungstenite::Message::Ping(payload) => {
|
||||
websocket
|
||||
.send(tungstenite::Message::Pong(payload))
|
||||
.await
|
||||
.expect("websocket pong should send");
|
||||
}
|
||||
tungstenite::Message::Pong(_) | tungstenite::Message::Frame(_) => {}
|
||||
frame => panic!("unexpected websocket frame while waiting for close: {frame:?}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn accept_http_request(listener: &TcpListener) -> CapturedHttpRequest {
|
||||
let (stream, _) = timeout(Duration::from_secs(5), listener.accept())
|
||||
let (stream, _) = timeout(TEST_HTTP_ACCEPT_TIMEOUT, listener.accept())
|
||||
.await
|
||||
.expect("HTTP request should arrive in time")
|
||||
.expect("listener accept should succeed");
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use super::PairingClientState;
|
||||
use crate::transport::TransportEvent;
|
||||
use crate::transport::remote_control::client_tracker::ClientTracker;
|
||||
use crate::transport::remote_control::client_tracker::REMOTE_CONTROL_IDLE_SWEEP_INTERVAL;
|
||||
@@ -9,7 +10,9 @@ use crate::transport::remote_control::enroll::load_persisted_remote_control_enro
|
||||
use crate::transport::remote_control::enroll::preview_remote_control_response_body;
|
||||
use crate::transport::remote_control::enroll::refresh_remote_control_server;
|
||||
use crate::transport::remote_control::enroll::update_persisted_remote_control_enrollment;
|
||||
use crate::transport::remote_control::pairing::RemoteControlPairingClient;
|
||||
|
||||
use super::clear_pairing_client;
|
||||
use super::protocol::ClientEnvelope;
|
||||
use super::protocol::ClientEvent;
|
||||
use super::protocol::ClientId;
|
||||
@@ -39,6 +42,7 @@ use std::collections::VecDeque;
|
||||
use std::io;
|
||||
use std::io::ErrorKind;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::Ordering;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::mpsc;
|
||||
@@ -251,6 +255,9 @@ pub(crate) struct RemoteControlWebsocket {
|
||||
enrollment: Option<RemoteControlEnrollment>,
|
||||
auth_recovery: UnauthorizedRecovery,
|
||||
auth_change_rx: watch::Receiver<u64>,
|
||||
pairing: PairingClientState,
|
||||
_pairing_refresh_tx: watch::Sender<u64>,
|
||||
pairing_refresh_rx: watch::Receiver<u64>,
|
||||
client_tracker: Arc<Mutex<ClientTracker>>,
|
||||
state: Arc<Mutex<WebsocketState>>,
|
||||
server_event_rx: Arc<Mutex<mpsc::Receiver<super::QueuedServerEnvelope>>>,
|
||||
@@ -282,12 +289,52 @@ enum ConnectionEndReason {
|
||||
Shutdown,
|
||||
Disabled,
|
||||
EnabledWatchClosed,
|
||||
AuthChanged,
|
||||
AuthWatchClosed,
|
||||
PairingRefreshWatchClosed,
|
||||
ServerTokenRefreshRejected,
|
||||
StaleEnrollment,
|
||||
ConnectionWorkerStopped,
|
||||
}
|
||||
|
||||
enum ConnectionLoopAction {
|
||||
End(ConnectionEndReason),
|
||||
RefreshServerToken,
|
||||
}
|
||||
|
||||
enum ConnectedServerTokenRefreshAction {
|
||||
Continue,
|
||||
Retry(io::Error),
|
||||
End(ConnectionEndReason),
|
||||
}
|
||||
|
||||
struct ConnectedServerTokenRefreshRequest {
|
||||
remote_control_target: RemoteControlTarget,
|
||||
auth: RemoteControlConnectionAuth,
|
||||
installation_id: String,
|
||||
enrollment: RemoteControlEnrollment,
|
||||
auth_change_revision: u64,
|
||||
}
|
||||
|
||||
struct PendingConnectedServerTokenRefreshRequest {
|
||||
remote_control_target: RemoteControlTarget,
|
||||
auth_manager: Arc<AuthManager>,
|
||||
auth_change_rx: watch::Receiver<u64>,
|
||||
installation_id: String,
|
||||
enrollment: RemoteControlEnrollment,
|
||||
}
|
||||
|
||||
struct ConnectedServerTokenRefreshResponse {
|
||||
enrollment: RemoteControlEnrollment,
|
||||
auth_change_revision: u64,
|
||||
}
|
||||
|
||||
pub(super) struct RemoteControlChannels {
|
||||
pub(super) transport_event_tx: mpsc::Sender<TransportEvent>,
|
||||
pub(super) status_publisher: RemoteControlStatusPublisher,
|
||||
pub(super) pairing: PairingClientState,
|
||||
pub(super) pairing_refresh_tx: watch::Sender<u64>,
|
||||
pub(super) pairing_refresh_rx: watch::Receiver<u64>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -404,6 +451,9 @@ impl RemoteControlWebsocket {
|
||||
enrollment: None,
|
||||
auth_recovery,
|
||||
auth_change_rx,
|
||||
pairing: channels.pairing,
|
||||
_pairing_refresh_tx: channels.pairing_refresh_tx,
|
||||
pairing_refresh_rx: channels.pairing_refresh_rx,
|
||||
client_tracker: Arc::new(Mutex::new(client_tracker)),
|
||||
state: Arc::new(Mutex::new(WebsocketState {
|
||||
outbound_buffer,
|
||||
@@ -489,7 +539,11 @@ impl RemoteControlWebsocket {
|
||||
};
|
||||
|
||||
let connection_end_reason = self
|
||||
.run_connection(websocket_connection, shutdown_token)
|
||||
.run_connection(
|
||||
websocket_connection,
|
||||
shutdown_token,
|
||||
app_server_client_name.as_deref(),
|
||||
)
|
||||
.await;
|
||||
let status = self.status_publisher.status();
|
||||
info!(
|
||||
@@ -611,12 +665,14 @@ impl RemoteControlWebsocket {
|
||||
&mut self.enrollment,
|
||||
connect_options,
|
||||
&self.status_publisher,
|
||||
&self.pairing,
|
||||
) => connect_result,
|
||||
};
|
||||
|
||||
match connect_result {
|
||||
Ok((websocket_connection, response)) => {
|
||||
if !*self.enabled_rx.borrow() {
|
||||
clear_pairing_client(&self.pairing);
|
||||
return ConnectOutcome::Disabled;
|
||||
}
|
||||
self.reconnect_attempt = 0;
|
||||
@@ -696,9 +752,10 @@ impl RemoteControlWebsocket {
|
||||
}
|
||||
|
||||
async fn run_connection(
|
||||
&self,
|
||||
&mut self,
|
||||
websocket_connection: WebSocketStream<MaybeTlsStream<TcpStream>>,
|
||||
shutdown_token: CancellationToken,
|
||||
app_server_client_name: Option<&str>,
|
||||
) -> ConnectionEndReason {
|
||||
let (websocket_writer, websocket_reader) = websocket_connection.split();
|
||||
let mut join_set = tokio::task::JoinSet::new();
|
||||
@@ -720,19 +777,81 @@ impl RemoteControlWebsocket {
|
||||
));
|
||||
|
||||
let mut enabled_rx = self.enabled_rx.clone();
|
||||
let connection_end_reason = tokio::select! {
|
||||
_ = shutdown_token.cancelled() => ConnectionEndReason::Shutdown,
|
||||
changed = enabled_rx.wait_for(|enabled| !*enabled) => {
|
||||
if changed.is_ok() {
|
||||
self.status_publisher
|
||||
.publish_status(RemoteControlConnectionStatus::Disabled);
|
||||
ConnectionEndReason::Disabled
|
||||
} else {
|
||||
ConnectionEndReason::EnabledWatchClosed
|
||||
let mut server_token_refresh_retry_delay = None;
|
||||
let connection_end_reason = loop {
|
||||
let server_token_refresh_delay =
|
||||
server_token_refresh_retry_delay.take().or_else(|| {
|
||||
self.enrollment
|
||||
.as_ref()
|
||||
.and_then(RemoteControlEnrollment::server_token_refresh_delay)
|
||||
});
|
||||
let server_token_refresh = async move {
|
||||
match server_token_refresh_delay {
|
||||
Some(delay) => tokio::time::sleep(delay).await,
|
||||
None => std::future::pending().await,
|
||||
}
|
||||
};
|
||||
tokio::pin!(server_token_refresh);
|
||||
let connection_loop_action = tokio::select! {
|
||||
_ = shutdown_token.cancelled() => ConnectionLoopAction::End(ConnectionEndReason::Shutdown),
|
||||
changed = enabled_rx.wait_for(|enabled| !*enabled) => {
|
||||
if changed.is_ok() {
|
||||
self.status_publisher
|
||||
.publish_status(RemoteControlConnectionStatus::Disabled);
|
||||
ConnectionLoopAction::End(ConnectionEndReason::Disabled)
|
||||
} else {
|
||||
ConnectionLoopAction::End(ConnectionEndReason::EnabledWatchClosed)
|
||||
}
|
||||
}
|
||||
changed = self.auth_change_rx.changed() => {
|
||||
if changed.is_ok() {
|
||||
self.auth_recovery = self.auth_manager.unauthorized_recovery();
|
||||
ConnectionLoopAction::End(ConnectionEndReason::AuthChanged)
|
||||
} else {
|
||||
ConnectionLoopAction::End(ConnectionEndReason::AuthWatchClosed)
|
||||
}
|
||||
}
|
||||
changed = self.pairing_refresh_rx.changed() => {
|
||||
if changed.is_ok() {
|
||||
ConnectionLoopAction::RefreshServerToken
|
||||
} else {
|
||||
ConnectionLoopAction::End(ConnectionEndReason::PairingRefreshWatchClosed)
|
||||
}
|
||||
}
|
||||
_ = &mut server_token_refresh => ConnectionLoopAction::RefreshServerToken,
|
||||
_ = join_set.join_next() => ConnectionLoopAction::End(ConnectionEndReason::ConnectionWorkerStopped),
|
||||
};
|
||||
match connection_loop_action {
|
||||
ConnectionLoopAction::End(connection_end_reason) => break connection_end_reason,
|
||||
ConnectionLoopAction::RefreshServerToken => {
|
||||
match self
|
||||
.refresh_connected_server_token_while_connected(
|
||||
&mut enabled_rx,
|
||||
&mut join_set,
|
||||
&shutdown_token,
|
||||
app_server_client_name,
|
||||
)
|
||||
.await
|
||||
{
|
||||
ConnectedServerTokenRefreshAction::Continue => {}
|
||||
ConnectedServerTokenRefreshAction::Retry(err) => {
|
||||
warn!(
|
||||
error = %err,
|
||||
error_kind = ?err.kind(),
|
||||
retry_delay = ?REMOTE_CONTROL_ACCOUNT_ID_RETRY_INTERVAL,
|
||||
"failed to refresh connected app-server remote control server token"
|
||||
);
|
||||
server_token_refresh_retry_delay =
|
||||
Some(REMOTE_CONTROL_ACCOUNT_ID_RETRY_INTERVAL);
|
||||
}
|
||||
ConnectedServerTokenRefreshAction::End(connection_end_reason) => {
|
||||
break connection_end_reason;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
_ = join_set.join_next() => ConnectionEndReason::ConnectionWorkerStopped,
|
||||
};
|
||||
clear_pairing_client(&self.pairing);
|
||||
shutdown_token.cancel();
|
||||
|
||||
Self::join_connection_workers(&mut join_set, REMOTE_CONTROL_CONNECTION_SHUTDOWN_TIMEOUT)
|
||||
@@ -740,6 +859,209 @@ impl RemoteControlWebsocket {
|
||||
connection_end_reason
|
||||
}
|
||||
|
||||
async fn refresh_connected_server_token_while_connected(
|
||||
&mut self,
|
||||
enabled_rx: &mut watch::Receiver<bool>,
|
||||
join_set: &mut tokio::task::JoinSet<()>,
|
||||
shutdown_token: &CancellationToken,
|
||||
app_server_client_name: Option<&str>,
|
||||
) -> ConnectedServerTokenRefreshAction {
|
||||
let refresh_request = match self.connected_server_token_refresh_request() {
|
||||
Ok(refresh_request) => refresh_request,
|
||||
Err(err) => return ConnectedServerTokenRefreshAction::Retry(err),
|
||||
};
|
||||
let refresh_request = tokio::select! {
|
||||
_ = shutdown_token.cancelled() => {
|
||||
return ConnectedServerTokenRefreshAction::End(ConnectionEndReason::Shutdown);
|
||||
}
|
||||
changed = enabled_rx.wait_for(|enabled| !*enabled) => {
|
||||
if changed.is_ok() {
|
||||
self.status_publisher
|
||||
.publish_status(RemoteControlConnectionStatus::Disabled);
|
||||
return ConnectedServerTokenRefreshAction::End(ConnectionEndReason::Disabled);
|
||||
}
|
||||
return ConnectedServerTokenRefreshAction::End(ConnectionEndReason::EnabledWatchClosed);
|
||||
}
|
||||
_ = join_set.join_next() => {
|
||||
return ConnectedServerTokenRefreshAction::End(ConnectionEndReason::ConnectionWorkerStopped);
|
||||
}
|
||||
refresh_request = prepare_connected_server_token_refresh(refresh_request) => refresh_request,
|
||||
};
|
||||
let refresh_request = match refresh_request {
|
||||
Ok(refresh_request) => refresh_request,
|
||||
Err(err) => {
|
||||
return self
|
||||
.apply_connected_server_token_refresh(Err(err), app_server_client_name)
|
||||
.await;
|
||||
}
|
||||
};
|
||||
if !mark_connected_refresh_auth_change_seen(
|
||||
&mut self.auth_change_rx,
|
||||
refresh_request.auth_change_revision,
|
||||
) {
|
||||
clear_pairing_client(&self.pairing);
|
||||
self.auth_recovery = self.auth_manager.unauthorized_recovery();
|
||||
return ConnectedServerTokenRefreshAction::End(ConnectionEndReason::AuthChanged);
|
||||
}
|
||||
let refresh_result = tokio::select! {
|
||||
_ = shutdown_token.cancelled() => {
|
||||
return ConnectedServerTokenRefreshAction::End(ConnectionEndReason::Shutdown);
|
||||
}
|
||||
changed = enabled_rx.wait_for(|enabled| !*enabled) => {
|
||||
if changed.is_ok() {
|
||||
self.status_publisher
|
||||
.publish_status(RemoteControlConnectionStatus::Disabled);
|
||||
return ConnectedServerTokenRefreshAction::End(ConnectionEndReason::Disabled);
|
||||
}
|
||||
return ConnectedServerTokenRefreshAction::End(ConnectionEndReason::EnabledWatchClosed);
|
||||
}
|
||||
changed = self.auth_change_rx.changed() => {
|
||||
if changed.is_ok() {
|
||||
self.auth_recovery = self.auth_manager.unauthorized_recovery();
|
||||
return ConnectedServerTokenRefreshAction::End(ConnectionEndReason::AuthChanged);
|
||||
}
|
||||
return ConnectedServerTokenRefreshAction::End(ConnectionEndReason::AuthWatchClosed);
|
||||
}
|
||||
_ = join_set.join_next() => {
|
||||
return ConnectedServerTokenRefreshAction::End(ConnectionEndReason::ConnectionWorkerStopped);
|
||||
}
|
||||
refresh_result = refresh_connected_server_token(refresh_request) => refresh_result,
|
||||
};
|
||||
self.apply_connected_server_token_refresh(refresh_result, app_server_client_name)
|
||||
.await
|
||||
}
|
||||
|
||||
fn connected_server_token_refresh_request(
|
||||
&self,
|
||||
) -> io::Result<PendingConnectedServerTokenRefreshRequest> {
|
||||
let remote_control_target = self.remote_control_target.clone().ok_or_else(|| {
|
||||
io::Error::other("missing remote control target while refreshing server token")
|
||||
})?;
|
||||
let enrollment = self.enrollment.clone().ok_or_else(|| {
|
||||
io::Error::other("missing remote control enrollment while refreshing server token")
|
||||
})?;
|
||||
Ok(PendingConnectedServerTokenRefreshRequest {
|
||||
remote_control_target,
|
||||
auth_manager: self.auth_manager.clone(),
|
||||
auth_change_rx: self.auth_change_rx.clone(),
|
||||
installation_id: self.installation_id.clone(),
|
||||
enrollment,
|
||||
})
|
||||
}
|
||||
|
||||
async fn apply_connected_server_token_refresh(
|
||||
&mut self,
|
||||
refresh_result: io::Result<ConnectedServerTokenRefreshResponse>,
|
||||
app_server_client_name: Option<&str>,
|
||||
) -> ConnectedServerTokenRefreshAction {
|
||||
let refresh_response = match refresh_result {
|
||||
Ok(refresh_response) => refresh_response,
|
||||
Err(err) if err.kind() == ErrorKind::NotFound => {
|
||||
self.clear_stale_connected_enrollment(app_server_client_name)
|
||||
.await;
|
||||
return ConnectedServerTokenRefreshAction::End(
|
||||
ConnectionEndReason::StaleEnrollment,
|
||||
);
|
||||
}
|
||||
Err(err) if err.kind() == ErrorKind::PermissionDenied => {
|
||||
if recover_remote_control_auth(&mut self.auth_recovery, &mut self.auth_change_rx)
|
||||
.await
|
||||
{
|
||||
return ConnectedServerTokenRefreshAction::Retry(io::Error::other(format!(
|
||||
"{err}; retrying after auth recovery"
|
||||
)));
|
||||
}
|
||||
if let Some(enrollment) = self.enrollment.as_mut() {
|
||||
enrollment.clear_server_token();
|
||||
}
|
||||
clear_pairing_client(&self.pairing);
|
||||
return ConnectedServerTokenRefreshAction::End(
|
||||
ConnectionEndReason::ServerTokenRefreshRejected,
|
||||
);
|
||||
}
|
||||
Err(err) if err.kind() == ErrorKind::InvalidInput => {
|
||||
clear_pairing_client(&self.pairing);
|
||||
self.auth_recovery = self.auth_manager.unauthorized_recovery();
|
||||
return ConnectedServerTokenRefreshAction::End(ConnectionEndReason::AuthChanged);
|
||||
}
|
||||
Err(err) => return ConnectedServerTokenRefreshAction::Retry(err),
|
||||
};
|
||||
if *self.auth_change_rx.borrow() != refresh_response.auth_change_revision {
|
||||
clear_pairing_client(&self.pairing);
|
||||
self.auth_recovery = self.auth_manager.unauthorized_recovery();
|
||||
return ConnectedServerTokenRefreshAction::End(ConnectionEndReason::AuthChanged);
|
||||
}
|
||||
let current_enrollment = match self.enrollment.as_ref() {
|
||||
Some(current_enrollment) => current_enrollment,
|
||||
None => return ConnectedServerTokenRefreshAction::Continue,
|
||||
};
|
||||
if !same_remote_control_enrollment_identity(
|
||||
current_enrollment,
|
||||
&refresh_response.enrollment,
|
||||
) {
|
||||
return ConnectedServerTokenRefreshAction::Continue;
|
||||
}
|
||||
self.enrollment = Some(refresh_response.enrollment);
|
||||
let Some(remote_control_target) = self.remote_control_target.as_ref() else {
|
||||
return ConnectedServerTokenRefreshAction::Retry(io::Error::other(
|
||||
"missing remote control target after refreshing server token",
|
||||
));
|
||||
};
|
||||
let Some(enrollment) = self.enrollment.as_ref() else {
|
||||
return ConnectedServerTokenRefreshAction::Retry(io::Error::other(
|
||||
"missing remote control enrollment after refreshing server token",
|
||||
));
|
||||
};
|
||||
match set_pairing_client(
|
||||
&self.pairing,
|
||||
remote_control_target,
|
||||
enrollment,
|
||||
refresh_response.auth_change_revision,
|
||||
) {
|
||||
Ok(()) => ConnectedServerTokenRefreshAction::Continue,
|
||||
Err(err) => ConnectedServerTokenRefreshAction::Retry(err),
|
||||
}
|
||||
}
|
||||
|
||||
async fn clear_stale_connected_enrollment(&mut self, app_server_client_name: Option<&str>) {
|
||||
let Some(remote_control_target) = self.remote_control_target.as_ref() else {
|
||||
clear_pairing_client(&self.pairing);
|
||||
self.enrollment = None;
|
||||
self.status_publisher
|
||||
.publish_environment_id(/*environment_id*/ None);
|
||||
return;
|
||||
};
|
||||
let Some(account_id) = self
|
||||
.enrollment
|
||||
.as_ref()
|
||||
.map(|enrollment| enrollment.account_id.clone())
|
||||
else {
|
||||
clear_pairing_client(&self.pairing);
|
||||
return;
|
||||
};
|
||||
info!(
|
||||
"connected remote control server refresh returned HTTP 404; clearing stale enrollment before re-enrolling: websocket_url={}, account_id={}",
|
||||
remote_control_target.websocket_url, account_id
|
||||
);
|
||||
if let Some(state_db) = self.state_db.as_deref() {
|
||||
clear_remote_control_enrollment(
|
||||
state_db,
|
||||
remote_control_target,
|
||||
&account_id,
|
||||
app_server_client_name,
|
||||
&mut self.enrollment,
|
||||
&self.status_publisher,
|
||||
&self.pairing,
|
||||
)
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
self.enrollment = None;
|
||||
self.status_publisher
|
||||
.publish_environment_id(/*environment_id*/ None);
|
||||
clear_pairing_client(&self.pairing);
|
||||
}
|
||||
|
||||
async fn join_connection_workers(
|
||||
join_set: &mut tokio::task::JoinSet<()>,
|
||||
shutdown_timeout: std::time::Duration,
|
||||
@@ -1166,6 +1488,73 @@ fn build_remote_control_websocket_request(
|
||||
Ok(request)
|
||||
}
|
||||
|
||||
async fn refresh_connected_server_token(
|
||||
mut refresh_request: ConnectedServerTokenRefreshRequest,
|
||||
) -> io::Result<ConnectedServerTokenRefreshResponse> {
|
||||
info!(
|
||||
"refreshing connected remote control server token: websocket_url={}, refresh_url={}, account_id={}, server_id={}, environment_id={}",
|
||||
refresh_request.remote_control_target.websocket_url,
|
||||
refresh_request.remote_control_target.refresh_url,
|
||||
refresh_request.auth.account_id,
|
||||
refresh_request.enrollment.server_id,
|
||||
refresh_request.enrollment.environment_id
|
||||
);
|
||||
refresh_remote_control_server(
|
||||
&refresh_request.remote_control_target,
|
||||
&refresh_request.auth,
|
||||
&refresh_request.installation_id,
|
||||
&mut refresh_request.enrollment,
|
||||
)
|
||||
.await?;
|
||||
Ok(ConnectedServerTokenRefreshResponse {
|
||||
enrollment: refresh_request.enrollment,
|
||||
auth_change_revision: refresh_request.auth_change_revision,
|
||||
})
|
||||
}
|
||||
|
||||
async fn prepare_connected_server_token_refresh(
|
||||
mut refresh_request: PendingConnectedServerTokenRefreshRequest,
|
||||
) -> io::Result<ConnectedServerTokenRefreshRequest> {
|
||||
let auth = load_remote_control_auth(&refresh_request.auth_manager).await?;
|
||||
// Loading auth may reload or proactively refresh through the same watch
|
||||
// receiver. Treat the auth used for this refresh as the current revision
|
||||
// before waiting for later auth changes to cancel the HTTP request.
|
||||
let auth_change_revision = *refresh_request.auth_change_rx.borrow_and_update();
|
||||
if refresh_request.enrollment.account_id != auth.account_id {
|
||||
return Err(io::Error::new(
|
||||
ErrorKind::InvalidInput,
|
||||
"remote control auth changed while refreshing server token",
|
||||
));
|
||||
}
|
||||
Ok(ConnectedServerTokenRefreshRequest {
|
||||
remote_control_target: refresh_request.remote_control_target,
|
||||
auth,
|
||||
installation_id: refresh_request.installation_id,
|
||||
enrollment: refresh_request.enrollment,
|
||||
auth_change_revision,
|
||||
})
|
||||
}
|
||||
|
||||
fn mark_connected_refresh_auth_change_seen(
|
||||
auth_change_rx: &mut watch::Receiver<u64>,
|
||||
auth_change_revision: u64,
|
||||
) -> bool {
|
||||
if *auth_change_rx.borrow() != auth_change_revision {
|
||||
return false;
|
||||
}
|
||||
auth_change_rx.borrow_and_update();
|
||||
true
|
||||
}
|
||||
|
||||
fn same_remote_control_enrollment_identity(
|
||||
left: &RemoteControlEnrollment,
|
||||
right: &RemoteControlEnrollment,
|
||||
) -> bool {
|
||||
left.account_id == right.account_id
|
||||
&& left.server_id == right.server_id
|
||||
&& left.environment_id == right.environment_id
|
||||
}
|
||||
|
||||
pub(crate) async fn load_remote_control_auth(
|
||||
auth_manager: &Arc<AuthManager>,
|
||||
) -> io::Result<RemoteControlConnectionAuth> {
|
||||
@@ -1229,6 +1618,7 @@ pub(super) async fn connect_remote_control_websocket(
|
||||
enrollment: &mut Option<RemoteControlEnrollment>,
|
||||
connect_options: RemoteControlConnectOptions<'_>,
|
||||
status_publisher: &RemoteControlStatusPublisher,
|
||||
pairing: &PairingClientState,
|
||||
) -> io::Result<(
|
||||
WebSocketStream<MaybeTlsStream<TcpStream>>,
|
||||
tungstenite::http::Response<()>,
|
||||
@@ -1237,6 +1627,7 @@ pub(super) async fn connect_remote_control_websocket(
|
||||
|
||||
let Some(state_db) = state_db else {
|
||||
*enrollment = None;
|
||||
clear_pairing_client(pairing);
|
||||
return Err(io::Error::new(
|
||||
ErrorKind::NotFound,
|
||||
"remote control requires sqlite state db",
|
||||
@@ -1249,10 +1640,15 @@ pub(super) async fn connect_remote_control_websocket(
|
||||
if err.kind() == ErrorKind::PermissionDenied {
|
||||
*enrollment = None;
|
||||
status_publisher.publish_environment_id(/*environment_id*/ None);
|
||||
clear_pairing_client(pairing);
|
||||
}
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
// Loading auth may reload or proactively refresh through the same watch
|
||||
// receiver. Treat the auth used for this connection as the current
|
||||
// revision before publishing pairing auth for it.
|
||||
let auth_change_revision = *auth_context.auth_change_rx.borrow_and_update();
|
||||
let enrollment_account_id = enrollment.as_ref().map(|enrollment| &enrollment.account_id);
|
||||
if enrollment_account_id.is_some_and(|account_id| account_id != &auth.account_id) {
|
||||
info!(
|
||||
@@ -1265,6 +1661,7 @@ pub(super) async fn connect_remote_control_websocket(
|
||||
);
|
||||
*enrollment = None;
|
||||
status_publisher.publish_environment_id(/*environment_id*/ None);
|
||||
clear_pairing_client(pairing);
|
||||
}
|
||||
|
||||
if let Some(enrollment) = enrollment.as_ref() {
|
||||
@@ -1342,6 +1739,7 @@ pub(super) async fn connect_remote_control_websocket(
|
||||
connect_options.app_server_client_name,
|
||||
enrollment,
|
||||
status_publisher,
|
||||
pairing,
|
||||
)
|
||||
.await;
|
||||
enroll_remote_control_server_if_missing(
|
||||
@@ -1397,8 +1795,17 @@ pub(super) async fn connect_remote_control_websocket(
|
||||
})?;
|
||||
|
||||
match websocket_connect_result {
|
||||
Ok((websocket_stream, response)) => Ok((websocket_stream, response.map(|_| ()))),
|
||||
Ok((websocket_stream, response)) => {
|
||||
set_pairing_client(
|
||||
pairing,
|
||||
remote_control_target,
|
||||
enrollment_ref,
|
||||
auth_change_revision,
|
||||
)?;
|
||||
Ok((websocket_stream, response.map(|_| ())))
|
||||
}
|
||||
Err(err) => {
|
||||
clear_pairing_client(pairing);
|
||||
match &err {
|
||||
tungstenite::Error::Http(response) if response.status().as_u16() == 404 => {
|
||||
info!(
|
||||
@@ -1415,6 +1822,7 @@ pub(super) async fn connect_remote_control_websocket(
|
||||
connect_options.app_server_client_name,
|
||||
enrollment,
|
||||
status_publisher,
|
||||
pairing,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
@@ -1429,6 +1837,7 @@ pub(super) async fn connect_remote_control_websocket(
|
||||
)
|
||||
})?
|
||||
.clear_server_token();
|
||||
clear_pairing_client(pairing);
|
||||
return Err(io::Error::other(format!(
|
||||
"remote control websocket auth failed with HTTP {}; refreshing server token before reconnect",
|
||||
response.status()
|
||||
@@ -1453,6 +1862,7 @@ async fn clear_remote_control_enrollment(
|
||||
app_server_client_name: Option<&str>,
|
||||
enrollment: &mut Option<RemoteControlEnrollment>,
|
||||
status_publisher: &RemoteControlStatusPublisher,
|
||||
pairing: &PairingClientState,
|
||||
) {
|
||||
if let Err(clear_err) = update_persisted_remote_control_enrollment(
|
||||
Some(state_db),
|
||||
@@ -1467,6 +1877,42 @@ async fn clear_remote_control_enrollment(
|
||||
}
|
||||
*enrollment = None;
|
||||
status_publisher.publish_environment_id(/*environment_id*/ None);
|
||||
clear_pairing_client(pairing);
|
||||
}
|
||||
|
||||
fn set_pairing_client(
|
||||
pairing: &PairingClientState,
|
||||
remote_control_target: &RemoteControlTarget,
|
||||
enrollment: &RemoteControlEnrollment,
|
||||
auth_change_revision: u64,
|
||||
) -> io::Result<()> {
|
||||
let remote_control_token = enrollment.remote_control_token.clone().ok_or_else(|| {
|
||||
io::Error::new(
|
||||
ErrorKind::InvalidInput,
|
||||
"remote control pairing is unavailable until enrollment completes",
|
||||
)
|
||||
})?;
|
||||
let expires_at = enrollment.expires_at.ok_or_else(|| {
|
||||
io::Error::new(
|
||||
ErrorKind::InvalidInput,
|
||||
"remote control pairing is unavailable until enrollment completes",
|
||||
)
|
||||
})?;
|
||||
let generation = pairing.generation.load(Ordering::Relaxed);
|
||||
*pairing
|
||||
.client
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner) =
|
||||
Some(RemoteControlPairingClient::new(
|
||||
remote_control_target,
|
||||
remote_control_token,
|
||||
enrollment.server_id.clone(),
|
||||
enrollment.environment_id.clone(),
|
||||
expires_at,
|
||||
auth_change_revision,
|
||||
generation,
|
||||
));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn enroll_remote_control_server_if_missing(
|
||||
@@ -1657,6 +2103,10 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
fn test_pairing() -> PairingClientState {
|
||||
PairingClientState::new()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn next_reconnect_delay_resets_after_cap() {
|
||||
let mut reconnect_attempt = 9;
|
||||
@@ -1721,6 +2171,41 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mark_connected_refresh_auth_change_seen_marks_loaded_auth_revision_seen() {
|
||||
let (auth_change_tx, mut auth_change_rx) = watch::channel(0u64);
|
||||
auth_change_tx.send_modify(|revision| *revision += 1);
|
||||
let auth_change_revision = *auth_change_rx.borrow();
|
||||
|
||||
assert!(mark_connected_refresh_auth_change_seen(
|
||||
&mut auth_change_rx,
|
||||
auth_change_revision
|
||||
));
|
||||
assert!(
|
||||
!auth_change_rx
|
||||
.has_changed()
|
||||
.expect("auth change watch should remain open")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mark_connected_refresh_auth_change_seen_preserves_racing_auth_change() {
|
||||
let (auth_change_tx, mut auth_change_rx) = watch::channel(0u64);
|
||||
auth_change_tx.send_modify(|revision| *revision += 1);
|
||||
let auth_change_revision = *auth_change_rx.borrow();
|
||||
auth_change_tx.send_modify(|revision| *revision += 1);
|
||||
|
||||
assert!(!mark_connected_refresh_auth_change_seen(
|
||||
&mut auth_change_rx,
|
||||
auth_change_revision
|
||||
));
|
||||
assert!(
|
||||
auth_change_rx
|
||||
.has_changed()
|
||||
.expect("auth change watch should remain open")
|
||||
);
|
||||
}
|
||||
|
||||
async fn remote_control_state_runtime(codex_home: &TempDir) -> Arc<StateRuntime> {
|
||||
StateRuntime::init(codex_home.path().to_path_buf(), "test-provider".to_string())
|
||||
.await
|
||||
@@ -1810,6 +2295,7 @@ mod tests {
|
||||
let mut enrollment = Some(remote_control_enrollment(Some(
|
||||
TEST_REMOTE_CONTROL_SERVER_TOKEN,
|
||||
)));
|
||||
let pairing = test_pairing();
|
||||
let (status_publisher, status_rx) = remote_control_status_channel();
|
||||
|
||||
let err = match connect_remote_control_websocket(
|
||||
@@ -1828,6 +2314,7 @@ mod tests {
|
||||
app_server_client_name: None,
|
||||
},
|
||||
&status_publisher,
|
||||
&pairing,
|
||||
)
|
||||
.await
|
||||
{
|
||||
@@ -1837,6 +2324,13 @@ mod tests {
|
||||
|
||||
server_task.await.expect("server task should succeed");
|
||||
assert_eq!(err.to_string(), expected_error);
|
||||
assert!(
|
||||
pairing
|
||||
.client
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
||||
.is_none()
|
||||
);
|
||||
assert_eq!(
|
||||
status_rx.borrow().clone(),
|
||||
RemoteControlStatusChangedNotification {
|
||||
@@ -1864,6 +2358,7 @@ mod tests {
|
||||
let mut enrollment = Some(remote_control_enrollment(Some(
|
||||
TEST_REMOTE_CONTROL_SERVER_TOKEN,
|
||||
)));
|
||||
let pairing = test_pairing();
|
||||
let (status_publisher, status_rx) = remote_control_status_channel();
|
||||
|
||||
let server_task = tokio::spawn(async move {
|
||||
@@ -1891,6 +2386,7 @@ mod tests {
|
||||
app_server_client_name: None,
|
||||
},
|
||||
&status_publisher,
|
||||
&pairing,
|
||||
)
|
||||
.await
|
||||
.expect_err("unauthorized response should fail the websocket connect");
|
||||
@@ -1915,6 +2411,14 @@ mod tests {
|
||||
/*remote_control_token*/ None
|
||||
))
|
||||
);
|
||||
assert_eq!(
|
||||
pairing
|
||||
.client
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
||||
.is_none(),
|
||||
true
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -1976,6 +2480,7 @@ mod tests {
|
||||
app_server_client_name: None,
|
||||
},
|
||||
&status_publisher,
|
||||
&test_pairing(),
|
||||
)
|
||||
.await
|
||||
.expect_err("unauthorized enrollment should fail the websocket connect");
|
||||
@@ -2070,6 +2575,7 @@ mod tests {
|
||||
app_server_client_name: None,
|
||||
},
|
||||
&status_publisher,
|
||||
&test_pairing(),
|
||||
)
|
||||
.await
|
||||
.expect_err("unauthorized refresh should fail the websocket connect");
|
||||
@@ -2135,6 +2641,7 @@ mod tests {
|
||||
app_server_client_name: None,
|
||||
},
|
||||
&status_publisher,
|
||||
&test_pairing(),
|
||||
)
|
||||
.await
|
||||
.expect_err("missing sqlite state db should fail remote control");
|
||||
@@ -2185,6 +2692,7 @@ mod tests {
|
||||
app_server_client_name: None,
|
||||
},
|
||||
&status_publisher,
|
||||
&test_pairing(),
|
||||
)
|
||||
.await
|
||||
.expect_err("missing auth should fail remote control");
|
||||
@@ -2221,6 +2729,7 @@ mod tests {
|
||||
let (status_publisher, _status_rx) = remote_control_status_channel();
|
||||
let shutdown_token = CancellationToken::new();
|
||||
let (_enabled_tx, enabled_rx) = watch::channel(true);
|
||||
let (pairing_refresh_tx, pairing_refresh_rx) = watch::channel(0u64);
|
||||
let websocket_task = tokio::spawn({
|
||||
let shutdown_token = shutdown_token.clone();
|
||||
async move {
|
||||
@@ -2236,6 +2745,9 @@ mod tests {
|
||||
RemoteControlChannels {
|
||||
transport_event_tx,
|
||||
status_publisher,
|
||||
pairing: test_pairing(),
|
||||
pairing_refresh_tx,
|
||||
pairing_refresh_rx,
|
||||
},
|
||||
shutdown_token,
|
||||
enabled_rx,
|
||||
|
||||
Reference in New Issue
Block a user