Compare commits

...

1 Commits

Author SHA1 Message Date
Michael Bolin
7c94e4cab9 feat: exec-command-mcp 2025-08-20 13:59:42 -07:00
17 changed files with 1106 additions and 5 deletions

89
codex-rs/Cargo.lock generated
View File

@@ -1438,6 +1438,12 @@ version = "0.15.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1aaf95b3e5c8f23aa320147307562d361db0ae0d51242340f558153b4eb2439b"
[[package]]
name = "downcast-rs"
version = "1.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "75b325c5dbd37f80359721ad39aca5a29fb04c89279657cffdda8736d0c0b9d2"
[[package]]
name = "dupe"
version = "0.9.1"
@@ -1632,6 +1638,21 @@ dependencies = [
"pin-project-lite",
]
[[package]]
name = "exec-command-mcp"
version = "0.0.0"
dependencies = [
"anyhow",
"mcp-types",
"portable-pty",
"schemars 0.8.22",
"serde",
"serde_json",
"tokio",
"tracing",
"tracing-subscriber",
]
[[package]]
name = "exr"
version = "1.73.0"
@@ -1683,6 +1704,17 @@ dependencies = [
"simd-adler32",
]
[[package]]
name = "filedescriptor"
version = "0.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e40758ed24c9b2eeb76c35fb0aebc66c626084edd827e07e1552279814c6682d"
dependencies = [
"libc",
"thiserror 1.0.69",
"winapi",
]
[[package]]
name = "fixedbitset"
version = "0.4.2"
@@ -3340,6 +3372,27 @@ dependencies = [
"portable-atomic",
]
[[package]]
name = "portable-pty"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b4a596a2b3d2752d94f51fac2d4a96737b8705dddd311a32b9af47211f08671e"
dependencies = [
"anyhow",
"bitflags 1.3.2",
"downcast-rs",
"filedescriptor",
"lazy_static",
"libc",
"log",
"nix",
"serial2",
"shared_library",
"shell-words",
"winapi",
"winreg",
]
[[package]]
name = "potential_utf"
version = "0.1.2"
@@ -4267,6 +4320,17 @@ dependencies = [
"syn 2.0.104",
]
[[package]]
name = "serial2"
version = "0.2.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "26e1e5956803a69ddd72ce2de337b577898801528749565def03515f82bad5bb"
dependencies = [
"cfg-if",
"libc",
"winapi",
]
[[package]]
name = "sha1"
version = "0.10.6"
@@ -4298,6 +4362,22 @@ dependencies = [
"lazy_static",
]
[[package]]
name = "shared_library"
version = "0.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a9e7e0f2bfae24d8a5b5a66c5b257a83c7412304311512a0c054cd5e619da11"
dependencies = [
"lazy_static",
"libc",
]
[[package]]
name = "shell-words"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "24188a676b6ae68c3b2cb3a01be17fbf7240ce009799bb56d5b1409051e78fde"
[[package]]
name = "shlex"
version = "1.3.0"
@@ -6008,6 +6088,15 @@ dependencies = [
"memchr",
]
[[package]]
name = "winreg"
version = "0.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "80d0f4e272c85def139476380b12f9ac60926689dd2e01d4923222f40580869d"
dependencies = [
"winapi",
]
[[package]]
name = "wiremock"
version = "0.6.4"

View File

