mirror of
https://github.com/openai/codex.git
synced 2026-04-28 08:34:54 +00:00
60 KiB
60 KiB
PR #2574: feat: StreamableShell with exec_command and write_stdin tools
- URL: https://github.com/openai/codex/pull/2574
- Author: bolinfest
- Created: 2025-08-22 02:53:20 UTC
- Updated: 2025-08-23 01:11:02 UTC
- Changes: +1096/-2, Files changed: 12, Commits: 1
Description
This introduces a complementary set of tools, exec_command and write_stdin, which are designed to facilitate working with long-running processes in a token-efficient manner.
To test:
codex-rs$ just codex --cd .. --config experimental_use_exec_command_tool=true
Though the above alone is unlikely to convince the model to use the new tools because of this bit in the base instructions:
e4c275d615/codex-rs/core/prompt.md (L266-L271)
So you probably need to add:
--config experimental_instructions_file=/some/other_instructions.md
where /some/other_instructions.md suggests using exec_command and write_stdin instead of shell.
Full Diff
diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock
index 94702a83ed..dbccbd863e 100644
--- a/codex-rs/Cargo.lock
+++ b/codex-rs/Cargo.lock
@@ -731,6 +731,7 @@ dependencies = [
"mime_guess",
"openssl-sys",
"os_info",
+ "portable-pty",
"predicates",
"pretty_assertions",
"rand 0.9.2",
@@ -1479,6 +1480,12 @@ version = "0.15.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1aaf95b3e5c8f23aa320147307562d361db0ae0d51242340f558153b4eb2439b"
+[[package]]
+name = "downcast-rs"
+version = "1.2.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "75b325c5dbd37f80359721ad39aca5a29fb04c89279657cffdda8736d0c0b9d2"
+
[[package]]
name = "dupe"
version = "0.9.1"
@@ -1724,6 +1731,17 @@ dependencies = [
"simd-adler32",
]
+[[package]]
+name = "filedescriptor"
+version = "0.8.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "e40758ed24c9b2eeb76c35fb0aebc66c626084edd827e07e1552279814c6682d"
+dependencies = [
+ "libc",
+ "thiserror 1.0.69",
+ "winapi",
+]
+
[[package]]
name = "fixedbitset"
version = "0.4.2"
@@ -3439,6 +3457,27 @@ dependencies = [
"portable-atomic",
]
+[[package]]
+name = "portable-pty"
+version = "0.9.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "b4a596a2b3d2752d94f51fac2d4a96737b8705dddd311a32b9af47211f08671e"
+dependencies = [
+ "anyhow",
+ "bitflags 1.3.2",
+ "downcast-rs",
+ "filedescriptor",
+ "lazy_static",
+ "libc",
+ "log",
+ "nix",
+ "serial2",
+ "shared_library",
+ "shell-words",
+ "winapi",
+ "winreg",
+]
+
[[package]]
name = "potential_utf"
version = "0.1.2"
@@ -4366,6 +4405,17 @@ dependencies = [
"syn 2.0.104",
]
+[[package]]
+name = "serial2"
+version = "0.2.31"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "26e1e5956803a69ddd72ce2de337b577898801528749565def03515f82bad5bb"
+dependencies = [
+ "cfg-if",
+ "libc",
+ "winapi",
+]
+
[[package]]
name = "sha1"
version = "0.10.6"
@@ -4397,6 +4447,22 @@ dependencies = [
"lazy_static",
]
+[[package]]
+name = "shared_library"
+version = "0.1.9"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "5a9e7e0f2bfae24d8a5b5a66c5b257a83c7412304311512a0c054cd5e619da11"
+dependencies = [
+ "lazy_static",
+ "libc",
+]
+
+[[package]]
+name = "shell-words"
+version = "1.1.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "24188a676b6ae68c3b2cb3a01be17fbf7240ce009799bb56d5b1409051e78fde"
+
[[package]]
name = "shlex"
version = "1.3.0"
@@ -6176,6 +6242,15 @@ dependencies = [
"memchr",
]
+[[package]]
+name = "winreg"
+version = "0.10.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "80d0f4e272c85def139476380b12f9ac60926689dd2e01d4923222f40580869d"
+dependencies = [
+ "winapi",
+]
+
[[package]]
name = "winsafe"
version = "0.0.19"
diff --git a/codex-rs/core/Cargo.toml b/codex-rs/core/Cargo.toml
index 56815ba03c..2f2fa7cbad 100644
--- a/codex-rs/core/Cargo.toml
+++ b/codex-rs/core/Cargo.toml
@@ -28,6 +28,7 @@ libc = "0.2.175"
mcp-types = { path = "../mcp-types" }
mime_guess = "2.0"
os_info = "3.12.0"
+portable-pty = "0.9.0"
rand = "0.9"
regex-lite = "0.1.6"
reqwest = { version = "0.12", features = ["json", "stream"] }
diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs
index 79b73a3335..e7be2ba997 100644
--- a/codex-rs/core/src/codex.rs
+++ b/codex-rs/core/src/codex.rs
@@ -52,6 +52,11 @@ use crate::exec::SandboxType;
use crate::exec::StdoutStream;
use crate::exec::StreamOutput;
use crate::exec::process_exec_tool_call;
+use crate::exec_command::EXEC_COMMAND_TOOL_NAME;
+use crate::exec_command::ExecCommandParams;
+use crate::exec_command::SESSION_MANAGER;
+use crate::exec_command::WRITE_STDIN_TOOL_NAME;
+use crate::exec_command::WriteStdinParams;
use crate::exec_env::create_env;
use crate::mcp_connection_manager::McpConnectionManager;
use crate::mcp_tool_call::handle_mcp_tool_call;
@@ -488,6 +493,7 @@ impl Session {
sandbox_policy.clone(),
config.include_plan_tool,
config.include_apply_patch_tool,
+ config.use_experimental_streamable_shell_tool,
),
user_instructions,
base_instructions,
@@ -1069,6 +1075,7 @@ async fn submission_loop(
new_sandbox_policy.clone(),
config.include_plan_tool,
config.include_apply_patch_tool,
+ config.use_experimental_streamable_shell_tool,
);
let new_turn_context = TurnContext {
@@ -1147,6 +1154,7 @@ async fn submission_loop(
sandbox_policy.clone(),
config.include_plan_tool,
config.include_apply_patch_tool,
+ config.use_experimental_streamable_shell_tool,
),
user_instructions: turn_context.user_instructions.clone(),
base_instructions: turn_context.base_instructions.clone(),
@@ -2037,6 +2045,52 @@ async fn handle_function_call(
.await
}
"update_plan" => handle_update_plan(sess, arguments, sub_id, call_id).await,
+ EXEC_COMMAND_TOOL_NAME => {
+ // TODO(mbolin): Sandbox check.
+ let exec_params = match serde_json::from_str::<ExecCommandParams>(&arguments) {
+ Ok(params) => params,
+ Err(e) => {
+ return ResponseInputItem::FunctionCallOutput {
+ call_id,
+ output: FunctionCallOutputPayload {
+ content: format!("failed to parse function arguments: {e}"),
+ success: Some(false),
+ },
+ };
+ }
+ };
+ let result = SESSION_MANAGER
+ .handle_exec_command_request(exec_params)
+ .await;
+ let function_call_output = crate::exec_command::result_into_payload(result);
+ ResponseInputItem::FunctionCallOutput {
+ call_id,
+ output: function_call_output,
+ }
+ }
+ WRITE_STDIN_TOOL_NAME => {
+ let write_stdin_params = match serde_json::from_str::<WriteStdinParams>(&arguments) {
+ Ok(params) => params,
+ Err(e) => {
+ return ResponseInputItem::FunctionCallOutput {
+ call_id,
+ output: FunctionCallOutputPayload {
+ content: format!("failed to parse function arguments: {e}"),
+ success: Some(false),
+ },
+ };
+ }
+ };
+ let result = SESSION_MANAGER
+ .handle_write_stdin_request(write_stdin_params)
+ .await;
+ let function_call_output: FunctionCallOutputPayload =
+ crate::exec_command::result_into_payload(result);
+ ResponseInputItem::FunctionCallOutput {
+ call_id,
+ output: function_call_output,
+ }
+ }
_ => {
match sess.mcp_connection_manager.parse_tool_name(&name) {
Some((server, tool_name)) => {
diff --git a/codex-rs/core/src/config.rs b/codex-rs/core/src/config.rs
index 67a54eb1d7..fbf0387a01 100644
--- a/codex-rs/core/src/config.rs
+++ b/codex-rs/core/src/config.rs
@@ -174,6 +174,8 @@ pub struct Config {
/// If set to `true`, the API key will be signed with the `originator` header.
pub preferred_auth_method: AuthMode,
+
+ pub use_experimental_streamable_shell_tool: bool,
}
impl Config {
@@ -469,6 +471,8 @@ pub struct ConfigToml {
/// Experimental path to a file whose contents replace the built-in BASE_INSTRUCTIONS.
pub experimental_instructions_file: Option<PathBuf>,
+ pub experimental_use_exec_command_tool: Option<bool>,
+
/// The value for the `originator` header included with Responses API requests.
pub responses_originator_header_internal_override: Option<String>,
@@ -758,6 +762,9 @@ impl Config {
include_apply_patch_tool: include_apply_patch_tool.unwrap_or(false),
responses_originator_header,
preferred_auth_method: cfg.preferred_auth_method.unwrap_or(AuthMode::ChatGPT),
+ use_experimental_streamable_shell_tool: cfg
+ .experimental_use_exec_command_tool
+ .unwrap_or(false),
};
Ok(config)
}
@@ -1124,6 +1131,7 @@ disable_response_storage = true
include_apply_patch_tool: false,
responses_originator_header: "codex_cli_rs".to_string(),
preferred_auth_method: AuthMode::ChatGPT,
+ use_experimental_streamable_shell_tool: false,
},
o3_profile_config
);
@@ -1178,6 +1186,7 @@ disable_response_storage = true
include_apply_patch_tool: false,
responses_originator_header: "codex_cli_rs".to_string(),
preferred_auth_method: AuthMode::ChatGPT,
+ use_experimental_streamable_shell_tool: false,
};
assert_eq!(expected_gpt3_profile_config, gpt3_profile_config);
@@ -1247,6 +1256,7 @@ disable_response_storage = true
include_apply_patch_tool: false,
responses_originator_header: "codex_cli_rs".to_string(),
preferred_auth_method: AuthMode::ChatGPT,
+ use_experimental_streamable_shell_tool: false,
};
assert_eq!(expected_zdr_profile_config, zdr_profile_config);
diff --git a/codex-rs/core/src/exec_command/exec_command_params.rs b/codex-rs/core/src/exec_command/exec_command_params.rs
new file mode 100644
index 0000000000..11a3fd4596
--- /dev/null
+++ b/codex-rs/core/src/exec_command/exec_command_params.rs
@@ -0,0 +1,57 @@
+use serde::Deserialize;
+use serde::Serialize;
+
+use crate::exec_command::session_id::SessionId;
+
+#[derive(Debug, Clone, Deserialize)]
+pub struct ExecCommandParams {
+ pub(crate) cmd: String,
+
+ #[serde(default = "default_yield_time")]
+ pub(crate) yield_time_ms: u64,
+
+ #[serde(default = "max_output_tokens")]
+ pub(crate) max_output_tokens: u64,
+
+ #[serde(default = "default_shell")]
+ pub(crate) shell: String,
+
+ #[serde(default = "default_login")]
+ pub(crate) login: bool,
+}
+
+fn default_yield_time() -> u64 {
+ 10_000
+}
+
+fn max_output_tokens() -> u64 {
+ 10_000
+}
+
+fn default_login() -> bool {
+ true
+}
+
+fn default_shell() -> String {
+ "/bin/bash".to_string()
+}
+
+#[derive(Debug, Deserialize, Serialize)]
+pub struct WriteStdinParams {
+ pub(crate) session_id: SessionId,
+ pub(crate) chars: String,
+
+ #[serde(default = "write_stdin_default_yield_time_ms")]
+ pub(crate) yield_time_ms: u64,
+
+ #[serde(default = "write_stdin_default_max_output_tokens")]
+ pub(crate) max_output_tokens: u64,
+}
+
+fn write_stdin_default_yield_time_ms() -> u64 {
+ 250
+}
+
+fn write_stdin_default_max_output_tokens() -> u64 {
+ 10_000
+}
diff --git a/codex-rs/core/src/exec_command/exec_command_session.rs b/codex-rs/core/src/exec_command/exec_command_session.rs
new file mode 100644
index 0000000000..7503150c9a
--- /dev/null
+++ b/codex-rs/core/src/exec_command/exec_command_session.rs
@@ -0,0 +1,83 @@
+use std::sync::Mutex as StdMutex;
+
+use tokio::sync::broadcast;
+use tokio::sync::mpsc;
+use tokio::task::JoinHandle;
+
+#[derive(Debug)]
+pub(crate) struct ExecCommandSession {
+ /// Queue for writing bytes to the process stdin (PTY master write side).
+ writer_tx: mpsc::Sender<Vec<u8>>,
+ /// Broadcast stream of output chunks read from the PTY. New subscribers
+ /// receive only chunks emitted after they subscribe.
+ output_tx: broadcast::Sender<Vec<u8>>,
+
+ /// Child killer handle for termination on drop (can signal independently
+ /// of a thread blocked in `.wait()`).
+ killer: StdMutex<Option<Box<dyn portable_pty::ChildKiller + Send + Sync>>>,
+
+ /// JoinHandle for the blocking PTY reader task.
+ reader_handle: StdMutex<Option<JoinHandle<()>>>,
+
+ /// JoinHandle for the stdin writer task.
+ writer_handle: StdMutex<Option<JoinHandle<()>>>,
+
+ /// JoinHandle for the child wait task.
+ wait_handle: StdMutex<Option<JoinHandle<()>>>,
+}
+
+impl ExecCommandSession {
+ pub(crate) fn new(
+ writer_tx: mpsc::Sender<Vec<u8>>,
+ output_tx: broadcast::Sender<Vec<u8>>,
+ killer: Box<dyn portable_pty::ChildKiller + Send + Sync>,
+ reader_handle: JoinHandle<()>,
+ writer_handle: JoinHandle<()>,
+ wait_handle: JoinHandle<()>,
+ ) -> Self {
+ Self {
+ writer_tx,
+ output_tx,
+ killer: StdMutex::new(Some(killer)),
+ reader_handle: StdMutex::new(Some(reader_handle)),
+ writer_handle: StdMutex::new(Some(writer_handle)),
+ wait_handle: StdMutex::new(Some(wait_handle)),
+ }
+ }
+
+ pub(crate) fn writer_sender(&self) -> mpsc::Sender<Vec<u8>> {
+ self.writer_tx.clone()
+ }
+
+ pub(crate) fn output_receiver(&self) -> broadcast::Receiver<Vec<u8>> {
+ self.output_tx.subscribe()
+ }
+}
+
+impl Drop for ExecCommandSession {
+ fn drop(&mut self) {
+ // Best-effort: terminate child first so blocking tasks can complete.
+ if let Ok(mut killer_opt) = self.killer.lock()
+ && let Some(mut killer) = killer_opt.take()
+ {
+ let _ = killer.kill();
+ }
+
+ // Abort background tasks; they may already have exited after kill.
+ if let Ok(mut h) = self.reader_handle.lock()
+ && let Some(handle) = h.take()
+ {
+ handle.abort();
+ }
+ if let Ok(mut h) = self.writer_handle.lock()
+ && let Some(handle) = h.take()
+ {
+ handle.abort();
+ }
+ if let Ok(mut h) = self.wait_handle.lock()
+ && let Some(handle) = h.take()
+ {
+ handle.abort();
+ }
+ }
+}
diff --git a/codex-rs/core/src/exec_command/mod.rs b/codex-rs/core/src/exec_command/mod.rs
new file mode 100644
index 0000000000..2fd88d4ec5
--- /dev/null
+++ b/codex-rs/core/src/exec_command/mod.rs
@@ -0,0 +1,14 @@
+mod exec_command_params;
+mod exec_command_session;
+mod responses_api;
+mod session_id;
+mod session_manager;
+
+pub use exec_command_params::ExecCommandParams;
+pub use exec_command_params::WriteStdinParams;
+pub use responses_api::EXEC_COMMAND_TOOL_NAME;
+pub use responses_api::WRITE_STDIN_TOOL_NAME;
+pub use responses_api::create_exec_command_tool_for_responses_api;
+pub use responses_api::create_write_stdin_tool_for_responses_api;
+pub use session_manager::SESSION_MANAGER;
+pub use session_manager::result_into_payload;
diff --git a/codex-rs/core/src/exec_command/responses_api.rs b/codex-rs/core/src/exec_command/responses_api.rs
new file mode 100644
index 0000000000..70b90dd425
--- /dev/null
+++ b/codex-rs/core/src/exec_command/responses_api.rs
@@ -0,0 +1,98 @@
+use std::collections::BTreeMap;
+
+use crate::openai_tools::JsonSchema;
+use crate::openai_tools::ResponsesApiTool;
+
+pub const EXEC_COMMAND_TOOL_NAME: &str = "exec_command";
+pub const WRITE_STDIN_TOOL_NAME: &str = "write_stdin";
+
+pub fn create_exec_command_tool_for_responses_api() -> ResponsesApiTool {
+ let mut properties = BTreeMap::<String, JsonSchema>::new();
+ properties.insert(
+ "cmd".to_string(),
+ JsonSchema::String {
+ description: Some("The shell command to execute.".to_string()),
+ },
+ );
+ properties.insert(
+ "yield_time_ms".to_string(),
+ JsonSchema::Number {
+ description: Some("The maximum time in milliseconds to wait for output.".to_string()),
+ },
+ );
+ properties.insert(
+ "max_output_tokens".to_string(),
+ JsonSchema::Number {
+ description: Some("The maximum number of tokens to output.".to_string()),
+ },
+ );
+ properties.insert(
+ "shell".to_string(),
+ JsonSchema::String {
+ description: Some("The shell to use. Defaults to \"/bin/bash\".".to_string()),
+ },
+ );
+ properties.insert(
+ "login".to_string(),
+ JsonSchema::Boolean {
+ description: Some(
+ "Whether to run the command as a login shell. Defaults to true.".to_string(),
+ ),
+ },
+ );
+
+ ResponsesApiTool {
+ name: EXEC_COMMAND_TOOL_NAME.to_owned(),
+ description: r#"Execute shell commands on the local machine with streaming output."#
+ .to_string(),
+ strict: false,
+ parameters: JsonSchema::Object {
+ properties,
+ required: Some(vec!["cmd".to_string()]),
+ additional_properties: Some(false),
+ },
+ }
+}
+
+pub fn create_write_stdin_tool_for_responses_api() -> ResponsesApiTool {
+ let mut properties = BTreeMap::<String, JsonSchema>::new();
+ properties.insert(
+ "session_id".to_string(),
+ JsonSchema::Number {
+ description: Some("The ID of the exec_command session.".to_string()),
+ },
+ );
+ properties.insert(
+ "chars".to_string(),
+ JsonSchema::String {
+ description: Some("The characters to write to stdin.".to_string()),
+ },
+ );
+ properties.insert(
+ "yield_time_ms".to_string(),
+ JsonSchema::Number {
+ description: Some(
+ "The maximum time in milliseconds to wait for output after writing.".to_string(),
+ ),
+ },
+ );
+ properties.insert(
+ "max_output_tokens".to_string(),
+ JsonSchema::Number {
+ description: Some("The maximum number of tokens to output.".to_string()),
+ },
+ );
+
+ ResponsesApiTool {
+ name: WRITE_STDIN_TOOL_NAME.to_owned(),
+ description: r#"Write characters to an exec session's stdin. Returns all stdout+stderr received within yield_time_ms.
+Can write control characters (\u0003 for Ctrl-C), or an empty string to just poll stdout+stderr."#
+ .to_string(),
+ strict: false,
+ parameters: JsonSchema::Object {
+ properties,
+ required: Some(vec!["session_id".to_string(), "chars".to_string()]),
+ additional_properties: Some(false),
+ },
+ }
+}
diff --git a/codex-rs/core/src/exec_command/session_id.rs b/codex-rs/core/src/exec_command/session_id.rs
new file mode 100644
index 0000000000..c97c5d5440
--- /dev/null
+++ b/codex-rs/core/src/exec_command/session_id.rs
@@ -0,0 +1,5 @@
+use serde::Deserialize;
+use serde::Serialize;
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
+pub(crate) struct SessionId(pub u32);
diff --git a/codex-rs/core/src/exec_command/session_manager.rs b/codex-rs/core/src/exec_command/session_manager.rs
new file mode 100644
index 0000000000..213b874bfa
--- /dev/null
+++ b/codex-rs/core/src/exec_command/session_manager.rs
@@ -0,0 +1,677 @@
+use std::collections::HashMap;
+use std::io::ErrorKind;
+use std::io::Read;
+use std::sync::Arc;
+use std::sync::LazyLock;
+use std::sync::Mutex as StdMutex;
+use std::sync::atomic::AtomicU32;
+
+use portable_pty::CommandBuilder;
+use portable_pty::PtySize;
+use portable_pty::native_pty_system;
+use tokio::sync::Mutex;
+use tokio::sync::mpsc;
+use tokio::sync::oneshot;
+use tokio::time::Duration;
+use tokio::time::Instant;
+use tokio::time::timeout;
+
+use crate::exec_command::exec_command_params::ExecCommandParams;
+use crate::exec_command::exec_command_params::WriteStdinParams;
+use crate::exec_command::exec_command_session::ExecCommandSession;
+use crate::exec_command::session_id::SessionId;
+use codex_protocol::models::FunctionCallOutputPayload;
+
+pub static SESSION_MANAGER: LazyLock<SessionManager> = LazyLock::new(SessionManager::default);
+
+#[derive(Debug, Default)]
+pub struct SessionManager {
+ next_session_id: AtomicU32,
+ sessions: Mutex<HashMap<SessionId, ExecCommandSession>>,
+}
+
+#[derive(Debug)]
+pub struct ExecCommandOutput {
+ wall_time: Duration,
+ exit_status: ExitStatus,
+ original_token_count: Option<u64>,
+ output: String,
+}
+
+impl ExecCommandOutput {
+ fn to_text_output(&self) -> String {
+ let wall_time_secs = self.wall_time.as_secs_f32();
+ let termination_status = match self.exit_status {
+ ExitStatus::Exited(code) => format!("Process exited with code {code}"),
+ ExitStatus::Ongoing(session_id) => {
+ format!("Process running with session ID {}", session_id.0)
+ }
+ };
+ let truncation_status = match self.original_token_count {
+ Some(tokens) => {
+ format!("\nWarning: truncated output (original token count: {tokens})")
+ }
+ None => "".to_string(),
+ };
+ format!(
+ r#"Wall time: {wall_time_secs:.3} seconds
+{termination_status}{truncation_status}
+Output:
+{output}"#,
+ output = self.output
+ )
+ }
+}
+
+#[derive(Debug)]
+pub enum ExitStatus {
+ Exited(i32),
+ Ongoing(SessionId),
+}
+
+pub fn result_into_payload(result: Result<ExecCommandOutput, String>) -> FunctionCallOutputPayload {
+ match result {
+ Ok(output) => FunctionCallOutputPayload {
+ content: output.to_text_output(),
+ success: Some(true),
+ },
+ Err(err) => FunctionCallOutputPayload {
+ content: err,
+ success: Some(false),
+ },
+ }
+}
+
+impl SessionManager {
+ /// Processes the request and is required to send a response via `outgoing`.
+ pub async fn handle_exec_command_request(
+ &self,
+ params: ExecCommandParams,
+ ) -> Result<ExecCommandOutput, String> {
+ // Allocate a session id.
+ let session_id = SessionId(
+ self.next_session_id
+ .fetch_add(1, std::sync::atomic::Ordering::SeqCst),
+ );
+
+ let (session, mut exit_rx) =
+ create_exec_command_session(params.clone())
+ .await
+ .map_err(|err| {
+ format!(
+ "failed to create exec command session for session id {}: {err}",
+ session_id.0
+ )
+ })?;
+
+ // Insert into session map.
+ let mut output_rx = session.output_receiver();
+ self.sessions.lock().await.insert(session_id, session);
+
+ // Collect output until either timeout expires or process exits.
+ // Do not cap during collection; truncate at the end if needed.
+ // Use a modest initial capacity to avoid large preallocation.
+ let cap_bytes_u64 = params.max_output_tokens.saturating_mul(4);
+ let cap_bytes: usize = cap_bytes_u64.min(usize::MAX as u64) as usize;
+ let mut collected: Vec<u8> = Vec::with_capacity(4096);
+
+ let start_time = Instant::now();
+ let deadline = start_time + Duration::from_millis(params.yield_time_ms);
+ let mut exit_code: Option<i32> = None;
+
+ loop {
+ if Instant::now() >= deadline {
+ break;
+ }
+ let remaining = deadline.saturating_duration_since(Instant::now());
+ tokio::select! {
+ biased;
+ exit = &mut exit_rx => {
+ exit_code = exit.ok();
+ // Small grace period to pull remaining buffered output
+ let grace_deadline = Instant::now() + Duration::from_millis(25);
+ while Instant::now() < grace_deadline {
+ match timeout(Duration::from_millis(1), output_rx.recv()).await {
+ Ok(Ok(chunk)) => {
+ collected.extend_from_slice(&chunk);
+ }
+ Ok(Err(tokio::sync::broadcast::error::RecvError::Lagged(_))) => {
+ // Skip missed messages; keep trying within grace period.
+ continue;
+ }
+ Ok(Err(tokio::sync::broadcast::error::RecvError::Closed)) => break,
+ Err(_) => break,
+ }
+ }
+ break;
+ }
+ chunk = timeout(remaining, output_rx.recv()) => {
+ match chunk {
+ Ok(Ok(chunk)) => {
+ collected.extend_from_slice(&chunk);
+ }
+ Ok(Err(tokio::sync::broadcast::error::RecvError::Lagged(_))) => {
+ // Skip missed messages; continue collecting fresh output.
+ }
+ Ok(Err(tokio::sync::broadcast::error::RecvError::Closed)) => { break; }
+ Err(_) => { break; }
+ }
+ }
+ }
+ }
+
+ let output = String::from_utf8_lossy(&collected).to_string();
+
+ let exit_status = if let Some(code) = exit_code {
+ ExitStatus::Exited(code)
+ } else {
+ ExitStatus::Ongoing(session_id)
+ };
+
+ // If output exceeds cap, truncate the middle and record original token estimate.
+ let (output, original_token_count) = truncate_middle(&output, cap_bytes);
+ Ok(ExecCommandOutput {
+ wall_time: Instant::now().duration_since(start_time),
+ exit_status,
+ original_token_count,
+ output,
+ })
+ }
+
+ /// Write characters to a session's stdin and collect combined output for up to `yield_time_ms`.
+ pub async fn handle_write_stdin_request(
+ &self,
+ params: WriteStdinParams,
+ ) -> Result<ExecCommandOutput, String> {
+ let WriteStdinParams {
+ session_id,
+ chars,
+ yield_time_ms,
+ max_output_tokens,
+ } = params;
+
+ // Grab handles without holding the sessions lock across await points.
+ let (writer_tx, mut output_rx) = {
+ let sessions = self.sessions.lock().await;
+ match sessions.get(&session_id) {
+ Some(session) => (session.writer_sender(), session.output_receiver()),
+ None => {
+ return Err(format!("unknown session id {}", session_id.0));
+ }
+ }
+ };
+
+ // Write stdin if provided.
+ if !chars.is_empty() && writer_tx.send(chars.into_bytes()).await.is_err() {
+ return Err("failed to write to stdin".to_string());
+ }
+
+ // Collect output up to yield_time_ms, truncating to max_output_tokens bytes.
+ let mut collected: Vec<u8> = Vec::with_capacity(4096);
+ let start_time = Instant::now();
+ let deadline = start_time + Duration::from_millis(yield_time_ms);
+ loop {
+ let now = Instant::now();
+ if now >= deadline {
+ break;
+ }
+ let remaining = deadline - now;
+ match timeout(remaining, output_rx.recv()).await {
+ Ok(Ok(chunk)) => {
+ // Collect all output within the time budget; truncate at the end.
+ collected.extend_from_slice(&chunk);
+ }
+ Ok(Err(tokio::sync::broadcast::error::RecvError::Lagged(_))) => {
+ // Skip missed messages; continue collecting fresh output.
+ }
+ Ok(Err(tokio::sync::broadcast::error::RecvError::Closed)) => break,
+ Err(_) => break, // timeout
+ }
+ }
+
+ // Return structured output, truncating middle if over cap.
+ let output = String::from_utf8_lossy(&collected).to_string();
+ let cap_bytes_u64 = max_output_tokens.saturating_mul(4);
+ let cap_bytes: usize = cap_bytes_u64.min(usize::MAX as u64) as usize;
+ let (output, original_token_count) = truncate_middle(&output, cap_bytes);
+ Ok(ExecCommandOutput {
+ wall_time: Instant::now().duration_since(start_time),
+ exit_status: ExitStatus::Ongoing(session_id),
+ original_token_count,
+ output,
+ })
+ }
+}
+
+/// Spawn PTY and child process per spawn_exec_command_session logic.
+async fn create_exec_command_session(
+ params: ExecCommandParams,
+) -> anyhow::Result<(ExecCommandSession, oneshot::Receiver<i32>)> {
+ let ExecCommandParams {
+ cmd,
+ yield_time_ms: _,
+ max_output_tokens: _,
+ shell,
+ login,
+ } = params;
+
+ // Use the native pty implementation for the system
+ let pty_system = native_pty_system();
+
+ // Create a new pty
+ let pair = pty_system.openpty(PtySize {
+ rows: 24,
+ cols: 80,
+ pixel_width: 0,
+ pixel_height: 0,
+ })?;
+
+ // Spawn a shell into the pty
+ let mut command_builder = CommandBuilder::new(shell);
+ let shell_mode_opt = if login { "-lc" } else { "-c" };
+ command_builder.arg(shell_mode_opt);
+ command_builder.arg(cmd);
+
+ let mut child = pair.slave.spawn_command(command_builder)?;
+ // Obtain a killer that can signal the process independently of `.wait()`.
+ let killer = child.clone_killer();
+
+ // Channel to forward write requests to the PTY writer.
+ let (writer_tx, mut writer_rx) = mpsc::channel::<Vec<u8>>(128);
+ // Broadcast for streaming PTY output to readers: subscribers receive from subscription time.
+ let (output_tx, _) = tokio::sync::broadcast::channel::<Vec<u8>>(256);
+
+ // Reader task: drain PTY and forward chunks to output channel.
+ let mut reader = pair.master.try_clone_reader()?;
+ let output_tx_clone = output_tx.clone();
+ let reader_handle = tokio::task::spawn_blocking(move || {
+ let mut buf = [0u8; 8192];
+ loop {
+ match reader.read(&mut buf) {
+ Ok(0) => break, // EOF
+ Ok(n) => {
+ // Forward to broadcast; best-effort if there are subscribers.
+ let _ = output_tx_clone.send(buf[..n].to_vec());
+ }
+ Err(ref e) if e.kind() == ErrorKind::Interrupted => {
+ // Retry on EINTR
+ continue;
+ }
+ Err(ref e) if e.kind() == ErrorKind::WouldBlock => {
+ // We're in a blocking thread; back off briefly and retry.
+ std::thread::sleep(Duration::from_millis(5));
+ continue;
+ }
+ Err(_) => break,
+ }
+ }
+ });
+
+ // Writer task: apply stdin writes to the PTY writer.
+ let writer = pair.master.take_writer()?;
+ let writer = Arc::new(StdMutex::new(writer));
+ let writer_handle = tokio::spawn({
+ let writer = writer.clone();
+ async move {
+ while let Some(bytes) = writer_rx.recv().await {
+ let writer = writer.clone();
+ // Perform blocking write on a blocking thread.
+ let _ = tokio::task::spawn_blocking(move || {
+ if let Ok(mut guard) = writer.lock() {
+ use std::io::Write;
+ let _ = guard.write_all(&bytes);
+ let _ = guard.flush();
+ }
+ })
+ .await;
+ }
+ }
+ });
+
+ // Keep the child alive until it exits, then signal exit code.
+ let (exit_tx, exit_rx) = oneshot::channel::<i32>();
+ let wait_handle = tokio::task::spawn_blocking(move || {
+ let code = match child.wait() {
+ Ok(status) => status.exit_code() as i32,
+ Err(_) => -1,
+ };
+ let _ = exit_tx.send(code);
+ });
+
+ // Create and store the session with channels.
+ let session = ExecCommandSession::new(
+ writer_tx,
+ output_tx,
+ killer,
+ reader_handle,
+ writer_handle,
+ wait_handle,
+ );
+ Ok((session, exit_rx))
+}
+
+/// Truncate the middle of a UTF-8 string to at most `max_bytes` bytes,
+/// preserving the beginning and the end. Returns the possibly truncated
+/// string and `Some(original_token_count)` (estimated at 4 bytes/token)
+/// if truncation occurred; otherwise returns the original string and `None`.
+fn truncate_middle(s: &str, max_bytes: usize) -> (String, Option<u64>) {
+ // No truncation needed
+ if s.len() <= max_bytes {
+ return (s.to_string(), None);
+ }
+ let est_tokens = (s.len() as u64).div_ceil(4);
+ if max_bytes == 0 {
+ // Cannot keep any content; still return a full marker (never truncated).
+ return (
+ format!("…{} tokens truncated…", est_tokens),
+ Some(est_tokens),
+ );
+ }
+
+ // Helper to truncate a string to a given byte length on a char boundary.
+ fn truncate_on_boundary(input: &str, max_len: usize) -> &str {
+ if input.len() <= max_len {
+ return input;
+ }
+ let mut end = max_len;
+ while end > 0 && !input.is_char_boundary(end) {
+ end -= 1;
+ }
+ &input[..end]
+ }
+
+ // Given a left/right budget, prefer newline boundaries; otherwise fall back
+ // to UTF-8 char boundaries.
+ fn pick_prefix_end(s: &str, left_budget: usize) -> usize {
+ if let Some(head) = s.get(..left_budget)
+ && let Some(i) = head.rfind('\n')
+ {
+ return i + 1; // keep the newline so suffix starts on a fresh line
+ }
+ truncate_on_boundary(s, left_budget).len()
+ }
+
+ fn pick_suffix_start(s: &str, right_budget: usize) -> usize {
+ let start_tail = s.len().saturating_sub(right_budget);
+ if let Some(tail) = s.get(start_tail..)
+ && let Some(i) = tail.find('\n')
+ {
+ return start_tail + i + 1; // start after newline
+ }
+ // Fall back to a char boundary at or after start_tail.
+ let mut idx = start_tail.min(s.len());
+ while idx < s.len() && !s.is_char_boundary(idx) {
+ idx += 1;
+ }
+ idx
+ }
+
+ // Refine marker length and budgets until stable. Marker is never truncated.
+ let mut guess_tokens = est_tokens; // worst-case: everything truncated
+ for _ in 0..4 {
+ let marker = format!("…{} tokens truncated…", guess_tokens);
+ let marker_len = marker.len();
+ let keep_budget = max_bytes.saturating_sub(marker_len);
+ if keep_budget == 0 {
+ // No room for any content within the cap; return a full, untruncated marker
+ // that reflects the entire truncated content.
+ return (
+ format!("…{} tokens truncated…", est_tokens),
+ Some(est_tokens),
+ );
+ }
+
+ let left_budget = keep_budget / 2;
+ let right_budget = keep_budget - left_budget;
+ let prefix_end = pick_prefix_end(s, left_budget);
+ let mut suffix_start = pick_suffix_start(s, right_budget);
+ if suffix_start < prefix_end {
+ suffix_start = prefix_end;
+ }
+ let kept_content_bytes = prefix_end + (s.len() - suffix_start);
+ let truncated_content_bytes = s.len().saturating_sub(kept_content_bytes);
+ let new_tokens = (truncated_content_bytes as u64).div_ceil(4);
+ if new_tokens == guess_tokens {
+ let mut out = String::with_capacity(marker_len + kept_content_bytes + 1);
+ out.push_str(&s[..prefix_end]);
+ out.push_str(&marker);
+ // Place marker on its own line for symmetry when we keep line boundaries.
+ out.push('\n');
+ out.push_str(&s[suffix_start..]);
+ return (out, Some(est_tokens));
+ }
+ guess_tokens = new_tokens;
+ }
+
+ // Fallback: use last guess to build output.
+ let marker = format!("…{} tokens truncated…", guess_tokens);
+ let marker_len = marker.len();
+ let keep_budget = max_bytes.saturating_sub(marker_len);
+ if keep_budget == 0 {
+ return (
+ format!("…{} tokens truncated…", est_tokens),
+ Some(est_tokens),
+ );
+ }
+ let left_budget = keep_budget / 2;
+ let right_budget = keep_budget - left_budget;
+ let prefix_end = pick_prefix_end(s, left_budget);
+ let suffix_start = pick_suffix_start(s, right_budget);
+ let mut out = String::with_capacity(marker_len + prefix_end + (s.len() - suffix_start) + 1);
+ out.push_str(&s[..prefix_end]);
+ out.push_str(&marker);
+ out.push('\n');
+ out.push_str(&s[suffix_start..]);
+ (out, Some(est_tokens))
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::exec_command::session_id::SessionId;
+
+ /// Test that verifies that [`SessionManager::handle_exec_command_request()`]
+ /// and [`SessionManager::handle_write_stdin_request()`] work as expected
+ /// in the presence of a process that never terminates (but produces
+ /// output continuously).
+ #[cfg(unix)]
+ #[allow(clippy::print_stderr)]
+ #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
+ async fn session_manager_streams_and_truncates_from_now() {
+ use crate::exec_command::exec_command_params::ExecCommandParams;
+ use crate::exec_command::exec_command_params::WriteStdinParams;
+ use tokio::time::sleep;
+
+ let session_manager = SessionManager::default();
+ // Long-running loop that prints an increasing counter every ~100ms.
+ // Use Python for a portable, reliable sleep across shells/PTYs.
+ let cmd = r#"python3 - <<'PY'
+import sys, time
+count = 0
+while True:
+ print(count)
+ sys.stdout.flush()
+ count += 100
+ time.sleep(0.1)
+PY"#
+ .to_string();
+
+ // Start the session and collect ~3s of output.
+ let params = ExecCommandParams {
+ cmd,
+ yield_time_ms: 3_000,
+ max_output_tokens: 1_000, // large enough to avoid truncation here
+ shell: "/bin/bash".to_string(),
+ login: false,
+ };
+ let initial_output = match session_manager
+ .handle_exec_command_request(params.clone())
+ .await
+ {
+ Ok(v) => v,
+ Err(e) => {
+ // PTY may be restricted in some sandboxes; skip in that case.
+ if e.contains("openpty") || e.contains("Operation not permitted") {
+ eprintln!("skipping test due to restricted PTY: {e}");
+ return;
+ }
+ panic!("exec request failed unexpectedly: {e}");
+ }
+ };
+ eprintln!("initial output: {initial_output:?}");
+
+ // Should be ongoing (we launched a never-ending loop).
+ let session_id = match initial_output.exit_status {
+ ExitStatus::Ongoing(id) => id,
+ _ => panic!("expected ongoing session"),
+ };
+
+ // Parse the numeric lines and get the max observed value in the first window.
+ let first_nums = extract_monotonic_numbers(&initial_output.output);
+ assert!(
+ !first_nums.is_empty(),
+ "expected some output from first window"
+ );
+ let first_max = *first_nums.iter().max().unwrap();
+
+ // Wait ~4s so counters progress while we're not reading.
+ sleep(Duration::from_millis(4_000)).await;
+
+ // Now read ~3s of output "from now" only.
+ // Use a small token cap so truncation occurs and we test middle truncation.
+ let write_params = WriteStdinParams {
+ session_id,
+ chars: String::new(),
+ yield_time_ms: 3_000,
+ max_output_tokens: 16, // 16 tokens ~= 64 bytes -> likely truncation
+ };
+ let second = session_manager
+ .handle_write_stdin_request(write_params)
+ .await
+ .expect("write stdin should succeed");
+
+ // Verify truncation metadata and size bound (cap is tokens*4 bytes).
+ assert!(second.original_token_count.is_some());
+ let cap_bytes = (16u64 * 4) as usize;
+ assert!(second.output.len() <= cap_bytes);
+ // New middle marker should be present.
+ assert!(
+ second.output.contains("tokens truncated") && second.output.contains('…'),
+ "expected truncation marker in output, got: {}",
+ second.output
+ );
+
+ // Minimal freshness check: the earliest number we see in the second window
+ // should be significantly larger than the last from the first window.
+ let second_nums = extract_monotonic_numbers(&second.output);
+ assert!(
+ !second_nums.is_empty(),
+ "expected some numeric output from second window"
+ );
+ let second_min = *second_nums.iter().min().unwrap();
+
+ // We slept 4 seconds (~40 ticks at 100ms/tick, each +100), so expect
+ // an increase of roughly 4000 or more. Allow a generous margin.
+ assert!(
+ second_min >= first_max + 2000,
+ "second_min={second_min} first_max={first_max}",
+ );
+ }
+
+ #[cfg(unix)]
+ fn extract_monotonic_numbers(s: &str) -> Vec<i64> {
+ s.lines()
+ .filter_map(|line| {
+ if !line.is_empty()
+ && line.chars().all(|c| c.is_ascii_digit())
+ && let Ok(n) = line.parse::<i64>()
+ {
+ // Our generator increments by 100; ignore spurious fragments.
+ if n % 100 == 0 {
+ return Some(n);
+ }
+ }
+ None
+ })
+ .collect()
+ }
+
+ #[test]
+ fn to_text_output_exited_no_truncation() {
+ let out = ExecCommandOutput {
+ wall_time: Duration::from_millis(1234),
+ exit_status: ExitStatus::Exited(0),
+ original_token_count: None,
+ output: "hello".to_string(),
+ };
+ let text = out.to_text_output();
+ let expected = r#"Wall time: 1.234 seconds
+Process exited with code 0
+Output:
+hello"#;
+ assert_eq!(expected, text);
+ }
+
+ #[test]
+ fn to_text_output_ongoing_with_truncation() {
+ let out = ExecCommandOutput {
+ wall_time: Duration::from_millis(500),
+ exit_status: ExitStatus::Ongoing(SessionId(42)),
+ original_token_count: Some(1000),
+ output: "abc".to_string(),
+ };
+ let text = out.to_text_output();
+ let expected = r#"Wall time: 0.500 seconds
+Process running with session ID 42
+Warning: truncated output (original token count: 1000)
+Output:
+abc"#;
+ assert_eq!(expected, text);
+ }
+
+ #[test]
+ fn truncate_middle_no_newlines_fallback() {
+ // A long string with no newlines that exceeds the cap.
+ let s = "abcdefghijklmnopqrstuvwxyz0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ";
+ let max_bytes = 16; // force truncation
+ let (out, original) = truncate_middle(s, max_bytes);
+ // For very small caps, we return the full, untruncated marker,
+ // even if it exceeds the cap.
+ assert_eq!(out, "…16 tokens truncated…");
+ // Original string length is 62 bytes => ceil(62/4) = 16 tokens.
+ assert_eq!(original, Some(16));
+ }
+
+ #[test]
+ fn truncate_middle_prefers_newline_boundaries() {
+ // Build a multi-line string of 20 numbered lines (each "NNN\n").
+ let mut s = String::new();
+ for i in 1..=20 {
+ s.push_str(&format!("{i:03}\n"));
+ }
+ // Total length: 20 lines * 4 bytes per line = 80 bytes.
+ assert_eq!(s.len(), 80);
+
+ // Choose a cap that forces truncation while leaving room for
+ // a few lines on each side after accounting for the marker.
+ let max_bytes = 64;
+ // Expect exact output: first 4 lines, marker, last 4 lines, and correct token estimate (80/4 = 20).
+ assert_eq!(
+ truncate_middle(&s, max_bytes),
+ (
+ r#"001
+002
+003
+004
+…12 tokens truncated…
+017
+018
+019
+020
+"#
+ .to_string(),
+ Some(20)
+ )
+ );
+ }
+}
diff --git a/codex-rs/core/src/lib.rs b/codex-rs/core/src/lib.rs
index 6d4699bceb..ae18332087 100644
--- a/codex-rs/core/src/lib.rs
+++ b/codex-rs/core/src/lib.rs
@@ -20,6 +20,7 @@ mod conversation_history;
mod environment_context;
pub mod error;
pub mod exec;
+mod exec_command;
pub mod exec_env;
mod flags;
pub mod git_info;
diff --git a/codex-rs/core/src/openai_tools.rs b/codex-rs/core/src/openai_tools.rs
index bb5e6dacbf..272c901dc2 100644
--- a/codex-rs/core/src/openai_tools.rs
+++ b/codex-rs/core/src/openai_tools.rs
@@ -56,6 +56,7 @@ pub enum ConfigShellToolType {
DefaultShell,
ShellWithRequest { sandbox_policy: SandboxPolicy },
LocalShell,
+ StreamableShell,
}
#[derive(Debug, Clone)]
@@ -72,13 +73,16 @@ impl ToolsConfig {
sandbox_policy: SandboxPolicy,
include_plan_tool: bool,
include_apply_patch_tool: bool,
+ use_streamable_shell_tool: bool,
) -> Self {
- let mut shell_type = if model_family.uses_local_shell_tool {
+ let mut shell_type = if use_streamable_shell_tool {
+ ConfigShellToolType::StreamableShell
+ } else if model_family.uses_local_shell_tool {
ConfigShellToolType::LocalShell
} else {
ConfigShellToolType::DefaultShell
};
- if matches!(approval_policy, AskForApproval::OnRequest) {
+ if matches!(approval_policy, AskForApproval::OnRequest) && !use_streamable_shell_tool {
shell_type = ConfigShellToolType::ShellWithRequest {
sandbox_policy: sandbox_policy.clone(),
}
@@ -492,6 +496,14 @@ pub(crate) fn get_openai_tools(
ConfigShellToolType::LocalShell => {
tools.push(OpenAiTool::LocalShell {});
}
+ ConfigShellToolType::StreamableShell => {
+ tools.push(OpenAiTool::Function(
+ crate::exec_command::create_exec_command_tool_for_responses_api(),
+ ));
+ tools.push(OpenAiTool::Function(
+ crate::exec_command::create_write_stdin_tool_for_responses_api(),
+ ));
+ }
}
if config.plan_tool {
@@ -564,6 +576,7 @@ mod tests {
SandboxPolicy::ReadOnly,
true,
false,
+ /*use_experimental_streamable_shell_tool*/ false,
);
let tools = get_openai_tools(&config, Some(HashMap::new()));
@@ -579,6 +592,7 @@ mod tests {
SandboxPolicy::ReadOnly,
true,
false,
+ /*use_experimental_streamable_shell_tool*/ false,
);
let tools = get_openai_tools(&config, Some(HashMap::new()));
@@ -594,6 +608,7 @@ mod tests {
SandboxPolicy::ReadOnly,
false,
false,
+ /*use_experimental_streamable_shell_tool*/ false,
);
let tools = get_openai_tools(
&config,
@@ -688,6 +703,7 @@ mod tests {
SandboxPolicy::ReadOnly,
false,
false,
+ /*use_experimental_streamable_shell_tool*/ false,
);
let tools = get_openai_tools(
@@ -744,6 +760,7 @@ mod tests {
SandboxPolicy::ReadOnly,
false,
false,
+ /*use_experimental_streamable_shell_tool*/ false,
);
let tools = get_openai_tools(
@@ -795,6 +812,7 @@ mod tests {
SandboxPolicy::ReadOnly,
false,
false,
+ /*use_experimental_streamable_shell_tool*/ false,
);
let tools = get_openai_tools(
@@ -849,6 +867,7 @@ mod tests {
SandboxPolicy::ReadOnly,
false,
false,
+ /*use_experimental_streamable_shell_tool*/ false,
);
let tools = get_openai_tools(
Review Comments
codex-rs/core/src/exec_command/session_manager.rs
- Created: 2025-08-22 23:23:04 UTC | Link: https://github.com/openai/codex/pull/2574#discussion_r2294924169
@@ -0,0 +1,580 @@
+use std::collections::HashMap;
+use std::io::Read;
+use std::sync::Arc;
+use std::sync::LazyLock;
+use std::sync::Mutex as StdMutex;
+use std::sync::atomic::AtomicU32;
+
+use portable_pty::CommandBuilder;
+use portable_pty::PtySize;
+use portable_pty::native_pty_system;
+use tokio::sync::Mutex;
+use tokio::sync::mpsc;
+use tokio::sync::oneshot;
+use tokio::time::Duration;
+use tokio::time::Instant;
+use tokio::time::timeout;
+
+use crate::exec_command::exec_command_params::ExecCommandParams;
+use crate::exec_command::exec_command_params::WriteStdinParams;
+use crate::exec_command::exec_command_session::ExecCommandSession;
+use crate::exec_command::session_id::SessionId;
+use crate::models::FunctionCallOutputPayload;
+
+pub static SESSION_MANAGER: LazyLock<SessionManager> = LazyLock::new(SessionManager::default);
+
+#[derive(Debug, Default)]
+pub struct SessionManager {
+ next_session_id: AtomicU32,
+ sessions: Mutex<HashMap<SessionId, ExecCommandSession>>,
+}
+
+#[derive(Debug)]
+pub struct ExecCommandOutput {
+ wall_time: Duration,
+ exit_status: ExitStatus,
+ original_token_count: Option<u64>,
+ output: String,
+}
+
+impl ExecCommandOutput {
+ fn to_text_output(&self) -> String {
+ let wall_time_secs = self.wall_time.as_secs_f32();
+ let termination_status = match self.exit_status {
+ ExitStatus::Exited(code) => format!("Process exited with code {code}"),
+ ExitStatus::Ongoing(session_id) => {
+ format!("Process running with session ID {}", session_id.0)
+ }
+ };
+ let truncation_status = match self.original_token_count {
+ Some(tokens) => {
+ format!("\nWarning: truncated output (original token count: {tokens})")
+ }
+ None => "".to_string(),
+ };
+ format!(
+ r#"Wall time: {wall_time_secs:.3} seconds
+{termination_status}{truncation_status}
+Output:
+{output}"#,
+ output = self.output
+ )
+ }
+}
+
+#[derive(Debug)]
+pub enum ExitStatus {
+ Exited(i32),
+ Ongoing(SessionId),
+}
+
+impl From<Result<ExecCommandOutput, String>> for FunctionCallOutputPayload {
+ fn from(val: Result<ExecCommandOutput, String>) -> Self {
+ match val {
+ Ok(output) => FunctionCallOutputPayload {
+ content: output.to_text_output(),
+ success: Some(true),
+ },
+ Err(err) => FunctionCallOutputPayload {
+ content: err,
+ success: Some(false),
+ },
+ }
+ }
+}
+
+impl SessionManager {
+ /// Processes the request and is required to send a response via `outgoing`.
+ pub async fn handle_exec_command_request(
+ &self,
+ params: ExecCommandParams,
+ ) -> Result<ExecCommandOutput, String> {
+ // Allocate a session id.
+ let session_id = SessionId(
+ self.next_session_id
+ .fetch_add(1, std::sync::atomic::Ordering::SeqCst),
+ );
+
+ let (session, mut exit_rx) =
+ create_exec_command_session(params.clone())
+ .await
+ .map_err(|err| {
+ format!(
+ "failed to create exec command session for session id {}: {err}",
+ session_id.0
+ )
+ })?;
+
+ // Insert into session map.
+ let mut output_rx = session.output_receiver();
+ self.sessions.lock().await.insert(session_id, session);
+
+ // Collect output until either timeout expires or process exits.
+ // Do not cap during collection; truncate at the end if needed.
+ // Use a modest initial capacity to avoid large preallocation.
+ let cap_bytes_u64 = params.max_output_tokens.saturating_mul(4);
+ let cap_bytes: usize = cap_bytes_u64.min(usize::MAX as u64) as usize;
+ let mut collected: Vec<u8> = Vec::with_capacity(4096);
+
+ let start_time = Instant::now();
+ let deadline = start_time + Duration::from_millis(params.yield_time_ms);
+ let mut exit_code: Option<i32> = None;
+
+ loop {
+ if Instant::now() >= deadline {
+ break;
+ }
+ let remaining = deadline.saturating_duration_since(Instant::now());
+ tokio::select! {
+ biased;
+ exit = &mut exit_rx => {
+ exit_code = exit.ok();
+ // Small grace period to pull remaining buffered output
+ let grace_deadline = Instant::now() + Duration::from_millis(25);
+ while Instant::now() < grace_deadline {
+ match timeout(Duration::from_millis(1), output_rx.recv()).await {
+ Ok(Ok(chunk)) => {
+ collected.extend_from_slice(&chunk);
+ }
+ Ok(Err(tokio::sync::broadcast::error::RecvError::Lagged(_))) => {
+ // Skip missed messages; keep trying within grace period.
+ continue;
+ }
+ Ok(Err(tokio::sync::broadcast::error::RecvError::Closed)) => break,
+ Err(_) => break,
+ }
+ }
+ break;
+ }
+ chunk = timeout(remaining, output_rx.recv()) => {
+ match chunk {
+ Ok(Ok(chunk)) => {
+ collected.extend_from_slice(&chunk);
+ }
+ Ok(Err(tokio::sync::broadcast::error::RecvError::Lagged(_))) => {
+ // Skip missed messages; continue collecting fresh output.
+ }
+ Ok(Err(tokio::sync::broadcast::error::RecvError::Closed)) => { break; }
+ Err(_) => { break; }
+ }
+ }
+ }
+ }
+
+ let output = String::from_utf8_lossy(&collected).to_string();
+
+ let exit_status = if let Some(code) = exit_code {
+ ExitStatus::Exited(code)
+ } else {
+ ExitStatus::Ongoing(session_id)
+ };
+
+ // If output exceeds cap, truncate the middle and record original token estimate.
+ let (output, original_token_count) = truncate_middle(&output, cap_bytes);
+ Ok(ExecCommandOutput {
+ wall_time: Instant::now().duration_since(start_time),
+ exit_status,
+ original_token_count,
+ output,
+ })
+ }
+
+ /// Write characters to a session's stdin and collect combined output for up to `yield_time_ms`.
+ pub async fn handle_write_stdin_request(
+ &self,
+ params: WriteStdinParams,
+ ) -> Result<ExecCommandOutput, String> {
+ let WriteStdinParams {
+ session_id,
+ chars,
+ yield_time_ms,
+ max_output_tokens,
+ } = params;
+
+ // Grab handles without holding the sessions lock across await points.
+ let (writer_tx, mut output_rx) = {
+ let sessions = self.sessions.lock().await;
+ match sessions.get(&session_id) {
+ Some(session) => (session.writer_sender(), session.output_receiver()),
+ None => {
+ return Err(format!("unknown session id {}", session_id.0));
+ }
+ }
+ };
+
+ // Write stdin if provided.
+ if !chars.is_empty() && writer_tx.send(chars.into_bytes()).await.is_err() {
+ return Err("failed to write to stdin".to_string());
+ }
+
+ // Collect output up to yield_time_ms, truncating to max_output_tokens bytes.
+ let mut collected: Vec<u8> = Vec::with_capacity(4096);
+ let start_time = Instant::now();
+ let deadline = start_time + Duration::from_millis(yield_time_ms);
+ loop {
+ let now = Instant::now();
+ if now >= deadline {
+ break;
+ }
+ let remaining = deadline - now;
+ match timeout(remaining, output_rx.recv()).await {
+ Ok(Ok(chunk)) => {
+ // Collect all output within the time budget; truncate at the end.
+ collected.extend_from_slice(&chunk);
+ }
+ Ok(Err(tokio::sync::broadcast::error::RecvError::Lagged(_))) => {
+ // Skip missed messages; continue collecting fresh output.
+ }
+ Ok(Err(tokio::sync::broadcast::error::RecvError::Closed)) => break,
+ Err(_) => break, // timeout
+ }
+ }
+
+ // Return structured output, truncating middle if over cap.
+ let output = String::from_utf8_lossy(&collected).to_string();
+ let cap_bytes_u64 = max_output_tokens.saturating_mul(4);
+ let cap_bytes: usize = cap_bytes_u64.min(usize::MAX as u64) as usize;
+ let (output, original_token_count) = truncate_middle(&output, cap_bytes);
+ Ok(ExecCommandOutput {
+ wall_time: Instant::now().duration_since(start_time),
+ exit_status: ExitStatus::Ongoing(session_id),
+ original_token_count,
+ output,
+ })
+ }
+}
+
+/// Spawn PTY and child process per spawn_exec_command_session logic.
+async fn create_exec_command_session(
+ params: ExecCommandParams,
+) -> anyhow::Result<(ExecCommandSession, oneshot::Receiver<i32>)> {
+ let ExecCommandParams {
+ cmd,
+ yield_time_ms: _,
+ max_output_tokens: _,
+ shell,
+ login,
+ } = params;
+
+ // Use the native pty implementation for the system
+ let pty_system = native_pty_system();
+
+ // Create a new pty
+ let pair = pty_system.openpty(PtySize {
+ rows: 24,
+ cols: 80,
+ pixel_width: 0,
+ pixel_height: 0,
+ })?;
+
+ // Spawn a shell into the pty
+ let mut command_builder = CommandBuilder::new(shell);
+ let shell_mode_opt = if login { "-lc" } else { "-c" };
+ command_builder.arg(shell_mode_opt);
+ command_builder.arg(cmd);
+
+ let mut child = pair.slave.spawn_command(command_builder)?;
+ // Obtain a killer that can signal the process independently of `.wait()`.
+ let killer = child.clone_killer();
+
+ // Channel to forward write requests to the PTY writer.
+ let (writer_tx, mut writer_rx) = mpsc::channel::<Vec<u8>>(128);
+ // Broadcast for streaming PTY output to readers: subscribers receive from subscription time.
+ let (output_tx, _) = tokio::sync::broadcast::channel::<Vec<u8>>(256);
+
+ // Reader task: drain PTY and forward chunks to output channel.
+ let mut reader = pair.master.try_clone_reader()?;
+ let output_tx_clone = output_tx.clone();
+ let reader_handle = tokio::task::spawn_blocking(move || {
+ let mut buf = [0u8; 8192];
+ loop {
+ match reader.read(&mut buf) {
+ Ok(0) => break, // EOF
+ Ok(n) => {
+ // Forward to broadcast; best-effort if there are subscribers.
+ let _ = output_tx_clone.send(buf[..n].to_vec());
+ }
+ Err(_) => break,
@hanson-openai I think this is the read that you are worried about BlockingIO. I'll add some logic.