Compare commits

...

5 Commits

Author SHA1 Message Date
Anton Panasenko
da437fd1f9 fix(app-server-transport): schedule pairing token refresh 2026-05-29 04:19:29 -07:00
Anton Panasenko
84529b9f5f fix(app-server-transport): keep pairing auth current 2026-05-29 04:19:22 -07:00
Anton Panasenko
413ae91e53 fix(app-server-transport): clear stale pairing client 2026-05-29 04:13:39 -07:00
Anton Panasenko
19dd22aebf test(app-server-transport): cover expired pairing token 2026-05-29 04:12:53 -07:00
Anton Panasenko
f73c101e1e feat(app-server-transport): add remote control pairing transport 2026-05-29 01:31:50 -07:00
8 changed files with 1067 additions and 5 deletions

View File

@@ -44,6 +44,24 @@ pub struct RemoteControlStatusReadResponse {
pub environment_id: Option<String>,
}
#[derive(Serialize, Deserialize, Debug, Clone, Default, PartialEq, Eq, JsonSchema, TS)]
#[serde(rename_all = "camelCase")]
#[ts(export_to = "v2/")]
pub struct RemoteControlPairingStartParams {
#[serde(default, skip_serializing_if = "std::ops::Not::not")]
pub manual_code: bool,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, JsonSchema, TS)]
#[serde(rename_all = "camelCase")]
#[ts(export_to = "v2/")]
pub struct RemoteControlPairingStartResponse {
pub pairing_code: String,
pub manual_pairing_code: Option<String>,
pub environment_id: String,
pub expires_at: i64,
}
#[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq, JsonSchema, TS)]
#[serde(rename_all = "camelCase")]
#[ts(rename_all = "camelCase", export_to = "v2/")]

View File

@@ -46,6 +46,13 @@ impl RemoteControlEnrollment {
})
}
pub(super) fn server_token_refresh_delay(&self) -> Option<std::time::Duration> {
let refresh_at = self.expires_at?
- time::Duration::seconds(REMOTE_CONTROL_SERVER_TOKEN_REFRESH_SKEW_SECS);
let refresh_delay = refresh_at - OffsetDateTime::now_utc();
Some(std::time::Duration::try_from(refresh_delay).unwrap_or_default())
}
pub(super) fn clear_server_token(&mut self) {
self.remote_control_token = None;
self.expires_at = None;
@@ -211,8 +218,14 @@ fn redact_remote_control_response_body(body: &str) -> String {
let Some(body_object) = body_json.as_object_mut() else {
return body.to_string();
};
if let Some(remote_control_token) = body_object.get_mut("remote_control_token") {
*remote_control_token = serde_json::Value::String("<redacted>".to_string());
for sensitive_field in [
"remote_control_token",
"pairing_code",
"manual_pairing_code",
] {
if let Some(value) = body_object.get_mut(sensitive_field) {
*value = serde_json::Value::String("<redacted>".to_string());
}
}
body_json.to_string()
}
@@ -429,6 +442,22 @@ mod tests {
assert!(!expires_later.should_refresh_server_token());
}
#[test]
fn remote_control_enrollment_schedules_server_token_refresh_before_expiry() {
let refresh_delay = RemoteControlEnrollment {
account_id: "account_id".to_string(),
environment_id: "environment_id".to_string(),
server_id: "server_id".to_string(),
server_name: "server_name".to_string(),
remote_control_token: Some("remote-control-token".to_string()),
expires_at: Some(OffsetDateTime::now_utc() + time::Duration::seconds(31)),
}
.server_token_refresh_delay()
.expect("server token refresh should be scheduled");
assert!(refresh_delay <= std::time::Duration::from_secs(1));
}
#[test]
fn preview_remote_control_response_body_redacts_server_token() {
assert_eq!(
@@ -443,6 +472,20 @@ mod tests {
);
}
#[test]
fn preview_remote_control_response_body_redacts_pairing_codes() {
assert_eq!(
serde_json::from_str::<serde_json::Value>(&preview_remote_control_response_body(
br#"{"pairing_code":"pairing-code","manual_pairing_code":"ABCD-EFGH"}"#
))
.expect("redacted response preview should stay valid json"),
json!({
"pairing_code": "<redacted>",
"manual_pairing_code": "<redacted>",
})
);
}
#[tokio::test]
async fn persisted_remote_control_enrollment_round_trips_by_target_and_account() {
let codex_home = TempDir::new().expect("temp dir should create");

View File

@@ -1,9 +1,11 @@
mod client_tracker;
mod enroll;
mod pairing;
mod protocol;
mod segment;
mod websocket;
use self::pairing::RemoteControlPairingClient;
use crate::transport::remote_control::websocket::RemoteControlChannels;
use crate::transport::remote_control::websocket::RemoteControlStatusPublisher;
use crate::transport::remote_control::websocket::RemoteControlWebsocket;
@@ -16,6 +18,8 @@ use super::CHANNEL_CAPACITY;
use super::TransportEvent;
use super::next_connection_id;
use codex_app_server_protocol::RemoteControlConnectionStatus;
use codex_app_server_protocol::RemoteControlPairingStartParams;
use codex_app_server_protocol::RemoteControlPairingStartResponse;
use codex_app_server_protocol::RemoteControlStatusChangedNotification;
use codex_login::AuthManager;
use codex_state::StateRuntime;
@@ -26,6 +30,7 @@ use std::fmt;
use std::io;
use std::panic::AssertUnwindSafe;
use std::sync::Arc;
use std::sync::Mutex as StdMutex;
use tokio::sync::mpsc;
use tokio::sync::oneshot;
use tokio::sync::watch;
@@ -52,6 +57,8 @@ pub struct RemoteControlHandle {
enabled_tx: Arc<watch::Sender<bool>>,
status_tx: Arc<watch::Sender<RemoteControlStatusChangedNotification>>,
state_db_available: bool,
pairing_client: Arc<StdMutex<Option<RemoteControlPairingClient>>>,
auth_change_rx: Arc<StdMutex<watch::Receiver<u64>>>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
@@ -108,6 +115,7 @@ impl RemoteControlHandle {
*state = false;
changed
});
clear_pairing_client(&self.pairing_client);
let status = self.status();
info!(
@@ -129,6 +137,61 @@ impl RemoteControlHandle {
self.status_tx.subscribe()
}
pub async fn start_pairing(
&self,
params: RemoteControlPairingStartParams,
) -> io::Result<RemoteControlPairingStartResponse> {
if !*self.enabled_tx.borrow() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"remote control pairing requires remote control to be enabled",
));
}
let auth_change_revision = self.auth_change_revision();
let pairing_client = {
let mut pairing_client = self
.pairing_client
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
let current_pairing_client = pairing_client
.as_ref()
.filter(|pairing_client| {
pairing_client.matches_auth_change_revision(auth_change_revision)
})
.cloned();
if current_pairing_client.is_none() {
*pairing_client = None;
}
current_pairing_client
}
.ok_or_else(Self::pairing_unavailable_error)?;
let pairing_response = pairing_client
.start(protocol::StartRemoteControlPairingRequest {
manual_code: params.manual_code,
})
.await;
if self.auth_change_revision() != auth_change_revision {
return Err(Self::pairing_unavailable_error());
}
pairing_response
}
fn auth_change_revision(&self) -> u64 {
*self
.auth_change_rx
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.borrow()
}
fn pairing_unavailable_error() -> io::Error {
io::Error::new(
io::ErrorKind::InvalidInput,
"remote control pairing is unavailable until enrollment completes",
)
}
fn publish_status(
&self,
connection_status: RemoteControlConnectionStatus,
@@ -176,6 +239,12 @@ fn remote_control_status_with_connection_status(
}
}
fn clear_pairing_client(pairing_client: &Arc<StdMutex<Option<RemoteControlPairingClient>>>) {
*pairing_client
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner) = None;
}
pub async fn start_remote_control(
config: RemoteControlStartConfig,
state_db: Option<Arc<StateRuntime>>,
@@ -198,6 +267,9 @@ pub async fn start_remote_control(
};
let (enabled_tx, enabled_rx) = watch::channel(initial_enabled);
let pairing_client = Arc::new(StdMutex::new(None));
let websocket_pairing_client = Arc::clone(&pairing_client);
let auth_change_rx = Arc::new(StdMutex::new(auth_manager.auth_change_receiver()));
let server_name = gethostname().to_string_lossy().trim().to_string();
let remote_control_url = config.remote_control_url;
let installation_id = config.installation_id;
@@ -245,6 +317,7 @@ pub async fn start_remote_control(
RemoteControlChannels {
transport_event_tx,
status_publisher,
pairing_client: websocket_pairing_client,
},
shutdown_token,
enabled_rx,
@@ -289,10 +362,14 @@ pub async fn start_remote_control(
enabled_tx: Arc::new(enabled_tx),
status_tx: Arc::new(status_tx),
state_db_available,
pairing_client,
auth_change_rx,
},
))
}
#[cfg(test)]
mod pairing_tests;
#[cfg(test)]
mod segment_tests;
#[cfg(test)]