@@ -7,6 +7,7 @@ members = [
"common",
"core",
"exec",
"exec-command-mcp",
"execpolicy",
"file-search",
"linux-sandbox",

View File

@@ -265,10 +265,7 @@ For casual greetings, acknowledgements, or other one-off conversational messages
## Shell commands
When using the shell, you must adhere to the following guidelines:
- When searching for text or files, prefer using `rg` or `rg --files` respectively because `rg` is much faster than alternatives like `grep`. (If the `rg` command is not found, then use alternatives.)
- Read files in chunks with a max chunk size of 250 lines. Do not use python scripts to attempt to output larger chunks of a file. Command line output will be truncated after 10 kilobytes or 256 lines of output, regardless of the command used.
Do NOT use `shell`. Use only `functions_exec_command` and `functions_write_stdin`.
## `apply_patch`

View File

@@ -480,6 +480,7 @@ impl Session {
sandbox_policy.clone(),
config.include_plan_tool,
config.include_apply_patch_tool,
config.experimental_disable_built_in_shell_tool,
),
user_instructions,
base_instructions,
@@ -1049,6 +1050,7 @@ async fn submission_loop(
new_sandbox_policy.clone(),
config.include_plan_tool,
config.include_apply_patch_tool,
config.experimental_disable_built_in_shell_tool,
);
let new_turn_context = TurnContext {
@@ -1125,6 +1127,7 @@ async fn submission_loop(
sandbox_policy.clone(),
config.include_plan_tool,
config.include_apply_patch_tool,
config.experimental_disable_built_in_shell_tool,
),
user_instructions: turn_context.user_instructions.clone(),
base_instructions: turn_context.base_instructions.clone(),

View File

@@ -162,6 +162,8 @@ pub struct Config {
/// model family's default preference.
pub include_apply_patch_tool: bool,
pub experimental_disable_built_in_shell_tool: bool,
/// The value for the `originator` header included with Responses API requests.
pub internal_originator: Option<String>,
@@ -409,6 +411,8 @@ pub struct ConfigToml {
/// Experimental path to a file whose contents replace the built-in BASE_INSTRUCTIONS.
pub experimental_instructions_file: Option<PathBuf>,
pub experimental_disable_built_in_shell_tool: Option<bool>,
/// The value for the `originator` header included with Responses API requests.
pub internal_originator: Option<String>,
@@ -678,6 +682,9 @@ impl Config {
experimental_resume,
include_plan_tool: include_plan_tool.unwrap_or(false),
include_apply_patch_tool: include_apply_patch_tool_val,
experimental_disable_built_in_shell_tool: cfg
.experimental_disable_built_in_shell_tool
.unwrap_or(false),
internal_originator: cfg.internal_originator,
preferred_auth_method: cfg.preferred_auth_method.unwrap_or(AuthMode::ChatGPT),
};
@@ -1043,6 +1050,7 @@ disable_response_storage = true
base_instructions: None,
include_plan_tool: false,
include_apply_patch_tool: false,
experimental_disable_built_in_shell_tool: false,
internal_originator: None,
preferred_auth_method: AuthMode::ChatGPT,
},
@@ -1096,6 +1104,7 @@ disable_response_storage = true
base_instructions: None,
include_plan_tool: false,
include_apply_patch_tool: false,
experimental_disable_built_in_shell_tool: false,
internal_originator: None,
preferred_auth_method: AuthMode::ChatGPT,
};
@@ -1164,6 +1173,7 @@ disable_response_storage = true
base_instructions: None,
include_plan_tool: false,
include_apply_patch_tool: false,
experimental_disable_built_in_shell_tool: false,
internal_originator: None,
preferred_auth_method: AuthMode::ChatGPT,
};

View File

@@ -37,6 +37,7 @@ pub enum ConfigShellToolType {
DefaultShell,
ShellWithRequest { sandbox_policy: SandboxPolicy },
LocalShell,
NoBuiltInShellTool,
}
#[derive(Debug, Clone)]
@@ -53,8 +54,11 @@ impl ToolsConfig {
sandbox_policy: SandboxPolicy,
include_plan_tool: bool,
include_apply_patch_tool: bool,
experimental_disable_built_in_shell_tool: bool,
) -> Self {
let mut shell_type = if model_family.uses_local_shell_tool {
let mut shell_type = if experimental_disable_built_in_shell_tool {
ConfigShellToolType::NoBuiltInShellTool
} else if model_family.uses_local_shell_tool {
ConfigShellToolType::LocalShell
} else {
ConfigShellToolType::DefaultShell
@@ -533,6 +537,9 @@ pub(crate) fn get_openai_tools(
ConfigShellToolType::LocalShell => {
tools.push(OpenAiTool::LocalShell {});
}
ConfigShellToolType::NoBuiltInShellTool => {
// Do not add a shell tool
}
}
if config.plan_tool {
@@ -597,6 +604,7 @@ mod tests {
SandboxPolicy::ReadOnly,
true,
model_family.uses_apply_patch_tool,
/*experimental_disable_built_in_shell_tool*/ false,
);
let tools = get_openai_tools(&config, Some(HashMap::new()));
@@ -612,6 +620,7 @@ mod tests {
SandboxPolicy::ReadOnly,
true,
model_family.uses_apply_patch_tool,
/*experimental_disable_built_in_shell_tool*/ false,
);
let tools = get_openai_tools(&config, Some(HashMap::new()));
@@ -627,6 +636,7 @@ mod tests {
SandboxPolicy::ReadOnly,
false,
model_family.uses_apply_patch_tool,
/*experimental_disable_built_in_shell_tool*/ false,
);
let tools = get_openai_tools(
&config,
@@ -721,6 +731,7 @@ mod tests {
SandboxPolicy::ReadOnly,
false,
model_family.uses_apply_patch_tool,
/*experimental_disable_built_in_shell_tool*/ false,
);
let tools = get_openai_tools(
@@ -777,6 +788,7 @@ mod tests {
SandboxPolicy::ReadOnly,
false,
model_family.uses_apply_patch_tool,
/*experimental_disable_built_in_shell_tool*/ false,
);
let tools = get_openai_tools(
@@ -828,6 +840,7 @@ mod tests {
SandboxPolicy::ReadOnly,
false,
model_family.uses_apply_patch_tool,
/*experimental_disable_built_in_shell_tool*/ false,
);
let tools = get_openai_tools(
@@ -882,6 +895,7 @@ mod tests {
SandboxPolicy::ReadOnly,
false,
model_family.uses_apply_patch_tool,
/*experimental_disable_built_in_shell_tool*/ false,
);
let tools = get_openai_tools(

View File

@@ -0,0 +1,35 @@
[package]
edition = "2024"
name = "exec-command-mcp"
version = { workspace = true }
[[bin]]
name = "exec-command-mcp"
path = "src/main.rs"
[lib]
name = "exec_command_mcp"
path = "src/lib.rs"
[lints]
workspace = true
[dependencies]
anyhow = "1"
mcp-types = { path = "../mcp-types" }
portable-pty = "0.9.0"
schemars = "0.8.22"
serde = { version = "1", features = ["derive"] }
serde_json = "1"
tokio = { version = "1", features = [
"io-std",
"io-util",
"macros",
"process",
"rt-multi-thread",
"time",
"sync",
"signal",
] }
tracing = { version = "0.1.41", features = ["log"] }
tracing-subscriber = { version = "0.3", features = ["env-filter", "fmt"] }

View File

@@ -0,0 +1,2 @@
pub(crate) const INVALID_REQUEST_ERROR_CODE: i64 = -32600;
pub(crate) const INTERNAL_ERROR_CODE: i64 = -32603;

View File

@@ -0,0 +1,59 @@
use schemars::JsonSchema;
use serde::Deserialize;
use serde::Serialize;
use crate::session_id::SessionId;
#[allow(dead_code)]
#[derive(Debug, Clone, Deserialize, JsonSchema)]
pub(crate) struct ExecCommandParams {
pub(crate) cmd: String,
#[serde(default = "default_yield_time")]
pub(crate) yield_time_ms: u64,
#[serde(default = "max_output_tokens")]
pub(crate) max_output_tokens: u64,
#[serde(default = "default_shell")]
pub(crate) shell: String,
#[serde(default = "default_login")]
pub(crate) login: bool,
}
fn default_yield_time() -> u64 {
10_000
}
fn max_output_tokens() -> u64 {
10_000
}
fn default_login() -> bool {
true
}
fn default_shell() -> String {
"/bin/bash".to_string()
}
#[derive(Debug, Deserialize, Serialize, JsonSchema)]
pub(crate) struct WriteStdinParams {
pub(crate) session_id: SessionId,
pub(crate) chars: String,
#[serde(default = "write_stdin_default_yield_time_ms")]
pub(crate) yield_time_ms: u64,
#[serde(default = "write_stdin_default_max_output_tokens")]
pub(crate) max_output_tokens: u64,
}
fn write_stdin_default_yield_time_ms() -> u64 {
250
}
fn write_stdin_default_max_output_tokens() -> u64 {
10_000
}

View File

@@ -0,0 +1,54 @@
use std::sync::Arc;
use tokio::sync::Mutex;
use tokio::sync::mpsc;
use crate::session_id::SessionId;
#[allow(dead_code)]
#[derive(Debug)]
pub(crate) struct ExecCommandSession {
pub(crate) id: SessionId,
/// Queue for writing bytes to the process stdin (PTY master write side).
writer_tx: mpsc::Sender<Vec<u8>>,
/// Stream of output chunks read from the PTY. Wrapped in Mutex so callers can
/// `await` receiving without needing `&mut self`.
output_rx: Arc<Mutex<mpsc::Receiver<Vec<u8>>>>,
}
#[allow(dead_code)]
impl ExecCommandSession {
pub(crate) fn new(
id: SessionId,
writer_tx: mpsc::Sender<Vec<u8>>,
output_rx: mpsc::Receiver<Vec<u8>>,
) -> Self {
Self {
id,
writer_tx,
output_rx: Arc::new(Mutex::new(output_rx)),
}
}
/// Enqueue bytes to be written to the process stdin (PTY master).
pub(crate) async fn write_stdin(&self, bytes: impl AsRef<[u8]>) -> anyhow::Result<()> {
self.writer_tx
.send(bytes.as_ref().to_vec())
.await
.map_err(|e| anyhow::anyhow!("failed to send to writer: {e}"))
}
/// Receive the next chunk of output from the process. Returns `None` when the
/// output stream is closed (process exited or reader finished).
pub(crate) async fn recv_output_chunk(&self) -> Option<Vec<u8>> {
self.output_rx.lock().await.recv().await
}
pub(crate) fn writer_sender(&self) -> mpsc::Sender<Vec<u8>> {
self.writer_tx.clone()
}
pub(crate) fn output_receiver(&self) -> Arc<Mutex<mpsc::Receiver<Vec<u8>>>> {
self.output_rx.clone()
}
}

View File

@@ -0,0 +1,118 @@
#![deny(clippy::print_stdout, clippy::print_stderr)]
use mcp_types::JSONRPCMessage;
use std::io::Result as IoResult;
use tokio::io::AsyncBufReadExt;
use tokio::io::AsyncWriteExt;
use tokio::io::BufReader;
use tokio::sync::mpsc;
use tracing::debug;
use tracing::error;
use tracing::info;
use tracing_subscriber::EnvFilter;
use crate::message_processor::MessageProcessor;
use crate::outgoing_message::OutgoingMessage;
use crate::outgoing_message_sender::OutgoingMessageSender;
mod error_code;
mod exec_command;
mod exec_command_session;
mod message_processor;
mod outgoing_message;
mod outgoing_message_sender;
mod session_id;
mod session_manager;
/// Size of the bounded channels used to communicate between tasks. The value
/// is a balance between throughput and memory usage 128 messages should be
/// plenty for an interactive CLI.
const CHANNEL_CAPACITY: usize = 128;
pub async fn run_main() -> IoResult<()> {
// Honor `RUST_LOG`.
tracing_subscriber::fmt()
.with_writer(std::io::stderr)
.with_env_filter(EnvFilter::from_default_env())
.init();
// Set up channels.
let (incoming_tx, mut incoming_rx) = mpsc::channel::<JSONRPCMessage>(CHANNEL_CAPACITY);
let (outgoing_tx, mut outgoing_rx) = mpsc::channel::<OutgoingMessage>(CHANNEL_CAPACITY);
// Task: read from stdin, push to `incoming_tx`.
let stdin_reader_handle = tokio::spawn({
let incoming_tx = incoming_tx.clone();
async move {
let stdin = tokio::io::stdin();
let reader = BufReader::new(stdin);
let mut lines = reader.lines();
while let Some(line) = lines.next_line().await.unwrap_or_default() {
match serde_json::from_str::<JSONRPCMessage>(&line) {
Ok(msg) => {
if incoming_tx.send(msg).await.is_err() {
// Receiver gone nothing left to do.
break;
}
}
Err(e) => error!("Failed to deserialize JSONRPCMessage: {e}"),
}
}
debug!("stdin reader finished (EOF)");
}
});
// Task: process incoming messages.
let processor_handle = tokio::spawn({
let outgoing_message_sender = OutgoingMessageSender::new(outgoing_tx);
let mut processor = MessageProcessor::new(outgoing_message_sender);
async move {
while let Some(msg) = incoming_rx.recv().await {
match msg {
JSONRPCMessage::Request(request) => processor.process_request(request).await,
JSONRPCMessage::Response(_response) => {}
JSONRPCMessage::Notification(_notification) => {}
JSONRPCMessage::Error(_error) => {}
}
}
info!("processor task exited (channel closed)");
}
});
// Task: write outgoing messages to stdout.
let stdout_writer_handle = tokio::spawn(async move {
let mut stdout = tokio::io::stdout();
while let Some(outgoing_message) = outgoing_rx.recv().await {
let msg: JSONRPCMessage = outgoing_message.into();
match serde_json::to_string(&msg) {
Ok(json) => {
if let Err(e) = stdout.write_all(json.as_bytes()).await {
error!("Failed to write to stdout: {e}");
break;
}
if let Err(e) = stdout.write_all(b"\n").await {
error!("Failed to write newline to stdout: {e}");
break;
}
if let Err(e) = stdout.flush().await {
error!("Failed to flush stdout: {e}");
break;
}
}
Err(e) => error!("Failed to serialize JSONRPCMessage: {e}"),
}
}
info!("stdout writer exited (channel closed)");
});
// Wait for all tasks to finish. The typical exit path is the stdin reader
// hitting EOF which, once it drops `incoming_tx`, propagates shutdown to
// the processor and then to the stdout task.
let _ = tokio::join!(stdin_reader_handle, processor_handle, stdout_writer_handle);
Ok(())
}

View File

@@ -0,0 +1,7 @@
use exec_command_mcp::run_main;
#[tokio::main]
async fn main() -> anyhow::Result<()> {
run_main().await?;
Ok(())
}

View File

@@ -0,0 +1,289 @@
use std::sync::Arc;
use mcp_types::CallToolRequestParams;
use mcp_types::CallToolResult;
use mcp_types::ClientRequest as McpClientRequest;
use mcp_types::ContentBlock;
use mcp_types::JSONRPCErrorError;
use mcp_types::JSONRPCRequest;
use mcp_types::ListToolsResult;
use mcp_types::ModelContextProtocolRequest;
use mcp_types::RequestId;
use mcp_types::ServerCapabilitiesTools;
use mcp_types::TextContent;
use mcp_types::Tool;
use mcp_types::ToolInputSchema;
use schemars::r#gen::SchemaSettings;
use crate::error_code::INVALID_REQUEST_ERROR_CODE;
use crate::error_code::{self};
use crate::exec_command::ExecCommandParams;
use crate::exec_command::WriteStdinParams;
use crate::outgoing_message_sender::OutgoingMessageSender;
use crate::session_manager::SessionManager;
#[derive(Debug)]
pub(crate) struct MessageProcessor {
initialized: bool,
outgoing: Arc<OutgoingMessageSender>,
session_manager: Arc<SessionManager>,
}
impl MessageProcessor {
pub(crate) fn new(outgoing: OutgoingMessageSender) -> Self {
Self {
initialized: false,
outgoing: Arc::new(outgoing),
session_manager: Arc::new(SessionManager::default()),
}
}
pub(crate) async fn process_request(&mut self, request: JSONRPCRequest) {
let request_id = request.id.clone();
let client_request = match McpClientRequest::try_from(request) {
Ok(client_request) => client_request,
Err(e) => {
self.outgoing
.send_error(
request_id,
JSONRPCErrorError {
code: error_code::INVALID_REQUEST_ERROR_CODE,
message: format!("Invalid request: {e}"),
data: None,
},
)
.await;
return;
}
};
match client_request {
McpClientRequest::InitializeRequest(params) => {
self.handle_initialize(request_id, params).await;
}
McpClientRequest::ListToolsRequest(params) => {
self.handle_list_tools(request_id, params).await;
}
McpClientRequest::CallToolRequest(params) => {
self.handle_call_tool(request_id, params).await;
}
_ => {
tracing::warn!("Unhandled client request: {client_request:?}");
}
}
}
async fn handle_initialize(
&mut self,
id: RequestId,
params: <mcp_types::InitializeRequest as ModelContextProtocolRequest>::Params,
) {
tracing::info!("initialize -> params: {:?}", params);
if self.initialized {
// Already initialised: send JSON-RPC error response.
let error = JSONRPCErrorError {
code: INVALID_REQUEST_ERROR_CODE,
message: "initialize called more than once".to_string(),
data: None,
};
self.outgoing.send_error(id, error).await;
return;
}
self.initialized = true;
// Build a minimal InitializeResult. Fill with placeholders.
let result = mcp_types::InitializeResult {
capabilities: mcp_types::ServerCapabilities {
completions: None,
experimental: None,
logging: None,
prompts: None,
resources: None,
tools: Some(ServerCapabilitiesTools {
list_changed: Some(true),
}),
},
instructions: None,
protocol_version: params.protocol_version.clone(),
server_info: mcp_types::Implementation {
name: "exec-command-mcp".to_string(),
version: env!("CARGO_PKG_VERSION").to_string(),
title: Some("Codex exec_command".to_string()),
},
};
self.send_response::<mcp_types::InitializeRequest>(id, result)
.await;
}
async fn handle_list_tools(
&self,
request_id: RequestId,
params: <mcp_types::ListToolsRequest as mcp_types::ModelContextProtocolRequest>::Params,
) {
tracing::trace!("tools/list ({request_id:?}) -> {params:?}");
// Generate tool schema eagerly in a short-lived scope to avoid holding
// non-Send schemars generator across await.
let result = {
let generator = SchemaSettings::draft2019_09()
.with(|s| {
s.inline_subschemas = true;
s.option_add_null_type = false;
})
.into_generator();
let exec_schema = generator
.clone()
.into_root_schema_for::<ExecCommandParams>();
let write_schema = generator.into_root_schema_for::<WriteStdinParams>();
#[expect(clippy::expect_used)]
let exec_schema_json =
serde_json::to_value(&exec_schema).expect("exec_command schema should serialize");
#[expect(clippy::expect_used)]
let write_schema_json =
serde_json::to_value(&write_schema).expect("write_stdin schema should serialize");
let exec_input_schema = serde_json::from_value::<ToolInputSchema>(exec_schema_json)
.unwrap_or_else(|e| {
panic!("failed to create Tool from schema: {e}");
});
let write_input_schema = serde_json::from_value::<ToolInputSchema>(write_schema_json)
.unwrap_or_else(|e| {
panic!("failed to create Tool from schema: {e}");
});
let tools = vec![
Tool {
name: "functions_exec_command".to_string(),
title: Some("Exec Command".to_string()),
description: Some("Start a PTY-backed shell command; returns early on timeout or completion.".to_string()),
input_schema: exec_input_schema,
output_schema: None,
annotations: None,
},
Tool {
name: "functions_write_stdin".to_string(),
title: Some("Write Stdin".to_string()),
description: Some("Write characters to a running exec session and collect output for a short window.".to_string()),
input_schema: write_input_schema,
output_schema: None,
annotations: None,
},
];
ListToolsResult {
tools,
next_cursor: None,
}
};
self.send_response::<mcp_types::ListToolsRequest>(request_id, result)
.await;
}
async fn handle_call_tool(
&self,
request_id: RequestId,
params: <mcp_types::CallToolRequest as mcp_types::ModelContextProtocolRequest>::Params,
) {
tracing::info!("tools/call -> params: {params:?}");
let CallToolRequestParams { name, arguments } = params;
match name.as_str() {
"functions_exec_command" => match extract_exec_command_params(arguments).await {
Ok(params) => {
tracing::info!("functions_exec_command -> params: {params:?}");
let session_manager = self.session_manager.clone();
let outgoing = self.outgoing.clone();
tokio::spawn(async move {
session_manager
.handle_exec_command_request(request_id, params, outgoing)
.await;
});
}
Err(jsonrpc_error) => {
self.outgoing.send_error(request_id, jsonrpc_error).await;
}
},
"functions_write_stdin" => match extract_write_stdin_params(arguments).await {
Ok(params) => {
tracing::info!("functions_write_stdin -> params: {params:?}");
let session_manager = self.session_manager.clone();
let outgoing = self.outgoing.clone();
tokio::spawn(async move {
session_manager
.handle_write_stdin_request(request_id, params, outgoing)
.await;
});
}
Err(jsonrpc_error) => {
self.outgoing.send_error(request_id, jsonrpc_error).await;
}
},
_ => {
let result = CallToolResult {
content: vec![ContentBlock::TextContent(TextContent {
r#type: "text".to_string(),
text: format!("Unknown tool '{name}'"),
annotations: None,
})],
is_error: Some(true),
structured_content: None,
};
self.send_response::<mcp_types::CallToolRequest>(request_id, result)
.await;
}
}
}
async fn send_response<T>(&self, id: RequestId, result: T::Result)
where
T: ModelContextProtocolRequest,
{
self.outgoing.send_response(id, result).await;
}
}
async fn extract_exec_command_params(
args: Option<serde_json::Value>,
) -> Result<ExecCommandParams, JSONRPCErrorError> {
match args {
Some(value) => match serde_json::from_value::<ExecCommandParams>(value) {
Ok(params) => Ok(params),
Err(e) => Err(JSONRPCErrorError {
code: error_code::INVALID_REQUEST_ERROR_CODE,
message: format!("Invalid request: {e}"),
data: None,
}),
},
None => Err(JSONRPCErrorError {
code: error_code::INVALID_REQUEST_ERROR_CODE,
message: "Missing arguments".to_string(),
data: None,
}),
}
}
async fn extract_write_stdin_params(
args: Option<serde_json::Value>,
) -> Result<WriteStdinParams, JSONRPCErrorError> {
match args {
Some(value) => match serde_json::from_value::<WriteStdinParams>(value) {
Ok(params) => Ok(params),
Err(e) => Err(JSONRPCErrorError {
code: error_code::INVALID_REQUEST_ERROR_CODE,
message: format!("Invalid request: {e}"),
data: None,
}),
},
None => Err(JSONRPCErrorError {
code: error_code::INVALID_REQUEST_ERROR_CODE,
message: "Missing arguments".to_string(),
data: None,
}),
}
}

View File

@@ -0,0 +1,46 @@
use mcp_types::JSONRPC_VERSION;
use mcp_types::JSONRPCError;
use mcp_types::JSONRPCErrorError;
use mcp_types::JSONRPCMessage;
use mcp_types::JSONRPCResponse;
use mcp_types::RequestId;
use mcp_types::Result;
use serde::Serialize;
/// Outgoing message from the server to the client.
pub(crate) enum OutgoingMessage {
Response(OutgoingResponse),
Error(OutgoingError),
}
impl From<OutgoingMessage> for JSONRPCMessage {
fn from(val: OutgoingMessage) -> Self {
use OutgoingMessage::*;
match val {
Response(OutgoingResponse { id, result }) => {
JSONRPCMessage::Response(JSONRPCResponse {
jsonrpc: JSONRPC_VERSION.into(),
id,
result,
})
}
Error(OutgoingError { id, error }) => JSONRPCMessage::Error(JSONRPCError {
jsonrpc: JSONRPC_VERSION.into(),
id,
error,
}),
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize)]
pub(crate) struct OutgoingResponse {
pub id: RequestId,
pub result: Result,
}
#[derive(Debug, Clone, PartialEq, Serialize)]
pub(crate) struct OutgoingError {
pub error: JSONRPCErrorError,
pub id: RequestId,
}

View File

@@ -0,0 +1,47 @@
use mcp_types::JSONRPCErrorError;
use mcp_types::RequestId;
use serde::Serialize;
use tokio::sync::mpsc;
use crate::outgoing_message::OutgoingError;
use crate::outgoing_message::OutgoingMessage;
use crate::outgoing_message::OutgoingResponse;
use crate::error_code::INTERNAL_ERROR_CODE;
/// Sends messages to the client and manages request callbacks.
#[derive(Debug)]
pub(crate) struct OutgoingMessageSender {
sender: mpsc::Sender<OutgoingMessage>,
}
impl OutgoingMessageSender {
pub(crate) fn new(sender: mpsc::Sender<OutgoingMessage>) -> Self {
Self { sender }
}
pub(crate) async fn send_response<T: Serialize>(&self, id: RequestId, response: T) {
match serde_json::to_value(response) {
Ok(result) => {
let outgoing_message = OutgoingMessage::Response(OutgoingResponse { id, result });
let _ = self.sender.send(outgoing_message).await;
}
Err(err) => {
self.send_error(
id,
JSONRPCErrorError {
code: INTERNAL_ERROR_CODE,
message: format!("failed to serialize response: {err}"),
data: None,
},
)
.await;
}
}
}
pub(crate) async fn send_error(&self, id: RequestId, error: JSONRPCErrorError) {
let outgoing_message = OutgoingMessage::Error(OutgoingError { id, error });
let _ = self.sender.send(outgoing_message).await;
}
}

View File

@@ -0,0 +1,6 @@
use schemars::JsonSchema;
use serde::Deserialize;
use serde::Serialize;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, JsonSchema)]
pub(crate) struct SessionId(pub u32);

View File

@@ -0,0 +1,324 @@
use std::collections::HashMap;
use std::io::Read;
use std::sync::Arc;
use std::sync::Mutex as StdMutex;
use std::sync::atomic::AtomicU32;
use mcp_types::CallToolResult;
use mcp_types::ContentBlock;
use mcp_types::JSONRPCErrorError;
use mcp_types::RequestId;
use mcp_types::TextContent;
use portable_pty::CommandBuilder;
use portable_pty::PtySize;
use portable_pty::native_pty_system;
use serde_json::json;
use tokio::sync::Mutex;
use tokio::sync::mpsc;
use tokio::sync::oneshot;
use tokio::time::Duration;
use tokio::time::Instant;
use tokio::time::timeout;
use crate::error_code;
use crate::exec_command::ExecCommandParams;
use crate::exec_command::WriteStdinParams;
use crate::exec_command_session::ExecCommandSession;
use crate::outgoing_message_sender::OutgoingMessageSender;
use crate::session_id::SessionId;
#[derive(Debug, Default)]
pub(crate) struct SessionManager {
next_session_id: AtomicU32,
sessions: Mutex<HashMap<SessionId, ExecCommandSession>>,
}
impl SessionManager {
/// Processes the request and is required to send a response via `outgoing`.
pub(crate) async fn handle_exec_command_request(
&self,
request_id: RequestId,
params: ExecCommandParams,
outgoing: Arc<OutgoingMessageSender>,
) {
// Allocate a session id.
let session_id = SessionId(
self.next_session_id
.fetch_add(1, std::sync::atomic::Ordering::SeqCst),
);
let result = create_exec_command_session(session_id, params.clone()).await;
match result {
Ok((session, mut exit_rx)) => {
// Insert into session map.
let output_receiver = session.output_receiver();
self.sessions.lock().await.insert(session_id, session);
// Collect output until either timeout expires or process exits.
// Cap by assuming 4 bytes per token (TODO: use a real tokenizer).
let cap_bytes_u64 = params.max_output_tokens.saturating_mul(4);
let cap_bytes: usize = cap_bytes_u64.min(usize::MAX as u64) as usize;
let cap_hint = cap_bytes.clamp(1024, 8192);
let mut collected: Vec<u8> = Vec::with_capacity(cap_hint);
let deadline = Instant::now() + Duration::from_millis(params.yield_time_ms);
let mut exit_code: Option<i32> = None;
loop {
if Instant::now() >= deadline {
break;
}
let remaining = deadline.saturating_duration_since(Instant::now());
tokio::select! {
biased;
exit = &mut exit_rx => {
exit_code = exit.ok();
// Small grace period to pull remaining buffered output
let grace_deadline = Instant::now() + Duration::from_millis(25);
while Instant::now() < grace_deadline {
let recv_next = async {
let mut rx = output_receiver.lock().await;
rx.recv().await
};
if let Ok(Some(chunk)) = timeout(Duration::from_millis(1), recv_next).await {
let available = cap_bytes.saturating_sub(collected.len());
if available == 0 { break; }
let take = available.min(chunk.len());
collected.extend_from_slice(&chunk[..take]);
} else {
break;
}
}
break;
}
chunk = timeout(remaining, async {
let mut rx = output_receiver.lock().await;
rx.recv().await
}) => {
match chunk {
Ok(Some(chunk)) => {
let available = cap_bytes.saturating_sub(collected.len());
if available == 0 { /* keep draining, but don't store */ }
else {
let take = available.min(chunk.len());
collected.extend_from_slice(&chunk[..take]);
}
}
Ok(None) => { break; }
Err(_) => { break; }
}
}
}
}
let text = String::from_utf8_lossy(&collected).to_string();
let mut structured = json!({ "sessionId": session_id });
if let Some(code) = exit_code {
structured["exitCode"] = json!(code);
}
let result = CallToolResult {
content: vec![ContentBlock::TextContent(TextContent {
r#type: "text".to_string(),
text,
annotations: None,
})],
is_error: None,
structured_content: Some(structured),
};
outgoing.send_response(request_id, result).await;
}
Err(err) => {
outgoing
.send_error(
request_id,
JSONRPCErrorError {
code: error_code::INTERNAL_ERROR_CODE,
message: format!("failed to start exec session: {err}"),
data: None,
},
)
.await;
}
}
}
/// Write characters to a session's stdin and collect combined output for up to `yield_time_ms`.
pub(crate) async fn handle_write_stdin_request(
&self,
request_id: RequestId,
params: WriteStdinParams,
outgoing: Arc<OutgoingMessageSender>,
) {
let WriteStdinParams {
session_id,
chars,
yield_time_ms,
max_output_tokens,
} = params;
// Grab handles without holding the sessions lock across await points.
let (writer_tx, output_rx) = {
let sessions = self.sessions.lock().await;
match sessions.get(&session_id) {
Some(session) => (session.writer_sender(), session.output_receiver()),
None => {
outgoing
.send_error(
request_id,
JSONRPCErrorError {
code: error_code::INVALID_REQUEST_ERROR_CODE,
message: format!("unknown session id {}", session_id.0),
data: None,
},
)
.await;
return;
}
}
};
// Write stdin if provided.
if !chars.is_empty() && writer_tx.send(chars.into_bytes()).await.is_err() {
outgoing
.send_error(
request_id,
JSONRPCErrorError {
code: error_code::INTERNAL_ERROR_CODE,
message: "failed to write to stdin".to_string(),
data: None,
},
)
.await;
return;
}
// Collect output up to yield_time_ms, truncating to max_output_tokens bytes.
let mut collected: Vec<u8> = Vec::with_capacity(4096);
let deadline = Instant::now() + Duration::from_millis(yield_time_ms);
loop {
let now = Instant::now();
if now >= deadline {
break;
}
let remaining = deadline - now;
match timeout(remaining, output_rx.lock().await.recv()).await {
Ok(Some(chunk)) => {
// Respect token/byte limit; keep draining but drop once full.
let available =
max_output_tokens.saturating_sub(collected.len() as u64) as usize;
if available > 0 {
let take = available.min(chunk.len());
collected.extend_from_slice(&chunk[..take]);
}
// Continue loop to drain further within time.
}
Ok(None) => break, // channel closed
Err(_) => break, // timeout
}
}
// Return text output as a CallToolResult
let text = String::from_utf8_lossy(&collected).to_string();
let result = CallToolResult {
content: vec![ContentBlock::TextContent(TextContent {
r#type: "text".to_string(),
text,
annotations: None,
})],
is_error: None,
structured_content: None,
};
outgoing.send_response(request_id, result).await;
}
}
/// Spawn PTY and child process per spawn_exec_command_session logic.
async fn create_exec_command_session(
session_id: SessionId,
params: ExecCommandParams,
) -> anyhow::Result<(ExecCommandSession, oneshot::Receiver<i32>)> {
let ExecCommandParams {
cmd,
yield_time_ms: _,
max_output_tokens: _,
shell,
login,
} = params;
// Use the native pty implementation for the system
let pty_system = native_pty_system();
// Create a new pty
let pair = pty_system.openpty(PtySize {
rows: 24,
cols: 80,
pixel_width: 0,
pixel_height: 0,
})?;
// Spawn a shell into the pty
let mut command_builder = CommandBuilder::new(shell);
let shell_mode_opt = if login { "-lc" } else { "-c" };
command_builder.arg(shell_mode_opt);
command_builder.arg(cmd);
let mut child = pair.slave.spawn_command(command_builder)?;
// Channel to forward write requests to the PTY writer.
let (writer_tx, mut writer_rx) = mpsc::channel::<Vec<u8>>(128);
// Channel for streaming PTY output to readers.
let (output_tx, output_rx) = mpsc::channel::<Vec<u8>>(256);
// Reader task: drain PTY and forward chunks to output channel.
let mut reader = pair.master.try_clone_reader()?;
let output_tx_clone = output_tx.clone();
tokio::task::spawn_blocking(move || {
let mut buf = [0u8; 8192];
loop {
match reader.read(&mut buf) {
Ok(0) => break, // EOF
Ok(n) => {
// Forward; block if receiver is slow to avoid dropping output.
let _ = output_tx_clone.blocking_send(buf[..n].to_vec());
}
Err(_) => break,
}
}
});
// Writer task: apply stdin writes to the PTY writer.
let writer = pair.master.take_writer()?;
let writer = Arc::new(StdMutex::new(writer));
tokio::spawn({
let writer = writer.clone();
async move {
while let Some(bytes) = writer_rx.recv().await {
let writer = writer.clone();
// Perform blocking write on a blocking thread.
let _ = tokio::task::spawn_blocking(move || {
if let Ok(mut guard) = writer.lock() {
use std::io::Write;
let _ = guard.write_all(&bytes);
let _ = guard.flush();
}
})
.await;
}
}
});
// Keep the child alive until it exits, then signal exit code.
let (exit_tx, exit_rx) = oneshot::channel::<i32>();
tokio::task::spawn_blocking(move || {
let code = match child.wait() {
Ok(status) => status.exit_code() as i32,
Err(_) => -1,
};
let _ = exit_tx.send(code);
});
// Create and store the session with channels.
let session = ExecCommandSession::new(session_id, writer_tx, output_rx);
Ok((session, exit_rx))
}