mirror of
https://github.com/openai/codex.git
synced 2026-05-07 21:06:39 +00:00
Compare commits
27 Commits
rust-v0.12
...
starr/exec
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9f125d25cb | ||
|
|
256760d6b9 | ||
|
|
e58b331d8f | ||
|
|
dd1c9ff41a | ||
|
|
62bd368d38 | ||
|
|
28b23c5cd3 | ||
|
|
3ff901257a | ||
|
|
c72f484068 | ||
|
|
7557a7307a | ||
|
|
08795f1b65 | ||
|
|
f47954caef | ||
|
|
c317a66c61 | ||
|
|
d4b347176a | ||
|
|
6a7112ad21 | ||
|
|
b4269e85ff | ||
|
|
29f8812a83 | ||
|
|
942a674042 | ||
|
|
6ed49d62d7 | ||
|
|
045c740618 | ||
|
|
21297834ed | ||
|
|
c00a36e727 | ||
|
|
74e96987b8 | ||
|
|
caea51d3b7 | ||
|
|
c956939cc6 | ||
|
|
0bb3f728e1 | ||
|
|
995a669971 | ||
|
|
9face2bcbf |
@@ -17,13 +17,14 @@ use tokio::sync::mpsc;
|
||||
use tokio::sync::watch;
|
||||
|
||||
use tokio::time::timeout;
|
||||
use tokio_tungstenite::connect_async;
|
||||
use tracing::debug;
|
||||
|
||||
use crate::ProcessId;
|
||||
use crate::client_api::ExecServerClientConnectOptions;
|
||||
use crate::client_api::ExecServerTransportParams;
|
||||
use crate::client_api::HttpClient;
|
||||
use crate::client_api::RemoteExecServerConnectArgs;
|
||||
use crate::client_api::StdioExecServerConnectArgs;
|
||||
use crate::connection::JsonRpcConnection;
|
||||
use crate::process::ExecProcessEvent;
|
||||
use crate::process::ExecProcessEventLog;
|
||||
@@ -105,6 +106,16 @@ impl From<RemoteExecServerConnectArgs> for ExecServerClientConnectOptions {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<StdioExecServerConnectArgs> for ExecServerClientConnectOptions {
|
||||
fn from(value: StdioExecServerConnectArgs) -> Self {
|
||||
Self {
|
||||
client_name: value.client_name,
|
||||
initialize_timeout: value.initialize_timeout,
|
||||
resume_session_id: value.resume_session_id,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RemoteExecServerConnectArgs {
|
||||
pub fn new(websocket_url: String, client_name: String) -> Self {
|
||||
Self {
|
||||
@@ -180,29 +191,25 @@ pub struct ExecServerClient {
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct LazyRemoteExecServerClient {
|
||||
websocket_url: String,
|
||||
transport_params: ExecServerTransportParams,
|
||||
client: Arc<OnceCell<ExecServerClient>>,
|
||||
}
|
||||
|
||||
impl LazyRemoteExecServerClient {
|
||||
pub(crate) fn new(websocket_url: String) -> Self {
|
||||
pub(crate) fn new(transport_params: ExecServerTransportParams) -> Self {
|
||||
Self {
|
||||
websocket_url,
|
||||
transport_params,
|
||||
client: Arc::new(OnceCell::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn get(&self) -> Result<ExecServerClient, ExecServerError> {
|
||||
self.client
|
||||
.get_or_try_init(|| async {
|
||||
ExecServerClient::connect_websocket(RemoteExecServerConnectArgs {
|
||||
websocket_url: self.websocket_url.clone(),
|
||||
client_name: "codex-environment".to_string(),
|
||||
connect_timeout: Duration::from_secs(5),
|
||||
initialize_timeout: Duration::from_secs(5),
|
||||
resume_session_id: None,
|
||||
})
|
||||
.await
|
||||
// TODO: Add reconnect/disconnect handling here instead of reusing
|
||||
// the first successfully initialized connection forever.
|
||||
.get_or_try_init(|| {
|
||||
let transport_params = self.transport_params.clone();
|
||||
async move { ExecServerClient::connect_for_transport(transport_params).await }
|
||||
})
|
||||
.await
|
||||
.cloned()
|
||||
@@ -269,32 +276,6 @@ pub enum ExecServerError {
|
||||
}
|
||||
|
||||
impl ExecServerClient {
|
||||
pub async fn connect_websocket(
|
||||
args: RemoteExecServerConnectArgs,
|
||||
) -> Result<Self, ExecServerError> {
|
||||
let websocket_url = args.websocket_url.clone();
|
||||
let connect_timeout = args.connect_timeout;
|
||||
let (stream, _) = timeout(connect_timeout, connect_async(websocket_url.as_str()))
|
||||
.await
|
||||
.map_err(|_| ExecServerError::WebSocketConnectTimeout {
|
||||
url: websocket_url.clone(),
|
||||
timeout: connect_timeout,
|
||||
})?
|
||||
.map_err(|source| ExecServerError::WebSocketConnect {
|
||||
url: websocket_url.clone(),
|
||||
source,
|
||||
})?;
|
||||
|
||||
Self::connect(
|
||||
JsonRpcConnection::from_websocket(
|
||||
stream,
|
||||
format!("exec-server websocket {websocket_url}"),
|
||||
),
|
||||
args.into(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn initialize(
|
||||
&self,
|
||||
options: ExecServerClientConnectOptions,
|
||||
@@ -443,7 +424,7 @@ impl ExecServerClient {
|
||||
.clone()
|
||||
}
|
||||
|
||||
async fn connect(
|
||||
pub(crate) async fn connect(
|
||||
connection: JsonRpcConnection,
|
||||
options: ExecServerClientConnectOptions,
|
||||
) -> Result<Self, ExecServerError> {
|
||||
@@ -893,18 +874,28 @@ mod tests {
|
||||
use codex_app_server_protocol::JSONRPCNotification;
|
||||
use codex_app_server_protocol::JSONRPCResponse;
|
||||
use pretty_assertions::assert_eq;
|
||||
use std::collections::HashMap;
|
||||
#[cfg(unix)]
|
||||
use std::path::Path;
|
||||
#[cfg(unix)]
|
||||
use std::process::Command;
|
||||
use tokio::io::AsyncBufReadExt;
|
||||
use tokio::io::AsyncWrite;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::io::BufReader;
|
||||
use tokio::io::duplex;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::oneshot;
|
||||
use tokio::time::Duration;
|
||||
#[cfg(unix)]
|
||||
use tokio::time::sleep;
|
||||
use tokio::time::timeout;
|
||||
|
||||
use super::ExecServerClient;
|
||||
use super::ExecServerClientConnectOptions;
|
||||
use crate::ProcessId;
|
||||
use crate::client_api::StdioExecServerCommand;
|
||||
use crate::client_api::StdioExecServerConnectArgs;
|
||||
use crate::connection::JsonRpcConnection;
|
||||
use crate::process::ExecProcessEvent;
|
||||
use crate::protocol::EXEC_CLOSED_METHOD;
|
||||
@@ -942,6 +933,162 @@ mod tests {
|
||||
.expect("json-rpc line should write");
|
||||
}
|
||||
|
||||
#[cfg(not(windows))]
|
||||
#[tokio::test]
|
||||
async fn connect_stdio_command_initializes_json_rpc_client() {
|
||||
let client = ExecServerClient::connect_stdio_command(StdioExecServerConnectArgs {
|
||||
command: StdioExecServerCommand {
|
||||
program: "sh".to_string(),
|
||||
args: vec![
|
||||
"-c".to_string(),
|
||||
"read _line; printf '%s\\n' '{\"id\":1,\"result\":{\"sessionId\":\"stdio-test\"}}'; read _line; sleep 60".to_string(),
|
||||
],
|
||||
env: HashMap::new(),
|
||||
cwd: None,
|
||||
},
|
||||
client_name: "stdio-test-client".to_string(),
|
||||
initialize_timeout: Duration::from_secs(1),
|
||||
resume_session_id: None,
|
||||
})
|
||||
.await
|
||||
.expect("stdio client should connect");
|
||||
|
||||
assert_eq!(client.session_id().as_deref(), Some("stdio-test"));
|
||||
}
|
||||
|
||||
#[cfg(windows)]
|
||||
#[tokio::test]
|
||||
async fn connect_stdio_command_initializes_json_rpc_client_on_windows() {
|
||||
let client = ExecServerClient::connect_stdio_command(StdioExecServerConnectArgs {
|
||||
command: StdioExecServerCommand {
|
||||
program: "powershell".to_string(),
|
||||
args: vec![
|
||||
"-NoProfile".to_string(),
|
||||
"-Command".to_string(),
|
||||
"$null = [Console]::In.ReadLine(); [Console]::Out.WriteLine('{\"id\":1,\"result\":{\"sessionId\":\"stdio-test\"}}'); $null = [Console]::In.ReadLine(); Start-Sleep -Seconds 60".to_string(),
|
||||
],
|
||||
env: HashMap::new(),
|
||||
cwd: None,
|
||||
},
|
||||
client_name: "stdio-test-client".to_string(),
|
||||
initialize_timeout: Duration::from_secs(1),
|
||||
resume_session_id: None,
|
||||
})
|
||||
.await
|
||||
.expect("stdio client should connect");
|
||||
|
||||
assert_eq!(client.session_id().as_deref(), Some("stdio-test"));
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
#[tokio::test]
|
||||
async fn dropping_stdio_client_terminates_spawned_process() {
|
||||
let tempdir = tempfile::tempdir().expect("tempdir should be created");
|
||||
let pid_file = tempdir.path().join("server.pid");
|
||||
let stdio_script = format!(
|
||||
"read _line; \
|
||||
echo \"$$\" > {}; \
|
||||
printf '%s\\n' '{{\"id\":1,\"result\":{{\"sessionId\":\"stdio-test\"}}}}'; \
|
||||
read _line; \
|
||||
sleep 60",
|
||||
shell_quote(pid_file.as_path()),
|
||||
);
|
||||
|
||||
let client = ExecServerClient::connect_stdio_command(StdioExecServerConnectArgs {
|
||||
command: StdioExecServerCommand {
|
||||
program: "sh".to_string(),
|
||||
args: vec!["-c".to_string(), stdio_script],
|
||||
env: HashMap::new(),
|
||||
cwd: None,
|
||||
},
|
||||
client_name: "stdio-test-client".to_string(),
|
||||
initialize_timeout: Duration::from_secs(1),
|
||||
resume_session_id: None,
|
||||
})
|
||||
.await
|
||||
.expect("stdio client should connect");
|
||||
let server_pid = read_pid_file(pid_file.as_path()).await;
|
||||
assert!(
|
||||
process_exists(server_pid),
|
||||
"spawned stdio process should be running before client drop"
|
||||
);
|
||||
|
||||
drop(client);
|
||||
|
||||
wait_for_process_exit(server_pid).await;
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
#[tokio::test]
|
||||
async fn malformed_stdio_message_terminates_spawned_process() {
|
||||
let tempdir = tempfile::tempdir().expect("tempdir should be created");
|
||||
let pid_file = tempdir.path().join("server.pid");
|
||||
let stdio_script = format!(
|
||||
"read _line; \
|
||||
echo \"$$\" > {}; \
|
||||
printf '%s\\n' 'not-json'; \
|
||||
sleep 60",
|
||||
shell_quote(pid_file.as_path()),
|
||||
);
|
||||
|
||||
let result = ExecServerClient::connect_stdio_command(StdioExecServerConnectArgs {
|
||||
command: StdioExecServerCommand {
|
||||
program: "sh".to_string(),
|
||||
args: vec!["-c".to_string(), stdio_script],
|
||||
env: HashMap::new(),
|
||||
cwd: None,
|
||||
},
|
||||
client_name: "stdio-test-client".to_string(),
|
||||
initialize_timeout: Duration::from_secs(1),
|
||||
resume_session_id: None,
|
||||
})
|
||||
.await;
|
||||
assert!(result.is_err(), "malformed stdio server should not connect");
|
||||
|
||||
let server_pid = read_pid_file(pid_file.as_path()).await;
|
||||
wait_for_process_exit(server_pid).await;
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
async fn read_pid_file(path: &Path) -> u32 {
|
||||
for _ in 0..20 {
|
||||
if let Ok(contents) = std::fs::read_to_string(path) {
|
||||
return contents
|
||||
.trim()
|
||||
.parse()
|
||||
.expect("pid file should contain a pid");
|
||||
}
|
||||
sleep(Duration::from_millis(50)).await;
|
||||
}
|
||||
panic!("pid file {} should be written", path.display());
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
async fn wait_for_process_exit(pid: u32) {
|
||||
for _ in 0..20 {
|
||||
if !process_exists(pid) {
|
||||
return;
|
||||
}
|
||||
sleep(Duration::from_millis(100)).await;
|
||||
}
|
||||
panic!("process {pid} should exit");
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
fn process_exists(pid: u32) -> bool {
|
||||
Command::new("kill")
|
||||
.arg("-0")
|
||||
.arg(pid.to_string())
|
||||
.status()
|
||||
.is_ok_and(|status| status.success())
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
fn shell_quote(path: &Path) -> String {
|
||||
let value = path.to_string_lossy();
|
||||
format!("'{}'", value.replace('\'', "'\\''"))
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn process_events_are_delivered_in_seq_order_when_notifications_are_reordered() {
|
||||
let (client_stdin, server_reader) = duplex(1 << 20);
|
||||
@@ -1085,6 +1232,92 @@ mod tests {
|
||||
server.await.expect("server task should finish");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn transport_disconnect_fails_sessions_and_rejects_new_sessions() {
|
||||
let (client_stdin, server_reader) = duplex(1 << 20);
|
||||
let (mut server_writer, client_stdout) = duplex(1 << 20);
|
||||
let (disconnect_tx, disconnect_rx) = oneshot::channel();
|
||||
let server = tokio::spawn(async move {
|
||||
let mut lines = BufReader::new(server_reader).lines();
|
||||
let initialize = read_jsonrpc_line(&mut lines).await;
|
||||
let request = match initialize {
|
||||
JSONRPCMessage::Request(request) if request.method == INITIALIZE_METHOD => request,
|
||||
other => panic!("expected initialize request, got {other:?}"),
|
||||
};
|
||||
write_jsonrpc_line(
|
||||
&mut server_writer,
|
||||
JSONRPCMessage::Response(JSONRPCResponse {
|
||||
id: request.id,
|
||||
result: serde_json::to_value(InitializeResponse {
|
||||
session_id: "session-1".to_string(),
|
||||
})
|
||||
.expect("initialize response should serialize"),
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
let initialized = read_jsonrpc_line(&mut lines).await;
|
||||
match initialized {
|
||||
JSONRPCMessage::Notification(notification)
|
||||
if notification.method == INITIALIZED_METHOD => {}
|
||||
other => panic!("expected initialized notification, got {other:?}"),
|
||||
}
|
||||
|
||||
let _ = disconnect_rx.await;
|
||||
drop(server_writer);
|
||||
});
|
||||
|
||||
let client = ExecServerClient::connect(
|
||||
JsonRpcConnection::from_stdio(
|
||||
client_stdout,
|
||||
client_stdin,
|
||||
"test-exec-server-client".to_string(),
|
||||
),
|
||||
ExecServerClientConnectOptions::default(),
|
||||
)
|
||||
.await
|
||||
.expect("client should connect");
|
||||
|
||||
let process_id = ProcessId::from("disconnect");
|
||||
let session = client
|
||||
.register_session(&process_id)
|
||||
.await
|
||||
.expect("session should register");
|
||||
let mut events = session.subscribe_events();
|
||||
|
||||
disconnect_tx.send(()).expect("disconnect should signal");
|
||||
|
||||
let event = timeout(Duration::from_secs(1), events.recv())
|
||||
.await
|
||||
.expect("session failure should not time out")
|
||||
.expect("session event stream should stay open");
|
||||
let ExecProcessEvent::Failed(message) = event else {
|
||||
panic!("expected session failure after disconnect, got {event:?}");
|
||||
};
|
||||
assert_eq!(message, "exec-server transport disconnected");
|
||||
|
||||
let response = session
|
||||
.read(
|
||||
/*after_seq*/ None, /*max_bytes*/ None, /*wait_ms*/ None,
|
||||
)
|
||||
.await
|
||||
.expect("disconnected session read should synthesize a response");
|
||||
assert_eq!(
|
||||
response.failure.as_deref(),
|
||||
Some("exec-server transport disconnected")
|
||||
);
|
||||
assert!(response.closed);
|
||||
|
||||
let new_session = client.register_session(&ProcessId::from("new")).await;
|
||||
assert!(matches!(
|
||||
new_session,
|
||||
Err(super::ExecServerError::Disconnected(_))
|
||||
));
|
||||
|
||||
drop(client);
|
||||
server.await.expect("server task should finish");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn wake_notifications_do_not_block_other_sessions() {
|
||||
let (client_stdin, server_reader) = duplex(1 << 20);
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use std::time::Duration;
|
||||
|
||||
use futures::future::BoxFuture;
|
||||
@@ -25,6 +27,32 @@ pub struct RemoteExecServerConnectArgs {
|
||||
pub resume_session_id: Option<String>,
|
||||
}
|
||||
|
||||
/// Stdio connection arguments for a command-backed exec-server.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub(crate) struct StdioExecServerConnectArgs {
|
||||
pub command: StdioExecServerCommand,
|
||||
pub client_name: String,
|
||||
pub initialize_timeout: Duration,
|
||||
pub resume_session_id: Option<String>,
|
||||
}
|
||||
|
||||
/// Structured process command used to start an exec-server over stdio.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub(crate) struct StdioExecServerCommand {
|
||||
pub program: String,
|
||||
pub args: Vec<String>,
|
||||
pub env: HashMap<String, String>,
|
||||
pub cwd: Option<PathBuf>,
|
||||
}
|
||||
|
||||
/// Parameters used to connect to a remote exec-server environment.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub(crate) enum ExecServerTransportParams {
|
||||
WebSocketUrl(String),
|
||||
#[allow(dead_code)]
|
||||
StdioCommand(StdioExecServerCommand),
|
||||
}
|
||||
|
||||
/// Sends HTTP requests through a runtime-selected transport.
|
||||
///
|
||||
/// This is the HTTP capability counterpart to [`crate::ExecBackend`]. Callers
|
||||
|
||||
125
codex-rs/exec-server/src/client_transport.rs
Normal file
125
codex-rs/exec-server/src/client_transport.rs
Normal file
@@ -0,0 +1,125 @@
|
||||
use std::process::Stdio;
|
||||
use std::time::Duration;
|
||||
|
||||
use tokio::io::AsyncBufReadExt;
|
||||
use tokio::io::BufReader;
|
||||
use tokio::process::Command;
|
||||
use tokio::time::timeout;
|
||||
use tokio_tungstenite::connect_async;
|
||||
use tracing::debug;
|
||||
use tracing::warn;
|
||||
|
||||
use crate::ExecServerClient;
|
||||
use crate::ExecServerError;
|
||||
use crate::client_api::RemoteExecServerConnectArgs;
|
||||
use crate::client_api::StdioExecServerCommand;
|
||||
use crate::client_api::StdioExecServerConnectArgs;
|
||||
use crate::connection::JsonRpcConnection;
|
||||
|
||||
const ENVIRONMENT_CLIENT_NAME: &str = "codex-environment";
|
||||
const ENVIRONMENT_CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
|
||||
const ENVIRONMENT_INITIALIZE_TIMEOUT: Duration = Duration::from_secs(5);
|
||||
|
||||
impl ExecServerClient {
|
||||
pub(crate) async fn connect_for_transport(
|
||||
transport_params: crate::client_api::ExecServerTransportParams,
|
||||
) -> Result<Self, ExecServerError> {
|
||||
match transport_params {
|
||||
crate::client_api::ExecServerTransportParams::WebSocketUrl(websocket_url) => {
|
||||
Self::connect_websocket(RemoteExecServerConnectArgs {
|
||||
websocket_url,
|
||||
client_name: ENVIRONMENT_CLIENT_NAME.to_string(),
|
||||
connect_timeout: ENVIRONMENT_CONNECT_TIMEOUT,
|
||||
initialize_timeout: ENVIRONMENT_INITIALIZE_TIMEOUT,
|
||||
resume_session_id: None,
|
||||
})
|
||||
.await
|
||||
}
|
||||
crate::client_api::ExecServerTransportParams::StdioCommand(command) => {
|
||||
Self::connect_stdio_command(StdioExecServerConnectArgs {
|
||||
command,
|
||||
client_name: ENVIRONMENT_CLIENT_NAME.to_string(),
|
||||
initialize_timeout: ENVIRONMENT_INITIALIZE_TIMEOUT,
|
||||
resume_session_id: None,
|
||||
})
|
||||
.await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn connect_websocket(
|
||||
args: RemoteExecServerConnectArgs,
|
||||
) -> Result<Self, ExecServerError> {
|
||||
let websocket_url = args.websocket_url.clone();
|
||||
let connect_timeout = args.connect_timeout;
|
||||
let (stream, _) = timeout(connect_timeout, connect_async(websocket_url.as_str()))
|
||||
.await
|
||||
.map_err(|_| ExecServerError::WebSocketConnectTimeout {
|
||||
url: websocket_url.clone(),
|
||||
timeout: connect_timeout,
|
||||
})?
|
||||
.map_err(|source| ExecServerError::WebSocketConnect {
|
||||
url: websocket_url.clone(),
|
||||
source,
|
||||
})?;
|
||||
|
||||
Self::connect(
|
||||
JsonRpcConnection::from_websocket(
|
||||
stream,
|
||||
format!("exec-server websocket {websocket_url}"),
|
||||
),
|
||||
args.into(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
pub(crate) async fn connect_stdio_command(
|
||||
args: StdioExecServerConnectArgs,
|
||||
) -> Result<Self, ExecServerError> {
|
||||
let mut child = stdio_command_process(&args.command)
|
||||
.stdin(Stdio::piped())
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.spawn()
|
||||
.map_err(ExecServerError::Spawn)?;
|
||||
|
||||
let stdin = child.stdin.take().ok_or_else(|| {
|
||||
ExecServerError::Protocol("spawned exec-server command has no stdin".to_string())
|
||||
})?;
|
||||
let stdout = child.stdout.take().ok_or_else(|| {
|
||||
ExecServerError::Protocol("spawned exec-server command has no stdout".to_string())
|
||||
})?;
|
||||
if let Some(stderr) = child.stderr.take() {
|
||||
tokio::spawn(async move {
|
||||
let mut lines = BufReader::new(stderr).lines();
|
||||
loop {
|
||||
match lines.next_line().await {
|
||||
Ok(Some(line)) => debug!("exec-server stdio stderr: {line}"),
|
||||
Ok(None) => break,
|
||||
Err(err) => {
|
||||
warn!("failed to read exec-server stdio stderr: {err}");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
Self::connect(
|
||||
JsonRpcConnection::from_stdio(stdout, stdin, "exec-server stdio command".to_string())
|
||||
.with_child_process(child),
|
||||
args.into(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
fn stdio_command_process(stdio_command: &StdioExecServerCommand) -> Command {
|
||||
let mut command = Command::new(&stdio_command.program);
|
||||
command.args(&stdio_command.args);
|
||||
command.envs(&stdio_command.env);
|
||||
if let Some(cwd) = &stdio_command.cwd {
|
||||
command.current_dir(cwd);
|
||||
}
|
||||
command
|
||||
}
|
||||
@@ -3,10 +3,12 @@ use futures::SinkExt;
|
||||
use futures::StreamExt;
|
||||
use tokio::io::AsyncRead;
|
||||
use tokio::io::AsyncWrite;
|
||||
use tokio::process::Child;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::watch;
|
||||
use tokio_tungstenite::WebSocketStream;
|
||||
use tokio_tungstenite::tungstenite::Message;
|
||||
use tracing::debug;
|
||||
|
||||
use tokio::io::AsyncBufReadExt;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
@@ -22,11 +24,55 @@ pub(crate) enum JsonRpcConnectionEvent {
|
||||
Disconnected { reason: Option<String> },
|
||||
}
|
||||
|
||||
pub(crate) enum JsonRpcTransport {
|
||||
Plain,
|
||||
Stdio { _transport: Box<StdioTransport> },
|
||||
}
|
||||
|
||||
impl JsonRpcTransport {
|
||||
fn from_child_process(child_process: Child) -> Self {
|
||||
Self::Stdio {
|
||||
_transport: Box::new(StdioTransport {
|
||||
child_process: Some(child_process),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct StdioTransport {
|
||||
child_process: Option<Child>,
|
||||
}
|
||||
|
||||
impl Drop for StdioTransport {
|
||||
fn drop(&mut self) {
|
||||
let Some(mut child_process) = self.child_process.take() else {
|
||||
return;
|
||||
};
|
||||
|
||||
if let Err(err) = child_process.start_kill() {
|
||||
debug!("failed to terminate exec-server stdio child: {err}");
|
||||
}
|
||||
match tokio::runtime::Handle::try_current() {
|
||||
Ok(handle) => {
|
||||
handle.spawn(async move {
|
||||
if let Err(err) = child_process.wait().await {
|
||||
debug!("failed to wait for exec-server stdio child: {err}");
|
||||
}
|
||||
});
|
||||
}
|
||||
Err(err) => {
|
||||
debug!("failed to wait for exec-server stdio child without a Tokio runtime: {err}");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct JsonRpcConnection {
|
||||
outgoing_tx: mpsc::Sender<JSONRPCMessage>,
|
||||
incoming_rx: mpsc::Receiver<JsonRpcConnectionEvent>,
|
||||
disconnected_rx: watch::Receiver<bool>,
|
||||
task_handles: Vec<tokio::task::JoinHandle<()>>,
|
||||
pub(crate) outgoing_tx: mpsc::Sender<JSONRPCMessage>,
|
||||
pub(crate) incoming_rx: mpsc::Receiver<JsonRpcConnectionEvent>,
|
||||
pub(crate) disconnected_rx: watch::Receiver<bool>,
|
||||
pub(crate) task_handles: Vec<tokio::task::JoinHandle<()>>,
|
||||
pub(crate) transport: JsonRpcTransport,
|
||||
}
|
||||
|
||||
impl JsonRpcConnection {
|
||||
@@ -117,6 +163,7 @@ impl JsonRpcConnection {
|
||||
incoming_rx,
|
||||
disconnected_rx,
|
||||
task_handles: vec![reader_task, writer_task],
|
||||
transport: JsonRpcTransport::Plain,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -251,23 +298,13 @@ impl JsonRpcConnection {
|
||||
incoming_rx,
|
||||
disconnected_rx,
|
||||
task_handles: vec![reader_task, writer_task],
|
||||
transport: JsonRpcTransport::Plain,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn into_parts(
|
||||
self,
|
||||
) -> (
|
||||
mpsc::Sender<JSONRPCMessage>,
|
||||
mpsc::Receiver<JsonRpcConnectionEvent>,
|
||||
watch::Receiver<bool>,
|
||||
Vec<tokio::task::JoinHandle<()>>,
|
||||
) {
|
||||
(
|
||||
self.outgoing_tx,
|
||||
self.incoming_rx,
|
||||
self.disconnected_rx,
|
||||
self.task_handles,
|
||||
)
|
||||
pub(crate) fn with_child_process(mut self, child_process: Child) -> Self {
|
||||
self.transport = JsonRpcTransport::from_child_process(child_process);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ use crate::ExecutorFileSystem;
|
||||
use crate::HttpClient;
|
||||
use crate::client::LazyRemoteExecServerClient;
|
||||
use crate::client::http_client::ReqwestHttpClient;
|
||||
use crate::client_api::ExecServerTransportParams;
|
||||
use crate::environment_provider::DefaultEnvironmentProvider;
|
||||
use crate::environment_provider::EnvironmentProvider;
|
||||
use crate::environment_provider::normalize_exec_server_url;
|
||||
@@ -274,7 +275,9 @@ impl Environment {
|
||||
exec_server_url: String,
|
||||
local_runtime_paths: Option<ExecServerRuntimePaths>,
|
||||
) -> Self {
|
||||
let client = LazyRemoteExecServerClient::new(exec_server_url.clone());
|
||||
let client = LazyRemoteExecServerClient::new(ExecServerTransportParams::WebSocketUrl(
|
||||
exec_server_url.clone(),
|
||||
));
|
||||
let exec_backend: Arc<dyn ExecBackend> = Arc::new(RemoteProcess::new(client.clone()));
|
||||
let filesystem: Arc<dyn ExecutorFileSystem> =
|
||||
Arc::new(RemoteFileSystem::new(client.clone()));
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
mod client;
|
||||
mod client_api;
|
||||
mod client_transport;
|
||||
mod connection;
|
||||
mod environment;
|
||||
mod environment_provider;
|
||||
|
||||
@@ -23,6 +23,7 @@ use tokio::task::JoinHandle;
|
||||
|
||||
use crate::connection::JsonRpcConnection;
|
||||
use crate::connection::JsonRpcConnectionEvent;
|
||||
use crate::connection::JsonRpcTransport;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum RpcCallError {
|
||||
@@ -58,11 +59,9 @@ pub(crate) enum RpcServerOutboundMessage {
|
||||
request_id: RequestId,
|
||||
error: JSONRPCErrorError,
|
||||
},
|
||||
#[allow(dead_code)]
|
||||
Notification(JSONRPCNotification),
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct RpcNotificationSender {
|
||||
outgoing_tx: mpsc::Sender<RpcServerOutboundMessage>,
|
||||
@@ -84,7 +83,6 @@ impl RpcNotificationSender {
|
||||
.map_err(|_| internal_error("RPC connection closed while sending response".into()))
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub(crate) async fn notify<P: Serialize>(
|
||||
&self,
|
||||
method: &str,
|
||||
@@ -229,12 +227,19 @@ pub(crate) struct RpcClient {
|
||||
disconnected_rx: watch::Receiver<bool>,
|
||||
next_request_id: AtomicI64,
|
||||
transport_tasks: Vec<JoinHandle<()>>,
|
||||
_transport: JsonRpcTransport,
|
||||
reader_task: JoinHandle<()>,
|
||||
}
|
||||
|
||||
impl RpcClient {
|
||||
pub(crate) fn new(connection: JsonRpcConnection) -> (Self, mpsc::Receiver<RpcClientEvent>) {
|
||||
let (write_tx, mut incoming_rx, disconnected_rx, transport_tasks) = connection.into_parts();
|
||||
let JsonRpcConnection {
|
||||
outgoing_tx: write_tx,
|
||||
mut incoming_rx,
|
||||
disconnected_rx,
|
||||
task_handles: transport_tasks,
|
||||
transport,
|
||||
} = connection;
|
||||
let pending = Arc::new(Mutex::new(HashMap::<RequestId, PendingRequest>::new()));
|
||||
let (event_tx, event_rx) = mpsc::channel(128);
|
||||
|
||||
@@ -275,6 +280,7 @@ impl RpcClient {
|
||||
disconnected_rx,
|
||||
next_request_id: AtomicI64::new(1),
|
||||
transport_tasks,
|
||||
_transport: transport,
|
||||
reader_task,
|
||||
},
|
||||
event_rx,
|
||||
@@ -357,7 +363,6 @@ impl RpcClient {
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[allow(dead_code)]
|
||||
pub(crate) async fn pending_request_count(&self) -> usize {
|
||||
self.pending.lock().await.len()
|
||||
}
|
||||
@@ -565,11 +570,9 @@ mod tests {
|
||||
async fn rpc_client_matches_out_of_order_responses_by_request_id() {
|
||||
let (client_stdin, server_reader) = tokio::io::duplex(4096);
|
||||
let (mut server_writer, client_stdout) = tokio::io::duplex(4096);
|
||||
let (client, _events_rx) = RpcClient::new(JsonRpcConnection::from_stdio(
|
||||
client_stdout,
|
||||
client_stdin,
|
||||
"test-rpc".to_string(),
|
||||
));
|
||||
let connection =
|
||||
JsonRpcConnection::from_stdio(client_stdout, client_stdin, "test-rpc".to_string());
|
||||
let (client, _events_rx) = RpcClient::new(connection);
|
||||
|
||||
let server = tokio::spawn(async move {
|
||||
let mut lines = BufReader::new(server_reader).lines();
|
||||
|
||||
@@ -47,8 +47,13 @@ async fn run_connection(
|
||||
runtime_paths: ExecServerRuntimePaths,
|
||||
) {
|
||||
let router = Arc::new(build_router());
|
||||
let (json_outgoing_tx, mut incoming_rx, mut disconnected_rx, connection_tasks) =
|
||||
connection.into_parts();
|
||||
let JsonRpcConnection {
|
||||
outgoing_tx: json_outgoing_tx,
|
||||
mut incoming_rx,
|
||||
mut disconnected_rx,
|
||||
task_handles: connection_tasks,
|
||||
transport: _transport,
|
||||
} = connection;
|
||||
let (outgoing_tx, mut outgoing_rx) =
|
||||
mpsc::channel::<RpcServerOutboundMessage>(CHANNEL_CAPACITY);
|
||||
let notifications = RpcNotificationSender::new(outgoing_tx.clone());
|
||||
|
||||
Reference in New Issue
Block a user