View File

@@ -0,0 +1,130 @@
use super::enroll::format_headers;
use super::enroll::preview_remote_control_response_body;
use super::protocol::RemoteControlTarget;
use super::protocol::StartRemoteControlPairingRequest;
use super::protocol::StartRemoteControlPairingResponse;
use codex_app_server_protocol::RemoteControlPairingStartResponse;
use codex_login::default_client::build_reqwest_client;
use std::io;
use std::io::ErrorKind;
use time::OffsetDateTime;
use time::format_description::well_known::Rfc3339;
const REMOTE_CONTROL_PAIRING_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(30);
#[derive(Debug, Clone)]
pub(super) struct RemoteControlPairingClient {
pairing_url: String,
remote_control_token: String,
server_id: String,
environment_id: String,
expires_at: OffsetDateTime,
auth_change_revision: u64,
}
impl RemoteControlPairingClient {
pub(super) fn new(
remote_control_target: &RemoteControlTarget,
remote_control_token: String,
server_id: String,
environment_id: String,
expires_at: OffsetDateTime,
auth_change_revision: u64,
) -> Self {
Self {
pairing_url: remote_control_target.pair_url.clone(),
remote_control_token,
server_id,
environment_id,
expires_at,
auth_change_revision,
}
}
pub(super) fn matches_auth_change_revision(&self, auth_change_revision: u64) -> bool {
self.auth_change_revision == auth_change_revision
}
pub(super) async fn start(
&self,
request: StartRemoteControlPairingRequest,
) -> io::Result<RemoteControlPairingStartResponse> {
if self.expires_at <= OffsetDateTime::now_utc() {
return Err(io::Error::new(
ErrorKind::InvalidInput,
"remote control pairing is unavailable because the server token expired",
));
}
let response = build_reqwest_client()
.post(&self.pairing_url)
.timeout(REMOTE_CONTROL_PAIRING_TIMEOUT)
.bearer_auth(&self.remote_control_token)
.json(&request)
.send()
.await
.map_err(|err| {
io::Error::other(format!(
"failed to start remote control pairing at `{}`: {err}",
self.pairing_url
))
})?;
let headers = response.headers().clone();
let status = response.status();
let body = response.bytes().await.map_err(|err| {
io::Error::other(format!(
"failed to read remote control pairing response from `{}`: {err}",
self.pairing_url
))
})?;
let body_preview = preview_remote_control_response_body(&body);
if !status.is_success() {
return Err(io::Error::other(format!(
"remote control pairing failed at `{}`: HTTP {status}, {}, body: {body_preview}",
self.pairing_url,
format_headers(&headers)
)));
}
let pairing = serde_json::from_slice::<StartRemoteControlPairingResponse>(&body).map_err(
|err| {
io::Error::other(format!(
"failed to parse remote control pairing response from `{}`: HTTP {status}, {}, body: {body_preview}, decode error: {err}",
self.pairing_url,
format_headers(&headers)
))
},
)?;
let StartRemoteControlPairingResponse {
pairing_code,
manual_pairing_code,
server_id,
environment_id,
expires_at,
} = pairing;
if server_id != self.server_id || environment_id != self.environment_id {
return Err(io::Error::new(
ErrorKind::InvalidData,
format!(
"remote control pairing returned mismatched enrollment: expected server_id={}, environment_id={}; got server_id={}, environment_id={}",
self.server_id, self.environment_id, server_id, environment_id
),
));
}
let expires_at = OffsetDateTime::parse(&expires_at, &Rfc3339)
.map_err(|err| {
io::Error::new(
ErrorKind::InvalidData,
format!("invalid remote control pairing expires_at: {err}"),
)
})?
.unix_timestamp();
Ok(RemoteControlPairingStartResponse {
pairing_code,
manual_pairing_code,
environment_id,
expires_at,
})
}
}

