mirror of
https://github.com/openai/codex.git
synced 2026-05-18 02:02:30 +00:00
## Why Added support for UDS connections in `codex --remote`. TUI also now connects to local app-server using UDS by default if it is running and set to listen to UDS connection. ## What Changed - Introduced `RemoteAppServerEndpoint` with `WebSocket` and `UnixSocket` variants. - Reused the existing JSON-RPC-over-WebSocket protocol over either a TCP WebSocket stream or a UDS stream. - Updated `codex --remote` to accept `ws://host:port`, `wss://host:port`, `unix://`, and `unix://PATH`. - Kept `--remote-auth-token-env` restricted to `wss://` and loopback `ws://` remotes. - Added a fast TUI startup probe for the default daemon socket, falling back to the embedded app server when the daemon is absent or unresponsive. ## Verification - Manually verified that the updated remote flow works. - Added coverage for UDS remote round trips, WebSocket auth headers, auth-token transport policy, remote address parsing, and missing-daemon fallback. - Ran focused remote test coverage locally.
1019 lines
40 KiB
Rust
1019 lines
40 KiB
Rust
/*
|
|
This module implements the remote app-server client transport.
|
|
|
|
It owns the remote connection lifecycle, including the initialize/initialized
|
|
handshake, JSON-RPC request/response routing, server-request resolution, and
|
|
notification streaming. Remote connections always carry WebSocket frames, over
|
|
either TCP WebSocket URLs or local Unix sockets. The rest of the crate uses the
|
|
same `AppServerEvent` surface for both in-process and remote transports, so
|
|
callers such as the TUI can switch between them without changing their
|
|
higher-level session logic.
|
|
*/
|
|
|
|
use std::collections::HashMap;
|
|
use std::collections::VecDeque;
|
|
use std::io::Error as IoError;
|
|
use std::io::ErrorKind;
|
|
use std::io::Result as IoResult;
|
|
use std::time::Duration;
|
|
|
|
use crate::AppServerEvent;
|
|
use crate::RequestResult;
|
|
use crate::SHUTDOWN_TIMEOUT;
|
|
use crate::TypedRequestError;
|
|
use crate::request_method_name;
|
|
use codex_app_server_protocol::ClientInfo;
|
|
use codex_app_server_protocol::ClientNotification;
|
|
use codex_app_server_protocol::ClientRequest;
|
|
use codex_app_server_protocol::InitializeCapabilities;
|
|
use codex_app_server_protocol::InitializeParams;
|
|
use codex_app_server_protocol::JSONRPCError;
|
|
use codex_app_server_protocol::JSONRPCErrorError;
|
|
use codex_app_server_protocol::JSONRPCMessage;
|
|
use codex_app_server_protocol::JSONRPCNotification;
|
|
use codex_app_server_protocol::JSONRPCRequest;
|
|
use codex_app_server_protocol::JSONRPCResponse;
|
|
use codex_app_server_protocol::RequestId;
|
|
use codex_app_server_protocol::Result as JsonRpcResult;
|
|
use codex_app_server_protocol::ServerNotification;
|
|
use codex_app_server_protocol::ServerRequest;
|
|
use codex_uds::UnixStream;
|
|
use codex_utils_absolute_path::AbsolutePathBuf;
|
|
use codex_utils_rustls_provider::ensure_rustls_crypto_provider;
|
|
use futures::SinkExt;
|
|
use futures::StreamExt;
|
|
use serde::de::DeserializeOwned;
|
|
use tokio::io::AsyncRead;
|
|
use tokio::io::AsyncWrite;
|
|
use tokio::net::TcpStream;
|
|
use tokio::sync::mpsc;
|
|
use tokio::sync::oneshot;
|
|
use tokio::time::timeout;
|
|
use tokio_tungstenite::MaybeTlsStream;
|
|
use tokio_tungstenite::WebSocketStream;
|
|
use tokio_tungstenite::client_async_with_config;
|
|
use tokio_tungstenite::connect_async_with_config;
|
|
use tokio_tungstenite::tungstenite::Error as TungsteniteError;
|
|
use tokio_tungstenite::tungstenite::Message;
|
|
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
|
|
use tokio_tungstenite::tungstenite::http::HeaderValue;
|
|
use tokio_tungstenite::tungstenite::http::header::AUTHORIZATION;
|
|
use tokio_tungstenite::tungstenite::protocol::WebSocketConfig;
|
|
use tracing::warn;
|
|
use url::Url;
|
|
|
|
const CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
|
|
const INITIALIZE_TIMEOUT: Duration = Duration::from_secs(10);
|
|
const REMOTE_APP_SERVER_MAX_WEBSOCKET_MESSAGE_SIZE: usize = 128 << 20;
|
|
// Tungstenite still needs an HTTP request URI for the WebSocket handshake;
|
|
// the bytes travel over the Unix socket, not TCP.
|
|
const UDS_WEBSOCKET_HANDSHAKE_URL: &str = "ws://localhost/rpc";
|
|
|
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
pub enum RemoteAppServerEndpoint {
|
|
WebSocket {
|
|
websocket_url: String,
|
|
auth_token: Option<String>,
|
|
},
|
|
UnixSocket {
|
|
socket_path: AbsolutePathBuf,
|
|
},
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct RemoteAppServerConnectArgs {
|
|
pub endpoint: RemoteAppServerEndpoint,
|
|
pub client_name: String,
|
|
pub client_version: String,
|
|
pub experimental_api: bool,
|
|
pub opt_out_notification_methods: Vec<String>,
|
|
pub channel_capacity: usize,
|
|
}
|
|
impl RemoteAppServerConnectArgs {
|
|
fn initialize_params(&self) -> InitializeParams {
|
|
let capabilities = InitializeCapabilities {
|
|
experimental_api: self.experimental_api,
|
|
request_attestation: false,
|
|
opt_out_notification_methods: if self.opt_out_notification_methods.is_empty() {
|
|
None
|
|
} else {
|
|
Some(self.opt_out_notification_methods.clone())
|
|
},
|
|
};
|
|
|
|
InitializeParams {
|
|
client_info: ClientInfo {
|
|
name: self.client_name.clone(),
|
|
title: None,
|
|
version: self.client_version.clone(),
|
|
},
|
|
capabilities: Some(capabilities),
|
|
}
|
|
}
|
|
}
|
|
|
|
pub(crate) fn websocket_url_supports_auth_token(url: &Url) -> bool {
|
|
match (url.scheme(), url.host()) {
|
|
("wss", Some(_)) => true,
|
|
("ws", Some(url::Host::Domain(domain))) => domain.eq_ignore_ascii_case("localhost"),
|
|
("ws", Some(url::Host::Ipv4(addr))) => addr.is_loopback(),
|
|
("ws", Some(url::Host::Ipv6(addr))) => addr.is_loopback(),
|
|
_ => false,
|
|
}
|
|
}
|
|
|
|
enum RemoteClientCommand {
|
|
Request {
|
|
request: Box<ClientRequest>,
|
|
response_tx: oneshot::Sender<IoResult<RequestResult>>,
|
|
},
|
|
Notify {
|
|
notification: ClientNotification,
|
|
response_tx: oneshot::Sender<IoResult<()>>,
|
|
},
|
|
ResolveServerRequest {
|
|
request_id: RequestId,
|
|
result: JsonRpcResult,
|
|
response_tx: oneshot::Sender<IoResult<()>>,
|
|
},
|
|
RejectServerRequest {
|
|
request_id: RequestId,
|
|
error: JSONRPCErrorError,
|
|
response_tx: oneshot::Sender<IoResult<()>>,
|
|
},
|
|
Shutdown {
|
|
response_tx: oneshot::Sender<IoResult<()>>,
|
|
},
|
|
}
|
|
|
|
pub struct RemoteAppServerClient {
|
|
command_tx: mpsc::Sender<RemoteClientCommand>,
|
|
event_rx: mpsc::UnboundedReceiver<AppServerEvent>,
|
|
pending_events: VecDeque<AppServerEvent>,
|
|
worker_handle: tokio::task::JoinHandle<()>,
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
pub struct RemoteAppServerRequestHandle {
|
|
command_tx: mpsc::Sender<RemoteClientCommand>,
|
|
}
|
|
|
|
impl RemoteAppServerClient {
|
|
pub async fn connect(args: RemoteAppServerConnectArgs) -> IoResult<Self> {
|
|
let channel_capacity = args.channel_capacity.max(1);
|
|
let initialize_params = args.initialize_params();
|
|
match args.endpoint {
|
|
RemoteAppServerEndpoint::WebSocket {
|
|
websocket_url,
|
|
auth_token,
|
|
} => {
|
|
let (endpoint, stream) =
|
|
connect_websocket_endpoint(websocket_url, auth_token).await?;
|
|
Self::connect_with_stream(channel_capacity, endpoint, stream, initialize_params)
|
|
.await
|
|
}
|
|
RemoteAppServerEndpoint::UnixSocket { socket_path } => {
|
|
let (endpoint, stream) = connect_unix_socket_endpoint(socket_path).await?;
|
|
Self::connect_with_stream(channel_capacity, endpoint, stream, initialize_params)
|
|
.await
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn connect_with_stream<S>(
|
|
channel_capacity: usize,
|
|
endpoint: String,
|
|
stream: WebSocketStream<S>,
|
|
initialize_params: InitializeParams,
|
|
) -> IoResult<Self>
|
|
where
|
|
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
|
|
{
|
|
let mut stream = stream;
|
|
let pending_events = initialize_remote_connection(
|
|
&mut stream,
|
|
&endpoint,
|
|
initialize_params,
|
|
INITIALIZE_TIMEOUT,
|
|
)
|
|
.await?;
|
|
|
|
let (command_tx, mut command_rx) = mpsc::channel::<RemoteClientCommand>(channel_capacity);
|
|
let (event_tx, event_rx) = mpsc::unbounded_channel::<AppServerEvent>();
|
|
let worker_handle = tokio::spawn(async move {
|
|
let mut pending_requests =
|
|
HashMap::<RequestId, oneshot::Sender<IoResult<RequestResult>>>::new();
|
|
let mut worker_exit_error: Option<(ErrorKind, String)> = None;
|
|
loop {
|
|
tokio::select! {
|
|
command = command_rx.recv() => {
|
|
let Some(command) = command else {
|
|
let _ = stream.close(None).await;
|
|
break;
|
|
};
|
|
match command {
|
|
RemoteClientCommand::Request { request, response_tx } => {
|
|
let request_id = request_id_from_client_request(&request);
|
|
if pending_requests.contains_key(&request_id) {
|
|
let _ = response_tx.send(Err(IoError::new(
|
|
ErrorKind::InvalidInput,
|
|
format!("duplicate remote app-server request id `{request_id}`"),
|
|
)));
|
|
continue;
|
|
}
|
|
pending_requests.insert(request_id.clone(), response_tx);
|
|
if let Err(err) = write_jsonrpc_message(
|
|
&mut stream,
|
|
JSONRPCMessage::Request(jsonrpc_request_from_client_request(*request)),
|
|
&endpoint,
|
|
)
|
|
.await
|
|
{
|
|
let err_message = err.to_string();
|
|
let message = format!(
|
|
"remote app server at `{endpoint}` write failed: {err_message}"
|
|
);
|
|
if let Some(response_tx) = pending_requests.remove(&request_id) {
|
|
let _ = response_tx.send(Err(err));
|
|
}
|
|
let _ = deliver_event(
|
|
&event_tx,
|
|
AppServerEvent::Disconnected {
|
|
message: message.clone(),
|
|
},
|
|
);
|
|
worker_exit_error = Some((ErrorKind::BrokenPipe, message));
|
|
break;
|
|
}
|
|
}
|
|
RemoteClientCommand::Notify { notification, response_tx } => {
|
|
let result = write_jsonrpc_message(
|
|
&mut stream,
|
|
JSONRPCMessage::Notification(
|
|
jsonrpc_notification_from_client_notification(notification),
|
|
),
|
|
&endpoint,
|
|
)
|
|
.await;
|
|
let _ = response_tx.send(result);
|
|
}
|
|
RemoteClientCommand::ResolveServerRequest {
|
|
request_id,
|
|
result,
|
|
response_tx,
|
|
} => {
|
|
let result = write_jsonrpc_message(
|
|
&mut stream,
|
|
JSONRPCMessage::Response(JSONRPCResponse {
|
|
id: request_id,
|
|
result,
|
|
}),
|
|
&endpoint,
|
|
)
|
|
.await;
|
|
let _ = response_tx.send(result);
|
|
}
|
|
RemoteClientCommand::RejectServerRequest {
|
|
request_id,
|
|
error,
|
|
response_tx,
|
|
} => {
|
|
let result = write_jsonrpc_message(
|
|
&mut stream,
|
|
JSONRPCMessage::Error(JSONRPCError {
|
|
error,
|
|
id: request_id,
|
|
}),
|
|
&endpoint,
|
|
)
|
|
.await;
|
|
let _ = response_tx.send(result);
|
|
}
|
|
RemoteClientCommand::Shutdown { response_tx } => {
|
|
let close_result = stream.close(None).await.or_else(|err| {
|
|
if websocket_close_error_is_already_closed(&err) {
|
|
Ok(())
|
|
} else {
|
|
Err(IoError::other(format!(
|
|
"failed to close websocket app server `{endpoint}`: {err}"
|
|
)))
|
|
}
|
|
});
|
|
let _ = response_tx.send(close_result);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
message = stream.next() => {
|
|
match message {
|
|
Some(Ok(Message::Text(text))) => {
|
|
match serde_json::from_str::<JSONRPCMessage>(&text) {
|
|
Ok(JSONRPCMessage::Response(response)) => {
|
|
if let Some(response_tx) = pending_requests.remove(&response.id) {
|
|
let _ = response_tx.send(Ok(Ok(response.result)));
|
|
}
|
|
}
|
|
Ok(JSONRPCMessage::Error(error)) => {
|
|
if let Some(response_tx) = pending_requests.remove(&error.id) {
|
|
let _ = response_tx.send(Ok(Err(error.error)));
|
|
}
|
|
}
|
|
Ok(JSONRPCMessage::Notification(notification)) => {
|
|
if let Some(event) =
|
|
app_server_event_from_notification(notification)
|
|
&& let Err(err) = deliver_event(
|
|
&event_tx,
|
|
event,
|
|
)
|
|
{
|
|
warn!(%err, "failed to deliver remote app-server event");
|
|
break;
|
|
}
|
|
}
|
|
Ok(JSONRPCMessage::Request(request)) => {
|
|
let request_id = request.id.clone();
|
|
let method = request.method.clone();
|
|
match ServerRequest::try_from(request) {
|
|
Ok(request) => {
|
|
if let Err(err) = deliver_event(
|
|
&event_tx,
|
|
AppServerEvent::ServerRequest(request),
|
|
)
|
|
{
|
|
warn!(%err, "failed to deliver remote app-server server request");
|
|
break;
|
|
}
|
|
}
|
|
Err(err) => {
|
|
warn!(%err, method, "rejecting unknown remote app-server request");
|
|
if let Err(reject_err) = write_jsonrpc_message(
|
|
&mut stream,
|
|
JSONRPCMessage::Error(JSONRPCError {
|
|
error: JSONRPCErrorError {
|
|
code: -32601,
|
|
message: format!(
|
|
"unsupported remote app-server request `{method}`"
|
|
),
|
|
data: None,
|
|
},
|
|
id: request_id,
|
|
}),
|
|
&endpoint,
|
|
)
|
|
.await
|
|
{
|
|
let err_message = reject_err.to_string();
|
|
let message = format!(
|
|
"remote app server at `{endpoint}` write failed: {err_message}"
|
|
);
|
|
let _ = deliver_event(
|
|
&event_tx,
|
|
AppServerEvent::Disconnected {
|
|
message: message.clone(),
|
|
},
|
|
);
|
|
worker_exit_error =
|
|
Some((ErrorKind::BrokenPipe, message));
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
Err(err) => {
|
|
let message = format!(
|
|
"remote app server at `{endpoint}` sent invalid JSON-RPC: {err}"
|
|
);
|
|
let _ = deliver_event(
|
|
&event_tx,
|
|
AppServerEvent::Disconnected {
|
|
message: message.clone(),
|
|
},
|
|
);
|
|
worker_exit_error =
|
|
Some((ErrorKind::InvalidData, message));
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
Some(Ok(Message::Close(frame))) => {
|
|
let reason = frame
|
|
.as_ref()
|
|
.map(|frame| frame.reason.to_string())
|
|
.filter(|reason| !reason.is_empty())
|
|
.unwrap_or_else(|| "connection closed".to_string());
|
|
let message = format!(
|
|
"remote app server at `{endpoint}` disconnected: {reason}"
|
|
);
|
|
let _ = deliver_event(
|
|
&event_tx,
|
|
AppServerEvent::Disconnected {
|
|
message: message.clone(),
|
|
},
|
|
);
|
|
worker_exit_error = Some((
|
|
ErrorKind::ConnectionAborted,
|
|
message,
|
|
));
|
|
break;
|
|
}
|
|
Some(Ok(Message::Binary(_)))
|
|
| Some(Ok(Message::Ping(_)))
|
|
| Some(Ok(Message::Pong(_)))
|
|
| Some(Ok(Message::Frame(_))) => {}
|
|
Some(Err(err)) => {
|
|
let message = format!(
|
|
"remote app server at `{endpoint}` transport failed: {err}"
|
|
);
|
|
let _ = deliver_event(
|
|
&event_tx,
|
|
AppServerEvent::Disconnected {
|
|
message: message.clone(),
|
|
},
|
|
);
|
|
worker_exit_error = Some((ErrorKind::InvalidData, message));
|
|
break;
|
|
}
|
|
None => {
|
|
let message = format!(
|
|
"remote app server at `{endpoint}` closed the connection"
|
|
);
|
|
let _ = deliver_event(
|
|
&event_tx,
|
|
AppServerEvent::Disconnected {
|
|
message: message.clone(),
|
|
},
|
|
);
|
|
worker_exit_error = Some((ErrorKind::UnexpectedEof, message));
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
let (err_kind, err_message) = worker_exit_error.unwrap_or_else(|| {
|
|
(
|
|
ErrorKind::BrokenPipe,
|
|
"remote app-server worker channel is closed".to_string(),
|
|
)
|
|
});
|
|
for (_, response_tx) in pending_requests {
|
|
let _ = response_tx.send(Err(IoError::new(err_kind, err_message.clone())));
|
|
}
|
|
});
|
|
|
|
Ok(Self {
|
|
command_tx,
|
|
event_rx,
|
|
pending_events: pending_events.into(),
|
|
worker_handle,
|
|
})
|
|
}
|
|
|
|
pub fn request_handle(&self) -> RemoteAppServerRequestHandle {
|
|
RemoteAppServerRequestHandle {
|
|
command_tx: self.command_tx.clone(),
|
|
}
|
|
}
|
|
|
|
pub async fn request(&self, request: ClientRequest) -> IoResult<RequestResult> {
|
|
let (response_tx, response_rx) = oneshot::channel();
|
|
self.command_tx
|
|
.send(RemoteClientCommand::Request {
|
|
request: Box::new(request),
|
|
response_tx,
|
|
})
|
|
.await
|
|
.map_err(|_| {
|
|
IoError::new(
|
|
ErrorKind::BrokenPipe,
|
|
"remote app-server worker channel is closed",
|
|
)
|
|
})?;
|
|
response_rx.await.map_err(|_| {
|
|
IoError::new(
|
|
ErrorKind::BrokenPipe,
|
|
"remote app-server request channel is closed",
|
|
)
|
|
})?
|
|
}
|
|
|
|
pub async fn request_typed<T>(&self, request: ClientRequest) -> Result<T, TypedRequestError>
|
|
where
|
|
T: DeserializeOwned,
|
|
{
|
|
let method = request_method_name(&request);
|
|
let response =
|
|
self.request(request)
|
|
.await
|
|
.map_err(|source| TypedRequestError::Transport {
|
|
method: method.clone(),
|
|
source,
|
|
})?;
|
|
let result = response.map_err(|source| TypedRequestError::Server {
|
|
method: method.clone(),
|
|
source,
|
|
})?;
|
|
serde_json::from_value(result)
|
|
.map_err(|source| TypedRequestError::Deserialize { method, source })
|
|
}
|
|
|
|
pub async fn notify(&self, notification: ClientNotification) -> IoResult<()> {
|
|
let (response_tx, response_rx) = oneshot::channel();
|
|
self.command_tx
|
|
.send(RemoteClientCommand::Notify {
|
|
notification,
|
|
response_tx,
|
|
})
|
|
.await
|
|
.map_err(|_| {
|
|
IoError::new(
|
|
ErrorKind::BrokenPipe,
|
|
"remote app-server worker channel is closed",
|
|
)
|
|
})?;
|
|
response_rx.await.map_err(|_| {
|
|
IoError::new(
|
|
ErrorKind::BrokenPipe,
|
|
"remote app-server notify channel is closed",
|
|
)
|
|
})?
|
|
}
|
|
|
|
pub async fn resolve_server_request(
|
|
&self,
|
|
request_id: RequestId,
|
|
result: JsonRpcResult,
|
|
) -> IoResult<()> {
|
|
let (response_tx, response_rx) = oneshot::channel();
|
|
self.command_tx
|
|
.send(RemoteClientCommand::ResolveServerRequest {
|
|
request_id,
|
|
result,
|
|
response_tx,
|
|
})
|
|
.await
|
|
.map_err(|_| {
|
|
IoError::new(
|
|
ErrorKind::BrokenPipe,
|
|
"remote app-server worker channel is closed",
|
|
)
|
|
})?;
|
|
response_rx.await.map_err(|_| {
|
|
IoError::new(
|
|
ErrorKind::BrokenPipe,
|
|
"remote app-server resolve channel is closed",
|
|
)
|
|
})?
|
|
}
|
|
|
|
pub async fn reject_server_request(
|
|
&self,
|
|
request_id: RequestId,
|
|
error: JSONRPCErrorError,
|
|
) -> IoResult<()> {
|
|
let (response_tx, response_rx) = oneshot::channel();
|
|
self.command_tx
|
|
.send(RemoteClientCommand::RejectServerRequest {
|
|
request_id,
|
|
error,
|
|
response_tx,
|
|
})
|
|
.await
|
|
.map_err(|_| {
|
|
IoError::new(
|
|
ErrorKind::BrokenPipe,
|
|
"remote app-server worker channel is closed",
|
|
)
|
|
})?;
|
|
response_rx.await.map_err(|_| {
|
|
IoError::new(
|
|
ErrorKind::BrokenPipe,
|
|
"remote app-server reject channel is closed",
|
|
)
|
|
})?
|
|
}
|
|
|
|
pub async fn next_event(&mut self) -> Option<AppServerEvent> {
|
|
if let Some(event) = self.pending_events.pop_front() {
|
|
return Some(event);
|
|
}
|
|
self.event_rx.recv().await
|
|
}
|
|
|
|
pub async fn shutdown(self) -> IoResult<()> {
|
|
let Self {
|
|
command_tx,
|
|
event_rx,
|
|
pending_events: _pending_events,
|
|
worker_handle,
|
|
} = self;
|
|
let mut worker_handle = worker_handle;
|
|
drop(event_rx);
|
|
let (response_tx, response_rx) = oneshot::channel();
|
|
if command_tx
|
|
.send(RemoteClientCommand::Shutdown { response_tx })
|
|
.await
|
|
.is_ok()
|
|
&& let Ok(Ok(close_result)) = timeout(SHUTDOWN_TIMEOUT, response_rx).await
|
|
{
|
|
close_result?;
|
|
}
|
|
|
|
if let Err(_elapsed) = timeout(SHUTDOWN_TIMEOUT, &mut worker_handle).await {
|
|
worker_handle.abort();
|
|
let _ = worker_handle.await;
|
|
}
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
impl RemoteAppServerRequestHandle {
|
|
pub async fn request(&self, request: ClientRequest) -> IoResult<RequestResult> {
|
|
let (response_tx, response_rx) = oneshot::channel();
|
|
self.command_tx
|
|
.send(RemoteClientCommand::Request {
|
|
request: Box::new(request),
|
|
response_tx,
|
|
})
|
|
.await
|
|
.map_err(|_| {
|
|
IoError::new(
|
|
ErrorKind::BrokenPipe,
|
|
"remote app-server worker channel is closed",
|
|
)
|
|
})?;
|
|
response_rx.await.map_err(|_| {
|
|
IoError::new(
|
|
ErrorKind::BrokenPipe,
|
|
"remote app-server request channel is closed",
|
|
)
|
|
})?
|
|
}
|
|
|
|
pub async fn request_typed<T>(&self, request: ClientRequest) -> Result<T, TypedRequestError>
|
|
where
|
|
T: DeserializeOwned,
|
|
{
|
|
let method = request_method_name(&request);
|
|
let response =
|
|
self.request(request)
|
|
.await
|
|
.map_err(|source| TypedRequestError::Transport {
|
|
method: method.clone(),
|
|
source,
|
|
})?;
|
|
let result = response.map_err(|source| TypedRequestError::Server {
|
|
method: method.clone(),
|
|
source,
|
|
})?;
|
|
serde_json::from_value(result)
|
|
.map_err(|source| TypedRequestError::Deserialize { method, source })
|
|
}
|
|
}
|
|
|
|
async fn connect_websocket_endpoint(
|
|
websocket_url: String,
|
|
auth_token: Option<String>,
|
|
) -> IoResult<(String, WebSocketStream<MaybeTlsStream<TcpStream>>)> {
|
|
let url = Url::parse(&websocket_url).map_err(|err| {
|
|
IoError::new(
|
|
ErrorKind::InvalidInput,
|
|
format!("invalid websocket URL `{websocket_url}`: {err}"),
|
|
)
|
|
})?;
|
|
if auth_token.is_some() && !websocket_url_supports_auth_token(&url) {
|
|
return Err(IoError::new(
|
|
ErrorKind::InvalidInput,
|
|
format!(
|
|
"remote auth tokens require `wss://` or loopback `ws://` URLs; got `{websocket_url}`"
|
|
),
|
|
));
|
|
}
|
|
|
|
let mut request = url.as_str().into_client_request().map_err(|err| {
|
|
IoError::new(
|
|
ErrorKind::InvalidInput,
|
|
format!("invalid websocket URL `{websocket_url}`: {err}"),
|
|
)
|
|
})?;
|
|
if let Some(auth_token) = auth_token.as_deref() {
|
|
let header_value =
|
|
HeaderValue::from_str(&format!("Bearer {auth_token}")).map_err(|err| {
|
|
IoError::new(
|
|
ErrorKind::InvalidInput,
|
|
format!("invalid remote authorization header value: {err}"),
|
|
)
|
|
})?;
|
|
request.headers_mut().insert(AUTHORIZATION, header_value);
|
|
}
|
|
|
|
ensure_rustls_crypto_provider();
|
|
let websocket_config = remote_websocket_config();
|
|
let stream = timeout(
|
|
CONNECT_TIMEOUT,
|
|
connect_async_with_config(
|
|
request,
|
|
Some(websocket_config),
|
|
/*disable_nagle*/ false,
|
|
),
|
|
)
|
|
.await
|
|
.map_err(|_| {
|
|
IoError::new(
|
|
ErrorKind::TimedOut,
|
|
format!("timed out connecting to remote app server at `{websocket_url}`"),
|
|
)
|
|
})?
|
|
.map(|(stream, _response)| stream)
|
|
.map_err(|err| {
|
|
IoError::other(format!(
|
|
"failed to connect to remote app server at `{websocket_url}`: {err}"
|
|
))
|
|
})?;
|
|
|
|
Ok((websocket_url, stream))
|
|
}
|
|
|
|
async fn connect_unix_socket_endpoint(
|
|
socket_path: AbsolutePathBuf,
|
|
) -> IoResult<(String, WebSocketStream<UnixStream>)> {
|
|
let endpoint = format!("unix://{}", socket_path.display());
|
|
let request = UDS_WEBSOCKET_HANDSHAKE_URL
|
|
.into_client_request()
|
|
.map_err(|err| {
|
|
IoError::new(
|
|
ErrorKind::InvalidInput,
|
|
format!("invalid UDS websocket handshake URL: {err}"),
|
|
)
|
|
})?;
|
|
let stream = timeout(CONNECT_TIMEOUT, UnixStream::connect(socket_path.as_path()))
|
|
.await
|
|
.map_err(|_| {
|
|
IoError::new(
|
|
ErrorKind::TimedOut,
|
|
format!("timed out connecting to remote app server at `{endpoint}`"),
|
|
)
|
|
})?
|
|
.map_err(|err| {
|
|
IoError::other(format!(
|
|
"failed to connect to remote app server at `{endpoint}`: {err}"
|
|
))
|
|
})?;
|
|
let websocket_config = remote_websocket_config();
|
|
let stream = timeout(
|
|
CONNECT_TIMEOUT,
|
|
client_async_with_config(request, stream, Some(websocket_config)),
|
|
)
|
|
.await
|
|
.map_err(|_| {
|
|
IoError::new(
|
|
ErrorKind::TimedOut,
|
|
format!("timed out upgrading remote app server at `{endpoint}`"),
|
|
)
|
|
})?
|
|
.map(|(stream, _response)| stream)
|
|
.map_err(|err| {
|
|
IoError::other(format!(
|
|
"failed to upgrade remote app server at `{endpoint}`: {err}"
|
|
))
|
|
})?;
|
|
|
|
Ok((endpoint, stream))
|
|
}
|
|
|
|
fn remote_websocket_config() -> WebSocketConfig {
|
|
WebSocketConfig::default()
|
|
.max_frame_size(Some(REMOTE_APP_SERVER_MAX_WEBSOCKET_MESSAGE_SIZE))
|
|
.max_message_size(Some(REMOTE_APP_SERVER_MAX_WEBSOCKET_MESSAGE_SIZE))
|
|
}
|
|
|
|
async fn initialize_remote_connection<S>(
|
|
stream: &mut WebSocketStream<S>,
|
|
endpoint: &str,
|
|
params: InitializeParams,
|
|
initialize_timeout: Duration,
|
|
) -> IoResult<Vec<AppServerEvent>>
|
|
where
|
|
S: AsyncRead + AsyncWrite + Unpin,
|
|
{
|
|
let initialize_request_id = RequestId::String("initialize".to_string());
|
|
let mut pending_events = Vec::new();
|
|
write_jsonrpc_message(
|
|
stream,
|
|
JSONRPCMessage::Request(jsonrpc_request_from_client_request(
|
|
ClientRequest::Initialize {
|
|
request_id: initialize_request_id.clone(),
|
|
params,
|
|
},
|
|
)),
|
|
endpoint,
|
|
)
|
|
.await?;
|
|
|
|
timeout(initialize_timeout, async {
|
|
loop {
|
|
match stream.next().await {
|
|
Some(Ok(Message::Text(text))) => {
|
|
let message = serde_json::from_str::<JSONRPCMessage>(&text).map_err(|err| {
|
|
IoError::other(format!(
|
|
"remote app server at `{endpoint}` sent invalid initialize response: {err}"
|
|
))
|
|
})?;
|
|
match message {
|
|
JSONRPCMessage::Response(response) if response.id == initialize_request_id => {
|
|
break Ok(());
|
|
}
|
|
JSONRPCMessage::Error(error) if error.id == initialize_request_id => {
|
|
break Err(IoError::other(format!(
|
|
"remote app server at `{endpoint}` rejected initialize: {}",
|
|
error.error.message
|
|
)));
|
|
}
|
|
JSONRPCMessage::Notification(notification) => {
|
|
if let Some(event) = app_server_event_from_notification(notification) {
|
|
pending_events.push(event);
|
|
}
|
|
}
|
|
JSONRPCMessage::Request(request) => {
|
|
let request_id = request.id.clone();
|
|
let method = request.method.clone();
|
|
match ServerRequest::try_from(request) {
|
|
Ok(request) => {
|
|
pending_events.push(AppServerEvent::ServerRequest(request));
|
|
}
|
|
Err(err) => {
|
|
warn!(%err, method, "rejecting unknown remote app-server request during initialize");
|
|
write_jsonrpc_message(
|
|
stream,
|
|
JSONRPCMessage::Error(JSONRPCError {
|
|
error: JSONRPCErrorError {
|
|
code: -32601,
|
|
message: format!(
|
|
"unsupported remote app-server request `{method}`"
|
|
),
|
|
data: None,
|
|
},
|
|
id: request_id,
|
|
}),
|
|
endpoint,
|
|
)
|
|
.await?;
|
|
}
|
|
}
|
|
}
|
|
JSONRPCMessage::Response(_) | JSONRPCMessage::Error(_) => {}
|
|
}
|
|
}
|
|
Some(Ok(Message::Binary(_)))
|
|
| Some(Ok(Message::Ping(_)))
|
|
| Some(Ok(Message::Pong(_)))
|
|
| Some(Ok(Message::Frame(_))) => {}
|
|
Some(Ok(Message::Close(frame))) => {
|
|
let reason = frame
|
|
.as_ref()
|
|
.map(|frame| frame.reason.to_string())
|
|
.filter(|reason| !reason.is_empty())
|
|
.unwrap_or_else(|| "connection closed during initialize".to_string());
|
|
break Err(IoError::new(
|
|
ErrorKind::ConnectionAborted,
|
|
format!(
|
|
"remote app server at `{endpoint}` closed during initialize: {reason}"
|
|
),
|
|
));
|
|
}
|
|
Some(Err(err)) => {
|
|
break Err(IoError::other(format!(
|
|
"remote app server at `{endpoint}` transport failed during initialize: {err}"
|
|
)));
|
|
}
|
|
None => {
|
|
break Err(IoError::new(
|
|
ErrorKind::UnexpectedEof,
|
|
format!("remote app server at `{endpoint}` closed during initialize"),
|
|
));
|
|
}
|
|
}
|
|
}
|
|
})
|
|
.await
|
|
.map_err(|_| {
|
|
IoError::new(
|
|
ErrorKind::TimedOut,
|
|
format!("timed out waiting for initialize response from `{endpoint}`"),
|
|
)
|
|
})??;
|
|
|
|
write_jsonrpc_message(
|
|
stream,
|
|
JSONRPCMessage::Notification(jsonrpc_notification_from_client_notification(
|
|
ClientNotification::Initialized,
|
|
)),
|
|
endpoint,
|
|
)
|
|
.await?;
|
|
|
|
Ok(pending_events)
|
|
}
|
|
|
|
fn app_server_event_from_notification(notification: JSONRPCNotification) -> Option<AppServerEvent> {
|
|
match ServerNotification::try_from(notification) {
|
|
Ok(notification) => Some(AppServerEvent::ServerNotification(notification)),
|
|
Err(_) => None,
|
|
}
|
|
}
|
|
|
|
fn deliver_event(
|
|
event_tx: &mpsc::UnboundedSender<AppServerEvent>,
|
|
event: AppServerEvent,
|
|
) -> IoResult<()> {
|
|
event_tx.send(event).map_err(|_| {
|
|
IoError::new(
|
|
ErrorKind::BrokenPipe,
|
|
"remote app-server event consumer channel is closed",
|
|
)
|
|
})
|
|
}
|
|
|
|
fn request_id_from_client_request(request: &ClientRequest) -> RequestId {
|
|
jsonrpc_request_from_client_request(request.clone()).id
|
|
}
|
|
|
|
fn jsonrpc_request_from_client_request(request: ClientRequest) -> JSONRPCRequest {
|
|
let value = match serde_json::to_value(request) {
|
|
Ok(value) => value,
|
|
Err(err) => panic!("client request should serialize: {err}"),
|
|
};
|
|
match serde_json::from_value(value) {
|
|
Ok(request) => request,
|
|
Err(err) => panic!("client request should encode as JSON-RPC request: {err}"),
|
|
}
|
|
}
|
|
|
|
fn jsonrpc_notification_from_client_notification(
|
|
notification: ClientNotification,
|
|
) -> JSONRPCNotification {
|
|
let value = match serde_json::to_value(notification) {
|
|
Ok(value) => value,
|
|
Err(err) => panic!("client notification should serialize: {err}"),
|
|
};
|
|
match serde_json::from_value(value) {
|
|
Ok(notification) => notification,
|
|
Err(err) => panic!("client notification should encode as JSON-RPC notification: {err}"),
|
|
}
|
|
}
|
|
|
|
async fn write_jsonrpc_message<S>(
|
|
stream: &mut WebSocketStream<S>,
|
|
message: JSONRPCMessage,
|
|
endpoint: &str,
|
|
) -> IoResult<()>
|
|
where
|
|
S: AsyncRead + AsyncWrite + Unpin,
|
|
{
|
|
let payload = serde_json::to_string(&message).map_err(IoError::other)?;
|
|
stream
|
|
.send(Message::Text(payload.into()))
|
|
.await
|
|
.map_err(|err| {
|
|
IoError::other(format!(
|
|
"failed to write websocket message to `{endpoint}`: {err}"
|
|
))
|
|
})
|
|
}
|
|
|
|
fn websocket_close_error_is_already_closed(err: &TungsteniteError) -> bool {
|
|
match err {
|
|
TungsteniteError::ConnectionClosed | TungsteniteError::AlreadyClosed => true,
|
|
TungsteniteError::Io(err) => matches!(
|
|
err.kind(),
|
|
ErrorKind::BrokenPipe | ErrorKind::ConnectionReset | ErrorKind::NotConnected
|
|
),
|
|
_ => false,
|
|
}
|
|
}
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[tokio::test]
|
|
async fn shutdown_tolerates_worker_exit_after_command_is_queued() {
|
|
let (command_tx, mut command_rx) = mpsc::channel(1);
|
|
let (_event_tx, event_rx) = mpsc::unbounded_channel::<AppServerEvent>();
|
|
let worker_handle = tokio::spawn(async move {
|
|
let _ = command_rx.recv().await;
|
|
});
|
|
let client = RemoteAppServerClient {
|
|
command_tx,
|
|
event_rx,
|
|
pending_events: VecDeque::new(),
|
|
worker_handle,
|
|
};
|
|
|
|
client
|
|
.shutdown()
|
|
.await
|
|
.expect("shutdown should complete when worker exits first");
|
|
}
|
|
}
|