use std::sync::Arc; use std::sync::Mutex as StdMutex; use std::sync::atomic::AtomicBool; use std::sync::atomic::Ordering; use codex_app_server_protocol::JSONRPCErrorError; use codex_app_server_protocol::RequestId; use serde_json::to_value; use std::collections::HashSet; use tokio::sync::Mutex; use tokio_util::sync::CancellationToken; use tokio_util::task::TaskTracker; use crate::ExecServerRuntimePaths; use crate::client::http_client::PendingReqwestHttpBodyStream; use crate::client::http_client::ReqwestHttpRequestRunner; use crate::protocol::ExecParams; use crate::protocol::ExecResponse; use crate::protocol::FsCopyParams; use crate::protocol::FsCopyResponse; use crate::protocol::FsCreateDirectoryParams; use crate::protocol::FsCreateDirectoryResponse; use crate::protocol::FsGetMetadataParams; use crate::protocol::FsGetMetadataResponse; use crate::protocol::FsReadDirectoryParams; use crate::protocol::FsReadDirectoryResponse; use crate::protocol::FsReadFileParams; use crate::protocol::FsReadFileResponse; use crate::protocol::FsRemoveParams; use crate::protocol::FsRemoveResponse; use crate::protocol::FsWriteFileParams; use crate::protocol::FsWriteFileResponse; use crate::protocol::HttpRequestParams; use crate::protocol::InitializeParams; use crate::protocol::InitializeResponse; use crate::protocol::ReadParams; use crate::protocol::ReadResponse; use crate::protocol::TerminateParams; use crate::protocol::TerminateResponse; use crate::protocol::WriteParams; use crate::protocol::WriteResponse; use crate::rpc::RpcNotificationSender; use crate::rpc::internal_error; use crate::rpc::invalid_params; use crate::rpc::invalid_request; use crate::server::file_system_handler::FileSystemHandler; use crate::server::session_registry::SessionHandle; use crate::server::session_registry::SessionRegistry; pub(crate) struct ExecServerHandler { session_registry: Arc, notifications: RpcNotificationSender, session: StdMutex>, active_body_stream_ids: Mutex>, background_task_shutdown: CancellationToken, background_tasks: TaskTracker, file_system: FileSystemHandler, initialize_requested: AtomicBool, initialized: AtomicBool, } impl ExecServerHandler { pub(crate) fn new( session_registry: Arc, notifications: RpcNotificationSender, runtime_paths: ExecServerRuntimePaths, ) -> Self { Self { session_registry, notifications, session: StdMutex::new(None), active_body_stream_ids: Mutex::new(HashSet::new()), background_task_shutdown: CancellationToken::new(), background_tasks: TaskTracker::new(), file_system: FileSystemHandler::new(runtime_paths), initialize_requested: AtomicBool::new(false), initialized: AtomicBool::new(false), } } pub(crate) async fn shutdown(&self) { self.background_task_shutdown.cancel(); self.background_tasks.close(); self.background_tasks.wait().await; if let Some(session) = self.session() { session.detach().await; } } pub(crate) fn is_session_attached(&self) -> bool { self.session() .is_none_or(|session| session.is_session_attached()) } pub(crate) async fn initialize( &self, params: InitializeParams, ) -> Result { if self.initialize_requested.swap(true, Ordering::SeqCst) { return Err(invalid_request( "initialize may only be sent once per connection".to_string(), )); } let session = match self .session_registry .attach(params.resume_session_id.clone(), self.notifications.clone()) .await { Ok(session) => session, Err(error) => { self.initialize_requested.store(false, Ordering::SeqCst); return Err(error); } }; let session_id = session.session_id().to_string(); tracing::debug!( session_id, connection_id = %session.connection_id(), "exec-server session attached" ); *self .session .lock() .unwrap_or_else(std::sync::PoisonError::into_inner) = Some(session); Ok(InitializeResponse { session_id }) } pub(crate) fn initialized(&self) -> Result<(), String> { if !self.initialize_requested.load(Ordering::SeqCst) { return Err("received `initialized` notification before `initialize`".into()); } self.require_session_attached() .map_err(|error| error.message)?; self.initialized.store(true, Ordering::SeqCst); Ok(()) } pub(crate) async fn exec(&self, params: ExecParams) -> Result { let session = self.require_initialized_for("exec")?; session.process().exec(params).await } pub(crate) async fn exec_read( &self, params: ReadParams, ) -> Result { let session = self.require_initialized_for("exec")?; let response = session.process().exec_read(params).await?; self.require_session_attached()?; Ok(response) } pub(crate) async fn exec_write( &self, params: WriteParams, ) -> Result { let session = self.require_initialized_for("exec")?; session.process().exec_write(params).await } pub(crate) async fn terminate( &self, params: TerminateParams, ) -> Result { let session = self.require_initialized_for("exec")?; session.process().terminate(params).await } pub(crate) async fn http_request( self: &Arc, request_id: RequestId, params: HttpRequestParams, ) -> Result<(), JSONRPCErrorError> { self.require_initialized_for("http")?; let stream_response = params.stream_response; let http_request_id = params.request_id.clone(); if stream_response { self.reserve_http_body_stream(&http_request_id).await?; } let response = ReqwestHttpRequestRunner::new(params.timeout_ms)? .run(params) .await; if response.is_err() && stream_response { self.release_http_body_stream(&http_request_id).await; } let (response, mut pending_stream) = response?; let result = match to_value(response) { Ok(result) => result, Err(err) => { if let Some(pending_stream) = pending_stream.take() { self.release_http_body_stream(&pending_stream.request_id) .await; } return Err(internal_error(err.to_string())); } }; if let Err(error) = self.notifications.response(request_id, result).await { if let Some(pending_stream) = pending_stream.take() { self.release_http_body_stream(&pending_stream.request_id) .await; } return Err(error); } if let Some(pending_stream) = pending_stream { self.start_http_body_stream(pending_stream).await; } Ok(()) } pub(crate) async fn fs_read_file( &self, params: FsReadFileParams, ) -> Result { self.require_initialized_for("filesystem")?; self.file_system.read_file(params).await } pub(crate) async fn fs_write_file( &self, params: FsWriteFileParams, ) -> Result { self.require_initialized_for("filesystem")?; self.file_system.write_file(params).await } pub(crate) async fn fs_create_directory( &self, params: FsCreateDirectoryParams, ) -> Result { self.require_initialized_for("filesystem")?; self.file_system.create_directory(params).await } pub(crate) async fn fs_get_metadata( &self, params: FsGetMetadataParams, ) -> Result { self.require_initialized_for("filesystem")?; self.file_system.get_metadata(params).await } pub(crate) async fn fs_read_directory( &self, params: FsReadDirectoryParams, ) -> Result { self.require_initialized_for("filesystem")?; self.file_system.read_directory(params).await } pub(crate) async fn fs_remove( &self, params: FsRemoveParams, ) -> Result { self.require_initialized_for("filesystem")?; self.file_system.remove(params).await } pub(crate) async fn fs_copy( &self, params: FsCopyParams, ) -> Result { self.require_initialized_for("filesystem")?; self.file_system.copy(params).await } fn require_initialized_for( &self, method_family: &str, ) -> Result { if !self.initialize_requested.load(Ordering::SeqCst) { return Err(invalid_request(format!( "client must call initialize before using {method_family} methods" ))); } let session = self.require_session_attached()?; if !self.initialized.load(Ordering::SeqCst) { return Err(invalid_request(format!( "client must send initialized before using {method_family} methods" ))); } Ok(session) } fn require_session_attached(&self) -> Result { let Some(session) = self.session() else { return Err(invalid_request( "client must call initialize before using methods".to_string(), )); }; if session.is_session_attached() { return Ok(session); } Err(invalid_request( "session has been resumed by another connection".to_string(), )) } fn session(&self) -> Option { self.session .lock() .unwrap_or_else(std::sync::PoisonError::into_inner) .clone() } async fn start_http_body_stream( self: &Arc, pending_stream: PendingReqwestHttpBodyStream, ) { let request_id = pending_stream.request_id.clone(); if self.background_task_shutdown.is_cancelled() { self.release_http_body_stream(&request_id).await; return; } let finished_request_id = request_id.clone(); let handler = Arc::clone(self); let notifications = self.notifications.clone(); let shutdown = self.background_task_shutdown.clone(); self.background_tasks.spawn(async move { tokio::select! { _ = shutdown.cancelled() => {} _ = ReqwestHttpRequestRunner::stream_body(pending_stream, notifications) => {} } handler.release_http_body_stream(&finished_request_id).await; }); } async fn release_http_body_stream(&self, request_id: &str) { let mut active_body_stream_ids = self.active_body_stream_ids.lock().await; active_body_stream_ids.remove(request_id); } async fn reserve_http_body_stream(&self, request_id: &str) -> Result<(), JSONRPCErrorError> { let mut active_body_stream_ids = self.active_body_stream_ids.lock().await; if active_body_stream_ids.contains(request_id) { return Err(invalid_params(format!( "http/request streamResponse requestId `{request_id}` is already active" ))); } active_body_stream_ids.insert(request_id.to_string()); Ok(()) } } #[cfg(test)] mod tests;