mirror of
https://github.com/openai/codex.git
synced 2026-05-16 17:23:57 +00:00
[exec-server] restore HTTP upgrade websocket listener
This commit is contained in:
@@ -5,8 +5,12 @@ use std::sync::atomic::AtomicBool;
|
||||
use std::sync::atomic::Ordering;
|
||||
use std::time::Duration;
|
||||
|
||||
use axum::extract::ws::Message as AxumWebSocketMessage;
|
||||
use axum::extract::ws::WebSocket as AxumWebSocket;
|
||||
use codex_app_server_protocol::JSONRPCMessage;
|
||||
use futures::Sink;
|
||||
use futures::SinkExt;
|
||||
use futures::Stream;
|
||||
use futures::StreamExt;
|
||||
use tokio::io::AsyncRead;
|
||||
use tokio::io::AsyncWrite;
|
||||
@@ -314,11 +318,30 @@ impl JsonRpcConnection {
|
||||
pub(crate) fn from_websocket<S>(stream: WebSocketStream<S>, connection_label: String) -> Self
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
let (websocket_writer, websocket_reader) = stream.split();
|
||||
Self::from_websocket_parts(websocket_writer, websocket_reader, connection_label)
|
||||
}
|
||||
|
||||
pub(crate) fn from_axum_websocket(stream: AxumWebSocket, connection_label: String) -> Self {
|
||||
let (websocket_writer, websocket_reader) = stream.split();
|
||||
Self::from_websocket_parts(websocket_writer, websocket_reader, connection_label)
|
||||
}
|
||||
|
||||
fn from_websocket_parts<W, R, M, E>(
|
||||
mut websocket_writer: W,
|
||||
mut websocket_reader: R,
|
||||
connection_label: String,
|
||||
) -> Self
|
||||
where
|
||||
W: Sink<M, Error = E> + Unpin + Send + 'static,
|
||||
R: Stream<Item = Result<M, E>> + Unpin + Send + 'static,
|
||||
M: JsonRpcWebSocketMessage,
|
||||
E: std::fmt::Display + Send + 'static,
|
||||
{
|
||||
let (outgoing_tx, mut outgoing_rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
let (incoming_tx, incoming_rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
let (disconnected_tx, disconnected_rx) = watch::channel(false);
|
||||
let (mut websocket_writer, mut websocket_reader) = stream.split();
|
||||
|
||||
let reader_label = connection_label.clone();
|
||||
let incoming_tx_for_reader = incoming_tx.clone();
|
||||
@@ -326,41 +349,36 @@ impl JsonRpcConnection {
|
||||
let reader_task = tokio::spawn(async move {
|
||||
loop {
|
||||
match websocket_reader.next().await {
|
||||
Some(Ok(Message::Text(text))) => {
|
||||
match serde_json::from_str::<JSONRPCMessage>(text.as_ref()) {
|
||||
Ok(message) => {
|
||||
if incoming_tx_for_reader
|
||||
.send(JsonRpcConnectionEvent::Message(message))
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
send_malformed_message(
|
||||
&incoming_tx_for_reader,
|
||||
Some(format!(
|
||||
"failed to parse websocket JSON-RPC message from {reader_label}: {err}"
|
||||
)),
|
||||
)
|
||||
.await;
|
||||
Some(Ok(message)) => match message.parse_jsonrpc_frame() {
|
||||
Ok(JsonRpcWebSocketFrame::Message(message)) => {
|
||||
if incoming_tx_for_reader
|
||||
.send(JsonRpcConnectionEvent::Message(message))
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(Ok(Message::Close(_))) => {
|
||||
send_disconnected(
|
||||
&incoming_tx_for_reader,
|
||||
&disconnected_tx_for_reader,
|
||||
/*reason*/ None,
|
||||
)
|
||||
.await;
|
||||
break;
|
||||
}
|
||||
Some(Ok(Message::Binary(_)))
|
||||
| Some(Ok(Message::Ping(_)))
|
||||
| Some(Ok(Message::Pong(_)))
|
||||
| Some(Ok(Message::Frame(_))) => {}
|
||||
Err(err) => {
|
||||
send_malformed_message(
|
||||
&incoming_tx_for_reader,
|
||||
Some(format!(
|
||||
"failed to parse websocket JSON-RPC message from {reader_label}: {err}"
|
||||
)),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
Ok(JsonRpcWebSocketFrame::Close) => {
|
||||
send_disconnected(
|
||||
&incoming_tx_for_reader,
|
||||
&disconnected_tx_for_reader,
|
||||
/*reason*/ None,
|
||||
)
|
||||
.await;
|
||||
break;
|
||||
}
|
||||
Ok(JsonRpcWebSocketFrame::Ignore) => {}
|
||||
},
|
||||
Some(Err(err)) => {
|
||||
send_disconnected(
|
||||
&incoming_tx_for_reader,
|
||||
@@ -389,8 +407,7 @@ impl JsonRpcConnection {
|
||||
while let Some(message) = outgoing_rx.recv().await {
|
||||
match serialize_jsonrpc_message(&message) {
|
||||
Ok(encoded) => {
|
||||
if let Err(err) = websocket_writer.send(Message::Text(encoded.into())).await
|
||||
{
|
||||
if let Err(err) = websocket_writer.send(M::from_text(encoded)).await {
|
||||
send_disconnected(
|
||||
&incoming_tx,
|
||||
&disconnected_tx,
|
||||
@@ -432,6 +449,53 @@ impl JsonRpcConnection {
|
||||
}
|
||||
}
|
||||
|
||||
enum JsonRpcWebSocketFrame {
|
||||
Message(JSONRPCMessage),
|
||||
Close,
|
||||
Ignore,
|
||||
}
|
||||
|
||||
trait JsonRpcWebSocketMessage: Send + 'static {
|
||||
fn parse_jsonrpc_frame(self) -> Result<JsonRpcWebSocketFrame, serde_json::Error>;
|
||||
fn from_text(text: String) -> Self;
|
||||
}
|
||||
|
||||
impl JsonRpcWebSocketMessage for Message {
|
||||
fn parse_jsonrpc_frame(self) -> Result<JsonRpcWebSocketFrame, serde_json::Error> {
|
||||
match self {
|
||||
Message::Text(text) => {
|
||||
serde_json::from_str(text.as_ref()).map(JsonRpcWebSocketFrame::Message)
|
||||
}
|
||||
Message::Close(_) => Ok(JsonRpcWebSocketFrame::Close),
|
||||
Message::Binary(_) | Message::Ping(_) | Message::Pong(_) | Message::Frame(_) => {
|
||||
Ok(JsonRpcWebSocketFrame::Ignore)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn from_text(text: String) -> Self {
|
||||
Self::Text(text.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl JsonRpcWebSocketMessage for AxumWebSocketMessage {
|
||||
fn parse_jsonrpc_frame(self) -> Result<JsonRpcWebSocketFrame, serde_json::Error> {
|
||||
match self {
|
||||
AxumWebSocketMessage::Text(text) => {
|
||||
serde_json::from_str(text.as_ref()).map(JsonRpcWebSocketFrame::Message)
|
||||
}
|
||||
AxumWebSocketMessage::Close(_) => Ok(JsonRpcWebSocketFrame::Close),
|
||||
AxumWebSocketMessage::Binary(_)
|
||||
| AxumWebSocketMessage::Ping(_)
|
||||
| AxumWebSocketMessage::Pong(_) => Ok(JsonRpcWebSocketFrame::Ignore),
|
||||
}
|
||||
}
|
||||
|
||||
fn from_text(text: String) -> Self {
|
||||
Self::Text(text.into())
|
||||
}
|
||||
}
|
||||
|
||||
async fn send_disconnected(
|
||||
incoming_tx: &mpsc::Sender<JsonRpcConnectionEvent>,
|
||||
disconnected_tx: &watch::Sender<bool>,
|
||||
|
||||
@@ -1,12 +1,18 @@
|
||||
use axum::Router;
|
||||
use axum::extract::ConnectInfo;
|
||||
use axum::extract::State;
|
||||
use axum::extract::ws::WebSocketUpgrade;
|
||||
use axum::http::StatusCode;
|
||||
use axum::response::IntoResponse;
|
||||
use axum::routing::any;
|
||||
use axum::routing::get;
|
||||
use std::io::Write as _;
|
||||
use std::net::SocketAddr;
|
||||
use tokio::io;
|
||||
use tokio::io::AsyncRead;
|
||||
use tokio::io::AsyncWrite;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio_tungstenite::accept_async;
|
||||
use tracing::info;
|
||||
use tracing::warn;
|
||||
|
||||
use crate::ExecServerRuntimePaths;
|
||||
use crate::connection::JsonRpcConnection;
|
||||
@@ -114,27 +120,42 @@ async fn run_websocket_listener(
|
||||
println!("ws://{local_addr}");
|
||||
std::io::stdout().flush()?;
|
||||
|
||||
loop {
|
||||
let (stream, peer_addr) = listener.accept().await?;
|
||||
let processor = processor.clone();
|
||||
tokio::spawn(async move {
|
||||
match accept_async(stream).await {
|
||||
Ok(websocket) => {
|
||||
processor
|
||||
.run_connection(JsonRpcConnection::from_websocket(
|
||||
websocket,
|
||||
format!("exec-server websocket {peer_addr}"),
|
||||
))
|
||||
.await;
|
||||
}
|
||||
Err(err) => {
|
||||
warn!(
|
||||
"failed to accept exec-server websocket connection from {peer_addr}: {err}"
|
||||
);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
let router = Router::new()
|
||||
.route("/", any(websocket_upgrade_handler))
|
||||
.route("/readyz", get(readiness_handler))
|
||||
.with_state(ExecServerWebSocketState { processor });
|
||||
axum::serve(
|
||||
listener,
|
||||
router.into_make_service_with_connect_info::<SocketAddr>(),
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct ExecServerWebSocketState {
|
||||
processor: ConnectionProcessor,
|
||||
}
|
||||
|
||||
async fn readiness_handler() -> StatusCode {
|
||||
StatusCode::OK
|
||||
}
|
||||
|
||||
async fn websocket_upgrade_handler(
|
||||
websocket: WebSocketUpgrade,
|
||||
ConnectInfo(peer_addr): ConnectInfo<SocketAddr>,
|
||||
State(state): State<ExecServerWebSocketState>,
|
||||
) -> impl IntoResponse {
|
||||
info!(%peer_addr, "exec-server websocket client connected");
|
||||
websocket.on_upgrade(move |stream| async move {
|
||||
state
|
||||
.processor
|
||||
.run_connection(JsonRpcConnection::from_axum_websocket(
|
||||
stream,
|
||||
format!("exec-server websocket {peer_addr}"),
|
||||
))
|
||||
.await;
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
21
codex-rs/exec-server/tests/health.rs
Normal file
21
codex-rs/exec-server/tests/health.rs
Normal file
@@ -0,0 +1,21 @@
|
||||
#![cfg(unix)]
|
||||
|
||||
mod common;
|
||||
|
||||
use common::exec_server::exec_server;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn exec_server_serves_readyz_alongside_websocket_endpoint() -> anyhow::Result<()> {
|
||||
let mut server = exec_server().await?;
|
||||
let http_base_url = server
|
||||
.websocket_url()
|
||||
.strip_prefix("ws://")
|
||||
.expect("websocket URL should use ws://");
|
||||
|
||||
let response = reqwest::get(format!("http://{http_base_url}/readyz")).await?;
|
||||
assert_eq!(response.status(), reqwest::StatusCode::OK);
|
||||
|
||||
server.shutdown().await?;
|
||||
Ok(())
|
||||
}
|
||||
Reference in New Issue
Block a user