View File

@@ -0,0 +1,289 @@
use super::pairing::RemoteControlPairingClient;
use super::protocol::RemoteControlTarget;
use super::protocol::StartRemoteControlPairingRequest;
use codex_app_server_protocol::RemoteControlPairingStartResponse;
use pretty_assertions::assert_eq;
use serde_json::json;
use time::OffsetDateTime;
use tokio::io::AsyncBufReadExt;
use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt;
use tokio::io::BufReader;
use tokio::net::TcpListener;
#[tokio::test]
async fn start_remote_control_pairing_uses_server_token_and_maps_response() {
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("listener should bind");
let pair_url = format!(
"http://{}/backend-api/wham/remote/control/server/pair",
listener.local_addr().expect("listener should have addr")
);
let server_task = tokio::spawn(async move {
let (stream, _) = listener.accept().await.expect("request should arrive");
let mut reader = BufReader::new(stream);
let mut request_line = String::new();
reader
.read_line(&mut request_line)
.await
.expect("request line should read");
assert_eq!(
request_line.trim_end(),
"POST /backend-api/wham/remote/control/server/pair HTTP/1.1"
);
let mut authorization = None;
let mut content_length = None;
loop {
let mut line = String::new();
reader
.read_line(&mut line)
.await
.expect("header line should read");
if line == "\r\n" {
break;
}
let (name, value) = line
.trim_end()
.split_once(": ")
.expect("header should split");
match name.to_ascii_lowercase().as_str() {
"authorization" => authorization = Some(value.to_string()),
"content-length" => {
content_length =
Some(value.parse::<usize>().expect("content length should parse"))
}
_ => {}
}
}
assert_eq!(
authorization,
Some("Bearer remote-control-token".to_string())
);
let mut body = vec![0; content_length.expect("request should have body")];
reader
.read_exact(&mut body)
.await
.expect("request body should read");
assert_eq!(
serde_json::from_slice::<serde_json::Value>(&body)
.expect("request body should be json"),
json!({ "manual_code": true })
);
let response_body = json!({
"pairing_code": "pairing-code",
"manual_pairing_code": "ABCD-EFGH",
"server_id": "server-id",
"environment_id": "environment-id",
"expires_at": "3026-05-22T12:34:56Z",
})
.to_string();
reader
.get_mut()
.write_all(
format!(
"HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\n\r\n{response_body}",
response_body.len()
)
.as_bytes(),
)
.await
.expect("response should write");
});
let client = RemoteControlPairingClient::new(
&RemoteControlTarget {
websocket_url: "ws://unused".to_string(),
enroll_url: "http://unused".to_string(),
refresh_url: "http://unused".to_string(),
pair_url,
},
"remote-control-token".to_string(),
"server-id".to_string(),
"environment-id".to_string(),
OffsetDateTime::from_unix_timestamp(33_336_362_096).expect("future timestamp should parse"),
/*auth_change_revision*/ 0,
);
let response = client
.start(StartRemoteControlPairingRequest { manual_code: true })
.await
.expect("pairing should succeed");
server_task.await.expect("server task should finish");
assert_eq!(
response,
RemoteControlPairingStartResponse {
pairing_code: "pairing-code".to_string(),
manual_pairing_code: Some("ABCD-EFGH".to_string()),
environment_id: "environment-id".to_string(),
expires_at: 33_336_362_096,
}
);
}
#[tokio::test]
async fn start_remote_control_pairing_preserves_backend_error_context() {
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("listener should bind");
let pair_url = format!(
"http://{}/backend-api/wham/remote/control/server/pair",
listener.local_addr().expect("listener should have addr")
);
let expected_pair_url = pair_url.clone();
let server_task = tokio::spawn(async move {
let (stream, _) = listener.accept().await.expect("request should arrive");
let mut reader = BufReader::new(stream);
loop {
let mut line = String::new();
reader
.read_line(&mut line)
.await
.expect("request line should read");
if line == "\r\n" {
break;
}
}
let response_body = "pairing unavailable";
reader
.get_mut()
.write_all(
format!(
"HTTP/1.1 503 Service Unavailable\r\nx-request-id: request-123\r\ncf-ray: ray-123\r\ncontent-length: {}\r\n\r\n{response_body}",
response_body.len()
)
.as_bytes(),
)
.await
.expect("response should write");
});
let client = RemoteControlPairingClient::new(
&RemoteControlTarget {
websocket_url: "ws://unused".to_string(),
enroll_url: "http://unused".to_string(),
refresh_url: "http://unused".to_string(),
pair_url,
},
"remote-control-token".to_string(),
"server-id".to_string(),
"environment-id".to_string(),
OffsetDateTime::from_unix_timestamp(33_336_362_096).expect("future timestamp should parse"),
/*auth_change_revision*/ 0,
);
let err = client
.start(StartRemoteControlPairingRequest { manual_code: false })
.await
.expect_err("pairing should fail");
server_task.await.expect("server task should finish");
assert_eq!(
err.to_string(),
format!(
"remote control pairing failed at `{expected_pair_url}`: HTTP 503 Service Unavailable, request-id: request-123, cf-ray: ray-123, body: pairing unavailable"
)
);
}
#[tokio::test]
async fn start_remote_control_pairing_rejects_expired_server_token() {
let client = RemoteControlPairingClient::new(
&RemoteControlTarget {
websocket_url: "ws://unused".to_string(),
enroll_url: "http://unused".to_string(),
refresh_url: "http://unused".to_string(),
pair_url: "http://unused".to_string(),
},
"remote-control-token".to_string(),
"server-id".to_string(),
"environment-id".to_string(),
OffsetDateTime::from_unix_timestamp(0).expect("expired timestamp should parse"),
);
let err = client
.start(StartRemoteControlPairingRequest { manual_code: false })
.await
.expect_err("expired server token should fail pairing");
assert_eq!(err.kind(), std::io::ErrorKind::InvalidInput);
assert_eq!(
err.to_string(),
"remote control pairing is unavailable because the server token expired"
);
}
#[tokio::test]
async fn start_remote_control_pairing_rejects_mismatched_enrollment() {
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("listener should bind");
let pair_url = format!(
"http://{}/backend-api/wham/remote/control/server/pair",
listener.local_addr().expect("listener should have addr")
);
let server_task = tokio::spawn(async move {
let (stream, _) = listener.accept().await.expect("request should arrive");
let mut reader = BufReader::new(stream);
loop {
let mut line = String::new();
reader
.read_line(&mut line)
.await
.expect("request line should read");
if line == "\r\n" {
break;
}
}
let response_body = json!({
"pairing_code": "pairing-code",
"manual_pairing_code": null,
"server_id": "other-server-id",
"environment_id": "other-environment-id",
"expires_at": "3026-05-22T12:34:56Z",
})
.to_string();
reader
.get_mut()
.write_all(
format!(
"HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\n\r\n{response_body}",
response_body.len()
)
.as_bytes(),
)
.await
.expect("response should write");
});
let client = RemoteControlPairingClient::new(
&RemoteControlTarget {
websocket_url: "ws://unused".to_string(),
enroll_url: "http://unused".to_string(),
refresh_url: "http://unused".to_string(),
pair_url,
},
"remote-control-token".to_string(),
"server-id".to_string(),
"environment-id".to_string(),
OffsetDateTime::from_unix_timestamp(33_336_362_096).expect("future timestamp should parse"),
);
let err = client
.start(StartRemoteControlPairingRequest { manual_code: false })
.await
.expect_err("mismatched enrollment should fail pairing");
server_task.await.expect("server task should finish");
assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
assert_eq!(
err.to_string(),
"remote control pairing returned mismatched enrollment: expected server_id=server-id, environment_id=environment-id; got server_id=other-server-id, environment_id=other-environment-id"
);
}

