Compare commits

...

1 Commits

Author SHA1 Message Date
Richard Lee
1276cfc9c4 add exec-server health check 2026-05-05 23:02:16 -07:00
3 changed files with 168 additions and 14 deletions

View File

@@ -2,8 +2,11 @@ use std::io::Write as _;
use std::net::SocketAddr;
use tokio::io;
use tokio::io::AsyncRead;
use tokio::io::AsyncReadExt as _;
use tokio::io::AsyncWrite;
use tokio::io::AsyncWriteExt as _;
use tokio::net::TcpListener;
use tokio::net::TcpStream;
use tokio_tungstenite::accept_async;
use tracing::warn;
@@ -12,6 +15,15 @@ use crate::connection::JsonRpcConnection;
use crate::server::processor::ConnectionProcessor;
pub const DEFAULT_LISTEN_URL: &str = "ws://127.0.0.1:0";
const HTTP_REQUEST_PEEK_BYTES: usize = 64;
const HEALTH_REQUEST_LINE_PREFIX: &[u8] = b"GET /health HTTP/";
const HEALTH_REQUEST_MAX_BYTES: usize = 8 * 1024;
const HEALTH_RESPONSE: &[u8] = b"HTTP/1.1 200 OK\r\n\
content-type: text/plain; charset=utf-8\r\n\
content-length: 3\r\n\
connection: close\r\n\
\r\n\
ok\n";
#[derive(Debug, Clone, Eq, PartialEq)]
pub(crate) enum ExecServerListenTransport {
@@ -19,6 +31,12 @@ pub(crate) enum ExecServerListenTransport {
Stdio,
}
#[derive(Debug, Clone, Eq, PartialEq)]
enum ConnectionPreflightRoute {
Health,
WebSocket,
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub enum ExecServerListenUrlParseError {
UnsupportedListenUrl(String),
@@ -117,25 +135,69 @@ async fn run_websocket_listener(
let (stream, peer_addr) = listener.accept().await?;
let processor = processor.clone();
tokio::spawn(async move {
match accept_async(stream).await {
Ok(websocket) => {
processor
.run_connection(JsonRpcConnection::from_websocket(
websocket,
format!("exec-server websocket {peer_addr}"),
))
.await;
}
Err(err) => {
warn!(
"failed to accept exec-server websocket connection from {peer_addr}: {err}"
);
}
if let Err(err) = serve_connection(stream, peer_addr, processor).await {
warn!("failed to serve exec-server connection from {peer_addr}: {err}");
}
});
}
}
async fn serve_connection(
mut stream: TcpStream,
peer_addr: SocketAddr,
processor: ConnectionProcessor,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let mut request_prefix = [0; HTTP_REQUEST_PEEK_BYTES];
let bytes_read = stream.peek(&mut request_prefix).await?;
match connection_preflight_route(&request_prefix[..bytes_read]) {
ConnectionPreflightRoute::Health => {
read_health_check_request(&mut stream).await?;
stream.write_all(HEALTH_RESPONSE).await?;
}
ConnectionPreflightRoute::WebSocket => match accept_async(stream).await {
Ok(websocket) => {
processor
.run_connection(JsonRpcConnection::from_websocket(
websocket,
format!("exec-server websocket {peer_addr}"),
))
.await;
}
Err(err) => {
warn!("failed to accept exec-server websocket connection from {peer_addr}: {err}");
}
},
};
Ok(())
}
fn connection_preflight_route(request_prefix: &[u8]) -> ConnectionPreflightRoute {
if request_prefix.starts_with(HEALTH_REQUEST_LINE_PREFIX) {
return ConnectionPreflightRoute::Health;
}
ConnectionPreflightRoute::WebSocket
}
async fn read_health_check_request(stream: &mut TcpStream) -> io::Result<()> {
let mut request = Vec::new();
let mut buffer = [0; 512];
loop {
let bytes_read = stream.read(&mut buffer).await?;
if bytes_read == 0 {
return Ok(());
}
request.extend_from_slice(&buffer[..bytes_read]);
if request.windows(4).any(|window| window == b"\r\n\r\n")
|| request.len() >= HEALTH_REQUEST_MAX_BYTES
{
return Ok(());
}
}
}
#[cfg(test)]
#[path = "transport_tests.rs"]
mod transport_tests;

View File

@@ -13,8 +13,11 @@ use tokio::io::BufReader;
use tokio::io::duplex;
use tokio::time::timeout;
use super::ConnectionPreflightRoute;
use super::DEFAULT_LISTEN_URL;
use super::ExecServerListenTransport;
use super::HEALTH_RESPONSE;
use super::connection_preflight_route;
use super::parse_listen_url;
use super::run_stdio_connection_with_io;
use crate::ExecServerRuntimePaths;
@@ -125,6 +128,39 @@ fn parse_listen_url_accepts_websocket_url() {
);
}
#[test]
fn connection_preflight_route_detects_health_path() {
assert_eq!(
connection_preflight_route(b"GET /health HTTP/1.1\r\nHost: localhost\r\n\r\n"),
ConnectionPreflightRoute::Health
);
}
#[test]
fn connection_preflight_route_falls_back_to_websocket() {
assert_eq!(
connection_preflight_route(b"GET /healthz HTTP/1.1\r\nHost: localhost\r\n\r\n"),
ConnectionPreflightRoute::WebSocket
);
assert_eq!(
connection_preflight_route(b"POST /health HTTP/1.1\r\nHost: localhost\r\n\r\n"),
ConnectionPreflightRoute::WebSocket
);
}
#[test]
fn health_check_response_is_plain_ok() {
assert_eq!(
HEALTH_RESPONSE,
b"HTTP/1.1 200 OK\r\n\
content-type: text/plain; charset=utf-8\r\n\
content-length: 3\r\n\
connection: close\r\n\
\r\n\
ok\n"
);
}
#[test]
fn parse_listen_url_rejects_invalid_websocket_url() {
let err = parse_listen_url("ws://localhost:1234")

View File

@@ -9,8 +9,64 @@ use codex_exec_server::InitializeParams;
use codex_exec_server::InitializeResponse;
use common::exec_server::exec_server;
use pretty_assertions::assert_eq;
use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt;
use tokio::net::TcpStream;
use uuid::Uuid;
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn exec_server_serves_health_check_and_keeps_websocket_running() -> anyhow::Result<()> {
let mut server = exec_server().await?;
let socket_addr = server
.websocket_url()
.strip_prefix("ws://")
.ok_or_else(|| anyhow::anyhow!("websocket URL should use ws://"))?;
let mut stream = TcpStream::connect(socket_addr).await?;
stream
.write_all(b"GET /health HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n")
.await?;
let mut response = String::new();
stream.read_to_string(&mut response).await?;
assert_eq!(
response,
"HTTP/1.1 200 OK\r\n\
content-type: text/plain; charset=utf-8\r\n\
content-length: 3\r\n\
connection: close\r\n\
\r\n\
ok\n"
);
let initialize_id = server
.send_request(
"initialize",
serde_json::to_value(InitializeParams {
client_name: "exec-server-test".to_string(),
resume_session_id: None,
})?,
)
.await?;
let response = server
.wait_for_event(|event| {
matches!(
event,
JSONRPCMessage::Response(JSONRPCResponse { id, .. }) if id == &initialize_id
)
})
.await?;
let JSONRPCMessage::Response(JSONRPCResponse { id, result }) = response else {
panic!("expected initialize response after health check");
};
assert_eq!(id, initialize_id);
let initialize_response: InitializeResponse = serde_json::from_value(result)?;
Uuid::parse_str(&initialize_response.session_id)?;
server.shutdown().await?;
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn exec_server_reports_malformed_websocket_json_and_keeps_running() -> anyhow::Result<()> {
let mut server = exec_server().await?;