mirror of
https://github.com/openai/codex.git
synced 2026-05-09 13:52:41 +00:00
Compare commits
4 Commits
eric/codex
...
starr/exec
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2877fa623a | ||
|
|
edab28997a | ||
|
|
7c2c8edc61 | ||
|
|
62204429a8 |
1
codex-rs/Cargo.lock
generated
1
codex-rs/Cargo.lock
generated
@@ -2672,6 +2672,7 @@ dependencies = [
|
||||
"tokio",
|
||||
"tokio-tungstenite",
|
||||
"tokio-util",
|
||||
"toml 0.9.11+spec-1.1.0",
|
||||
"tracing",
|
||||
"uuid",
|
||||
]
|
||||
|
||||
@@ -450,9 +450,13 @@ struct ExecServerCommand {
|
||||
#[arg(
|
||||
long = "listen",
|
||||
value_name = "URL",
|
||||
default_value = "ws://127.0.0.1:0"
|
||||
default_value = codex_exec_server::DEFAULT_LISTEN_URL
|
||||
)]
|
||||
listen: String,
|
||||
|
||||
/// Path to exec-server configuration. Defaults to `$CODEX_HOME/exec-server.toml`.
|
||||
#[arg(long = "config-path", value_name = "PATH")]
|
||||
config: Option<PathBuf>,
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::Subcommand)]
|
||||
@@ -1257,6 +1261,15 @@ async fn run_exec_server_command(
|
||||
cmd: ExecServerCommand,
|
||||
arg0_paths: &Arg0DispatchPaths,
|
||||
) -> anyhow::Result<()> {
|
||||
let config_path = match cmd.config {
|
||||
Some(path) => path,
|
||||
None => find_codex_home()?
|
||||
.join(codex_exec_server::EXEC_SERVER_CONFIG_FILE)
|
||||
.to_path_buf(),
|
||||
};
|
||||
let options = codex_exec_server::ExecServerConfig::load_from_path(&config_path)
|
||||
.await?
|
||||
.into_run_options(&config_path)?;
|
||||
let codex_self_exe = arg0_paths
|
||||
.codex_self_exe
|
||||
.clone()
|
||||
@@ -1265,7 +1278,7 @@ async fn run_exec_server_command(
|
||||
codex_self_exe,
|
||||
arg0_paths.codex_linux_sandbox_exe.clone(),
|
||||
)?;
|
||||
codex_exec_server::run_main(&cmd.listen, runtime_paths)
|
||||
codex_exec_server::run_main_with_options(&cmd.listen, runtime_paths, options)
|
||||
.await
|
||||
.map_err(anyhow::Error::from_boxed)
|
||||
}
|
||||
@@ -1812,12 +1825,38 @@ mod tests {
|
||||
app_server
|
||||
}
|
||||
|
||||
fn exec_server_from_args(args: &[&str]) -> ExecServerCommand {
|
||||
let cli = MultitoolCli::try_parse_from(args).expect("parse");
|
||||
let Subcommand::ExecServer(exec_server) = cli.subcommand.expect("exec-server present")
|
||||
else {
|
||||
unreachable!()
|
||||
};
|
||||
exec_server
|
||||
}
|
||||
|
||||
fn default_app_server_socket_path() -> AbsolutePathBuf {
|
||||
let codex_home = find_codex_home().expect("codex home");
|
||||
codex_app_server::app_server_control_socket_path(&codex_home)
|
||||
.expect("default app-server socket path")
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn exec_server_defaults_config_path_to_none() {
|
||||
let exec_server = exec_server_from_args(["codex", "exec-server"].as_ref());
|
||||
|
||||
assert_eq!(exec_server.listen, "ws://127.0.0.1:0");
|
||||
assert_eq!(exec_server.config, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn exec_server_parses_config_path() {
|
||||
let exec_server = exec_server_from_args(
|
||||
["codex", "exec-server", "--config-path", "/tmp/exec.toml"].as_ref(),
|
||||
);
|
||||
|
||||
assert_eq!(exec_server.config, Some(PathBuf::from("/tmp/exec.toml")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn debug_prompt_input_parses_prompt_and_images() {
|
||||
let cli = MultitoolCli::try_parse_from([
|
||||
|
||||
@@ -35,11 +35,13 @@ tokio = { workspace = true, features = [
|
||||
"net",
|
||||
"process",
|
||||
"rt-multi-thread",
|
||||
"signal",
|
||||
"sync",
|
||||
"time",
|
||||
] }
|
||||
tokio-util = { workspace = true, features = ["rt"] }
|
||||
tokio-tungstenite = { workspace = true }
|
||||
toml = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
uuid = { workspace = true, features = ["v4"] }
|
||||
|
||||
|
||||
@@ -23,6 +23,14 @@ The CLI entrypoint supports:
|
||||
|
||||
- `ws://IP:PORT` (default)
|
||||
|
||||
The CLI also accepts `--config-path PATH`. When omitted, the server reads
|
||||
`$CODEX_HOME/exec-server.toml`. Missing config files are ignored. The supported
|
||||
setting is:
|
||||
|
||||
```toml
|
||||
graceful_shutdown_timeout_ms = 30000
|
||||
```
|
||||
|
||||
Wire framing:
|
||||
|
||||
- websocket: one JSON-RPC message per websocket text frame
|
||||
@@ -39,8 +47,16 @@ Each connection follows this sequence:
|
||||
If the server receives any notification other than `initialized`, it replies
|
||||
with an error using request id `-1`.
|
||||
|
||||
If the websocket connection closes, the server terminates any remaining managed
|
||||
processes for that client connection.
|
||||
If the websocket connection closes, the server detaches from its session. A
|
||||
later connection may resume the session by passing the returned `sessionId` to
|
||||
`initialize`.
|
||||
|
||||
On the first SIGINT or SIGTERM, the server stops accepting new websocket
|
||||
connections and begins a graceful drain. Existing connections stay open, but
|
||||
new `process/start` and `http/request` calls are rejected. The server exits
|
||||
after active processes and HTTP body streams finish, or after
|
||||
`graceful_shutdown_timeout_ms` elapses. A second SIGINT or SIGTERM skips the
|
||||
remaining drain and forces all sessions to stop.
|
||||
|
||||
## API
|
||||
|
||||
|
||||
169
codex-rs/exec-server/src/config.rs
Normal file
169
codex-rs/exec-server/src/config.rs
Normal file
@@ -0,0 +1,169 @@
|
||||
use std::path::Path;
|
||||
use std::time::Duration;
|
||||
|
||||
use serde::Deserialize;
|
||||
|
||||
pub const EXEC_SERVER_CONFIG_FILE: &str = "exec-server.toml";
|
||||
pub const DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(30);
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
pub struct ExecServerRunOptions {
|
||||
pub graceful_shutdown_timeout: Duration,
|
||||
}
|
||||
|
||||
impl Default for ExecServerRunOptions {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
graceful_shutdown_timeout: DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default, Deserialize, Eq, PartialEq)]
|
||||
#[serde(deny_unknown_fields)]
|
||||
pub struct ExecServerConfig {
|
||||
pub graceful_shutdown_timeout_ms: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ExecServerConfigError {
|
||||
#[error("failed to read exec-server config `{path}`: {source}")]
|
||||
Read {
|
||||
path: String,
|
||||
#[source]
|
||||
source: std::io::Error,
|
||||
},
|
||||
#[error("failed to parse exec-server config `{path}`: {source}")]
|
||||
Parse {
|
||||
path: String,
|
||||
#[source]
|
||||
source: toml::de::Error,
|
||||
},
|
||||
#[error(
|
||||
"invalid exec-server config `{path}`: graceful_shutdown_timeout_ms must be greater than 0"
|
||||
)]
|
||||
InvalidTimeout { path: String },
|
||||
}
|
||||
|
||||
impl ExecServerConfig {
|
||||
pub async fn load_from_path(path: impl AsRef<Path>) -> Result<Self, ExecServerConfigError> {
|
||||
let path = path.as_ref();
|
||||
let contents = match tokio::fs::read_to_string(path).await {
|
||||
Ok(contents) => contents,
|
||||
Err(source) if source.kind() == std::io::ErrorKind::NotFound => {
|
||||
return Ok(Self::default());
|
||||
}
|
||||
Err(source) => {
|
||||
return Err(ExecServerConfigError::Read {
|
||||
path: path.display().to_string(),
|
||||
source,
|
||||
});
|
||||
}
|
||||
};
|
||||
toml::from_str(&contents).map_err(|source| ExecServerConfigError::Parse {
|
||||
path: path.display().to_string(),
|
||||
source,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn into_run_options(
|
||||
self,
|
||||
path: impl AsRef<Path>,
|
||||
) -> Result<ExecServerRunOptions, ExecServerConfigError> {
|
||||
let Some(timeout_ms) = self.graceful_shutdown_timeout_ms else {
|
||||
return Ok(ExecServerRunOptions::default());
|
||||
};
|
||||
if timeout_ms == 0 {
|
||||
return Err(ExecServerConfigError::InvalidTimeout {
|
||||
path: path.as_ref().display().to_string(),
|
||||
});
|
||||
}
|
||||
Ok(ExecServerRunOptions {
|
||||
graceful_shutdown_timeout: Duration::from_millis(timeout_ms),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use pretty_assertions::assert_eq;
|
||||
use tempfile::TempDir;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn missing_config_uses_defaults() {
|
||||
let temp = TempDir::new().expect("tempdir");
|
||||
let path = temp.path().join(EXEC_SERVER_CONFIG_FILE);
|
||||
|
||||
let config = ExecServerConfig::load_from_path(&path)
|
||||
.await
|
||||
.expect("missing config should load");
|
||||
let options = config
|
||||
.into_run_options(&path)
|
||||
.expect("default options should validate");
|
||||
|
||||
assert_eq!(options, ExecServerRunOptions::default());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn parses_graceful_shutdown_timeout() {
|
||||
let temp = TempDir::new().expect("tempdir");
|
||||
let path = temp.path().join(EXEC_SERVER_CONFIG_FILE);
|
||||
tokio::fs::write(&path, "graceful_shutdown_timeout_ms = 125\n")
|
||||
.await
|
||||
.expect("write config");
|
||||
|
||||
let config = ExecServerConfig::load_from_path(&path)
|
||||
.await
|
||||
.expect("config should load");
|
||||
let options = config
|
||||
.into_run_options(&path)
|
||||
.expect("config should validate");
|
||||
|
||||
assert_eq!(
|
||||
options,
|
||||
ExecServerRunOptions {
|
||||
graceful_shutdown_timeout: Duration::from_millis(125),
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn malformed_config_reports_path() {
|
||||
let temp = TempDir::new().expect("tempdir");
|
||||
let path = temp.path().join(EXEC_SERVER_CONFIG_FILE);
|
||||
tokio::fs::write(&path, "graceful_shutdown_timeout_ms = ")
|
||||
.await
|
||||
.expect("write config");
|
||||
|
||||
let err = ExecServerConfig::load_from_path(&path)
|
||||
.await
|
||||
.expect_err("malformed config should fail");
|
||||
|
||||
assert!(
|
||||
err.to_string().contains(path.to_string_lossy().as_ref()),
|
||||
"error should mention path: {err}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn zero_timeout_is_invalid() {
|
||||
let temp = TempDir::new().expect("tempdir");
|
||||
let path = temp.path().join(EXEC_SERVER_CONFIG_FILE);
|
||||
|
||||
let err = ExecServerConfig {
|
||||
graceful_shutdown_timeout_ms: Some(0),
|
||||
}
|
||||
.into_run_options(&path)
|
||||
.expect_err("zero timeout should fail");
|
||||
|
||||
assert_eq!(
|
||||
err.to_string(),
|
||||
format!(
|
||||
"invalid exec-server config `{}`: graceful_shutdown_timeout_ms must be greater than 0",
|
||||
path.display()
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,6 @@
|
||||
mod client;
|
||||
mod client_api;
|
||||
mod config;
|
||||
mod connection;
|
||||
mod environment;
|
||||
mod environment_provider;
|
||||
@@ -33,6 +34,11 @@ pub use codex_file_system::FileSystemResult;
|
||||
pub use codex_file_system::FileSystemSandboxContext;
|
||||
pub use codex_file_system::ReadDirectoryEntry;
|
||||
pub use codex_file_system::RemoveOptions;
|
||||
pub use config::DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT;
|
||||
pub use config::EXEC_SERVER_CONFIG_FILE;
|
||||
pub use config::ExecServerConfig;
|
||||
pub use config::ExecServerConfigError;
|
||||
pub use config::ExecServerRunOptions;
|
||||
pub use environment::CODEX_EXEC_SERVER_URL_ENV_VAR;
|
||||
pub use environment::Environment;
|
||||
pub use environment::EnvironmentManager;
|
||||
@@ -91,3 +97,4 @@ pub use runtime_paths::ExecServerRuntimePaths;
|
||||
pub use server::DEFAULT_LISTEN_URL;
|
||||
pub use server::ExecServerListenUrlParseError;
|
||||
pub use server::run_main;
|
||||
pub use server::run_main_with_options;
|
||||
|
||||
@@ -142,6 +142,17 @@ impl LocalProcess {
|
||||
*notification_sender = notifications;
|
||||
}
|
||||
|
||||
pub(crate) async fn active_process_count(&self) -> usize {
|
||||
let processes = self.inner.processes.lock().await;
|
||||
processes
|
||||
.values()
|
||||
.filter(|process| match process {
|
||||
ProcessEntry::Starting => true,
|
||||
ProcessEntry::Running(process) => !process.closed,
|
||||
})
|
||||
.count()
|
||||
}
|
||||
|
||||
async fn start_process(
|
||||
&self,
|
||||
params: ExecParams,
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
mod drain;
|
||||
mod file_system_handler;
|
||||
mod handler;
|
||||
mod process_handler;
|
||||
@@ -10,11 +11,20 @@ pub(crate) use handler::ExecServerHandler;
|
||||
pub use transport::DEFAULT_LISTEN_URL;
|
||||
pub use transport::ExecServerListenUrlParseError;
|
||||
|
||||
use crate::ExecServerRunOptions;
|
||||
use crate::ExecServerRuntimePaths;
|
||||
|
||||
pub async fn run_main(
|
||||
listen_url: &str,
|
||||
runtime_paths: ExecServerRuntimePaths,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
transport::run_transport(listen_url, runtime_paths).await
|
||||
run_main_with_options(listen_url, runtime_paths, ExecServerRunOptions::default()).await
|
||||
}
|
||||
|
||||
pub async fn run_main_with_options(
|
||||
listen_url: &str,
|
||||
runtime_paths: ExecServerRuntimePaths,
|
||||
options: ExecServerRunOptions,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
transport::run_transport(listen_url, runtime_paths, options).await
|
||||
}
|
||||
|
||||
116
codex-rs/exec-server/src/server/drain.rs
Normal file
116
codex-rs/exec-server/src/server/drain.rs
Normal file
@@ -0,0 +1,116 @@
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::atomic::AtomicUsize;
|
||||
use std::sync::atomic::Ordering;
|
||||
|
||||
use crate::rpc::invalid_request;
|
||||
use codex_app_server_protocol::JSONRPCErrorError;
|
||||
|
||||
pub(crate) struct DrainState {
|
||||
draining: AtomicBool,
|
||||
active_rpc_requests: AtomicUsize,
|
||||
active_http_requests: AtomicUsize,
|
||||
}
|
||||
|
||||
pub(crate) struct ActiveRpcRequest {
|
||||
state: Arc<DrainState>,
|
||||
}
|
||||
|
||||
pub(crate) struct ActiveHttpRequest {
|
||||
state: Arc<DrainState>,
|
||||
}
|
||||
|
||||
impl DrainState {
|
||||
pub(crate) fn new() -> Arc<Self> {
|
||||
Arc::new(Self {
|
||||
draining: AtomicBool::new(false),
|
||||
active_rpc_requests: AtomicUsize::new(0),
|
||||
active_http_requests: AtomicUsize::new(0),
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn begin(&self) {
|
||||
self.draining.store(true, Ordering::SeqCst);
|
||||
}
|
||||
|
||||
pub(crate) fn is_draining(&self) -> bool {
|
||||
self.draining.load(Ordering::SeqCst)
|
||||
}
|
||||
|
||||
pub(crate) fn try_start_process(&self) -> Result<(), JSONRPCErrorError> {
|
||||
if self.is_draining() {
|
||||
return Err(invalid_request(
|
||||
"exec-server is draining; new processes are not accepted".to_string(),
|
||||
));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn start_rpc_request(self: &Arc<Self>) -> ActiveRpcRequest {
|
||||
self.active_rpc_requests.fetch_add(1, Ordering::SeqCst);
|
||||
ActiveRpcRequest {
|
||||
state: Arc::clone(self),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn try_start_http_request(
|
||||
self: &Arc<Self>,
|
||||
) -> Result<ActiveHttpRequest, JSONRPCErrorError> {
|
||||
if self.is_draining() {
|
||||
return Err(invalid_request(
|
||||
"exec-server is draining; new HTTP requests are not accepted".to_string(),
|
||||
));
|
||||
}
|
||||
self.active_http_requests.fetch_add(1, Ordering::SeqCst);
|
||||
if self.is_draining() {
|
||||
self.active_http_requests.fetch_sub(1, Ordering::SeqCst);
|
||||
return Err(invalid_request(
|
||||
"exec-server is draining; new HTTP requests are not accepted".to_string(),
|
||||
));
|
||||
}
|
||||
Ok(ActiveHttpRequest {
|
||||
state: Arc::clone(self),
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn active_http_request_count(&self) -> usize {
|
||||
self.active_http_requests.load(Ordering::SeqCst)
|
||||
}
|
||||
|
||||
pub(crate) fn active_rpc_request_count(&self) -> usize {
|
||||
self.active_rpc_requests.load(Ordering::SeqCst)
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for ActiveRpcRequest {
|
||||
fn drop(&mut self) {
|
||||
self.state
|
||||
.active_rpc_requests
|
||||
.fetch_sub(1, Ordering::SeqCst);
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for ActiveHttpRequest {
|
||||
fn drop(&mut self) {
|
||||
self.state
|
||||
.active_http_requests
|
||||
.fetch_sub(1, Ordering::SeqCst);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::DrainState;
|
||||
|
||||
#[test]
|
||||
fn rpc_request_guard_tracks_active_count_until_drop() {
|
||||
let state = DrainState::new();
|
||||
assert_eq!(state.active_rpc_request_count(), 0);
|
||||
|
||||
let guard = state.start_rpc_request();
|
||||
assert_eq!(state.active_rpc_request_count(), 1);
|
||||
|
||||
drop(guard);
|
||||
assert_eq!(state.active_rpc_request_count(), 0);
|
||||
}
|
||||
}
|
||||
@@ -43,12 +43,15 @@ use crate::rpc::RpcNotificationSender;
|
||||
use crate::rpc::internal_error;
|
||||
use crate::rpc::invalid_params;
|
||||
use crate::rpc::invalid_request;
|
||||
use crate::server::drain::ActiveHttpRequest;
|
||||
use crate::server::drain::DrainState;
|
||||
use crate::server::file_system_handler::FileSystemHandler;
|
||||
use crate::server::session_registry::SessionHandle;
|
||||
use crate::server::session_registry::SessionRegistry;
|
||||
|
||||
pub(crate) struct ExecServerHandler {
|
||||
session_registry: Arc<SessionRegistry>,
|
||||
drain_state: Arc<DrainState>,
|
||||
notifications: RpcNotificationSender,
|
||||
session: StdMutex<Option<SessionHandle>>,
|
||||
active_body_stream_ids: Mutex<HashSet<String>>,
|
||||
@@ -62,11 +65,13 @@ pub(crate) struct ExecServerHandler {
|
||||
impl ExecServerHandler {
|
||||
pub(crate) fn new(
|
||||
session_registry: Arc<SessionRegistry>,
|
||||
drain_state: Arc<DrainState>,
|
||||
notifications: RpcNotificationSender,
|
||||
runtime_paths: ExecServerRuntimePaths,
|
||||
) -> Self {
|
||||
Self {
|
||||
session_registry,
|
||||
drain_state,
|
||||
notifications,
|
||||
session: StdMutex::new(None),
|
||||
active_body_stream_ids: Mutex::new(HashSet::new()),
|
||||
@@ -138,6 +143,7 @@ impl ExecServerHandler {
|
||||
|
||||
pub(crate) async fn exec(&self, params: ExecParams) -> Result<ExecResponse, JSONRPCErrorError> {
|
||||
let session = self.require_initialized_for("exec")?;
|
||||
self.drain_state.try_start_process()?;
|
||||
session.process().exec(params).await
|
||||
}
|
||||
|
||||
@@ -173,6 +179,7 @@ impl ExecServerHandler {
|
||||
params: HttpRequestParams,
|
||||
) -> Result<(), JSONRPCErrorError> {
|
||||
self.require_initialized_for("http")?;
|
||||
let http_guard = self.drain_state.try_start_http_request()?;
|
||||
let stream_response = params.stream_response;
|
||||
let http_request_id = params.request_id.clone();
|
||||
if stream_response {
|
||||
@@ -203,7 +210,8 @@ impl ExecServerHandler {
|
||||
return Err(error);
|
||||
}
|
||||
if let Some(pending_stream) = pending_stream {
|
||||
self.start_http_body_stream(pending_stream).await;
|
||||
self.start_http_body_stream(pending_stream, http_guard)
|
||||
.await;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
@@ -307,6 +315,7 @@ impl ExecServerHandler {
|
||||
async fn start_http_body_stream(
|
||||
self: &Arc<Self>,
|
||||
pending_stream: PendingReqwestHttpBodyStream,
|
||||
http_guard: ActiveHttpRequest,
|
||||
) {
|
||||
let request_id = pending_stream.request_id.clone();
|
||||
if self.background_task_shutdown.is_cancelled() {
|
||||
@@ -318,6 +327,7 @@ impl ExecServerHandler {
|
||||
let notifications = self.notifications.clone();
|
||||
let shutdown = self.background_task_shutdown.clone();
|
||||
self.background_tasks.spawn(async move {
|
||||
let _http_guard = http_guard;
|
||||
tokio::select! {
|
||||
_ = shutdown.cancelled() => {}
|
||||
_ = ReqwestHttpRequestRunner::stream_body(pending_stream, notifications) => {}
|
||||
|
||||
@@ -16,6 +16,7 @@ use crate::protocol::ReadResponse;
|
||||
use crate::protocol::TerminateParams;
|
||||
use crate::protocol::TerminateResponse;
|
||||
use crate::rpc::RpcNotificationSender;
|
||||
use crate::server::drain::DrainState;
|
||||
use crate::server::session_registry::SessionRegistry;
|
||||
|
||||
fn exec_params(process_id: &str) -> ExecParams {
|
||||
@@ -80,6 +81,7 @@ async fn initialized_handler() -> Arc<ExecServerHandler> {
|
||||
let registry = SessionRegistry::new();
|
||||
let handler = Arc::new(ExecServerHandler::new(
|
||||
registry,
|
||||
DrainState::new(),
|
||||
RpcNotificationSender::new(outgoing_tx),
|
||||
test_runtime_paths(),
|
||||
));
|
||||
@@ -158,6 +160,7 @@ async fn long_poll_read_fails_after_session_resume() {
|
||||
let registry = SessionRegistry::new();
|
||||
let first_handler = Arc::new(ExecServerHandler::new(
|
||||
Arc::clone(®istry),
|
||||
DrainState::new(),
|
||||
RpcNotificationSender::new(first_tx),
|
||||
test_runtime_paths(),
|
||||
));
|
||||
@@ -198,6 +201,7 @@ async fn long_poll_read_fails_after_session_resume() {
|
||||
let (second_tx, _second_rx) = mpsc::channel(16);
|
||||
let second_handler = Arc::new(ExecServerHandler::new(
|
||||
registry,
|
||||
DrainState::new(),
|
||||
RpcNotificationSender::new(second_tx),
|
||||
test_runtime_paths(),
|
||||
));
|
||||
@@ -231,6 +235,7 @@ async fn active_session_resume_is_rejected() {
|
||||
let registry = SessionRegistry::new();
|
||||
let first_handler = Arc::new(ExecServerHandler::new(
|
||||
Arc::clone(®istry),
|
||||
DrainState::new(),
|
||||
RpcNotificationSender::new(first_tx),
|
||||
test_runtime_paths(),
|
||||
));
|
||||
@@ -245,6 +250,7 @@ async fn active_session_resume_is_rejected() {
|
||||
let (second_tx, _second_rx) = mpsc::channel(16);
|
||||
let second_handler = Arc::new(ExecServerHandler::new(
|
||||
registry,
|
||||
DrainState::new(),
|
||||
RpcNotificationSender::new(second_tx),
|
||||
test_runtime_paths(),
|
||||
));
|
||||
@@ -273,6 +279,7 @@ async fn output_and_exit_are_retained_after_notification_receiver_closes() {
|
||||
let (outgoing_tx, outgoing_rx) = mpsc::channel(16);
|
||||
let handler = Arc::new(ExecServerHandler::new(
|
||||
SessionRegistry::new(),
|
||||
DrainState::new(),
|
||||
RpcNotificationSender::new(outgoing_tx),
|
||||
test_runtime_paths(),
|
||||
));
|
||||
|
||||
@@ -31,6 +31,10 @@ impl ProcessHandler {
|
||||
self.process.set_notification_sender(notifications);
|
||||
}
|
||||
|
||||
pub(crate) async fn active_process_count(&self) -> usize {
|
||||
self.process.active_process_count().await
|
||||
}
|
||||
|
||||
pub(crate) async fn exec(&self, params: ExecParams) -> Result<ExecResponse, JSONRPCErrorError> {
|
||||
self.process.exec(params).await
|
||||
}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::time::sleep;
|
||||
use tracing::debug;
|
||||
use tracing::warn;
|
||||
|
||||
@@ -14,12 +16,14 @@ use crate::rpc::encode_server_message;
|
||||
use crate::rpc::invalid_request;
|
||||
use crate::rpc::method_not_found;
|
||||
use crate::server::ExecServerHandler;
|
||||
use crate::server::drain::DrainState;
|
||||
use crate::server::registry::build_router;
|
||||
use crate::server::session_registry::SessionRegistry;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct ConnectionProcessor {
|
||||
session_registry: Arc<SessionRegistry>,
|
||||
drain_state: Arc<DrainState>,
|
||||
runtime_paths: ExecServerRuntimePaths,
|
||||
}
|
||||
|
||||
@@ -27,14 +31,36 @@ impl ConnectionProcessor {
|
||||
pub(crate) fn new(runtime_paths: ExecServerRuntimePaths) -> Self {
|
||||
Self {
|
||||
session_registry: SessionRegistry::new(),
|
||||
drain_state: DrainState::new(),
|
||||
runtime_paths,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn begin_drain(&self) {
|
||||
self.drain_state.begin();
|
||||
}
|
||||
|
||||
pub(crate) async fn wait_until_idle(&self) {
|
||||
while !self.is_idle().await {
|
||||
sleep(Duration::from_millis(25)).await;
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn shutdown_all_sessions(&self) {
|
||||
self.session_registry.shutdown_all().await;
|
||||
}
|
||||
|
||||
async fn is_idle(&self) -> bool {
|
||||
self.drain_state.active_rpc_request_count() == 0
|
||||
&& self.drain_state.active_http_request_count() == 0
|
||||
&& self.session_registry.active_process_count().await == 0
|
||||
}
|
||||
|
||||
pub(crate) async fn run_connection(&self, connection: JsonRpcConnection) {
|
||||
run_connection(
|
||||
connection,
|
||||
Arc::clone(&self.session_registry),
|
||||
Arc::clone(&self.drain_state),
|
||||
self.runtime_paths.clone(),
|
||||
)
|
||||
.await;
|
||||
@@ -44,6 +70,7 @@ impl ConnectionProcessor {
|
||||
async fn run_connection(
|
||||
connection: JsonRpcConnection,
|
||||
session_registry: Arc<SessionRegistry>,
|
||||
drain_state: Arc<DrainState>,
|
||||
runtime_paths: ExecServerRuntimePaths,
|
||||
) {
|
||||
let router = Arc::new(build_router());
|
||||
@@ -54,6 +81,7 @@ async fn run_connection(
|
||||
let notifications = RpcNotificationSender::new(outgoing_tx.clone());
|
||||
let handler = Arc::new(ExecServerHandler::new(
|
||||
session_registry,
|
||||
Arc::clone(&drain_state),
|
||||
notifications,
|
||||
runtime_paths,
|
||||
));
|
||||
@@ -96,6 +124,7 @@ async fn run_connection(
|
||||
JsonRpcConnectionEvent::Message(message) => match message {
|
||||
codex_app_server_protocol::JSONRPCMessage::Request(request) => {
|
||||
if let Some(route) = router.request_route(request.method.as_str()) {
|
||||
let _request_guard = drain_state.start_rpc_request();
|
||||
let message = tokio::select! {
|
||||
message = route(Arc::clone(&handler), request) => message,
|
||||
_ = disconnected_rx.changed() => {
|
||||
@@ -217,6 +246,7 @@ mod tests {
|
||||
use crate::protocol::ReadParams;
|
||||
use crate::protocol::TerminateParams;
|
||||
use crate::protocol::TerminateResponse;
|
||||
use crate::server::drain::DrainState;
|
||||
use crate::server::session_registry::SessionRegistry;
|
||||
|
||||
#[tokio::test]
|
||||
@@ -317,7 +347,12 @@ mod tests {
|
||||
let (server_writer, client_reader) = duplex(1 << 20);
|
||||
let connection =
|
||||
JsonRpcConnection::from_stdio(server_reader, server_writer, label.to_string());
|
||||
let task = tokio::spawn(run_connection(connection, registry, test_runtime_paths()));
|
||||
let task = tokio::spawn(run_connection(
|
||||
connection,
|
||||
registry,
|
||||
DrainState::new(),
|
||||
test_runtime_paths(),
|
||||
));
|
||||
(client_writer, BufReader::new(client_reader).lines(), task)
|
||||
}
|
||||
|
||||
|
||||
@@ -134,6 +134,28 @@ impl SessionRegistry {
|
||||
entry.process.shutdown().await;
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn active_process_count(&self) -> usize {
|
||||
let entries = {
|
||||
let sessions = self.sessions.lock().await;
|
||||
sessions.values().cloned().collect::<Vec<_>>()
|
||||
};
|
||||
let mut count = 0;
|
||||
for entry in entries {
|
||||
count += entry.process.active_process_count().await;
|
||||
}
|
||||
count
|
||||
}
|
||||
|
||||
pub(crate) async fn shutdown_all(&self) {
|
||||
let entries = {
|
||||
let mut sessions = self.sessions.lock().await;
|
||||
sessions.drain().map(|(_, entry)| entry).collect::<Vec<_>>()
|
||||
};
|
||||
for entry in entries {
|
||||
entry.process.shutdown().await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SessionRegistry {
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
use std::io::Write as _;
|
||||
use std::net::SocketAddr;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::task::JoinSet;
|
||||
use tokio::time::sleep;
|
||||
use tokio_tungstenite::accept_async;
|
||||
use tracing::info;
|
||||
use tracing::warn;
|
||||
|
||||
use crate::ExecServerRunOptions;
|
||||
use crate::ExecServerRuntimePaths;
|
||||
use crate::connection::JsonRpcConnection;
|
||||
use crate::server::processor::ConnectionProcessor;
|
||||
@@ -50,42 +54,110 @@ pub(crate) fn parse_listen_url(
|
||||
pub(crate) async fn run_transport(
|
||||
listen_url: &str,
|
||||
runtime_paths: ExecServerRuntimePaths,
|
||||
options: ExecServerRunOptions,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let bind_address = parse_listen_url(listen_url)?;
|
||||
run_websocket_listener(bind_address, runtime_paths).await
|
||||
run_websocket_listener(bind_address, runtime_paths, options).await
|
||||
}
|
||||
|
||||
async fn run_websocket_listener(
|
||||
bind_address: SocketAddr,
|
||||
runtime_paths: ExecServerRuntimePaths,
|
||||
options: ExecServerRunOptions,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let listener = TcpListener::bind(bind_address).await?;
|
||||
let local_addr = listener.local_addr()?;
|
||||
let processor = ConnectionProcessor::new(runtime_paths);
|
||||
let mut connection_tasks = JoinSet::new();
|
||||
tracing::info!("codex-exec-server listening on ws://{local_addr}");
|
||||
println!("ws://{local_addr}");
|
||||
std::io::stdout().flush()?;
|
||||
|
||||
loop {
|
||||
let (stream, peer_addr) = listener.accept().await?;
|
||||
let processor = processor.clone();
|
||||
tokio::spawn(async move {
|
||||
match accept_async(stream).await {
|
||||
Ok(websocket) => {
|
||||
processor
|
||||
.run_connection(JsonRpcConnection::from_websocket(
|
||||
websocket,
|
||||
format!("exec-server websocket {peer_addr}"),
|
||||
))
|
||||
.await;
|
||||
}
|
||||
Err(err) => {
|
||||
warn!(
|
||||
"failed to accept exec-server websocket connection from {peer_addr}: {err}"
|
||||
);
|
||||
}
|
||||
reap_finished_connections(&mut connection_tasks);
|
||||
tokio::select! {
|
||||
accept_result = listener.accept() => {
|
||||
let (stream, peer_addr) = accept_result?;
|
||||
let processor = processor.clone();
|
||||
connection_tasks.spawn(async move {
|
||||
match accept_async(stream).await {
|
||||
Ok(websocket) => {
|
||||
processor
|
||||
.run_connection(JsonRpcConnection::from_websocket(
|
||||
websocket,
|
||||
format!("exec-server websocket {peer_addr}"),
|
||||
))
|
||||
.await;
|
||||
}
|
||||
Err(err) => {
|
||||
warn!(
|
||||
"failed to accept exec-server websocket connection from {peer_addr}: {err}"
|
||||
);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
signal_result = shutdown_signal() => {
|
||||
if let Err(err) = signal_result {
|
||||
warn!("failed while waiting for exec-server shutdown signal: {err}");
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
drop(listener);
|
||||
processor.begin_drain();
|
||||
info!(
|
||||
timeout_ms = options.graceful_shutdown_timeout.as_millis(),
|
||||
"exec-server graceful shutdown started"
|
||||
);
|
||||
|
||||
tokio::select! {
|
||||
_ = processor.wait_until_idle() => {
|
||||
info!("exec-server graceful shutdown drained active work");
|
||||
}
|
||||
_ = sleep(options.graceful_shutdown_timeout) => {
|
||||
warn!("exec-server graceful shutdown timed out; forcing remaining sessions to stop");
|
||||
processor.shutdown_all_sessions().await;
|
||||
}
|
||||
signal_result = shutdown_signal() => {
|
||||
if let Err(err) = signal_result {
|
||||
warn!("failed while waiting for second exec-server shutdown signal: {err}");
|
||||
}
|
||||
warn!("exec-server received second shutdown signal; forcing remaining sessions to stop");
|
||||
processor.shutdown_all_sessions().await;
|
||||
}
|
||||
}
|
||||
|
||||
connection_tasks.abort_all();
|
||||
while connection_tasks.join_next().await.is_some() {}
|
||||
processor.shutdown_all_sessions().await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn reap_finished_connections(connection_tasks: &mut JoinSet<()>) {
|
||||
while let Some(result) = connection_tasks.try_join_next() {
|
||||
if let Err(err) = result {
|
||||
warn!("exec-server websocket connection task failed: {err}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn shutdown_signal() -> std::io::Result<()> {
|
||||
#[cfg(unix)]
|
||||
{
|
||||
let mut terminate =
|
||||
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())?;
|
||||
tokio::select! {
|
||||
result = tokio::signal::ctrl_c() => result,
|
||||
_ = terminate.recv() => Ok(()),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(unix))]
|
||||
{
|
||||
tokio::signal::ctrl_c().await
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
use std::net::SocketAddr;
|
||||
|
||||
use pretty_assertions::assert_eq;
|
||||
use tokio::task::JoinSet;
|
||||
|
||||
use super::DEFAULT_LISTEN_URL;
|
||||
use super::parse_listen_url;
|
||||
use super::reap_finished_connections;
|
||||
|
||||
#[test]
|
||||
fn parse_listen_url_accepts_default_websocket_url() {
|
||||
@@ -48,3 +50,14 @@ fn parse_listen_url_rejects_unsupported_url() {
|
||||
"unsupported --listen URL `http://127.0.0.1:1234`; expected `ws://IP:PORT`"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn reap_finished_connections_drains_completed_join_set_tasks() {
|
||||
let mut tasks = JoinSet::new();
|
||||
tasks.spawn(async {});
|
||||
tokio::task::yield_now().await;
|
||||
|
||||
reap_finished_connections(&mut tasks);
|
||||
|
||||
assert!(tasks.is_empty());
|
||||
}
|
||||
|
||||
@@ -60,7 +60,29 @@ pub(crate) async fn exec_server() -> anyhow::Result<ExecServerHarness> {
|
||||
exec_server_with_env(std::iter::empty::<(&str, &str)>()).await
|
||||
}
|
||||
|
||||
pub(crate) async fn exec_server_with_config(
|
||||
config_contents: &str,
|
||||
) -> anyhow::Result<ExecServerHarness> {
|
||||
spawn_exec_server(
|
||||
std::iter::empty::<(&str, &str)>(),
|
||||
Some(config_contents.as_bytes()),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
pub(crate) async fn exec_server_with_env<I, K, V>(env: I) -> anyhow::Result<ExecServerHarness>
|
||||
where
|
||||
I: IntoIterator<Item = (K, V)>,
|
||||
K: AsRef<std::ffi::OsStr>,
|
||||
V: AsRef<std::ffi::OsStr>,
|
||||
{
|
||||
spawn_exec_server(env, /*config_contents*/ None).await
|
||||
}
|
||||
|
||||
async fn spawn_exec_server<I, K, V>(
|
||||
env: I,
|
||||
config_contents: Option<&[u8]>,
|
||||
) -> anyhow::Result<ExecServerHarness>
|
||||
where
|
||||
I: IntoIterator<Item = (K, V)>,
|
||||
K: AsRef<std::ffi::OsStr>,
|
||||
@@ -68,6 +90,15 @@ where
|
||||
{
|
||||
let helper_paths = test_codex_helper_paths()?;
|
||||
let codex_home = TempDir::new()?;
|
||||
if let Some(config_contents) = config_contents {
|
||||
tokio::fs::write(
|
||||
codex_home
|
||||
.path()
|
||||
.join(codex_exec_server::EXEC_SERVER_CONFIG_FILE),
|
||||
config_contents,
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
let mut child = Command::new(&helper_paths.codex_exe);
|
||||
child.args(["exec-server", "--listen", "ws://127.0.0.1:0"]);
|
||||
child.stdin(Stdio::null());
|
||||
@@ -177,6 +208,55 @@ impl ExecServerHarness {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
pub(crate) fn send_sigint(&mut self) -> anyhow::Result<()> {
|
||||
self.send_signal("INT")
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
pub(crate) fn send_sigterm(&mut self) -> anyhow::Result<()> {
|
||||
self.send_signal("TERM")
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
fn send_signal(&mut self, signal: &str) -> anyhow::Result<()> {
|
||||
let pid = self
|
||||
.child
|
||||
.id()
|
||||
.ok_or_else(|| anyhow!("exec-server process has no pid"))?;
|
||||
let status = std::process::Command::new("kill")
|
||||
.arg(format!("-{signal}"))
|
||||
.arg(pid.to_string())
|
||||
.status()?;
|
||||
if !status.success() {
|
||||
return Err(anyhow!(
|
||||
"failed to send SIG{signal} to exec-server pid {pid}"
|
||||
));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) async fn wait_for_exit(
|
||||
&mut self,
|
||||
timeout_duration: Duration,
|
||||
) -> anyhow::Result<std::process::ExitStatus> {
|
||||
timeout(timeout_duration, self.child.wait())
|
||||
.await
|
||||
.map_err(|_| anyhow!("timed out waiting for exec-server shutdown"))?
|
||||
.map_err(Into::into)
|
||||
}
|
||||
|
||||
pub(crate) async fn assert_still_running_for(
|
||||
&mut self,
|
||||
duration: Duration,
|
||||
) -> anyhow::Result<()> {
|
||||
sleep(duration).await;
|
||||
if let Some(status) = self.child.try_wait()? {
|
||||
return Err(anyhow!("exec-server exited early with status {status}"));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn send_message(&mut self, message: JSONRPCMessage) -> anyhow::Result<()> {
|
||||
let encoded = serde_json::to_string(&message)?;
|
||||
self.websocket.send(Message::Text(encoded.into())).await?;
|
||||
|
||||
@@ -136,22 +136,34 @@ fn maybe_run_exec_server_from_test_binary(guard: Option<&TestBinaryDispatchGuard
|
||||
return;
|
||||
}
|
||||
|
||||
let Some(flag) = args.next() else {
|
||||
let mut listen_url = None;
|
||||
let mut config_path = None;
|
||||
while let Some(flag) = args.next() {
|
||||
match flag.as_str() {
|
||||
"--listen" => {
|
||||
let Some(value) = args.next() else {
|
||||
eprintln!("expected listen URL");
|
||||
std::process::exit(1);
|
||||
};
|
||||
listen_url = Some(value);
|
||||
}
|
||||
"--config-path" => {
|
||||
let Some(value) = args.next() else {
|
||||
eprintln!("expected config path");
|
||||
std::process::exit(1);
|
||||
};
|
||||
config_path = Some(PathBuf::from(value));
|
||||
}
|
||||
_ => {
|
||||
eprintln!("unexpected exec-server argument `{flag}`");
|
||||
std::process::exit(1);
|
||||
}
|
||||
}
|
||||
}
|
||||
let Some(listen_url) = listen_url else {
|
||||
eprintln!("expected --listen");
|
||||
std::process::exit(1);
|
||||
};
|
||||
if flag != "--listen" {
|
||||
eprintln!("expected --listen, got `{flag}`");
|
||||
std::process::exit(1);
|
||||
}
|
||||
let Some(listen_url) = args.next() else {
|
||||
eprintln!("expected listen URL");
|
||||
std::process::exit(1);
|
||||
};
|
||||
if args.next().is_some() {
|
||||
eprintln!("unexpected extra arguments");
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
let current_exe = match env::current_exe() {
|
||||
Ok(current_exe) => current_exe,
|
||||
@@ -180,8 +192,17 @@ fn maybe_run_exec_server_from_test_binary(guard: Option<&TestBinaryDispatchGuard
|
||||
std::process::exit(1);
|
||||
}
|
||||
};
|
||||
let exit_code = match runtime.block_on(codex_exec_server::run_main(&listen_url, runtime_paths))
|
||||
{
|
||||
let config_path = config_path.unwrap_or_else(|| {
|
||||
let codex_home = env::var_os("CODEX_HOME")
|
||||
.map(PathBuf::from)
|
||||
.unwrap_or_else(|| PathBuf::from("."));
|
||||
codex_home.join(codex_exec_server::EXEC_SERVER_CONFIG_FILE)
|
||||
});
|
||||
let exit_code = match runtime.block_on(run_test_exec_server(
|
||||
&listen_url,
|
||||
runtime_paths,
|
||||
&config_path,
|
||||
)) {
|
||||
Ok(()) => 0,
|
||||
Err(error) => {
|
||||
eprintln!("exec-server failed: {error}");
|
||||
@@ -191,6 +212,17 @@ fn maybe_run_exec_server_from_test_binary(guard: Option<&TestBinaryDispatchGuard
|
||||
std::process::exit(exit_code);
|
||||
}
|
||||
|
||||
async fn run_test_exec_server(
|
||||
listen_url: &str,
|
||||
runtime_paths: ExecServerRuntimePaths,
|
||||
config_path: &Path,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let options = codex_exec_server::ExecServerConfig::load_from_path(config_path)
|
||||
.await?
|
||||
.into_run_options(config_path)?;
|
||||
codex_exec_server::run_main_with_options(listen_url, runtime_paths, options).await
|
||||
}
|
||||
|
||||
fn linux_sandbox_exe(
|
||||
guard: Option<&TestBinaryDispatchGuard>,
|
||||
current_exe: &std::path::Path,
|
||||
|
||||
201
codex-rs/exec-server/tests/shutdown.rs
Normal file
201
codex-rs/exec-server/tests/shutdown.rs
Normal file
@@ -0,0 +1,201 @@
|
||||
#![cfg(unix)]
|
||||
|
||||
mod common;
|
||||
|
||||
use std::time::Duration;
|
||||
|
||||
use codex_app_server_protocol::JSONRPCError;
|
||||
use codex_app_server_protocol::JSONRPCMessage;
|
||||
use codex_app_server_protocol::JSONRPCResponse;
|
||||
use codex_exec_server::ExecResponse;
|
||||
use codex_exec_server::InitializeParams;
|
||||
use codex_exec_server::ProcessId;
|
||||
use common::exec_server::ExecServerHarness;
|
||||
use common::exec_server::exec_server_with_config;
|
||||
use pretty_assertions::assert_eq;
|
||||
use tokio::time::Instant;
|
||||
use tokio_tungstenite::connect_async;
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn sigterm_drains_active_process_before_exit() -> anyhow::Result<()> {
|
||||
let mut server = exec_server_with_config("graceful_shutdown_timeout_ms = 2000\n").await?;
|
||||
initialize_exec_server(&mut server).await?;
|
||||
start_sleep_process(&mut server, "proc-drain", "0.4").await?;
|
||||
|
||||
server.send_sigterm()?;
|
||||
server
|
||||
.assert_still_running_for(Duration::from_millis(100))
|
||||
.await?;
|
||||
let status = server.wait_for_exit(Duration::from_secs(3)).await?;
|
||||
|
||||
assert!(status.success(), "exec-server exited with {status}");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn drain_timeout_forces_active_process_shutdown() -> anyhow::Result<()> {
|
||||
let mut server = exec_server_with_config("graceful_shutdown_timeout_ms = 100\n").await?;
|
||||
initialize_exec_server(&mut server).await?;
|
||||
start_sleep_process(&mut server, "proc-timeout", "5").await?;
|
||||
|
||||
server.send_sigterm()?;
|
||||
let status = server.wait_for_exit(Duration::from_secs(2)).await?;
|
||||
|
||||
assert!(status.success(), "exec-server exited with {status}");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn second_signal_forces_shutdown_during_drain() -> anyhow::Result<()> {
|
||||
let mut server = exec_server_with_config("graceful_shutdown_timeout_ms = 5000\n").await?;
|
||||
initialize_exec_server(&mut server).await?;
|
||||
start_sleep_process(&mut server, "proc-second-signal", "5").await?;
|
||||
|
||||
server.send_sigint()?;
|
||||
server
|
||||
.assert_still_running_for(Duration::from_millis(100))
|
||||
.await?;
|
||||
server.send_sigint()?;
|
||||
let status = server.wait_for_exit(Duration::from_secs(2)).await?;
|
||||
|
||||
assert!(status.success(), "exec-server exited with {status}");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn draining_rejects_new_process_starts_on_existing_connection() -> anyhow::Result<()> {
|
||||
let mut server = exec_server_with_config("graceful_shutdown_timeout_ms = 2000\n").await?;
|
||||
initialize_exec_server(&mut server).await?;
|
||||
start_sleep_process(&mut server, "proc-existing", "0.8").await?;
|
||||
|
||||
server.send_sigterm()?;
|
||||
wait_until_new_connections_are_refused(server.websocket_url()).await?;
|
||||
let request_id = server
|
||||
.send_request(
|
||||
"process/start",
|
||||
serde_json::json!({
|
||||
"processId": "proc-rejected",
|
||||
"argv": ["true"],
|
||||
"cwd": std::env::current_dir()?,
|
||||
"env": {},
|
||||
"tty": false,
|
||||
"pipeStdin": false,
|
||||
"arg0": null
|
||||
}),
|
||||
)
|
||||
.await?;
|
||||
let response = server
|
||||
.wait_for_event(|event| {
|
||||
matches!(
|
||||
event,
|
||||
JSONRPCMessage::Error(JSONRPCError { id, .. }) if id == &request_id
|
||||
)
|
||||
})
|
||||
.await?;
|
||||
|
||||
let JSONRPCMessage::Error(JSONRPCError { error, .. }) = response else {
|
||||
panic!("expected process/start to fail while draining");
|
||||
};
|
||||
assert_eq!(error.code, -32600);
|
||||
assert_eq!(
|
||||
error.message,
|
||||
"exec-server is draining; new processes are not accepted"
|
||||
);
|
||||
|
||||
let status = server.wait_for_exit(Duration::from_secs(3)).await?;
|
||||
assert!(status.success(), "exec-server exited with {status}");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn signal_stops_accepting_new_websocket_connections() -> anyhow::Result<()> {
|
||||
let mut server = exec_server_with_config("graceful_shutdown_timeout_ms = 2000\n").await?;
|
||||
initialize_exec_server(&mut server).await?;
|
||||
start_sleep_process(&mut server, "proc-connection-refused", "0.8").await?;
|
||||
|
||||
server.send_sigterm()?;
|
||||
wait_until_new_connections_are_refused(server.websocket_url()).await?;
|
||||
let status = server.wait_for_exit(Duration::from_secs(3)).await?;
|
||||
assert!(status.success(), "exec-server exited with {status}");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn initialize_exec_server(server: &mut ExecServerHarness) -> anyhow::Result<()> {
|
||||
let initialize_id = server
|
||||
.send_request(
|
||||
"initialize",
|
||||
serde_json::to_value(InitializeParams {
|
||||
client_name: "exec-server-test".to_string(),
|
||||
resume_session_id: None,
|
||||
})?,
|
||||
)
|
||||
.await?;
|
||||
let _ = wait_for_response(server, initialize_id).await?;
|
||||
server
|
||||
.send_notification("initialized", serde_json::json!({}))
|
||||
.await
|
||||
}
|
||||
|
||||
async fn start_sleep_process(
|
||||
server: &mut ExecServerHarness,
|
||||
process_id: &str,
|
||||
seconds: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
let request_id = server
|
||||
.send_request(
|
||||
"process/start",
|
||||
serde_json::json!({
|
||||
"processId": process_id,
|
||||
"argv": ["/bin/sh", "-c", format!("sleep {seconds}")],
|
||||
"cwd": std::env::current_dir()?,
|
||||
"env": {},
|
||||
"tty": false,
|
||||
"pipeStdin": false,
|
||||
"arg0": null
|
||||
}),
|
||||
)
|
||||
.await?;
|
||||
let result = wait_for_response(server, request_id).await?;
|
||||
let response: ExecResponse = serde_json::from_value(result)?;
|
||||
assert_eq!(
|
||||
response,
|
||||
ExecResponse {
|
||||
process_id: ProcessId::from(process_id)
|
||||
}
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn wait_for_response(
|
||||
server: &mut ExecServerHarness,
|
||||
expected_id: codex_app_server_protocol::RequestId,
|
||||
) -> anyhow::Result<serde_json::Value> {
|
||||
let response = server
|
||||
.wait_for_event(|event| {
|
||||
matches!(
|
||||
event,
|
||||
JSONRPCMessage::Response(JSONRPCResponse { id, .. }) if id == &expected_id
|
||||
)
|
||||
})
|
||||
.await?;
|
||||
let JSONRPCMessage::Response(JSONRPCResponse { result, .. }) = response else {
|
||||
panic!("expected JSON-RPC response");
|
||||
};
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
async fn wait_until_new_connections_are_refused(websocket_url: &str) -> anyhow::Result<()> {
|
||||
let deadline = Instant::now() + Duration::from_secs(1);
|
||||
loop {
|
||||
match connect_async(websocket_url).await {
|
||||
Ok((websocket, _)) => {
|
||||
drop(websocket);
|
||||
if Instant::now() >= deadline {
|
||||
anyhow::bail!("exec-server kept accepting websocket connections after signal");
|
||||
}
|
||||
tokio::time::sleep(Duration::from_millis(25)).await;
|
||||
}
|
||||
Err(_) => return Ok(()),
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user