View File

@@ -12,6 +12,7 @@ pub(super) struct RemoteControlTarget {
pub(super) websocket_url: String,
pub(super) enroll_url: String,
pub(super) refresh_url: String,
pub(super) pair_url: String,
}
#[derive(Debug, Serialize)]
@@ -37,6 +38,20 @@ pub(super) struct RefreshRemoteServerRequest {
pub(super) installation_id: String,
}
#[derive(Debug, Serialize)]
pub(super) struct StartRemoteControlPairingRequest {
pub(super) manual_code: bool,
}
#[derive(Debug, Deserialize)]
pub(super) struct StartRemoteControlPairingResponse {
pub(super) pairing_code: String,
pub(super) manual_pairing_code: Option<String>,
pub(super) server_id: String,
pub(super) environment_id: String,
pub(super) expires_at: String,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(transparent)]
pub struct ClientId(pub String);
@@ -189,6 +204,9 @@ pub(super) fn normalize_remote_control_url(
let refresh_url = remote_control_url
.join("wham/remote/control/server/refresh")
.map_err(map_url_parse_error)?;
let pair_url = remote_control_url
.join("wham/remote/control/server/pair")
.map_err(map_url_parse_error)?;
let mut websocket_url = remote_control_url
.join("wham/remote/control/server")
.map_err(map_url_parse_error)?;
@@ -207,6 +225,7 @@ pub(super) fn normalize_remote_control_url(
websocket_url: websocket_url.to_string(),
enroll_url: enroll_url.to_string(),
refresh_url: refresh_url.to_string(),
pair_url: pair_url.to_string(),
})
}
@@ -227,6 +246,8 @@ mod tests {
.to_string(),
refresh_url: "https://chatgpt.com/backend-api/wham/remote/control/server/refresh"
.to_string(),
pair_url: "https://chatgpt.com/backend-api/wham/remote/control/server/pair"
.to_string(),
}
);
assert_eq!(
@@ -242,6 +263,9 @@ mod tests {
refresh_url:
"https://api.chatgpt-staging.com/backend-api/wham/remote/control/server/refresh"
.to_string(),
pair_url:
"https://api.chatgpt-staging.com/backend-api/wham/remote/control/server/pair"
.to_string(),
}
);
}
@@ -258,6 +282,8 @@ mod tests {
.to_string(),
refresh_url: "http://localhost:8080/backend-api/wham/remote/control/server/refresh"
.to_string(),
pair_url: "http://localhost:8080/backend-api/wham/remote/control/server/pair"
.to_string(),
}
);
assert_eq!(
@@ -271,6 +297,8 @@ mod tests {
refresh_url:
"https://localhost:8443/backend-api/wham/remote/control/server/refresh"
.to_string(),
pair_url: "https://localhost:8443/backend-api/wham/remote/control/server/pair"
.to_string(),
}
);
}

View File

