Compare commits

...

1 Commits

Author SHA1 Message Date
starr-openai
7e5d6a8b37 Add exec-server websocket connection token 2026-05-28 19:01:50 -07:00
7 changed files with 160 additions and 11 deletions

View File

@@ -24,6 +24,12 @@ The CLI entrypoint supports:
- `ws://IP:PORT` (default)
- `--remote URL --environment-id ID [--name NAME]`
For direct websocket listeners, setting
`CODEX_EXEC_SERVER_CONNECTION_TOKEN` requires clients to send the same value
as a bearer `Authorization` header. The printed listen URL and `/readyz`
remain secret-free. Direct websocket `ExecServerClient` connections use the
same env var.
Remote mode registers the local exec-server with the environment registry,
then reconnects to the service-provided rendezvous websocket as the environment.
It uses the standard Codex ChatGPT sign-in state; run `codex login` first when

View File

@@ -4,6 +4,8 @@ use tokio::io::BufReader;
use tokio::process::Command;
use tokio::time::timeout;
use tokio_tungstenite::connect_async;
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
use tokio_tungstenite::tungstenite::http::header::AUTHORIZATION;
use tracing::debug;
use tracing::warn;
@@ -15,6 +17,7 @@ use crate::client_api::RemoteExecServerConnectArgs;
use crate::client_api::StdioExecServerCommand;
use crate::client_api::StdioExecServerConnectArgs;
use crate::connection::JsonRpcConnection;
use crate::connection_token::connection_token_from_env;
use crate::relay::harness_connection_from_websocket;
const ENVIRONMENT_CLIENT_NAME: &str = "codex-environment";
@@ -59,7 +62,23 @@ impl ExecServerClient {
ensure_rustls_crypto_provider();
let websocket_url = args.websocket_url.clone();
let connect_timeout = args.connect_timeout;
let (stream, _) = timeout(connect_timeout, connect_async(websocket_url.as_str()))
let mut request = websocket_url
.as_str()
.into_client_request()
.map_err(|err| {
ExecServerError::Protocol(format!(
"invalid exec-server websocket URL `{websocket_url}`: {err}"
))
})?;
if !is_rendezvous_harness_url(&websocket_url)
&& let Some(connection_token) =
connection_token_from_env().map_err(ExecServerError::Protocol)?
{
request
.headers_mut()
.insert(AUTHORIZATION, connection_token);
}
let (stream, _) = timeout(connect_timeout, connect_async(request))
.await
.map_err(|_| ExecServerError::WebSocketConnectTimeout {
url: websocket_url.clone(),

View File

@@ -0,0 +1,20 @@
use axum::http::HeaderValue;
const CONNECTION_TOKEN_ENV_VAR: &str = "CODEX_EXEC_SERVER_CONNECTION_TOKEN";
pub(crate) fn connection_token_from_env() -> Result<Option<HeaderValue>, String> {
let token = match std::env::var(CONNECTION_TOKEN_ENV_VAR) {
Ok(token) => token,
Err(std::env::VarError::NotPresent) => return Ok(None),
Err(std::env::VarError::NotUnicode(_)) => {
return Err(format!("{CONNECTION_TOKEN_ENV_VAR} must be valid Unicode"));
}
};
if token.is_empty() {
return Err(format!("{CONNECTION_TOKEN_ENV_VAR} must not be empty"));
}
let mut header = HeaderValue::from_str(&format!("Bearer {token}"))
.map_err(|_| format!("{CONNECTION_TOKEN_ENV_VAR} must be a valid HTTP header value"))?;
header.set_sensitive(true);
Ok(Some(header))
}

View File

@@ -2,6 +2,7 @@ mod client;
mod client_api;
mod client_transport;
mod connection;
mod connection_token;
mod environment;
mod environment_provider;
mod environment_toml;

View File

@@ -3,8 +3,11 @@ use axum::body::Body;
use axum::extract::ConnectInfo;
use axum::extract::State;
use axum::extract::ws::WebSocketUpgrade;
use axum::http::HeaderMap;
use axum::http::HeaderValue;
use axum::http::Request;
use axum::http::StatusCode;
use axum::http::header::AUTHORIZATION;
use axum::http::header::ORIGIN;
use axum::middleware;
use axum::middleware::Next;
@@ -23,6 +26,7 @@ use tracing::warn;
use crate::ExecServerRuntimePaths;
use crate::connection::JsonRpcConnection;
use crate::connection_token::connection_token_from_env;
use crate::server::processor::ConnectionProcessor;
pub const DEFAULT_LISTEN_URL: &str = "ws://127.0.0.1:0";
@@ -83,7 +87,12 @@ pub(crate) async fn run_transport(
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
match parse_listen_url(listen_url)? {
ExecServerListenTransport::WebSocket(bind_address) => {
run_websocket_listener(bind_address, runtime_paths).await
run_websocket_listener(
bind_address,
connection_token_from_env().map_err(std::io::Error::other)?,
runtime_paths,
)
.await
}
ExecServerListenTransport::Stdio => run_stdio_connection(runtime_paths).await,
}
@@ -118,6 +127,7 @@ where
async fn run_websocket_listener(
bind_address: SocketAddr,
connection_token: Option<HeaderValue>,
runtime_paths: ExecServerRuntimePaths,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let listener = TcpListener::bind(bind_address).await?;
@@ -131,7 +141,10 @@ async fn run_websocket_listener(
.route("/", any(websocket_upgrade_handler))
.route("/readyz", get(readiness_handler))
.layer(middleware::from_fn(reject_requests_with_origin_header))
.with_state(ExecServerWebSocketState { processor });
.with_state(ExecServerWebSocketState {
processor,
connection_token,
});
axum::serve(
listener,
router.into_make_service_with_connect_info::<SocketAddr>(),
@@ -143,6 +156,7 @@ async fn run_websocket_listener(
#[derive(Clone)]
struct ExecServerWebSocketState {
processor: ConnectionProcessor,
connection_token: Option<HeaderValue>,
}
async fn readiness_handler() -> StatusCode {
@@ -168,10 +182,19 @@ async fn reject_requests_with_origin_header(
async fn websocket_upgrade_handler(
websocket: WebSocketUpgrade,
ConnectInfo(peer_addr): ConnectInfo<SocketAddr>,
headers: HeaderMap,
State(state): State<ExecServerWebSocketState>,
) -> impl IntoResponse {
) -> Result<impl IntoResponse, StatusCode> {
if state
.connection_token
.as_ref()
.is_some_and(|token| headers.get(AUTHORIZATION) != Some(token))
{
return Err(StatusCode::UNAUTHORIZED);
}
info!(%peer_addr, "exec-server websocket client connected");
websocket.on_upgrade(move |stream| async move {
Ok(websocket.on_upgrade(move |stream| async move {
state
.processor
.run_connection(JsonRpcConnection::from_axum_websocket(
@@ -179,7 +202,7 @@ async fn websocket_upgrade_handler(
format!("exec-server websocket {peer_addr}"),
))
.await;
})
}))
}
#[cfg(test)]

View File

@@ -21,6 +21,8 @@ use tokio::time::sleep;
use tokio::time::timeout;
use tokio_tungstenite::connect_async;
use tokio_tungstenite::tungstenite::Message;
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
use tokio_tungstenite::tungstenite::http::HeaderValue;
const CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
const CONNECT_RETRY_INTERVAL: Duration = Duration::from_millis(25);
@@ -60,7 +62,29 @@ pub(crate) async fn exec_server() -> anyhow::Result<ExecServerHarness> {
exec_server_with_env(std::iter::empty::<(&str, &str)>()).await
}
pub(crate) async fn exec_server_with_connection_token(
connection_token: &str,
) -> anyhow::Result<ExecServerHarness> {
exec_server_with_env_and_connection_token(
[("CODEX_EXEC_SERVER_CONNECTION_TOKEN", connection_token)],
Some(connection_token),
)
.await
}
pub(crate) async fn exec_server_with_env<I, K, V>(env: I) -> anyhow::Result<ExecServerHarness>
where
I: IntoIterator<Item = (K, V)>,
K: AsRef<std::ffi::OsStr>,
V: AsRef<std::ffi::OsStr>,
{
exec_server_with_env_and_connection_token(env, /*connection_token*/ None).await
}
async fn exec_server_with_env_and_connection_token<I, K, V>(
env: I,
connection_token: Option<&str>,
) -> anyhow::Result<ExecServerHarness>
where
I: IntoIterator<Item = (K, V)>,
K: AsRef<std::ffi::OsStr>,
@@ -79,7 +103,13 @@ where
let mut child = child.spawn()?;
let websocket_url = read_listen_url_from_stdout(&mut child).await?;
let (websocket, _) = connect_websocket_when_ready(&websocket_url).await?;
let (websocket, _) = match connection_token {
Some(connection_token) => {
connect_websocket_with_connection_token_when_ready(&websocket_url, connection_token)
.await?
}
None => connect_websocket_when_ready(&websocket_url).await?,
};
Ok(ExecServerHarness {
_codex_home: codex_home,
_helper_paths: helper_paths,
@@ -213,15 +243,18 @@ impl ExecServerHarness {
}
}
async fn connect_websocket_when_ready(
websocket_url: &str,
pub(crate) async fn connect_websocket_when_ready<R>(
request: R,
) -> anyhow::Result<(
tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>,
tokio_tungstenite::tungstenite::handshake::client::Response,
)> {
)>
where
R: IntoClientRequest + Clone + Unpin,
{
let deadline = Instant::now() + CONNECT_TIMEOUT;
loop {
match connect_async(websocket_url).await {
match connect_async(request.clone()).await {
Ok(websocket) => return Ok(websocket),
Err(err)
if Instant::now() < deadline
@@ -238,6 +271,21 @@ async fn connect_websocket_when_ready(
}
}
async fn connect_websocket_with_connection_token_when_ready(
websocket_url: &str,
connection_token: &str,
) -> anyhow::Result<(
tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>,
tokio_tungstenite::tungstenite::handshake::client::Response,
)> {
let mut request = websocket_url.into_client_request()?;
request.headers_mut().insert(
tokio_tungstenite::tungstenite::http::header::AUTHORIZATION,
HeaderValue::from_str(&format!("Bearer {connection_token}"))?,
);
connect_websocket_when_ready(request).await
}
async fn read_listen_url_from_stdout(child: &mut Child) -> anyhow::Result<String> {
let stdout = child
.stdout

View File

@@ -2,8 +2,11 @@
mod common;
use common::exec_server::connect_websocket_when_ready;
use common::exec_server::exec_server;
use common::exec_server::exec_server_with_connection_token;
use pretty_assertions::assert_eq;
use tokio_tungstenite::tungstenite::Error;
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn exec_server_serves_readyz_alongside_websocket_endpoint() -> anyhow::Result<()> {
@@ -19,3 +22,32 @@ async fn exec_server_serves_readyz_alongside_websocket_endpoint() -> anyhow::Res
server.shutdown().await?;
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn exec_server_connection_token_gates_websocket_only() -> anyhow::Result<()> {
let mut server = exec_server_with_connection_token("secret").await?;
let http_base_url = server
.websocket_url()
.strip_prefix("ws://")
.expect("websocket URL should use ws://");
let response = reqwest::get(format!("http://{http_base_url}/readyz")).await?;
assert_eq!(response.status(), reqwest::StatusCode::OK);
let err = connect_websocket_when_ready(server.websocket_url())
.await
.expect_err("missing connection token should reject websocket upgrade");
assert_unauthorized_websocket_error(err);
server.shutdown().await?;
Ok(())
}
fn assert_unauthorized_websocket_error(err: anyhow::Error) {
let Some(websocket_error) = err.downcast_ref::<Error>() else {
panic!("websocket rejection should be a tungstenite error");
};
let Error::Http(response) = websocket_error else {
panic!("expected websocket HTTP rejection, got {websocket_error:?}");
};
assert_eq!(response.status(), reqwest::StatusCode::UNAUTHORIZED);
}