mirror of
https://github.com/openai/codex.git
synced 2026-02-01 22:47:52 +00:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7c94e4cab9 |
89
codex-rs/Cargo.lock
generated
89
codex-rs/Cargo.lock
generated
@@ -1438,6 +1438,12 @@ version = "0.15.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1aaf95b3e5c8f23aa320147307562d361db0ae0d51242340f558153b4eb2439b"
|
||||
|
||||
[[package]]
|
||||
name = "downcast-rs"
|
||||
version = "1.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "75b325c5dbd37f80359721ad39aca5a29fb04c89279657cffdda8736d0c0b9d2"
|
||||
|
||||
[[package]]
|
||||
name = "dupe"
|
||||
version = "0.9.1"
|
||||
@@ -1632,6 +1638,21 @@ dependencies = [
|
||||
"pin-project-lite",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "exec-command-mcp"
|
||||
version = "0.0.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"mcp-types",
|
||||
"portable-pty",
|
||||
"schemars 0.8.22",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"tokio",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "exr"
|
||||
version = "1.73.0"
|
||||
@@ -1683,6 +1704,17 @@ dependencies = [
|
||||
"simd-adler32",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "filedescriptor"
|
||||
version = "0.8.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e40758ed24c9b2eeb76c35fb0aebc66c626084edd827e07e1552279814c6682d"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"thiserror 1.0.69",
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fixedbitset"
|
||||
version = "0.4.2"
|
||||
@@ -3340,6 +3372,27 @@ dependencies = [
|
||||
"portable-atomic",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "portable-pty"
|
||||
version = "0.9.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b4a596a2b3d2752d94f51fac2d4a96737b8705dddd311a32b9af47211f08671e"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"bitflags 1.3.2",
|
||||
"downcast-rs",
|
||||
"filedescriptor",
|
||||
"lazy_static",
|
||||
"libc",
|
||||
"log",
|
||||
"nix",
|
||||
"serial2",
|
||||
"shared_library",
|
||||
"shell-words",
|
||||
"winapi",
|
||||
"winreg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "potential_utf"
|
||||
version = "0.1.2"
|
||||
@@ -4267,6 +4320,17 @@ dependencies = [
|
||||
"syn 2.0.104",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serial2"
|
||||
version = "0.2.31"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "26e1e5956803a69ddd72ce2de337b577898801528749565def03515f82bad5bb"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"libc",
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sha1"
|
||||
version = "0.10.6"
|
||||
@@ -4298,6 +4362,22 @@ dependencies = [
|
||||
"lazy_static",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "shared_library"
|
||||
version = "0.1.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5a9e7e0f2bfae24d8a5b5a66c5b257a83c7412304311512a0c054cd5e619da11"
|
||||
dependencies = [
|
||||
"lazy_static",
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "shell-words"
|
||||
version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "24188a676b6ae68c3b2cb3a01be17fbf7240ce009799bb56d5b1409051e78fde"
|
||||
|
||||
[[package]]
|
||||
name = "shlex"
|
||||
version = "1.3.0"
|
||||
@@ -6008,6 +6088,15 @@ dependencies = [
|
||||
"memchr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "winreg"
|
||||
version = "0.10.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "80d0f4e272c85def139476380b12f9ac60926689dd2e01d4923222f40580869d"
|
||||
dependencies = [
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wiremock"
|
||||
version = "0.6.4"
|
||||
|
||||
@@ -7,6 +7,7 @@ members = [
|
||||
"common",
|
||||
"core",
|
||||
"exec",
|
||||
"exec-command-mcp",
|
||||
"execpolicy",
|
||||
"file-search",
|
||||
"linux-sandbox",
|
||||
|
||||
@@ -265,10 +265,7 @@ For casual greetings, acknowledgements, or other one-off conversational messages
|
||||
|
||||
## Shell commands
|
||||
|
||||
When using the shell, you must adhere to the following guidelines:
|
||||
|
||||
- When searching for text or files, prefer using `rg` or `rg --files` respectively because `rg` is much faster than alternatives like `grep`. (If the `rg` command is not found, then use alternatives.)
|
||||
- Read files in chunks with a max chunk size of 250 lines. Do not use python scripts to attempt to output larger chunks of a file. Command line output will be truncated after 10 kilobytes or 256 lines of output, regardless of the command used.
|
||||
Do NOT use `shell`. Use only `functions_exec_command` and `functions_write_stdin`.
|
||||
|
||||
## `apply_patch`
|
||||
|
||||
|
||||
@@ -480,6 +480,7 @@ impl Session {
|
||||
sandbox_policy.clone(),
|
||||
config.include_plan_tool,
|
||||
config.include_apply_patch_tool,
|
||||
config.experimental_disable_built_in_shell_tool,
|
||||
),
|
||||
user_instructions,
|
||||
base_instructions,
|
||||
@@ -1049,6 +1050,7 @@ async fn submission_loop(
|
||||
new_sandbox_policy.clone(),
|
||||
config.include_plan_tool,
|
||||
config.include_apply_patch_tool,
|
||||
config.experimental_disable_built_in_shell_tool,
|
||||
);
|
||||
|
||||
let new_turn_context = TurnContext {
|
||||
@@ -1125,6 +1127,7 @@ async fn submission_loop(
|
||||
sandbox_policy.clone(),
|
||||
config.include_plan_tool,
|
||||
config.include_apply_patch_tool,
|
||||
config.experimental_disable_built_in_shell_tool,
|
||||
),
|
||||
user_instructions: turn_context.user_instructions.clone(),
|
||||
base_instructions: turn_context.base_instructions.clone(),
|
||||
|
||||
@@ -162,6 +162,8 @@ pub struct Config {
|
||||
/// model family's default preference.
|
||||
pub include_apply_patch_tool: bool,
|
||||
|
||||
pub experimental_disable_built_in_shell_tool: bool,
|
||||
|
||||
/// The value for the `originator` header included with Responses API requests.
|
||||
pub internal_originator: Option<String>,
|
||||
|
||||
@@ -409,6 +411,8 @@ pub struct ConfigToml {
|
||||
/// Experimental path to a file whose contents replace the built-in BASE_INSTRUCTIONS.
|
||||
pub experimental_instructions_file: Option<PathBuf>,
|
||||
|
||||
pub experimental_disable_built_in_shell_tool: Option<bool>,
|
||||
|
||||
/// The value for the `originator` header included with Responses API requests.
|
||||
pub internal_originator: Option<String>,
|
||||
|
||||
@@ -678,6 +682,9 @@ impl Config {
|
||||
experimental_resume,
|
||||
include_plan_tool: include_plan_tool.unwrap_or(false),
|
||||
include_apply_patch_tool: include_apply_patch_tool_val,
|
||||
experimental_disable_built_in_shell_tool: cfg
|
||||
.experimental_disable_built_in_shell_tool
|
||||
.unwrap_or(false),
|
||||
internal_originator: cfg.internal_originator,
|
||||
preferred_auth_method: cfg.preferred_auth_method.unwrap_or(AuthMode::ChatGPT),
|
||||
};
|
||||
@@ -1043,6 +1050,7 @@ disable_response_storage = true
|
||||
base_instructions: None,
|
||||
include_plan_tool: false,
|
||||
include_apply_patch_tool: false,
|
||||
experimental_disable_built_in_shell_tool: false,
|
||||
internal_originator: None,
|
||||
preferred_auth_method: AuthMode::ChatGPT,
|
||||
},
|
||||
@@ -1096,6 +1104,7 @@ disable_response_storage = true
|
||||
base_instructions: None,
|
||||
include_plan_tool: false,
|
||||
include_apply_patch_tool: false,
|
||||
experimental_disable_built_in_shell_tool: false,
|
||||
internal_originator: None,
|
||||
preferred_auth_method: AuthMode::ChatGPT,
|
||||
};
|
||||
@@ -1164,6 +1173,7 @@ disable_response_storage = true
|
||||
base_instructions: None,
|
||||
include_plan_tool: false,
|
||||
include_apply_patch_tool: false,
|
||||
experimental_disable_built_in_shell_tool: false,
|
||||
internal_originator: None,
|
||||
preferred_auth_method: AuthMode::ChatGPT,
|
||||
};
|
||||
|
||||
@@ -37,6 +37,7 @@ pub enum ConfigShellToolType {
|
||||
DefaultShell,
|
||||
ShellWithRequest { sandbox_policy: SandboxPolicy },
|
||||
LocalShell,
|
||||
NoBuiltInShellTool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@@ -53,8 +54,11 @@ impl ToolsConfig {
|
||||
sandbox_policy: SandboxPolicy,
|
||||
include_plan_tool: bool,
|
||||
include_apply_patch_tool: bool,
|
||||
experimental_disable_built_in_shell_tool: bool,
|
||||
) -> Self {
|
||||
let mut shell_type = if model_family.uses_local_shell_tool {
|
||||
let mut shell_type = if experimental_disable_built_in_shell_tool {
|
||||
ConfigShellToolType::NoBuiltInShellTool
|
||||
} else if model_family.uses_local_shell_tool {
|
||||
ConfigShellToolType::LocalShell
|
||||
} else {
|
||||
ConfigShellToolType::DefaultShell
|
||||
@@ -533,6 +537,9 @@ pub(crate) fn get_openai_tools(
|
||||
ConfigShellToolType::LocalShell => {
|
||||
tools.push(OpenAiTool::LocalShell {});
|
||||
}
|
||||
ConfigShellToolType::NoBuiltInShellTool => {
|
||||
// Do not add a shell tool
|
||||
}
|
||||
}
|
||||
|
||||
if config.plan_tool {
|
||||
@@ -597,6 +604,7 @@ mod tests {
|
||||
SandboxPolicy::ReadOnly,
|
||||
true,
|
||||
model_family.uses_apply_patch_tool,
|
||||
/*experimental_disable_built_in_shell_tool*/ false,
|
||||
);
|
||||
let tools = get_openai_tools(&config, Some(HashMap::new()));
|
||||
|
||||
@@ -612,6 +620,7 @@ mod tests {
|
||||
SandboxPolicy::ReadOnly,
|
||||
true,
|
||||
model_family.uses_apply_patch_tool,
|
||||
/*experimental_disable_built_in_shell_tool*/ false,
|
||||
);
|
||||
let tools = get_openai_tools(&config, Some(HashMap::new()));
|
||||
|
||||
@@ -627,6 +636,7 @@ mod tests {
|
||||
SandboxPolicy::ReadOnly,
|
||||
false,
|
||||
model_family.uses_apply_patch_tool,
|
||||
/*experimental_disable_built_in_shell_tool*/ false,
|
||||
);
|
||||
let tools = get_openai_tools(
|
||||
&config,
|
||||
@@ -721,6 +731,7 @@ mod tests {
|
||||
SandboxPolicy::ReadOnly,
|
||||
false,
|
||||
model_family.uses_apply_patch_tool,
|
||||
/*experimental_disable_built_in_shell_tool*/ false,
|
||||
);
|
||||
|
||||
let tools = get_openai_tools(
|
||||
@@ -777,6 +788,7 @@ mod tests {
|
||||
SandboxPolicy::ReadOnly,
|
||||
false,
|
||||
model_family.uses_apply_patch_tool,
|
||||
/*experimental_disable_built_in_shell_tool*/ false,
|
||||
);
|
||||
|
||||
let tools = get_openai_tools(
|
||||
@@ -828,6 +840,7 @@ mod tests {
|
||||
SandboxPolicy::ReadOnly,
|
||||
false,
|
||||
model_family.uses_apply_patch_tool,
|
||||
/*experimental_disable_built_in_shell_tool*/ false,
|
||||
);
|
||||
|
||||
let tools = get_openai_tools(
|
||||
@@ -882,6 +895,7 @@ mod tests {
|
||||
SandboxPolicy::ReadOnly,
|
||||
false,
|
||||
model_family.uses_apply_patch_tool,
|
||||
/*experimental_disable_built_in_shell_tool*/ false,
|
||||
);
|
||||
|
||||
let tools = get_openai_tools(
|
||||
|
||||
35
codex-rs/exec-command-mcp/Cargo.toml
Normal file
35
codex-rs/exec-command-mcp/Cargo.toml
Normal file
@@ -0,0 +1,35 @@
|
||||
[package]
|
||||
edition = "2024"
|
||||
name = "exec-command-mcp"
|
||||
version = { workspace = true }
|
||||
|
||||
[[bin]]
|
||||
name = "exec-command-mcp"
|
||||
path = "src/main.rs"
|
||||
|
||||
[lib]
|
||||
name = "exec_command_mcp"
|
||||
path = "src/lib.rs"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1"
|
||||
mcp-types = { path = "../mcp-types" }
|
||||
portable-pty = "0.9.0"
|
||||
schemars = "0.8.22"
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
tokio = { version = "1", features = [
|
||||
"io-std",
|
||||
"io-util",
|
||||
"macros",
|
||||
"process",
|
||||
"rt-multi-thread",
|
||||
"time",
|
||||
"sync",
|
||||
"signal",
|
||||
] }
|
||||
tracing = { version = "0.1.41", features = ["log"] }
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter", "fmt"] }
|
||||
2
codex-rs/exec-command-mcp/src/error_code.rs
Normal file
2
codex-rs/exec-command-mcp/src/error_code.rs
Normal file
@@ -0,0 +1,2 @@
|
||||
pub(crate) const INVALID_REQUEST_ERROR_CODE: i64 = -32600;
|
||||
pub(crate) const INTERNAL_ERROR_CODE: i64 = -32603;
|
||||
59
codex-rs/exec-command-mcp/src/exec_command.rs
Normal file
59
codex-rs/exec-command-mcp/src/exec_command.rs
Normal file
@@ -0,0 +1,59 @@
|
||||
use schemars::JsonSchema;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
|
||||
use crate::session_id::SessionId;
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[derive(Debug, Clone, Deserialize, JsonSchema)]
|
||||
pub(crate) struct ExecCommandParams {
|
||||
pub(crate) cmd: String,
|
||||
|
||||
#[serde(default = "default_yield_time")]
|
||||
pub(crate) yield_time_ms: u64,
|
||||
|
||||
#[serde(default = "max_output_tokens")]
|
||||
pub(crate) max_output_tokens: u64,
|
||||
|
||||
#[serde(default = "default_shell")]
|
||||
pub(crate) shell: String,
|
||||
|
||||
#[serde(default = "default_login")]
|
||||
pub(crate) login: bool,
|
||||
}
|
||||
|
||||
fn default_yield_time() -> u64 {
|
||||
10_000
|
||||
}
|
||||
|
||||
fn max_output_tokens() -> u64 {
|
||||
10_000
|
||||
}
|
||||
|
||||
fn default_login() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn default_shell() -> String {
|
||||
"/bin/bash".to_string()
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, JsonSchema)]
|
||||
pub(crate) struct WriteStdinParams {
|
||||
pub(crate) session_id: SessionId,
|
||||
pub(crate) chars: String,
|
||||
|
||||
#[serde(default = "write_stdin_default_yield_time_ms")]
|
||||
pub(crate) yield_time_ms: u64,
|
||||
|
||||
#[serde(default = "write_stdin_default_max_output_tokens")]
|
||||
pub(crate) max_output_tokens: u64,
|
||||
}
|
||||
|
||||
fn write_stdin_default_yield_time_ms() -> u64 {
|
||||
250
|
||||
}
|
||||
|
||||
fn write_stdin_default_max_output_tokens() -> u64 {
|
||||
10_000
|
||||
}
|
||||
54
codex-rs/exec-command-mcp/src/exec_command_session.rs
Normal file
54
codex-rs/exec-command-mcp/src/exec_command_session.rs
Normal file
@@ -0,0 +1,54 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use crate::session_id::SessionId;
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct ExecCommandSession {
|
||||
pub(crate) id: SessionId,
|
||||
/// Queue for writing bytes to the process stdin (PTY master write side).
|
||||
writer_tx: mpsc::Sender<Vec<u8>>,
|
||||
/// Stream of output chunks read from the PTY. Wrapped in Mutex so callers can
|
||||
/// `await` receiving without needing `&mut self`.
|
||||
output_rx: Arc<Mutex<mpsc::Receiver<Vec<u8>>>>,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
impl ExecCommandSession {
|
||||
pub(crate) fn new(
|
||||
id: SessionId,
|
||||
writer_tx: mpsc::Sender<Vec<u8>>,
|
||||
output_rx: mpsc::Receiver<Vec<u8>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
id,
|
||||
writer_tx,
|
||||
output_rx: Arc::new(Mutex::new(output_rx)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Enqueue bytes to be written to the process stdin (PTY master).
|
||||
pub(crate) async fn write_stdin(&self, bytes: impl AsRef<[u8]>) -> anyhow::Result<()> {
|
||||
self.writer_tx
|
||||
.send(bytes.as_ref().to_vec())
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("failed to send to writer: {e}"))
|
||||
}
|
||||
|
||||
/// Receive the next chunk of output from the process. Returns `None` when the
|
||||
/// output stream is closed (process exited or reader finished).
|
||||
pub(crate) async fn recv_output_chunk(&self) -> Option<Vec<u8>> {
|
||||
self.output_rx.lock().await.recv().await
|
||||
}
|
||||
|
||||
pub(crate) fn writer_sender(&self) -> mpsc::Sender<Vec<u8>> {
|
||||
self.writer_tx.clone()
|
||||
}
|
||||
|
||||
pub(crate) fn output_receiver(&self) -> Arc<Mutex<mpsc::Receiver<Vec<u8>>>> {
|
||||
self.output_rx.clone()
|
||||
}
|
||||
}
|
||||
118
codex-rs/exec-command-mcp/src/lib.rs
Normal file
118
codex-rs/exec-command-mcp/src/lib.rs
Normal file
@@ -0,0 +1,118 @@
|
||||
#![deny(clippy::print_stdout, clippy::print_stderr)]
|
||||
|
||||
use mcp_types::JSONRPCMessage;
|
||||
use std::io::Result as IoResult;
|
||||
use tokio::io::AsyncBufReadExt;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::io::BufReader;
|
||||
use tokio::sync::mpsc;
|
||||
use tracing::debug;
|
||||
use tracing::error;
|
||||
use tracing::info;
|
||||
use tracing_subscriber::EnvFilter;
|
||||
|
||||
use crate::message_processor::MessageProcessor;
|
||||
use crate::outgoing_message::OutgoingMessage;
|
||||
use crate::outgoing_message_sender::OutgoingMessageSender;
|
||||
|
||||
mod error_code;
|
||||
mod exec_command;
|
||||
mod exec_command_session;
|
||||
mod message_processor;
|
||||
mod outgoing_message;
|
||||
mod outgoing_message_sender;
|
||||
mod session_id;
|
||||
mod session_manager;
|
||||
|
||||
/// Size of the bounded channels used to communicate between tasks. The value
|
||||
/// is a balance between throughput and memory usage – 128 messages should be
|
||||
/// plenty for an interactive CLI.
|
||||
const CHANNEL_CAPACITY: usize = 128;
|
||||
|
||||
pub async fn run_main() -> IoResult<()> {
|
||||
// Honor `RUST_LOG`.
|
||||
tracing_subscriber::fmt()
|
||||
.with_writer(std::io::stderr)
|
||||
.with_env_filter(EnvFilter::from_default_env())
|
||||
.init();
|
||||
|
||||
// Set up channels.
|
||||
let (incoming_tx, mut incoming_rx) = mpsc::channel::<JSONRPCMessage>(CHANNEL_CAPACITY);
|
||||
let (outgoing_tx, mut outgoing_rx) = mpsc::channel::<OutgoingMessage>(CHANNEL_CAPACITY);
|
||||
|
||||
// Task: read from stdin, push to `incoming_tx`.
|
||||
let stdin_reader_handle = tokio::spawn({
|
||||
let incoming_tx = incoming_tx.clone();
|
||||
async move {
|
||||
let stdin = tokio::io::stdin();
|
||||
let reader = BufReader::new(stdin);
|
||||
let mut lines = reader.lines();
|
||||
|
||||
while let Some(line) = lines.next_line().await.unwrap_or_default() {
|
||||
match serde_json::from_str::<JSONRPCMessage>(&line) {
|
||||
Ok(msg) => {
|
||||
if incoming_tx.send(msg).await.is_err() {
|
||||
// Receiver gone – nothing left to do.
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(e) => error!("Failed to deserialize JSONRPCMessage: {e}"),
|
||||
}
|
||||
}
|
||||
|
||||
debug!("stdin reader finished (EOF)");
|
||||
}
|
||||
});
|
||||
|
||||
// Task: process incoming messages.
|
||||
let processor_handle = tokio::spawn({
|
||||
let outgoing_message_sender = OutgoingMessageSender::new(outgoing_tx);
|
||||
let mut processor = MessageProcessor::new(outgoing_message_sender);
|
||||
async move {
|
||||
while let Some(msg) = incoming_rx.recv().await {
|
||||
match msg {
|
||||
JSONRPCMessage::Request(request) => processor.process_request(request).await,
|
||||
JSONRPCMessage::Response(_response) => {}
|
||||
JSONRPCMessage::Notification(_notification) => {}
|
||||
JSONRPCMessage::Error(_error) => {}
|
||||
}
|
||||
}
|
||||
|
||||
info!("processor task exited (channel closed)");
|
||||
}
|
||||
});
|
||||
|
||||
// Task: write outgoing messages to stdout.
|
||||
let stdout_writer_handle = tokio::spawn(async move {
|
||||
let mut stdout = tokio::io::stdout();
|
||||
while let Some(outgoing_message) = outgoing_rx.recv().await {
|
||||
let msg: JSONRPCMessage = outgoing_message.into();
|
||||
match serde_json::to_string(&msg) {
|
||||
Ok(json) => {
|
||||
if let Err(e) = stdout.write_all(json.as_bytes()).await {
|
||||
error!("Failed to write to stdout: {e}");
|
||||
break;
|
||||
}
|
||||
if let Err(e) = stdout.write_all(b"\n").await {
|
||||
error!("Failed to write newline to stdout: {e}");
|
||||
break;
|
||||
}
|
||||
if let Err(e) = stdout.flush().await {
|
||||
error!("Failed to flush stdout: {e}");
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(e) => error!("Failed to serialize JSONRPCMessage: {e}"),
|
||||
}
|
||||
}
|
||||
|
||||
info!("stdout writer exited (channel closed)");
|
||||
});
|
||||
|
||||
// Wait for all tasks to finish. The typical exit path is the stdin reader
|
||||
// hitting EOF which, once it drops `incoming_tx`, propagates shutdown to
|
||||
// the processor and then to the stdout task.
|
||||
let _ = tokio::join!(stdin_reader_handle, processor_handle, stdout_writer_handle);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
7
codex-rs/exec-command-mcp/src/main.rs
Normal file
7
codex-rs/exec-command-mcp/src/main.rs
Normal file
@@ -0,0 +1,7 @@
|
||||
use exec_command_mcp::run_main;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
run_main().await?;
|
||||
Ok(())
|
||||
}
|
||||
289
codex-rs/exec-command-mcp/src/message_processor.rs
Normal file
289
codex-rs/exec-command-mcp/src/message_processor.rs
Normal file
@@ -0,0 +1,289 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use mcp_types::CallToolRequestParams;
|
||||
use mcp_types::CallToolResult;
|
||||
use mcp_types::ClientRequest as McpClientRequest;
|
||||
use mcp_types::ContentBlock;
|
||||
use mcp_types::JSONRPCErrorError;
|
||||
use mcp_types::JSONRPCRequest;
|
||||
use mcp_types::ListToolsResult;
|
||||
use mcp_types::ModelContextProtocolRequest;
|
||||
use mcp_types::RequestId;
|
||||
use mcp_types::ServerCapabilitiesTools;
|
||||
use mcp_types::TextContent;
|
||||
use mcp_types::Tool;
|
||||
use mcp_types::ToolInputSchema;
|
||||
use schemars::r#gen::SchemaSettings;
|
||||
|
||||
use crate::error_code::INVALID_REQUEST_ERROR_CODE;
|
||||
use crate::error_code::{self};
|
||||
use crate::exec_command::ExecCommandParams;
|
||||
use crate::exec_command::WriteStdinParams;
|
||||
use crate::outgoing_message_sender::OutgoingMessageSender;
|
||||
use crate::session_manager::SessionManager;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct MessageProcessor {
|
||||
initialized: bool,
|
||||
outgoing: Arc<OutgoingMessageSender>,
|
||||
session_manager: Arc<SessionManager>,
|
||||
}
|
||||
|
||||
impl MessageProcessor {
|
||||
pub(crate) fn new(outgoing: OutgoingMessageSender) -> Self {
|
||||
Self {
|
||||
initialized: false,
|
||||
outgoing: Arc::new(outgoing),
|
||||
session_manager: Arc::new(SessionManager::default()),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn process_request(&mut self, request: JSONRPCRequest) {
|
||||
let request_id = request.id.clone();
|
||||
let client_request = match McpClientRequest::try_from(request) {
|
||||
Ok(client_request) => client_request,
|
||||
Err(e) => {
|
||||
self.outgoing
|
||||
.send_error(
|
||||
request_id,
|
||||
JSONRPCErrorError {
|
||||
code: error_code::INVALID_REQUEST_ERROR_CODE,
|
||||
message: format!("Invalid request: {e}"),
|
||||
data: None,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
match client_request {
|
||||
McpClientRequest::InitializeRequest(params) => {
|
||||
self.handle_initialize(request_id, params).await;
|
||||
}
|
||||
McpClientRequest::ListToolsRequest(params) => {
|
||||
self.handle_list_tools(request_id, params).await;
|
||||
}
|
||||
McpClientRequest::CallToolRequest(params) => {
|
||||
self.handle_call_tool(request_id, params).await;
|
||||
}
|
||||
_ => {
|
||||
tracing::warn!("Unhandled client request: {client_request:?}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_initialize(
|
||||
&mut self,
|
||||
id: RequestId,
|
||||
params: <mcp_types::InitializeRequest as ModelContextProtocolRequest>::Params,
|
||||
) {
|
||||
tracing::info!("initialize -> params: {:?}", params);
|
||||
|
||||
if self.initialized {
|
||||
// Already initialised: send JSON-RPC error response.
|
||||
let error = JSONRPCErrorError {
|
||||
code: INVALID_REQUEST_ERROR_CODE,
|
||||
message: "initialize called more than once".to_string(),
|
||||
data: None,
|
||||
};
|
||||
self.outgoing.send_error(id, error).await;
|
||||
return;
|
||||
}
|
||||
|
||||
self.initialized = true;
|
||||
|
||||
// Build a minimal InitializeResult. Fill with placeholders.
|
||||
let result = mcp_types::InitializeResult {
|
||||
capabilities: mcp_types::ServerCapabilities {
|
||||
completions: None,
|
||||
experimental: None,
|
||||
logging: None,
|
||||
prompts: None,
|
||||
resources: None,
|
||||
tools: Some(ServerCapabilitiesTools {
|
||||
list_changed: Some(true),
|
||||
}),
|
||||
},
|
||||
instructions: None,
|
||||
protocol_version: params.protocol_version.clone(),
|
||||
server_info: mcp_types::Implementation {
|
||||
name: "exec-command-mcp".to_string(),
|
||||
version: env!("CARGO_PKG_VERSION").to_string(),
|
||||
title: Some("Codex exec_command".to_string()),
|
||||
},
|
||||
};
|
||||
|
||||
self.send_response::<mcp_types::InitializeRequest>(id, result)
|
||||
.await;
|
||||
}
|
||||
|
||||
async fn handle_list_tools(
|
||||
&self,
|
||||
request_id: RequestId,
|
||||
params: <mcp_types::ListToolsRequest as mcp_types::ModelContextProtocolRequest>::Params,
|
||||
) {
|
||||
tracing::trace!("tools/list ({request_id:?}) -> {params:?}");
|
||||
|
||||
// Generate tool schema eagerly in a short-lived scope to avoid holding
|
||||
// non-Send schemars generator across await.
|
||||
let result = {
|
||||
let generator = SchemaSettings::draft2019_09()
|
||||
.with(|s| {
|
||||
s.inline_subschemas = true;
|
||||
s.option_add_null_type = false;
|
||||
})
|
||||
.into_generator();
|
||||
|
||||
let exec_schema = generator
|
||||
.clone()
|
||||
.into_root_schema_for::<ExecCommandParams>();
|
||||
let write_schema = generator.into_root_schema_for::<WriteStdinParams>();
|
||||
|
||||
#[expect(clippy::expect_used)]
|
||||
let exec_schema_json =
|
||||
serde_json::to_value(&exec_schema).expect("exec_command schema should serialize");
|
||||
#[expect(clippy::expect_used)]
|
||||
let write_schema_json =
|
||||
serde_json::to_value(&write_schema).expect("write_stdin schema should serialize");
|
||||
|
||||
let exec_input_schema = serde_json::from_value::<ToolInputSchema>(exec_schema_json)
|
||||
.unwrap_or_else(|e| {
|
||||
panic!("failed to create Tool from schema: {e}");
|
||||
});
|
||||
let write_input_schema = serde_json::from_value::<ToolInputSchema>(write_schema_json)
|
||||
.unwrap_or_else(|e| {
|
||||
panic!("failed to create Tool from schema: {e}");
|
||||
});
|
||||
|
||||
let tools = vec![
|
||||
Tool {
|
||||
name: "functions_exec_command".to_string(),
|
||||
title: Some("Exec Command".to_string()),
|
||||
description: Some("Start a PTY-backed shell command; returns early on timeout or completion.".to_string()),
|
||||
input_schema: exec_input_schema,
|
||||
output_schema: None,
|
||||
annotations: None,
|
||||
},
|
||||
Tool {
|
||||
name: "functions_write_stdin".to_string(),
|
||||
title: Some("Write Stdin".to_string()),
|
||||
description: Some("Write characters to a running exec session and collect output for a short window.".to_string()),
|
||||
input_schema: write_input_schema,
|
||||
output_schema: None,
|
||||
annotations: None,
|
||||
},
|
||||
];
|
||||
|
||||
ListToolsResult {
|
||||
tools,
|
||||
next_cursor: None,
|
||||
}
|
||||
};
|
||||
|
||||
self.send_response::<mcp_types::ListToolsRequest>(request_id, result)
|
||||
.await;
|
||||
}
|
||||
|
||||
async fn handle_call_tool(
|
||||
&self,
|
||||
request_id: RequestId,
|
||||
params: <mcp_types::CallToolRequest as mcp_types::ModelContextProtocolRequest>::Params,
|
||||
) {
|
||||
tracing::info!("tools/call -> params: {params:?}");
|
||||
let CallToolRequestParams { name, arguments } = params;
|
||||
|
||||
match name.as_str() {
|
||||
"functions_exec_command" => match extract_exec_command_params(arguments).await {
|
||||
Ok(params) => {
|
||||
tracing::info!("functions_exec_command -> params: {params:?}");
|
||||
let session_manager = self.session_manager.clone();
|
||||
let outgoing = self.outgoing.clone();
|
||||
tokio::spawn(async move {
|
||||
session_manager
|
||||
.handle_exec_command_request(request_id, params, outgoing)
|
||||
.await;
|
||||
});
|
||||
}
|
||||
Err(jsonrpc_error) => {
|
||||
self.outgoing.send_error(request_id, jsonrpc_error).await;
|
||||
}
|
||||
},
|
||||
"functions_write_stdin" => match extract_write_stdin_params(arguments).await {
|
||||
Ok(params) => {
|
||||
tracing::info!("functions_write_stdin -> params: {params:?}");
|
||||
let session_manager = self.session_manager.clone();
|
||||
let outgoing = self.outgoing.clone();
|
||||
tokio::spawn(async move {
|
||||
session_manager
|
||||
.handle_write_stdin_request(request_id, params, outgoing)
|
||||
.await;
|
||||
});
|
||||
}
|
||||
Err(jsonrpc_error) => {
|
||||
self.outgoing.send_error(request_id, jsonrpc_error).await;
|
||||
}
|
||||
},
|
||||
_ => {
|
||||
let result = CallToolResult {
|
||||
content: vec![ContentBlock::TextContent(TextContent {
|
||||
r#type: "text".to_string(),
|
||||
text: format!("Unknown tool '{name}'"),
|
||||
annotations: None,
|
||||
})],
|
||||
is_error: Some(true),
|
||||
structured_content: None,
|
||||
};
|
||||
self.send_response::<mcp_types::CallToolRequest>(request_id, result)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn send_response<T>(&self, id: RequestId, result: T::Result)
|
||||
where
|
||||
T: ModelContextProtocolRequest,
|
||||
{
|
||||
self.outgoing.send_response(id, result).await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn extract_exec_command_params(
|
||||
args: Option<serde_json::Value>,
|
||||
) -> Result<ExecCommandParams, JSONRPCErrorError> {
|
||||
match args {
|
||||
Some(value) => match serde_json::from_value::<ExecCommandParams>(value) {
|
||||
Ok(params) => Ok(params),
|
||||
Err(e) => Err(JSONRPCErrorError {
|
||||
code: error_code::INVALID_REQUEST_ERROR_CODE,
|
||||
message: format!("Invalid request: {e}"),
|
||||
data: None,
|
||||
}),
|
||||
},
|
||||
None => Err(JSONRPCErrorError {
|
||||
code: error_code::INVALID_REQUEST_ERROR_CODE,
|
||||
message: "Missing arguments".to_string(),
|
||||
data: None,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
async fn extract_write_stdin_params(
|
||||
args: Option<serde_json::Value>,
|
||||
) -> Result<WriteStdinParams, JSONRPCErrorError> {
|
||||
match args {
|
||||
Some(value) => match serde_json::from_value::<WriteStdinParams>(value) {
|
||||
Ok(params) => Ok(params),
|
||||
Err(e) => Err(JSONRPCErrorError {
|
||||
code: error_code::INVALID_REQUEST_ERROR_CODE,
|
||||
message: format!("Invalid request: {e}"),
|
||||
data: None,
|
||||
}),
|
||||
},
|
||||
None => Err(JSONRPCErrorError {
|
||||
code: error_code::INVALID_REQUEST_ERROR_CODE,
|
||||
message: "Missing arguments".to_string(),
|
||||
data: None,
|
||||
}),
|
||||
}
|
||||
}
|
||||
46
codex-rs/exec-command-mcp/src/outgoing_message.rs
Normal file
46
codex-rs/exec-command-mcp/src/outgoing_message.rs
Normal file
@@ -0,0 +1,46 @@
|
||||
use mcp_types::JSONRPC_VERSION;
|
||||
use mcp_types::JSONRPCError;
|
||||
use mcp_types::JSONRPCErrorError;
|
||||
use mcp_types::JSONRPCMessage;
|
||||
use mcp_types::JSONRPCResponse;
|
||||
use mcp_types::RequestId;
|
||||
use mcp_types::Result;
|
||||
use serde::Serialize;
|
||||
|
||||
/// Outgoing message from the server to the client.
|
||||
pub(crate) enum OutgoingMessage {
|
||||
Response(OutgoingResponse),
|
||||
Error(OutgoingError),
|
||||
}
|
||||
|
||||
impl From<OutgoingMessage> for JSONRPCMessage {
|
||||
fn from(val: OutgoingMessage) -> Self {
|
||||
use OutgoingMessage::*;
|
||||
match val {
|
||||
Response(OutgoingResponse { id, result }) => {
|
||||
JSONRPCMessage::Response(JSONRPCResponse {
|
||||
jsonrpc: JSONRPC_VERSION.into(),
|
||||
id,
|
||||
result,
|
||||
})
|
||||
}
|
||||
Error(OutgoingError { id, error }) => JSONRPCMessage::Error(JSONRPCError {
|
||||
jsonrpc: JSONRPC_VERSION.into(),
|
||||
id,
|
||||
error,
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize)]
|
||||
pub(crate) struct OutgoingResponse {
|
||||
pub id: RequestId,
|
||||
pub result: Result,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize)]
|
||||
pub(crate) struct OutgoingError {
|
||||
pub error: JSONRPCErrorError,
|
||||
pub id: RequestId,
|
||||
}
|
||||
47
codex-rs/exec-command-mcp/src/outgoing_message_sender.rs
Normal file
47
codex-rs/exec-command-mcp/src/outgoing_message_sender.rs
Normal file
@@ -0,0 +1,47 @@
|
||||
use mcp_types::JSONRPCErrorError;
|
||||
use mcp_types::RequestId;
|
||||
use serde::Serialize;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use crate::outgoing_message::OutgoingError;
|
||||
use crate::outgoing_message::OutgoingMessage;
|
||||
use crate::outgoing_message::OutgoingResponse;
|
||||
|
||||
use crate::error_code::INTERNAL_ERROR_CODE;
|
||||
|
||||
/// Sends messages to the client and manages request callbacks.
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct OutgoingMessageSender {
|
||||
sender: mpsc::Sender<OutgoingMessage>,
|
||||
}
|
||||
|
||||
impl OutgoingMessageSender {
|
||||
pub(crate) fn new(sender: mpsc::Sender<OutgoingMessage>) -> Self {
|
||||
Self { sender }
|
||||
}
|
||||
|
||||
pub(crate) async fn send_response<T: Serialize>(&self, id: RequestId, response: T) {
|
||||
match serde_json::to_value(response) {
|
||||
Ok(result) => {
|
||||
let outgoing_message = OutgoingMessage::Response(OutgoingResponse { id, result });
|
||||
let _ = self.sender.send(outgoing_message).await;
|
||||
}
|
||||
Err(err) => {
|
||||
self.send_error(
|
||||
id,
|
||||
JSONRPCErrorError {
|
||||
code: INTERNAL_ERROR_CODE,
|
||||
message: format!("failed to serialize response: {err}"),
|
||||
data: None,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn send_error(&self, id: RequestId, error: JSONRPCErrorError) {
|
||||
let outgoing_message = OutgoingMessage::Error(OutgoingError { id, error });
|
||||
let _ = self.sender.send(outgoing_message).await;
|
||||
}
|
||||
}
|
||||
6
codex-rs/exec-command-mcp/src/session_id.rs
Normal file
6
codex-rs/exec-command-mcp/src/session_id.rs
Normal file
@@ -0,0 +1,6 @@
|
||||
use schemars::JsonSchema;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, JsonSchema)]
|
||||
pub(crate) struct SessionId(pub u32);
|
||||
324
codex-rs/exec-command-mcp/src/session_manager.rs
Normal file
324
codex-rs/exec-command-mcp/src/session_manager.rs
Normal file
@@ -0,0 +1,324 @@
|
||||
use std::collections::HashMap;
|
||||
use std::io::Read;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex as StdMutex;
|
||||
use std::sync::atomic::AtomicU32;
|
||||
|
||||
use mcp_types::CallToolResult;
|
||||
use mcp_types::ContentBlock;
|
||||
use mcp_types::JSONRPCErrorError;
|
||||
use mcp_types::RequestId;
|
||||
use mcp_types::TextContent;
|
||||
use portable_pty::CommandBuilder;
|
||||
use portable_pty::PtySize;
|
||||
use portable_pty::native_pty_system;
|
||||
use serde_json::json;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::oneshot;
|
||||
use tokio::time::Duration;
|
||||
use tokio::time::Instant;
|
||||
use tokio::time::timeout;
|
||||
|
||||
use crate::error_code;
|
||||
use crate::exec_command::ExecCommandParams;
|
||||
use crate::exec_command::WriteStdinParams;
|
||||
use crate::exec_command_session::ExecCommandSession;
|
||||
use crate::outgoing_message_sender::OutgoingMessageSender;
|
||||
use crate::session_id::SessionId;
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub(crate) struct SessionManager {
|
||||
next_session_id: AtomicU32,
|
||||
sessions: Mutex<HashMap<SessionId, ExecCommandSession>>,
|
||||
}
|
||||
|
||||
impl SessionManager {
|
||||
/// Processes the request and is required to send a response via `outgoing`.
|
||||
pub(crate) async fn handle_exec_command_request(
|
||||
&self,
|
||||
request_id: RequestId,
|
||||
params: ExecCommandParams,
|
||||
outgoing: Arc<OutgoingMessageSender>,
|
||||
) {
|
||||
// Allocate a session id.
|
||||
let session_id = SessionId(
|
||||
self.next_session_id
|
||||
.fetch_add(1, std::sync::atomic::Ordering::SeqCst),
|
||||
);
|
||||
|
||||
let result = create_exec_command_session(session_id, params.clone()).await;
|
||||
|
||||
match result {
|
||||
Ok((session, mut exit_rx)) => {
|
||||
// Insert into session map.
|
||||
let output_receiver = session.output_receiver();
|
||||
self.sessions.lock().await.insert(session_id, session);
|
||||
|
||||
// Collect output until either timeout expires or process exits.
|
||||
// Cap by assuming 4 bytes per token (TODO: use a real tokenizer).
|
||||
let cap_bytes_u64 = params.max_output_tokens.saturating_mul(4);
|
||||
let cap_bytes: usize = cap_bytes_u64.min(usize::MAX as u64) as usize;
|
||||
let cap_hint = cap_bytes.clamp(1024, 8192);
|
||||
let mut collected: Vec<u8> = Vec::with_capacity(cap_hint);
|
||||
|
||||
let deadline = Instant::now() + Duration::from_millis(params.yield_time_ms);
|
||||
let mut exit_code: Option<i32> = None;
|
||||
|
||||
loop {
|
||||
if Instant::now() >= deadline {
|
||||
break;
|
||||
}
|
||||
let remaining = deadline.saturating_duration_since(Instant::now());
|
||||
tokio::select! {
|
||||
biased;
|
||||
exit = &mut exit_rx => {
|
||||
exit_code = exit.ok();
|
||||
// Small grace period to pull remaining buffered output
|
||||
let grace_deadline = Instant::now() + Duration::from_millis(25);
|
||||
while Instant::now() < grace_deadline {
|
||||
let recv_next = async {
|
||||
let mut rx = output_receiver.lock().await;
|
||||
rx.recv().await
|
||||
};
|
||||
if let Ok(Some(chunk)) = timeout(Duration::from_millis(1), recv_next).await {
|
||||
let available = cap_bytes.saturating_sub(collected.len());
|
||||
if available == 0 { break; }
|
||||
let take = available.min(chunk.len());
|
||||
collected.extend_from_slice(&chunk[..take]);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
chunk = timeout(remaining, async {
|
||||
let mut rx = output_receiver.lock().await;
|
||||
rx.recv().await
|
||||
}) => {
|
||||
match chunk {
|
||||
Ok(Some(chunk)) => {
|
||||
let available = cap_bytes.saturating_sub(collected.len());
|
||||
if available == 0 { /* keep draining, but don't store */ }
|
||||
else {
|
||||
let take = available.min(chunk.len());
|
||||
collected.extend_from_slice(&chunk[..take]);
|
||||
}
|
||||
}
|
||||
Ok(None) => { break; }
|
||||
Err(_) => { break; }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let text = String::from_utf8_lossy(&collected).to_string();
|
||||
let mut structured = json!({ "sessionId": session_id });
|
||||
if let Some(code) = exit_code {
|
||||
structured["exitCode"] = json!(code);
|
||||
}
|
||||
let result = CallToolResult {
|
||||
content: vec![ContentBlock::TextContent(TextContent {
|
||||
r#type: "text".to_string(),
|
||||
text,
|
||||
annotations: None,
|
||||
})],
|
||||
is_error: None,
|
||||
structured_content: Some(structured),
|
||||
};
|
||||
outgoing.send_response(request_id, result).await;
|
||||
}
|
||||
Err(err) => {
|
||||
outgoing
|
||||
.send_error(
|
||||
request_id,
|
||||
JSONRPCErrorError {
|
||||
code: error_code::INTERNAL_ERROR_CODE,
|
||||
message: format!("failed to start exec session: {err}"),
|
||||
data: None,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Write characters to a session's stdin and collect combined output for up to `yield_time_ms`.
|
||||
pub(crate) async fn handle_write_stdin_request(
|
||||
&self,
|
||||
request_id: RequestId,
|
||||
params: WriteStdinParams,
|
||||
outgoing: Arc<OutgoingMessageSender>,
|
||||
) {
|
||||
let WriteStdinParams {
|
||||
session_id,
|
||||
chars,
|
||||
yield_time_ms,
|
||||
max_output_tokens,
|
||||
} = params;
|
||||
|
||||
// Grab handles without holding the sessions lock across await points.
|
||||
let (writer_tx, output_rx) = {
|
||||
let sessions = self.sessions.lock().await;
|
||||
match sessions.get(&session_id) {
|
||||
Some(session) => (session.writer_sender(), session.output_receiver()),
|
||||
None => {
|
||||
outgoing
|
||||
.send_error(
|
||||
request_id,
|
||||
JSONRPCErrorError {
|
||||
code: error_code::INVALID_REQUEST_ERROR_CODE,
|
||||
message: format!("unknown session id {}", session_id.0),
|
||||
data: None,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Write stdin if provided.
|
||||
if !chars.is_empty() && writer_tx.send(chars.into_bytes()).await.is_err() {
|
||||
outgoing
|
||||
.send_error(
|
||||
request_id,
|
||||
JSONRPCErrorError {
|
||||
code: error_code::INTERNAL_ERROR_CODE,
|
||||
message: "failed to write to stdin".to_string(),
|
||||
data: None,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
|
||||
// Collect output up to yield_time_ms, truncating to max_output_tokens bytes.
|
||||
let mut collected: Vec<u8> = Vec::with_capacity(4096);
|
||||
let deadline = Instant::now() + Duration::from_millis(yield_time_ms);
|
||||
loop {
|
||||
let now = Instant::now();
|
||||
if now >= deadline {
|
||||
break;
|
||||
}
|
||||
let remaining = deadline - now;
|
||||
match timeout(remaining, output_rx.lock().await.recv()).await {
|
||||
Ok(Some(chunk)) => {
|
||||
// Respect token/byte limit; keep draining but drop once full.
|
||||
let available =
|
||||
max_output_tokens.saturating_sub(collected.len() as u64) as usize;
|
||||
if available > 0 {
|
||||
let take = available.min(chunk.len());
|
||||
collected.extend_from_slice(&chunk[..take]);
|
||||
}
|
||||
// Continue loop to drain further within time.
|
||||
}
|
||||
Ok(None) => break, // channel closed
|
||||
Err(_) => break, // timeout
|
||||
}
|
||||
}
|
||||
|
||||
// Return text output as a CallToolResult
|
||||
let text = String::from_utf8_lossy(&collected).to_string();
|
||||
let result = CallToolResult {
|
||||
content: vec![ContentBlock::TextContent(TextContent {
|
||||
r#type: "text".to_string(),
|
||||
text,
|
||||
annotations: None,
|
||||
})],
|
||||
is_error: None,
|
||||
structured_content: None,
|
||||
};
|
||||
outgoing.send_response(request_id, result).await;
|
||||
}
|
||||
}
|
||||
|
||||
/// Spawn PTY and child process per spawn_exec_command_session logic.
|
||||
async fn create_exec_command_session(
|
||||
session_id: SessionId,
|
||||
params: ExecCommandParams,
|
||||
) -> anyhow::Result<(ExecCommandSession, oneshot::Receiver<i32>)> {
|
||||
let ExecCommandParams {
|
||||
cmd,
|
||||
yield_time_ms: _,
|
||||
max_output_tokens: _,
|
||||
shell,
|
||||
login,
|
||||
} = params;
|
||||
|
||||
// Use the native pty implementation for the system
|
||||
let pty_system = native_pty_system();
|
||||
|
||||
// Create a new pty
|
||||
let pair = pty_system.openpty(PtySize {
|
||||
rows: 24,
|
||||
cols: 80,
|
||||
pixel_width: 0,
|
||||
pixel_height: 0,
|
||||
})?;
|
||||
|
||||
// Spawn a shell into the pty
|
||||
let mut command_builder = CommandBuilder::new(shell);
|
||||
let shell_mode_opt = if login { "-lc" } else { "-c" };
|
||||
command_builder.arg(shell_mode_opt);
|
||||
command_builder.arg(cmd);
|
||||
|
||||
let mut child = pair.slave.spawn_command(command_builder)?;
|
||||
|
||||
// Channel to forward write requests to the PTY writer.
|
||||
let (writer_tx, mut writer_rx) = mpsc::channel::<Vec<u8>>(128);
|
||||
// Channel for streaming PTY output to readers.
|
||||
let (output_tx, output_rx) = mpsc::channel::<Vec<u8>>(256);
|
||||
|
||||
// Reader task: drain PTY and forward chunks to output channel.
|
||||
let mut reader = pair.master.try_clone_reader()?;
|
||||
let output_tx_clone = output_tx.clone();
|
||||
tokio::task::spawn_blocking(move || {
|
||||
let mut buf = [0u8; 8192];
|
||||
loop {
|
||||
match reader.read(&mut buf) {
|
||||
Ok(0) => break, // EOF
|
||||
Ok(n) => {
|
||||
// Forward; block if receiver is slow to avoid dropping output.
|
||||
let _ = output_tx_clone.blocking_send(buf[..n].to_vec());
|
||||
}
|
||||
Err(_) => break,
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Writer task: apply stdin writes to the PTY writer.
|
||||
let writer = pair.master.take_writer()?;
|
||||
let writer = Arc::new(StdMutex::new(writer));
|
||||
tokio::spawn({
|
||||
let writer = writer.clone();
|
||||
async move {
|
||||
while let Some(bytes) = writer_rx.recv().await {
|
||||
let writer = writer.clone();
|
||||
// Perform blocking write on a blocking thread.
|
||||
let _ = tokio::task::spawn_blocking(move || {
|
||||
if let Ok(mut guard) = writer.lock() {
|
||||
use std::io::Write;
|
||||
let _ = guard.write_all(&bytes);
|
||||
let _ = guard.flush();
|
||||
}
|
||||
})
|
||||
.await;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Keep the child alive until it exits, then signal exit code.
|
||||
let (exit_tx, exit_rx) = oneshot::channel::<i32>();
|
||||
tokio::task::spawn_blocking(move || {
|
||||
let code = match child.wait() {
|
||||
Ok(status) => status.exit_code() as i32,
|
||||
Err(_) => -1,
|
||||
};
|
||||
let _ = exit_tx.send(code);
|
||||
});
|
||||
|
||||
// Create and store the session with channels.
|
||||
let session = ExecCommandSession::new(session_id, writer_tx, output_rx);
|
||||
Ok((session, exit_rx))
|
||||
}
|
||||
Reference in New Issue
Block a user