Files
codex/prs/bolinfest/PR-2574.md
2025-09-02 15:17:45 -07:00

60 KiB

PR #2574: feat: StreamableShell with exec_command and write_stdin tools

Description

This introduces a complementary set of tools, exec_command and write_stdin, which are designed to facilitate working with long-running processes in a token-efficient manner.

To test:

codex-rs$ just codex --cd .. --config experimental_use_exec_command_tool=true

Though the above alone is unlikely to convince the model to use the new tools because of this bit in the base instructions:

e4c275d615/codex-rs/core/prompt.md (L266-L271)

So you probably need to add:

--config experimental_instructions_file=/some/other_instructions.md

where /some/other_instructions.md suggests using exec_command and write_stdin instead of shell.

Full Diff

diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock
index 94702a83ed..dbccbd863e 100644
--- a/codex-rs/Cargo.lock
+++ b/codex-rs/Cargo.lock
@@ -731,6 +731,7 @@ dependencies = [
  "mime_guess",
  "openssl-sys",
  "os_info",
+ "portable-pty",
  "predicates",
  "pretty_assertions",
  "rand 0.9.2",
@@ -1479,6 +1480,12 @@ version = "0.15.7"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "1aaf95b3e5c8f23aa320147307562d361db0ae0d51242340f558153b4eb2439b"
 
+[[package]]
+name = "downcast-rs"
+version = "1.2.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "75b325c5dbd37f80359721ad39aca5a29fb04c89279657cffdda8736d0c0b9d2"
+
 [[package]]
 name = "dupe"
 version = "0.9.1"
@@ -1724,6 +1731,17 @@ dependencies = [
  "simd-adler32",
 ]
 
+[[package]]
+name = "filedescriptor"
+version = "0.8.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "e40758ed24c9b2eeb76c35fb0aebc66c626084edd827e07e1552279814c6682d"
+dependencies = [
+ "libc",
+ "thiserror 1.0.69",
+ "winapi",
+]
+
 [[package]]
 name = "fixedbitset"
 version = "0.4.2"
@@ -3439,6 +3457,27 @@ dependencies = [
  "portable-atomic",
 ]
 
+[[package]]
+name = "portable-pty"
+version = "0.9.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "b4a596a2b3d2752d94f51fac2d4a96737b8705dddd311a32b9af47211f08671e"
+dependencies = [
+ "anyhow",
+ "bitflags 1.3.2",
+ "downcast-rs",
+ "filedescriptor",
+ "lazy_static",
+ "libc",
+ "log",
+ "nix",
+ "serial2",
+ "shared_library",
+ "shell-words",
+ "winapi",
+ "winreg",
+]
+
 [[package]]
 name = "potential_utf"
 version = "0.1.2"
@@ -4366,6 +4405,17 @@ dependencies = [
  "syn 2.0.104",
 ]
 
+[[package]]
+name = "serial2"
+version = "0.2.31"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "26e1e5956803a69ddd72ce2de337b577898801528749565def03515f82bad5bb"
+dependencies = [
+ "cfg-if",
+ "libc",
+ "winapi",
+]
+
 [[package]]
 name = "sha1"
 version = "0.10.6"
@@ -4397,6 +4447,22 @@ dependencies = [
  "lazy_static",
 ]
 
+[[package]]
+name = "shared_library"
+version = "0.1.9"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "5a9e7e0f2bfae24d8a5b5a66c5b257a83c7412304311512a0c054cd5e619da11"
+dependencies = [
+ "lazy_static",
+ "libc",
+]
+
+[[package]]
+name = "shell-words"
+version = "1.1.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "24188a676b6ae68c3b2cb3a01be17fbf7240ce009799bb56d5b1409051e78fde"
+
 [[package]]
 name = "shlex"
 version = "1.3.0"
@@ -6176,6 +6242,15 @@ dependencies = [
  "memchr",
 ]
 
+[[package]]
+name = "winreg"
+version = "0.10.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "80d0f4e272c85def139476380b12f9ac60926689dd2e01d4923222f40580869d"
+dependencies = [
+ "winapi",
+]
+
 [[package]]
 name = "winsafe"
 version = "0.0.19"
diff --git a/codex-rs/core/Cargo.toml b/codex-rs/core/Cargo.toml
index 56815ba03c..2f2fa7cbad 100644
--- a/codex-rs/core/Cargo.toml
+++ b/codex-rs/core/Cargo.toml
@@ -28,6 +28,7 @@ libc = "0.2.175"
 mcp-types = { path = "../mcp-types" }
 mime_guess = "2.0"
 os_info = "3.12.0"
+portable-pty = "0.9.0"
 rand = "0.9"
 regex-lite = "0.1.6"
 reqwest = { version = "0.12", features = ["json", "stream"] }
diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs
index 79b73a3335..e7be2ba997 100644
--- a/codex-rs/core/src/codex.rs
+++ b/codex-rs/core/src/codex.rs
@@ -52,6 +52,11 @@ use crate::exec::SandboxType;
 use crate::exec::StdoutStream;
 use crate::exec::StreamOutput;
 use crate::exec::process_exec_tool_call;
+use crate::exec_command::EXEC_COMMAND_TOOL_NAME;
+use crate::exec_command::ExecCommandParams;
+use crate::exec_command::SESSION_MANAGER;
+use crate::exec_command::WRITE_STDIN_TOOL_NAME;
+use crate::exec_command::WriteStdinParams;
 use crate::exec_env::create_env;
 use crate::mcp_connection_manager::McpConnectionManager;
 use crate::mcp_tool_call::handle_mcp_tool_call;
@@ -488,6 +493,7 @@ impl Session {
                 sandbox_policy.clone(),
                 config.include_plan_tool,
                 config.include_apply_patch_tool,
+                config.use_experimental_streamable_shell_tool,
             ),
             user_instructions,
             base_instructions,
@@ -1069,6 +1075,7 @@ async fn submission_loop(
                     new_sandbox_policy.clone(),
                     config.include_plan_tool,
                     config.include_apply_patch_tool,
+                    config.use_experimental_streamable_shell_tool,
                 );
 
                 let new_turn_context = TurnContext {
@@ -1147,6 +1154,7 @@ async fn submission_loop(
                             sandbox_policy.clone(),
                             config.include_plan_tool,
                             config.include_apply_patch_tool,
+                            config.use_experimental_streamable_shell_tool,
                         ),
                         user_instructions: turn_context.user_instructions.clone(),
                         base_instructions: turn_context.base_instructions.clone(),
@@ -2037,6 +2045,52 @@ async fn handle_function_call(
             .await
         }
         "update_plan" => handle_update_plan(sess, arguments, sub_id, call_id).await,
+        EXEC_COMMAND_TOOL_NAME => {
+            // TODO(mbolin): Sandbox check.
+            let exec_params = match serde_json::from_str::<ExecCommandParams>(&arguments) {
+                Ok(params) => params,
+                Err(e) => {
+                    return ResponseInputItem::FunctionCallOutput {
+                        call_id,
+                        output: FunctionCallOutputPayload {
+                            content: format!("failed to parse function arguments: {e}"),
+                            success: Some(false),
+                        },
+                    };
+                }
+            };
+            let result = SESSION_MANAGER
+                .handle_exec_command_request(exec_params)
+                .await;
+            let function_call_output = crate::exec_command::result_into_payload(result);
+            ResponseInputItem::FunctionCallOutput {
+                call_id,
+                output: function_call_output,
+            }
+        }
+        WRITE_STDIN_TOOL_NAME => {
+            let write_stdin_params = match serde_json::from_str::<WriteStdinParams>(&arguments) {
+                Ok(params) => params,
+                Err(e) => {
+                    return ResponseInputItem::FunctionCallOutput {
+                        call_id,
+                        output: FunctionCallOutputPayload {
+                            content: format!("failed to parse function arguments: {e}"),
+                            success: Some(false),
+                        },
+                    };
+                }
+            };
+            let result = SESSION_MANAGER
+                .handle_write_stdin_request(write_stdin_params)
+                .await;
+            let function_call_output: FunctionCallOutputPayload =
+                crate::exec_command::result_into_payload(result);
+            ResponseInputItem::FunctionCallOutput {
+                call_id,
+                output: function_call_output,
+            }
+        }
         _ => {
             match sess.mcp_connection_manager.parse_tool_name(&name) {
                 Some((server, tool_name)) => {
diff --git a/codex-rs/core/src/config.rs b/codex-rs/core/src/config.rs
index 67a54eb1d7..fbf0387a01 100644
--- a/codex-rs/core/src/config.rs
+++ b/codex-rs/core/src/config.rs
@@ -174,6 +174,8 @@ pub struct Config {
 
     /// If set to `true`, the API key will be signed with the `originator` header.
     pub preferred_auth_method: AuthMode,
+
+    pub use_experimental_streamable_shell_tool: bool,
 }
 
 impl Config {
@@ -469,6 +471,8 @@ pub struct ConfigToml {
     /// Experimental path to a file whose contents replace the built-in BASE_INSTRUCTIONS.
     pub experimental_instructions_file: Option<PathBuf>,
 
+    pub experimental_use_exec_command_tool: Option<bool>,
+
     /// The value for the `originator` header included with Responses API requests.
     pub responses_originator_header_internal_override: Option<String>,
 
@@ -758,6 +762,9 @@ impl Config {
             include_apply_patch_tool: include_apply_patch_tool.unwrap_or(false),
             responses_originator_header,
             preferred_auth_method: cfg.preferred_auth_method.unwrap_or(AuthMode::ChatGPT),
+            use_experimental_streamable_shell_tool: cfg
+                .experimental_use_exec_command_tool
+                .unwrap_or(false),
         };
         Ok(config)
     }
@@ -1124,6 +1131,7 @@ disable_response_storage = true
                 include_apply_patch_tool: false,
                 responses_originator_header: "codex_cli_rs".to_string(),
                 preferred_auth_method: AuthMode::ChatGPT,
+                use_experimental_streamable_shell_tool: false,
             },
             o3_profile_config
         );
@@ -1178,6 +1186,7 @@ disable_response_storage = true
             include_apply_patch_tool: false,
             responses_originator_header: "codex_cli_rs".to_string(),
             preferred_auth_method: AuthMode::ChatGPT,
+            use_experimental_streamable_shell_tool: false,
         };
 
         assert_eq!(expected_gpt3_profile_config, gpt3_profile_config);
@@ -1247,6 +1256,7 @@ disable_response_storage = true
             include_apply_patch_tool: false,
             responses_originator_header: "codex_cli_rs".to_string(),
             preferred_auth_method: AuthMode::ChatGPT,
+            use_experimental_streamable_shell_tool: false,
         };
 
         assert_eq!(expected_zdr_profile_config, zdr_profile_config);
diff --git a/codex-rs/core/src/exec_command/exec_command_params.rs b/codex-rs/core/src/exec_command/exec_command_params.rs
new file mode 100644
index 0000000000..11a3fd4596
--- /dev/null
+++ b/codex-rs/core/src/exec_command/exec_command_params.rs
@@ -0,0 +1,57 @@
+use serde::Deserialize;
+use serde::Serialize;
+
+use crate::exec_command::session_id::SessionId;
+
+#[derive(Debug, Clone, Deserialize)]
+pub struct ExecCommandParams {
+    pub(crate) cmd: String,
+
+    #[serde(default = "default_yield_time")]
+    pub(crate) yield_time_ms: u64,
+
+    #[serde(default = "max_output_tokens")]
+    pub(crate) max_output_tokens: u64,
+
+    #[serde(default = "default_shell")]
+    pub(crate) shell: String,
+
+    #[serde(default = "default_login")]
+    pub(crate) login: bool,
+}
+
+fn default_yield_time() -> u64 {
+    10_000
+}
+
+fn max_output_tokens() -> u64 {
+    10_000
+}
+
+fn default_login() -> bool {
+    true
+}
+
+fn default_shell() -> String {
+    "/bin/bash".to_string()
+}
+
+#[derive(Debug, Deserialize, Serialize)]
+pub struct WriteStdinParams {
+    pub(crate) session_id: SessionId,
+    pub(crate) chars: String,
+
+    #[serde(default = "write_stdin_default_yield_time_ms")]
+    pub(crate) yield_time_ms: u64,
+
+    #[serde(default = "write_stdin_default_max_output_tokens")]
+    pub(crate) max_output_tokens: u64,
+}
+
+fn write_stdin_default_yield_time_ms() -> u64 {
+    250
+}
+
+fn write_stdin_default_max_output_tokens() -> u64 {
+    10_000
+}
diff --git a/codex-rs/core/src/exec_command/exec_command_session.rs b/codex-rs/core/src/exec_command/exec_command_session.rs
new file mode 100644
index 0000000000..7503150c9a
--- /dev/null
+++ b/codex-rs/core/src/exec_command/exec_command_session.rs
@@ -0,0 +1,83 @@
+use std::sync::Mutex as StdMutex;
+
+use tokio::sync::broadcast;
+use tokio::sync::mpsc;
+use tokio::task::JoinHandle;
+
+#[derive(Debug)]
+pub(crate) struct ExecCommandSession {
+    /// Queue for writing bytes to the process stdin (PTY master write side).
+    writer_tx: mpsc::Sender<Vec<u8>>,
+    /// Broadcast stream of output chunks read from the PTY. New subscribers
+    /// receive only chunks emitted after they subscribe.
+    output_tx: broadcast::Sender<Vec<u8>>,
+
+    /// Child killer handle for termination on drop (can signal independently
+    /// of a thread blocked in `.wait()`).
+    killer: StdMutex<Option<Box<dyn portable_pty::ChildKiller + Send + Sync>>>,
+
+    /// JoinHandle for the blocking PTY reader task.
+    reader_handle: StdMutex<Option<JoinHandle<()>>>,
+
+    /// JoinHandle for the stdin writer task.
+    writer_handle: StdMutex<Option<JoinHandle<()>>>,
+
+    /// JoinHandle for the child wait task.
+    wait_handle: StdMutex<Option<JoinHandle<()>>>,
+}
+
+impl ExecCommandSession {
+    pub(crate) fn new(
+        writer_tx: mpsc::Sender<Vec<u8>>,
+        output_tx: broadcast::Sender<Vec<u8>>,
+        killer: Box<dyn portable_pty::ChildKiller + Send + Sync>,
+        reader_handle: JoinHandle<()>,
+        writer_handle: JoinHandle<()>,
+        wait_handle: JoinHandle<()>,
+    ) -> Self {
+        Self {
+            writer_tx,
+            output_tx,
+            killer: StdMutex::new(Some(killer)),
+            reader_handle: StdMutex::new(Some(reader_handle)),
+            writer_handle: StdMutex::new(Some(writer_handle)),
+            wait_handle: StdMutex::new(Some(wait_handle)),
+        }
+    }
+
+    pub(crate) fn writer_sender(&self) -> mpsc::Sender<Vec<u8>> {
+        self.writer_tx.clone()
+    }
+
+    pub(crate) fn output_receiver(&self) -> broadcast::Receiver<Vec<u8>> {
+        self.output_tx.subscribe()
+    }
+}
+
+impl Drop for ExecCommandSession {
+    fn drop(&mut self) {
+        // Best-effort: terminate child first so blocking tasks can complete.
+        if let Ok(mut killer_opt) = self.killer.lock()
+            && let Some(mut killer) = killer_opt.take()
+        {
+            let _ = killer.kill();
+        }
+
+        // Abort background tasks; they may already have exited after kill.
+        if let Ok(mut h) = self.reader_handle.lock()
+            && let Some(handle) = h.take()
+        {
+            handle.abort();
+        }
+        if let Ok(mut h) = self.writer_handle.lock()
+            && let Some(handle) = h.take()
+        {
+            handle.abort();
+        }
+        if let Ok(mut h) = self.wait_handle.lock()
+            && let Some(handle) = h.take()
+        {
+            handle.abort();
+        }
+    }
+}
diff --git a/codex-rs/core/src/exec_command/mod.rs b/codex-rs/core/src/exec_command/mod.rs
new file mode 100644
index 0000000000..2fd88d4ec5
--- /dev/null
+++ b/codex-rs/core/src/exec_command/mod.rs
@@ -0,0 +1,14 @@
+mod exec_command_params;
+mod exec_command_session;
+mod responses_api;
+mod session_id;
+mod session_manager;
+
+pub use exec_command_params::ExecCommandParams;
+pub use exec_command_params::WriteStdinParams;
+pub use responses_api::EXEC_COMMAND_TOOL_NAME;
+pub use responses_api::WRITE_STDIN_TOOL_NAME;
+pub use responses_api::create_exec_command_tool_for_responses_api;
+pub use responses_api::create_write_stdin_tool_for_responses_api;
+pub use session_manager::SESSION_MANAGER;
+pub use session_manager::result_into_payload;
diff --git a/codex-rs/core/src/exec_command/responses_api.rs b/codex-rs/core/src/exec_command/responses_api.rs
new file mode 100644
index 0000000000..70b90dd425
--- /dev/null
+++ b/codex-rs/core/src/exec_command/responses_api.rs
@@ -0,0 +1,98 @@
+use std::collections::BTreeMap;
+
+use crate::openai_tools::JsonSchema;
+use crate::openai_tools::ResponsesApiTool;
+
+pub const EXEC_COMMAND_TOOL_NAME: &str = "exec_command";
+pub const WRITE_STDIN_TOOL_NAME: &str = "write_stdin";
+
+pub fn create_exec_command_tool_for_responses_api() -> ResponsesApiTool {
+    let mut properties = BTreeMap::<String, JsonSchema>::new();
+    properties.insert(
+        "cmd".to_string(),
+        JsonSchema::String {
+            description: Some("The shell command to execute.".to_string()),
+        },
+    );
+    properties.insert(
+        "yield_time_ms".to_string(),
+        JsonSchema::Number {
+            description: Some("The maximum time in milliseconds to wait for output.".to_string()),
+        },
+    );
+    properties.insert(
+        "max_output_tokens".to_string(),
+        JsonSchema::Number {
+            description: Some("The maximum number of tokens to output.".to_string()),
+        },
+    );
+    properties.insert(
+        "shell".to_string(),
+        JsonSchema::String {
+            description: Some("The shell to use. Defaults to \"/bin/bash\".".to_string()),
+        },
+    );
+    properties.insert(
+        "login".to_string(),
+        JsonSchema::Boolean {
+            description: Some(
+                "Whether to run the command as a login shell. Defaults to true.".to_string(),
+            ),
+        },
+    );
+
+    ResponsesApiTool {
+        name: EXEC_COMMAND_TOOL_NAME.to_owned(),
+        description: r#"Execute shell commands on the local machine with streaming output."#
+            .to_string(),
+        strict: false,
+        parameters: JsonSchema::Object {
+            properties,
+            required: Some(vec!["cmd".to_string()]),
+            additional_properties: Some(false),
+        },
+    }
+}
+
+pub fn create_write_stdin_tool_for_responses_api() -> ResponsesApiTool {
+    let mut properties = BTreeMap::<String, JsonSchema>::new();
+    properties.insert(
+        "session_id".to_string(),
+        JsonSchema::Number {
+            description: Some("The ID of the exec_command session.".to_string()),
+        },
+    );
+    properties.insert(
+        "chars".to_string(),
+        JsonSchema::String {
+            description: Some("The characters to write to stdin.".to_string()),
+        },
+    );
+    properties.insert(
+        "yield_time_ms".to_string(),
+        JsonSchema::Number {
+            description: Some(
+                "The maximum time in milliseconds to wait for output after writing.".to_string(),
+            ),
+        },
+    );
+    properties.insert(
+        "max_output_tokens".to_string(),
+        JsonSchema::Number {
+            description: Some("The maximum number of tokens to output.".to_string()),
+        },
+    );
+
+    ResponsesApiTool {
+        name: WRITE_STDIN_TOOL_NAME.to_owned(),
+        description: r#"Write characters to an exec session's stdin. Returns all stdout+stderr received within yield_time_ms.
+Can write control characters (\u0003 for Ctrl-C), or an empty string to just poll stdout+stderr."#
+            .to_string(),
+        strict: false,
+        parameters: JsonSchema::Object {
+            properties,
+            required: Some(vec!["session_id".to_string(), "chars".to_string()]),
+            additional_properties: Some(false),
+        },
+    }
+}
diff --git a/codex-rs/core/src/exec_command/session_id.rs b/codex-rs/core/src/exec_command/session_id.rs
new file mode 100644
index 0000000000..c97c5d5440
--- /dev/null
+++ b/codex-rs/core/src/exec_command/session_id.rs
@@ -0,0 +1,5 @@
+use serde::Deserialize;
+use serde::Serialize;
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
+pub(crate) struct SessionId(pub u32);
diff --git a/codex-rs/core/src/exec_command/session_manager.rs b/codex-rs/core/src/exec_command/session_manager.rs
new file mode 100644
index 0000000000..213b874bfa
--- /dev/null
+++ b/codex-rs/core/src/exec_command/session_manager.rs
@@ -0,0 +1,677 @@
+use std::collections::HashMap;
+use std::io::ErrorKind;
+use std::io::Read;
+use std::sync::Arc;
+use std::sync::LazyLock;
+use std::sync::Mutex as StdMutex;
+use std::sync::atomic::AtomicU32;
+
+use portable_pty::CommandBuilder;
+use portable_pty::PtySize;
+use portable_pty::native_pty_system;
+use tokio::sync::Mutex;
+use tokio::sync::mpsc;
+use tokio::sync::oneshot;
+use tokio::time::Duration;
+use tokio::time::Instant;
+use tokio::time::timeout;
+
+use crate::exec_command::exec_command_params::ExecCommandParams;
+use crate::exec_command::exec_command_params::WriteStdinParams;
+use crate::exec_command::exec_command_session::ExecCommandSession;
+use crate::exec_command::session_id::SessionId;
+use codex_protocol::models::FunctionCallOutputPayload;
+
+pub static SESSION_MANAGER: LazyLock<SessionManager> = LazyLock::new(SessionManager::default);
+
+#[derive(Debug, Default)]
+pub struct SessionManager {
+    next_session_id: AtomicU32,
+    sessions: Mutex<HashMap<SessionId, ExecCommandSession>>,
+}
+
+#[derive(Debug)]
+pub struct ExecCommandOutput {
+    wall_time: Duration,
+    exit_status: ExitStatus,
+    original_token_count: Option<u64>,
+    output: String,
+}
+
+impl ExecCommandOutput {
+    fn to_text_output(&self) -> String {
+        let wall_time_secs = self.wall_time.as_secs_f32();
+        let termination_status = match self.exit_status {
+            ExitStatus::Exited(code) => format!("Process exited with code {code}"),
+            ExitStatus::Ongoing(session_id) => {
+                format!("Process running with session ID {}", session_id.0)
+            }
+        };
+        let truncation_status = match self.original_token_count {
+            Some(tokens) => {
+                format!("\nWarning: truncated output (original token count: {tokens})")
+            }
+            None => "".to_string(),
+        };
+        format!(
+            r#"Wall time: {wall_time_secs:.3} seconds
+{termination_status}{truncation_status}
+Output:
+{output}"#,
+            output = self.output
+        )
+    }
+}
+
+#[derive(Debug)]
+pub enum ExitStatus {
+    Exited(i32),
+    Ongoing(SessionId),
+}
+
+pub fn result_into_payload(result: Result<ExecCommandOutput, String>) -> FunctionCallOutputPayload {
+    match result {
+        Ok(output) => FunctionCallOutputPayload {
+            content: output.to_text_output(),
+            success: Some(true),
+        },
+        Err(err) => FunctionCallOutputPayload {
+            content: err,
+            success: Some(false),
+        },
+    }
+}
+
+impl SessionManager {
+    /// Processes the request and is required to send a response via `outgoing`.
+    pub async fn handle_exec_command_request(
+        &self,
+        params: ExecCommandParams,
+    ) -> Result<ExecCommandOutput, String> {
+        // Allocate a session id.
+        let session_id = SessionId(
+            self.next_session_id
+                .fetch_add(1, std::sync::atomic::Ordering::SeqCst),
+        );
+
+        let (session, mut exit_rx) =
+            create_exec_command_session(params.clone())
+                .await
+                .map_err(|err| {
+                    format!(
+                        "failed to create exec command session for session id {}: {err}",
+                        session_id.0
+                    )
+                })?;
+
+        // Insert into session map.
+        let mut output_rx = session.output_receiver();
+        self.sessions.lock().await.insert(session_id, session);
+
+        // Collect output until either timeout expires or process exits.
+        // Do not cap during collection; truncate at the end if needed.
+        // Use a modest initial capacity to avoid large preallocation.
+        let cap_bytes_u64 = params.max_output_tokens.saturating_mul(4);
+        let cap_bytes: usize = cap_bytes_u64.min(usize::MAX as u64) as usize;
+        let mut collected: Vec<u8> = Vec::with_capacity(4096);
+
+        let start_time = Instant::now();
+        let deadline = start_time + Duration::from_millis(params.yield_time_ms);
+        let mut exit_code: Option<i32> = None;
+
+        loop {
+            if Instant::now() >= deadline {
+                break;
+            }
+            let remaining = deadline.saturating_duration_since(Instant::now());
+            tokio::select! {
+                biased;
+                exit = &mut exit_rx => {
+                    exit_code = exit.ok();
+                    // Small grace period to pull remaining buffered output
+                    let grace_deadline = Instant::now() + Duration::from_millis(25);
+                    while Instant::now() < grace_deadline {
+                        match timeout(Duration::from_millis(1), output_rx.recv()).await {
+                            Ok(Ok(chunk)) => {
+                                collected.extend_from_slice(&chunk);
+                            }
+                            Ok(Err(tokio::sync::broadcast::error::RecvError::Lagged(_))) => {
+                                // Skip missed messages; keep trying within grace period.
+                                continue;
+                            }
+                            Ok(Err(tokio::sync::broadcast::error::RecvError::Closed)) => break,
+                            Err(_) => break,
+                        }
+                    }
+                    break;
+                }
+                chunk = timeout(remaining, output_rx.recv()) => {
+                    match chunk {
+                        Ok(Ok(chunk)) => {
+                            collected.extend_from_slice(&chunk);
+                        }
+                        Ok(Err(tokio::sync::broadcast::error::RecvError::Lagged(_))) => {
+                            // Skip missed messages; continue collecting fresh output.
+                        }
+                        Ok(Err(tokio::sync::broadcast::error::RecvError::Closed)) => { break; }
+                        Err(_) => { break; }
+                    }
+                }
+            }
+        }
+
+        let output = String::from_utf8_lossy(&collected).to_string();
+
+        let exit_status = if let Some(code) = exit_code {
+            ExitStatus::Exited(code)
+        } else {
+            ExitStatus::Ongoing(session_id)
+        };
+
+        // If output exceeds cap, truncate the middle and record original token estimate.
+        let (output, original_token_count) = truncate_middle(&output, cap_bytes);
+        Ok(ExecCommandOutput {
+            wall_time: Instant::now().duration_since(start_time),
+            exit_status,
+            original_token_count,
+            output,
+        })
+    }
+
+    /// Write characters to a session's stdin and collect combined output for up to `yield_time_ms`.
+    pub async fn handle_write_stdin_request(
+        &self,
+        params: WriteStdinParams,
+    ) -> Result<ExecCommandOutput, String> {
+        let WriteStdinParams {
+            session_id,
+            chars,
+            yield_time_ms,
+            max_output_tokens,
+        } = params;
+
+        // Grab handles without holding the sessions lock across await points.
+        let (writer_tx, mut output_rx) = {
+            let sessions = self.sessions.lock().await;
+            match sessions.get(&session_id) {
+                Some(session) => (session.writer_sender(), session.output_receiver()),
+                None => {
+                    return Err(format!("unknown session id {}", session_id.0));
+                }
+            }
+        };
+
+        // Write stdin if provided.
+        if !chars.is_empty() && writer_tx.send(chars.into_bytes()).await.is_err() {
+            return Err("failed to write to stdin".to_string());
+        }
+
+        // Collect output up to yield_time_ms, truncating to max_output_tokens bytes.
+        let mut collected: Vec<u8> = Vec::with_capacity(4096);
+        let start_time = Instant::now();
+        let deadline = start_time + Duration::from_millis(yield_time_ms);
+        loop {
+            let now = Instant::now();
+            if now >= deadline {
+                break;
+            }
+            let remaining = deadline - now;
+            match timeout(remaining, output_rx.recv()).await {
+                Ok(Ok(chunk)) => {
+                    // Collect all output within the time budget; truncate at the end.
+                    collected.extend_from_slice(&chunk);
+                }
+                Ok(Err(tokio::sync::broadcast::error::RecvError::Lagged(_))) => {
+                    // Skip missed messages; continue collecting fresh output.
+                }
+                Ok(Err(tokio::sync::broadcast::error::RecvError::Closed)) => break,
+                Err(_) => break, // timeout
+            }
+        }
+
+        // Return structured output, truncating middle if over cap.
+        let output = String::from_utf8_lossy(&collected).to_string();
+        let cap_bytes_u64 = max_output_tokens.saturating_mul(4);
+        let cap_bytes: usize = cap_bytes_u64.min(usize::MAX as u64) as usize;
+        let (output, original_token_count) = truncate_middle(&output, cap_bytes);
+        Ok(ExecCommandOutput {
+            wall_time: Instant::now().duration_since(start_time),
+            exit_status: ExitStatus::Ongoing(session_id),
+            original_token_count,
+            output,
+        })
+    }
+}
+
+/// Spawn PTY and child process per spawn_exec_command_session logic.
+async fn create_exec_command_session(
+    params: ExecCommandParams,
+) -> anyhow::Result<(ExecCommandSession, oneshot::Receiver<i32>)> {
+    let ExecCommandParams {
+        cmd,
+        yield_time_ms: _,
+        max_output_tokens: _,
+        shell,
+        login,
+    } = params;
+
+    // Use the native pty implementation for the system
+    let pty_system = native_pty_system();
+
+    // Create a new pty
+    let pair = pty_system.openpty(PtySize {
+        rows: 24,
+        cols: 80,
+        pixel_width: 0,
+        pixel_height: 0,
+    })?;
+
+    // Spawn a shell into the pty
+    let mut command_builder = CommandBuilder::new(shell);
+    let shell_mode_opt = if login { "-lc" } else { "-c" };
+    command_builder.arg(shell_mode_opt);
+    command_builder.arg(cmd);
+
+    let mut child = pair.slave.spawn_command(command_builder)?;
+    // Obtain a killer that can signal the process independently of `.wait()`.
+    let killer = child.clone_killer();
+
+    // Channel to forward write requests to the PTY writer.
+    let (writer_tx, mut writer_rx) = mpsc::channel::<Vec<u8>>(128);
+    // Broadcast for streaming PTY output to readers: subscribers receive from subscription time.
+    let (output_tx, _) = tokio::sync::broadcast::channel::<Vec<u8>>(256);
+
+    // Reader task: drain PTY and forward chunks to output channel.
+    let mut reader = pair.master.try_clone_reader()?;
+    let output_tx_clone = output_tx.clone();
+    let reader_handle = tokio::task::spawn_blocking(move || {
+        let mut buf = [0u8; 8192];
+        loop {
+            match reader.read(&mut buf) {
+                Ok(0) => break, // EOF
+                Ok(n) => {
+                    // Forward to broadcast; best-effort if there are subscribers.
+                    let _ = output_tx_clone.send(buf[..n].to_vec());
+                }
+                Err(ref e) if e.kind() == ErrorKind::Interrupted => {
+                    // Retry on EINTR
+                    continue;
+                }
+                Err(ref e) if e.kind() == ErrorKind::WouldBlock => {
+                    // We're in a blocking thread; back off briefly and retry.
+                    std::thread::sleep(Duration::from_millis(5));
+                    continue;
+                }
+                Err(_) => break,
+            }
+        }
+    });
+
+    // Writer task: apply stdin writes to the PTY writer.
+    let writer = pair.master.take_writer()?;
+    let writer = Arc::new(StdMutex::new(writer));
+    let writer_handle = tokio::spawn({
+        let writer = writer.clone();
+        async move {
+            while let Some(bytes) = writer_rx.recv().await {
+                let writer = writer.clone();
+                // Perform blocking write on a blocking thread.
+                let _ = tokio::task::spawn_blocking(move || {
+                    if let Ok(mut guard) = writer.lock() {
+                        use std::io::Write;
+                        let _ = guard.write_all(&bytes);
+                        let _ = guard.flush();
+                    }
+                })
+                .await;
+            }
+        }
+    });
+
+    // Keep the child alive until it exits, then signal exit code.
+    let (exit_tx, exit_rx) = oneshot::channel::<i32>();
+    let wait_handle = tokio::task::spawn_blocking(move || {
+        let code = match child.wait() {
+            Ok(status) => status.exit_code() as i32,
+            Err(_) => -1,
+        };
+        let _ = exit_tx.send(code);
+    });
+
+    // Create and store the session with channels.
+    let session = ExecCommandSession::new(
+        writer_tx,
+        output_tx,
+        killer,
+        reader_handle,
+        writer_handle,
+        wait_handle,
+    );
+    Ok((session, exit_rx))
+}
+
+/// Truncate the middle of a UTF-8 string to at most `max_bytes` bytes,
+/// preserving the beginning and the end. Returns the possibly truncated
+/// string and `Some(original_token_count)` (estimated at 4 bytes/token)
+/// if truncation occurred; otherwise returns the original string and `None`.
+fn truncate_middle(s: &str, max_bytes: usize) -> (String, Option<u64>) {
+    // No truncation needed
+    if s.len() <= max_bytes {
+        return (s.to_string(), None);
+    }
+    let est_tokens = (s.len() as u64).div_ceil(4);
+    if max_bytes == 0 {
+        // Cannot keep any content; still return a full marker (never truncated).
+        return (
+            format!("…{} tokens truncated…", est_tokens),
+            Some(est_tokens),
+        );
+    }
+
+    // Helper to truncate a string to a given byte length on a char boundary.
+    fn truncate_on_boundary(input: &str, max_len: usize) -> &str {
+        if input.len() <= max_len {
+            return input;
+        }
+        let mut end = max_len;
+        while end > 0 && !input.is_char_boundary(end) {
+            end -= 1;
+        }
+        &input[..end]
+    }
+
+    // Given a left/right budget, prefer newline boundaries; otherwise fall back
+    // to UTF-8 char boundaries.
+    fn pick_prefix_end(s: &str, left_budget: usize) -> usize {
+        if let Some(head) = s.get(..left_budget)
+            && let Some(i) = head.rfind('\n')
+        {
+            return i + 1; // keep the newline so suffix starts on a fresh line
+        }
+        truncate_on_boundary(s, left_budget).len()
+    }
+
+    fn pick_suffix_start(s: &str, right_budget: usize) -> usize {
+        let start_tail = s.len().saturating_sub(right_budget);
+        if let Some(tail) = s.get(start_tail..)
+            && let Some(i) = tail.find('\n')
+        {
+            return start_tail + i + 1; // start after newline
+        }
+        // Fall back to a char boundary at or after start_tail.
+        let mut idx = start_tail.min(s.len());
+        while idx < s.len() && !s.is_char_boundary(idx) {
+            idx += 1;
+        }
+        idx
+    }
+
+    // Refine marker length and budgets until stable. Marker is never truncated.
+    let mut guess_tokens = est_tokens; // worst-case: everything truncated
+    for _ in 0..4 {
+        let marker = format!("…{} tokens truncated…", guess_tokens);
+        let marker_len = marker.len();
+        let keep_budget = max_bytes.saturating_sub(marker_len);
+        if keep_budget == 0 {
+            // No room for any content within the cap; return a full, untruncated marker
+            // that reflects the entire truncated content.
+            return (
+                format!("…{} tokens truncated…", est_tokens),
+                Some(est_tokens),
+            );
+        }
+
+        let left_budget = keep_budget / 2;
+        let right_budget = keep_budget - left_budget;
+        let prefix_end = pick_prefix_end(s, left_budget);
+        let mut suffix_start = pick_suffix_start(s, right_budget);
+        if suffix_start < prefix_end {
+            suffix_start = prefix_end;
+        }
+        let kept_content_bytes = prefix_end + (s.len() - suffix_start);
+        let truncated_content_bytes = s.len().saturating_sub(kept_content_bytes);
+        let new_tokens = (truncated_content_bytes as u64).div_ceil(4);
+        if new_tokens == guess_tokens {
+            let mut out = String::with_capacity(marker_len + kept_content_bytes + 1);
+            out.push_str(&s[..prefix_end]);
+            out.push_str(&marker);
+            // Place marker on its own line for symmetry when we keep line boundaries.
+            out.push('\n');
+            out.push_str(&s[suffix_start..]);
+            return (out, Some(est_tokens));
+        }
+        guess_tokens = new_tokens;
+    }
+
+    // Fallback: use last guess to build output.
+    let marker = format!("…{} tokens truncated…", guess_tokens);
+    let marker_len = marker.len();
+    let keep_budget = max_bytes.saturating_sub(marker_len);
+    if keep_budget == 0 {
+        return (
+            format!("…{} tokens truncated…", est_tokens),
+            Some(est_tokens),
+        );
+    }
+    let left_budget = keep_budget / 2;
+    let right_budget = keep_budget - left_budget;
+    let prefix_end = pick_prefix_end(s, left_budget);
+    let suffix_start = pick_suffix_start(s, right_budget);
+    let mut out = String::with_capacity(marker_len + prefix_end + (s.len() - suffix_start) + 1);
+    out.push_str(&s[..prefix_end]);
+    out.push_str(&marker);
+    out.push('\n');
+    out.push_str(&s[suffix_start..]);
+    (out, Some(est_tokens))
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use crate::exec_command::session_id::SessionId;
+
+    /// Test that verifies that [`SessionManager::handle_exec_command_request()`]
+    /// and [`SessionManager::handle_write_stdin_request()`] work as expected
+    /// in the presence of a process that never terminates (but produces
+    /// output continuously).
+    #[cfg(unix)]
+    #[allow(clippy::print_stderr)]
+    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
+    async fn session_manager_streams_and_truncates_from_now() {
+        use crate::exec_command::exec_command_params::ExecCommandParams;
+        use crate::exec_command::exec_command_params::WriteStdinParams;
+        use tokio::time::sleep;
+
+        let session_manager = SessionManager::default();
+        // Long-running loop that prints an increasing counter every ~100ms.
+        // Use Python for a portable, reliable sleep across shells/PTYs.
+        let cmd = r#"python3 - <<'PY'
+import sys, time
+count = 0
+while True:
+    print(count)
+    sys.stdout.flush()
+    count += 100
+    time.sleep(0.1)
+PY"#
+        .to_string();
+
+        // Start the session and collect ~3s of output.
+        let params = ExecCommandParams {
+            cmd,
+            yield_time_ms: 3_000,
+            max_output_tokens: 1_000, // large enough to avoid truncation here
+            shell: "/bin/bash".to_string(),
+            login: false,
+        };
+        let initial_output = match session_manager
+            .handle_exec_command_request(params.clone())
+            .await
+        {
+            Ok(v) => v,
+            Err(e) => {
+                // PTY may be restricted in some sandboxes; skip in that case.
+                if e.contains("openpty") || e.contains("Operation not permitted") {
+                    eprintln!("skipping test due to restricted PTY: {e}");
+                    return;
+                }
+                panic!("exec request failed unexpectedly: {e}");
+            }
+        };
+        eprintln!("initial output: {initial_output:?}");
+
+        // Should be ongoing (we launched a never-ending loop).
+        let session_id = match initial_output.exit_status {
+            ExitStatus::Ongoing(id) => id,
+            _ => panic!("expected ongoing session"),
+        };
+
+        // Parse the numeric lines and get the max observed value in the first window.
+        let first_nums = extract_monotonic_numbers(&initial_output.output);
+        assert!(
+            !first_nums.is_empty(),
+            "expected some output from first window"
+        );
+        let first_max = *first_nums.iter().max().unwrap();
+
+        // Wait ~4s so counters progress while we're not reading.
+        sleep(Duration::from_millis(4_000)).await;
+
+        // Now read ~3s of output "from now" only.
+        // Use a small token cap so truncation occurs and we test middle truncation.
+        let write_params = WriteStdinParams {
+            session_id,
+            chars: String::new(),
+            yield_time_ms: 3_000,
+            max_output_tokens: 16, // 16 tokens ~= 64 bytes -> likely truncation
+        };
+        let second = session_manager
+            .handle_write_stdin_request(write_params)
+            .await
+            .expect("write stdin should succeed");
+
+        // Verify truncation metadata and size bound (cap is tokens*4 bytes).
+        assert!(second.original_token_count.is_some());
+        let cap_bytes = (16u64 * 4) as usize;
+        assert!(second.output.len() <= cap_bytes);
+        // New middle marker should be present.
+        assert!(
+            second.output.contains("tokens truncated") && second.output.contains('…'),
+            "expected truncation marker in output, got: {}",
+            second.output
+        );
+
+        // Minimal freshness check: the earliest number we see in the second window
+        // should be significantly larger than the last from the first window.
+        let second_nums = extract_monotonic_numbers(&second.output);
+        assert!(
+            !second_nums.is_empty(),
+            "expected some numeric output from second window"
+        );
+        let second_min = *second_nums.iter().min().unwrap();
+
+        // We slept 4 seconds (~40 ticks at 100ms/tick, each +100), so expect
+        // an increase of roughly 4000 or more. Allow a generous margin.
+        assert!(
+            second_min >= first_max + 2000,
+            "second_min={second_min} first_max={first_max}",
+        );
+    }
+
+    #[cfg(unix)]
+    fn extract_monotonic_numbers(s: &str) -> Vec<i64> {
+        s.lines()
+            .filter_map(|line| {
+                if !line.is_empty()
+                    && line.chars().all(|c| c.is_ascii_digit())
+                    && let Ok(n) = line.parse::<i64>()
+                {
+                    // Our generator increments by 100; ignore spurious fragments.
+                    if n % 100 == 0 {
+                        return Some(n);
+                    }
+                }
+                None
+            })
+            .collect()
+    }
+
+    #[test]
+    fn to_text_output_exited_no_truncation() {
+        let out = ExecCommandOutput {
+            wall_time: Duration::from_millis(1234),
+            exit_status: ExitStatus::Exited(0),
+            original_token_count: None,
+            output: "hello".to_string(),
+        };
+        let text = out.to_text_output();
+        let expected = r#"Wall time: 1.234 seconds
+Process exited with code 0
+Output:
+hello"#;
+        assert_eq!(expected, text);
+    }
+
+    #[test]
+    fn to_text_output_ongoing_with_truncation() {
+        let out = ExecCommandOutput {
+            wall_time: Duration::from_millis(500),
+            exit_status: ExitStatus::Ongoing(SessionId(42)),
+            original_token_count: Some(1000),
+            output: "abc".to_string(),
+        };
+        let text = out.to_text_output();
+        let expected = r#"Wall time: 0.500 seconds
+Process running with session ID 42
+Warning: truncated output (original token count: 1000)
+Output:
+abc"#;
+        assert_eq!(expected, text);
+    }
+
+    #[test]
+    fn truncate_middle_no_newlines_fallback() {
+        // A long string with no newlines that exceeds the cap.
+        let s = "abcdefghijklmnopqrstuvwxyz0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ";
+        let max_bytes = 16; // force truncation
+        let (out, original) = truncate_middle(s, max_bytes);
+        // For very small caps, we return the full, untruncated marker,
+        // even if it exceeds the cap.
+        assert_eq!(out, "…16 tokens truncated…");
+        // Original string length is 62 bytes => ceil(62/4) = 16 tokens.
+        assert_eq!(original, Some(16));
+    }
+
+    #[test]
+    fn truncate_middle_prefers_newline_boundaries() {
+        // Build a multi-line string of 20 numbered lines (each "NNN\n").
+        let mut s = String::new();
+        for i in 1..=20 {
+            s.push_str(&format!("{i:03}\n"));
+        }
+        // Total length: 20 lines * 4 bytes per line = 80 bytes.
+        assert_eq!(s.len(), 80);
+
+        // Choose a cap that forces truncation while leaving room for
+        // a few lines on each side after accounting for the marker.
+        let max_bytes = 64;
+        // Expect exact output: first 4 lines, marker, last 4 lines, and correct token estimate (80/4 = 20).
+        assert_eq!(
+            truncate_middle(&s, max_bytes),
+            (
+                r#"001
+002
+003
+004
+…12 tokens truncated…
+017
+018
+019
+020
+"#
+                .to_string(),
+                Some(20)
+            )
+        );
+    }
+}
diff --git a/codex-rs/core/src/lib.rs b/codex-rs/core/src/lib.rs
index 6d4699bceb..ae18332087 100644
--- a/codex-rs/core/src/lib.rs
+++ b/codex-rs/core/src/lib.rs
@@ -20,6 +20,7 @@ mod conversation_history;
 mod environment_context;
 pub mod error;
 pub mod exec;
+mod exec_command;
 pub mod exec_env;
 mod flags;
 pub mod git_info;
diff --git a/codex-rs/core/src/openai_tools.rs b/codex-rs/core/src/openai_tools.rs
index bb5e6dacbf..272c901dc2 100644
--- a/codex-rs/core/src/openai_tools.rs
+++ b/codex-rs/core/src/openai_tools.rs
@@ -56,6 +56,7 @@ pub enum ConfigShellToolType {
     DefaultShell,
     ShellWithRequest { sandbox_policy: SandboxPolicy },
     LocalShell,
+    StreamableShell,
 }
 
 #[derive(Debug, Clone)]
@@ -72,13 +73,16 @@ impl ToolsConfig {
         sandbox_policy: SandboxPolicy,
         include_plan_tool: bool,
         include_apply_patch_tool: bool,
+        use_streamable_shell_tool: bool,
     ) -> Self {
-        let mut shell_type = if model_family.uses_local_shell_tool {
+        let mut shell_type = if use_streamable_shell_tool {
+            ConfigShellToolType::StreamableShell
+        } else if model_family.uses_local_shell_tool {
             ConfigShellToolType::LocalShell
         } else {
             ConfigShellToolType::DefaultShell
         };
-        if matches!(approval_policy, AskForApproval::OnRequest) {
+        if matches!(approval_policy, AskForApproval::OnRequest) && !use_streamable_shell_tool {
             shell_type = ConfigShellToolType::ShellWithRequest {
                 sandbox_policy: sandbox_policy.clone(),
             }
@@ -492,6 +496,14 @@ pub(crate) fn get_openai_tools(
         ConfigShellToolType::LocalShell => {
             tools.push(OpenAiTool::LocalShell {});
         }
+        ConfigShellToolType::StreamableShell => {
+            tools.push(OpenAiTool::Function(
+                crate::exec_command::create_exec_command_tool_for_responses_api(),
+            ));
+            tools.push(OpenAiTool::Function(
+                crate::exec_command::create_write_stdin_tool_for_responses_api(),
+            ));
+        }
     }
 
     if config.plan_tool {
@@ -564,6 +576,7 @@ mod tests {
             SandboxPolicy::ReadOnly,
             true,
             false,
+            /*use_experimental_streamable_shell_tool*/ false,
         );
         let tools = get_openai_tools(&config, Some(HashMap::new()));
 
@@ -579,6 +592,7 @@ mod tests {
             SandboxPolicy::ReadOnly,
             true,
             false,
+            /*use_experimental_streamable_shell_tool*/ false,
         );
         let tools = get_openai_tools(&config, Some(HashMap::new()));
 
@@ -594,6 +608,7 @@ mod tests {
             SandboxPolicy::ReadOnly,
             false,
             false,
+            /*use_experimental_streamable_shell_tool*/ false,
         );
         let tools = get_openai_tools(
             &config,
@@ -688,6 +703,7 @@ mod tests {
             SandboxPolicy::ReadOnly,
             false,
             false,
+            /*use_experimental_streamable_shell_tool*/ false,
         );
 
         let tools = get_openai_tools(
@@ -744,6 +760,7 @@ mod tests {
             SandboxPolicy::ReadOnly,
             false,
             false,
+            /*use_experimental_streamable_shell_tool*/ false,
         );
 
         let tools = get_openai_tools(
@@ -795,6 +812,7 @@ mod tests {
             SandboxPolicy::ReadOnly,
             false,
             false,
+            /*use_experimental_streamable_shell_tool*/ false,
         );
 
         let tools = get_openai_tools(
@@ -849,6 +867,7 @@ mod tests {
             SandboxPolicy::ReadOnly,
             false,
             false,
+            /*use_experimental_streamable_shell_tool*/ false,
         );
 
         let tools = get_openai_tools(

Review Comments

codex-rs/core/src/exec_command/session_manager.rs

@@ -0,0 +1,580 @@
+use std::collections::HashMap;
+use std::io::Read;
+use std::sync::Arc;
+use std::sync::LazyLock;
+use std::sync::Mutex as StdMutex;
+use std::sync::atomic::AtomicU32;
+
+use portable_pty::CommandBuilder;
+use portable_pty::PtySize;
+use portable_pty::native_pty_system;
+use tokio::sync::Mutex;
+use tokio::sync::mpsc;
+use tokio::sync::oneshot;
+use tokio::time::Duration;
+use tokio::time::Instant;
+use tokio::time::timeout;
+
+use crate::exec_command::exec_command_params::ExecCommandParams;
+use crate::exec_command::exec_command_params::WriteStdinParams;
+use crate::exec_command::exec_command_session::ExecCommandSession;
+use crate::exec_command::session_id::SessionId;
+use crate::models::FunctionCallOutputPayload;
+
+pub static SESSION_MANAGER: LazyLock<SessionManager> = LazyLock::new(SessionManager::default);
+
+#[derive(Debug, Default)]
+pub struct SessionManager {
+    next_session_id: AtomicU32,
+    sessions: Mutex<HashMap<SessionId, ExecCommandSession>>,
+}
+
+#[derive(Debug)]
+pub struct ExecCommandOutput {
+    wall_time: Duration,
+    exit_status: ExitStatus,
+    original_token_count: Option<u64>,
+    output: String,
+}
+
+impl ExecCommandOutput {
+    fn to_text_output(&self) -> String {
+        let wall_time_secs = self.wall_time.as_secs_f32();
+        let termination_status = match self.exit_status {
+            ExitStatus::Exited(code) => format!("Process exited with code {code}"),
+            ExitStatus::Ongoing(session_id) => {
+                format!("Process running with session ID {}", session_id.0)
+            }
+        };
+        let truncation_status = match self.original_token_count {
+            Some(tokens) => {
+                format!("\nWarning: truncated output (original token count: {tokens})")
+            }
+            None => "".to_string(),
+        };
+        format!(
+            r#"Wall time: {wall_time_secs:.3} seconds
+{termination_status}{truncation_status}
+Output:
+{output}"#,
+            output = self.output
+        )
+    }
+}
+
+#[derive(Debug)]
+pub enum ExitStatus {
+    Exited(i32),
+    Ongoing(SessionId),
+}
+
+impl From<Result<ExecCommandOutput, String>> for FunctionCallOutputPayload {
+    fn from(val: Result<ExecCommandOutput, String>) -> Self {
+        match val {
+            Ok(output) => FunctionCallOutputPayload {
+                content: output.to_text_output(),
+                success: Some(true),
+            },
+            Err(err) => FunctionCallOutputPayload {
+                content: err,
+                success: Some(false),
+            },
+        }
+    }
+}
+
+impl SessionManager {
+    /// Processes the request and is required to send a response via `outgoing`.
+    pub async fn handle_exec_command_request(
+        &self,
+        params: ExecCommandParams,
+    ) -> Result<ExecCommandOutput, String> {
+        // Allocate a session id.
+        let session_id = SessionId(
+            self.next_session_id
+                .fetch_add(1, std::sync::atomic::Ordering::SeqCst),
+        );
+
+        let (session, mut exit_rx) =
+            create_exec_command_session(params.clone())
+                .await
+                .map_err(|err| {
+                    format!(
+                        "failed to create exec command session for session id {}: {err}",
+                        session_id.0
+                    )
+                })?;
+
+        // Insert into session map.
+        let mut output_rx = session.output_receiver();
+        self.sessions.lock().await.insert(session_id, session);
+
+        // Collect output until either timeout expires or process exits.
+        // Do not cap during collection; truncate at the end if needed.
+        // Use a modest initial capacity to avoid large preallocation.
+        let cap_bytes_u64 = params.max_output_tokens.saturating_mul(4);
+        let cap_bytes: usize = cap_bytes_u64.min(usize::MAX as u64) as usize;
+        let mut collected: Vec<u8> = Vec::with_capacity(4096);
+
+        let start_time = Instant::now();
+        let deadline = start_time + Duration::from_millis(params.yield_time_ms);
+        let mut exit_code: Option<i32> = None;
+
+        loop {
+            if Instant::now() >= deadline {
+                break;
+            }
+            let remaining = deadline.saturating_duration_since(Instant::now());
+            tokio::select! {
+                biased;
+                exit = &mut exit_rx => {
+                    exit_code = exit.ok();
+                    // Small grace period to pull remaining buffered output
+                    let grace_deadline = Instant::now() + Duration::from_millis(25);
+                    while Instant::now() < grace_deadline {
+                        match timeout(Duration::from_millis(1), output_rx.recv()).await {
+                            Ok(Ok(chunk)) => {
+                                collected.extend_from_slice(&chunk);
+                            }
+                            Ok(Err(tokio::sync::broadcast::error::RecvError::Lagged(_))) => {
+                                // Skip missed messages; keep trying within grace period.
+                                continue;
+                            }
+                            Ok(Err(tokio::sync::broadcast::error::RecvError::Closed)) => break,
+                            Err(_) => break,
+                        }
+                    }
+                    break;
+                }
+                chunk = timeout(remaining, output_rx.recv()) => {
+                    match chunk {
+                        Ok(Ok(chunk)) => {
+                            collected.extend_from_slice(&chunk);
+                        }
+                        Ok(Err(tokio::sync::broadcast::error::RecvError::Lagged(_))) => {
+                            // Skip missed messages; continue collecting fresh output.
+                        }
+                        Ok(Err(tokio::sync::broadcast::error::RecvError::Closed)) => { break; }
+                        Err(_) => { break; }
+                    }
+                }
+            }
+        }
+
+        let output = String::from_utf8_lossy(&collected).to_string();
+
+        let exit_status = if let Some(code) = exit_code {
+            ExitStatus::Exited(code)
+        } else {
+            ExitStatus::Ongoing(session_id)
+        };
+
+        // If output exceeds cap, truncate the middle and record original token estimate.
+        let (output, original_token_count) = truncate_middle(&output, cap_bytes);
+        Ok(ExecCommandOutput {
+            wall_time: Instant::now().duration_since(start_time),
+            exit_status,
+            original_token_count,
+            output,
+        })
+    }
+
+    /// Write characters to a session's stdin and collect combined output for up to `yield_time_ms`.
+    pub async fn handle_write_stdin_request(
+        &self,
+        params: WriteStdinParams,
+    ) -> Result<ExecCommandOutput, String> {
+        let WriteStdinParams {
+            session_id,
+            chars,
+            yield_time_ms,
+            max_output_tokens,
+        } = params;
+
+        // Grab handles without holding the sessions lock across await points.
+        let (writer_tx, mut output_rx) = {
+            let sessions = self.sessions.lock().await;
+            match sessions.get(&session_id) {
+                Some(session) => (session.writer_sender(), session.output_receiver()),
+                None => {
+                    return Err(format!("unknown session id {}", session_id.0));
+                }
+            }
+        };
+
+        // Write stdin if provided.
+        if !chars.is_empty() && writer_tx.send(chars.into_bytes()).await.is_err() {
+            return Err("failed to write to stdin".to_string());
+        }
+
+        // Collect output up to yield_time_ms, truncating to max_output_tokens bytes.
+        let mut collected: Vec<u8> = Vec::with_capacity(4096);
+        let start_time = Instant::now();
+        let deadline = start_time + Duration::from_millis(yield_time_ms);
+        loop {
+            let now = Instant::now();
+            if now >= deadline {
+                break;
+            }
+            let remaining = deadline - now;
+            match timeout(remaining, output_rx.recv()).await {
+                Ok(Ok(chunk)) => {
+                    // Collect all output within the time budget; truncate at the end.
+                    collected.extend_from_slice(&chunk);
+                }
+                Ok(Err(tokio::sync::broadcast::error::RecvError::Lagged(_))) => {
+                    // Skip missed messages; continue collecting fresh output.
+                }
+                Ok(Err(tokio::sync::broadcast::error::RecvError::Closed)) => break,
+                Err(_) => break, // timeout
+            }
+        }
+
+        // Return structured output, truncating middle if over cap.
+        let output = String::from_utf8_lossy(&collected).to_string();
+        let cap_bytes_u64 = max_output_tokens.saturating_mul(4);
+        let cap_bytes: usize = cap_bytes_u64.min(usize::MAX as u64) as usize;
+        let (output, original_token_count) = truncate_middle(&output, cap_bytes);
+        Ok(ExecCommandOutput {
+            wall_time: Instant::now().duration_since(start_time),
+            exit_status: ExitStatus::Ongoing(session_id),
+            original_token_count,
+            output,
+        })
+    }
+}
+
+/// Spawn PTY and child process per spawn_exec_command_session logic.
+async fn create_exec_command_session(
+    params: ExecCommandParams,
+) -> anyhow::Result<(ExecCommandSession, oneshot::Receiver<i32>)> {
+    let ExecCommandParams {
+        cmd,
+        yield_time_ms: _,
+        max_output_tokens: _,
+        shell,
+        login,
+    } = params;
+
+    // Use the native pty implementation for the system
+    let pty_system = native_pty_system();
+
+    // Create a new pty
+    let pair = pty_system.openpty(PtySize {
+        rows: 24,
+        cols: 80,
+        pixel_width: 0,
+        pixel_height: 0,
+    })?;
+
+    // Spawn a shell into the pty
+    let mut command_builder = CommandBuilder::new(shell);
+    let shell_mode_opt = if login { "-lc" } else { "-c" };
+    command_builder.arg(shell_mode_opt);
+    command_builder.arg(cmd);
+
+    let mut child = pair.slave.spawn_command(command_builder)?;
+    // Obtain a killer that can signal the process independently of `.wait()`.
+    let killer = child.clone_killer();
+
+    // Channel to forward write requests to the PTY writer.
+    let (writer_tx, mut writer_rx) = mpsc::channel::<Vec<u8>>(128);
+    // Broadcast for streaming PTY output to readers: subscribers receive from subscription time.
+    let (output_tx, _) = tokio::sync::broadcast::channel::<Vec<u8>>(256);
+
+    // Reader task: drain PTY and forward chunks to output channel.
+    let mut reader = pair.master.try_clone_reader()?;
+    let output_tx_clone = output_tx.clone();
+    let reader_handle = tokio::task::spawn_blocking(move || {
+        let mut buf = [0u8; 8192];
+        loop {
+            match reader.read(&mut buf) {
+                Ok(0) => break, // EOF
+                Ok(n) => {
+                    // Forward to broadcast; best-effort if there are subscribers.
+                    let _ = output_tx_clone.send(buf[..n].to_vec());
+                }
+                Err(_) => break,

@hanson-openai I think this is the read that you are worried about BlockingIO. I'll add some logic.