use anyhow::Context; use anyhow::Result; use anyhow::bail; use app_test_support::create_mock_responses_server_sequence_unchecked; use codex_app_server_protocol::ClientInfo; use codex_app_server_protocol::InitializeParams; use codex_app_server_protocol::JSONRPCError; use codex_app_server_protocol::JSONRPCMessage; use codex_app_server_protocol::JSONRPCNotification; use codex_app_server_protocol::JSONRPCRequest; use codex_app_server_protocol::JSONRPCResponse; use codex_app_server_protocol::RequestId; use futures::SinkExt; use futures::StreamExt; use reqwest::StatusCode; use serde_json::json; use std::net::SocketAddr; use std::path::Path; use std::process::Stdio; use tempfile::TempDir; use tokio::io::AsyncBufReadExt; use tokio::io::BufReader; use tokio::process::Child; use tokio::process::Command; use tokio::time::Duration; use tokio::time::Instant; use tokio::time::sleep; use tokio::time::timeout; use tokio_tungstenite::MaybeTlsStream; use tokio_tungstenite::WebSocketStream; use tokio_tungstenite::connect_async; use tokio_tungstenite::tungstenite::Error as WebSocketError; use tokio_tungstenite::tungstenite::Message as WebSocketMessage; use tokio_tungstenite::tungstenite::client::IntoClientRequest; use tokio_tungstenite::tungstenite::http::HeaderValue; use tokio_tungstenite::tungstenite::http::header::ORIGIN; pub(super) const DEFAULT_READ_TIMEOUT: Duration = Duration::from_secs(5); pub(super) type WsClient = WebSocketStream>; #[tokio::test] async fn websocket_transport_routes_per_connection_handshake_and_responses() -> Result<()> { let server = create_mock_responses_server_sequence_unchecked(Vec::new()).await; let codex_home = TempDir::new()?; create_config_toml(codex_home.path(), &server.uri(), "never")?; let (mut process, bind_addr) = spawn_websocket_server(codex_home.path()).await?; let mut ws1 = connect_websocket(bind_addr).await?; let mut ws2 = connect_websocket(bind_addr).await?; send_initialize_request(&mut ws1, 1, "ws_client_one").await?; let first_init = read_response_for_id(&mut ws1, 1).await?; assert_eq!(first_init.id, RequestId::Integer(1)); // Initialize responses are request-scoped and must not leak to other // connections. assert_no_message(&mut ws2, Duration::from_millis(250)).await?; send_config_read_request(&mut ws2, 2).await?; let not_initialized = read_error_for_id(&mut ws2, 2).await?; assert_eq!(not_initialized.error.message, "Not initialized"); send_initialize_request(&mut ws2, 3, "ws_client_two").await?; let second_init = read_response_for_id(&mut ws2, 3).await?; assert_eq!(second_init.id, RequestId::Integer(3)); // Same request-id on different connections must route independently. send_config_read_request(&mut ws1, 77).await?; send_config_read_request(&mut ws2, 77).await?; let ws1_config = read_response_for_id(&mut ws1, 77).await?; let ws2_config = read_response_for_id(&mut ws2, 77).await?; assert_eq!(ws1_config.id, RequestId::Integer(77)); assert_eq!(ws2_config.id, RequestId::Integer(77)); assert!(ws1_config.result.get("config").is_some()); assert!(ws2_config.result.get("config").is_some()); process .kill() .await .context("failed to stop websocket app-server process")?; Ok(()) } #[tokio::test] async fn websocket_transport_serves_health_endpoints_on_same_listener() -> Result<()> { let server = create_mock_responses_server_sequence_unchecked(Vec::new()).await; let codex_home = TempDir::new()?; create_config_toml(codex_home.path(), &server.uri(), "never")?; let (mut process, bind_addr) = spawn_websocket_server(codex_home.path()).await?; let client = reqwest::Client::new(); let readyz = http_get(&client, bind_addr, "/readyz").await?; assert_eq!(readyz.status(), StatusCode::OK); let healthz = http_get(&client, bind_addr, "/healthz").await?; assert_eq!(healthz.status(), StatusCode::OK); let mut ws = connect_websocket(bind_addr).await?; send_initialize_request(&mut ws, 1, "ws_health_client").await?; let init = read_response_for_id(&mut ws, 1).await?; assert_eq!(init.id, RequestId::Integer(1)); process .kill() .await .context("failed to stop websocket app-server process")?; Ok(()) } #[tokio::test] async fn websocket_transport_rejects_requests_with_origin_header() -> Result<()> { let server = create_mock_responses_server_sequence_unchecked(Vec::new()).await; let codex_home = TempDir::new()?; create_config_toml(codex_home.path(), &server.uri(), "never")?; let (mut process, bind_addr) = spawn_websocket_server(codex_home.path()).await?; let client = reqwest::Client::new(); let deadline = Instant::now() + Duration::from_secs(10); let healthz = loop { match client .get(format!("http://{bind_addr}/healthz")) .header(ORIGIN.as_str(), "https://example.com") .send() .await .with_context(|| format!("failed to GET http://{bind_addr}/healthz with Origin header")) { Ok(response) => break response, Err(err) => { if Instant::now() >= deadline { bail!("failed to GET http://{bind_addr}/healthz with Origin header: {err}"); } sleep(Duration::from_millis(50)).await; } } }; assert_eq!(healthz.status(), StatusCode::FORBIDDEN); let url = format!("ws://{bind_addr}"); let mut request = url.into_client_request()?; request .headers_mut() .insert(ORIGIN, HeaderValue::from_static("https://example.com")); match connect_async(request).await { Err(WebSocketError::Http(response)) => { assert_eq!(response.status(), StatusCode::FORBIDDEN); } Ok(_) => bail!("expected websocket handshake with Origin header to be rejected"), Err(err) => bail!("expected HTTP rejection for Origin header, got {err}"), } process .kill() .await .context("failed to stop websocket app-server process")?; Ok(()) } pub(super) async fn spawn_websocket_server(codex_home: &Path) -> Result<(Child, SocketAddr)> { let program = codex_utils_cargo_bin::cargo_bin("codex-app-server") .context("should find app-server binary")?; let mut cmd = Command::new(program); cmd.arg("--listen") .arg("ws://127.0.0.1:0") .stdin(Stdio::null()) .stdout(Stdio::null()) .stderr(Stdio::piped()) .env("CODEX_HOME", codex_home) .env("RUST_LOG", "debug"); let mut process = cmd .kill_on_drop(true) .spawn() .context("failed to spawn websocket app-server process")?; let stderr = process .stderr .take() .context("failed to capture websocket app-server stderr")?; let mut stderr_reader = BufReader::new(stderr).lines(); let deadline = Instant::now() + Duration::from_secs(10); let bind_addr = loop { let line = timeout( deadline.saturating_duration_since(Instant::now()), stderr_reader.next_line(), ) .await .context("timed out waiting for websocket app-server to report bound websocket address")? .context("failed to read websocket app-server stderr")? .context("websocket app-server exited before reporting bound websocket address")?; eprintln!("[websocket app-server stderr] {line}"); let stripped_line = { let mut stripped = String::with_capacity(line.len()); let mut chars = line.chars().peekable(); while let Some(ch) = chars.next() { if ch == '\u{1b}' && matches!(chars.peek(), Some(&'[')) { chars.next(); for next in chars.by_ref() { if ('@'..='~').contains(&next) { break; } } continue; } stripped.push(ch); } stripped }; if let Some(bind_addr) = stripped_line .split_whitespace() .find_map(|token| token.strip_prefix("ws://")) .and_then(|addr| addr.parse::().ok()) { break bind_addr; } }; tokio::spawn(async move { while let Ok(Some(line)) = stderr_reader.next_line().await { eprintln!("[websocket app-server stderr] {line}"); } }); Ok((process, bind_addr)) } pub(super) async fn connect_websocket(bind_addr: SocketAddr) -> Result { let url = format!("ws://{bind_addr}"); let deadline = Instant::now() + Duration::from_secs(10); loop { match connect_async(&url).await { Ok((stream, _response)) => return Ok(stream), Err(err) => { if Instant::now() >= deadline { bail!("failed to connect websocket to {url}: {err}"); } sleep(Duration::from_millis(50)).await; } } } } async fn http_get( client: &reqwest::Client, bind_addr: SocketAddr, path: &str, ) -> Result { let deadline = Instant::now() + Duration::from_secs(10); loop { match client .get(format!("http://{bind_addr}{path}")) .send() .await .with_context(|| format!("failed to GET http://{bind_addr}{path}")) { Ok(response) => return Ok(response), Err(err) => { if Instant::now() >= deadline { bail!("failed to GET http://{bind_addr}{path}: {err}"); } sleep(Duration::from_millis(50)).await; } } } } pub(super) async fn send_initialize_request( stream: &mut WsClient, id: i64, client_name: &str, ) -> Result<()> { let params = InitializeParams { client_info: ClientInfo { name: client_name.to_string(), title: Some("WebSocket Test Client".to_string()), version: "0.1.0".to_string(), }, capabilities: None, }; send_request( stream, "initialize", id, Some(serde_json::to_value(params)?), ) .await } async fn send_config_read_request(stream: &mut WsClient, id: i64) -> Result<()> { send_request( stream, "config/read", id, Some(json!({ "includeLayers": false })), ) .await } pub(super) async fn send_request( stream: &mut WsClient, method: &str, id: i64, params: Option, ) -> Result<()> { let message = JSONRPCMessage::Request(JSONRPCRequest { id: RequestId::Integer(id), method: method.to_string(), params, trace: None, }); send_jsonrpc(stream, message).await } async fn send_jsonrpc(stream: &mut WsClient, message: JSONRPCMessage) -> Result<()> { let payload = serde_json::to_string(&message)?; stream .send(WebSocketMessage::Text(payload.into())) .await .context("failed to send websocket frame") } pub(super) async fn read_response_for_id( stream: &mut WsClient, id: i64, ) -> Result { let target_id = RequestId::Integer(id); loop { let message = read_jsonrpc_message(stream).await?; if let JSONRPCMessage::Response(response) = message && response.id == target_id { return Ok(response); } } } pub(super) async fn read_notification_for_method( stream: &mut WsClient, method: &str, ) -> Result { loop { let message = read_jsonrpc_message(stream).await?; if let JSONRPCMessage::Notification(notification) = message && notification.method == method { return Ok(notification); } } } pub(super) async fn read_response_and_notification_for_method( stream: &mut WsClient, id: i64, method: &str, ) -> Result<(JSONRPCResponse, JSONRPCNotification)> { let target_id = RequestId::Integer(id); let mut response = None; let mut notification = None; while response.is_none() || notification.is_none() { let message = read_jsonrpc_message(stream).await?; match message { JSONRPCMessage::Response(candidate) if candidate.id == target_id => { response = Some(candidate); } JSONRPCMessage::Notification(candidate) if candidate.method == method => { if notification.replace(candidate).is_some() { bail!( "received duplicate notification for method `{method}` before completing paired read" ); } } _ => {} } } let Some(response) = response else { bail!("response must be set before returning"); }; let Some(notification) = notification else { bail!("notification must be set before returning"); }; Ok((response, notification)) } async fn read_error_for_id(stream: &mut WsClient, id: i64) -> Result { let target_id = RequestId::Integer(id); loop { let message = read_jsonrpc_message(stream).await?; if let JSONRPCMessage::Error(err) = message && err.id == target_id { return Ok(err); } } } pub(super) async fn read_jsonrpc_message(stream: &mut WsClient) -> Result { loop { let frame = timeout(DEFAULT_READ_TIMEOUT, stream.next()) .await .context("timed out waiting for websocket frame")? .context("websocket stream ended unexpectedly")? .context("failed to read websocket frame")?; match frame { WebSocketMessage::Text(text) => return Ok(serde_json::from_str(text.as_ref())?), WebSocketMessage::Ping(payload) => { stream.send(WebSocketMessage::Pong(payload)).await?; } WebSocketMessage::Pong(_) => {} WebSocketMessage::Close(frame) => { bail!("websocket closed unexpectedly: {frame:?}") } WebSocketMessage::Binary(_) => bail!("unexpected binary websocket frame"), WebSocketMessage::Frame(_) => {} } } } pub(super) async fn assert_no_message(stream: &mut WsClient, wait_for: Duration) -> Result<()> { match timeout(wait_for, stream.next()).await { Ok(Some(Ok(frame))) => bail!("unexpected frame while waiting for silence: {frame:?}"), Ok(Some(Err(err))) => bail!("unexpected websocket read error: {err}"), Ok(None) => bail!("websocket closed unexpectedly while waiting for silence"), Err(_) => Ok(()), } } pub(super) fn create_config_toml( codex_home: &Path, server_uri: &str, approval_policy: &str, ) -> std::io::Result<()> { let config_toml = codex_home.join("config.toml"); std::fs::write( config_toml, format!( r#" model = "mock-model" approval_policy = "{approval_policy}" sandbox_mode = "read-only" model_provider = "mock_provider" [model_providers.mock_provider] name = "Mock provider for test" base_url = "{server_uri}/v1" wire_api = "responses" request_max_retries = 0 stream_max_retries = 0 "# ), ) }