Compare commits

...

1 Commits

Author SHA1 Message Date
Dylan
305847c010 Add structured shell tool variant 2025-10-03 22:54:28 -07:00
3 changed files with 259 additions and 16 deletions

View File

@@ -5,16 +5,23 @@ use crate::codex::TurnContext;
use crate::exec::ExecParams;
use crate::exec_env::create_env;
use crate::function_tool::FunctionCallError;
use crate::tools::ExecResponseFormat;
use crate::tools::context::ToolInvocation;
use crate::tools::context::ToolOutput;
use crate::tools::context::ToolPayload;
use crate::tools::handle_container_exec_with_params;
use crate::tools::handle_container_exec_with_params_with_format;
use crate::tools::registry::ToolHandler;
use crate::tools::registry::ToolKind;
pub struct ShellHandler;
pub struct ShellHandler {
response_format: ExecResponseFormat,
}
impl ShellHandler {
pub fn new(response_format: ExecResponseFormat) -> Self {
Self { response_format }
}
fn to_exec_params(params: ShellToolCallParams, turn_context: &TurnContext) -> ExecParams {
ExecParams {
command: params.command,
@@ -63,7 +70,7 @@ impl ToolHandler for ShellHandler {
))
})?;
let exec_params = Self::to_exec_params(params, turn);
let content = handle_container_exec_with_params(
let content = handle_container_exec_with_params_with_format(
tool_name.as_str(),
exec_params,
session,
@@ -71,6 +78,7 @@ impl ToolHandler for ShellHandler {
tracker,
sub_id.to_string(),
call_id.clone(),
self.response_format,
)
.await?;
Ok(ToolOutput::Function {
@@ -80,7 +88,7 @@ impl ToolHandler for ShellHandler {
}
ToolPayload::LocalShell { params } => {
let exec_params = Self::to_exec_params(params, turn);
let content = handle_container_exec_with_params(
let content = handle_container_exec_with_params_with_format(
tool_name.as_str(),
exec_params,
session,
@@ -88,6 +96,7 @@ impl ToolHandler for ShellHandler {
tracker,
sub_id.to_string(),
call_id.clone(),
self.response_format,
)
.await?;
Ok(ToolOutput::Function {

View File

@@ -45,6 +45,12 @@ pub(crate) const TELEMETRY_PREVIEW_TRUNCATION_NOTICE: &str =
"[... telemetry preview truncated ...]";
// TODO(jif) break this down
#[derive(Clone, Copy)]
pub(crate) enum ExecResponseFormat {
LegacyJson,
StructuredText,
}
pub(crate) async fn handle_container_exec_with_params(
tool_name: &str,
params: ExecParams,
@@ -53,6 +59,29 @@ pub(crate) async fn handle_container_exec_with_params(
turn_diff_tracker: &mut TurnDiffTracker,
sub_id: String,
call_id: String,
) -> Result<String, FunctionCallError> {
handle_container_exec_with_params_with_format(
tool_name,
params,
sess,
turn_context,
turn_diff_tracker,
sub_id,
call_id,
ExecResponseFormat::LegacyJson,
)
.await
}
pub(crate) async fn handle_container_exec_with_params_with_format(
tool_name: &str,
params: ExecParams,
sess: &Session,
turn_context: &TurnContext,
turn_diff_tracker: &mut TurnDiffTracker,
sub_id: String,
call_id: String,
response_format: ExecResponseFormat,
) -> Result<String, FunctionCallError> {
let otel_event_manager = turn_context.client.get_otel_event_manager();
@@ -148,7 +177,7 @@ pub(crate) async fn handle_container_exec_with_params(
match output_result {
Ok(output) => {
let ExecToolCallOutput { exit_code, .. } = &output;
let content = format_exec_output_apply_patch(&output);
let content = format_exec_output_with_style(&output, response_format);
if *exit_code == 0 {
Ok(content)
} else {
@@ -156,12 +185,14 @@ pub(crate) async fn handle_container_exec_with_params(
}
}
Err(ExecError::Function(err)) => Err(err),
Err(ExecError::Codex(CodexErr::Sandbox(SandboxErr::Timeout { output }))) => Err(
FunctionCallError::RespondToModel(format_exec_output_apply_patch(&output)),
),
Err(ExecError::Codex(err)) => Err(FunctionCallError::RespondToModel(format!(
"execution error: {err:?}"
))),
Err(ExecError::Codex(CodexErr::Sandbox(SandboxErr::Timeout { output }))) => {
Err(FunctionCallError::RespondToModel(
format_exec_output_with_style(&output, response_format),
))
}
Err(ExecError::Codex(err)) => Err(FunctionCallError::RespondToModel(
format_unexpected_exec_error(err, response_format),
)),
}
}
@@ -201,6 +232,155 @@ pub fn format_exec_output_apply_patch(exec_output: &ExecToolCallOutput) -> Strin
serde_json::to_string(&payload).expect("serialize ExecOutput")
}
fn format_exec_output_with_style(
exec_output: &ExecToolCallOutput,
response_format: ExecResponseFormat,
) -> String {
match response_format {
ExecResponseFormat::LegacyJson => format_exec_output_apply_patch(exec_output),
ExecResponseFormat::StructuredText => format_exec_output_structured(exec_output),
}
}
fn format_unexpected_exec_error(err: CodexErr, response_format: ExecResponseFormat) -> String {
match response_format {
ExecResponseFormat::LegacyJson => format!("execution error: {err:?}"),
ExecResponseFormat::StructuredText => format_structured_error(&format!("{err:?}")),
}
}
fn format_structured_error(message: &str) -> String {
let lines = vec![
"Exit code: N/A".to_string(),
"Wall time: N/A seconds".to_string(),
format!("Error: {message}"),
"Output:".to_string(),
String::new(),
];
lines.join("\n")
}
fn format_wall_time(duration: std::time::Duration) -> String {
format_significant_digits(duration.as_secs_f64(), 4)
}
fn format_significant_digits(value: f64, digits: usize) -> String {
if !value.is_finite() {
return value.to_string();
}
if value == 0.0 {
return "0".to_string();
}
let abs = value.abs();
let initial_exponent = abs.log10().floor() as i32;
let rounded_value = if value == 0.0 {
0.0
} else {
let scale = 10_f64.powf((digits as f64 - 1.0) - initial_exponent as f64);
(value * scale).round() / scale
};
let abs_rounded = rounded_value.abs();
let exponent = if abs_rounded == 0.0 {
0
} else {
abs_rounded.log10().floor() as i32
};
let use_exp = exponent < -4 || exponent >= digits as i32;
if use_exp {
return format!("{rounded_value:.prec$e}", prec = digits.saturating_sub(1));
}
let decimal_places = (digits as i32 - exponent - 1).max(0) as usize;
let mut s = format!("{rounded_value:.dp$}", dp = decimal_places);
if s.contains('.') {
while s.ends_with('0') {
s.pop();
}
if s.ends_with('.') {
s.pop();
}
}
s
}
pub fn format_exec_output_structured(exec_output: &ExecToolCallOutput) -> String {
let ExecToolCallOutput {
exit_code,
duration,
aggregated_output,
..
} = exec_output;
let mut sections = Vec::new();
sections.push(format!("Exit code: {exit_code}"));
sections.push(format!(
"Wall time: {} seconds",
format_wall_time(*duration)
));
if let Some(total_lines) = aggregated_output.truncated_after_lines {
sections.push(format!("Total output lines: {total_lines}"));
}
sections.push("Output:".to_string());
sections.push(format_exec_output_str(exec_output));
sections.join("\n")
}
#[cfg(test)]
mod tests {
use super::*;
use crate::exec::StreamOutput;
use pretty_assertions::assert_eq;
use std::time::Duration;
fn sample_output() -> ExecToolCallOutput {
ExecToolCallOutput {
exit_code: 0,
stdout: StreamOutput::new("stdout".to_string()),
stderr: StreamOutput::new("stderr".to_string()),
aggregated_output: StreamOutput::new("stdout\nstderr".to_string()),
duration: Duration::from_secs_f64(1.2345),
timed_out: false,
}
}
#[test]
fn structured_format_basic() {
let formatted = format_exec_output_structured(&sample_output());
let expected = "Exit code: 0\nWall time: 1.235 seconds\nOutput:\nstdout\nstderr";
assert_eq!(formatted, expected);
}
#[test]
fn structured_format_includes_truncation_metadata() {
let mut output = sample_output();
output.aggregated_output.truncated_after_lines = Some(200);
let formatted = format_exec_output_structured(&output);
assert!(formatted.contains("Total output lines: 200"));
}
#[test]
fn significant_digit_formatting_matches_expectations() {
assert_eq!(format_significant_digits(0.0, 4), "0");
assert_eq!(format_significant_digits(1.23456, 4), "1.235");
assert_eq!(format_significant_digits(12345.0, 4), "1.235e4");
assert_eq!(format_significant_digits(0.000123456, 4), "0.0001235");
}
#[test]
fn structured_error_includes_metadata() {
let error = format_structured_error("unexpected failure");
assert_eq!(
error,
"Exit code: N/A\nWall time: N/A seconds\nError: unexpected failure\nOutput:\n"
);
}
}
pub fn format_exec_output_str(exec_output: &ExecToolCallOutput) -> String {
let ExecToolCallOutput {
aggregated_output, ..

View File

@@ -187,7 +187,7 @@ fn create_unified_exec_tool() -> ToolSpec {
})
}
fn create_shell_tool() -> ToolSpec {
fn shell_tool_properties() -> BTreeMap<String, JsonSchema> {
let mut properties = BTreeMap::new();
properties.insert(
"command".to_string(),
@@ -222,9 +222,15 @@ fn create_shell_tool() -> ToolSpec {
},
);
properties
}
fn create_shell_tool_with_metadata(name: &str, description: &str) -> ToolSpec {
let properties = shell_tool_properties();
ToolSpec::Function(ResponsesApiTool {
name: "shell".to_string(),
description: "Runs a shell command and returns its output.".to_string(),
name: name.to_string(),
description: description.to_string(),
strict: false,
parameters: JsonSchema::Object {
properties,
@@ -234,6 +240,17 @@ fn create_shell_tool() -> ToolSpec {
})
}
fn create_shell_tool() -> ToolSpec {
create_shell_tool_with_metadata("shell", "Runs a shell command and returns its output.")
}
fn create_shell_v2_tool() -> ToolSpec {
create_shell_tool_with_metadata(
"shell_v2",
"Runs a shell command and returns its output with structured metadata.",
)
}
fn create_view_image_tool() -> ToolSpec {
// Support only local filesystem path.
let mut properties = BTreeMap::new();
@@ -501,6 +518,7 @@ pub(crate) fn build_specs(
use crate::exec_command::WRITE_STDIN_TOOL_NAME;
use crate::exec_command::create_exec_command_tool_for_responses_api;
use crate::exec_command::create_write_stdin_tool_for_responses_api;
use crate::tools::ExecResponseFormat;
use crate::tools::handlers::ApplyPatchHandler;
use crate::tools::handlers::ExecStreamHandler;
use crate::tools::handlers::McpHandler;
@@ -513,7 +531,8 @@ pub(crate) fn build_specs(
let mut builder = ToolRegistryBuilder::new();
let shell_handler = Arc::new(ShellHandler);
let shell_handler = Arc::new(ShellHandler::new(ExecResponseFormat::LegacyJson));
let shell_v2_handler = Arc::new(ShellHandler::new(ExecResponseFormat::StructuredText));
let exec_stream_handler = Arc::new(ExecStreamHandler);
let unified_exec_handler = Arc::new(UnifiedExecHandler);
let plan_handler = Arc::new(PlanHandler);
@@ -528,6 +547,8 @@ pub(crate) fn build_specs(
match &config.shell_type {
ConfigShellToolType::Default => {
builder.push_spec(create_shell_tool());
builder.push_spec(create_shell_v2_tool());
builder.register_handler("shell_v2", shell_v2_handler.clone());
}
ConfigShellToolType::Local => {
builder.push_spec(ToolSpec::LocalShell {});
@@ -548,7 +569,8 @@ pub(crate) fn build_specs(
// Always register shell aliases so older prompts remain compatible.
builder.register_handler("shell", shell_handler.clone());
builder.register_handler("container.exec", shell_handler.clone());
builder.register_handler("local_shell", shell_handler);
builder.register_handler("local_shell", shell_handler.clone());
builder.register_handler("shell_v2", shell_v2_handler);
if config.plan_tool {
builder.push_spec(PLAN_TOOL.clone());
@@ -639,6 +661,23 @@ mod tests {
}
}
#[test]
fn test_build_specs_with_shell_variants() {
let model_family = find_family_for_model("o3").expect("o3 should be a valid model family");
let config = ToolsConfig::new(&ToolsConfigParams {
model_family: &model_family,
include_plan_tool: false,
include_apply_patch_tool: false,
include_web_search_request: false,
use_streamable_shell_tool: false,
include_view_image_tool: false,
experimental_unified_exec_tool: false,
});
let (tools, _) = build_specs(&config, Some(HashMap::new())).build();
assert_eq_tool_names(&tools, &["shell", "shell_v2"]);
}
#[test]
fn test_build_specs() {
let model_family = find_family_for_model("codex-mini-latest")
@@ -1158,6 +1197,21 @@ mod tests {
assert_eq!(description, expected);
}
#[test]
fn test_shell_v2_tool() {
let tool = super::create_shell_v2_tool();
let ToolSpec::Function(ResponsesApiTool {
description, name, ..
}) = &tool
else {
panic!("expected function tool");
};
assert_eq!(name, "shell_v2");
let expected = "Runs a shell command and returns its output with structured metadata.";
assert_eq!(description, expected);
}
#[test]
fn test_get_openai_tools_mcp_tools_with_additional_properties_schema() {
let model_family = find_family_for_model("gpt-5-codex")