@@ -20,6 +20,7 @@ use codex_app_server_protocol::AuthMode;
use codex_app_server_protocol::ConfigWarningNotification;
use codex_app_server_protocol::JSONRPCMessage;
use codex_app_server_protocol::RemoteControlConnectionStatus;
use codex_app_server_protocol::RemoteControlPairingStartParams;
use codex_app_server_protocol::RemoteControlStatusChangedNotification;
use codex_app_server_protocol::ServerNotification;
use codex_config::types::AuthCredentialsStoreMode;
@@ -39,7 +40,10 @@ use pretty_assertions::assert_eq;
use serde_json::json;
use std::collections::BTreeMap;
use std::sync::Arc;
use std::sync::Mutex as StdMutex;
use tempfile::TempDir;
use time::OffsetDateTime;
use time::format_description::well_known::Rfc3339;
use tokio::io::AsyncBufReadExt;
use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt;
@@ -58,6 +62,7 @@ use tokio_tungstenite::tungstenite;
use tokio_util::sync::CancellationToken;
const TEST_INSTALLATION_ID: &str = "11111111-1111-4111-8111-111111111111";
const TEST_REMOTE_CONTROL_URL: &str = "http://127.0.0.1:1/backend-api/wham/remote/control";
const TEST_REMOTE_CONTROL_SERVER_TOKEN: &str = "Remote Control Token";
const TEST_REFRESHED_REMOTE_CONTROL_SERVER_TOKEN: &str = "Refreshed Remote Control Token";
const TEST_REMOTE_CONTROL_SERVER_TOKEN_EXPIRES_AT: &str = "2999-01-01T00:00:00Z";
@@ -128,16 +133,59 @@ fn test_server_name() -> String {
gethostname().to_string_lossy().trim().to_string()
}
fn remote_control_handle_with_pairing_client(
remote_control_url: &str,
auth_change_rx: watch::Receiver<u64>,
) -> RemoteControlHandle {
let (enabled_tx, _enabled_rx) = watch::channel(/*init*/ true);
let (status_tx, _status_rx) = watch::channel(RemoteControlStatusChangedNotification {
status: RemoteControlConnectionStatus::Connected,
server_name: test_server_name(),
installation_id: TEST_INSTALLATION_ID.to_string(),
environment_id: Some("env_test".to_string()),
});
let pairing_client = Arc::new(StdMutex::new(Some(RemoteControlPairingClient::new(
&normalize_remote_control_url(remote_control_url)
.expect("remote control target should normalize"),
TEST_REMOTE_CONTROL_SERVER_TOKEN.to_string(),
"srv_e_test".to_string(),
"env_test".to_string(),
OffsetDateTime::from_unix_timestamp(33_336_362_096).expect("future timestamp should parse"),
/*auth_change_revision*/ 0,
))));
RemoteControlHandle {
enabled_tx: Arc::new(enabled_tx),
status_tx: Arc::new(status_tx),
state_db_available: true,
pairing_client,
auth_change_rx: Arc::new(StdMutex::new(auth_change_rx)),
}
}
fn remote_control_server_token_response(
server_id: &str,
environment_id: &str,
remote_control_token: &str,
) -> serde_json::Value {
remote_control_server_token_response_with_expiry(
server_id,
environment_id,
remote_control_token,
TEST_REMOTE_CONTROL_SERVER_TOKEN_EXPIRES_AT,
)
}
fn remote_control_server_token_response_with_expiry(
server_id: &str,
environment_id: &str,
remote_control_token: &str,
expires_at: &str,
) -> serde_json::Value {
json!({
"server_id": server_id,
"environment_id": environment_id,
"remote_control_token": remote_control_token,
"expires_at": TEST_REMOTE_CONTROL_SERVER_TOKEN_EXPIRES_AT,
"expires_at": expires_at,
})
}
@@ -895,6 +943,299 @@ async fn remote_control_handle_enable_disable_stops_and_restarts_connections() {
let _ = remote_task.await;
}
#[tokio::test]
async fn remote_control_handle_disable_clears_stale_pairing_client() {
let remote_handle = remote_control_handle_with_pairing_client(
TEST_REMOTE_CONTROL_URL,
watch::channel(/*init*/ 0u64).1,
);
assert_eq!(
remote_handle.disable(),
RemoteControlStatusChangedNotification {
status: RemoteControlConnectionStatus::Disabled,
server_name: test_server_name(),
installation_id: TEST_INSTALLATION_ID.to_string(),
environment_id: None,
}
);
remote_handle.enable().expect("enable should succeed");
assert_eq!(
remote_handle
.start_pairing(RemoteControlPairingStartParams { manual_code: false })
.await
.expect_err("re-enabled remote control should wait for refreshed pairing auth")
.to_string(),
"remote control pairing is unavailable until enrollment completes"
);
}
#[tokio::test]
async fn remote_control_handle_rejects_pairing_client_after_auth_change() {
let (auth_change_tx, auth_change_rx) = watch::channel(/*init*/ 0u64);
let remote_handle =
remote_control_handle_with_pairing_client(TEST_REMOTE_CONTROL_URL, auth_change_rx);
auth_change_tx.send_modify(|revision| *revision += 1);
assert_eq!(
remote_handle
.start_pairing(RemoteControlPairingStartParams::default())
.await
.expect_err("pairing should wait for current-account enrollment")
.to_string(),
"remote control pairing is unavailable until enrollment completes"
);
}
#[tokio::test]
async fn remote_control_handle_discards_pairing_response_after_auth_change() {
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("listener should bind");
let remote_control_url = remote_control_url_for_listener(&listener);
let codex_home = TempDir::new().expect("temp dir should create");
save_auth(
codex_home.path(),
&remote_control_auth_dot_json(Some("account_id")),
AuthCredentialsStoreMode::File,
)
.expect("initial auth should save");
let auth_manager = AuthManager::shared(
codex_home.path().to_path_buf(),
/*enable_codex_api_key_env*/ false,
AuthCredentialsStoreMode::File,
/*chatgpt_base_url*/ None,
)
.await;
let remote_handle = remote_control_handle_with_pairing_client(
&remote_control_url,
auth_manager.auth_change_receiver(),
);
let pairing_task = tokio::spawn({
let remote_handle = remote_handle.clone();
async move {
remote_handle
.start_pairing(RemoteControlPairingStartParams::default())
.await
}
});
let pairing_request = accept_http_request(&listener).await;
assert_eq!(
pairing_request.request_line,
"POST /backend-api/wham/remote/control/server/pair HTTP/1.1"
);
assert_eq!(
pairing_request.headers.get("authorization"),
Some(&format!("Bearer {TEST_REMOTE_CONTROL_SERVER_TOKEN}"))
);
assert_eq!(
serde_json::from_str::<serde_json::Value>(&pairing_request.body)
.expect("pairing request body should deserialize"),
json!({ "manual_code": false })
);
save_auth(
codex_home.path(),
&remote_control_auth_dot_json(Some("next_account_id")),
AuthCredentialsStoreMode::File,
)
.expect("next auth should save");
auth_manager.reload().await;
respond_with_json(
pairing_request.stream,
json!({
"pairing_code": "stale-pairing-code",
"manual_pairing_code": "ABCD-EFGH",
"server_id": "srv_e_test",
"environment_id": "env_test",
"expires_at": "3026-05-22T12:34:56Z",
}),
)
.await;
assert_eq!(
pairing_task
.await
.expect("pairing task should join")
.expect_err("stale pairing response should be discarded")
.to_string(),
"remote control pairing is unavailable until enrollment completes"
);
}
#[tokio::test]
async fn remote_control_handle_clears_pairing_client_after_auth_change() {
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("listener should bind");
let remote_control_url = remote_control_url_for_listener(&listener);
let codex_home = TempDir::new().expect("temp dir should create");
save_auth(
codex_home.path(),
&remote_control_auth_dot_json(Some("account_id")),
AuthCredentialsStoreMode::File,
)
.expect("initial auth should save");
let auth_manager = AuthManager::shared(
codex_home.path().to_path_buf(),
/*enable_codex_api_key_env*/ false,
AuthCredentialsStoreMode::File,
/*chatgpt_base_url*/ None,
)
.await;
let (transport_event_tx, _transport_event_rx) =
mpsc::channel::<TransportEvent>(CHANNEL_CAPACITY);
let shutdown_token = CancellationToken::new();
let (remote_task, remote_handle) = start_remote_control(
RemoteControlStartConfig {
remote_control_url,
installation_id: TEST_INSTALLATION_ID.to_string(),
},
Some(remote_control_state_runtime(&codex_home).await),
auth_manager.clone(),
transport_event_tx,
shutdown_token.clone(),
/*app_server_client_name_rx*/ None,
/*initial_enabled*/ true,
)
.await
.expect("remote control should start");
let enroll_request = accept_http_request(&listener).await;
respond_with_json(
enroll_request.stream,
remote_control_server_token_response(
"srv_e_initial",
"env_initial",
TEST_REMOTE_CONTROL_SERVER_TOKEN,
),
)
.await;
let mut first_websocket = accept_remote_control_connection(&listener).await;
save_auth(
codex_home.path(),
&remote_control_auth_dot_json(Some("next_account_id")),
AuthCredentialsStoreMode::File,
)
.expect("next auth should save");
auth_manager.reload().await;
expect_remote_control_connection_closed(
&mut first_websocket,
"auth change should close the stale websocket",
)
.await;
assert_eq!(
remote_handle
.start_pairing(RemoteControlPairingStartParams::default())
.await
.expect_err("pairing should wait for current-account enrollment")
.to_string(),
"remote control pairing is unavailable until enrollment completes"
);
let enroll_request = accept_http_request(&listener).await;
assert_eq!(
enroll_request.request_line,
"POST /backend-api/wham/remote/control/server/enroll HTTP/1.1"
);
respond_with_json(
enroll_request.stream,
remote_control_server_token_response(
"srv_e_next",
"env_next",
TEST_REFRESHED_REMOTE_CONTROL_SERVER_TOKEN,
),
)
.await;
let mut second_websocket = accept_remote_control_connection(&listener).await;
second_websocket
.close(None)
.await
.expect("second websocket should close");
shutdown_token.cancel();
let _ = remote_task.await;
}
#[tokio::test]
async fn remote_control_refreshes_server_token_while_connected() {
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("listener should bind");
let remote_control_url = remote_control_url_for_listener(&listener);
let codex_home = TempDir::new().expect("temp dir should create");
let expiring_token_expires_at = (OffsetDateTime::now_utc() + time::Duration::seconds(31))
.format(&Rfc3339)
.expect("token expiry should format");
let (transport_event_tx, _transport_event_rx) =
mpsc::channel::<TransportEvent>(CHANNEL_CAPACITY);
let shutdown_token = CancellationToken::new();
let (remote_task, _remote_handle) = start_remote_control(
RemoteControlStartConfig {
remote_control_url,
installation_id: TEST_INSTALLATION_ID.to_string(),
},
Some(remote_control_state_runtime(&codex_home).await),
remote_control_auth_manager(),
transport_event_tx,
shutdown_token.clone(),
/*app_server_client_name_rx*/ None,
/*initial_enabled*/ true,
)
.await
.expect("remote control should start");
let enroll_request = accept_http_request(&listener).await;
respond_with_json(
enroll_request.stream,
remote_control_server_token_response_with_expiry(
"srv_e_test",
"env_test",
TEST_REMOTE_CONTROL_SERVER_TOKEN,
&expiring_token_expires_at,
),
)
.await;
let mut first_websocket = accept_remote_control_connection(&listener).await;
expect_remote_control_connection_closed(
&mut first_websocket,
"server token refresh should close the stale websocket",
)
.await;
let refresh_request = accept_http_request(&listener).await;
assert_eq!(
refresh_request.request_line,
"POST /backend-api/wham/remote/control/server/refresh HTTP/1.1"
);
respond_with_json(
refresh_request.stream,
remote_control_server_token_response(
"srv_e_test",
"env_test",
TEST_REFRESHED_REMOTE_CONTROL_SERVER_TOKEN,
),
)
.await;
let (handshake_request, mut second_websocket) =
accept_remote_control_backend_connection(&listener).await;
assert_eq!(
handshake_request.headers.get("authorization"),
Some(&format!(
"Bearer {TEST_REFRESHED_REMOTE_CONTROL_SERVER_TOKEN}"
))
);
second_websocket
.close(None)
.await
.expect("second websocket should close");
shutdown_token.cancel();
let _ = remote_task.await;
}
#[tokio::test]
async fn remote_control_transport_clears_outgoing_buffer_when_backend_acks() {
let listener = TcpListener::bind("127.0.0.1:0")
@@ -1848,6 +2189,34 @@ async fn accept_remote_control_connection(listener: &TcpListener) -> WebSocketSt
.expect("websocket handshake should succeed")
}
async fn expect_remote_control_connection_closed(
websocket: &mut WebSocketStream<TcpStream>,
timeout_message: &str,
) {
loop {
let frame = timeout(Duration::from_secs(5), websocket.next())
.await
.expect(timeout_message);
let Some(frame) = frame else {
return;
};
let Ok(frame) = frame else {
return;
};
match frame {
tungstenite::Message::Close(_) => return,
tungstenite::Message::Ping(payload) => {
websocket
.send(tungstenite::Message::Pong(payload))
.await
.expect("websocket pong should send");
}
tungstenite::Message::Pong(_) | tungstenite::Message::Frame(_) => {}
frame => panic!("unexpected websocket frame while waiting for close: {frame:?}"),
}
}
}
async fn accept_http_request(listener: &TcpListener) -> CapturedHttpRequest {
let (stream, _) = timeout(Duration::from_secs(5), listener.accept())
.await

View File

@@ -9,7 +9,9 @@ use crate::transport::remote_control::enroll::load_persisted_remote_control_enro
use crate::transport::remote_control::enroll::preview_remote_control_response_body;
use crate::transport::remote_control::enroll::refresh_remote_control_server;
use crate::transport::remote_control::enroll::update_persisted_remote_control_enrollment;
use crate::transport::remote_control::pairing::RemoteControlPairingClient;
use super::clear_pairing_client;
use super::protocol::ClientEnvelope;
use super::protocol::ClientEvent;
use super::protocol::ClientId;
@@ -39,6 +41,7 @@ use std::collections::VecDeque;
use std::io;
use std::io::ErrorKind;
use std::sync::Arc;
use std::sync::Mutex as StdMutex;
use tokio::net::TcpStream;
use tokio::sync::Mutex;
use tokio::sync::mpsc;
@@ -251,6 +254,7 @@ pub(crate) struct RemoteControlWebsocket {
enrollment: Option<RemoteControlEnrollment>,
auth_recovery: UnauthorizedRecovery,
auth_change_rx: watch::Receiver<u64>,
pairing_client: Arc<StdMutex<Option<RemoteControlPairingClient>>>,
client_tracker: Arc<Mutex<ClientTracker>>,
state: Arc<Mutex<WebsocketState>>,
server_event_rx: Arc<Mutex<mpsc::Receiver<super::QueuedServerEnvelope>>>,
@@ -282,12 +286,16 @@ enum ConnectionEndReason {
Shutdown,
Disabled,
EnabledWatchClosed,
AuthChanged,
AuthWatchClosed,
ServerTokenRefreshRequired,
ConnectionWorkerStopped,
}
pub(super) struct RemoteControlChannels {
pub(super) transport_event_tx: mpsc::Sender<TransportEvent>,
pub(super) status_publisher: RemoteControlStatusPublisher,
pub(super) pairing_client: Arc<StdMutex<Option<RemoteControlPairingClient>>>,
}
#[derive(Clone)]
@@ -404,6 +412,7 @@ impl RemoteControlWebsocket {
enrollment: None,
auth_recovery,
auth_change_rx,
pairing_client: channels.pairing_client,
client_tracker: Arc::new(Mutex::new(client_tracker)),
state: Arc::new(Mutex::new(WebsocketState {
outbound_buffer,
@@ -611,6 +620,7 @@ impl RemoteControlWebsocket {
&mut self.enrollment,
connect_options,
&self.status_publisher,
&self.pairing_client,
) => connect_result,
};
@@ -696,7 +706,7 @@ impl RemoteControlWebsocket {
}
async fn run_connection(
&self,
&mut self,
websocket_connection: WebSocketStream<MaybeTlsStream<TcpStream>>,
shutdown_token: CancellationToken,
) -> ConnectionEndReason {
@@ -720,6 +730,17 @@ impl RemoteControlWebsocket {
));
let mut enabled_rx = self.enabled_rx.clone();
let server_token_refresh_delay = self
.enrollment
.as_ref()
.and_then(RemoteControlEnrollment::server_token_refresh_delay);
let server_token_refresh = async move {
match server_token_refresh_delay {
Some(delay) => tokio::time::sleep(delay).await,
None => std::future::pending().await,
}
};
tokio::pin!(server_token_refresh);
let connection_end_reason = tokio::select! {
_ = shutdown_token.cancelled() => ConnectionEndReason::Shutdown,
changed = enabled_rx.wait_for(|enabled| !*enabled) => {
@@ -731,8 +752,18 @@ impl RemoteControlWebsocket {
ConnectionEndReason::EnabledWatchClosed
}
}
changed = self.auth_change_rx.changed() => {
if changed.is_ok() {
self.auth_recovery = self.auth_manager.unauthorized_recovery();
ConnectionEndReason::AuthChanged
} else {
ConnectionEndReason::AuthWatchClosed
}
}
_ = &mut server_token_refresh => ConnectionEndReason::ServerTokenRefreshRequired,
_ = join_set.join_next() => ConnectionEndReason::ConnectionWorkerStopped,
};
clear_pairing_client(&self.pairing_client);
shutdown_token.cancel();
Self::join_connection_workers(&mut join_set, REMOTE_CONTROL_CONNECTION_SHUTDOWN_TIMEOUT)
@@ -1229,6 +1260,7 @@ pub(super) async fn connect_remote_control_websocket(
enrollment: &mut Option<RemoteControlEnrollment>,
connect_options: RemoteControlConnectOptions<'_>,
status_publisher: &RemoteControlStatusPublisher,
pairing_client: &Arc<StdMutex<Option<RemoteControlPairingClient>>>,
) -> io::Result<(
WebSocketStream<MaybeTlsStream<TcpStream>>,
tungstenite::http::Response<()>,
@@ -1237,18 +1269,21 @@ pub(super) async fn connect_remote_control_websocket(
let Some(state_db) = state_db else {
*enrollment = None;
clear_pairing_client(pairing_client);
return Err(io::Error::new(
ErrorKind::NotFound,
"remote control requires sqlite state db",
));
};
let auth_change_revision = *auth_context.auth_change_rx.borrow();
let auth = match load_remote_control_auth(auth_context.auth_manager).await {
Ok(auth) => auth,
Err(err) => {
if err.kind() == ErrorKind::PermissionDenied {
*enrollment = None;
status_publisher.publish_environment_id(/*environment_id*/ None);
clear_pairing_client(pairing_client);
}
return Err(err);
}
@@ -1265,6 +1300,7 @@ pub(super) async fn connect_remote_control_websocket(
);
*enrollment = None;
status_publisher.publish_environment_id(/*environment_id*/ None);
clear_pairing_client(pairing_client);
}
if let Some(enrollment) = enrollment.as_ref() {
@@ -1342,6 +1378,7 @@ pub(super) async fn connect_remote_control_websocket(
connect_options.app_server_client_name,
enrollment,
status_publisher,
pairing_client,
)
.await;
enroll_remote_control_server_if_missing(
@@ -1397,8 +1434,17 @@ pub(super) async fn connect_remote_control_websocket(
})?;
match websocket_connect_result {
Ok((websocket_stream, response)) => Ok((websocket_stream, response.map(|_| ()))),
Ok((websocket_stream, response)) => {
set_pairing_client(
pairing_client,
remote_control_target,
enrollment_ref,
auth_change_revision,
)?;
Ok((websocket_stream, response.map(|_| ())))
}
Err(err) => {
clear_pairing_client(pairing_client);
match &err {
tungstenite::Error::Http(response) if response.status().as_u16() == 404 => {
info!(
@@ -1415,6 +1461,7 @@ pub(super) async fn connect_remote_control_websocket(
connect_options.app_server_client_name,
enrollment,
status_publisher,
pairing_client,
)
.await;
}
@@ -1429,6 +1476,7 @@ pub(super) async fn connect_remote_control_websocket(
)
})?
.clear_server_token();
clear_pairing_client(pairing_client);
return Err(io::Error::other(format!(
"remote control websocket auth failed with HTTP {}; refreshing server token before reconnect",
response.status()
@@ -1453,6 +1501,7 @@ async fn clear_remote_control_enrollment(
app_server_client_name: Option<&str>,
enrollment: &mut Option<RemoteControlEnrollment>,
status_publisher: &RemoteControlStatusPublisher,
pairing_client: &Arc<StdMutex<Option<RemoteControlPairingClient>>>,
) {
if let Err(clear_err) = update_persisted_remote_control_enrollment(
Some(state_db),
@@ -1467,6 +1516,39 @@ async fn clear_remote_control_enrollment(
}
*enrollment = None;
status_publisher.publish_environment_id(/*environment_id*/ None);
clear_pairing_client(pairing_client);
}
fn set_pairing_client(
pairing_client: &Arc<StdMutex<Option<RemoteControlPairingClient>>>,
remote_control_target: &RemoteControlTarget,
enrollment: &RemoteControlEnrollment,
auth_change_revision: u64,
) -> io::Result<()> {
let remote_control_token = enrollment.remote_control_token.clone().ok_or_else(|| {
io::Error::new(
ErrorKind::InvalidInput,
"remote control pairing is unavailable until enrollment completes",
)
})?;
let expires_at = enrollment.expires_at.ok_or_else(|| {
io::Error::new(
ErrorKind::InvalidInput,
"remote control pairing is unavailable until enrollment completes",
)
})?;
*pairing_client
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner) =
Some(RemoteControlPairingClient::new(
remote_control_target,
remote_control_token,
enrollment.server_id.clone(),
enrollment.environment_id.clone(),
expires_at,
auth_change_revision,
));
Ok(())
}
async fn enroll_remote_control_server_if_missing(
@@ -1657,6 +1739,10 @@ mod tests {
}
}
fn test_pairing_client() -> Arc<StdMutex<Option<RemoteControlPairingClient>>> {
Arc::new(StdMutex::new(None))
}
#[test]
fn next_reconnect_delay_resets_after_cap() {
let mut reconnect_attempt = 9;
@@ -1810,6 +1896,7 @@ mod tests {
let mut enrollment = Some(remote_control_enrollment(Some(
TEST_REMOTE_CONTROL_SERVER_TOKEN,
)));
let pairing_client = test_pairing_client();
let (status_publisher, status_rx) = remote_control_status_channel();
let err = match connect_remote_control_websocket(
@@ -1828,6 +1915,7 @@ mod tests {
app_server_client_name: None,
},
&status_publisher,
&pairing_client,
)
.await
{
@@ -1837,6 +1925,12 @@ mod tests {
server_task.await.expect("server task should succeed");
assert_eq!(err.to_string(), expected_error);
assert!(
pairing_client
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.is_none()
);
assert_eq!(
status_rx.borrow().clone(),
RemoteControlStatusChangedNotification {
@@ -1864,6 +1958,7 @@ mod tests {
let mut enrollment = Some(remote_control_enrollment(Some(
TEST_REMOTE_CONTROL_SERVER_TOKEN,
)));
let pairing_client = test_pairing_client();
let (status_publisher, status_rx) = remote_control_status_channel();
let server_task = tokio::spawn(async move {
@@ -1891,6 +1986,7 @@ mod tests {
app_server_client_name: None,
},
&status_publisher,
&pairing_client,
)
.await
.expect_err("unauthorized response should fail the websocket connect");
@@ -1915,6 +2011,13 @@ mod tests {
/*remote_control_token*/ None
))
);
assert_eq!(
pairing_client
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.is_none(),
true
);
}
#[tokio::test]
@@ -1976,6 +2079,7 @@ mod tests {
app_server_client_name: None,
},
&status_publisher,
&test_pairing_client(),
)
.await
.expect_err("unauthorized enrollment should fail the websocket connect");
@@ -2070,6 +2174,7 @@ mod tests {
app_server_client_name: None,
},
&status_publisher,
&test_pairing_client(),
)
.await
.expect_err("unauthorized refresh should fail the websocket connect");
@@ -2135,6 +2240,7 @@ mod tests {
app_server_client_name: None,
},
&status_publisher,
&test_pairing_client(),
)
.await
.expect_err("missing sqlite state db should fail remote control");
@@ -2185,6 +2291,7 @@ mod tests {
app_server_client_name: None,
},
&status_publisher,
&test_pairing_client(),
)
.await
.expect_err("missing auth should fail remote control");
@@ -2236,6 +2343,7 @@ mod tests {
RemoteControlChannels {
transport_event_tx,
status_publisher,
pairing_client: test_pairing_client(),
},
shutdown_token,
enabled_rx,