Compare commits

...

3 Commits

Author SHA1 Message Date
viyatb-oai
d7390cbf52 fix: require remote control token bundle
Co-authored-by: Codex noreply@openai.com
2026-04-20 20:03:11 -07:00
viyatb-oai
6539a56c15 fix: preserve remote control enrollment during token renewal
Co-authored-by: Codex noreply@openai.com
2026-04-20 17:14:02 -07:00
viyatb-oai
148c6feca3 Use scoped remote control server tokens 2026-04-20 17:14:02 -07:00
4 changed files with 562 additions and 152 deletions

View File

@@ -2,6 +2,9 @@ use super::protocol::EnrollRemoteServerRequest;
use super::protocol::EnrollRemoteServerResponse;
use super::protocol::RemoteControlTarget;
use axum::http::HeaderMap;
use chrono::DateTime;
use chrono::Duration;
use chrono::Utc;
use codex_login::default_client::build_reqwest_client;
use codex_state::RemoteControlEnrollmentRecord;
use codex_state::StateRuntime;
@@ -13,6 +16,8 @@ use tracing::warn;
const REMOTE_CONTROL_ENROLL_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(30);
const REMOTE_CONTROL_RESPONSE_BODY_MAX_BYTES: usize = 4096;
const REMOTE_CONTROL_SERVER_WEBSOCKET_SCOPE: &str = "remote_control_server_websocket";
const REMOTE_CONTROL_SERVER_TOKEN_REFRESH_SKEW_SECONDS: i64 = 60;
const REQUEST_ID_HEADER: &str = "x-request-id";
const OAI_REQUEST_ID_HEADER: &str = "x-oai-request-id";
@@ -28,6 +33,24 @@ pub(super) struct RemoteControlEnrollment {
pub(super) server_name: String,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub(super) struct RemoteControlServerToken {
pub(super) bearer_token: String,
pub(super) expires_at: DateTime<Utc>,
}
impl RemoteControlServerToken {
pub(super) fn expires_soon(&self, now: DateTime<Utc>) -> bool {
self.expires_at <= now + Duration::seconds(REMOTE_CONTROL_SERVER_TOKEN_REFRESH_SKEW_SECONDS)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub(super) struct RemoteControlEnrollmentResult {
pub(super) enrollment: RemoteControlEnrollment,
pub(super) server_token: RemoteControlServerToken,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub(super) struct RemoteControlConnectionAuth {
pub(super) authorization_header_value: String,
@@ -191,7 +214,8 @@ pub(crate) fn format_headers(headers: &HeaderMap) -> String {
pub(super) async fn enroll_remote_control_server(
remote_control_target: &RemoteControlTarget,
auth: &RemoteControlConnectionAuth,
) -> io::Result<RemoteControlEnrollment> {
existing_enrollment: Option<&RemoteControlEnrollment>,
) -> io::Result<RemoteControlEnrollmentResult> {
let enroll_url = &remote_control_target.enroll_url;
let server_name = gethostname().to_string_lossy().trim().to_string();
let request = EnrollRemoteServerRequest {
@@ -199,6 +223,8 @@ pub(super) async fn enroll_remote_control_server(
os: std::env::consts::OS,
arch: std::env::consts::ARCH,
app_server_version: env!("CARGO_PKG_VERSION"),
server_id: existing_enrollment.map(|enrollment| enrollment.server_id.clone()),
environment_id: existing_enrollment.map(|enrollment| enrollment.environment_id.clone()),
};
let client = build_reqwest_client();
let mut http_request = client
@@ -246,11 +272,43 @@ pub(super) async fn enroll_remote_control_server(
))
})?;
Ok(RemoteControlEnrollment {
account_id: auth.account_id.clone(),
environment_id: enrollment.environment_id,
server_id: enrollment.server_id,
server_name,
let server_token = remote_control_server_token_from_response(&enrollment)?;
Ok(RemoteControlEnrollmentResult {
enrollment: RemoteControlEnrollment {
account_id: auth.account_id.clone(),
environment_id: enrollment.environment_id,
server_id: enrollment.server_id,
server_name,
},
server_token,
})
}
fn remote_control_server_token_from_response(
enrollment: &EnrollRemoteServerResponse,
) -> io::Result<RemoteControlServerToken> {
if !enrollment
.scopes
.iter()
.any(|scope| scope == REMOTE_CONTROL_SERVER_WEBSOCKET_SCOPE)
{
return Err(io::Error::new(
ErrorKind::InvalidData,
"remote control enrollment response token is missing server websocket scope",
));
}
let expires_at = DateTime::parse_from_rfc3339(&enrollment.expires_at)
.map_err(|err| {
io::Error::new(
ErrorKind::InvalidData,
format!("invalid remote control token expires_at: {err}"),
)
})?
.with_timezone(&Utc);
Ok(RemoteControlServerToken {
bearer_token: enrollment.remote_control_token.clone(),
expires_at,
})
}
@@ -258,6 +316,8 @@ pub(super) async fn enroll_remote_control_server(
mod tests {
use super::*;
use crate::transport::remote_control::protocol::normalize_remote_control_url;
use chrono::DateTime;
use chrono::Utc;
use codex_state::StateRuntime;
use pretty_assertions::assert_eq;
use serde_json::json;
@@ -423,6 +483,34 @@ mod tests {
);
}
#[test]
fn remote_control_server_token_from_response_parses_scoped_token() {
let server_token = remote_control_server_token_from_response(&EnrollRemoteServerResponse {
server_id: "srv_e_test".to_string(),
environment_id: "env_test".to_string(),
remote_control_token: "remote-control-token".to_string(),
expires_at: "2026-04-09T12:00:00Z".to_string(),
scopes: vec!["remote_control_server_websocket".to_string()],
})
.expect("token response should parse");
assert_eq!(server_token.bearer_token, "remote-control-token");
assert!(
!server_token.expires_soon(
DateTime::parse_from_rfc3339("2026-04-09T11:58:30Z")
.expect("timestamp should parse")
.with_timezone(&Utc)
)
);
assert!(
server_token.expires_soon(
DateTime::parse_from_rfc3339("2026-04-09T11:59:30Z")
.expect("timestamp should parse")
.with_timezone(&Utc)
)
);
}
#[tokio::test]
async fn enroll_remote_control_server_parse_failure_includes_response_body() {
let listener = TcpListener::bind("127.0.0.1:0")
@@ -454,6 +542,7 @@ mod tests {
account_id: "account_id".to_string(),
is_fedramp_account: false,
},
/*existing_enrollment*/ None,
)
.await
.expect_err("invalid response should fail to parse");

View File

@@ -19,12 +19,19 @@ pub(super) struct EnrollRemoteServerRequest {
pub(super) os: &'static str,
pub(super) arch: &'static str,
pub(super) app_server_version: &'static str,
#[serde(skip_serializing_if = "Option::is_none")]
pub(super) server_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(super) environment_id: Option<String>,
}
#[derive(Debug, Deserialize)]
pub(super) struct EnrollRemoteServerResponse {
pub(super) server_id: String,
pub(super) environment_id: String,
pub(super) remote_control_token: String,
pub(super) expires_at: String,
pub(super) scopes: Vec<String>,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]

View File

@@ -114,6 +114,20 @@ fn remote_control_url_for_listener(listener: &TcpListener) -> String {
format!("http://{addr}/backend-api/")
}
fn remote_control_enroll_response(
server_id: impl Into<String>,
environment_id: impl Into<String>,
remote_control_token: impl Into<String>,
) -> serde_json::Value {
json!({
"server_id": server_id.into(),
"environment_id": environment_id.into(),
"remote_control_token": remote_control_token.into(),
"expires_at": (chrono::Utc::now() + chrono::Duration::minutes(10)).to_rfc3339(),
"scopes": ["remote_control_server_websocket"],
})
}
#[tokio::test]
async fn remote_control_transport_manages_virtual_clients_and_routes_messages() {
let listener = TcpListener::bind("127.0.0.1:0")
@@ -142,7 +156,7 @@ async fn remote_control_transport_manages_virtual_clients_and_routes_messages()
);
respond_with_json(
enroll_request.stream,
json!({ "server_id": "srv_e_test", "environment_id": "env_test" }),
remote_control_enroll_response("srv_e_test", "env_test", "remote-control-token"),
)
.await;
let mut websocket = accept_remote_control_connection(&listener).await;
@@ -408,7 +422,7 @@ async fn remote_control_transport_reconnects_after_disconnect() {
);
respond_with_json(
enroll_request.stream,
json!({ "server_id": "srv_e_test", "environment_id": "env_test" }),
remote_control_enroll_response("srv_e_test", "env_test", "remote-control-token"),
)
.await;
let mut first_websocket = accept_remote_control_connection(&listener).await;
@@ -547,7 +561,7 @@ async fn remote_control_handle_set_enabled_stops_and_restarts_connections() {
);
respond_with_json(
enroll_request.stream,
json!({ "server_id": "srv_e_test", "environment_id": "env_test" }),
remote_control_enroll_response("srv_e_test", "env_test", "remote-control-token"),
)
.await;
let mut first_websocket = accept_remote_control_connection(&listener).await;
@@ -596,7 +610,7 @@ async fn remote_control_transport_clears_outgoing_buffer_when_backend_acks() {
let enroll_request = accept_http_request(&listener).await;
respond_with_json(
enroll_request.stream,
json!({ "server_id": "srv_e_test", "environment_id": "env_test" }),
remote_control_enroll_response("srv_e_test", "env_test", "remote-control-token"),
)
.await;
let mut first_websocket = accept_remote_control_connection(&listener).await;
@@ -786,7 +800,7 @@ async fn remote_control_http_mode_enrolls_before_connecting() {
);
respond_with_json(
enroll_request.stream,
json!({ "server_id": "srv_e_test", "environment_id": "env_test" }),
remote_control_enroll_response("srv_e_test", "env_test", "remote-control-token"),
)
.await;
@@ -798,7 +812,7 @@ async fn remote_control_http_mode_enrolls_before_connecting() {
);
assert_eq!(
handshake_request.headers.get("authorization"),
Some(&"Bearer Access Token".to_string())
Some(&"Bearer remote-control-token".to_string())
);
assert_eq!(
handshake_request
@@ -938,7 +952,113 @@ async fn remote_control_http_mode_enrolls_before_connecting() {
}
#[tokio::test]
async fn remote_control_http_mode_reuses_persisted_enrollment_before_reenrolling() {
async fn remote_control_renews_server_token_with_existing_enrollment_ids() {
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("listener should bind");
let remote_control_url = remote_control_url_for_listener(&listener);
let codex_home = TempDir::new().expect("temp dir should create");
let (transport_event_tx, _transport_event_rx) =
mpsc::channel::<TransportEvent>(CHANNEL_CAPACITY);
let expected_server_name = gethostname().to_string_lossy().trim().to_string();
let shutdown_token = CancellationToken::new();
let (remote_task, _remote_handle) = start_remote_control(
remote_control_url,
Some(remote_control_state_runtime(&codex_home).await),
remote_control_auth_manager(),
transport_event_tx,
shutdown_token.clone(),
/*app_server_client_name_rx*/ None,
/*initial_enabled*/ true,
)
.await
.expect("remote control should start");
let enroll_request = accept_http_request(&listener).await;
assert_eq!(
serde_json::from_str::<serde_json::Value>(&enroll_request.body)
.expect("initial enroll body should deserialize"),
json!({
"name": expected_server_name,
"os": std::env::consts::OS,
"arch": std::env::consts::ARCH,
"app_server_version": env!("CARGO_PKG_VERSION"),
})
);
respond_with_json(
enroll_request.stream,
json!({
"server_id": "srv_e_test",
"environment_id": "env_test",
"remote_control_token": "remote-control-token-1",
"expires_at": (chrono::Utc::now() + chrono::Duration::seconds(30)).to_rfc3339(),
"scopes": ["remote_control_server_websocket"],
}),
)
.await;
let (first_handshake_request, mut first_websocket) =
accept_remote_control_backend_connection(&listener).await;
assert_eq!(
first_handshake_request.headers.get("authorization"),
Some(&"Bearer remote-control-token-1".to_string())
);
assert_eq!(
first_handshake_request.headers.get("x-codex-server-id"),
Some(&"srv_e_test".to_string())
);
first_websocket
.close(None)
.await
.expect("first websocket should close");
drop(first_websocket);
let renew_request = accept_http_request(&listener).await;
assert_eq!(
renew_request.request_line,
"POST /backend-api/wham/remote/control/server/enroll HTTP/1.1"
);
assert_eq!(
serde_json::from_str::<serde_json::Value>(&renew_request.body)
.expect("renew enroll body should deserialize"),
json!({
"name": expected_server_name,
"os": std::env::consts::OS,
"arch": std::env::consts::ARCH,
"app_server_version": env!("CARGO_PKG_VERSION"),
"server_id": "srv_e_test",
"environment_id": "env_test",
})
);
respond_with_json(
renew_request.stream,
json!({
"server_id": "srv_e_test",
"environment_id": "env_test",
"remote_control_token": "remote-control-token-2",
"expires_at": (chrono::Utc::now() + chrono::Duration::minutes(10)).to_rfc3339(),
"scopes": ["remote_control_server_websocket"],
}),
)
.await;
let (second_handshake_request, _second_websocket) =
accept_remote_control_backend_connection(&listener).await;
assert_eq!(
second_handshake_request.headers.get("authorization"),
Some(&"Bearer remote-control-token-2".to_string())
);
assert_eq!(
second_handshake_request.headers.get("x-codex-server-id"),
Some(&"srv_e_test".to_string())
);
shutdown_token.cancel();
let _ = remote_task.await;
}
#[tokio::test]
async fn remote_control_http_mode_renews_token_for_persisted_enrollment_before_connecting() {
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("listener should bind");
@@ -947,6 +1067,7 @@ async fn remote_control_http_mode_reuses_persisted_enrollment_before_reenrolling
let state_db = remote_control_state_runtime(&codex_home).await;
let remote_control_target =
normalize_remote_control_url(&remote_control_url).expect("target should parse");
let expected_server_name = gethostname().to_string_lossy().trim().to_string();
let persisted_enrollment = RemoteControlEnrollment {
account_id: "account_id".to_string(),
environment_id: "env_persisted".to_string(),
@@ -978,6 +1099,33 @@ async fn remote_control_http_mode_reuses_persisted_enrollment_before_reenrolling
.await
.expect("remote control should start");
let renew_request = accept_http_request(&listener).await;
assert_eq!(
renew_request.request_line,
"POST /backend-api/wham/remote/control/server/enroll HTTP/1.1"
);
assert_eq!(
serde_json::from_str::<serde_json::Value>(&renew_request.body)
.expect("renew enroll body should deserialize"),
json!({
"name": expected_server_name,
"os": std::env::consts::OS,
"arch": std::env::consts::ARCH,
"app_server_version": env!("CARGO_PKG_VERSION"),
"server_id": persisted_enrollment.server_id.clone(),
"environment_id": persisted_enrollment.environment_id.clone(),
})
);
respond_with_json(
renew_request.stream,
remote_control_enroll_response(
persisted_enrollment.server_id.clone(),
persisted_enrollment.environment_id.clone(),
"remote-control-token",
),
)
.await;
let (handshake_request, _websocket) = accept_remote_control_backend_connection(&listener).await;
assert_eq!(
handshake_request.path,
@@ -987,6 +1135,10 @@ async fn remote_control_http_mode_reuses_persisted_enrollment_before_reenrolling
handshake_request.headers.get("x-codex-server-id"),
Some(&persisted_enrollment.server_id)
);
assert_eq!(
handshake_request.headers.get("authorization"),
Some(&"Bearer remote-control-token".to_string())
);
assert_eq!(
load_persisted_remote_control_enrollment(
Some(state_db.as_ref()),
@@ -995,7 +1147,10 @@ async fn remote_control_http_mode_reuses_persisted_enrollment_before_reenrolling
/*app_server_client_name*/ None,
)
.await,
Some(persisted_enrollment)
Some(RemoteControlEnrollment {
server_name: expected_server_name,
..persisted_enrollment
})
);
shutdown_token.cancel();
@@ -1050,11 +1205,30 @@ async fn remote_control_stdio_mode_waits_for_client_name_before_connecting() {
.expect_err("remote control should wait for the stdio client name");
let _ = app_server_client_name_tx.send(app_server_client_name.to_string());
let renew_request = accept_http_request(&listener).await;
assert_eq!(
renew_request.request_line,
"POST /backend-api/wham/remote/control/server/enroll HTTP/1.1"
);
respond_with_json(
renew_request.stream,
remote_control_enroll_response(
persisted_enrollment.server_id.clone(),
persisted_enrollment.environment_id.clone(),
"remote-control-token",
),
)
.await;
let (handshake_request, _websocket) = accept_remote_control_backend_connection(&listener).await;
assert_eq!(
handshake_request.headers.get("x-codex-server-id"),
Some(&persisted_enrollment.server_id)
);
assert_eq!(
handshake_request.headers.get("authorization"),
Some(&"Bearer remote-control-token".to_string())
);
shutdown_token.cancel();
let _ = remote_task.await;
@@ -1120,10 +1294,11 @@ async fn remote_control_waits_for_account_id_before_enrolling() {
);
respond_with_json(
enroll_request.stream,
json!({
"server_id": expected_enrollment.server_id,
"environment_id": expected_enrollment.environment_id,
}),
remote_control_enroll_response(
expected_enrollment.server_id.clone(),
expected_enrollment.environment_id.clone(),
"remote-control-token",
),
)
.await;
@@ -1158,7 +1333,7 @@ async fn remote_control_http_mode_clears_stale_persisted_enrollment_after_404()
account_id: "account_id".to_string(),
environment_id: "env_refreshed".to_string(),
server_id: "srv_e_refreshed".to_string(),
server_name: expected_server_name,
server_name: expected_server_name.clone(),
};
update_persisted_remote_control_enrollment(
Some(state_db.as_ref()),
@@ -1185,6 +1360,33 @@ async fn remote_control_http_mode_clears_stale_persisted_enrollment_after_404()
.await
.expect("remote control should start");
let stale_token_request = accept_http_request(&listener).await;
assert_eq!(
stale_token_request.request_line,
"POST /backend-api/wham/remote/control/server/enroll HTTP/1.1"
);
assert_eq!(
serde_json::from_str::<serde_json::Value>(&stale_token_request.body)
.expect("stale token enroll body should deserialize"),
json!({
"name": expected_server_name,
"os": std::env::consts::OS,
"arch": std::env::consts::ARCH,
"app_server_version": env!("CARGO_PKG_VERSION"),
"server_id": stale_enrollment.server_id.clone(),
"environment_id": stale_enrollment.environment_id.clone(),
})
);
respond_with_json(
stale_token_request.stream,
remote_control_enroll_response(
stale_enrollment.server_id.clone(),
stale_enrollment.environment_id.clone(),
"stale-remote-control-token",
),
)
.await;
let websocket_request = accept_http_request(&listener).await;
assert_eq!(
websocket_request.request_line,
@@ -1194,6 +1396,10 @@ async fn remote_control_http_mode_clears_stale_persisted_enrollment_after_404()
websocket_request.headers.get("x-codex-server-id"),
Some(&stale_enrollment.server_id)
);
assert_eq!(
websocket_request.headers.get("authorization"),
Some(&"Bearer stale-remote-control-token".to_string())
);
respond_with_status(websocket_request.stream, "404 Not Found", "").await;
let enroll_request = accept_http_request(&listener).await;
@@ -1201,12 +1407,23 @@ async fn remote_control_http_mode_clears_stale_persisted_enrollment_after_404()
enroll_request.request_line,
"POST /backend-api/wham/remote/control/server/enroll HTTP/1.1"
);
assert_eq!(
serde_json::from_str::<serde_json::Value>(&enroll_request.body)
.expect("refreshed enroll body should deserialize"),
json!({
"name": refreshed_enrollment.server_name.clone(),
"os": std::env::consts::OS,
"arch": std::env::consts::ARCH,
"app_server_version": env!("CARGO_PKG_VERSION"),
})
);
respond_with_json(
enroll_request.stream,
json!({
"server_id": refreshed_enrollment.server_id,
"environment_id": refreshed_enrollment.environment_id,
}),
remote_control_enroll_response(
refreshed_enrollment.server_id.clone(),
refreshed_enrollment.environment_id.clone(),
"refreshed-remote-control-token",
),
)
.await;
@@ -1215,6 +1432,10 @@ async fn remote_control_http_mode_clears_stale_persisted_enrollment_after_404()
handshake_request.headers.get("x-codex-server-id"),
Some(&refreshed_enrollment.server_id)
);
assert_eq!(
handshake_request.headers.get("authorization"),
Some(&"Bearer refreshed-remote-control-token".to_string())
);
assert_eq!(
load_persisted_remote_control_enrollment(
Some(state_db.as_ref()),

View File

@@ -3,6 +3,7 @@ use crate::transport::remote_control::client_tracker::ClientTracker;
use crate::transport::remote_control::client_tracker::REMOTE_CONTROL_IDLE_SWEEP_INTERVAL;
use crate::transport::remote_control::enroll::RemoteControlConnectionAuth;
use crate::transport::remote_control::enroll::RemoteControlEnrollment;
use crate::transport::remote_control::enroll::RemoteControlServerToken;
use crate::transport::remote_control::enroll::enroll_remote_control_server;
use crate::transport::remote_control::enroll::format_headers;
use crate::transport::remote_control::enroll::load_persisted_remote_control_enrollment;
@@ -17,6 +18,7 @@ use super::protocol::ServerEnvelope;
use super::protocol::StreamId;
use axum::http::HeaderValue;
use base64::Engine;
use chrono::Utc;
use codex_core::util::backoff;
use codex_login::AuthManager;
use codex_login::UnauthorizedRecovery;
@@ -113,6 +115,13 @@ struct WebsocketState {
next_seq_id_by_stream: HashMap<(ClientId, StreamId), u64>,
}
#[derive(Default)]
struct RemoteControlEnrollmentState {
enrollment: Option<RemoteControlEnrollment>,
server_token: Option<RemoteControlServerToken>,
server_token_refresh_required: bool,
}
pub(crate) struct RemoteControlWebsocket {
remote_control_url: String,
remote_control_target: Option<RemoteControlTarget>,
@@ -120,7 +129,7 @@ pub(crate) struct RemoteControlWebsocket {
auth_manager: Arc<AuthManager>,
shutdown_token: CancellationToken,
reconnect_attempt: u64,
enrollment: Option<RemoteControlEnrollment>,
enrollment_state: RemoteControlEnrollmentState,
auth_recovery: UnauthorizedRecovery,
client_tracker: Arc<Mutex<ClientTracker>>,
state: Arc<Mutex<WebsocketState>>,
@@ -191,7 +200,7 @@ impl RemoteControlWebsocket {
auth_manager,
shutdown_token,
reconnect_attempt: 0,
enrollment: None,
enrollment_state: RemoteControlEnrollmentState::default(),
auth_recovery,
client_tracker: Arc::new(Mutex::new(client_tracker)),
state: Arc::new(Mutex::new(WebsocketState {
@@ -304,16 +313,14 @@ impl RemoteControlWebsocket {
}
return ConnectOutcome::Disabled;
}
connect_result = connect_remote_control_websocket_with_options(
ConnectRemoteControlWebsocketOptions {
remote_control_target: &remote_control_target,
state_db: self.state_db.as_deref(),
auth_manager: &self.auth_manager,
auth_recovery: &mut self.auth_recovery,
enrollment: &mut self.enrollment,
subscribe_cursor: subscribe_cursor.as_deref(),
app_server_client_name,
},
connect_result = connect_remote_control_websocket(
&remote_control_target,
self.state_db.as_deref(),
&self.auth_manager,
&mut self.auth_recovery,
&mut self.enrollment_state,
subscribe_cursor.as_deref(),
app_server_client_name,
) => connect_result,
};
@@ -683,6 +690,7 @@ fn build_remote_control_websocket_request(
websocket_url: &str,
enrollment: &RemoteControlEnrollment,
auth: &RemoteControlConnectionAuth,
server_token: &RemoteControlServerToken,
subscribe_cursor: Option<&str>,
) -> io::Result<tungstenite::http::Request<()>> {
let mut request = websocket_url.into_client_request().map_err(|err| {
@@ -703,7 +711,8 @@ fn build_remote_control_websocket_request(
"x-codex-protocol-version",
REMOTE_CONTROL_PROTOCOL_VERSION,
)?;
set_remote_control_header(headers, "authorization", &auth.authorization_header_value)?;
let authorization_header_value = format!("Bearer {}", server_token.bearer_token);
set_remote_control_header(headers, "authorization", &authorization_header_value)?;
set_remote_control_header(headers, REMOTE_CONTROL_ACCOUNT_ID_HEADER, &auth.account_id)?;
if auth.is_fedramp_account {
set_remote_control_header(headers, REMOTE_CONTROL_FEDRAMP_HEADER, "true")?;
@@ -774,13 +783,12 @@ pub(crate) async fn load_remote_control_auth(
})
}
#[cfg(test)]
pub(super) async fn connect_remote_control_websocket(
async fn connect_remote_control_websocket(
remote_control_target: &RemoteControlTarget,
state_db: Option<&StateRuntime>,
auth_manager: &Arc<AuthManager>,
auth_recovery: &mut UnauthorizedRecovery,
enrollment: &mut Option<RemoteControlEnrollment>,
enrollment_state: &mut RemoteControlEnrollmentState,
subscribe_cursor: Option<&str>,
app_server_client_name: Option<&str>,
) -> io::Result<(
@@ -792,7 +800,7 @@ pub(super) async fn connect_remote_control_websocket(
state_db,
auth_manager,
auth_recovery,
enrollment,
enrollment_state,
subscribe_cursor,
app_server_client_name,
})
@@ -804,7 +812,7 @@ struct ConnectRemoteControlWebsocketOptions<'a> {
state_db: Option<&'a StateRuntime>,
auth_manager: &'a Arc<AuthManager>,
auth_recovery: &'a mut UnauthorizedRecovery,
enrollment: &'a mut Option<RemoteControlEnrollment>,
enrollment_state: &'a mut RemoteControlEnrollmentState,
subscribe_cursor: Option<&'a str>,
app_server_client_name: Option<&'a str>,
}
@@ -820,7 +828,7 @@ async fn connect_remote_control_websocket_with_options(
state_db,
auth_manager,
auth_recovery,
enrollment,
enrollment_state,
subscribe_cursor,
app_server_client_name,
} = options;
@@ -828,21 +836,27 @@ async fn connect_remote_control_websocket_with_options(
ensure_rustls_crypto_provider();
let auth = load_remote_control_auth(auth_manager).await?;
let enrollment_account_id = enrollment.as_ref().map(|enrollment| &enrollment.account_id);
let enrollment_account_id = enrollment_state
.enrollment
.as_ref()
.map(|enrollment| &enrollment.account_id);
if enrollment_account_id.is_some_and(|account_id| account_id != &auth.account_id) {
info!(
"clearing in-memory remote control enrollment because account id changed: websocket_url={}, previous_account_id={:?}, current_account_id={:?}",
remote_control_target.websocket_url,
enrollment
enrollment_state
.enrollment
.as_ref()
.map(|enrollment| enrollment.account_id.as_str()),
auth.account_id
);
*enrollment = None;
enrollment_state.enrollment = None;
enrollment_state.server_token = None;
enrollment_state.server_token_refresh_required = false;
}
if enrollment.is_none() {
*enrollment = load_persisted_remote_control_enrollment(
if enrollment_state.enrollment.is_none() {
enrollment_state.enrollment = load_persisted_remote_control_enrollment(
state_db,
remote_control_target,
&auth.account_id,
@@ -851,24 +865,56 @@ async fn connect_remote_control_websocket_with_options(
.await;
}
if enrollment.is_none() {
let should_refresh_server_token = enrollment_state.enrollment.is_some()
&& (enrollment_state.server_token_refresh_required
|| enrollment_state
.server_token
.as_ref()
.is_none_or(|server_token| server_token.expires_soon(Utc::now())));
if should_refresh_server_token {
info!(
"creating new remote control enrollment: websocket_url={}, enroll_url={}, account_id={}",
remote_control_target.websocket_url, remote_control_target.enroll_url, auth.account_id
"remote control server token needs refresh; renewing enrollment token: websocket_url={}, account_id={}",
remote_control_target.websocket_url, auth.account_id
);
let new_enrollment = match enroll_remote_control_server(remote_control_target, &auth).await
{
Ok(new_enrollment) => new_enrollment,
Err(err)
if err.kind() == ErrorKind::PermissionDenied
&& recover_remote_control_auth(auth_recovery).await =>
enrollment_state.server_token = None;
enrollment_state.server_token_refresh_required = true;
}
if enrollment_state.enrollment.is_none() || should_refresh_server_token {
let existing_enrollment = enrollment_state.enrollment.as_ref();
if let Some(existing_enrollment) = existing_enrollment {
info!(
"renewing remote control server token for existing enrollment: websocket_url={}, enroll_url={}, account_id={}, server_id={}, environment_id={}",
remote_control_target.websocket_url,
remote_control_target.enroll_url,
auth.account_id,
existing_enrollment.server_id,
existing_enrollment.environment_id
);
} else {
info!(
"creating new remote control enrollment: websocket_url={}, enroll_url={}, account_id={}",
remote_control_target.websocket_url,
remote_control_target.enroll_url,
auth.account_id
);
}
let enrollment_result =
match enroll_remote_control_server(remote_control_target, &auth, existing_enrollment)
.await
{
return Err(io::Error::other(format!(
"{err}; retrying after auth recovery"
)));
}
Err(err) => return Err(err),
};
Ok(enrollment_result) => enrollment_result,
Err(err)
if err.kind() == ErrorKind::PermissionDenied
&& recover_remote_control_auth(auth_recovery).await =>
{
return Err(io::Error::other(format!(
"{err}; retrying after auth recovery"
)));
}
Err(err) => return Err(err),
};
let new_enrollment = enrollment_result.enrollment;
if let Err(err) = update_persisted_remote_control_enrollment(
state_db,
remote_control_target,
@@ -887,16 +933,22 @@ async fn connect_remote_control_websocket_with_options(
new_enrollment.server_id,
new_enrollment.environment_id
);
*enrollment = Some(new_enrollment);
enrollment_state.enrollment = Some(new_enrollment);
enrollment_state.server_token = Some(enrollment_result.server_token);
enrollment_state.server_token_refresh_required = false;
}
let enrollment_ref = enrollment.as_ref().ok_or_else(|| {
let enrollment_ref = enrollment_state.enrollment.as_ref().ok_or_else(|| {
io::Error::other("missing remote control enrollment after enrollment step")
})?;
let server_token = enrollment_state.server_token.as_ref().ok_or_else(|| {
io::Error::other("missing remote control server token after enrollment step")
})?;
let request = build_remote_control_websocket_request(
&remote_control_target.websocket_url,
enrollment_ref,
&auth,
server_token,
subscribe_cursor,
)?;
@@ -925,17 +977,23 @@ async fn connect_remote_control_websocket_with_options(
"failed to clear stale remote control enrollment in sqlite state db: {clear_err}"
);
}
*enrollment = None;
enrollment_state.enrollment = None;
enrollment_state.server_token = None;
enrollment_state.server_token_refresh_required = false;
}
tungstenite::Error::Http(response)
if matches!(response.status().as_u16(), 401 | 403) =>
{
if recover_remote_control_auth(auth_recovery).await {
return Err(io::Error::other(format!(
"remote control websocket auth failed with HTTP {}; retrying after auth recovery",
response.status()
)));
}
info!(
"remote control websocket token auth failed with HTTP {}; renewing token before reconnecting",
response.status()
);
enrollment_state.server_token = None;
enrollment_state.server_token_refresh_required = true;
return Err(io::Error::other(format!(
"remote control websocket token auth failed with HTTP {}; re-enrolling",
response.status()
)));
}
_ => {}
}
@@ -1087,8 +1145,56 @@ mod tests {
}
}
fn remote_control_server_token(bearer_token: &str) -> RemoteControlServerToken {
RemoteControlServerToken {
bearer_token: bearer_token.to_string(),
expires_at: Utc::now() + chrono::Duration::minutes(10),
}
}
#[test]
fn build_remote_control_websocket_request_uses_server_token() {
let enrollment = RemoteControlEnrollment {
account_id: "account_id".to_string(),
environment_id: "env_test".to_string(),
server_id: "srv_e_test".to_string(),
server_name: "test-server".to_string(),
};
let auth = RemoteControlConnectionAuth {
authorization_header_value: "AgentAssertion assertion".to_string(),
account_id: "account_id".to_string(),
is_fedramp_account: false,
};
let server_token = remote_control_server_token("remote-control-token");
let request = build_remote_control_websocket_request(
"ws://localhost:8080/backend-api/wham/remote/control/server",
&enrollment,
&auth,
&server_token,
/*subscribe_cursor*/ None,
)
.expect("websocket request should build");
assert_eq!(
request
.headers()
.get("authorization")
.expect("authorization header should exist"),
"Bearer remote-control-token"
);
assert_eq!(
request
.headers()
.get(REMOTE_CONTROL_ACCOUNT_ID_HEADER)
.expect("account id header should exist"),
"account_id"
);
}
#[test]
fn build_remote_control_websocket_request_includes_fedramp_header() {
let server_token = remote_control_server_token("remote-control-token");
let request = build_remote_control_websocket_request(
"ws://127.0.0.1/backend-api/wham/remote/control/server",
&RemoteControlEnrollment {
@@ -1102,6 +1208,7 @@ mod tests {
account_id: "account_id".to_string(),
is_fedramp_account: true,
},
&server_token,
/*subscribe_cursor*/ None,
)
.expect("request should build");
@@ -1115,6 +1222,60 @@ mod tests {
);
}
#[tokio::test]
async fn connect_remote_control_websocket_keeps_enrollment_after_server_token_auth_failure() {
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("listener should bind");
let remote_control_url = remote_control_url_for_listener(&listener);
let remote_control_target =
normalize_remote_control_url(&remote_control_url).expect("target should parse");
let server_task = tokio::spawn(async move {
let (stream, request_line) = accept_http_request(&listener).await;
assert_eq!(
request_line,
"GET /backend-api/wham/remote/control/server HTTP/1.1"
);
respond_with_status_and_headers(stream, "401 Unauthorized", &[], "unauthorized").await;
});
let codex_home = TempDir::new().expect("temp dir should create");
let state_db = remote_control_state_runtime(&codex_home).await;
let auth_manager = remote_control_auth_manager();
let mut auth_recovery = auth_manager.unauthorized_recovery();
let enrollment = RemoteControlEnrollment {
account_id: "account_id".to_string(),
environment_id: "env_test".to_string(),
server_id: "srv_e_test".to_string(),
server_name: "test-server".to_string(),
};
let mut enrollment_state = RemoteControlEnrollmentState {
enrollment: Some(enrollment.clone()),
server_token: Some(remote_control_server_token("remote-control-token")),
..Default::default()
};
let err = connect_remote_control_websocket(
&remote_control_target,
Some(state_db.as_ref()),
&auth_manager,
&mut auth_recovery,
&mut enrollment_state,
/*subscribe_cursor*/ None,
/*app_server_client_name*/ None,
)
.await
.expect_err("server token auth failure should fail this connect attempt");
server_task.await.expect("server task should succeed");
assert_eq!(
err.to_string(),
"remote control websocket token auth failed with HTTP 401 Unauthorized; re-enrolling"
);
assert_eq!(enrollment_state.enrollment, Some(enrollment));
assert_eq!(enrollment_state.server_token, None);
assert!(enrollment_state.server_token_refresh_required);
}
#[tokio::test]
async fn connect_remote_control_websocket_includes_http_error_details() {
let listener = TcpListener::bind("127.0.0.1:0")
@@ -1145,19 +1306,23 @@ mod tests {
let state_db = remote_control_state_runtime(&codex_home).await;
let auth_manager = remote_control_auth_manager();
let mut auth_recovery = auth_manager.unauthorized_recovery();
let mut enrollment = Some(RemoteControlEnrollment {
account_id: "account_id".to_string(),
environment_id: "env_test".to_string(),
server_id: "srv_e_test".to_string(),
server_name: "test-server".to_string(),
});
let mut enrollment_state = RemoteControlEnrollmentState {
enrollment: Some(RemoteControlEnrollment {
account_id: "account_id".to_string(),
environment_id: "env_test".to_string(),
server_id: "srv_e_test".to_string(),
server_name: "test-server".to_string(),
}),
server_token: Some(remote_control_server_token("remote-control-token")),
..Default::default()
};
let err = match connect_remote_control_websocket(
&remote_control_target,
Some(state_db.as_ref()),
&auth_manager,
&mut auth_recovery,
&mut enrollment,
&mut enrollment_state,
/*subscribe_cursor*/ None,
/*app_server_client_name*/ None,
)
@@ -1171,78 +1336,6 @@ mod tests {
assert_eq!(err.to_string(), expected_error);
}
#[tokio::test]
async fn connect_remote_control_websocket_recovers_after_unauthorized_reload() {
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("listener should bind");
let remote_control_url = remote_control_url_for_listener(&listener);
let remote_control_target =
normalize_remote_control_url(&remote_control_url).expect("target should parse");
let codex_home = TempDir::new().expect("temp dir should create");
save_auth(
codex_home.path(),
&remote_control_auth_dot_json("stale-token"),
AuthCredentialsStoreMode::File,
)
.expect("stale auth should save");
let state_db = remote_control_state_runtime(&codex_home).await;
let auth_manager = AuthManager::shared(
codex_home.path().to_path_buf(),
/*enable_codex_api_key_env*/ false,
AuthCredentialsStoreMode::File,
);
let mut auth_recovery = auth_manager.unauthorized_recovery();
let mut enrollment = Some(RemoteControlEnrollment {
account_id: "account_id".to_string(),
environment_id: "env_test".to_string(),
server_id: "srv_e_test".to_string(),
server_name: "test-server".to_string(),
});
save_auth(
codex_home.path(),
&remote_control_auth_dot_json("fresh-token"),
AuthCredentialsStoreMode::File,
)
.expect("fresh auth should save");
let server_task = tokio::spawn(async move {
let (stream, request_line) = accept_http_request(&listener).await;
assert_eq!(
request_line,
"GET /backend-api/wham/remote/control/server HTTP/1.1"
);
respond_with_status_and_headers(stream, "401 Unauthorized", &[], "unauthorized").await;
});
let err = connect_remote_control_websocket(
&remote_control_target,
Some(state_db.as_ref()),
&auth_manager,
&mut auth_recovery,
&mut enrollment,
/*subscribe_cursor*/ None,
/*app_server_client_name*/ None,
)
.await
.expect_err("unauthorized response should fail the websocket connect");
server_task.await.expect("server task should succeed");
assert_eq!(
err.to_string(),
"remote control websocket auth failed with HTTP 401 Unauthorized; retrying after auth recovery"
);
assert_eq!(
auth_manager
.auth()
.await
.expect("auth should remain available")
.get_token()
.expect("token should be readable"),
"fresh-token"
);
}
#[tokio::test]
async fn connect_remote_control_websocket_recovers_after_unauthorized_enrollment() {
let listener = TcpListener::bind("127.0.0.1:0")
@@ -1274,7 +1367,7 @@ mod tests {
AuthCredentialsStoreMode::File,
);
let mut auth_recovery = auth_manager.unauthorized_recovery();
let mut enrollment = None;
let mut enrollment_state = RemoteControlEnrollmentState::default();
save_auth(
codex_home.path(),
&remote_control_auth_dot_json("fresh-token"),
@@ -1287,7 +1380,7 @@ mod tests {
Some(state_db.as_ref()),
&auth_manager,
&mut auth_recovery,
&mut enrollment,
&mut enrollment_state,
/*subscribe_cursor*/ None,
/*app_server_client_name*/ None,
)