mirror of
https://github.com/openai/codex.git
synced 2026-03-23 15:13:49 +00:00
Compare commits
22 Commits
starr/exec
...
starr/core
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cd1ebd6b5b | ||
|
|
4314d96e22 | ||
|
|
7f57c4590e | ||
|
|
6e84ab68f0 | ||
|
|
736fc3520d | ||
|
|
4b4ccf062a | ||
|
|
69eeead714 | ||
|
|
161d2a0f20 | ||
|
|
f42bf5b89e | ||
|
|
47729f70d4 | ||
|
|
a3e7bf62a7 | ||
|
|
2d037ca8f6 | ||
|
|
ebe2855c5a | ||
|
|
f8e872d6b6 | ||
|
|
4abc243801 | ||
|
|
17060efdbc | ||
|
|
6ed2130b1b | ||
|
|
ff18ed98c9 | ||
|
|
1720036c1b | ||
|
|
6459f85644 | ||
|
|
18889645ac | ||
|
|
ca029c614f |
20
codex-rs/Cargo.lock
generated
20
codex-rs/Cargo.lock
generated
@@ -1844,11 +1844,13 @@ dependencies = [
|
||||
"codex-config",
|
||||
"codex-connectors",
|
||||
"codex-environment",
|
||||
"codex-exec-server",
|
||||
"codex-execpolicy",
|
||||
"codex-file-search",
|
||||
"codex-git",
|
||||
"codex-hooks",
|
||||
"codex-keyring-store",
|
||||
"codex-mcp-types",
|
||||
"codex-network-proxy",
|
||||
"codex-otel",
|
||||
"codex-protocol",
|
||||
@@ -1859,6 +1861,7 @@ dependencies = [
|
||||
"codex-skills",
|
||||
"codex-state",
|
||||
"codex-test-macros",
|
||||
"codex-tool-spec",
|
||||
"codex-utils-absolute-path",
|
||||
"codex-utils-cache",
|
||||
"codex-utils-cargo-bin",
|
||||
@@ -2011,6 +2014,7 @@ dependencies = [
|
||||
"base64 0.22.1",
|
||||
"clap",
|
||||
"codex-app-server-protocol",
|
||||
"codex-environment",
|
||||
"codex-utils-cargo-bin",
|
||||
"codex-utils-pty",
|
||||
"futures",
|
||||
@@ -2227,6 +2231,14 @@ dependencies = [
|
||||
"wiremock",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "codex-mcp-types"
|
||||
version = "0.0.0"
|
||||
dependencies = [
|
||||
"codex-protocol",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "codex-network-proxy"
|
||||
version = "0.0.0"
|
||||
@@ -2523,6 +2535,14 @@ dependencies = [
|
||||
"syn 2.0.114",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "codex-tool-spec"
|
||||
version = "0.0.0"
|
||||
dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "codex-tui"
|
||||
version = "0.0.0"
|
||||
|
||||
@@ -34,6 +34,7 @@ members = [
|
||||
"linux-sandbox",
|
||||
"lmstudio",
|
||||
"login",
|
||||
"mcp-types",
|
||||
"mcp-server",
|
||||
"network-proxy",
|
||||
"ollama",
|
||||
@@ -69,6 +70,7 @@ members = [
|
||||
"state",
|
||||
"codex-experimental-api-macros",
|
||||
"test-macros",
|
||||
"tool-spec",
|
||||
"package-manager",
|
||||
"artifacts",
|
||||
]
|
||||
@@ -107,6 +109,7 @@ codex-config = { path = "config" }
|
||||
codex-core = { path = "core" }
|
||||
codex-environment = { path = "environment" }
|
||||
codex-exec = { path = "exec" }
|
||||
codex-exec-server = { path = "exec-server" }
|
||||
codex-execpolicy = { path = "execpolicy" }
|
||||
codex-experimental-api-macros = { path = "codex-experimental-api-macros" }
|
||||
codex-feedback = { path = "feedback" }
|
||||
@@ -117,6 +120,7 @@ codex-keyring-store = { path = "keyring-store" }
|
||||
codex-linux-sandbox = { path = "linux-sandbox" }
|
||||
codex-lmstudio = { path = "lmstudio" }
|
||||
codex-login = { path = "login" }
|
||||
codex-mcp-types = { path = "mcp-types" }
|
||||
codex-mcp-server = { path = "mcp-server" }
|
||||
codex-network-proxy = { path = "network-proxy" }
|
||||
codex-ollama = { path = "ollama" }
|
||||
@@ -132,6 +136,7 @@ codex-skills = { path = "skills" }
|
||||
codex-state = { path = "state" }
|
||||
codex-stdio-to-uds = { path = "stdio-to-uds" }
|
||||
codex-test-macros = { path = "test-macros" }
|
||||
codex-tool-spec = { path = "tool-spec" }
|
||||
codex-tui = { path = "tui" }
|
||||
codex-tui-app-server = { path = "tui_app_server" }
|
||||
codex-utils-absolute-path = { path = "utils/absolute-path" }
|
||||
|
||||
@@ -34,7 +34,7 @@ pub(crate) struct FsApi {
|
||||
impl Default for FsApi {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
file_system: Arc::new(Environment.get_filesystem()),
|
||||
file_system: Environment::default().get_filesystem(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -35,13 +35,16 @@ codex-client = { workspace = true }
|
||||
codex-connectors = { workspace = true }
|
||||
codex-config = { workspace = true }
|
||||
codex-environment = { workspace = true }
|
||||
codex-exec-server = { workspace = true }
|
||||
codex-shell-command = { workspace = true }
|
||||
codex-skills = { workspace = true }
|
||||
codex-tool-spec = { workspace = true }
|
||||
codex-execpolicy = { workspace = true }
|
||||
codex-file-search = { workspace = true }
|
||||
codex-git = { workspace = true }
|
||||
codex-hooks = { workspace = true }
|
||||
codex-keyring-store = { workspace = true }
|
||||
codex-mcp-types = { workspace = true }
|
||||
codex-network-proxy = { workspace = true }
|
||||
codex-otel = { workspace = true }
|
||||
codex-artifacts = { workspace = true }
|
||||
@@ -92,7 +95,6 @@ sha2 = { workspace = true }
|
||||
shlex = { workspace = true }
|
||||
similar = { workspace = true }
|
||||
tempfile = { workspace = true }
|
||||
test-log = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
time = { workspace = true, features = [
|
||||
"formatting",
|
||||
@@ -166,6 +168,7 @@ opentelemetry = { workspace = true }
|
||||
predicates = { workspace = true }
|
||||
pretty_assertions = { workspace = true }
|
||||
test-case = "3.3.1"
|
||||
test-log = { workspace = true }
|
||||
opentelemetry_sdk = { workspace = true, features = [
|
||||
"experimental_metrics_custom_reader",
|
||||
"metrics",
|
||||
|
||||
@@ -1915,6 +1915,18 @@
|
||||
"description": "Experimental / do not use. Replaces the synthesized realtime startup context appended to websocket session instructions. An empty string disables startup context injection entirely.",
|
||||
"type": "string"
|
||||
},
|
||||
"experimental_unified_exec_exec_server_websocket_url": {
|
||||
"description": "Optional websocket URL for connecting to an existing `codex-exec-server`.",
|
||||
"type": "string"
|
||||
},
|
||||
"experimental_unified_exec_exec_server_workspace_root": {
|
||||
"allOf": [
|
||||
{
|
||||
"$ref": "#/definitions/AbsolutePathBuf"
|
||||
}
|
||||
],
|
||||
"description": "Optional absolute path to the executor-visible workspace root that corresponds to the local session cwd."
|
||||
},
|
||||
"experimental_use_freeform_apply_patch": {
|
||||
"type": "boolean"
|
||||
},
|
||||
|
||||
@@ -94,6 +94,7 @@ use crate::auth::RefreshTokenError;
|
||||
use crate::client_common::Prompt;
|
||||
use crate::client_common::ResponseEvent;
|
||||
use crate::client_common::ResponseStream;
|
||||
use crate::client_common::create_tools_json_for_responses_api;
|
||||
use crate::default_client::build_reqwest_client;
|
||||
use crate::error::CodexErr;
|
||||
use crate::error::Result;
|
||||
@@ -104,7 +105,6 @@ use crate::response_debug_context::extract_response_debug_context;
|
||||
use crate::response_debug_context::extract_response_debug_context_from_api_error;
|
||||
use crate::response_debug_context::telemetry_api_error_message;
|
||||
use crate::response_debug_context::telemetry_transport_error_message;
|
||||
use crate::tools::spec::create_tools_json_for_responses_api;
|
||||
use crate::util::FeedbackRequestTags;
|
||||
use crate::util::emit_feedback_auth_recovery_tags;
|
||||
use crate::util::emit_feedback_request_tags_with_auth_env;
|
||||
|
||||
@@ -14,6 +14,8 @@ use std::task::Context;
|
||||
use std::task::Poll;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
mod tools_json;
|
||||
|
||||
/// Review thread system prompt. Edit `core/src/review_prompt.md` to customize.
|
||||
pub const REVIEW_PROMPT: &str = include_str!("../review_prompt.md");
|
||||
|
||||
@@ -44,6 +46,8 @@ pub struct Prompt {
|
||||
pub output_schema: Option<Value>,
|
||||
}
|
||||
|
||||
pub(crate) use tools_json::create_tools_json_for_responses_api;
|
||||
|
||||
impl Prompt {
|
||||
pub(crate) fn get_formatted_input(&self) -> Vec<ResponseItem> {
|
||||
let mut input = self.input.clone();
|
||||
|
||||
23
codex-rs/core/src/client_common/tools_json.rs
Normal file
23
codex-rs/core/src/client_common/tools_json.rs
Normal file
@@ -0,0 +1,23 @@
|
||||
use crate::error::Result;
|
||||
|
||||
use super::tools::ToolSpec;
|
||||
|
||||
/// Returns JSON values that are compatible with Function Calling in the
|
||||
/// Responses API:
|
||||
/// https://platform.openai.com/docs/guides/function-calling?api-mode=responses
|
||||
///
|
||||
/// This helper intentionally lives under `client_common` so the model client
|
||||
/// can serialize tools without depending on the full tool registry builder in
|
||||
/// `tools/spec.rs`.
|
||||
pub(crate) fn create_tools_json_for_responses_api(
|
||||
tools: &[ToolSpec],
|
||||
) -> Result<Vec<serde_json::Value>> {
|
||||
let mut tools_json = Vec::new();
|
||||
|
||||
for tool in tools {
|
||||
let json = serde_json::to_value(tool)?;
|
||||
tools_json.push(json);
|
||||
}
|
||||
|
||||
Ok(tools_json)
|
||||
}
|
||||
12
codex-rs/core/src/client_tools.rs
Normal file
12
codex-rs/core/src/client_tools.rs
Normal file
@@ -0,0 +1,12 @@
|
||||
use serde::Serialize;
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::error::Result;
|
||||
|
||||
pub(crate) fn create_tools_json_for_responses_api<T: Serialize>(tools: &[T]) -> Result<Vec<Value>> {
|
||||
tools
|
||||
.iter()
|
||||
.map(serde_json::to_value)
|
||||
.collect::<serde_json::Result<Vec<_>>>()
|
||||
.map_err(Into::into)
|
||||
}
|
||||
@@ -312,6 +312,7 @@ use crate::turn_diff_tracker::TurnDiffTracker;
|
||||
use crate::turn_timing::TurnTimingState;
|
||||
use crate::turn_timing::record_turn_ttfm_metric;
|
||||
use crate::turn_timing::record_turn_ttft_metric;
|
||||
use crate::unified_exec::RemoteExecServerBackend;
|
||||
use crate::unified_exec::UnifiedExecProcessManager;
|
||||
use crate::util::backoff;
|
||||
use crate::windows_sandbox::WindowsSandboxLevelExt;
|
||||
@@ -1772,6 +1773,13 @@ impl Session {
|
||||
});
|
||||
}
|
||||
|
||||
let remote_exec_server =
|
||||
RemoteExecServerBackend::connect_for_config(config.as_ref()).await?;
|
||||
let environment = remote_exec_server
|
||||
.as_ref()
|
||||
.map(|backend| Environment::with_file_system(Arc::new(backend.file_system())))
|
||||
.unwrap_or_default();
|
||||
|
||||
let services = SessionServices {
|
||||
// Initialize the MCP connection manager with an uninitialized
|
||||
// instance. It will be replaced with one created via
|
||||
@@ -1784,8 +1792,9 @@ impl Session {
|
||||
&config.permissions.approval_policy,
|
||||
))),
|
||||
mcp_startup_cancellation_token: Mutex::new(CancellationToken::new()),
|
||||
unified_exec_manager: UnifiedExecProcessManager::new(
|
||||
unified_exec_manager: UnifiedExecProcessManager::with_remote_exec_server(
|
||||
config.background_terminal_max_timeout,
|
||||
remote_exec_server,
|
||||
),
|
||||
shell_zsh_path: config.zsh_path.clone(),
|
||||
main_execve_wrapper_exe: config.main_execve_wrapper_exe.clone(),
|
||||
@@ -1825,7 +1834,7 @@ impl Session {
|
||||
code_mode_service: crate::tools::code_mode::CodeModeService::new(
|
||||
config.js_repl_node_path.clone(),
|
||||
),
|
||||
environment: Arc::new(Environment),
|
||||
environment: Arc::new(environment),
|
||||
};
|
||||
let js_repl = Arc::new(JsReplHandle::with_node_path(
|
||||
config.js_repl_node_path.clone(),
|
||||
@@ -1907,6 +1916,7 @@ impl Session {
|
||||
*cancel_guard = CancellationToken::new();
|
||||
}
|
||||
let (mcp_connection_manager, cancel_token) = McpConnectionManager::new(
|
||||
INITIAL_SUBMIT_ID.to_owned(),
|
||||
&mcp_servers,
|
||||
config.mcp_oauth_credentials_store_mode,
|
||||
auth_statuses.clone(),
|
||||
@@ -4042,6 +4052,7 @@ impl Session {
|
||||
*guard = CancellationToken::new();
|
||||
}
|
||||
let (refreshed_manager, cancel_token) = McpConnectionManager::new(
|
||||
INITIAL_SUBMIT_ID.to_owned(),
|
||||
&mcp_servers,
|
||||
store_mode,
|
||||
auth_statuses,
|
||||
@@ -4374,774 +4385,7 @@ fn submission_dispatch_span(sub: &Submission) -> tracing::Span {
|
||||
}
|
||||
|
||||
/// Operation handlers
|
||||
mod handlers {
|
||||
use crate::codex::Session;
|
||||
use crate::codex::SessionSettingsUpdate;
|
||||
use crate::codex::SteerInputError;
|
||||
|
||||
use crate::codex::spawn_review_thread;
|
||||
use crate::config::Config;
|
||||
|
||||
use crate::mcp::auth::compute_auth_statuses;
|
||||
use crate::mcp::collect_mcp_snapshot_from_manager;
|
||||
use crate::review_prompts::resolve_review_request;
|
||||
use crate::rollout::RolloutRecorder;
|
||||
use crate::rollout::session_index;
|
||||
use crate::tasks::CompactTask;
|
||||
use crate::tasks::UndoTask;
|
||||
use crate::tasks::UserShellCommandMode;
|
||||
use crate::tasks::UserShellCommandTask;
|
||||
use crate::tasks::execute_user_shell_command;
|
||||
use codex_protocol::custom_prompts::CustomPrompt;
|
||||
use codex_protocol::protocol::CodexErrorInfo;
|
||||
use codex_protocol::protocol::ErrorEvent;
|
||||
use codex_protocol::protocol::Event;
|
||||
use codex_protocol::protocol::EventMsg;
|
||||
use codex_protocol::protocol::ListCustomPromptsResponseEvent;
|
||||
use codex_protocol::protocol::ListSkillsResponseEvent;
|
||||
use codex_protocol::protocol::McpServerRefreshConfig;
|
||||
use codex_protocol::protocol::Op;
|
||||
use codex_protocol::protocol::ReviewDecision;
|
||||
use codex_protocol::protocol::ReviewRequest;
|
||||
use codex_protocol::protocol::RolloutItem;
|
||||
use codex_protocol::protocol::SkillsListEntry;
|
||||
use codex_protocol::protocol::ThreadNameUpdatedEvent;
|
||||
use codex_protocol::protocol::ThreadRolledBackEvent;
|
||||
use codex_protocol::protocol::TurnAbortReason;
|
||||
use codex_protocol::protocol::WarningEvent;
|
||||
use codex_protocol::request_permissions::RequestPermissionsResponse;
|
||||
use codex_protocol::request_user_input::RequestUserInputResponse;
|
||||
|
||||
use crate::context_manager::is_user_turn_boundary;
|
||||
use codex_protocol::config_types::CollaborationMode;
|
||||
use codex_protocol::config_types::ModeKind;
|
||||
use codex_protocol::config_types::Settings;
|
||||
use codex_protocol::dynamic_tools::DynamicToolResponse;
|
||||
use codex_protocol::mcp::RequestId as ProtocolRequestId;
|
||||
use codex_protocol::user_input::UserInput;
|
||||
use codex_rmcp_client::ElicitationAction;
|
||||
use codex_rmcp_client::ElicitationResponse;
|
||||
use serde_json::Value;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use tracing::info;
|
||||
use tracing::warn;
|
||||
|
||||
pub async fn interrupt(sess: &Arc<Session>) {
|
||||
sess.interrupt_task().await;
|
||||
}
|
||||
|
||||
pub async fn clean_background_terminals(sess: &Arc<Session>) {
|
||||
sess.close_unified_exec_processes().await;
|
||||
}
|
||||
|
||||
pub async fn override_turn_context(
|
||||
sess: &Session,
|
||||
sub_id: String,
|
||||
updates: SessionSettingsUpdate,
|
||||
) {
|
||||
if let Err(err) = sess.update_settings(updates).await {
|
||||
sess.send_event_raw(Event {
|
||||
id: sub_id,
|
||||
msg: EventMsg::Error(ErrorEvent {
|
||||
message: err.to_string(),
|
||||
codex_error_info: Some(CodexErrorInfo::BadRequest),
|
||||
}),
|
||||
})
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn user_input_or_turn(sess: &Arc<Session>, sub_id: String, op: Op) {
|
||||
let (items, updates) = match op {
|
||||
Op::UserTurn {
|
||||
cwd,
|
||||
approval_policy,
|
||||
sandbox_policy,
|
||||
model,
|
||||
effort,
|
||||
summary,
|
||||
service_tier,
|
||||
final_output_json_schema,
|
||||
items,
|
||||
collaboration_mode,
|
||||
personality,
|
||||
} => {
|
||||
let collaboration_mode = collaboration_mode.or_else(|| {
|
||||
Some(CollaborationMode {
|
||||
mode: ModeKind::Default,
|
||||
settings: Settings {
|
||||
model: model.clone(),
|
||||
reasoning_effort: effort,
|
||||
developer_instructions: None,
|
||||
},
|
||||
})
|
||||
});
|
||||
(
|
||||
items,
|
||||
SessionSettingsUpdate {
|
||||
cwd: Some(cwd),
|
||||
approval_policy: Some(approval_policy),
|
||||
approvals_reviewer: None,
|
||||
sandbox_policy: Some(sandbox_policy),
|
||||
windows_sandbox_level: None,
|
||||
collaboration_mode,
|
||||
reasoning_summary: summary,
|
||||
service_tier,
|
||||
final_output_json_schema: Some(final_output_json_schema),
|
||||
personality,
|
||||
app_server_client_name: None,
|
||||
},
|
||||
)
|
||||
}
|
||||
Op::UserInput {
|
||||
items,
|
||||
final_output_json_schema,
|
||||
} => (
|
||||
items,
|
||||
SessionSettingsUpdate {
|
||||
final_output_json_schema: Some(final_output_json_schema),
|
||||
..Default::default()
|
||||
},
|
||||
),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
let Ok(current_context) = sess.new_turn_with_sub_id(sub_id, updates).await else {
|
||||
// new_turn_with_sub_id already emits the error event.
|
||||
return;
|
||||
};
|
||||
sess.maybe_emit_unknown_model_warning_for_turn(current_context.as_ref())
|
||||
.await;
|
||||
current_context.session_telemetry.user_prompt(&items);
|
||||
|
||||
// Attempt to inject input into current task.
|
||||
if let Err(SteerInputError::NoActiveTurn(items)) =
|
||||
sess.steer_input(items, /*expected_turn_id*/ None).await
|
||||
{
|
||||
sess.refresh_mcp_servers_if_requested(¤t_context)
|
||||
.await;
|
||||
sess.spawn_task(
|
||||
Arc::clone(¤t_context),
|
||||
items,
|
||||
crate::tasks::RegularTask::new(),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn run_user_shell_command(sess: &Arc<Session>, sub_id: String, command: String) {
|
||||
if let Some((turn_context, cancellation_token)) =
|
||||
sess.active_turn_context_and_cancellation_token().await
|
||||
{
|
||||
let session = Arc::clone(sess);
|
||||
tokio::spawn(async move {
|
||||
execute_user_shell_command(
|
||||
session,
|
||||
turn_context,
|
||||
command,
|
||||
cancellation_token,
|
||||
UserShellCommandMode::ActiveTurnAuxiliary,
|
||||
)
|
||||
.await;
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
let turn_context = sess.new_default_turn_with_sub_id(sub_id).await;
|
||||
sess.spawn_task(
|
||||
Arc::clone(&turn_context),
|
||||
Vec::new(),
|
||||
UserShellCommandTask::new(command),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
pub async fn resolve_elicitation(
|
||||
sess: &Arc<Session>,
|
||||
server_name: String,
|
||||
request_id: ProtocolRequestId,
|
||||
decision: codex_protocol::approvals::ElicitationAction,
|
||||
content: Option<Value>,
|
||||
meta: Option<Value>,
|
||||
) {
|
||||
let action = match decision {
|
||||
codex_protocol::approvals::ElicitationAction::Accept => ElicitationAction::Accept,
|
||||
codex_protocol::approvals::ElicitationAction::Decline => ElicitationAction::Decline,
|
||||
codex_protocol::approvals::ElicitationAction::Cancel => ElicitationAction::Cancel,
|
||||
};
|
||||
let content = match action {
|
||||
// Preserve the legacy fallback for clients that only send an action.
|
||||
ElicitationAction::Accept => Some(content.unwrap_or_else(|| serde_json::json!({}))),
|
||||
ElicitationAction::Decline | ElicitationAction::Cancel => None,
|
||||
};
|
||||
let response = ElicitationResponse {
|
||||
action,
|
||||
content,
|
||||
meta,
|
||||
};
|
||||
let request_id = match request_id {
|
||||
ProtocolRequestId::String(value) => {
|
||||
rmcp::model::NumberOrString::String(std::sync::Arc::from(value))
|
||||
}
|
||||
ProtocolRequestId::Integer(value) => rmcp::model::NumberOrString::Number(value),
|
||||
};
|
||||
if let Err(err) = sess
|
||||
.resolve_elicitation(server_name, request_id, response)
|
||||
.await
|
||||
{
|
||||
warn!(
|
||||
error = %err,
|
||||
"failed to resolve elicitation request in session"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Propagate a user's exec approval decision to the session.
|
||||
/// Also optionally applies an execpolicy amendment.
|
||||
pub async fn exec_approval(
|
||||
sess: &Arc<Session>,
|
||||
approval_id: String,
|
||||
turn_id: Option<String>,
|
||||
decision: ReviewDecision,
|
||||
) {
|
||||
let event_turn_id = turn_id.unwrap_or_else(|| approval_id.clone());
|
||||
if let ReviewDecision::ApprovedExecpolicyAmendment {
|
||||
proposed_execpolicy_amendment,
|
||||
} = &decision
|
||||
{
|
||||
match sess
|
||||
.persist_execpolicy_amendment(proposed_execpolicy_amendment)
|
||||
.await
|
||||
{
|
||||
Ok(()) => {
|
||||
sess.record_execpolicy_amendment_message(
|
||||
&event_turn_id,
|
||||
proposed_execpolicy_amendment,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
Err(err) => {
|
||||
let message = format!("Failed to apply execpolicy amendment: {err}");
|
||||
tracing::warn!("{message}");
|
||||
let warning = EventMsg::Warning(WarningEvent { message });
|
||||
sess.send_event_raw(Event {
|
||||
id: event_turn_id.clone(),
|
||||
msg: warning,
|
||||
})
|
||||
.await;
|
||||
}
|
||||
}
|
||||
}
|
||||
match decision {
|
||||
ReviewDecision::Abort => {
|
||||
sess.interrupt_task().await;
|
||||
}
|
||||
other => sess.notify_approval(&approval_id, other).await,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn patch_approval(sess: &Arc<Session>, id: String, decision: ReviewDecision) {
|
||||
match decision {
|
||||
ReviewDecision::Abort => {
|
||||
sess.interrupt_task().await;
|
||||
}
|
||||
other => sess.notify_approval(&id, other).await,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn request_user_input_response(
|
||||
sess: &Arc<Session>,
|
||||
id: String,
|
||||
response: RequestUserInputResponse,
|
||||
) {
|
||||
sess.notify_user_input_response(&id, response).await;
|
||||
}
|
||||
|
||||
pub async fn request_permissions_response(
|
||||
sess: &Arc<Session>,
|
||||
id: String,
|
||||
response: RequestPermissionsResponse,
|
||||
) {
|
||||
sess.notify_request_permissions_response(&id, response)
|
||||
.await;
|
||||
}
|
||||
|
||||
pub async fn dynamic_tool_response(
|
||||
sess: &Arc<Session>,
|
||||
id: String,
|
||||
response: DynamicToolResponse,
|
||||
) {
|
||||
sess.notify_dynamic_tool_response(&id, response).await;
|
||||
}
|
||||
|
||||
pub async fn add_to_history(sess: &Arc<Session>, config: &Arc<Config>, text: String) {
|
||||
let id = sess.conversation_id;
|
||||
let config = Arc::clone(config);
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = crate::message_history::append_entry(&text, &id, &config).await {
|
||||
warn!("failed to append to message history: {e}");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
pub async fn get_history_entry_request(
|
||||
sess: &Arc<Session>,
|
||||
config: &Arc<Config>,
|
||||
sub_id: String,
|
||||
offset: usize,
|
||||
log_id: u64,
|
||||
) {
|
||||
let config = Arc::clone(config);
|
||||
let sess_clone = Arc::clone(sess);
|
||||
|
||||
tokio::spawn(async move {
|
||||
// Run lookup in blocking thread because it does file IO + locking.
|
||||
let entry_opt = tokio::task::spawn_blocking(move || {
|
||||
crate::message_history::lookup(log_id, offset, &config)
|
||||
})
|
||||
.await
|
||||
.unwrap_or(None);
|
||||
|
||||
let event = Event {
|
||||
id: sub_id,
|
||||
msg: EventMsg::GetHistoryEntryResponse(
|
||||
crate::protocol::GetHistoryEntryResponseEvent {
|
||||
offset,
|
||||
log_id,
|
||||
entry: entry_opt.map(|e| codex_protocol::message_history::HistoryEntry {
|
||||
conversation_id: e.session_id,
|
||||
ts: e.ts,
|
||||
text: e.text,
|
||||
}),
|
||||
},
|
||||
),
|
||||
};
|
||||
|
||||
sess_clone.send_event_raw(event).await;
|
||||
});
|
||||
}
|
||||
|
||||
pub async fn refresh_mcp_servers(sess: &Arc<Session>, refresh_config: McpServerRefreshConfig) {
|
||||
let mut guard = sess.pending_mcp_server_refresh_config.lock().await;
|
||||
*guard = Some(refresh_config);
|
||||
}
|
||||
|
||||
pub async fn reload_user_config(sess: &Arc<Session>) {
|
||||
sess.reload_user_config_layer().await;
|
||||
}
|
||||
|
||||
pub async fn list_mcp_tools(sess: &Session, config: &Arc<Config>, sub_id: String) {
|
||||
let mcp_connection_manager = sess.services.mcp_connection_manager.read().await;
|
||||
let auth = sess.services.auth_manager.auth().await;
|
||||
let mcp_servers = sess
|
||||
.services
|
||||
.mcp_manager
|
||||
.effective_servers(config, auth.as_ref());
|
||||
let snapshot = collect_mcp_snapshot_from_manager(
|
||||
&mcp_connection_manager,
|
||||
compute_auth_statuses(mcp_servers.iter(), config.mcp_oauth_credentials_store_mode)
|
||||
.await,
|
||||
)
|
||||
.await;
|
||||
let event = Event {
|
||||
id: sub_id,
|
||||
msg: EventMsg::McpListToolsResponse(snapshot),
|
||||
};
|
||||
sess.send_event_raw(event).await;
|
||||
}
|
||||
|
||||
pub async fn list_custom_prompts(sess: &Session, sub_id: String) {
|
||||
let custom_prompts: Vec<CustomPrompt> =
|
||||
if let Some(dir) = crate::custom_prompts::default_prompts_dir() {
|
||||
crate::custom_prompts::discover_prompts_in(&dir).await
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
|
||||
let event = Event {
|
||||
id: sub_id,
|
||||
msg: EventMsg::ListCustomPromptsResponse(ListCustomPromptsResponseEvent {
|
||||
custom_prompts,
|
||||
}),
|
||||
};
|
||||
sess.send_event_raw(event).await;
|
||||
}
|
||||
|
||||
pub async fn list_skills(
|
||||
sess: &Session,
|
||||
sub_id: String,
|
||||
cwds: Vec<PathBuf>,
|
||||
force_reload: bool,
|
||||
) {
|
||||
let (cwds, session_source) = if cwds.is_empty() {
|
||||
let state = sess.state.lock().await;
|
||||
(
|
||||
vec![state.session_configuration.cwd.clone()],
|
||||
state.session_configuration.session_source.clone(),
|
||||
)
|
||||
} else {
|
||||
let state = sess.state.lock().await;
|
||||
(cwds, state.session_configuration.session_source.clone())
|
||||
};
|
||||
|
||||
let skills_manager = &sess.services.skills_manager;
|
||||
let mut skills = Vec::new();
|
||||
for cwd in cwds {
|
||||
let outcome = crate::skills::filter_skill_load_outcome_for_session_source(
|
||||
skills_manager.skills_for_cwd(&cwd, force_reload).await,
|
||||
&session_source,
|
||||
);
|
||||
let errors = super::errors_to_info(&outcome.errors);
|
||||
let skills_metadata = super::skills_to_info(&outcome.skills, &outcome.disabled_paths);
|
||||
skills.push(SkillsListEntry {
|
||||
cwd,
|
||||
skills: skills_metadata,
|
||||
errors,
|
||||
});
|
||||
}
|
||||
|
||||
let event = Event {
|
||||
id: sub_id,
|
||||
msg: EventMsg::ListSkillsResponse(ListSkillsResponseEvent { skills }),
|
||||
};
|
||||
sess.send_event_raw(event).await;
|
||||
}
|
||||
|
||||
pub async fn undo(sess: &Arc<Session>, sub_id: String) {
|
||||
let turn_context = sess.new_default_turn_with_sub_id(sub_id).await;
|
||||
sess.spawn_task(turn_context, Vec::new(), UndoTask::new())
|
||||
.await;
|
||||
}
|
||||
|
||||
pub async fn compact(sess: &Arc<Session>, sub_id: String) {
|
||||
let turn_context = sess.new_default_turn_with_sub_id(sub_id).await;
|
||||
|
||||
sess.spawn_task(
|
||||
Arc::clone(&turn_context),
|
||||
vec![UserInput::Text {
|
||||
text: turn_context.compact_prompt().to_string(),
|
||||
// Compaction prompt is synthesized; no UI element ranges to preserve.
|
||||
text_elements: Vec::new(),
|
||||
}],
|
||||
CompactTask,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
pub async fn drop_memories(sess: &Arc<Session>, config: &Arc<Config>, sub_id: String) {
|
||||
let mut errors = Vec::new();
|
||||
|
||||
if let Some(state_db) = sess.services.state_db.as_deref() {
|
||||
if let Err(err) = state_db.clear_memory_data().await {
|
||||
errors.push(format!("failed clearing memory rows from state db: {err}"));
|
||||
}
|
||||
} else {
|
||||
errors.push("state db unavailable; memory rows were not cleared".to_string());
|
||||
}
|
||||
|
||||
let memory_root = crate::memories::memory_root(&config.codex_home);
|
||||
if let Err(err) = crate::memories::clear_memory_root_contents(&memory_root).await {
|
||||
errors.push(format!(
|
||||
"failed clearing memory directory {}: {err}",
|
||||
memory_root.display()
|
||||
));
|
||||
}
|
||||
|
||||
if errors.is_empty() {
|
||||
sess.send_event_raw(Event {
|
||||
id: sub_id,
|
||||
msg: EventMsg::Warning(WarningEvent {
|
||||
message: format!(
|
||||
"Dropped memories at {} and cleared memory rows from state db.",
|
||||
memory_root.display()
|
||||
),
|
||||
}),
|
||||
})
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
|
||||
sess.send_event_raw(Event {
|
||||
id: sub_id,
|
||||
msg: EventMsg::Error(ErrorEvent {
|
||||
message: format!("Memory drop completed with errors: {}", errors.join("; ")),
|
||||
codex_error_info: Some(CodexErrorInfo::Other),
|
||||
}),
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
pub async fn update_memories(sess: &Arc<Session>, config: &Arc<Config>, sub_id: String) {
|
||||
let session_source = {
|
||||
let state = sess.state.lock().await;
|
||||
state.session_configuration.session_source.clone()
|
||||
};
|
||||
|
||||
crate::memories::start_memories_startup_task(sess, Arc::clone(config), &session_source);
|
||||
|
||||
sess.send_event_raw(Event {
|
||||
id: sub_id.clone(),
|
||||
msg: EventMsg::Warning(WarningEvent {
|
||||
message: "Memory update triggered.".to_string(),
|
||||
}),
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
pub async fn thread_rollback(sess: &Arc<Session>, sub_id: String, num_turns: u32) {
|
||||
if num_turns == 0 {
|
||||
sess.send_event_raw(Event {
|
||||
id: sub_id,
|
||||
msg: EventMsg::Error(ErrorEvent {
|
||||
message: "num_turns must be >= 1".to_string(),
|
||||
codex_error_info: Some(CodexErrorInfo::ThreadRollbackFailed),
|
||||
}),
|
||||
})
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
|
||||
let has_active_turn = { sess.active_turn.lock().await.is_some() };
|
||||
if has_active_turn {
|
||||
sess.send_event_raw(Event {
|
||||
id: sub_id,
|
||||
msg: EventMsg::Error(ErrorEvent {
|
||||
message: "Cannot rollback while a turn is in progress.".to_string(),
|
||||
codex_error_info: Some(CodexErrorInfo::ThreadRollbackFailed),
|
||||
}),
|
||||
})
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
|
||||
let turn_context = sess.new_default_turn_with_sub_id(sub_id).await;
|
||||
let rollout_path = {
|
||||
let recorder = {
|
||||
let guard = sess.services.rollout.lock().await;
|
||||
guard.clone()
|
||||
};
|
||||
let Some(recorder) = recorder else {
|
||||
sess.send_event_raw(Event {
|
||||
id: turn_context.sub_id.clone(),
|
||||
msg: EventMsg::Error(ErrorEvent {
|
||||
message: "thread rollback requires a persisted rollout path".to_string(),
|
||||
codex_error_info: Some(CodexErrorInfo::ThreadRollbackFailed),
|
||||
}),
|
||||
})
|
||||
.await;
|
||||
return;
|
||||
};
|
||||
recorder.rollout_path().to_path_buf()
|
||||
};
|
||||
if let Some(recorder) = {
|
||||
let guard = sess.services.rollout.lock().await;
|
||||
guard.clone()
|
||||
} && let Err(err) = recorder.flush().await
|
||||
{
|
||||
sess.send_event_raw(Event {
|
||||
id: turn_context.sub_id.clone(),
|
||||
msg: EventMsg::Error(ErrorEvent {
|
||||
message: format!(
|
||||
"failed to flush rollout `{}` for rollback replay: {err}",
|
||||
rollout_path.display()
|
||||
),
|
||||
codex_error_info: Some(CodexErrorInfo::ThreadRollbackFailed),
|
||||
}),
|
||||
})
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
|
||||
let initial_history =
|
||||
match RolloutRecorder::get_rollout_history(rollout_path.as_path()).await {
|
||||
Ok(history) => history,
|
||||
Err(err) => {
|
||||
sess.send_event_raw(Event {
|
||||
id: turn_context.sub_id.clone(),
|
||||
msg: EventMsg::Error(ErrorEvent {
|
||||
message: format!(
|
||||
"failed to load rollout `{}` for rollback replay: {err}",
|
||||
rollout_path.display()
|
||||
),
|
||||
codex_error_info: Some(CodexErrorInfo::ThreadRollbackFailed),
|
||||
}),
|
||||
})
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let rollback_event = ThreadRolledBackEvent { num_turns };
|
||||
let rollback_msg = EventMsg::ThreadRolledBack(rollback_event.clone());
|
||||
let replay_items = initial_history
|
||||
.get_rollout_items()
|
||||
.into_iter()
|
||||
.chain(std::iter::once(RolloutItem::EventMsg(rollback_msg.clone())))
|
||||
.collect::<Vec<_>>();
|
||||
sess.persist_rollout_items(&[RolloutItem::EventMsg(rollback_msg.clone())])
|
||||
.await;
|
||||
sess.flush_rollout().await;
|
||||
sess.apply_rollout_reconstruction(turn_context.as_ref(), replay_items.as_slice())
|
||||
.await;
|
||||
sess.recompute_token_usage(turn_context.as_ref()).await;
|
||||
|
||||
sess.deliver_event_raw(Event {
|
||||
id: turn_context.sub_id.clone(),
|
||||
msg: rollback_msg,
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
/// Persists the thread name in the session index, updates in-memory state, and emits
|
||||
/// a `ThreadNameUpdated` event on success.
|
||||
///
|
||||
/// This appends the name to `CODEX_HOME/sessions_index.jsonl` via `session_index::append_thread_name` for the
|
||||
/// current `thread_id`, then updates `SessionConfiguration::thread_name`.
|
||||
///
|
||||
/// Returns an error event if the name is empty or session persistence is disabled.
|
||||
pub async fn set_thread_name(sess: &Arc<Session>, sub_id: String, name: String) {
|
||||
let Some(name) = crate::util::normalize_thread_name(&name) else {
|
||||
let event = Event {
|
||||
id: sub_id,
|
||||
msg: EventMsg::Error(ErrorEvent {
|
||||
message: "Thread name cannot be empty.".to_string(),
|
||||
codex_error_info: Some(CodexErrorInfo::BadRequest),
|
||||
}),
|
||||
};
|
||||
sess.send_event_raw(event).await;
|
||||
return;
|
||||
};
|
||||
|
||||
let persistence_enabled = {
|
||||
let rollout = sess.services.rollout.lock().await;
|
||||
rollout.is_some()
|
||||
};
|
||||
if !persistence_enabled {
|
||||
let event = Event {
|
||||
id: sub_id,
|
||||
msg: EventMsg::Error(ErrorEvent {
|
||||
message: "Session persistence is disabled; cannot rename thread.".to_string(),
|
||||
codex_error_info: Some(CodexErrorInfo::Other),
|
||||
}),
|
||||
};
|
||||
sess.send_event_raw(event).await;
|
||||
return;
|
||||
};
|
||||
|
||||
let codex_home = sess.codex_home().await;
|
||||
if let Err(e) =
|
||||
session_index::append_thread_name(&codex_home, sess.conversation_id, &name).await
|
||||
{
|
||||
let event = Event {
|
||||
id: sub_id,
|
||||
msg: EventMsg::Error(ErrorEvent {
|
||||
message: format!("Failed to set thread name: {e}"),
|
||||
codex_error_info: Some(CodexErrorInfo::Other),
|
||||
}),
|
||||
};
|
||||
sess.send_event_raw(event).await;
|
||||
return;
|
||||
}
|
||||
|
||||
{
|
||||
let mut state = sess.state.lock().await;
|
||||
state.session_configuration.thread_name = Some(name.clone());
|
||||
}
|
||||
|
||||
sess.send_event_raw(Event {
|
||||
id: sub_id,
|
||||
msg: EventMsg::ThreadNameUpdated(ThreadNameUpdatedEvent {
|
||||
thread_id: sess.conversation_id,
|
||||
thread_name: Some(name),
|
||||
}),
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
pub async fn shutdown(sess: &Arc<Session>, sub_id: String) -> bool {
|
||||
sess.abort_all_tasks(TurnAbortReason::Interrupted).await;
|
||||
let _ = sess.conversation.shutdown().await;
|
||||
sess.services
|
||||
.unified_exec_manager
|
||||
.terminate_all_processes()
|
||||
.await;
|
||||
sess.guardian_review_session.shutdown().await;
|
||||
info!("Shutting down Codex instance");
|
||||
let history = sess.clone_history().await;
|
||||
let turn_count = history
|
||||
.raw_items()
|
||||
.iter()
|
||||
.filter(|item| is_user_turn_boundary(item))
|
||||
.count();
|
||||
sess.services.session_telemetry.counter(
|
||||
"codex.conversation.turn.count",
|
||||
i64::try_from(turn_count).unwrap_or(0),
|
||||
&[],
|
||||
);
|
||||
|
||||
// Gracefully flush and shutdown rollout recorder on session end so tests
|
||||
// that inspect the rollout file do not race with the background writer.
|
||||
let recorder_opt = {
|
||||
let mut guard = sess.services.rollout.lock().await;
|
||||
guard.take()
|
||||
};
|
||||
if let Some(rec) = recorder_opt
|
||||
&& let Err(e) = rec.shutdown().await
|
||||
{
|
||||
warn!("failed to shutdown rollout recorder: {e}");
|
||||
let event = Event {
|
||||
id: sub_id.clone(),
|
||||
msg: EventMsg::Error(ErrorEvent {
|
||||
message: "Failed to shutdown rollout recorder".to_string(),
|
||||
codex_error_info: Some(CodexErrorInfo::Other),
|
||||
}),
|
||||
};
|
||||
sess.send_event_raw(event).await;
|
||||
}
|
||||
|
||||
let event = Event {
|
||||
id: sub_id,
|
||||
msg: EventMsg::ShutdownComplete,
|
||||
};
|
||||
sess.send_event_raw(event).await;
|
||||
true
|
||||
}
|
||||
|
||||
pub async fn review(
|
||||
sess: &Arc<Session>,
|
||||
config: &Arc<Config>,
|
||||
sub_id: String,
|
||||
review_request: ReviewRequest,
|
||||
) {
|
||||
let turn_context = sess.new_default_turn_with_sub_id(sub_id.clone()).await;
|
||||
sess.maybe_emit_unknown_model_warning_for_turn(turn_context.as_ref())
|
||||
.await;
|
||||
sess.refresh_mcp_servers_if_requested(&turn_context).await;
|
||||
match resolve_review_request(review_request, turn_context.cwd.as_path()) {
|
||||
Ok(resolved) => {
|
||||
spawn_review_thread(
|
||||
Arc::clone(sess),
|
||||
Arc::clone(config),
|
||||
turn_context.clone(),
|
||||
sub_id,
|
||||
resolved,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
Err(err) => {
|
||||
let event = Event {
|
||||
id: sub_id,
|
||||
msg: EventMsg::Error(ErrorEvent {
|
||||
message: err.to_string(),
|
||||
codex_error_info: Some(CodexErrorInfo::Other),
|
||||
}),
|
||||
};
|
||||
sess.send_event(&turn_context, event.msg).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
mod handlers;
|
||||
|
||||
/// Spawn a review thread using the given prompt.
|
||||
async fn spawn_review_thread(
|
||||
|
||||
747
codex-rs/core/src/codex/handlers.rs
Normal file
747
codex-rs/core/src/codex/handlers.rs
Normal file
@@ -0,0 +1,747 @@
|
||||
use crate::codex::Session;
|
||||
use crate::codex::SessionSettingsUpdate;
|
||||
use crate::codex::SteerInputError;
|
||||
|
||||
use crate::codex::spawn_review_thread;
|
||||
use crate::config::Config;
|
||||
|
||||
use crate::mcp::auth::compute_auth_statuses;
|
||||
use crate::mcp::collect_mcp_snapshot_from_manager;
|
||||
use crate::review_prompts::resolve_review_request;
|
||||
use crate::rollout::RolloutRecorder;
|
||||
use crate::rollout::session_index;
|
||||
use crate::tasks::CompactTask;
|
||||
use crate::tasks::UndoTask;
|
||||
use crate::tasks::UserShellCommandMode;
|
||||
use crate::tasks::UserShellCommandTask;
|
||||
use crate::tasks::execute_user_shell_command;
|
||||
use codex_protocol::custom_prompts::CustomPrompt;
|
||||
use codex_protocol::protocol::CodexErrorInfo;
|
||||
use codex_protocol::protocol::ErrorEvent;
|
||||
use codex_protocol::protocol::Event;
|
||||
use codex_protocol::protocol::EventMsg;
|
||||
use codex_protocol::protocol::ListCustomPromptsResponseEvent;
|
||||
use codex_protocol::protocol::ListSkillsResponseEvent;
|
||||
use codex_protocol::protocol::McpServerRefreshConfig;
|
||||
use codex_protocol::protocol::Op;
|
||||
use codex_protocol::protocol::ReviewDecision;
|
||||
use codex_protocol::protocol::ReviewRequest;
|
||||
use codex_protocol::protocol::RolloutItem;
|
||||
use codex_protocol::protocol::SkillsListEntry;
|
||||
use codex_protocol::protocol::ThreadNameUpdatedEvent;
|
||||
use codex_protocol::protocol::ThreadRolledBackEvent;
|
||||
use codex_protocol::protocol::TurnAbortReason;
|
||||
use codex_protocol::protocol::WarningEvent;
|
||||
use codex_protocol::request_permissions::RequestPermissionsResponse;
|
||||
use codex_protocol::request_user_input::RequestUserInputResponse;
|
||||
|
||||
use crate::context_manager::is_user_turn_boundary;
|
||||
use codex_protocol::config_types::CollaborationMode;
|
||||
use codex_protocol::config_types::ModeKind;
|
||||
use codex_protocol::config_types::Settings;
|
||||
use codex_protocol::dynamic_tools::DynamicToolResponse;
|
||||
use codex_protocol::mcp::RequestId as ProtocolRequestId;
|
||||
use codex_protocol::user_input::UserInput;
|
||||
use codex_rmcp_client::ElicitationAction;
|
||||
use codex_rmcp_client::ElicitationResponse;
|
||||
use serde_json::Value;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use tracing::info;
|
||||
use tracing::warn;
|
||||
|
||||
pub async fn interrupt(sess: &Arc<Session>) {
|
||||
sess.interrupt_task().await;
|
||||
}
|
||||
|
||||
pub async fn clean_background_terminals(sess: &Arc<Session>) {
|
||||
sess.close_unified_exec_processes().await;
|
||||
}
|
||||
|
||||
pub async fn override_turn_context(sess: &Session, sub_id: String, updates: SessionSettingsUpdate) {
|
||||
if let Err(err) = sess.update_settings(updates).await {
|
||||
sess.send_event_raw(Event {
|
||||
id: sub_id,
|
||||
msg: EventMsg::Error(ErrorEvent {
|
||||
message: err.to_string(),
|
||||
codex_error_info: Some(CodexErrorInfo::BadRequest),
|
||||
}),
|
||||
})
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn user_input_or_turn(sess: &Arc<Session>, sub_id: String, op: Op) {
|
||||
let (items, updates) = match op {
|
||||
Op::UserTurn {
|
||||
cwd,
|
||||
approval_policy,
|
||||
sandbox_policy,
|
||||
model,
|
||||
effort,
|
||||
summary,
|
||||
service_tier,
|
||||
final_output_json_schema,
|
||||
items,
|
||||
collaboration_mode,
|
||||
personality,
|
||||
} => {
|
||||
let collaboration_mode = collaboration_mode.or_else(|| {
|
||||
Some(CollaborationMode {
|
||||
mode: ModeKind::Default,
|
||||
settings: Settings {
|
||||
model: model.clone(),
|
||||
reasoning_effort: effort,
|
||||
developer_instructions: None,
|
||||
},
|
||||
})
|
||||
});
|
||||
(
|
||||
items,
|
||||
SessionSettingsUpdate {
|
||||
cwd: Some(cwd),
|
||||
approval_policy: Some(approval_policy),
|
||||
approvals_reviewer: None,
|
||||
sandbox_policy: Some(sandbox_policy),
|
||||
windows_sandbox_level: None,
|
||||
collaboration_mode,
|
||||
reasoning_summary: summary,
|
||||
service_tier,
|
||||
final_output_json_schema: Some(final_output_json_schema),
|
||||
personality,
|
||||
app_server_client_name: None,
|
||||
},
|
||||
)
|
||||
}
|
||||
Op::UserInput {
|
||||
items,
|
||||
final_output_json_schema,
|
||||
} => (
|
||||
items,
|
||||
SessionSettingsUpdate {
|
||||
final_output_json_schema: Some(final_output_json_schema),
|
||||
..Default::default()
|
||||
},
|
||||
),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
let Ok(current_context) = sess.new_turn_with_sub_id(sub_id, updates).await else {
|
||||
// new_turn_with_sub_id already emits the error event.
|
||||
return;
|
||||
};
|
||||
sess.maybe_emit_unknown_model_warning_for_turn(current_context.as_ref())
|
||||
.await;
|
||||
current_context.session_telemetry.user_prompt(&items);
|
||||
|
||||
// Attempt to inject input into current task.
|
||||
if let Err(SteerInputError::NoActiveTurn(items)) =
|
||||
sess.steer_input(items, /*expected_turn_id*/ None).await
|
||||
{
|
||||
sess.refresh_mcp_servers_if_requested(¤t_context)
|
||||
.await;
|
||||
sess.spawn_task(
|
||||
Arc::clone(¤t_context),
|
||||
items,
|
||||
crate::tasks::RegularTask::new(),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn run_user_shell_command(sess: &Arc<Session>, sub_id: String, command: String) {
|
||||
if let Some((turn_context, cancellation_token)) =
|
||||
sess.active_turn_context_and_cancellation_token().await
|
||||
{
|
||||
let session = Arc::clone(sess);
|
||||
tokio::spawn(async move {
|
||||
execute_user_shell_command(
|
||||
session,
|
||||
turn_context,
|
||||
command,
|
||||
cancellation_token,
|
||||
UserShellCommandMode::ActiveTurnAuxiliary,
|
||||
)
|
||||
.await;
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
let turn_context = sess.new_default_turn_with_sub_id(sub_id).await;
|
||||
sess.spawn_task(
|
||||
Arc::clone(&turn_context),
|
||||
Vec::new(),
|
||||
UserShellCommandTask::new(command),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
pub async fn resolve_elicitation(
|
||||
sess: &Arc<Session>,
|
||||
server_name: String,
|
||||
request_id: ProtocolRequestId,
|
||||
decision: codex_protocol::approvals::ElicitationAction,
|
||||
content: Option<Value>,
|
||||
meta: Option<Value>,
|
||||
) {
|
||||
let action = match decision {
|
||||
codex_protocol::approvals::ElicitationAction::Accept => ElicitationAction::Accept,
|
||||
codex_protocol::approvals::ElicitationAction::Decline => ElicitationAction::Decline,
|
||||
codex_protocol::approvals::ElicitationAction::Cancel => ElicitationAction::Cancel,
|
||||
};
|
||||
let content = match action {
|
||||
// Preserve the legacy fallback for clients that only send an action.
|
||||
ElicitationAction::Accept => Some(content.unwrap_or_else(|| serde_json::json!({}))),
|
||||
ElicitationAction::Decline | ElicitationAction::Cancel => None,
|
||||
};
|
||||
let response = ElicitationResponse {
|
||||
action,
|
||||
content,
|
||||
meta,
|
||||
};
|
||||
let request_id = match request_id {
|
||||
ProtocolRequestId::String(value) => {
|
||||
rmcp::model::NumberOrString::String(std::sync::Arc::from(value))
|
||||
}
|
||||
ProtocolRequestId::Integer(value) => rmcp::model::NumberOrString::Number(value),
|
||||
};
|
||||
if let Err(err) = sess
|
||||
.resolve_elicitation(server_name, request_id, response)
|
||||
.await
|
||||
{
|
||||
warn!(
|
||||
error = %err,
|
||||
"failed to resolve elicitation request in session"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Propagate a user's exec approval decision to the session.
|
||||
/// Also optionally applies an execpolicy amendment.
|
||||
pub async fn exec_approval(
|
||||
sess: &Arc<Session>,
|
||||
approval_id: String,
|
||||
turn_id: Option<String>,
|
||||
decision: ReviewDecision,
|
||||
) {
|
||||
let event_turn_id = turn_id.unwrap_or_else(|| approval_id.clone());
|
||||
if let ReviewDecision::ApprovedExecpolicyAmendment {
|
||||
proposed_execpolicy_amendment,
|
||||
} = &decision
|
||||
{
|
||||
match sess
|
||||
.persist_execpolicy_amendment(proposed_execpolicy_amendment)
|
||||
.await
|
||||
{
|
||||
Ok(()) => {
|
||||
sess.record_execpolicy_amendment_message(
|
||||
&event_turn_id,
|
||||
proposed_execpolicy_amendment,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
Err(err) => {
|
||||
let message = format!("Failed to apply execpolicy amendment: {err}");
|
||||
tracing::warn!("{message}");
|
||||
let warning = EventMsg::Warning(WarningEvent { message });
|
||||
sess.send_event_raw(Event {
|
||||
id: event_turn_id.clone(),
|
||||
msg: warning,
|
||||
})
|
||||
.await;
|
||||
}
|
||||
}
|
||||
}
|
||||
match decision {
|
||||
ReviewDecision::Abort => {
|
||||
sess.interrupt_task().await;
|
||||
}
|
||||
other => sess.notify_approval(&approval_id, other).await,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn patch_approval(sess: &Arc<Session>, id: String, decision: ReviewDecision) {
|
||||
match decision {
|
||||
ReviewDecision::Abort => {
|
||||
sess.interrupt_task().await;
|
||||
}
|
||||
other => sess.notify_approval(&id, other).await,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn request_user_input_response(
|
||||
sess: &Arc<Session>,
|
||||
id: String,
|
||||
response: RequestUserInputResponse,
|
||||
) {
|
||||
sess.notify_user_input_response(&id, response).await;
|
||||
}
|
||||
|
||||
pub async fn request_permissions_response(
|
||||
sess: &Arc<Session>,
|
||||
id: String,
|
||||
response: RequestPermissionsResponse,
|
||||
) {
|
||||
sess.notify_request_permissions_response(&id, response)
|
||||
.await;
|
||||
}
|
||||
|
||||
pub async fn dynamic_tool_response(sess: &Arc<Session>, id: String, response: DynamicToolResponse) {
|
||||
sess.notify_dynamic_tool_response(&id, response).await;
|
||||
}
|
||||
|
||||
pub async fn add_to_history(sess: &Arc<Session>, config: &Arc<Config>, text: String) {
|
||||
let id = sess.conversation_id;
|
||||
let config = Arc::clone(config);
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = crate::message_history::append_entry(&text, &id, &config).await {
|
||||
warn!("failed to append to message history: {e}");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
pub async fn get_history_entry_request(
|
||||
sess: &Arc<Session>,
|
||||
config: &Arc<Config>,
|
||||
sub_id: String,
|
||||
offset: usize,
|
||||
log_id: u64,
|
||||
) {
|
||||
let config = Arc::clone(config);
|
||||
let sess_clone = Arc::clone(sess);
|
||||
|
||||
tokio::spawn(async move {
|
||||
// Run lookup in blocking thread because it does file IO + locking.
|
||||
let entry_opt = tokio::task::spawn_blocking(move || {
|
||||
crate::message_history::lookup(log_id, offset, &config)
|
||||
})
|
||||
.await
|
||||
.unwrap_or(None);
|
||||
|
||||
let event = Event {
|
||||
id: sub_id,
|
||||
msg: EventMsg::GetHistoryEntryResponse(crate::protocol::GetHistoryEntryResponseEvent {
|
||||
offset,
|
||||
log_id,
|
||||
entry: entry_opt.map(|e| codex_protocol::message_history::HistoryEntry {
|
||||
conversation_id: e.session_id,
|
||||
ts: e.ts,
|
||||
text: e.text,
|
||||
}),
|
||||
}),
|
||||
};
|
||||
|
||||
sess_clone.send_event_raw(event).await;
|
||||
});
|
||||
}
|
||||
|
||||
pub async fn refresh_mcp_servers(sess: &Arc<Session>, refresh_config: McpServerRefreshConfig) {
|
||||
let mut guard = sess.pending_mcp_server_refresh_config.lock().await;
|
||||
*guard = Some(refresh_config);
|
||||
}
|
||||
|
||||
pub async fn reload_user_config(sess: &Arc<Session>) {
|
||||
sess.reload_user_config_layer().await;
|
||||
}
|
||||
|
||||
pub async fn list_mcp_tools(sess: &Session, config: &Arc<Config>, sub_id: String) {
|
||||
let mcp_connection_manager = sess.services.mcp_connection_manager.read().await;
|
||||
let auth = sess.services.auth_manager.auth().await;
|
||||
let mcp_servers = sess
|
||||
.services
|
||||
.mcp_manager
|
||||
.effective_servers(config, auth.as_ref());
|
||||
let snapshot = collect_mcp_snapshot_from_manager(
|
||||
&mcp_connection_manager,
|
||||
compute_auth_statuses(mcp_servers.iter(), config.mcp_oauth_credentials_store_mode).await,
|
||||
)
|
||||
.await;
|
||||
let event = Event {
|
||||
id: sub_id,
|
||||
msg: EventMsg::McpListToolsResponse(snapshot),
|
||||
};
|
||||
sess.send_event_raw(event).await;
|
||||
}
|
||||
|
||||
pub async fn list_custom_prompts(sess: &Session, sub_id: String) {
|
||||
let custom_prompts: Vec<CustomPrompt> =
|
||||
if let Some(dir) = crate::custom_prompts::default_prompts_dir() {
|
||||
crate::custom_prompts::discover_prompts_in(&dir).await
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
|
||||
let event = Event {
|
||||
id: sub_id,
|
||||
msg: EventMsg::ListCustomPromptsResponse(ListCustomPromptsResponseEvent { custom_prompts }),
|
||||
};
|
||||
sess.send_event_raw(event).await;
|
||||
}
|
||||
|
||||
pub async fn list_skills(sess: &Session, sub_id: String, cwds: Vec<PathBuf>, force_reload: bool) {
|
||||
let (cwds, session_source) = if cwds.is_empty() {
|
||||
let state = sess.state.lock().await;
|
||||
(
|
||||
vec![state.session_configuration.cwd.clone()],
|
||||
state.session_configuration.session_source.clone(),
|
||||
)
|
||||
} else {
|
||||
let state = sess.state.lock().await;
|
||||
(cwds, state.session_configuration.session_source.clone())
|
||||
};
|
||||
|
||||
let skills_manager = &sess.services.skills_manager;
|
||||
let mut skills = Vec::new();
|
||||
for cwd in cwds {
|
||||
let outcome = crate::skills::filter_skill_load_outcome_for_session_source(
|
||||
skills_manager.skills_for_cwd(&cwd, force_reload).await,
|
||||
&session_source,
|
||||
);
|
||||
let errors = super::errors_to_info(&outcome.errors);
|
||||
let skills_metadata = super::skills_to_info(&outcome.skills, &outcome.disabled_paths);
|
||||
skills.push(SkillsListEntry {
|
||||
cwd,
|
||||
skills: skills_metadata,
|
||||
errors,
|
||||
});
|
||||
}
|
||||
|
||||
let event = Event {
|
||||
id: sub_id,
|
||||
msg: EventMsg::ListSkillsResponse(ListSkillsResponseEvent { skills }),
|
||||
};
|
||||
sess.send_event_raw(event).await;
|
||||
}
|
||||
|
||||
pub async fn undo(sess: &Arc<Session>, sub_id: String) {
|
||||
let turn_context = sess.new_default_turn_with_sub_id(sub_id).await;
|
||||
sess.spawn_task(turn_context, Vec::new(), UndoTask::new())
|
||||
.await;
|
||||
}
|
||||
|
||||
pub async fn compact(sess: &Arc<Session>, sub_id: String) {
|
||||
let turn_context = sess.new_default_turn_with_sub_id(sub_id).await;
|
||||
|
||||
sess.spawn_task(
|
||||
Arc::clone(&turn_context),
|
||||
vec![UserInput::Text {
|
||||
text: turn_context.compact_prompt().to_string(),
|
||||
// Compaction prompt is synthesized; no UI element ranges to preserve.
|
||||
text_elements: Vec::new(),
|
||||
}],
|
||||
CompactTask,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
pub async fn drop_memories(sess: &Arc<Session>, config: &Arc<Config>, sub_id: String) {
|
||||
let mut errors = Vec::new();
|
||||
|
||||
if let Some(state_db) = sess.services.state_db.as_deref() {
|
||||
if let Err(err) = state_db.clear_memory_data().await {
|
||||
errors.push(format!("failed clearing memory rows from state db: {err}"));
|
||||
}
|
||||
} else {
|
||||
errors.push("state db unavailable; memory rows were not cleared".to_string());
|
||||
}
|
||||
|
||||
let memory_root = crate::memories::memory_root(&config.codex_home);
|
||||
if let Err(err) = crate::memories::clear_memory_root_contents(&memory_root).await {
|
||||
errors.push(format!(
|
||||
"failed clearing memory directory {}: {err}",
|
||||
memory_root.display()
|
||||
));
|
||||
}
|
||||
|
||||
if errors.is_empty() {
|
||||
sess.send_event_raw(Event {
|
||||
id: sub_id,
|
||||
msg: EventMsg::Warning(WarningEvent {
|
||||
message: format!(
|
||||
"Dropped memories at {} and cleared memory rows from state db.",
|
||||
memory_root.display()
|
||||
),
|
||||
}),
|
||||
})
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
|
||||
sess.send_event_raw(Event {
|
||||
id: sub_id,
|
||||
msg: EventMsg::Error(ErrorEvent {
|
||||
message: format!("Memory drop completed with errors: {}", errors.join("; ")),
|
||||
codex_error_info: Some(CodexErrorInfo::Other),
|
||||
}),
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
pub async fn update_memories(sess: &Arc<Session>, config: &Arc<Config>, sub_id: String) {
|
||||
let session_source = {
|
||||
let state = sess.state.lock().await;
|
||||
state.session_configuration.session_source.clone()
|
||||
};
|
||||
|
||||
crate::memories::start_memories_startup_task(sess, Arc::clone(config), &session_source);
|
||||
|
||||
sess.send_event_raw(Event {
|
||||
id: sub_id.clone(),
|
||||
msg: EventMsg::Warning(WarningEvent {
|
||||
message: "Memory update triggered.".to_string(),
|
||||
}),
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
pub async fn thread_rollback(sess: &Arc<Session>, sub_id: String, num_turns: u32) {
|
||||
if num_turns == 0 {
|
||||
sess.send_event_raw(Event {
|
||||
id: sub_id,
|
||||
msg: EventMsg::Error(ErrorEvent {
|
||||
message: "num_turns must be >= 1".to_string(),
|
||||
codex_error_info: Some(CodexErrorInfo::ThreadRollbackFailed),
|
||||
}),
|
||||
})
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
|
||||
let has_active_turn = { sess.active_turn.lock().await.is_some() };
|
||||
if has_active_turn {
|
||||
sess.send_event_raw(Event {
|
||||
id: sub_id,
|
||||
msg: EventMsg::Error(ErrorEvent {
|
||||
message: "Cannot rollback while a turn is in progress.".to_string(),
|
||||
codex_error_info: Some(CodexErrorInfo::ThreadRollbackFailed),
|
||||
}),
|
||||
})
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
|
||||
let turn_context = sess.new_default_turn_with_sub_id(sub_id).await;
|
||||
let rollout_path = {
|
||||
let recorder = {
|
||||
let guard = sess.services.rollout.lock().await;
|
||||
guard.clone()
|
||||
};
|
||||
let Some(recorder) = recorder else {
|
||||
sess.send_event_raw(Event {
|
||||
id: turn_context.sub_id.clone(),
|
||||
msg: EventMsg::Error(ErrorEvent {
|
||||
message: "thread rollback requires a persisted rollout path".to_string(),
|
||||
codex_error_info: Some(CodexErrorInfo::ThreadRollbackFailed),
|
||||
}),
|
||||
})
|
||||
.await;
|
||||
return;
|
||||
};
|
||||
recorder.rollout_path().to_path_buf()
|
||||
};
|
||||
if let Some(recorder) = {
|
||||
let guard = sess.services.rollout.lock().await;
|
||||
guard.clone()
|
||||
} && let Err(err) = recorder.flush().await
|
||||
{
|
||||
sess.send_event_raw(Event {
|
||||
id: turn_context.sub_id.clone(),
|
||||
msg: EventMsg::Error(ErrorEvent {
|
||||
message: format!(
|
||||
"failed to flush rollout `{}` for rollback replay: {err}",
|
||||
rollout_path.display()
|
||||
),
|
||||
codex_error_info: Some(CodexErrorInfo::ThreadRollbackFailed),
|
||||
}),
|
||||
})
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
|
||||
let initial_history = match RolloutRecorder::get_rollout_history(rollout_path.as_path()).await {
|
||||
Ok(history) => history,
|
||||
Err(err) => {
|
||||
sess.send_event_raw(Event {
|
||||
id: turn_context.sub_id.clone(),
|
||||
msg: EventMsg::Error(ErrorEvent {
|
||||
message: format!(
|
||||
"failed to load rollout `{}` for rollback replay: {err}",
|
||||
rollout_path.display()
|
||||
),
|
||||
codex_error_info: Some(CodexErrorInfo::ThreadRollbackFailed),
|
||||
}),
|
||||
})
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let rollback_event = ThreadRolledBackEvent { num_turns };
|
||||
let rollback_msg = EventMsg::ThreadRolledBack(rollback_event.clone());
|
||||
let replay_items = initial_history
|
||||
.get_rollout_items()
|
||||
.into_iter()
|
||||
.chain(std::iter::once(RolloutItem::EventMsg(rollback_msg.clone())))
|
||||
.collect::<Vec<_>>();
|
||||
sess.persist_rollout_items(&[RolloutItem::EventMsg(rollback_msg.clone())])
|
||||
.await;
|
||||
sess.flush_rollout().await;
|
||||
sess.apply_rollout_reconstruction(turn_context.as_ref(), replay_items.as_slice())
|
||||
.await;
|
||||
sess.recompute_token_usage(turn_context.as_ref()).await;
|
||||
|
||||
sess.deliver_event_raw(Event {
|
||||
id: turn_context.sub_id.clone(),
|
||||
msg: rollback_msg,
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
/// Persists the thread name in the session index, updates in-memory state, and emits
|
||||
/// a `ThreadNameUpdated` event on success.
|
||||
///
|
||||
/// This appends the name to `CODEX_HOME/sessions_index.jsonl` via `session_index::append_thread_name` for the
|
||||
/// current `thread_id`, then updates `SessionConfiguration::thread_name`.
|
||||
///
|
||||
/// Returns an error event if the name is empty or session persistence is disabled.
|
||||
pub async fn set_thread_name(sess: &Arc<Session>, sub_id: String, name: String) {
|
||||
let Some(name) = crate::util::normalize_thread_name(&name) else {
|
||||
let event = Event {
|
||||
id: sub_id,
|
||||
msg: EventMsg::Error(ErrorEvent {
|
||||
message: "Thread name cannot be empty.".to_string(),
|
||||
codex_error_info: Some(CodexErrorInfo::BadRequest),
|
||||
}),
|
||||
};
|
||||
sess.send_event_raw(event).await;
|
||||
return;
|
||||
};
|
||||
|
||||
let persistence_enabled = {
|
||||
let rollout = sess.services.rollout.lock().await;
|
||||
rollout.is_some()
|
||||
};
|
||||
if !persistence_enabled {
|
||||
let event = Event {
|
||||
id: sub_id,
|
||||
msg: EventMsg::Error(ErrorEvent {
|
||||
message: "Session persistence is disabled; cannot rename thread.".to_string(),
|
||||
codex_error_info: Some(CodexErrorInfo::Other),
|
||||
}),
|
||||
};
|
||||
sess.send_event_raw(event).await;
|
||||
return;
|
||||
};
|
||||
|
||||
let codex_home = sess.codex_home().await;
|
||||
if let Err(e) =
|
||||
session_index::append_thread_name(&codex_home, sess.conversation_id, &name).await
|
||||
{
|
||||
let event = Event {
|
||||
id: sub_id,
|
||||
msg: EventMsg::Error(ErrorEvent {
|
||||
message: format!("Failed to set thread name: {e}"),
|
||||
codex_error_info: Some(CodexErrorInfo::Other),
|
||||
}),
|
||||
};
|
||||
sess.send_event_raw(event).await;
|
||||
return;
|
||||
}
|
||||
|
||||
{
|
||||
let mut state = sess.state.lock().await;
|
||||
state.session_configuration.thread_name = Some(name.clone());
|
||||
}
|
||||
|
||||
sess.send_event_raw(Event {
|
||||
id: sub_id,
|
||||
msg: EventMsg::ThreadNameUpdated(ThreadNameUpdatedEvent {
|
||||
thread_id: sess.conversation_id,
|
||||
thread_name: Some(name),
|
||||
}),
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
pub async fn shutdown(sess: &Arc<Session>, sub_id: String) -> bool {
|
||||
sess.abort_all_tasks(TurnAbortReason::Interrupted).await;
|
||||
let _ = sess.conversation.shutdown().await;
|
||||
sess.services
|
||||
.unified_exec_manager
|
||||
.terminate_all_processes()
|
||||
.await;
|
||||
sess.guardian_review_session.shutdown().await;
|
||||
info!("Shutting down Codex instance");
|
||||
let history = sess.clone_history().await;
|
||||
let turn_count = history
|
||||
.raw_items()
|
||||
.iter()
|
||||
.filter(|item| is_user_turn_boundary(item))
|
||||
.count();
|
||||
sess.services.session_telemetry.counter(
|
||||
"codex.conversation.turn.count",
|
||||
i64::try_from(turn_count).unwrap_or(0),
|
||||
&[],
|
||||
);
|
||||
|
||||
// Gracefully flush and shutdown rollout recorder on session end so tests
|
||||
// that inspect the rollout file do not race with the background writer.
|
||||
let recorder_opt = {
|
||||
let mut guard = sess.services.rollout.lock().await;
|
||||
guard.take()
|
||||
};
|
||||
if let Some(rec) = recorder_opt
|
||||
&& let Err(e) = rec.shutdown().await
|
||||
{
|
||||
warn!("failed to shutdown rollout recorder: {e}");
|
||||
let event = Event {
|
||||
id: sub_id.clone(),
|
||||
msg: EventMsg::Error(ErrorEvent {
|
||||
message: "Failed to shutdown rollout recorder".to_string(),
|
||||
codex_error_info: Some(CodexErrorInfo::Other),
|
||||
}),
|
||||
};
|
||||
sess.send_event_raw(event).await;
|
||||
}
|
||||
|
||||
let event = Event {
|
||||
id: sub_id,
|
||||
msg: EventMsg::ShutdownComplete,
|
||||
};
|
||||
sess.send_event_raw(event).await;
|
||||
true
|
||||
}
|
||||
|
||||
pub async fn review(
|
||||
sess: &Arc<Session>,
|
||||
config: &Arc<Config>,
|
||||
sub_id: String,
|
||||
review_request: ReviewRequest,
|
||||
) {
|
||||
let turn_context = sess.new_default_turn_with_sub_id(sub_id.clone()).await;
|
||||
sess.maybe_emit_unknown_model_warning_for_turn(turn_context.as_ref())
|
||||
.await;
|
||||
sess.refresh_mcp_servers_if_requested(&turn_context).await;
|
||||
match resolve_review_request(review_request, turn_context.cwd.as_path()) {
|
||||
Ok(resolved) => {
|
||||
spawn_review_thread(
|
||||
Arc::clone(sess),
|
||||
Arc::clone(config),
|
||||
turn_context.clone(),
|
||||
sub_id,
|
||||
resolved,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
Err(err) => {
|
||||
let event = Event {
|
||||
id: sub_id,
|
||||
msg: EventMsg::Error(ErrorEvent {
|
||||
message: err.to_string(),
|
||||
codex_error_info: Some(CodexErrorInfo::Other),
|
||||
}),
|
||||
};
|
||||
sess.send_event(&turn_context, event.msg).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2465,7 +2465,7 @@ pub(crate) async fn make_session_and_context() -> (Session, TurnContext) {
|
||||
true,
|
||||
));
|
||||
let network_approval = Arc::new(NetworkApprovalService::default());
|
||||
let environment = Arc::new(codex_environment::Environment);
|
||||
let environment = Arc::new(codex_environment::Environment::default());
|
||||
|
||||
let file_watcher = Arc::new(FileWatcher::noop());
|
||||
let services = SessionServices {
|
||||
@@ -3259,7 +3259,7 @@ pub(crate) async fn make_session_and_context_with_dynamic_tools_and_rx(
|
||||
true,
|
||||
));
|
||||
let network_approval = Arc::new(NetworkApprovalService::default());
|
||||
let environment = Arc::new(codex_environment::Environment);
|
||||
let environment = Arc::new(codex_environment::Environment::default());
|
||||
|
||||
let file_watcher = Arc::new(FileWatcher::noop());
|
||||
let services = SessionServices {
|
||||
|
||||
@@ -4328,6 +4328,8 @@ fn test_precedence_fixture_with_o3_profile() -> std::io::Result<()> {
|
||||
web_search_mode: Constrained::allow_any(WebSearchMode::Cached),
|
||||
web_search_config: None,
|
||||
use_experimental_unified_exec_tool: !cfg!(windows),
|
||||
experimental_unified_exec_exec_server_websocket_url: None,
|
||||
experimental_unified_exec_exec_server_workspace_root: None,
|
||||
background_terminal_max_timeout: DEFAULT_MAX_BACKGROUND_TERMINAL_TIMEOUT_MS,
|
||||
ghost_snapshot: GhostSnapshotConfig::default(),
|
||||
features: Features::with_defaults().into(),
|
||||
@@ -4469,6 +4471,8 @@ fn test_precedence_fixture_with_gpt3_profile() -> std::io::Result<()> {
|
||||
web_search_mode: Constrained::allow_any(WebSearchMode::Cached),
|
||||
web_search_config: None,
|
||||
use_experimental_unified_exec_tool: !cfg!(windows),
|
||||
experimental_unified_exec_exec_server_websocket_url: None,
|
||||
experimental_unified_exec_exec_server_workspace_root: None,
|
||||
background_terminal_max_timeout: DEFAULT_MAX_BACKGROUND_TERMINAL_TIMEOUT_MS,
|
||||
ghost_snapshot: GhostSnapshotConfig::default(),
|
||||
features: Features::with_defaults().into(),
|
||||
@@ -4608,6 +4612,8 @@ fn test_precedence_fixture_with_zdr_profile() -> std::io::Result<()> {
|
||||
web_search_mode: Constrained::allow_any(WebSearchMode::Cached),
|
||||
web_search_config: None,
|
||||
use_experimental_unified_exec_tool: !cfg!(windows),
|
||||
experimental_unified_exec_exec_server_websocket_url: None,
|
||||
experimental_unified_exec_exec_server_workspace_root: None,
|
||||
background_terminal_max_timeout: DEFAULT_MAX_BACKGROUND_TERMINAL_TIMEOUT_MS,
|
||||
ghost_snapshot: GhostSnapshotConfig::default(),
|
||||
features: Features::with_defaults().into(),
|
||||
@@ -4733,6 +4739,8 @@ fn test_precedence_fixture_with_gpt5_profile() -> std::io::Result<()> {
|
||||
web_search_mode: Constrained::allow_any(WebSearchMode::Cached),
|
||||
web_search_config: None,
|
||||
use_experimental_unified_exec_tool: !cfg!(windows),
|
||||
experimental_unified_exec_exec_server_websocket_url: None,
|
||||
experimental_unified_exec_exec_server_workspace_root: None,
|
||||
background_terminal_max_timeout: DEFAULT_MAX_BACKGROUND_TERMINAL_TIMEOUT_MS,
|
||||
ghost_snapshot: GhostSnapshotConfig::default(),
|
||||
features: Features::with_defaults().into(),
|
||||
|
||||
@@ -539,6 +539,14 @@ pub struct Config {
|
||||
/// If set to `true`, used only the experimental unified exec tool.
|
||||
pub use_experimental_unified_exec_tool: bool,
|
||||
|
||||
/// When set, connect unified-exec process launches to an existing remote
|
||||
/// `codex-exec-server` websocket endpoint.
|
||||
pub experimental_unified_exec_exec_server_websocket_url: Option<String>,
|
||||
|
||||
/// When set, map local session cwd to this executor-visible workspace root
|
||||
/// for remote unified-exec process launches.
|
||||
pub experimental_unified_exec_exec_server_workspace_root: Option<AbsolutePathBuf>,
|
||||
|
||||
/// Maximum poll window for background terminal output (`write_stdin`), in milliseconds.
|
||||
/// Default: `300000` (5 minutes).
|
||||
pub background_terminal_max_timeout: u64,
|
||||
@@ -1325,6 +1333,13 @@ pub struct ConfigToml {
|
||||
/// Default: `300000` (5 minutes).
|
||||
pub background_terminal_max_timeout: Option<u64>,
|
||||
|
||||
/// Optional websocket URL for connecting to an existing `codex-exec-server`.
|
||||
pub experimental_unified_exec_exec_server_websocket_url: Option<String>,
|
||||
|
||||
/// Optional absolute path to the executor-visible workspace root that
|
||||
/// corresponds to the local session cwd.
|
||||
pub experimental_unified_exec_exec_server_workspace_root: Option<AbsolutePathBuf>,
|
||||
|
||||
/// Optional absolute path to the Node runtime used by `js_repl`.
|
||||
pub js_repl_node_path: Option<AbsolutePathBuf>,
|
||||
|
||||
@@ -2474,6 +2489,20 @@ impl Config {
|
||||
|
||||
let include_apply_patch_tool_flag = features.enabled(Feature::ApplyPatchFreeform);
|
||||
let use_experimental_unified_exec_tool = features.enabled(Feature::UnifiedExec);
|
||||
let experimental_unified_exec_exec_server_websocket_url = cfg
|
||||
.experimental_unified_exec_exec_server_websocket_url
|
||||
.clone()
|
||||
.and_then(|value: String| {
|
||||
let trimmed = value.trim();
|
||||
if trimmed.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(trimmed.to_string())
|
||||
}
|
||||
});
|
||||
let experimental_unified_exec_exec_server_workspace_root = cfg
|
||||
.experimental_unified_exec_exec_server_workspace_root
|
||||
.clone();
|
||||
|
||||
let forced_chatgpt_workspace_id =
|
||||
cfg.forced_chatgpt_workspace_id.as_ref().and_then(|value| {
|
||||
@@ -2768,6 +2797,8 @@ impl Config {
|
||||
web_search_mode: constrained_web_search_mode.value,
|
||||
web_search_config,
|
||||
use_experimental_unified_exec_tool,
|
||||
experimental_unified_exec_exec_server_websocket_url,
|
||||
experimental_unified_exec_exec_server_workspace_root,
|
||||
background_terminal_max_timeout,
|
||||
ghost_snapshot,
|
||||
features,
|
||||
|
||||
@@ -217,6 +217,7 @@ pub async fn list_accessible_connectors_from_mcp_tools_with_options_and_status(
|
||||
};
|
||||
|
||||
let (mcp_connection_manager, cancel_token) = McpConnectionManager::new(
|
||||
String::new(),
|
||||
&mcp_servers,
|
||||
config.mcp_oauth_credentials_store_mode,
|
||||
auth_status_entries,
|
||||
|
||||
@@ -14,6 +14,8 @@ pub mod auth;
|
||||
mod auth_env_telemetry;
|
||||
mod client;
|
||||
mod client_common;
|
||||
#[cfg(test)]
|
||||
mod client_tools;
|
||||
pub mod codex;
|
||||
mod realtime_context;
|
||||
mod realtime_conversation;
|
||||
@@ -55,9 +57,9 @@ mod network_policy_decision;
|
||||
pub mod network_proxy_loader;
|
||||
mod original_image_detail;
|
||||
mod packages;
|
||||
pub use mcp::types::SandboxState;
|
||||
pub use mcp_connection_manager::MCP_SANDBOX_STATE_CAPABILITY;
|
||||
pub use mcp_connection_manager::MCP_SANDBOX_STATE_METHOD;
|
||||
pub use mcp_connection_manager::SandboxState;
|
||||
pub use text_encoding::bytes_to_string_smart;
|
||||
mod mcp_tool_call;
|
||||
mod memories;
|
||||
|
||||
117
codex-rs/core/src/mcp/config.rs
Normal file
117
codex-rs/core/src/mcp/config.rs
Normal file
@@ -0,0 +1,117 @@
|
||||
use std::collections::HashMap;
|
||||
use std::env;
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::CodexAuth;
|
||||
use crate::config::Config;
|
||||
use crate::config::types::McpServerConfig;
|
||||
use crate::config::types::McpServerTransportConfig;
|
||||
|
||||
use super::CODEX_APPS_MCP_SERVER_NAME;
|
||||
|
||||
pub(crate) const CODEX_CONNECTORS_TOKEN_ENV_VAR: &str = "CODEX_CONNECTORS_TOKEN";
|
||||
|
||||
fn codex_apps_mcp_bearer_token_env_var() -> Option<String> {
|
||||
match env::var(CODEX_CONNECTORS_TOKEN_ENV_VAR) {
|
||||
Ok(value) if !value.trim().is_empty() => Some(CODEX_CONNECTORS_TOKEN_ENV_VAR.to_string()),
|
||||
Ok(_) => None,
|
||||
Err(env::VarError::NotPresent) => None,
|
||||
Err(env::VarError::NotUnicode(_)) => Some(CODEX_CONNECTORS_TOKEN_ENV_VAR.to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
fn codex_apps_mcp_bearer_token(auth: Option<&CodexAuth>) -> Option<String> {
|
||||
let token = auth.and_then(|auth| auth.get_token().ok())?;
|
||||
let token = token.trim();
|
||||
if token.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(token.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
fn codex_apps_mcp_http_headers(auth: Option<&CodexAuth>) -> Option<HashMap<String, String>> {
|
||||
let mut headers = HashMap::new();
|
||||
if let Some(token) = codex_apps_mcp_bearer_token(auth) {
|
||||
headers.insert("Authorization".to_string(), format!("Bearer {token}"));
|
||||
}
|
||||
if let Some(account_id) = auth.and_then(CodexAuth::get_account_id) {
|
||||
headers.insert("ChatGPT-Account-ID".to_string(), account_id);
|
||||
}
|
||||
if headers.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(headers)
|
||||
}
|
||||
}
|
||||
|
||||
fn normalize_codex_apps_base_url(base_url: &str) -> String {
|
||||
let mut base_url = base_url.trim_end_matches('/').to_string();
|
||||
if (base_url.starts_with("https://chatgpt.com")
|
||||
|| base_url.starts_with("https://chat.openai.com"))
|
||||
&& !base_url.contains("/backend-api")
|
||||
{
|
||||
base_url = format!("{base_url}/backend-api");
|
||||
}
|
||||
base_url
|
||||
}
|
||||
|
||||
pub(crate) fn codex_apps_mcp_url_for_base_url(base_url: &str) -> String {
|
||||
let base_url = normalize_codex_apps_base_url(base_url);
|
||||
if base_url.contains("/backend-api") {
|
||||
format!("{base_url}/wham/apps")
|
||||
} else if base_url.contains("/api/codex") {
|
||||
format!("{base_url}/apps")
|
||||
} else {
|
||||
format!("{base_url}/api/codex/apps")
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn codex_apps_mcp_url(config: &Config) -> String {
|
||||
codex_apps_mcp_url_for_base_url(&config.chatgpt_base_url)
|
||||
}
|
||||
|
||||
fn codex_apps_mcp_server_config(config: &Config, auth: Option<&CodexAuth>) -> McpServerConfig {
|
||||
let bearer_token_env_var = codex_apps_mcp_bearer_token_env_var();
|
||||
let http_headers = if bearer_token_env_var.is_some() {
|
||||
None
|
||||
} else {
|
||||
codex_apps_mcp_http_headers(auth)
|
||||
};
|
||||
let url = codex_apps_mcp_url(config);
|
||||
|
||||
McpServerConfig {
|
||||
transport: McpServerTransportConfig::StreamableHttp {
|
||||
url,
|
||||
bearer_token_env_var,
|
||||
http_headers,
|
||||
env_http_headers: None,
|
||||
},
|
||||
enabled: true,
|
||||
required: false,
|
||||
disabled_reason: None,
|
||||
startup_timeout_sec: Some(Duration::from_secs(30)),
|
||||
tool_timeout_sec: None,
|
||||
enabled_tools: None,
|
||||
disabled_tools: None,
|
||||
scopes: None,
|
||||
oauth_resource: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn with_codex_apps_mcp(
|
||||
mut servers: HashMap<String, McpServerConfig>,
|
||||
connectors_enabled: bool,
|
||||
auth: Option<&CodexAuth>,
|
||||
config: &Config,
|
||||
) -> HashMap<String, McpServerConfig> {
|
||||
if connectors_enabled {
|
||||
servers.insert(
|
||||
CODEX_APPS_MCP_SERVER_NAME.to_string(),
|
||||
codex_apps_mcp_server_config(config, auth),
|
||||
);
|
||||
} else {
|
||||
servers.remove(CODEX_APPS_MCP_SERVER_NAME);
|
||||
}
|
||||
servers
|
||||
}
|
||||
62
codex-rs/core/src/mcp/manager.rs
Normal file
62
codex-rs/core/src/mcp/manager.rs
Normal file
@@ -0,0 +1,62 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::CodexAuth;
|
||||
use crate::config::Config;
|
||||
use crate::config::types::McpServerConfig;
|
||||
use crate::plugins::PluginsManager;
|
||||
|
||||
use crate::mcp::ToolPluginProvenance;
|
||||
|
||||
pub struct McpManager {
|
||||
plugins_manager: Arc<PluginsManager>,
|
||||
}
|
||||
|
||||
impl McpManager {
|
||||
pub fn new(plugins_manager: Arc<PluginsManager>) -> Self {
|
||||
Self { plugins_manager }
|
||||
}
|
||||
|
||||
pub fn configured_servers(&self, config: &Config) -> HashMap<String, McpServerConfig> {
|
||||
configured_mcp_servers(config, self.plugins_manager.as_ref())
|
||||
}
|
||||
|
||||
pub fn effective_servers(
|
||||
&self,
|
||||
config: &Config,
|
||||
auth: Option<&CodexAuth>,
|
||||
) -> HashMap<String, McpServerConfig> {
|
||||
effective_mcp_servers(config, auth, self.plugins_manager.as_ref())
|
||||
}
|
||||
|
||||
pub fn tool_plugin_provenance(&self, config: &Config) -> ToolPluginProvenance {
|
||||
let loaded_plugins = self.plugins_manager.plugins_for_config(config);
|
||||
ToolPluginProvenance::from_capability_summaries(loaded_plugins.capability_summaries())
|
||||
}
|
||||
}
|
||||
|
||||
fn configured_mcp_servers(
|
||||
config: &Config,
|
||||
plugins_manager: &PluginsManager,
|
||||
) -> HashMap<String, McpServerConfig> {
|
||||
let loaded_plugins = plugins_manager.plugins_for_config(config);
|
||||
let mut servers = config.mcp_servers.get().clone();
|
||||
for (name, plugin_server) in loaded_plugins.effective_mcp_servers() {
|
||||
servers.entry(name).or_insert(plugin_server);
|
||||
}
|
||||
servers
|
||||
}
|
||||
|
||||
fn effective_mcp_servers(
|
||||
config: &Config,
|
||||
auth: Option<&CodexAuth>,
|
||||
plugins_manager: &PluginsManager,
|
||||
) -> HashMap<String, McpServerConfig> {
|
||||
let servers = configured_mcp_servers(config, plugins_manager);
|
||||
crate::mcp::config::with_codex_apps_mcp(
|
||||
servers,
|
||||
config.features.apps_enabled_for_auth(auth),
|
||||
auth,
|
||||
config,
|
||||
)
|
||||
}
|
||||
@@ -1,37 +1,21 @@
|
||||
pub mod auth;
|
||||
pub mod config;
|
||||
pub mod manager;
|
||||
mod skill_dependencies;
|
||||
pub mod snapshot;
|
||||
pub mod types;
|
||||
|
||||
pub(crate) use config::with_codex_apps_mcp;
|
||||
pub(crate) use manager::McpManager;
|
||||
pub(crate) use skill_dependencies::maybe_prompt_and_install_mcp_dependencies;
|
||||
pub(crate) use snapshot::collect_mcp_snapshot_from_manager;
|
||||
pub(crate) use types::split_qualified_tool_name;
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::env;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use async_channel::unbounded;
|
||||
use codex_protocol::mcp::Resource;
|
||||
use codex_protocol::mcp::ResourceTemplate;
|
||||
use codex_protocol::mcp::Tool;
|
||||
use codex_protocol::protocol::McpListToolsResponseEvent;
|
||||
use codex_protocol::protocol::SandboxPolicy;
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::AuthManager;
|
||||
use crate::CodexAuth;
|
||||
use crate::config::Config;
|
||||
use crate::config::types::McpServerConfig;
|
||||
use crate::config::types::McpServerTransportConfig;
|
||||
use crate::mcp::auth::compute_auth_statuses;
|
||||
use crate::mcp_connection_manager::McpConnectionManager;
|
||||
use crate::mcp_connection_manager::SandboxState;
|
||||
use crate::mcp_connection_manager::codex_apps_tools_cache_key;
|
||||
use crate::plugins::PluginCapabilitySummary;
|
||||
use crate::plugins::PluginsManager;
|
||||
|
||||
const MCP_TOOL_NAME_PREFIX: &str = "mcp";
|
||||
const MCP_TOOL_NAME_DELIMITER: &str = "__";
|
||||
pub(crate) const CODEX_APPS_MCP_SERVER_NAME: &str = "codex_apps";
|
||||
const CODEX_CONNECTORS_TOKEN_ENV_VAR: &str = "CODEX_CONNECTORS_TOKEN";
|
||||
|
||||
#[derive(Debug, Clone, Default, PartialEq, Eq)]
|
||||
pub struct ToolPluginProvenance {
|
||||
@@ -54,7 +38,9 @@ impl ToolPluginProvenance {
|
||||
.unwrap_or(&[])
|
||||
}
|
||||
|
||||
fn from_capability_summaries(capability_summaries: &[PluginCapabilitySummary]) -> Self {
|
||||
pub(crate) fn from_capability_summaries(
|
||||
capability_summaries: &[PluginCapabilitySummary],
|
||||
) -> Self {
|
||||
let mut tool_plugin_provenance = Self::default();
|
||||
for plugin in capability_summaries {
|
||||
for connector_id in &plugin.app_connector_ids {
|
||||
@@ -91,358 +77,6 @@ impl ToolPluginProvenance {
|
||||
}
|
||||
}
|
||||
|
||||
fn codex_apps_mcp_bearer_token_env_var() -> Option<String> {
|
||||
match env::var(CODEX_CONNECTORS_TOKEN_ENV_VAR) {
|
||||
Ok(value) if !value.trim().is_empty() => Some(CODEX_CONNECTORS_TOKEN_ENV_VAR.to_string()),
|
||||
Ok(_) => None,
|
||||
Err(env::VarError::NotPresent) => None,
|
||||
Err(env::VarError::NotUnicode(_)) => Some(CODEX_CONNECTORS_TOKEN_ENV_VAR.to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
fn codex_apps_mcp_bearer_token(auth: Option<&CodexAuth>) -> Option<String> {
|
||||
let token = auth.and_then(|auth| auth.get_token().ok())?;
|
||||
let token = token.trim();
|
||||
if token.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(token.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
fn codex_apps_mcp_http_headers(auth: Option<&CodexAuth>) -> Option<HashMap<String, String>> {
|
||||
let mut headers = HashMap::new();
|
||||
if let Some(token) = codex_apps_mcp_bearer_token(auth) {
|
||||
headers.insert("Authorization".to_string(), format!("Bearer {token}"));
|
||||
}
|
||||
if let Some(account_id) = auth.and_then(CodexAuth::get_account_id) {
|
||||
headers.insert("ChatGPT-Account-ID".to_string(), account_id);
|
||||
}
|
||||
if headers.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(headers)
|
||||
}
|
||||
}
|
||||
|
||||
fn normalize_codex_apps_base_url(base_url: &str) -> String {
|
||||
let mut base_url = base_url.trim_end_matches('/').to_string();
|
||||
if (base_url.starts_with("https://chatgpt.com")
|
||||
|| base_url.starts_with("https://chat.openai.com"))
|
||||
&& !base_url.contains("/backend-api")
|
||||
{
|
||||
base_url = format!("{base_url}/backend-api");
|
||||
}
|
||||
base_url
|
||||
}
|
||||
|
||||
fn codex_apps_mcp_url_for_base_url(base_url: &str) -> String {
|
||||
let base_url = normalize_codex_apps_base_url(base_url);
|
||||
if base_url.contains("/backend-api") {
|
||||
format!("{base_url}/wham/apps")
|
||||
} else if base_url.contains("/api/codex") {
|
||||
format!("{base_url}/apps")
|
||||
} else {
|
||||
format!("{base_url}/api/codex/apps")
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn codex_apps_mcp_url(config: &Config) -> String {
|
||||
codex_apps_mcp_url_for_base_url(&config.chatgpt_base_url)
|
||||
}
|
||||
|
||||
fn codex_apps_mcp_server_config(config: &Config, auth: Option<&CodexAuth>) -> McpServerConfig {
|
||||
let bearer_token_env_var = codex_apps_mcp_bearer_token_env_var();
|
||||
let http_headers = if bearer_token_env_var.is_some() {
|
||||
None
|
||||
} else {
|
||||
codex_apps_mcp_http_headers(auth)
|
||||
};
|
||||
let url = codex_apps_mcp_url(config);
|
||||
|
||||
McpServerConfig {
|
||||
transport: McpServerTransportConfig::StreamableHttp {
|
||||
url,
|
||||
bearer_token_env_var,
|
||||
http_headers,
|
||||
env_http_headers: None,
|
||||
},
|
||||
enabled: true,
|
||||
required: false,
|
||||
disabled_reason: None,
|
||||
startup_timeout_sec: Some(Duration::from_secs(30)),
|
||||
tool_timeout_sec: None,
|
||||
enabled_tools: None,
|
||||
disabled_tools: None,
|
||||
scopes: None,
|
||||
oauth_resource: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn with_codex_apps_mcp(
|
||||
mut servers: HashMap<String, McpServerConfig>,
|
||||
connectors_enabled: bool,
|
||||
auth: Option<&CodexAuth>,
|
||||
config: &Config,
|
||||
) -> HashMap<String, McpServerConfig> {
|
||||
if connectors_enabled {
|
||||
servers.insert(
|
||||
CODEX_APPS_MCP_SERVER_NAME.to_string(),
|
||||
codex_apps_mcp_server_config(config, auth),
|
||||
);
|
||||
} else {
|
||||
servers.remove(CODEX_APPS_MCP_SERVER_NAME);
|
||||
}
|
||||
servers
|
||||
}
|
||||
|
||||
pub struct McpManager {
|
||||
plugins_manager: Arc<PluginsManager>,
|
||||
}
|
||||
|
||||
impl McpManager {
|
||||
pub fn new(plugins_manager: Arc<PluginsManager>) -> Self {
|
||||
Self { plugins_manager }
|
||||
}
|
||||
|
||||
pub fn configured_servers(&self, config: &Config) -> HashMap<String, McpServerConfig> {
|
||||
configured_mcp_servers(config, self.plugins_manager.as_ref())
|
||||
}
|
||||
|
||||
pub fn effective_servers(
|
||||
&self,
|
||||
config: &Config,
|
||||
auth: Option<&CodexAuth>,
|
||||
) -> HashMap<String, McpServerConfig> {
|
||||
effective_mcp_servers(config, auth, self.plugins_manager.as_ref())
|
||||
}
|
||||
|
||||
pub fn tool_plugin_provenance(&self, config: &Config) -> ToolPluginProvenance {
|
||||
let loaded_plugins = self.plugins_manager.plugins_for_config(config);
|
||||
ToolPluginProvenance::from_capability_summaries(loaded_plugins.capability_summaries())
|
||||
}
|
||||
}
|
||||
|
||||
fn configured_mcp_servers(
|
||||
config: &Config,
|
||||
plugins_manager: &PluginsManager,
|
||||
) -> HashMap<String, McpServerConfig> {
|
||||
let loaded_plugins = plugins_manager.plugins_for_config(config);
|
||||
let mut servers = config.mcp_servers.get().clone();
|
||||
for (name, plugin_server) in loaded_plugins.effective_mcp_servers() {
|
||||
servers.entry(name).or_insert(plugin_server);
|
||||
}
|
||||
servers
|
||||
}
|
||||
|
||||
fn effective_mcp_servers(
|
||||
config: &Config,
|
||||
auth: Option<&CodexAuth>,
|
||||
plugins_manager: &PluginsManager,
|
||||
) -> HashMap<String, McpServerConfig> {
|
||||
let servers = configured_mcp_servers(config, plugins_manager);
|
||||
with_codex_apps_mcp(
|
||||
servers,
|
||||
config.features.apps_enabled_for_auth(auth),
|
||||
auth,
|
||||
config,
|
||||
)
|
||||
}
|
||||
|
||||
pub async fn collect_mcp_snapshot(config: &Config) -> McpListToolsResponseEvent {
|
||||
let auth_manager = AuthManager::shared(
|
||||
config.codex_home.clone(),
|
||||
/*enable_codex_api_key_env*/ false,
|
||||
config.cli_auth_credentials_store_mode,
|
||||
);
|
||||
let auth = auth_manager.auth().await;
|
||||
let mcp_manager = McpManager::new(Arc::new(PluginsManager::new(config.codex_home.clone())));
|
||||
let mcp_servers = mcp_manager.effective_servers(config, auth.as_ref());
|
||||
let tool_plugin_provenance = mcp_manager.tool_plugin_provenance(config);
|
||||
if mcp_servers.is_empty() {
|
||||
return McpListToolsResponseEvent {
|
||||
tools: HashMap::new(),
|
||||
resources: HashMap::new(),
|
||||
resource_templates: HashMap::new(),
|
||||
auth_statuses: HashMap::new(),
|
||||
};
|
||||
}
|
||||
|
||||
let auth_status_entries =
|
||||
compute_auth_statuses(mcp_servers.iter(), config.mcp_oauth_credentials_store_mode).await;
|
||||
|
||||
let (tx_event, rx_event) = unbounded();
|
||||
drop(rx_event);
|
||||
|
||||
// Use ReadOnly sandbox policy for MCP snapshot collection (safest default)
|
||||
let sandbox_state = SandboxState {
|
||||
sandbox_policy: SandboxPolicy::new_read_only_policy(),
|
||||
codex_linux_sandbox_exe: config.codex_linux_sandbox_exe.clone(),
|
||||
sandbox_cwd: env::current_dir().unwrap_or_else(|_| PathBuf::from("/")),
|
||||
use_legacy_landlock: config.features.use_legacy_landlock(),
|
||||
};
|
||||
|
||||
let (mcp_connection_manager, cancel_token) = McpConnectionManager::new(
|
||||
&mcp_servers,
|
||||
config.mcp_oauth_credentials_store_mode,
|
||||
auth_status_entries.clone(),
|
||||
&config.permissions.approval_policy,
|
||||
tx_event,
|
||||
sandbox_state,
|
||||
config.codex_home.clone(),
|
||||
codex_apps_tools_cache_key(auth.as_ref()),
|
||||
tool_plugin_provenance,
|
||||
)
|
||||
.await;
|
||||
|
||||
let snapshot =
|
||||
collect_mcp_snapshot_from_manager(&mcp_connection_manager, auth_status_entries).await;
|
||||
|
||||
cancel_token.cancel();
|
||||
|
||||
snapshot
|
||||
}
|
||||
|
||||
pub fn split_qualified_tool_name(qualified_name: &str) -> Option<(String, String)> {
|
||||
let mut parts = qualified_name.split(MCP_TOOL_NAME_DELIMITER);
|
||||
let prefix = parts.next()?;
|
||||
if prefix != MCP_TOOL_NAME_PREFIX {
|
||||
return None;
|
||||
}
|
||||
let server_name = parts.next()?;
|
||||
let tool_name: String = parts.collect::<Vec<_>>().join(MCP_TOOL_NAME_DELIMITER);
|
||||
if tool_name.is_empty() {
|
||||
return None;
|
||||
}
|
||||
Some((server_name.to_string(), tool_name))
|
||||
}
|
||||
|
||||
pub fn group_tools_by_server(
|
||||
tools: &HashMap<String, Tool>,
|
||||
) -> HashMap<String, HashMap<String, Tool>> {
|
||||
let mut grouped = HashMap::new();
|
||||
for (qualified_name, tool) in tools {
|
||||
if let Some((server_name, tool_name)) = split_qualified_tool_name(qualified_name) {
|
||||
grouped
|
||||
.entry(server_name)
|
||||
.or_insert_with(HashMap::new)
|
||||
.insert(tool_name, tool.clone());
|
||||
}
|
||||
}
|
||||
grouped
|
||||
}
|
||||
|
||||
pub(crate) async fn collect_mcp_snapshot_from_manager(
|
||||
mcp_connection_manager: &McpConnectionManager,
|
||||
auth_status_entries: HashMap<String, crate::mcp::auth::McpAuthStatusEntry>,
|
||||
) -> McpListToolsResponseEvent {
|
||||
let (tools, resources, resource_templates) = tokio::join!(
|
||||
mcp_connection_manager.list_all_tools(),
|
||||
mcp_connection_manager.list_all_resources(),
|
||||
mcp_connection_manager.list_all_resource_templates(),
|
||||
);
|
||||
|
||||
let auth_statuses = auth_status_entries
|
||||
.iter()
|
||||
.map(|(name, entry)| (name.clone(), entry.auth_status))
|
||||
.collect();
|
||||
|
||||
let tools = tools
|
||||
.into_iter()
|
||||
.filter_map(|(name, tool)| match serde_json::to_value(tool.tool) {
|
||||
Ok(value) => match Tool::from_mcp_value(value) {
|
||||
Ok(tool) => Some((name, tool)),
|
||||
Err(err) => {
|
||||
tracing::warn!("Failed to convert MCP tool '{name}': {err}");
|
||||
None
|
||||
}
|
||||
},
|
||||
Err(err) => {
|
||||
tracing::warn!("Failed to serialize MCP tool '{name}': {err}");
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let resources = resources
|
||||
.into_iter()
|
||||
.map(|(name, resources)| {
|
||||
let resources = resources
|
||||
.into_iter()
|
||||
.filter_map(|resource| match serde_json::to_value(resource) {
|
||||
Ok(value) => match Resource::from_mcp_value(value.clone()) {
|
||||
Ok(resource) => Some(resource),
|
||||
Err(err) => {
|
||||
let (uri, resource_name) = match value {
|
||||
Value::Object(obj) => (
|
||||
obj.get("uri")
|
||||
.and_then(|v| v.as_str().map(ToString::to_string)),
|
||||
obj.get("name")
|
||||
.and_then(|v| v.as_str().map(ToString::to_string)),
|
||||
),
|
||||
_ => (None, None),
|
||||
};
|
||||
|
||||
tracing::warn!(
|
||||
"Failed to convert MCP resource (uri={uri:?}, name={resource_name:?}): {err}"
|
||||
);
|
||||
None
|
||||
}
|
||||
},
|
||||
Err(err) => {
|
||||
tracing::warn!("Failed to serialize MCP resource: {err}");
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
(name, resources)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let resource_templates = resource_templates
|
||||
.into_iter()
|
||||
.map(|(name, templates)| {
|
||||
let templates = templates
|
||||
.into_iter()
|
||||
.filter_map(|template| match serde_json::to_value(template) {
|
||||
Ok(value) => match ResourceTemplate::from_mcp_value(value.clone()) {
|
||||
Ok(template) => Some(template),
|
||||
Err(err) => {
|
||||
let (uri_template, template_name) = match value {
|
||||
Value::Object(obj) => (
|
||||
obj.get("uriTemplate")
|
||||
.or_else(|| obj.get("uri_template"))
|
||||
.and_then(|v| v.as_str().map(ToString::to_string)),
|
||||
obj.get("name")
|
||||
.and_then(|v| v.as_str().map(ToString::to_string)),
|
||||
),
|
||||
_ => (None, None),
|
||||
};
|
||||
|
||||
tracing::warn!(
|
||||
"Failed to convert MCP resource template (uri_template={uri_template:?}, name={template_name:?}): {err}"
|
||||
);
|
||||
None
|
||||
}
|
||||
},
|
||||
Err(err) => {
|
||||
tracing::warn!("Failed to serialize MCP resource template: {err}");
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
(name, templates)
|
||||
})
|
||||
.collect();
|
||||
|
||||
McpListToolsResponseEvent {
|
||||
tools,
|
||||
resources,
|
||||
resource_templates,
|
||||
auth_statuses,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "mod_tests.rs"]
|
||||
mod tests;
|
||||
|
||||
@@ -1,12 +1,17 @@
|
||||
use super::*;
|
||||
use crate::config::CONFIG_TOML_FILE;
|
||||
use crate::config::ConfigBuilder;
|
||||
use crate::config::types::McpServerConfig;
|
||||
use crate::config::types::McpServerTransportConfig;
|
||||
use crate::features::Feature;
|
||||
use crate::plugins::AppConnectorId;
|
||||
use crate::plugins::PluginCapabilitySummary;
|
||||
use crate::plugins::PluginsManager;
|
||||
use codex_protocol::mcp::Tool;
|
||||
use pretty_assertions::assert_eq;
|
||||
use std::fs;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use toml::Value;
|
||||
|
||||
fn write_file(path: &Path, contents: &str) {
|
||||
@@ -79,7 +84,7 @@ fn group_tools_by_server_strips_prefix_and_groups() {
|
||||
expected.insert("alpha".to_string(), expected_alpha);
|
||||
expected.insert("beta".to_string(), expected_beta);
|
||||
|
||||
assert_eq!(group_tools_by_server(&tools), expected);
|
||||
assert_eq!(types::group_tools_by_server(&tools), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -126,19 +131,19 @@ fn tool_plugin_provenance_collects_app_and_mcp_sources() {
|
||||
#[test]
|
||||
fn codex_apps_mcp_url_for_base_url_keeps_existing_paths() {
|
||||
assert_eq!(
|
||||
codex_apps_mcp_url_for_base_url("https://chatgpt.com/backend-api"),
|
||||
config::codex_apps_mcp_url_for_base_url("https://chatgpt.com/backend-api"),
|
||||
"https://chatgpt.com/backend-api/wham/apps"
|
||||
);
|
||||
assert_eq!(
|
||||
codex_apps_mcp_url_for_base_url("https://chat.openai.com"),
|
||||
config::codex_apps_mcp_url_for_base_url("https://chat.openai.com"),
|
||||
"https://chat.openai.com/backend-api/wham/apps"
|
||||
);
|
||||
assert_eq!(
|
||||
codex_apps_mcp_url_for_base_url("http://localhost:8080/api/codex"),
|
||||
config::codex_apps_mcp_url_for_base_url("http://localhost:8080/api/codex"),
|
||||
"http://localhost:8080/api/codex/apps"
|
||||
);
|
||||
assert_eq!(
|
||||
codex_apps_mcp_url_for_base_url("http://localhost:8080"),
|
||||
config::codex_apps_mcp_url_for_base_url("http://localhost:8080"),
|
||||
"http://localhost:8080/api/codex/apps"
|
||||
);
|
||||
}
|
||||
@@ -149,7 +154,7 @@ fn codex_apps_mcp_url_uses_legacy_codex_apps_path() {
|
||||
config.chatgpt_base_url = "https://chatgpt.com".to_string();
|
||||
|
||||
assert_eq!(
|
||||
codex_apps_mcp_url(&config),
|
||||
config::codex_apps_mcp_url(&config),
|
||||
"https://chatgpt.com/backend-api/wham/apps"
|
||||
);
|
||||
}
|
||||
|
||||
191
codex-rs/core/src/mcp/snapshot.rs
Normal file
191
codex-rs/core/src/mcp/snapshot.rs
Normal file
@@ -0,0 +1,191 @@
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_channel::unbounded;
|
||||
use codex_protocol::mcp::Resource;
|
||||
use codex_protocol::mcp::ResourceTemplate;
|
||||
use codex_protocol::mcp::Tool;
|
||||
use codex_protocol::protocol::McpListToolsResponseEvent;
|
||||
use codex_protocol::protocol::SandboxPolicy;
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::AuthManager;
|
||||
use crate::config::Config;
|
||||
use crate::mcp::auth::compute_auth_statuses;
|
||||
use crate::mcp::types::SandboxState;
|
||||
use crate::mcp_connection_manager::McpConnectionManager;
|
||||
use crate::mcp_connection_manager::codex_apps_tools_cache_key;
|
||||
use crate::plugins::PluginsManager;
|
||||
|
||||
use super::McpManager;
|
||||
|
||||
pub async fn collect_mcp_snapshot(config: &Config) -> McpListToolsResponseEvent {
|
||||
let auth_manager = AuthManager::shared(
|
||||
config.codex_home.clone(),
|
||||
/*enable_codex_api_key_env*/ false,
|
||||
config.cli_auth_credentials_store_mode,
|
||||
);
|
||||
let auth = auth_manager.auth().await;
|
||||
let mcp_manager = McpManager::new(Arc::new(PluginsManager::new(config.codex_home.clone())));
|
||||
let mcp_servers = mcp_manager.effective_servers(config, auth.as_ref());
|
||||
let tool_plugin_provenance = mcp_manager.tool_plugin_provenance(config);
|
||||
if mcp_servers.is_empty() {
|
||||
return McpListToolsResponseEvent {
|
||||
tools: HashMap::new(),
|
||||
resources: HashMap::new(),
|
||||
resource_templates: HashMap::new(),
|
||||
auth_statuses: HashMap::new(),
|
||||
};
|
||||
}
|
||||
|
||||
let auth_status_entries =
|
||||
compute_auth_statuses(mcp_servers.iter(), config.mcp_oauth_credentials_store_mode).await;
|
||||
|
||||
let (tx_event, rx_event) = unbounded();
|
||||
drop(rx_event);
|
||||
|
||||
// Use ReadOnly sandbox policy for MCP snapshot collection (safest default)
|
||||
let sandbox_state = SandboxState {
|
||||
sandbox_policy: SandboxPolicy::new_read_only_policy(),
|
||||
codex_linux_sandbox_exe: config.codex_linux_sandbox_exe.clone(),
|
||||
sandbox_cwd: std::env::current_dir().unwrap_or_else(|_| PathBuf::from("/")),
|
||||
use_legacy_landlock: config.features.use_legacy_landlock(),
|
||||
};
|
||||
|
||||
let (mcp_connection_manager, cancel_token) = McpConnectionManager::new(
|
||||
String::new(),
|
||||
&mcp_servers,
|
||||
config.mcp_oauth_credentials_store_mode,
|
||||
auth_status_entries.clone(),
|
||||
&config.permissions.approval_policy,
|
||||
tx_event,
|
||||
sandbox_state,
|
||||
config.codex_home.clone(),
|
||||
codex_apps_tools_cache_key(auth.as_ref()),
|
||||
tool_plugin_provenance,
|
||||
)
|
||||
.await;
|
||||
|
||||
let snapshot =
|
||||
collect_mcp_snapshot_from_manager(&mcp_connection_manager, auth_status_entries).await;
|
||||
|
||||
cancel_token.cancel();
|
||||
|
||||
snapshot
|
||||
}
|
||||
|
||||
pub(crate) async fn collect_mcp_snapshot_from_manager(
|
||||
mcp_connection_manager: &McpConnectionManager,
|
||||
auth_status_entries: HashMap<String, crate::mcp::auth::McpAuthStatusEntry>,
|
||||
) -> McpListToolsResponseEvent {
|
||||
let (tools, resources, resource_templates) = tokio::join!(
|
||||
mcp_connection_manager.list_all_tools(),
|
||||
mcp_connection_manager.list_all_resources(),
|
||||
mcp_connection_manager.list_all_resource_templates(),
|
||||
);
|
||||
|
||||
let auth_statuses = auth_status_entries
|
||||
.iter()
|
||||
.map(|(name, entry)| (name.clone(), entry.auth_status))
|
||||
.collect();
|
||||
|
||||
let tools = tools
|
||||
.into_iter()
|
||||
.filter_map(|(name, tool)| match serde_json::to_value(tool.tool) {
|
||||
Ok(value) => match Tool::from_mcp_value(value) {
|
||||
Ok(tool) => Some((name, tool)),
|
||||
Err(err) => {
|
||||
tracing::warn!("Failed to convert MCP tool '{name}': {err}");
|
||||
None
|
||||
}
|
||||
},
|
||||
Err(err) => {
|
||||
tracing::warn!("Failed to serialize MCP tool '{name}': {err}");
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let resources = resources
|
||||
.into_iter()
|
||||
.map(|(name, resources)| {
|
||||
let resources = resources
|
||||
.into_iter()
|
||||
.filter_map(|resource| match serde_json::to_value(resource) {
|
||||
Ok(value) => match Resource::from_mcp_value(value.clone()) {
|
||||
Ok(resource) => Some(resource),
|
||||
Err(err) => {
|
||||
let (uri, resource_name) = match value {
|
||||
Value::Object(obj) => (
|
||||
obj
|
||||
.get("uri")
|
||||
.and_then(|v| v.as_str().map(ToString::to_string)),
|
||||
obj
|
||||
.get("name")
|
||||
.and_then(|v| v.as_str().map(ToString::to_string)),
|
||||
),
|
||||
_ => (None, None),
|
||||
};
|
||||
|
||||
tracing::warn!(
|
||||
"Failed to convert MCP resource (uri={uri:?}, name={resource_name:?}): {err}"
|
||||
);
|
||||
None
|
||||
}
|
||||
},
|
||||
Err(err) => {
|
||||
tracing::warn!("Failed to serialize MCP resource: {err}");
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
(name, resources)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let resource_templates = resource_templates
|
||||
.into_iter()
|
||||
.map(|(name, templates)| {
|
||||
let templates = templates
|
||||
.into_iter()
|
||||
.filter_map(|template| match serde_json::to_value(template) {
|
||||
Ok(value) => match ResourceTemplate::from_mcp_value(value.clone()) {
|
||||
Ok(template) => Some(template),
|
||||
Err(err) => {
|
||||
let (uri_template, template_name) = match value {
|
||||
Value::Object(obj) => (
|
||||
obj
|
||||
.get("uriTemplate")
|
||||
.or_else(|| obj.get("uri_template"))
|
||||
.and_then(|v| v.as_str().map(ToString::to_string)),
|
||||
obj
|
||||
.get("name")
|
||||
.and_then(|v| v.as_str().map(ToString::to_string)),
|
||||
),
|
||||
_ => (None, None),
|
||||
};
|
||||
|
||||
tracing::warn!(
|
||||
"Failed to convert MCP resource template (uri_template={uri_template:?}, name={template_name:?}): {err}"
|
||||
);
|
||||
None
|
||||
}
|
||||
},
|
||||
Err(err) => {
|
||||
tracing::warn!("Failed to serialize MCP resource template: {err}");
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
(name, templates)
|
||||
})
|
||||
.collect();
|
||||
|
||||
McpListToolsResponseEvent {
|
||||
tools,
|
||||
resources,
|
||||
resource_templates,
|
||||
auth_statuses,
|
||||
}
|
||||
}
|
||||
3
codex-rs/core/src/mcp/types.rs
Normal file
3
codex-rs/core/src/mcp/types.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
pub use codex_mcp_types::SandboxState;
|
||||
pub use codex_mcp_types::group_tools_by_server;
|
||||
pub use codex_mcp_types::split_qualified_tool_name;
|
||||
@@ -40,7 +40,6 @@ use codex_protocol::protocol::McpStartupCompleteEvent;
|
||||
use codex_protocol::protocol::McpStartupFailure;
|
||||
use codex_protocol::protocol::McpStartupStatus;
|
||||
use codex_protocol::protocol::McpStartupUpdateEvent;
|
||||
use codex_protocol::protocol::SandboxPolicy;
|
||||
use codex_rmcp_client::ElicitationResponse;
|
||||
use codex_rmcp_client::OAuthCredentialsStoreMode;
|
||||
use codex_rmcp_client::RmcpClient;
|
||||
@@ -78,11 +77,11 @@ use tracing::instrument;
|
||||
use tracing::warn;
|
||||
use url::Url;
|
||||
|
||||
use crate::codex::INITIAL_SUBMIT_ID;
|
||||
use crate::config::types::McpServerConfig;
|
||||
use crate::config::types::McpServerTransportConfig;
|
||||
use crate::connectors::is_connector_id_allowed;
|
||||
use crate::connectors::sanitize_name;
|
||||
use crate::mcp::types::SandboxState;
|
||||
|
||||
/// Delimiter used to separate the server name from the tool name in a fully
|
||||
/// qualified tool name.
|
||||
@@ -581,19 +580,9 @@ impl AsyncManagedClient {
|
||||
pub const MCP_SANDBOX_STATE_CAPABILITY: &str = "codex/sandbox-state";
|
||||
|
||||
/// Custom MCP request to push sandbox state updates.
|
||||
/// When used, the `params` field of the notification is [`SandboxState`].
|
||||
/// When used, the `params` field of the notification is [`crate::SandboxState`].
|
||||
pub const MCP_SANDBOX_STATE_METHOD: &str = "codex/sandbox-state/update";
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct SandboxState {
|
||||
pub sandbox_policy: SandboxPolicy,
|
||||
pub codex_linux_sandbox_exe: Option<PathBuf>,
|
||||
pub sandbox_cwd: PathBuf,
|
||||
#[serde(default)]
|
||||
pub use_legacy_landlock: bool,
|
||||
}
|
||||
|
||||
/// A thin wrapper around a set of running [`RmcpClient`] instances.
|
||||
pub(crate) struct McpConnectionManager {
|
||||
clients: HashMap<String, AsyncManagedClient>,
|
||||
@@ -633,6 +622,7 @@ impl McpConnectionManager {
|
||||
|
||||
#[allow(clippy::new_ret_no_self, clippy::too_many_arguments)]
|
||||
pub async fn new(
|
||||
submit_id: String,
|
||||
mcp_servers: &HashMap<String, McpServerConfig>,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
auth_entries: HashMap<String, McpAuthStatusEntry>,
|
||||
@@ -657,6 +647,7 @@ impl McpConnectionManager {
|
||||
let cancel_token = cancel_token.child_token();
|
||||
let _ = emit_update(
|
||||
&tx_event,
|
||||
submit_id.as_str(),
|
||||
McpStartupUpdateEvent {
|
||||
server: server_name.clone(),
|
||||
status: McpStartupStatus::Starting,
|
||||
@@ -683,6 +674,7 @@ impl McpConnectionManager {
|
||||
);
|
||||
clients.insert(server_name.clone(), async_managed_client.clone());
|
||||
let tx_event = tx_event.clone();
|
||||
let submit_id = submit_id.clone();
|
||||
let auth_entry = auth_entries.get(&server_name).cloned();
|
||||
let sandbox_state = initial_sandbox_state.clone();
|
||||
join_set.spawn(async move {
|
||||
@@ -715,6 +707,7 @@ impl McpConnectionManager {
|
||||
|
||||
let _ = emit_update(
|
||||
&tx_event,
|
||||
submit_id.as_str(),
|
||||
McpStartupUpdateEvent {
|
||||
server: server_name.clone(),
|
||||
status,
|
||||
@@ -730,6 +723,7 @@ impl McpConnectionManager {
|
||||
server_origins,
|
||||
elicitation_requests: elicitation_requests.clone(),
|
||||
};
|
||||
let submit_id_for_task = submit_id.clone();
|
||||
tokio::spawn(async move {
|
||||
let outcomes = join_set.join_all().await;
|
||||
let mut summary = McpStartupCompleteEvent::default();
|
||||
@@ -747,7 +741,7 @@ impl McpConnectionManager {
|
||||
}
|
||||
let _ = tx_event
|
||||
.send(Event {
|
||||
id: INITIAL_SUBMIT_ID.to_owned(),
|
||||
id: submit_id_for_task,
|
||||
msg: EventMsg::McpStartupComplete(summary),
|
||||
})
|
||||
.await;
|
||||
@@ -1133,11 +1127,12 @@ impl McpConnectionManager {
|
||||
|
||||
async fn emit_update(
|
||||
tx_event: &Sender<Event>,
|
||||
submit_id: &str,
|
||||
update: McpStartupUpdateEvent,
|
||||
) -> Result<(), async_channel::SendError<Event>> {
|
||||
tx_event
|
||||
.send(Event {
|
||||
id: INITIAL_SUBMIT_ID.to_owned(),
|
||||
id: submit_id.to_owned(),
|
||||
msg: EventMsg::McpStartupUpdate(update),
|
||||
})
|
||||
.await
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
use async_trait::async_trait;
|
||||
use codex_environment::ExecutorFileSystem;
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::FunctionCallOutputContentItem;
|
||||
use codex_protocol::models::ImageDetail;
|
||||
|
||||
@@ -46,6 +46,7 @@ use std::path::PathBuf;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct UnifiedExecRequest {
|
||||
pub process_id: i32,
|
||||
pub command: Vec<String>,
|
||||
pub cwd: PathBuf,
|
||||
pub env: HashMap<String, String>,
|
||||
@@ -239,6 +240,7 @@ impl<'a> ToolRuntime<UnifiedExecRequest, UnifiedExecProcess> for UnifiedExecRunt
|
||||
return self
|
||||
.manager
|
||||
.open_session_with_exec_env(
|
||||
req.process_id,
|
||||
&prepared.exec_request,
|
||||
req.tty,
|
||||
prepared.spawn_lifecycle,
|
||||
@@ -275,7 +277,12 @@ impl<'a> ToolRuntime<UnifiedExecRequest, UnifiedExecProcess> for UnifiedExecRunt
|
||||
.env_for(spec, req.network.as_ref())
|
||||
.map_err(|err| ToolError::Codex(err.into()))?;
|
||||
self.manager
|
||||
.open_session_with_exec_env(&exec_env, req.tty, Box::new(NoopSpawnLifecycle))
|
||||
.open_session_with_exec_env(
|
||||
req.process_id,
|
||||
&exec_env,
|
||||
req.tty,
|
||||
Box::new(NoopSpawnLifecycle),
|
||||
)
|
||||
.await
|
||||
.map_err(|err| match err {
|
||||
UnifiedExecError::SandboxDenied { output, .. } => {
|
||||
|
||||
@@ -49,6 +49,10 @@ use codex_protocol::openai_models::WebSearchToolType;
|
||||
use codex_protocol::protocol::SandboxPolicy;
|
||||
use codex_protocol::protocol::SessionSource;
|
||||
use codex_protocol::protocol::SubAgentSource;
|
||||
#[cfg(test)]
|
||||
pub use codex_tool_spec::AdditionalProperties;
|
||||
pub use codex_tool_spec::JsonSchema;
|
||||
pub use codex_tool_spec::parse_tool_input_schema;
|
||||
use codex_utils_absolute_path::AbsolutePathBuf;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
@@ -467,62 +471,6 @@ fn supports_image_generation(model_info: &ModelInfo) -> bool {
|
||||
model_info.input_modalities.contains(&InputModality::Image)
|
||||
}
|
||||
|
||||
/// Generic JSON‑Schema subset needed for our tool definitions
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(tag = "type", rename_all = "lowercase")]
|
||||
pub enum JsonSchema {
|
||||
Boolean {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
description: Option<String>,
|
||||
},
|
||||
String {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
description: Option<String>,
|
||||
},
|
||||
/// MCP schema allows "number" | "integer" for Number
|
||||
#[serde(alias = "integer")]
|
||||
Number {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
description: Option<String>,
|
||||
},
|
||||
Array {
|
||||
items: Box<JsonSchema>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
description: Option<String>,
|
||||
},
|
||||
Object {
|
||||
properties: BTreeMap<String, JsonSchema>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
required: Option<Vec<String>>,
|
||||
#[serde(
|
||||
rename = "additionalProperties",
|
||||
skip_serializing_if = "Option::is_none"
|
||||
)]
|
||||
additional_properties: Option<AdditionalProperties>,
|
||||
},
|
||||
}
|
||||
|
||||
/// Whether additional properties are allowed, and if so, any required schema
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(untagged)]
|
||||
pub enum AdditionalProperties {
|
||||
Boolean(bool),
|
||||
Schema(Box<JsonSchema>),
|
||||
}
|
||||
|
||||
impl From<bool> for AdditionalProperties {
|
||||
fn from(b: bool) -> Self {
|
||||
Self::Boolean(b)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<JsonSchema> for AdditionalProperties {
|
||||
fn from(s: JsonSchema) -> Self {
|
||||
Self::Schema(Box::new(s))
|
||||
}
|
||||
}
|
||||
|
||||
fn create_network_permissions_schema() -> JsonSchema {
|
||||
JsonSchema::Object {
|
||||
properties: BTreeMap::from([(
|
||||
@@ -2240,19 +2188,16 @@ pub(crate) struct ApplyPatchToolArgs {
|
||||
/// Returns JSON values that are compatible with Function Calling in the
|
||||
/// Responses API:
|
||||
/// https://platform.openai.com/docs/guides/function-calling?api-mode=responses
|
||||
///
|
||||
/// API stability: this helper remains available from `crate::tools::spec`, but
|
||||
/// its implementation moved under `client_common` so the model client can
|
||||
/// serialize tool payloads without depending on the full tool registry builder.
|
||||
#[allow(dead_code)] // Kept for internal tests and legacy in-crate call sites.
|
||||
pub fn create_tools_json_for_responses_api(
|
||||
tools: &[ToolSpec],
|
||||
) -> crate::error::Result<Vec<serde_json::Value>> {
|
||||
let mut tools_json = Vec::new();
|
||||
|
||||
for tool in tools {
|
||||
let json = serde_json::to_value(tool)?;
|
||||
tools_json.push(json);
|
||||
}
|
||||
|
||||
Ok(tools_json)
|
||||
crate::client_common::create_tools_json_for_responses_api(tools)
|
||||
}
|
||||
|
||||
fn push_tool_spec(
|
||||
builder: &mut ToolRegistryBuilder,
|
||||
spec: ToolSpec,
|
||||
@@ -2314,13 +2259,6 @@ fn dynamic_tool_to_openai_tool(
|
||||
})
|
||||
}
|
||||
|
||||
/// Parse the tool input_schema or return an error for invalid schema
|
||||
pub fn parse_tool_input_schema(input_schema: &JsonValue) -> Result<JsonSchema, serde_json::Error> {
|
||||
let mut input_schema = input_schema.clone();
|
||||
sanitize_json_schema(&mut input_schema);
|
||||
serde_json::from_value::<JsonSchema>(input_schema)
|
||||
}
|
||||
|
||||
fn mcp_tool_to_openai_tool_parts(
|
||||
tool: rmcp::model::Tool,
|
||||
) -> Result<(String, JsonSchema, Option<JsonValue>), serde_json::Error> {
|
||||
@@ -2345,13 +2283,7 @@ fn mcp_tool_to_openai_tool_parts(
|
||||
);
|
||||
}
|
||||
|
||||
// Serialize to a raw JSON value so we can sanitize schemas coming from MCP
|
||||
// servers. Some servers omit the top-level or nested `type` in JSON
|
||||
// Schemas (e.g. using enum/anyOf), or use unsupported variants like
|
||||
// `integer`. Our internal JsonSchema is a small subset and requires
|
||||
// `type`, so we coerce/sanitize here for compatibility.
|
||||
sanitize_json_schema(&mut serialized_input_schema);
|
||||
let input_schema = serde_json::from_value::<JsonSchema>(serialized_input_schema)?;
|
||||
let input_schema = parse_tool_input_schema(&serialized_input_schema)?;
|
||||
let structured_content_schema = output_schema
|
||||
.map(|output_schema| serde_json::Value::Object(output_schema.as_ref().clone()))
|
||||
.unwrap_or_else(|| JsonValue::Object(serde_json::Map::new()));
|
||||
@@ -2382,117 +2314,6 @@ fn mcp_call_tool_result_output_schema(structured_content_schema: JsonValue) -> J
|
||||
})
|
||||
}
|
||||
|
||||
/// Sanitize a JSON Schema (as serde_json::Value) so it can fit our limited
|
||||
/// JsonSchema enum. This function:
|
||||
/// - Ensures every schema object has a "type". If missing, infers it from
|
||||
/// common keywords (properties => object, items => array, enum/const/format => string)
|
||||
/// and otherwise defaults to "string".
|
||||
/// - Fills required child fields (e.g. array items, object properties) with
|
||||
/// permissive defaults when absent.
|
||||
fn sanitize_json_schema(value: &mut JsonValue) {
|
||||
match value {
|
||||
JsonValue::Bool(_) => {
|
||||
// JSON Schema boolean form: true/false. Coerce to an accept-all string.
|
||||
*value = json!({ "type": "string" });
|
||||
}
|
||||
JsonValue::Array(arr) => {
|
||||
for v in arr.iter_mut() {
|
||||
sanitize_json_schema(v);
|
||||
}
|
||||
}
|
||||
JsonValue::Object(map) => {
|
||||
// First, recursively sanitize known nested schema holders
|
||||
if let Some(props) = map.get_mut("properties")
|
||||
&& let Some(props_map) = props.as_object_mut()
|
||||
{
|
||||
for (_k, v) in props_map.iter_mut() {
|
||||
sanitize_json_schema(v);
|
||||
}
|
||||
}
|
||||
if let Some(items) = map.get_mut("items") {
|
||||
sanitize_json_schema(items);
|
||||
}
|
||||
// Some schemas use oneOf/anyOf/allOf - sanitize their entries
|
||||
for combiner in ["oneOf", "anyOf", "allOf", "prefixItems"] {
|
||||
if let Some(v) = map.get_mut(combiner) {
|
||||
sanitize_json_schema(v);
|
||||
}
|
||||
}
|
||||
|
||||
// Normalize/ensure type
|
||||
let mut ty = map.get("type").and_then(|v| v.as_str()).map(str::to_string);
|
||||
|
||||
// If type is an array (union), pick first supported; else leave to inference
|
||||
if ty.is_none()
|
||||
&& let Some(JsonValue::Array(types)) = map.get("type")
|
||||
{
|
||||
for t in types {
|
||||
if let Some(tt) = t.as_str()
|
||||
&& matches!(
|
||||
tt,
|
||||
"object" | "array" | "string" | "number" | "integer" | "boolean"
|
||||
)
|
||||
{
|
||||
ty = Some(tt.to_string());
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Infer type if still missing
|
||||
if ty.is_none() {
|
||||
if map.contains_key("properties")
|
||||
|| map.contains_key("required")
|
||||
|| map.contains_key("additionalProperties")
|
||||
{
|
||||
ty = Some("object".to_string());
|
||||
} else if map.contains_key("items") || map.contains_key("prefixItems") {
|
||||
ty = Some("array".to_string());
|
||||
} else if map.contains_key("enum")
|
||||
|| map.contains_key("const")
|
||||
|| map.contains_key("format")
|
||||
{
|
||||
ty = Some("string".to_string());
|
||||
} else if map.contains_key("minimum")
|
||||
|| map.contains_key("maximum")
|
||||
|| map.contains_key("exclusiveMinimum")
|
||||
|| map.contains_key("exclusiveMaximum")
|
||||
|| map.contains_key("multipleOf")
|
||||
{
|
||||
ty = Some("number".to_string());
|
||||
}
|
||||
}
|
||||
// If we still couldn't infer, default to string
|
||||
let ty = ty.unwrap_or_else(|| "string".to_string());
|
||||
map.insert("type".to_string(), JsonValue::String(ty.to_string()));
|
||||
|
||||
// Ensure object schemas have properties map
|
||||
if ty == "object" {
|
||||
if !map.contains_key("properties") {
|
||||
map.insert(
|
||||
"properties".to_string(),
|
||||
JsonValue::Object(serde_json::Map::new()),
|
||||
);
|
||||
}
|
||||
// If additionalProperties is an object schema, sanitize it too.
|
||||
// Leave booleans as-is, since JSON Schema allows boolean here.
|
||||
if let Some(ap) = map.get_mut("additionalProperties") {
|
||||
let is_bool = matches!(ap, JsonValue::Bool(_));
|
||||
if !is_bool {
|
||||
sanitize_json_schema(ap);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure array schemas have items
|
||||
if ty == "array" && !map.contains_key("items") {
|
||||
map.insert("items".to_string(), json!({ "type": "string" }));
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
/// Builds the tool registry builder while collecting tool specs for later serialization.
|
||||
#[cfg(test)]
|
||||
pub(crate) fn build_specs(
|
||||
|
||||
@@ -2777,7 +2777,7 @@ fn chat_tools_include_top_level_name() {
|
||||
output_schema: None,
|
||||
})];
|
||||
|
||||
let responses_json = create_tools_json_for_responses_api(&tools).unwrap();
|
||||
let responses_json = crate::client_tools::create_tools_json_for_responses_api(&tools).unwrap();
|
||||
assert_eq!(
|
||||
responses_json,
|
||||
vec![json!({
|
||||
|
||||
116
codex-rs/core/src/unified_exec/backend.rs
Normal file
116
codex-rs/core/src/unified_exec/backend.rs
Normal file
@@ -0,0 +1,116 @@
|
||||
use std::path::PathBuf;
|
||||
|
||||
use codex_exec_server::ExecServerClient;
|
||||
use codex_exec_server::RemoteExecServerConnectArgs;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::exec::SandboxType;
|
||||
use crate::sandboxing::ExecRequest;
|
||||
use crate::unified_exec::RemoteExecServerFileSystem;
|
||||
use crate::unified_exec::SpawnLifecycleHandle;
|
||||
use crate::unified_exec::UnifiedExecError;
|
||||
use crate::unified_exec::UnifiedExecProcess;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct RemoteExecServerBackend {
|
||||
client: ExecServerClient,
|
||||
local_workspace_root: PathBuf,
|
||||
remote_workspace_root: Option<PathBuf>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for RemoteExecServerBackend {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("RemoteExecServerBackend")
|
||||
.field("local_workspace_root", &self.local_workspace_root)
|
||||
.field("remote_workspace_root", &self.remote_workspace_root)
|
||||
.finish_non_exhaustive()
|
||||
}
|
||||
}
|
||||
|
||||
impl RemoteExecServerBackend {
|
||||
pub(crate) async fn connect_for_config(
|
||||
config: &Config,
|
||||
) -> Result<Option<Self>, UnifiedExecError> {
|
||||
let Some(websocket_url) = config
|
||||
.experimental_unified_exec_exec_server_websocket_url
|
||||
.clone()
|
||||
else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let client = ExecServerClient::connect_websocket(RemoteExecServerConnectArgs::new(
|
||||
websocket_url,
|
||||
"codex-core".to_string(),
|
||||
))
|
||||
.await
|
||||
.map_err(|err| UnifiedExecError::create_process(err.to_string()))?;
|
||||
|
||||
Ok(Some(Self {
|
||||
client,
|
||||
local_workspace_root: config.cwd.clone(),
|
||||
remote_workspace_root: config
|
||||
.experimental_unified_exec_exec_server_workspace_root
|
||||
.clone()
|
||||
.map(PathBuf::from),
|
||||
}))
|
||||
}
|
||||
|
||||
pub(crate) async fn open_session(
|
||||
&self,
|
||||
process_id: i32,
|
||||
env: &ExecRequest,
|
||||
tty: bool,
|
||||
spawn_lifecycle: SpawnLifecycleHandle,
|
||||
) -> Result<UnifiedExecProcess, UnifiedExecError> {
|
||||
if !spawn_lifecycle.inherited_fds().is_empty() {
|
||||
return Err(UnifiedExecError::create_process(
|
||||
"remote exec-server mode does not support inherited file descriptors".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if env.sandbox != SandboxType::None {
|
||||
return Err(UnifiedExecError::create_process(format!(
|
||||
"remote exec-server mode does not support sandboxed execution yet: {:?}",
|
||||
env.sandbox
|
||||
)));
|
||||
}
|
||||
|
||||
let remote_cwd = self.map_remote_cwd(env.cwd.as_path())?;
|
||||
UnifiedExecProcess::from_exec_server(
|
||||
self.client.clone(),
|
||||
process_id,
|
||||
env,
|
||||
remote_cwd,
|
||||
tty,
|
||||
spawn_lifecycle,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
pub(crate) fn file_system(&self) -> RemoteExecServerFileSystem {
|
||||
RemoteExecServerFileSystem::new(self.client.clone())
|
||||
}
|
||||
|
||||
fn map_remote_cwd(&self, local_cwd: &std::path::Path) -> Result<PathBuf, UnifiedExecError> {
|
||||
let Some(remote_root) = self.remote_workspace_root.as_ref() else {
|
||||
if local_cwd == self.local_workspace_root.as_path() {
|
||||
return Ok(PathBuf::from("."));
|
||||
}
|
||||
return Err(UnifiedExecError::create_process(format!(
|
||||
"remote exec-server mode needs `experimental_unified_exec_exec_server_workspace_root` for non-root cwd `{}`",
|
||||
local_cwd.display()
|
||||
)));
|
||||
};
|
||||
|
||||
let relative =
|
||||
UnifiedExecProcess::relative_cwd_under(local_cwd, &self.local_workspace_root)
|
||||
.ok_or_else(|| {
|
||||
UnifiedExecError::create_process(format!(
|
||||
"cwd `{}` is not under local workspace root `{}`",
|
||||
local_cwd.display(),
|
||||
self.local_workspace_root.display()
|
||||
))
|
||||
})?;
|
||||
Ok(remote_root.join(relative))
|
||||
}
|
||||
}
|
||||
@@ -38,21 +38,25 @@ use crate::codex::TurnContext;
|
||||
use crate::sandboxing::SandboxPermissions;
|
||||
|
||||
mod async_watcher;
|
||||
mod backend;
|
||||
mod errors;
|
||||
mod head_tail_buffer;
|
||||
mod process;
|
||||
mod process_manager;
|
||||
mod remote_filesystem;
|
||||
|
||||
pub(crate) fn set_deterministic_process_ids_for_tests(enabled: bool) {
|
||||
process_manager::set_deterministic_process_ids_for_tests(enabled);
|
||||
}
|
||||
|
||||
pub(crate) use backend::RemoteExecServerBackend;
|
||||
pub(crate) use errors::UnifiedExecError;
|
||||
pub(crate) use process::NoopSpawnLifecycle;
|
||||
#[cfg(unix)]
|
||||
pub(crate) use process::SpawnLifecycle;
|
||||
pub(crate) use process::SpawnLifecycleHandle;
|
||||
pub(crate) use process::UnifiedExecProcess;
|
||||
pub(crate) use remote_filesystem::RemoteExecServerFileSystem;
|
||||
|
||||
pub(crate) const MIN_YIELD_TIME_MS: u64 = 250;
|
||||
// Minimum yield time for an empty `write_stdin`.
|
||||
@@ -123,6 +127,7 @@ impl ProcessStore {
|
||||
pub(crate) struct UnifiedExecProcessManager {
|
||||
process_store: Mutex<ProcessStore>,
|
||||
max_write_stdin_yield_time_ms: u64,
|
||||
remote_exec_server: Option<RemoteExecServerBackend>,
|
||||
}
|
||||
|
||||
impl UnifiedExecProcessManager {
|
||||
@@ -131,6 +136,19 @@ impl UnifiedExecProcessManager {
|
||||
process_store: Mutex::new(ProcessStore::default()),
|
||||
max_write_stdin_yield_time_ms: max_write_stdin_yield_time_ms
|
||||
.max(MIN_EMPTY_YIELD_TIME_MS),
|
||||
remote_exec_server: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn with_remote_exec_server(
|
||||
max_write_stdin_yield_time_ms: u64,
|
||||
remote_exec_server: Option<RemoteExecServerBackend>,
|
||||
) -> Self {
|
||||
Self {
|
||||
process_store: Mutex::new(ProcessStore::default()),
|
||||
max_write_stdin_yield_time_ms: max_write_stdin_yield_time_ms
|
||||
.max(MIN_EMPTY_YIELD_TIME_MS),
|
||||
remote_exec_server,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#![allow(clippy::module_inception)]
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex as StdMutex;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::atomic::Ordering;
|
||||
use tokio::sync::Mutex;
|
||||
@@ -16,8 +17,13 @@ use crate::exec::ExecToolCallOutput;
|
||||
use crate::exec::SandboxType;
|
||||
use crate::exec::StreamOutput;
|
||||
use crate::exec::is_likely_sandbox_denied;
|
||||
use crate::sandboxing::ExecRequest;
|
||||
use crate::truncate::TruncationPolicy;
|
||||
use crate::truncate::formatted_truncate_text;
|
||||
use codex_exec_server::ExecParams;
|
||||
use codex_exec_server::ExecServerClient;
|
||||
use codex_exec_server::ExecServerEvent;
|
||||
use codex_utils_absolute_path::AbsolutePathBuf;
|
||||
use codex_utils_pty::ExecCommandSession;
|
||||
use codex_utils_pty::SpawnedPty;
|
||||
|
||||
@@ -56,7 +62,7 @@ pub(crate) struct OutputHandles {
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct UnifiedExecProcess {
|
||||
process_handle: ExecCommandSession,
|
||||
process_handle: ProcessBackend,
|
||||
output_rx: broadcast::Receiver<Vec<u8>>,
|
||||
output_buffer: OutputBuffer,
|
||||
output_notify: Arc<Notify>,
|
||||
@@ -69,9 +75,45 @@ pub(crate) struct UnifiedExecProcess {
|
||||
_spawn_lifecycle: SpawnLifecycleHandle,
|
||||
}
|
||||
|
||||
enum ProcessBackend {
|
||||
Local(ExecCommandSession),
|
||||
Remote(RemoteExecSession),
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for ProcessBackend {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Local(process_handle) => f.debug_tuple("Local").field(process_handle).finish(),
|
||||
Self::Remote(process_handle) => f.debug_tuple("Remote").field(process_handle).finish(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct RemoteExecSession {
|
||||
process_key: String,
|
||||
client: ExecServerClient,
|
||||
writer_tx: mpsc::Sender<Vec<u8>>,
|
||||
exited: Arc<AtomicBool>,
|
||||
exit_code: Arc<StdMutex<Option<i32>>>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for RemoteExecSession {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("RemoteExecSession")
|
||||
.field("process_key", &self.process_key)
|
||||
.field("exited", &self.exited.load(Ordering::SeqCst))
|
||||
.field(
|
||||
"exit_code",
|
||||
&self.exit_code.lock().ok().and_then(|guard| *guard),
|
||||
)
|
||||
.finish_non_exhaustive()
|
||||
}
|
||||
}
|
||||
|
||||
impl UnifiedExecProcess {
|
||||
pub(super) fn new(
|
||||
process_handle: ExecCommandSession,
|
||||
fn new(
|
||||
process_handle: ProcessBackend,
|
||||
initial_output_rx: tokio::sync::broadcast::Receiver<Vec<u8>>,
|
||||
sandbox_type: SandboxType,
|
||||
spawn_lifecycle: SpawnLifecycleHandle,
|
||||
@@ -123,7 +165,10 @@ impl UnifiedExecProcess {
|
||||
}
|
||||
|
||||
pub(super) fn writer_sender(&self) -> mpsc::Sender<Vec<u8>> {
|
||||
self.process_handle.writer_sender()
|
||||
match &self.process_handle {
|
||||
ProcessBackend::Local(process_handle) => process_handle.writer_sender(),
|
||||
ProcessBackend::Remote(process_handle) => process_handle.writer_tx.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn output_handles(&self) -> OutputHandles {
|
||||
@@ -149,17 +194,38 @@ impl UnifiedExecProcess {
|
||||
}
|
||||
|
||||
pub(super) fn has_exited(&self) -> bool {
|
||||
self.process_handle.has_exited()
|
||||
match &self.process_handle {
|
||||
ProcessBackend::Local(process_handle) => process_handle.has_exited(),
|
||||
ProcessBackend::Remote(process_handle) => process_handle.exited.load(Ordering::SeqCst),
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn exit_code(&self) -> Option<i32> {
|
||||
self.process_handle.exit_code()
|
||||
match &self.process_handle {
|
||||
ProcessBackend::Local(process_handle) => process_handle.exit_code(),
|
||||
ProcessBackend::Remote(process_handle) => process_handle
|
||||
.exit_code
|
||||
.lock()
|
||||
.ok()
|
||||
.and_then(|guard| *guard),
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn terminate(&self) {
|
||||
self.output_closed.store(true, Ordering::Release);
|
||||
self.output_closed_notify.notify_waiters();
|
||||
self.process_handle.terminate();
|
||||
match &self.process_handle {
|
||||
ProcessBackend::Local(process_handle) => process_handle.terminate(),
|
||||
ProcessBackend::Remote(process_handle) => {
|
||||
let client = process_handle.client.clone();
|
||||
let process_key = process_handle.process_key.clone();
|
||||
if let Ok(handle) = tokio::runtime::Handle::try_current() {
|
||||
handle.spawn(async move {
|
||||
let _ = client.terminate(&process_key).await;
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
self.cancellation_token.cancel();
|
||||
self.output_task.abort();
|
||||
}
|
||||
@@ -232,7 +298,12 @@ impl UnifiedExecProcess {
|
||||
mut exit_rx,
|
||||
} = spawned;
|
||||
let output_rx = codex_utils_pty::combine_output_receivers(stdout_rx, stderr_rx);
|
||||
let managed = Self::new(process_handle, output_rx, sandbox_type, spawn_lifecycle);
|
||||
let managed = Self::new(
|
||||
ProcessBackend::Local(process_handle),
|
||||
output_rx,
|
||||
sandbox_type,
|
||||
spawn_lifecycle,
|
||||
);
|
||||
|
||||
let exit_ready = matches!(exit_rx.try_recv(), Ok(_) | Err(TryRecvError::Closed));
|
||||
|
||||
@@ -262,6 +333,102 @@ impl UnifiedExecProcess {
|
||||
Ok(managed)
|
||||
}
|
||||
|
||||
pub(super) async fn from_exec_server(
|
||||
client: ExecServerClient,
|
||||
process_id: i32,
|
||||
env: &ExecRequest,
|
||||
remote_cwd: std::path::PathBuf,
|
||||
tty: bool,
|
||||
spawn_lifecycle: SpawnLifecycleHandle,
|
||||
) -> Result<Self, UnifiedExecError> {
|
||||
let process_key = process_id.to_string();
|
||||
let mut events_rx = client.event_receiver();
|
||||
let response = client
|
||||
.exec(ExecParams {
|
||||
process_id: process_key.clone(),
|
||||
argv: env.command.clone(),
|
||||
cwd: remote_cwd,
|
||||
env: env.env.clone(),
|
||||
tty,
|
||||
arg0: env.arg0.clone(),
|
||||
sandbox: None,
|
||||
})
|
||||
.await
|
||||
.map_err(|err| UnifiedExecError::create_process(err.to_string()))?;
|
||||
let process_key = response.process_id;
|
||||
|
||||
let (output_tx, output_rx) = broadcast::channel(256);
|
||||
let (writer_tx, mut writer_rx) = mpsc::channel::<Vec<u8>>(256);
|
||||
let exited = Arc::new(AtomicBool::new(false));
|
||||
let exit_code = Arc::new(StdMutex::new(None));
|
||||
|
||||
let managed = Self::new(
|
||||
ProcessBackend::Remote(RemoteExecSession {
|
||||
process_key: process_key.clone(),
|
||||
client: client.clone(),
|
||||
writer_tx,
|
||||
exited: Arc::clone(&exited),
|
||||
exit_code: Arc::clone(&exit_code),
|
||||
}),
|
||||
output_rx,
|
||||
env.sandbox,
|
||||
spawn_lifecycle,
|
||||
);
|
||||
|
||||
{
|
||||
let client = client.clone();
|
||||
let writer_process_key = process_key.clone();
|
||||
tokio::spawn(async move {
|
||||
while let Some(chunk) = writer_rx.recv().await {
|
||||
if client.write(&writer_process_key, chunk).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
{
|
||||
let cancellation_token = managed.cancellation_token();
|
||||
tokio::spawn(async move {
|
||||
while let Ok(event) = events_rx.recv().await {
|
||||
match event {
|
||||
ExecServerEvent::OutputDelta(notification)
|
||||
if notification.process_id == process_key =>
|
||||
{
|
||||
let _ = output_tx.send(notification.chunk.into_inner());
|
||||
}
|
||||
ExecServerEvent::Exited(notification)
|
||||
if notification.process_id == process_key =>
|
||||
{
|
||||
exited.store(true, Ordering::SeqCst);
|
||||
if let Ok(mut guard) = exit_code.lock() {
|
||||
*guard = Some(notification.exit_code);
|
||||
}
|
||||
cancellation_token.cancel();
|
||||
break;
|
||||
}
|
||||
ExecServerEvent::OutputDelta(_) | ExecServerEvent::Exited(_) => {}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
Ok(managed)
|
||||
}
|
||||
|
||||
pub(super) fn relative_cwd_under(
|
||||
local_cwd: &std::path::Path,
|
||||
local_root: &std::path::Path,
|
||||
) -> Option<std::path::PathBuf> {
|
||||
let local_cwd = AbsolutePathBuf::try_from(local_cwd.to_path_buf()).ok()?;
|
||||
let local_root = AbsolutePathBuf::try_from(local_root.to_path_buf()).ok()?;
|
||||
local_cwd
|
||||
.as_path()
|
||||
.strip_prefix(local_root.as_path())
|
||||
.ok()
|
||||
.map(std::path::Path::to_path_buf)
|
||||
}
|
||||
|
||||
fn signal_exit(&self) {
|
||||
self.cancellation_token.cancel();
|
||||
}
|
||||
|
||||
@@ -539,10 +539,17 @@ impl UnifiedExecProcessManager {
|
||||
|
||||
pub(crate) async fn open_session_with_exec_env(
|
||||
&self,
|
||||
process_id: i32,
|
||||
env: &ExecRequest,
|
||||
tty: bool,
|
||||
mut spawn_lifecycle: SpawnLifecycleHandle,
|
||||
) -> Result<UnifiedExecProcess, UnifiedExecError> {
|
||||
if let Some(remote_exec_server) = self.remote_exec_server.as_ref() {
|
||||
return remote_exec_server
|
||||
.open_session(process_id, env, tty, spawn_lifecycle)
|
||||
.await;
|
||||
}
|
||||
|
||||
let (program, args) = env
|
||||
.command
|
||||
.split_first()
|
||||
@@ -610,6 +617,7 @@ impl UnifiedExecProcessManager {
|
||||
})
|
||||
.await;
|
||||
let req = UnifiedExecToolRequest {
|
||||
process_id: request.process_id,
|
||||
command: request.command.clone(),
|
||||
cwd,
|
||||
env,
|
||||
|
||||
152
codex-rs/core/src/unified_exec/remote_filesystem.rs
Normal file
152
codex-rs/core/src/unified_exec/remote_filesystem.rs
Normal file
@@ -0,0 +1,152 @@
|
||||
use std::io;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use base64::Engine as _;
|
||||
use base64::engine::general_purpose::STANDARD;
|
||||
use codex_app_server_protocol::FsCopyParams;
|
||||
use codex_app_server_protocol::FsCreateDirectoryParams;
|
||||
use codex_app_server_protocol::FsGetMetadataParams;
|
||||
use codex_app_server_protocol::FsReadDirectoryParams;
|
||||
use codex_app_server_protocol::FsReadFileParams;
|
||||
use codex_app_server_protocol::FsRemoveParams;
|
||||
use codex_app_server_protocol::FsWriteFileParams;
|
||||
use codex_environment::CopyOptions;
|
||||
use codex_environment::CreateDirectoryOptions;
|
||||
use codex_environment::ExecutorFileSystem;
|
||||
use codex_environment::FileMetadata;
|
||||
use codex_environment::FileSystemResult;
|
||||
use codex_environment::ReadDirectoryEntry;
|
||||
use codex_environment::RemoveOptions;
|
||||
use codex_exec_server::ExecServerClient;
|
||||
use codex_exec_server::ExecServerError;
|
||||
use codex_utils_absolute_path::AbsolutePathBuf;
|
||||
|
||||
/// Filesystem backend that forwards environment operations to a remote
|
||||
/// `codex-exec-server`.
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct RemoteExecServerFileSystem {
|
||||
client: ExecServerClient,
|
||||
}
|
||||
|
||||
impl RemoteExecServerFileSystem {
|
||||
pub(crate) fn new(client: ExecServerClient) -> Self {
|
||||
Self { client }
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for RemoteExecServerFileSystem {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("RemoteExecServerFileSystem")
|
||||
.finish_non_exhaustive()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ExecutorFileSystem for RemoteExecServerFileSystem {
|
||||
async fn read_file(&self, path: &AbsolutePathBuf) -> FileSystemResult<Vec<u8>> {
|
||||
let response = self
|
||||
.client
|
||||
.fs_read_file(FsReadFileParams { path: path.clone() })
|
||||
.await
|
||||
.map_err(map_exec_server_error)?;
|
||||
STANDARD.decode(response.data_base64).map_err(|error| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
format!("remote exec-server returned invalid base64 file data: {error}"),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
async fn write_file(&self, path: &AbsolutePathBuf, contents: Vec<u8>) -> FileSystemResult<()> {
|
||||
self.client
|
||||
.fs_write_file(FsWriteFileParams {
|
||||
path: path.clone(),
|
||||
data_base64: STANDARD.encode(contents),
|
||||
})
|
||||
.await
|
||||
.map_err(map_exec_server_error)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn create_directory(
|
||||
&self,
|
||||
path: &AbsolutePathBuf,
|
||||
options: CreateDirectoryOptions,
|
||||
) -> FileSystemResult<()> {
|
||||
self.client
|
||||
.fs_create_directory(FsCreateDirectoryParams {
|
||||
path: path.clone(),
|
||||
recursive: Some(options.recursive),
|
||||
})
|
||||
.await
|
||||
.map_err(map_exec_server_error)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_metadata(&self, path: &AbsolutePathBuf) -> FileSystemResult<FileMetadata> {
|
||||
let response = self
|
||||
.client
|
||||
.fs_get_metadata(FsGetMetadataParams { path: path.clone() })
|
||||
.await
|
||||
.map_err(map_exec_server_error)?;
|
||||
Ok(FileMetadata {
|
||||
is_directory: response.is_directory,
|
||||
is_file: response.is_file,
|
||||
created_at_ms: response.created_at_ms,
|
||||
modified_at_ms: response.modified_at_ms,
|
||||
})
|
||||
}
|
||||
|
||||
async fn read_directory(
|
||||
&self,
|
||||
path: &AbsolutePathBuf,
|
||||
) -> FileSystemResult<Vec<ReadDirectoryEntry>> {
|
||||
let response = self
|
||||
.client
|
||||
.fs_read_directory(FsReadDirectoryParams { path: path.clone() })
|
||||
.await
|
||||
.map_err(map_exec_server_error)?;
|
||||
Ok(response
|
||||
.entries
|
||||
.into_iter()
|
||||
.map(|entry| ReadDirectoryEntry {
|
||||
file_name: entry.file_name,
|
||||
is_directory: entry.is_directory,
|
||||
is_file: entry.is_file,
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
|
||||
async fn remove(&self, path: &AbsolutePathBuf, options: RemoveOptions) -> FileSystemResult<()> {
|
||||
self.client
|
||||
.fs_remove(FsRemoveParams {
|
||||
path: path.clone(),
|
||||
recursive: Some(options.recursive),
|
||||
force: Some(options.force),
|
||||
})
|
||||
.await
|
||||
.map_err(map_exec_server_error)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn copy(
|
||||
&self,
|
||||
source_path: &AbsolutePathBuf,
|
||||
destination_path: &AbsolutePathBuf,
|
||||
options: CopyOptions,
|
||||
) -> FileSystemResult<()> {
|
||||
self.client
|
||||
.fs_copy(FsCopyParams {
|
||||
source_path: source_path.clone(),
|
||||
destination_path: destination_path.clone(),
|
||||
recursive: options.recursive,
|
||||
})
|
||||
.await
|
||||
.map_err(map_exec_server_error)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn map_exec_server_error(error: ExecServerError) -> io::Error {
|
||||
io::Error::other(error.to_string())
|
||||
}
|
||||
@@ -8,11 +8,37 @@ pub use fs::FileSystemResult;
|
||||
pub use fs::ReadDirectoryEntry;
|
||||
pub use fs::RemoveOptions;
|
||||
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct Environment;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Environment {
|
||||
file_system: Arc<dyn ExecutorFileSystem>,
|
||||
}
|
||||
|
||||
impl Environment {
|
||||
pub fn get_filesystem(&self) -> impl ExecutorFileSystem + use<> {
|
||||
fs::LocalFileSystem
|
||||
pub fn local() -> Self {
|
||||
Self {
|
||||
file_system: Arc::new(fs::LocalFileSystem),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_file_system(file_system: Arc<dyn ExecutorFileSystem>) -> Self {
|
||||
Self { file_system }
|
||||
}
|
||||
|
||||
pub fn get_filesystem(&self) -> Arc<dyn ExecutorFileSystem> {
|
||||
Arc::clone(&self.file_system)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Environment {
|
||||
fn default() -> Self {
|
||||
Self::local()
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for Environment {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("Environment").finish_non_exhaustive()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
load("//:defs.bzl", "codex_rust_crate")
|
||||
|
||||
codex_rust_crate(
|
||||
name = "exec-server",
|
||||
crate_name = "codex_exec_server",
|
||||
test_tags = ["no-sandbox"],
|
||||
)
|
||||
@@ -15,6 +15,7 @@ workspace = true
|
||||
base64 = { workspace = true }
|
||||
clap = { workspace = true, features = ["derive"] }
|
||||
codex-app-server-protocol = { workspace = true }
|
||||
codex-environment = { workspace = true }
|
||||
codex-utils-pty = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
|
||||
@@ -121,7 +121,8 @@ Request params:
|
||||
"PATH": "/usr/bin:/bin"
|
||||
},
|
||||
"tty": true,
|
||||
"arg0": null
|
||||
"arg0": null,
|
||||
"sandbox": null
|
||||
}
|
||||
```
|
||||
|
||||
@@ -133,6 +134,9 @@ Field definitions:
|
||||
- `tty`: when `true`, spawn a PTY-backed interactive process; when `false`,
|
||||
spawn a pipe-backed process with closed stdin.
|
||||
- `arg0`: optional argv0 override forwarded to `codex-utils-pty`.
|
||||
- `sandbox`: optional sandbox config. Omit it for the current direct-spawn
|
||||
behavior. Explicit `{"mode":"none"}` is accepted; `{"mode":"hostDefault"}`
|
||||
is currently rejected until host-local sandbox materialization is wired up.
|
||||
|
||||
Response:
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
51
codex-rs/exec-server/src/client/jsonrpc_backend.rs
Normal file
51
codex-rs/exec-server/src/client/jsonrpc_backend.rs
Normal file
@@ -0,0 +1,51 @@
|
||||
use codex_app_server_protocol::JSONRPCMessage;
|
||||
use codex_app_server_protocol::JSONRPCNotification;
|
||||
use codex_app_server_protocol::JSONRPCRequest;
|
||||
use codex_app_server_protocol::RequestId;
|
||||
use serde::Serialize;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use super::ExecServerError;
|
||||
|
||||
pub(super) struct JsonRpcBackend {
|
||||
write_tx: mpsc::Sender<JSONRPCMessage>,
|
||||
}
|
||||
|
||||
impl JsonRpcBackend {
|
||||
pub(super) fn new(write_tx: mpsc::Sender<JSONRPCMessage>) -> Self {
|
||||
Self { write_tx }
|
||||
}
|
||||
|
||||
pub(super) async fn notify<P: Serialize>(
|
||||
&self,
|
||||
method: &str,
|
||||
params: &P,
|
||||
) -> Result<(), ExecServerError> {
|
||||
let params = serde_json::to_value(params)?;
|
||||
self.write_tx
|
||||
.send(JSONRPCMessage::Notification(JSONRPCNotification {
|
||||
method: method.to_string(),
|
||||
params: Some(params),
|
||||
}))
|
||||
.await
|
||||
.map_err(|_| ExecServerError::Closed)
|
||||
}
|
||||
|
||||
pub(super) async fn send_request<P: Serialize>(
|
||||
&self,
|
||||
request_id: RequestId,
|
||||
method: &str,
|
||||
params: &P,
|
||||
) -> Result<(), ExecServerError> {
|
||||
let params = serde_json::to_value(params)?;
|
||||
self.write_tx
|
||||
.send(JSONRPCMessage::Request(JSONRPCRequest {
|
||||
id: request_id,
|
||||
method: method.to_string(),
|
||||
params: Some(params),
|
||||
trace: None,
|
||||
}))
|
||||
.await
|
||||
.map_err(|_| ExecServerError::Closed)
|
||||
}
|
||||
}
|
||||
141
codex-rs/exec-server/src/client/local_backend.rs
Normal file
141
codex-rs/exec-server/src/client/local_backend.rs
Normal file
@@ -0,0 +1,141 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use codex_app_server_protocol::FsCopyParams;
|
||||
use codex_app_server_protocol::FsCopyResponse;
|
||||
use codex_app_server_protocol::FsCreateDirectoryParams;
|
||||
use codex_app_server_protocol::FsCreateDirectoryResponse;
|
||||
use codex_app_server_protocol::FsGetMetadataParams;
|
||||
use codex_app_server_protocol::FsGetMetadataResponse;
|
||||
use codex_app_server_protocol::FsReadDirectoryParams;
|
||||
use codex_app_server_protocol::FsReadDirectoryResponse;
|
||||
use codex_app_server_protocol::FsReadFileParams;
|
||||
use codex_app_server_protocol::FsReadFileResponse;
|
||||
use codex_app_server_protocol::FsRemoveParams;
|
||||
use codex_app_server_protocol::FsRemoveResponse;
|
||||
use codex_app_server_protocol::FsWriteFileParams;
|
||||
use codex_app_server_protocol::FsWriteFileResponse;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use crate::protocol::ExecParams;
|
||||
use crate::protocol::ExecResponse;
|
||||
use crate::protocol::INITIALIZED_METHOD;
|
||||
use crate::protocol::InitializeResponse;
|
||||
use crate::protocol::ReadParams;
|
||||
use crate::protocol::ReadResponse;
|
||||
use crate::protocol::TerminateParams;
|
||||
use crate::protocol::TerminateResponse;
|
||||
use crate::protocol::WriteParams;
|
||||
use crate::protocol::WriteResponse;
|
||||
use crate::server::ExecServerHandler;
|
||||
|
||||
use super::ExecServerError;
|
||||
use super::server_result_to_client;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(super) struct LocalBackend {
|
||||
handler: Arc<Mutex<ExecServerHandler>>,
|
||||
}
|
||||
|
||||
impl LocalBackend {
|
||||
pub(super) fn new(handler: ExecServerHandler) -> Self {
|
||||
Self {
|
||||
handler: Arc::new(Mutex::new(handler)),
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn shutdown(&self) {
|
||||
self.handler.lock().await.shutdown().await;
|
||||
}
|
||||
|
||||
pub(super) async fn initialize(&self) -> Result<InitializeResponse, ExecServerError> {
|
||||
server_result_to_client(self.handler.lock().await.initialize())
|
||||
}
|
||||
|
||||
pub(super) async fn notify(&self, method: &str) -> Result<(), ExecServerError> {
|
||||
match method {
|
||||
INITIALIZED_METHOD => self
|
||||
.handler
|
||||
.lock()
|
||||
.await
|
||||
.initialized()
|
||||
.map_err(ExecServerError::Protocol),
|
||||
other => Err(ExecServerError::Protocol(format!(
|
||||
"unsupported in-process notification method `{other}`"
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn exec(&self, params: ExecParams) -> Result<ExecResponse, ExecServerError> {
|
||||
server_result_to_client(self.handler.lock().await.exec(params).await)
|
||||
}
|
||||
|
||||
pub(super) async fn exec_read(
|
||||
&self,
|
||||
params: ReadParams,
|
||||
) -> Result<ReadResponse, ExecServerError> {
|
||||
server_result_to_client(self.handler.lock().await.exec_read(params).await)
|
||||
}
|
||||
|
||||
pub(super) async fn exec_write(
|
||||
&self,
|
||||
params: WriteParams,
|
||||
) -> Result<WriteResponse, ExecServerError> {
|
||||
server_result_to_client(self.handler.lock().await.exec_write(params).await)
|
||||
}
|
||||
|
||||
pub(super) async fn terminate(
|
||||
&self,
|
||||
params: TerminateParams,
|
||||
) -> Result<TerminateResponse, ExecServerError> {
|
||||
server_result_to_client(self.handler.lock().await.terminate(params).await)
|
||||
}
|
||||
|
||||
pub(super) async fn fs_read_file(
|
||||
&self,
|
||||
params: FsReadFileParams,
|
||||
) -> Result<FsReadFileResponse, ExecServerError> {
|
||||
server_result_to_client(self.handler.lock().await.fs_read_file(params).await)
|
||||
}
|
||||
|
||||
pub(super) async fn fs_write_file(
|
||||
&self,
|
||||
params: FsWriteFileParams,
|
||||
) -> Result<FsWriteFileResponse, ExecServerError> {
|
||||
server_result_to_client(self.handler.lock().await.fs_write_file(params).await)
|
||||
}
|
||||
|
||||
pub(super) async fn fs_create_directory(
|
||||
&self,
|
||||
params: FsCreateDirectoryParams,
|
||||
) -> Result<FsCreateDirectoryResponse, ExecServerError> {
|
||||
server_result_to_client(self.handler.lock().await.fs_create_directory(params).await)
|
||||
}
|
||||
|
||||
pub(super) async fn fs_get_metadata(
|
||||
&self,
|
||||
params: FsGetMetadataParams,
|
||||
) -> Result<FsGetMetadataResponse, ExecServerError> {
|
||||
server_result_to_client(self.handler.lock().await.fs_get_metadata(params).await)
|
||||
}
|
||||
|
||||
pub(super) async fn fs_read_directory(
|
||||
&self,
|
||||
params: FsReadDirectoryParams,
|
||||
) -> Result<FsReadDirectoryResponse, ExecServerError> {
|
||||
server_result_to_client(self.handler.lock().await.fs_read_directory(params).await)
|
||||
}
|
||||
|
||||
pub(super) async fn fs_remove(
|
||||
&self,
|
||||
params: FsRemoveParams,
|
||||
) -> Result<FsRemoveResponse, ExecServerError> {
|
||||
server_result_to_client(self.handler.lock().await.fs_remove(params).await)
|
||||
}
|
||||
|
||||
pub(super) async fn fs_copy(
|
||||
&self,
|
||||
params: FsCopyParams,
|
||||
) -> Result<FsCopyResponse, ExecServerError> {
|
||||
server_result_to_client(self.handler.lock().await.fs_copy(params).await)
|
||||
}
|
||||
}
|
||||
72
codex-rs/exec-server/src/client/process.rs
Normal file
72
codex-rs/exec-server/src/client/process.rs
Normal file
@@ -0,0 +1,72 @@
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex as StdMutex;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::atomic::Ordering;
|
||||
|
||||
use tokio::sync::broadcast;
|
||||
|
||||
use super::ExecServerClient;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub(super) struct ExecServerOutput {
|
||||
pub(super) stream: crate::protocol::ExecOutputStream,
|
||||
pub(super) chunk: Vec<u8>,
|
||||
}
|
||||
|
||||
pub(super) struct ExecServerProcess {
|
||||
pub(super) process_id: String,
|
||||
pub(super) output_rx: broadcast::Receiver<ExecServerOutput>,
|
||||
pub(super) status: Arc<RemoteProcessStatus>,
|
||||
pub(super) client: ExecServerClient,
|
||||
}
|
||||
|
||||
impl ExecServerProcess {
|
||||
pub(super) fn output_receiver(&self) -> broadcast::Receiver<ExecServerOutput> {
|
||||
self.output_rx.resubscribe()
|
||||
}
|
||||
|
||||
pub(super) fn has_exited(&self) -> bool {
|
||||
self.status.has_exited()
|
||||
}
|
||||
|
||||
pub(super) fn exit_code(&self) -> Option<i32> {
|
||||
self.status.exit_code()
|
||||
}
|
||||
|
||||
pub(super) fn terminate(&self) {
|
||||
let client = self.client.clone();
|
||||
let process_id = self.process_id.clone();
|
||||
tokio::spawn(async move {
|
||||
let _ = client.terminate(&process_id).await;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) struct RemoteProcessStatus {
|
||||
exited: AtomicBool,
|
||||
exit_code: StdMutex<Option<i32>>,
|
||||
}
|
||||
|
||||
impl RemoteProcessStatus {
|
||||
pub(super) fn new() -> Self {
|
||||
Self {
|
||||
exited: AtomicBool::new(false),
|
||||
exit_code: StdMutex::new(None),
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn has_exited(&self) -> bool {
|
||||
self.exited.load(Ordering::SeqCst)
|
||||
}
|
||||
|
||||
pub(super) fn exit_code(&self) -> Option<i32> {
|
||||
self.exit_code.lock().ok().and_then(|guard| *guard)
|
||||
}
|
||||
|
||||
pub(super) fn mark_exited(&self, exit_code: Option<i32>) {
|
||||
self.exited.store(true, Ordering::SeqCst);
|
||||
if let Ok(mut guard) = self.exit_code.lock() {
|
||||
*guard = exit_code;
|
||||
}
|
||||
}
|
||||
}
|
||||
888
codex-rs/exec-server/src/client/tests.rs
Normal file
888
codex-rs/exec-server/src/client/tests.rs
Normal file
@@ -0,0 +1,888 @@
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
|
||||
use pretty_assertions::assert_eq;
|
||||
use tokio::io::AsyncBufReadExt;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::io::BufReader;
|
||||
use tokio::time::timeout;
|
||||
|
||||
use super::ExecServerClient;
|
||||
use super::ExecServerClientConnectOptions;
|
||||
use super::ExecServerError;
|
||||
use super::ExecServerOutput;
|
||||
use crate::protocol::EXEC_METHOD;
|
||||
use crate::protocol::EXEC_OUTPUT_DELTA_METHOD;
|
||||
use crate::protocol::EXEC_TERMINATE_METHOD;
|
||||
use crate::protocol::ExecOutputStream;
|
||||
use crate::protocol::ExecParams;
|
||||
use crate::protocol::INITIALIZE_METHOD;
|
||||
use crate::protocol::INITIALIZED_METHOD;
|
||||
use crate::protocol::PROTOCOL_VERSION;
|
||||
use crate::protocol::ReadParams;
|
||||
use codex_app_server_protocol::JSONRPCError;
|
||||
use codex_app_server_protocol::JSONRPCErrorError;
|
||||
use codex_app_server_protocol::JSONRPCMessage;
|
||||
use codex_app_server_protocol::JSONRPCNotification;
|
||||
use codex_app_server_protocol::JSONRPCRequest;
|
||||
use codex_app_server_protocol::JSONRPCResponse;
|
||||
|
||||
fn test_options() -> ExecServerClientConnectOptions {
|
||||
ExecServerClientConnectOptions {
|
||||
client_name: "test-client".to_string(),
|
||||
initialize_timeout: Duration::from_secs(1),
|
||||
}
|
||||
}
|
||||
|
||||
async fn read_jsonrpc_line<R>(lines: &mut tokio::io::Lines<BufReader<R>>) -> JSONRPCMessage
|
||||
where
|
||||
R: tokio::io::AsyncRead + Unpin,
|
||||
{
|
||||
let next_line = timeout(Duration::from_secs(1), lines.next_line()).await;
|
||||
let line_result = match next_line {
|
||||
Ok(line_result) => line_result,
|
||||
Err(err) => panic!("timed out waiting for JSON-RPC line: {err}"),
|
||||
};
|
||||
let maybe_line = match line_result {
|
||||
Ok(maybe_line) => maybe_line,
|
||||
Err(err) => panic!("failed to read JSON-RPC line: {err}"),
|
||||
};
|
||||
let line = match maybe_line {
|
||||
Some(line) => line,
|
||||
None => panic!("server connection closed before JSON-RPC line arrived"),
|
||||
};
|
||||
match serde_json::from_str::<JSONRPCMessage>(&line) {
|
||||
Ok(message) => message,
|
||||
Err(err) => panic!("failed to parse JSON-RPC line: {err}"),
|
||||
}
|
||||
}
|
||||
|
||||
async fn write_jsonrpc_line<W>(writer: &mut W, message: JSONRPCMessage)
|
||||
where
|
||||
W: tokio::io::AsyncWrite + Unpin,
|
||||
{
|
||||
let encoded = match serde_json::to_string(&message) {
|
||||
Ok(encoded) => encoded,
|
||||
Err(err) => panic!("failed to encode JSON-RPC message: {err}"),
|
||||
};
|
||||
if let Err(err) = writer.write_all(format!("{encoded}\n").as_bytes()).await {
|
||||
panic!("failed to write JSON-RPC line: {err}");
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn connect_stdio_performs_initialize_handshake() {
|
||||
let (client_stdin, server_reader) = tokio::io::duplex(4096);
|
||||
let (mut server_writer, client_stdout) = tokio::io::duplex(4096);
|
||||
|
||||
let server = tokio::spawn(async move {
|
||||
let mut lines = BufReader::new(server_reader).lines();
|
||||
|
||||
let initialize = read_jsonrpc_line(&mut lines).await;
|
||||
let JSONRPCMessage::Request(request) = initialize else {
|
||||
panic!("expected initialize request");
|
||||
};
|
||||
assert_eq!(request.method, INITIALIZE_METHOD);
|
||||
assert_eq!(
|
||||
request.params,
|
||||
Some(serde_json::json!({ "clientName": "test-client" }))
|
||||
);
|
||||
write_jsonrpc_line(
|
||||
&mut server_writer,
|
||||
JSONRPCMessage::Response(JSONRPCResponse {
|
||||
id: request.id,
|
||||
result: serde_json::json!({ "protocolVersion": PROTOCOL_VERSION }),
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
let initialized = read_jsonrpc_line(&mut lines).await;
|
||||
let JSONRPCMessage::Notification(JSONRPCNotification { method, params }) = initialized
|
||||
else {
|
||||
panic!("expected initialized notification");
|
||||
};
|
||||
assert_eq!(method, INITIALIZED_METHOD);
|
||||
assert_eq!(params, Some(serde_json::json!({})));
|
||||
});
|
||||
|
||||
let client = ExecServerClient::connect_stdio(client_stdin, client_stdout, test_options()).await;
|
||||
if let Err(err) = client {
|
||||
panic!("failed to connect test client: {err}");
|
||||
}
|
||||
|
||||
if let Err(err) = server.await {
|
||||
panic!("server task failed: {err}");
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn connect_in_process_starts_processes_without_jsonrpc_transport() {
|
||||
let client = match ExecServerClient::connect_in_process(test_options()).await {
|
||||
Ok(client) => client,
|
||||
Err(err) => panic!("failed to connect in-process client: {err}"),
|
||||
};
|
||||
|
||||
let process = match client
|
||||
.start_process(ExecParams {
|
||||
process_id: "proc-1".to_string(),
|
||||
argv: vec!["printf".to_string(), "hello".to_string()],
|
||||
cwd: std::env::current_dir().unwrap_or_else(|err| panic!("missing cwd: {err}")),
|
||||
env: HashMap::new(),
|
||||
tty: false,
|
||||
arg0: None,
|
||||
sandbox: None,
|
||||
})
|
||||
.await
|
||||
{
|
||||
Ok(process) => process,
|
||||
Err(err) => panic!("failed to start in-process child: {err}"),
|
||||
};
|
||||
|
||||
let mut output = process.output_receiver();
|
||||
let output = timeout(Duration::from_secs(1), output.recv())
|
||||
.await
|
||||
.unwrap_or_else(|err| panic!("timed out waiting for process output: {err}"))
|
||||
.unwrap_or_else(|err| panic!("failed to receive process output: {err}"));
|
||||
assert_eq!(
|
||||
output,
|
||||
ExecServerOutput {
|
||||
stream: crate::protocol::ExecOutputStream::Stdout,
|
||||
chunk: b"hello".to_vec(),
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn connect_in_process_read_returns_retained_output_and_exit_state() {
|
||||
let client = match ExecServerClient::connect_in_process(test_options()).await {
|
||||
Ok(client) => client,
|
||||
Err(err) => panic!("failed to connect in-process client: {err}"),
|
||||
};
|
||||
|
||||
let response = match client
|
||||
.exec(ExecParams {
|
||||
process_id: "proc-1".to_string(),
|
||||
argv: vec!["printf".to_string(), "hello".to_string()],
|
||||
cwd: std::env::current_dir().unwrap_or_else(|err| panic!("missing cwd: {err}")),
|
||||
env: HashMap::new(),
|
||||
tty: false,
|
||||
arg0: None,
|
||||
sandbox: None,
|
||||
})
|
||||
.await
|
||||
{
|
||||
Ok(response) => response,
|
||||
Err(err) => panic!("failed to start in-process child: {err}"),
|
||||
};
|
||||
|
||||
let process_id = response.process_id.clone();
|
||||
let read = match client
|
||||
.read(ReadParams {
|
||||
process_id: process_id.clone(),
|
||||
after_seq: None,
|
||||
max_bytes: None,
|
||||
wait_ms: Some(1000),
|
||||
})
|
||||
.await
|
||||
{
|
||||
Ok(read) => read,
|
||||
Err(err) => panic!("failed to read in-process child output: {err}"),
|
||||
};
|
||||
|
||||
assert_eq!(read.chunks.len(), 1);
|
||||
assert_eq!(read.chunks[0].seq, 1);
|
||||
assert_eq!(read.chunks[0].stream, ExecOutputStream::Stdout);
|
||||
assert_eq!(read.chunks[0].chunk.clone().into_inner(), b"hello".to_vec());
|
||||
assert_eq!(read.next_seq, 2);
|
||||
let read = if read.exited {
|
||||
read
|
||||
} else {
|
||||
match client
|
||||
.read(ReadParams {
|
||||
process_id,
|
||||
after_seq: Some(read.next_seq - 1),
|
||||
max_bytes: None,
|
||||
wait_ms: Some(1000),
|
||||
})
|
||||
.await
|
||||
{
|
||||
Ok(read) => read,
|
||||
Err(err) => panic!("failed to wait for in-process child exit: {err}"),
|
||||
}
|
||||
};
|
||||
assert!(read.exited);
|
||||
assert_eq!(read.exit_code, Some(0));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn connect_in_process_rejects_invalid_exec_params_from_handler() {
|
||||
let client = match ExecServerClient::connect_in_process(test_options()).await {
|
||||
Ok(client) => client,
|
||||
Err(err) => panic!("failed to connect in-process client: {err}"),
|
||||
};
|
||||
|
||||
let result = client
|
||||
.start_process(ExecParams {
|
||||
process_id: "proc-1".to_string(),
|
||||
argv: Vec::new(),
|
||||
cwd: std::env::current_dir().unwrap_or_else(|err| panic!("missing cwd: {err}")),
|
||||
env: HashMap::new(),
|
||||
tty: false,
|
||||
arg0: None,
|
||||
sandbox: None,
|
||||
})
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Err(ExecServerError::Server { code, message }) => {
|
||||
assert_eq!(code, -32602);
|
||||
assert_eq!(message, "argv must not be empty");
|
||||
}
|
||||
Err(err) => panic!("unexpected in-process exec failure: {err}"),
|
||||
Ok(_) => panic!("expected invalid params error"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn connect_in_process_rejects_writes_to_unknown_processes() {
|
||||
let client = match ExecServerClient::connect_in_process(test_options()).await {
|
||||
Ok(client) => client,
|
||||
Err(err) => panic!("failed to connect in-process client: {err}"),
|
||||
};
|
||||
|
||||
let result = client.write("missing", b"input".to_vec()).await;
|
||||
|
||||
match result {
|
||||
Err(ExecServerError::Server { code, message }) => {
|
||||
assert_eq!(code, -32600);
|
||||
assert_eq!(message, "unknown process id missing");
|
||||
}
|
||||
Err(err) => panic!("unexpected in-process write failure: {err}"),
|
||||
Ok(_) => panic!("expected unknown process error"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn connect_in_process_terminate_marks_process_exited() {
|
||||
let client = match ExecServerClient::connect_in_process(test_options()).await {
|
||||
Ok(client) => client,
|
||||
Err(err) => panic!("failed to connect in-process client: {err}"),
|
||||
};
|
||||
|
||||
let process = match client
|
||||
.start_process(ExecParams {
|
||||
process_id: "proc-1".to_string(),
|
||||
argv: vec!["sleep".to_string(), "30".to_string()],
|
||||
cwd: std::env::current_dir().unwrap_or_else(|err| panic!("missing cwd: {err}")),
|
||||
env: HashMap::new(),
|
||||
tty: false,
|
||||
arg0: None,
|
||||
sandbox: None,
|
||||
})
|
||||
.await
|
||||
{
|
||||
Ok(process) => process,
|
||||
Err(err) => panic!("failed to start in-process child: {err}"),
|
||||
};
|
||||
|
||||
if let Err(err) = client.terminate(&process.process_id).await {
|
||||
panic!("failed to terminate in-process child: {err}");
|
||||
}
|
||||
|
||||
timeout(Duration::from_secs(2), async {
|
||||
loop {
|
||||
if process.has_exited() {
|
||||
break;
|
||||
}
|
||||
tokio::time::sleep(Duration::from_millis(10)).await;
|
||||
}
|
||||
})
|
||||
.await
|
||||
.unwrap_or_else(|err| panic!("timed out waiting for in-process child to exit: {err}"));
|
||||
|
||||
assert!(process.has_exited());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn dropping_in_process_client_terminates_running_processes() {
|
||||
let marker_path = std::env::temp_dir().join(format!(
|
||||
"codex-exec-server-inprocess-drop-{}-{}",
|
||||
std::process::id(),
|
||||
std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.expect("system time")
|
||||
.as_nanos()
|
||||
));
|
||||
let _ = std::fs::remove_file(&marker_path);
|
||||
|
||||
{
|
||||
let client = match ExecServerClient::connect_in_process(test_options()).await {
|
||||
Ok(client) => client,
|
||||
Err(err) => panic!("failed to connect in-process client: {err}"),
|
||||
};
|
||||
|
||||
let _ = client
|
||||
.exec(ExecParams {
|
||||
process_id: "proc-1".to_string(),
|
||||
argv: vec![
|
||||
"/bin/sh".to_string(),
|
||||
"-c".to_string(),
|
||||
format!("sleep 2; printf dropped > {}", marker_path.display()),
|
||||
],
|
||||
cwd: std::env::current_dir().expect("cwd"),
|
||||
env: HashMap::new(),
|
||||
tty: false,
|
||||
arg0: None,
|
||||
sandbox: None,
|
||||
})
|
||||
.await
|
||||
.unwrap_or_else(|err| panic!("failed to start in-process child: {err}"));
|
||||
}
|
||||
|
||||
tokio::time::sleep(Duration::from_secs(3)).await;
|
||||
assert!(
|
||||
!marker_path.exists(),
|
||||
"dropping the in-process client should terminate managed children"
|
||||
);
|
||||
let _ = std::fs::remove_file(&marker_path);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn connect_stdio_returns_initialize_errors() {
|
||||
let (client_stdin, server_reader) = tokio::io::duplex(4096);
|
||||
let (mut server_writer, client_stdout) = tokio::io::duplex(4096);
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut lines = BufReader::new(server_reader).lines();
|
||||
|
||||
let initialize = read_jsonrpc_line(&mut lines).await;
|
||||
let JSONRPCMessage::Request(request) = initialize else {
|
||||
panic!("expected initialize request");
|
||||
};
|
||||
write_jsonrpc_line(
|
||||
&mut server_writer,
|
||||
JSONRPCMessage::Error(JSONRPCError {
|
||||
id: request.id,
|
||||
error: JSONRPCErrorError {
|
||||
code: -32600,
|
||||
message: "rejected".to_string(),
|
||||
data: None,
|
||||
},
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
});
|
||||
|
||||
let result = ExecServerClient::connect_stdio(client_stdin, client_stdout, test_options()).await;
|
||||
|
||||
match result {
|
||||
Err(ExecServerError::Server { code, message }) => {
|
||||
assert_eq!(code, -32600);
|
||||
assert_eq!(message, "rejected");
|
||||
}
|
||||
Err(err) => panic!("unexpected initialize failure: {err}"),
|
||||
Ok(_) => panic!("expected initialize failure"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn start_process_cleans_up_registered_process_after_request_error() {
|
||||
let (client_stdin, server_reader) = tokio::io::duplex(4096);
|
||||
let (mut server_writer, client_stdout) = tokio::io::duplex(4096);
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut lines = BufReader::new(server_reader).lines();
|
||||
|
||||
let initialize = read_jsonrpc_line(&mut lines).await;
|
||||
let JSONRPCMessage::Request(initialize_request) = initialize else {
|
||||
panic!("expected initialize request");
|
||||
};
|
||||
write_jsonrpc_line(
|
||||
&mut server_writer,
|
||||
JSONRPCMessage::Response(JSONRPCResponse {
|
||||
id: initialize_request.id,
|
||||
result: serde_json::json!({ "protocolVersion": PROTOCOL_VERSION }),
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
let initialized = read_jsonrpc_line(&mut lines).await;
|
||||
let JSONRPCMessage::Notification(notification) = initialized else {
|
||||
panic!("expected initialized notification");
|
||||
};
|
||||
assert_eq!(notification.method, INITIALIZED_METHOD);
|
||||
|
||||
let exec_request = read_jsonrpc_line(&mut lines).await;
|
||||
let JSONRPCMessage::Request(JSONRPCRequest { id, method, .. }) = exec_request else {
|
||||
panic!("expected exec request");
|
||||
};
|
||||
assert_eq!(method, EXEC_METHOD);
|
||||
write_jsonrpc_line(
|
||||
&mut server_writer,
|
||||
JSONRPCMessage::Error(JSONRPCError {
|
||||
id,
|
||||
error: JSONRPCErrorError {
|
||||
code: -32600,
|
||||
message: "duplicate process".to_string(),
|
||||
data: None,
|
||||
},
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
});
|
||||
|
||||
let client =
|
||||
match ExecServerClient::connect_stdio(client_stdin, client_stdout, test_options()).await {
|
||||
Ok(client) => client,
|
||||
Err(err) => panic!("failed to connect test client: {err}"),
|
||||
};
|
||||
|
||||
let result = client
|
||||
.start_process(ExecParams {
|
||||
process_id: "proc-1".to_string(),
|
||||
argv: vec!["bash".to_string(), "-lc".to_string(), "true".to_string()],
|
||||
cwd: std::env::current_dir().unwrap_or_else(|err| panic!("missing cwd: {err}")),
|
||||
env: HashMap::new(),
|
||||
tty: true,
|
||||
arg0: None,
|
||||
sandbox: None,
|
||||
})
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Err(ExecServerError::Server { code, message }) => {
|
||||
assert_eq!(code, -32600);
|
||||
assert_eq!(message, "duplicate process");
|
||||
}
|
||||
Err(err) => panic!("unexpected start_process failure: {err}"),
|
||||
Ok(_) => panic!("expected start_process failure"),
|
||||
}
|
||||
|
||||
assert!(
|
||||
client.inner.pending.lock().await.is_empty(),
|
||||
"failed requests should not leave pending request state behind"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn connect_stdio_times_out_during_initialize_handshake() {
|
||||
let (client_stdin, server_reader) = tokio::io::duplex(4096);
|
||||
let (_server_writer, client_stdout) = tokio::io::duplex(4096);
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut lines = BufReader::new(server_reader).lines();
|
||||
let _ = read_jsonrpc_line(&mut lines).await;
|
||||
tokio::time::sleep(Duration::from_millis(200)).await;
|
||||
});
|
||||
|
||||
let result = ExecServerClient::connect_stdio(
|
||||
client_stdin,
|
||||
client_stdout,
|
||||
ExecServerClientConnectOptions {
|
||||
client_name: "test-client".to_string(),
|
||||
initialize_timeout: Duration::from_millis(25),
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Err(ExecServerError::InitializeTimedOut { timeout }) => {
|
||||
assert_eq!(timeout, Duration::from_millis(25));
|
||||
}
|
||||
Err(err) => panic!("unexpected initialize timeout failure: {err}"),
|
||||
Ok(_) => panic!("expected initialize timeout"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn start_process_preserves_output_stream_metadata() {
|
||||
let (client_stdin, server_reader) = tokio::io::duplex(4096);
|
||||
let (mut server_writer, client_stdout) = tokio::io::duplex(4096);
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut lines = BufReader::new(server_reader).lines();
|
||||
|
||||
let initialize = read_jsonrpc_line(&mut lines).await;
|
||||
let JSONRPCMessage::Request(initialize_request) = initialize else {
|
||||
panic!("expected initialize request");
|
||||
};
|
||||
write_jsonrpc_line(
|
||||
&mut server_writer,
|
||||
JSONRPCMessage::Response(JSONRPCResponse {
|
||||
id: initialize_request.id,
|
||||
result: serde_json::json!({ "protocolVersion": PROTOCOL_VERSION }),
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
let initialized = read_jsonrpc_line(&mut lines).await;
|
||||
let JSONRPCMessage::Notification(notification) = initialized else {
|
||||
panic!("expected initialized notification");
|
||||
};
|
||||
assert_eq!(notification.method, INITIALIZED_METHOD);
|
||||
|
||||
let exec_request = read_jsonrpc_line(&mut lines).await;
|
||||
let JSONRPCMessage::Request(JSONRPCRequest { id, method, .. }) = exec_request else {
|
||||
panic!("expected exec request");
|
||||
};
|
||||
assert_eq!(method, EXEC_METHOD);
|
||||
write_jsonrpc_line(
|
||||
&mut server_writer,
|
||||
JSONRPCMessage::Response(JSONRPCResponse {
|
||||
id,
|
||||
result: serde_json::json!({ "processId": "proc-1" }),
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
tokio::time::sleep(Duration::from_millis(25)).await;
|
||||
write_jsonrpc_line(
|
||||
&mut server_writer,
|
||||
JSONRPCMessage::Notification(JSONRPCNotification {
|
||||
method: EXEC_OUTPUT_DELTA_METHOD.to_string(),
|
||||
params: Some(serde_json::json!({
|
||||
"processId": "proc-1",
|
||||
"stream": "stderr",
|
||||
"chunk": "ZXJyb3IK"
|
||||
})),
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||
});
|
||||
|
||||
let client =
|
||||
match ExecServerClient::connect_stdio(client_stdin, client_stdout, test_options()).await {
|
||||
Ok(client) => client,
|
||||
Err(err) => panic!("failed to connect test client: {err}"),
|
||||
};
|
||||
|
||||
let process = match client
|
||||
.start_process(ExecParams {
|
||||
process_id: "proc-1".to_string(),
|
||||
argv: vec!["bash".to_string(), "-lc".to_string(), "true".to_string()],
|
||||
cwd: std::env::current_dir().unwrap_or_else(|err| panic!("missing cwd: {err}")),
|
||||
env: HashMap::new(),
|
||||
tty: true,
|
||||
arg0: None,
|
||||
sandbox: None,
|
||||
})
|
||||
.await
|
||||
{
|
||||
Ok(process) => process,
|
||||
Err(err) => panic!("failed to start process: {err}"),
|
||||
};
|
||||
|
||||
let mut output = process.output_receiver();
|
||||
let output = timeout(Duration::from_secs(1), output.recv())
|
||||
.await
|
||||
.unwrap_or_else(|err| panic!("timed out waiting for process output: {err}"))
|
||||
.unwrap_or_else(|err| panic!("failed to receive process output: {err}"));
|
||||
assert_eq!(output.stream, ExecOutputStream::Stderr);
|
||||
assert_eq!(output.chunk, b"error\n".to_vec());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn terminate_does_not_mark_process_exited_before_exit_notification() {
|
||||
let (client_stdin, server_reader) = tokio::io::duplex(4096);
|
||||
let (mut server_writer, client_stdout) = tokio::io::duplex(4096);
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut lines = BufReader::new(server_reader).lines();
|
||||
|
||||
let initialize = read_jsonrpc_line(&mut lines).await;
|
||||
let JSONRPCMessage::Request(initialize_request) = initialize else {
|
||||
panic!("expected initialize request");
|
||||
};
|
||||
write_jsonrpc_line(
|
||||
&mut server_writer,
|
||||
JSONRPCMessage::Response(JSONRPCResponse {
|
||||
id: initialize_request.id,
|
||||
result: serde_json::json!({ "protocolVersion": PROTOCOL_VERSION }),
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
let initialized = read_jsonrpc_line(&mut lines).await;
|
||||
let JSONRPCMessage::Notification(notification) = initialized else {
|
||||
panic!("expected initialized notification");
|
||||
};
|
||||
assert_eq!(notification.method, INITIALIZED_METHOD);
|
||||
|
||||
let exec_request = read_jsonrpc_line(&mut lines).await;
|
||||
let JSONRPCMessage::Request(JSONRPCRequest { id, method, .. }) = exec_request else {
|
||||
panic!("expected exec request");
|
||||
};
|
||||
assert_eq!(method, EXEC_METHOD);
|
||||
write_jsonrpc_line(
|
||||
&mut server_writer,
|
||||
JSONRPCMessage::Response(JSONRPCResponse {
|
||||
id,
|
||||
result: serde_json::json!({ "processId": "proc-1" }),
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
let terminate_request = read_jsonrpc_line(&mut lines).await;
|
||||
let JSONRPCMessage::Request(JSONRPCRequest { id, method, .. }) = terminate_request else {
|
||||
panic!("expected terminate request");
|
||||
};
|
||||
assert_eq!(method, EXEC_TERMINATE_METHOD);
|
||||
write_jsonrpc_line(
|
||||
&mut server_writer,
|
||||
JSONRPCMessage::Response(JSONRPCResponse {
|
||||
id,
|
||||
result: serde_json::json!({ "running": true }),
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||
});
|
||||
|
||||
let client =
|
||||
match ExecServerClient::connect_stdio(client_stdin, client_stdout, test_options()).await {
|
||||
Ok(client) => client,
|
||||
Err(err) => panic!("failed to connect test client: {err}"),
|
||||
};
|
||||
|
||||
let process = match client
|
||||
.start_process(ExecParams {
|
||||
process_id: "proc-1".to_string(),
|
||||
argv: vec!["bash".to_string(), "-lc".to_string(), "true".to_string()],
|
||||
cwd: std::env::current_dir().unwrap_or_else(|err| panic!("missing cwd: {err}")),
|
||||
env: HashMap::new(),
|
||||
tty: true,
|
||||
arg0: None,
|
||||
sandbox: None,
|
||||
})
|
||||
.await
|
||||
{
|
||||
Ok(process) => process,
|
||||
Err(err) => panic!("failed to start process: {err}"),
|
||||
};
|
||||
|
||||
process.terminate();
|
||||
tokio::time::sleep(Duration::from_millis(25)).await;
|
||||
assert!(!process.has_exited(), "terminate should not imply exit");
|
||||
assert_eq!(process.exit_code(), None);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn start_process_uses_protocol_process_ids() {
|
||||
let (client_stdin, server_reader) = tokio::io::duplex(4096);
|
||||
let (mut server_writer, client_stdout) = tokio::io::duplex(4096);
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut lines = BufReader::new(server_reader).lines();
|
||||
|
||||
let initialize = read_jsonrpc_line(&mut lines).await;
|
||||
let JSONRPCMessage::Request(initialize_request) = initialize else {
|
||||
panic!("expected initialize request");
|
||||
};
|
||||
write_jsonrpc_line(
|
||||
&mut server_writer,
|
||||
JSONRPCMessage::Response(JSONRPCResponse {
|
||||
id: initialize_request.id,
|
||||
result: serde_json::json!({ "protocolVersion": PROTOCOL_VERSION }),
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
let initialized = read_jsonrpc_line(&mut lines).await;
|
||||
let JSONRPCMessage::Notification(notification) = initialized else {
|
||||
panic!("expected initialized notification");
|
||||
};
|
||||
assert_eq!(notification.method, INITIALIZED_METHOD);
|
||||
|
||||
let exec_request = read_jsonrpc_line(&mut lines).await;
|
||||
let JSONRPCMessage::Request(JSONRPCRequest { id, method, .. }) = exec_request else {
|
||||
panic!("expected exec request");
|
||||
};
|
||||
assert_eq!(method, EXEC_METHOD);
|
||||
write_jsonrpc_line(
|
||||
&mut server_writer,
|
||||
JSONRPCMessage::Response(JSONRPCResponse {
|
||||
id,
|
||||
result: serde_json::json!({ "processId": "other-proc" }),
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
});
|
||||
|
||||
let client =
|
||||
match ExecServerClient::connect_stdio(client_stdin, client_stdout, test_options()).await {
|
||||
Ok(client) => client,
|
||||
Err(err) => panic!("failed to connect test client: {err}"),
|
||||
};
|
||||
|
||||
let process = match client
|
||||
.start_process(ExecParams {
|
||||
process_id: "proc-1".to_string(),
|
||||
argv: vec!["bash".to_string(), "-lc".to_string(), "true".to_string()],
|
||||
cwd: std::env::current_dir().unwrap_or_else(|err| panic!("missing cwd: {err}")),
|
||||
env: HashMap::new(),
|
||||
tty: true,
|
||||
arg0: None,
|
||||
sandbox: None,
|
||||
})
|
||||
.await
|
||||
{
|
||||
Ok(process) => process,
|
||||
Err(err) => panic!("failed to start process: {err}"),
|
||||
};
|
||||
|
||||
assert_eq!(process.process_id, "other-proc");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn start_process_routes_output_for_protocol_process_ids() {
|
||||
let (client_stdin, server_reader) = tokio::io::duplex(4096);
|
||||
let (mut server_writer, client_stdout) = tokio::io::duplex(4096);
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut lines = BufReader::new(server_reader).lines();
|
||||
|
||||
let initialize = read_jsonrpc_line(&mut lines).await;
|
||||
let JSONRPCMessage::Request(initialize_request) = initialize else {
|
||||
panic!("expected initialize request");
|
||||
};
|
||||
write_jsonrpc_line(
|
||||
&mut server_writer,
|
||||
JSONRPCMessage::Response(JSONRPCResponse {
|
||||
id: initialize_request.id,
|
||||
result: serde_json::json!({ "protocolVersion": PROTOCOL_VERSION }),
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
let initialized = read_jsonrpc_line(&mut lines).await;
|
||||
let JSONRPCMessage::Notification(notification) = initialized else {
|
||||
panic!("expected initialized notification");
|
||||
};
|
||||
assert_eq!(notification.method, INITIALIZED_METHOD);
|
||||
|
||||
let exec_request = read_jsonrpc_line(&mut lines).await;
|
||||
let JSONRPCMessage::Request(JSONRPCRequest { id, method, .. }) = exec_request else {
|
||||
panic!("expected exec request");
|
||||
};
|
||||
assert_eq!(method, EXEC_METHOD);
|
||||
write_jsonrpc_line(
|
||||
&mut server_writer,
|
||||
JSONRPCMessage::Response(JSONRPCResponse {
|
||||
id,
|
||||
result: serde_json::json!({ "processId": "proc-1" }),
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
tokio::time::sleep(Duration::from_millis(25)).await;
|
||||
write_jsonrpc_line(
|
||||
&mut server_writer,
|
||||
JSONRPCMessage::Notification(JSONRPCNotification {
|
||||
method: EXEC_OUTPUT_DELTA_METHOD.to_string(),
|
||||
params: Some(serde_json::json!({
|
||||
"processId": "proc-1",
|
||||
"stream": "stdout",
|
||||
"chunk": "YWxpdmUK"
|
||||
})),
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
});
|
||||
|
||||
let client =
|
||||
match ExecServerClient::connect_stdio(client_stdin, client_stdout, test_options()).await {
|
||||
Ok(client) => client,
|
||||
Err(err) => panic!("failed to connect test client: {err}"),
|
||||
};
|
||||
|
||||
let first_process = match client
|
||||
.start_process(ExecParams {
|
||||
process_id: "proc-1".to_string(),
|
||||
argv: vec!["bash".to_string(), "-lc".to_string(), "true".to_string()],
|
||||
cwd: std::env::current_dir().unwrap_or_else(|err| panic!("missing cwd: {err}")),
|
||||
env: HashMap::new(),
|
||||
tty: true,
|
||||
arg0: None,
|
||||
sandbox: None,
|
||||
})
|
||||
.await
|
||||
{
|
||||
Ok(process) => process,
|
||||
Err(err) => panic!("failed to start first process: {err}"),
|
||||
};
|
||||
|
||||
let mut output = first_process.output_receiver();
|
||||
let output = timeout(Duration::from_secs(1), output.recv())
|
||||
.await
|
||||
.unwrap_or_else(|err| panic!("timed out waiting for process output: {err}"))
|
||||
.unwrap_or_else(|err| panic!("failed to receive process output: {err}"));
|
||||
assert_eq!(output.stream, ExecOutputStream::Stdout);
|
||||
assert_eq!(output.chunk, b"alive\n".to_vec());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn transport_shutdown_marks_processes_exited_without_exit_codes() {
|
||||
let (client_stdin, server_reader) = tokio::io::duplex(4096);
|
||||
let (mut server_writer, client_stdout) = tokio::io::duplex(4096);
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut lines = BufReader::new(server_reader).lines();
|
||||
|
||||
let initialize = read_jsonrpc_line(&mut lines).await;
|
||||
let JSONRPCMessage::Request(initialize_request) = initialize else {
|
||||
panic!("expected initialize request");
|
||||
};
|
||||
write_jsonrpc_line(
|
||||
&mut server_writer,
|
||||
JSONRPCMessage::Response(JSONRPCResponse {
|
||||
id: initialize_request.id,
|
||||
result: serde_json::json!({ "protocolVersion": PROTOCOL_VERSION }),
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
let initialized = read_jsonrpc_line(&mut lines).await;
|
||||
let JSONRPCMessage::Notification(notification) = initialized else {
|
||||
panic!("expected initialized notification");
|
||||
};
|
||||
assert_eq!(notification.method, INITIALIZED_METHOD);
|
||||
|
||||
let exec_request = read_jsonrpc_line(&mut lines).await;
|
||||
let JSONRPCMessage::Request(JSONRPCRequest { id, method, .. }) = exec_request else {
|
||||
panic!("expected exec request");
|
||||
};
|
||||
assert_eq!(method, EXEC_METHOD);
|
||||
write_jsonrpc_line(
|
||||
&mut server_writer,
|
||||
JSONRPCMessage::Response(JSONRPCResponse {
|
||||
id,
|
||||
result: serde_json::json!({ "processId": "proc-1" }),
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
drop(server_writer);
|
||||
});
|
||||
|
||||
let client =
|
||||
match ExecServerClient::connect_stdio(client_stdin, client_stdout, test_options()).await {
|
||||
Ok(client) => client,
|
||||
Err(err) => panic!("failed to connect test client: {err}"),
|
||||
};
|
||||
|
||||
let process = match client
|
||||
.start_process(ExecParams {
|
||||
process_id: "proc-1".to_string(),
|
||||
argv: vec!["bash".to_string(), "-lc".to_string(), "true".to_string()],
|
||||
cwd: std::env::current_dir().unwrap_or_else(|err| panic!("missing cwd: {err}")),
|
||||
env: HashMap::new(),
|
||||
tty: true,
|
||||
arg0: None,
|
||||
sandbox: None,
|
||||
})
|
||||
.await
|
||||
{
|
||||
Ok(process) => process,
|
||||
Err(err) => panic!("failed to start process: {err}"),
|
||||
};
|
||||
|
||||
let _ = process;
|
||||
}
|
||||
@@ -22,6 +22,7 @@ pub(crate) enum JsonRpcConnectionEvent {
|
||||
pub(crate) struct JsonRpcConnection {
|
||||
outgoing_tx: mpsc::Sender<JSONRPCMessage>,
|
||||
incoming_rx: mpsc::Receiver<JsonRpcConnectionEvent>,
|
||||
task_handles: Vec<tokio::task::JoinHandle<()>>,
|
||||
}
|
||||
|
||||
impl JsonRpcConnection {
|
||||
@@ -35,7 +36,7 @@ impl JsonRpcConnection {
|
||||
|
||||
let reader_label = connection_label.clone();
|
||||
let incoming_tx_for_reader = incoming_tx.clone();
|
||||
tokio::spawn(async move {
|
||||
let reader_task = tokio::spawn(async move {
|
||||
let mut lines = BufReader::new(reader).lines();
|
||||
loop {
|
||||
match lines.next_line().await {
|
||||
@@ -66,7 +67,7 @@ impl JsonRpcConnection {
|
||||
}
|
||||
}
|
||||
Ok(None) => {
|
||||
send_disconnected(&incoming_tx_for_reader, None).await;
|
||||
send_disconnected(&incoming_tx_for_reader, /*reason*/ None).await;
|
||||
break;
|
||||
}
|
||||
Err(err) => {
|
||||
@@ -83,7 +84,7 @@ impl JsonRpcConnection {
|
||||
}
|
||||
});
|
||||
|
||||
tokio::spawn(async move {
|
||||
let writer_task = tokio::spawn(async move {
|
||||
let mut writer = BufWriter::new(writer);
|
||||
while let Some(message) = outgoing_rx.recv().await {
|
||||
if let Err(err) = write_jsonrpc_line_message(&mut writer, &message).await {
|
||||
@@ -102,6 +103,7 @@ impl JsonRpcConnection {
|
||||
Self {
|
||||
outgoing_tx,
|
||||
incoming_rx,
|
||||
task_handles: vec![reader_task, writer_task],
|
||||
}
|
||||
}
|
||||
|
||||
@@ -115,7 +117,7 @@ impl JsonRpcConnection {
|
||||
|
||||
let reader_label = connection_label.clone();
|
||||
let incoming_tx_for_reader = incoming_tx.clone();
|
||||
tokio::spawn(async move {
|
||||
let reader_task = tokio::spawn(async move {
|
||||
loop {
|
||||
match websocket_reader.next().await {
|
||||
Some(Ok(Message::Text(text))) => {
|
||||
@@ -165,7 +167,7 @@ impl JsonRpcConnection {
|
||||
}
|
||||
}
|
||||
Some(Ok(Message::Close(_))) => {
|
||||
send_disconnected(&incoming_tx_for_reader, None).await;
|
||||
send_disconnected(&incoming_tx_for_reader, /*reason*/ None).await;
|
||||
break;
|
||||
}
|
||||
Some(Ok(Message::Ping(_))) | Some(Ok(Message::Pong(_))) => {}
|
||||
@@ -181,14 +183,14 @@ impl JsonRpcConnection {
|
||||
break;
|
||||
}
|
||||
None => {
|
||||
send_disconnected(&incoming_tx_for_reader, None).await;
|
||||
send_disconnected(&incoming_tx_for_reader, /*reason*/ None).await;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
tokio::spawn(async move {
|
||||
let writer_task = tokio::spawn(async move {
|
||||
while let Some(message) = outgoing_rx.recv().await {
|
||||
match serialize_jsonrpc_message(&message) {
|
||||
Ok(encoded) => {
|
||||
@@ -221,6 +223,7 @@ impl JsonRpcConnection {
|
||||
Self {
|
||||
outgoing_tx,
|
||||
incoming_rx,
|
||||
task_handles: vec![reader_task, writer_task],
|
||||
}
|
||||
}
|
||||
|
||||
@@ -229,8 +232,9 @@ impl JsonRpcConnection {
|
||||
) -> (
|
||||
mpsc::Sender<JSONRPCMessage>,
|
||||
mpsc::Receiver<JsonRpcConnectionEvent>,
|
||||
Vec<tokio::task::JoinHandle<()>>,
|
||||
) {
|
||||
(self.outgoing_tx, self.incoming_rx)
|
||||
(self.outgoing_tx, self.incoming_rx, self.task_handles)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -323,7 +327,7 @@ mod tests {
|
||||
let (connection_writer, reader_from_connection) = tokio::io::duplex(1024);
|
||||
let connection =
|
||||
JsonRpcConnection::from_stdio(connection_reader, connection_writer, "test".to_string());
|
||||
let (outgoing_tx, mut incoming_rx) = connection.into_parts();
|
||||
let (outgoing_tx, mut incoming_rx, _task_handles) = connection.into_parts();
|
||||
|
||||
let incoming_message = JSONRPCMessage::Request(JSONRPCRequest {
|
||||
id: RequestId::Integer(7),
|
||||
@@ -371,7 +375,7 @@ mod tests {
|
||||
let (connection_writer, _reader_from_connection) = tokio::io::duplex(1024);
|
||||
let connection =
|
||||
JsonRpcConnection::from_stdio(connection_reader, connection_writer, "test".to_string());
|
||||
let (_outgoing_tx, mut incoming_rx) = connection.into_parts();
|
||||
let (_outgoing_tx, mut incoming_rx, _task_handles) = connection.into_parts();
|
||||
|
||||
if let Err(err) = writer_to_connection.write_all(b"not-json\n").await {
|
||||
panic!("failed to write invalid JSON: {err}");
|
||||
@@ -401,7 +405,7 @@ mod tests {
|
||||
let (connection_writer, _reader_from_connection) = tokio::io::duplex(1024);
|
||||
let connection =
|
||||
JsonRpcConnection::from_stdio(connection_reader, connection_writer, "test".to_string());
|
||||
let (_outgoing_tx, mut incoming_rx) = connection.into_parts();
|
||||
let (_outgoing_tx, mut incoming_rx, _task_handles) = connection.into_parts();
|
||||
drop(writer_to_connection);
|
||||
|
||||
let event = recv_event(&mut incoming_rx).await;
|
||||
|
||||
@@ -13,6 +13,13 @@ pub const EXEC_WRITE_METHOD: &str = "process/write";
|
||||
pub const EXEC_TERMINATE_METHOD: &str = "process/terminate";
|
||||
pub const EXEC_OUTPUT_DELTA_METHOD: &str = "process/output";
|
||||
pub const EXEC_EXITED_METHOD: &str = "process/exited";
|
||||
pub const FS_READ_FILE_METHOD: &str = "fs/readFile";
|
||||
pub const FS_WRITE_FILE_METHOD: &str = "fs/writeFile";
|
||||
pub const FS_CREATE_DIRECTORY_METHOD: &str = "fs/createDirectory";
|
||||
pub const FS_GET_METADATA_METHOD: &str = "fs/getMetadata";
|
||||
pub const FS_READ_DIRECTORY_METHOD: &str = "fs/readDirectory";
|
||||
pub const FS_REMOVE_METHOD: &str = "fs/remove";
|
||||
pub const FS_COPY_METHOD: &str = "fs/copy";
|
||||
pub const PROTOCOL_VERSION: &str = "exec-server.v0";
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
@@ -54,6 +61,21 @@ pub struct ExecParams {
|
||||
pub env: HashMap<String, String>,
|
||||
pub tty: bool,
|
||||
pub arg0: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub sandbox: Option<ExecSandboxConfig>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ExecSandboxConfig {
|
||||
pub mode: ExecSandboxMode,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub enum ExecSandboxMode {
|
||||
None,
|
||||
HostDefault,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
mod filesystem;
|
||||
mod handler;
|
||||
mod processor;
|
||||
mod routing;
|
||||
|
||||
170
codex-rs/exec-server/src/server/filesystem.rs
Normal file
170
codex-rs/exec-server/src/server/filesystem.rs
Normal file
@@ -0,0 +1,170 @@
|
||||
use std::io;
|
||||
use std::sync::Arc;
|
||||
|
||||
use base64::Engine as _;
|
||||
use base64::engine::general_purpose::STANDARD;
|
||||
use codex_app_server_protocol::FsCopyParams;
|
||||
use codex_app_server_protocol::FsCopyResponse;
|
||||
use codex_app_server_protocol::FsCreateDirectoryParams;
|
||||
use codex_app_server_protocol::FsCreateDirectoryResponse;
|
||||
use codex_app_server_protocol::FsGetMetadataParams;
|
||||
use codex_app_server_protocol::FsGetMetadataResponse;
|
||||
use codex_app_server_protocol::FsReadDirectoryEntry;
|
||||
use codex_app_server_protocol::FsReadDirectoryParams;
|
||||
use codex_app_server_protocol::FsReadDirectoryResponse;
|
||||
use codex_app_server_protocol::FsReadFileParams;
|
||||
use codex_app_server_protocol::FsReadFileResponse;
|
||||
use codex_app_server_protocol::FsRemoveParams;
|
||||
use codex_app_server_protocol::FsRemoveResponse;
|
||||
use codex_app_server_protocol::FsWriteFileParams;
|
||||
use codex_app_server_protocol::FsWriteFileResponse;
|
||||
use codex_app_server_protocol::JSONRPCErrorError;
|
||||
use codex_environment::CopyOptions;
|
||||
use codex_environment::CreateDirectoryOptions;
|
||||
use codex_environment::Environment;
|
||||
use codex_environment::ExecutorFileSystem;
|
||||
use codex_environment::RemoveOptions;
|
||||
|
||||
use crate::server::routing::internal_error;
|
||||
use crate::server::routing::invalid_request;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct ExecServerFileSystem {
|
||||
file_system: Arc<dyn ExecutorFileSystem>,
|
||||
}
|
||||
|
||||
impl Default for ExecServerFileSystem {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
file_system: Environment::default().get_filesystem(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ExecServerFileSystem {
|
||||
pub(crate) async fn read_file(
|
||||
&self,
|
||||
params: FsReadFileParams,
|
||||
) -> Result<FsReadFileResponse, JSONRPCErrorError> {
|
||||
let bytes = self
|
||||
.file_system
|
||||
.read_file(¶ms.path)
|
||||
.await
|
||||
.map_err(map_fs_error)?;
|
||||
Ok(FsReadFileResponse {
|
||||
data_base64: STANDARD.encode(bytes),
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) async fn write_file(
|
||||
&self,
|
||||
params: FsWriteFileParams,
|
||||
) -> Result<FsWriteFileResponse, JSONRPCErrorError> {
|
||||
let bytes = STANDARD.decode(params.data_base64).map_err(|err| {
|
||||
invalid_request(format!(
|
||||
"fs/writeFile requires valid base64 dataBase64: {err}"
|
||||
))
|
||||
})?;
|
||||
self.file_system
|
||||
.write_file(¶ms.path, bytes)
|
||||
.await
|
||||
.map_err(map_fs_error)?;
|
||||
Ok(FsWriteFileResponse {})
|
||||
}
|
||||
|
||||
pub(crate) async fn create_directory(
|
||||
&self,
|
||||
params: FsCreateDirectoryParams,
|
||||
) -> Result<FsCreateDirectoryResponse, JSONRPCErrorError> {
|
||||
self.file_system
|
||||
.create_directory(
|
||||
¶ms.path,
|
||||
CreateDirectoryOptions {
|
||||
recursive: params.recursive.unwrap_or(true),
|
||||
},
|
||||
)
|
||||
.await
|
||||
.map_err(map_fs_error)?;
|
||||
Ok(FsCreateDirectoryResponse {})
|
||||
}
|
||||
|
||||
pub(crate) async fn get_metadata(
|
||||
&self,
|
||||
params: FsGetMetadataParams,
|
||||
) -> Result<FsGetMetadataResponse, JSONRPCErrorError> {
|
||||
let metadata = self
|
||||
.file_system
|
||||
.get_metadata(¶ms.path)
|
||||
.await
|
||||
.map_err(map_fs_error)?;
|
||||
Ok(FsGetMetadataResponse {
|
||||
is_directory: metadata.is_directory,
|
||||
is_file: metadata.is_file,
|
||||
created_at_ms: metadata.created_at_ms,
|
||||
modified_at_ms: metadata.modified_at_ms,
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) async fn read_directory(
|
||||
&self,
|
||||
params: FsReadDirectoryParams,
|
||||
) -> Result<FsReadDirectoryResponse, JSONRPCErrorError> {
|
||||
let entries = self
|
||||
.file_system
|
||||
.read_directory(¶ms.path)
|
||||
.await
|
||||
.map_err(map_fs_error)?;
|
||||
Ok(FsReadDirectoryResponse {
|
||||
entries: entries
|
||||
.into_iter()
|
||||
.map(|entry| FsReadDirectoryEntry {
|
||||
file_name: entry.file_name,
|
||||
is_directory: entry.is_directory,
|
||||
is_file: entry.is_file,
|
||||
})
|
||||
.collect(),
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) async fn remove(
|
||||
&self,
|
||||
params: FsRemoveParams,
|
||||
) -> Result<FsRemoveResponse, JSONRPCErrorError> {
|
||||
self.file_system
|
||||
.remove(
|
||||
¶ms.path,
|
||||
RemoveOptions {
|
||||
recursive: params.recursive.unwrap_or(true),
|
||||
force: params.force.unwrap_or(true),
|
||||
},
|
||||
)
|
||||
.await
|
||||
.map_err(map_fs_error)?;
|
||||
Ok(FsRemoveResponse {})
|
||||
}
|
||||
|
||||
pub(crate) async fn copy(
|
||||
&self,
|
||||
params: FsCopyParams,
|
||||
) -> Result<FsCopyResponse, JSONRPCErrorError> {
|
||||
self.file_system
|
||||
.copy(
|
||||
¶ms.source_path,
|
||||
¶ms.destination_path,
|
||||
CopyOptions {
|
||||
recursive: params.recursive,
|
||||
},
|
||||
)
|
||||
.await
|
||||
.map_err(map_fs_error)?;
|
||||
Ok(FsCopyResponse {})
|
||||
}
|
||||
}
|
||||
|
||||
fn map_fs_error(err: io::Error) -> JSONRPCErrorError {
|
||||
if err.kind() == io::ErrorKind::InvalidInput {
|
||||
invalid_request(err.to_string())
|
||||
} else {
|
||||
internal_error(err.to_string())
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
734
codex-rs/exec-server/src/server/handler/tests.rs
Normal file
734
codex-rs/exec-server/src/server/handler/tests.rs
Normal file
@@ -0,0 +1,734 @@
|
||||
use std::collections::HashMap;
|
||||
use std::collections::VecDeque;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use pretty_assertions::assert_eq;
|
||||
use tokio::sync::Notify;
|
||||
use tokio::time::timeout;
|
||||
|
||||
use super::ExecServerHandler;
|
||||
use super::RetainedOutputChunk;
|
||||
use super::RunningProcess;
|
||||
use crate::protocol::ExecOutputStream;
|
||||
use crate::protocol::ExecSandboxConfig;
|
||||
use crate::protocol::ExecSandboxMode;
|
||||
use crate::protocol::InitializeParams;
|
||||
use crate::protocol::InitializeResponse;
|
||||
use crate::protocol::PROTOCOL_VERSION;
|
||||
use crate::protocol::ReadParams;
|
||||
use crate::protocol::TerminateResponse;
|
||||
use crate::protocol::WriteParams;
|
||||
use crate::server::routing::ExecServerClientNotification;
|
||||
use crate::server::routing::ExecServerInboundMessage;
|
||||
use crate::server::routing::ExecServerOutboundMessage;
|
||||
use crate::server::routing::ExecServerRequest;
|
||||
use crate::server::routing::ExecServerResponseMessage;
|
||||
use codex_app_server_protocol::RequestId;
|
||||
|
||||
async fn recv_outbound(
|
||||
outgoing_rx: &mut tokio::sync::mpsc::Receiver<ExecServerOutboundMessage>,
|
||||
) -> ExecServerOutboundMessage {
|
||||
let recv_result = timeout(Duration::from_secs(1), outgoing_rx.recv()).await;
|
||||
let maybe_message = match recv_result {
|
||||
Ok(maybe_message) => maybe_message,
|
||||
Err(err) => panic!("timed out waiting for handler output: {err}"),
|
||||
};
|
||||
match maybe_message {
|
||||
Some(message) => message,
|
||||
None => panic!("handler output channel closed unexpectedly"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn initialize_response_reports_protocol_version() {
|
||||
let (outgoing_tx, mut outgoing_rx) = tokio::sync::mpsc::channel(1);
|
||||
let mut handler = ExecServerHandler::new(outgoing_tx);
|
||||
|
||||
if let Err(err) = handler
|
||||
.handle_message(ExecServerInboundMessage::Request(
|
||||
ExecServerRequest::Initialize {
|
||||
request_id: RequestId::Integer(1),
|
||||
params: InitializeParams {
|
||||
client_name: "test".to_string(),
|
||||
},
|
||||
},
|
||||
))
|
||||
.await
|
||||
{
|
||||
panic!("initialize should succeed: {err}");
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
recv_outbound(&mut outgoing_rx).await,
|
||||
ExecServerOutboundMessage::Response {
|
||||
request_id: RequestId::Integer(1),
|
||||
response: ExecServerResponseMessage::Initialize(InitializeResponse {
|
||||
protocol_version: PROTOCOL_VERSION.to_string(),
|
||||
}),
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn exec_methods_require_initialize() {
|
||||
let (outgoing_tx, mut outgoing_rx) = tokio::sync::mpsc::channel(1);
|
||||
let mut handler = ExecServerHandler::new(outgoing_tx);
|
||||
|
||||
if let Err(err) = handler
|
||||
.handle_message(ExecServerInboundMessage::Request(ExecServerRequest::Exec {
|
||||
request_id: RequestId::Integer(7),
|
||||
params: crate::protocol::ExecParams {
|
||||
process_id: "proc-1".to_string(),
|
||||
argv: vec!["bash".to_string(), "-lc".to_string(), "true".to_string()],
|
||||
cwd: std::env::current_dir().expect("cwd"),
|
||||
env: HashMap::new(),
|
||||
tty: true,
|
||||
arg0: None,
|
||||
sandbox: None,
|
||||
},
|
||||
}))
|
||||
.await
|
||||
{
|
||||
panic!("request handling should not fail the handler: {err}");
|
||||
}
|
||||
|
||||
let ExecServerOutboundMessage::Error { request_id, error } =
|
||||
recv_outbound(&mut outgoing_rx).await
|
||||
else {
|
||||
panic!("expected invalid-request error");
|
||||
};
|
||||
assert_eq!(request_id, RequestId::Integer(7));
|
||||
assert_eq!(error.code, -32600);
|
||||
assert_eq!(
|
||||
error.message,
|
||||
"client must call initialize before using exec methods"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn exec_methods_require_initialized_notification_after_initialize() {
|
||||
let (outgoing_tx, mut outgoing_rx) = tokio::sync::mpsc::channel(2);
|
||||
let mut handler = ExecServerHandler::new(outgoing_tx);
|
||||
|
||||
if let Err(err) = handler
|
||||
.handle_message(ExecServerInboundMessage::Request(
|
||||
ExecServerRequest::Initialize {
|
||||
request_id: RequestId::Integer(1),
|
||||
params: InitializeParams {
|
||||
client_name: "test".to_string(),
|
||||
},
|
||||
},
|
||||
))
|
||||
.await
|
||||
{
|
||||
panic!("initialize should succeed: {err}");
|
||||
}
|
||||
let _ = recv_outbound(&mut outgoing_rx).await;
|
||||
|
||||
if let Err(err) = handler
|
||||
.handle_message(ExecServerInboundMessage::Request(ExecServerRequest::Exec {
|
||||
request_id: RequestId::Integer(2),
|
||||
params: crate::protocol::ExecParams {
|
||||
process_id: "proc-1".to_string(),
|
||||
argv: vec!["bash".to_string(), "-lc".to_string(), "true".to_string()],
|
||||
cwd: std::env::current_dir().expect("cwd"),
|
||||
env: HashMap::new(),
|
||||
tty: true,
|
||||
arg0: None,
|
||||
sandbox: None,
|
||||
},
|
||||
}))
|
||||
.await
|
||||
{
|
||||
panic!("request handling should not fail the handler: {err}");
|
||||
}
|
||||
|
||||
let ExecServerOutboundMessage::Error { request_id, error } =
|
||||
recv_outbound(&mut outgoing_rx).await
|
||||
else {
|
||||
panic!("expected invalid-request error");
|
||||
};
|
||||
assert_eq!(request_id, RequestId::Integer(2));
|
||||
assert_eq!(error.code, -32600);
|
||||
assert_eq!(
|
||||
error.message,
|
||||
"client must send initialized before using exec methods"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn initialized_before_initialize_is_a_protocol_error() {
|
||||
let (outgoing_tx, _outgoing_rx) = tokio::sync::mpsc::channel(1);
|
||||
let mut handler = ExecServerHandler::new(outgoing_tx);
|
||||
|
||||
let result = handler
|
||||
.handle_message(ExecServerInboundMessage::Notification(
|
||||
ExecServerClientNotification::Initialized,
|
||||
))
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Err(err) => {
|
||||
assert_eq!(
|
||||
err,
|
||||
"received `initialized` notification before `initialize`"
|
||||
);
|
||||
}
|
||||
Ok(()) => panic!("expected protocol error for early initialized notification"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn initialize_may_only_be_sent_once_per_connection() {
|
||||
let (outgoing_tx, mut outgoing_rx) = tokio::sync::mpsc::channel(2);
|
||||
let mut handler = ExecServerHandler::new(outgoing_tx);
|
||||
|
||||
if let Err(err) = handler
|
||||
.handle_message(ExecServerInboundMessage::Request(
|
||||
ExecServerRequest::Initialize {
|
||||
request_id: RequestId::Integer(1),
|
||||
params: InitializeParams {
|
||||
client_name: "test".to_string(),
|
||||
},
|
||||
},
|
||||
))
|
||||
.await
|
||||
{
|
||||
panic!("initialize should succeed: {err}");
|
||||
}
|
||||
let _ = recv_outbound(&mut outgoing_rx).await;
|
||||
|
||||
if let Err(err) = handler
|
||||
.handle_message(ExecServerInboundMessage::Request(
|
||||
ExecServerRequest::Initialize {
|
||||
request_id: RequestId::Integer(2),
|
||||
params: InitializeParams {
|
||||
client_name: "test".to_string(),
|
||||
},
|
||||
},
|
||||
))
|
||||
.await
|
||||
{
|
||||
panic!("duplicate initialize should not fail the handler: {err}");
|
||||
}
|
||||
|
||||
let ExecServerOutboundMessage::Error { request_id, error } =
|
||||
recv_outbound(&mut outgoing_rx).await
|
||||
else {
|
||||
panic!("expected invalid-request error");
|
||||
};
|
||||
assert_eq!(request_id, RequestId::Integer(2));
|
||||
assert_eq!(error.code, -32600);
|
||||
assert_eq!(
|
||||
error.message,
|
||||
"initialize may only be sent once per connection"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn host_default_sandbox_requests_are_rejected_until_supported() {
|
||||
let (outgoing_tx, mut outgoing_rx) = tokio::sync::mpsc::channel(3);
|
||||
let mut handler = ExecServerHandler::new(outgoing_tx);
|
||||
|
||||
if let Err(err) = handler
|
||||
.handle_message(ExecServerInboundMessage::Request(
|
||||
ExecServerRequest::Initialize {
|
||||
request_id: RequestId::Integer(1),
|
||||
params: InitializeParams {
|
||||
client_name: "test".to_string(),
|
||||
},
|
||||
},
|
||||
))
|
||||
.await
|
||||
{
|
||||
panic!("initialize should succeed: {err}");
|
||||
}
|
||||
let _ = recv_outbound(&mut outgoing_rx).await;
|
||||
if let Err(err) = handler
|
||||
.handle_message(ExecServerInboundMessage::Notification(
|
||||
ExecServerClientNotification::Initialized,
|
||||
))
|
||||
.await
|
||||
{
|
||||
panic!("initialized should succeed: {err}");
|
||||
}
|
||||
|
||||
if let Err(err) = handler
|
||||
.handle_message(ExecServerInboundMessage::Request(ExecServerRequest::Exec {
|
||||
request_id: RequestId::Integer(2),
|
||||
params: crate::protocol::ExecParams {
|
||||
process_id: "proc-1".to_string(),
|
||||
argv: vec!["bash".to_string(), "-lc".to_string(), "true".to_string()],
|
||||
cwd: std::env::current_dir().expect("cwd"),
|
||||
env: HashMap::new(),
|
||||
tty: false,
|
||||
arg0: None,
|
||||
sandbox: Some(ExecSandboxConfig {
|
||||
mode: ExecSandboxMode::HostDefault,
|
||||
}),
|
||||
},
|
||||
}))
|
||||
.await
|
||||
{
|
||||
panic!("request handling should not fail the handler: {err}");
|
||||
}
|
||||
|
||||
let ExecServerOutboundMessage::Error { request_id, error } =
|
||||
recv_outbound(&mut outgoing_rx).await
|
||||
else {
|
||||
panic!("expected unsupported sandbox error");
|
||||
};
|
||||
assert_eq!(request_id, RequestId::Integer(2));
|
||||
assert_eq!(error.code, -32600);
|
||||
assert_eq!(
|
||||
error.message,
|
||||
"sandbox mode `hostDefault` is not supported by exec-server yet"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn exec_echoes_client_process_ids() {
|
||||
let (outgoing_tx, mut outgoing_rx) = tokio::sync::mpsc::channel(4);
|
||||
let mut handler = ExecServerHandler::new(outgoing_tx);
|
||||
|
||||
if let Err(err) = handler
|
||||
.handle_message(ExecServerInboundMessage::Request(
|
||||
ExecServerRequest::Initialize {
|
||||
request_id: RequestId::Integer(1),
|
||||
params: InitializeParams {
|
||||
client_name: "test".to_string(),
|
||||
},
|
||||
},
|
||||
))
|
||||
.await
|
||||
{
|
||||
panic!("initialize should succeed: {err}");
|
||||
}
|
||||
let _ = recv_outbound(&mut outgoing_rx).await;
|
||||
if let Err(err) = handler
|
||||
.handle_message(ExecServerInboundMessage::Notification(
|
||||
ExecServerClientNotification::Initialized,
|
||||
))
|
||||
.await
|
||||
{
|
||||
panic!("initialized should succeed: {err}");
|
||||
}
|
||||
|
||||
let params = crate::protocol::ExecParams {
|
||||
process_id: "proc-1".to_string(),
|
||||
argv: vec![
|
||||
"bash".to_string(),
|
||||
"-lc".to_string(),
|
||||
"sleep 30".to_string(),
|
||||
],
|
||||
cwd: std::env::current_dir().expect("cwd"),
|
||||
env: HashMap::new(),
|
||||
tty: false,
|
||||
arg0: None,
|
||||
sandbox: None,
|
||||
};
|
||||
if let Err(err) = handler
|
||||
.handle_message(ExecServerInboundMessage::Request(ExecServerRequest::Exec {
|
||||
request_id: RequestId::Integer(2),
|
||||
params: params.clone(),
|
||||
}))
|
||||
.await
|
||||
{
|
||||
panic!("first exec should succeed: {err}");
|
||||
}
|
||||
let ExecServerOutboundMessage::Response {
|
||||
request_id,
|
||||
response: ExecServerResponseMessage::Exec(first_exec),
|
||||
} = recv_outbound(&mut outgoing_rx).await
|
||||
else {
|
||||
panic!("expected first exec response");
|
||||
};
|
||||
assert_eq!(request_id, RequestId::Integer(2));
|
||||
assert_eq!(first_exec.process_id, "proc-1");
|
||||
|
||||
if let Err(err) = handler
|
||||
.handle_message(ExecServerInboundMessage::Request(ExecServerRequest::Exec {
|
||||
request_id: RequestId::Integer(3),
|
||||
params: crate::protocol::ExecParams {
|
||||
process_id: "proc-2".to_string(),
|
||||
argv: vec!["bash".to_string(), "-lc".to_string(), "true".to_string()],
|
||||
..params
|
||||
},
|
||||
}))
|
||||
.await
|
||||
{
|
||||
panic!("second exec should succeed: {err}");
|
||||
}
|
||||
|
||||
let ExecServerOutboundMessage::Response {
|
||||
request_id,
|
||||
response: ExecServerResponseMessage::Exec(second_exec),
|
||||
} = recv_outbound(&mut outgoing_rx).await
|
||||
else {
|
||||
panic!("expected second exec response");
|
||||
};
|
||||
assert_eq!(request_id, RequestId::Integer(3));
|
||||
assert_eq!(second_exec.process_id, "proc-2");
|
||||
|
||||
handler.shutdown().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn writes_to_pipe_backed_processes_are_rejected() {
|
||||
let (outgoing_tx, mut outgoing_rx) = tokio::sync::mpsc::channel(4);
|
||||
let mut handler = ExecServerHandler::new(outgoing_tx);
|
||||
|
||||
if let Err(err) = handler
|
||||
.handle_message(ExecServerInboundMessage::Request(
|
||||
ExecServerRequest::Initialize {
|
||||
request_id: RequestId::Integer(1),
|
||||
params: InitializeParams {
|
||||
client_name: "test".to_string(),
|
||||
},
|
||||
},
|
||||
))
|
||||
.await
|
||||
{
|
||||
panic!("initialize should succeed: {err}");
|
||||
}
|
||||
let _ = recv_outbound(&mut outgoing_rx).await;
|
||||
if let Err(err) = handler
|
||||
.handle_message(ExecServerInboundMessage::Notification(
|
||||
ExecServerClientNotification::Initialized,
|
||||
))
|
||||
.await
|
||||
{
|
||||
panic!("initialized should succeed: {err}");
|
||||
}
|
||||
|
||||
if let Err(err) = handler
|
||||
.handle_message(ExecServerInboundMessage::Request(ExecServerRequest::Exec {
|
||||
request_id: RequestId::Integer(2),
|
||||
params: crate::protocol::ExecParams {
|
||||
process_id: "proc-1".to_string(),
|
||||
argv: vec![
|
||||
"bash".to_string(),
|
||||
"-lc".to_string(),
|
||||
"sleep 30".to_string(),
|
||||
],
|
||||
cwd: std::env::current_dir().expect("cwd"),
|
||||
env: HashMap::new(),
|
||||
tty: false,
|
||||
arg0: None,
|
||||
sandbox: None,
|
||||
},
|
||||
}))
|
||||
.await
|
||||
{
|
||||
panic!("exec should succeed: {err}");
|
||||
}
|
||||
let ExecServerOutboundMessage::Response {
|
||||
response: ExecServerResponseMessage::Exec(exec_response),
|
||||
..
|
||||
} = recv_outbound(&mut outgoing_rx).await
|
||||
else {
|
||||
panic!("expected exec response");
|
||||
};
|
||||
|
||||
if let Err(err) = handler
|
||||
.handle_message(ExecServerInboundMessage::Request(
|
||||
ExecServerRequest::Write {
|
||||
request_id: RequestId::Integer(3),
|
||||
params: WriteParams {
|
||||
process_id: exec_response.process_id,
|
||||
chunk: b"hello\n".to_vec().into(),
|
||||
},
|
||||
},
|
||||
))
|
||||
.await
|
||||
{
|
||||
panic!("write should not fail the handler: {err}");
|
||||
}
|
||||
|
||||
let ExecServerOutboundMessage::Error { request_id, error } =
|
||||
recv_outbound(&mut outgoing_rx).await
|
||||
else {
|
||||
panic!("expected stdin-closed error");
|
||||
};
|
||||
assert_eq!(request_id, RequestId::Integer(3));
|
||||
assert_eq!(error.code, -32600);
|
||||
assert_eq!(error.message, "stdin is closed for process proc-1");
|
||||
|
||||
handler.shutdown().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn writes_to_unknown_processes_are_rejected() {
|
||||
let (outgoing_tx, mut outgoing_rx) = tokio::sync::mpsc::channel(2);
|
||||
let mut handler = ExecServerHandler::new(outgoing_tx);
|
||||
|
||||
if let Err(err) = handler
|
||||
.handle_message(ExecServerInboundMessage::Request(
|
||||
ExecServerRequest::Initialize {
|
||||
request_id: RequestId::Integer(1),
|
||||
params: InitializeParams {
|
||||
client_name: "test".to_string(),
|
||||
},
|
||||
},
|
||||
))
|
||||
.await
|
||||
{
|
||||
panic!("initialize should succeed: {err}");
|
||||
}
|
||||
let _ = recv_outbound(&mut outgoing_rx).await;
|
||||
if let Err(err) = handler
|
||||
.handle_message(ExecServerInboundMessage::Notification(
|
||||
ExecServerClientNotification::Initialized,
|
||||
))
|
||||
.await
|
||||
{
|
||||
panic!("initialized should succeed: {err}");
|
||||
}
|
||||
|
||||
if let Err(err) = handler
|
||||
.handle_message(ExecServerInboundMessage::Request(
|
||||
ExecServerRequest::Write {
|
||||
request_id: RequestId::Integer(2),
|
||||
params: WriteParams {
|
||||
process_id: "missing".to_string(),
|
||||
chunk: b"hello\n".to_vec().into(),
|
||||
},
|
||||
},
|
||||
))
|
||||
.await
|
||||
{
|
||||
panic!("write should not fail the handler: {err}");
|
||||
}
|
||||
|
||||
let ExecServerOutboundMessage::Error { request_id, error } =
|
||||
recv_outbound(&mut outgoing_rx).await
|
||||
else {
|
||||
panic!("expected unknown-process error");
|
||||
};
|
||||
assert_eq!(request_id, RequestId::Integer(2));
|
||||
assert_eq!(error.code, -32600);
|
||||
assert_eq!(error.message, "unknown process id missing");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn terminate_unknown_processes_report_running_false() {
|
||||
let (outgoing_tx, mut outgoing_rx) = tokio::sync::mpsc::channel(2);
|
||||
let mut handler = ExecServerHandler::new(outgoing_tx);
|
||||
|
||||
if let Err(err) = handler
|
||||
.handle_message(ExecServerInboundMessage::Request(
|
||||
ExecServerRequest::Initialize {
|
||||
request_id: RequestId::Integer(1),
|
||||
params: InitializeParams {
|
||||
client_name: "test".to_string(),
|
||||
},
|
||||
},
|
||||
))
|
||||
.await
|
||||
{
|
||||
panic!("initialize should succeed: {err}");
|
||||
}
|
||||
let _ = recv_outbound(&mut outgoing_rx).await;
|
||||
if let Err(err) = handler
|
||||
.handle_message(ExecServerInboundMessage::Notification(
|
||||
ExecServerClientNotification::Initialized,
|
||||
))
|
||||
.await
|
||||
{
|
||||
panic!("initialized should succeed: {err}");
|
||||
}
|
||||
|
||||
if let Err(err) = handler
|
||||
.handle_message(ExecServerInboundMessage::Request(
|
||||
ExecServerRequest::Terminate {
|
||||
request_id: RequestId::Integer(2),
|
||||
params: crate::protocol::TerminateParams {
|
||||
process_id: "missing".to_string(),
|
||||
},
|
||||
},
|
||||
))
|
||||
.await
|
||||
{
|
||||
panic!("terminate should not fail the handler: {err}");
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
recv_outbound(&mut outgoing_rx).await,
|
||||
ExecServerOutboundMessage::Response {
|
||||
request_id: RequestId::Integer(2),
|
||||
response: ExecServerResponseMessage::Terminate(TerminateResponse { running: false }),
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn terminate_keeps_process_ids_reserved() {
|
||||
let (outgoing_tx, mut outgoing_rx) = tokio::sync::mpsc::channel(2);
|
||||
let mut handler = ExecServerHandler::new(outgoing_tx);
|
||||
|
||||
if let Err(err) = handler
|
||||
.handle_message(ExecServerInboundMessage::Request(
|
||||
ExecServerRequest::Initialize {
|
||||
request_id: RequestId::Integer(1),
|
||||
params: InitializeParams {
|
||||
client_name: "test".to_string(),
|
||||
},
|
||||
},
|
||||
))
|
||||
.await
|
||||
{
|
||||
panic!("initialize should succeed: {err}");
|
||||
}
|
||||
let _ = recv_outbound(&mut outgoing_rx).await;
|
||||
if let Err(err) = handler
|
||||
.handle_message(ExecServerInboundMessage::Notification(
|
||||
ExecServerClientNotification::Initialized,
|
||||
))
|
||||
.await
|
||||
{
|
||||
panic!("initialized should succeed: {err}");
|
||||
}
|
||||
|
||||
if let Err(err) = handler
|
||||
.handle_message(ExecServerInboundMessage::Request(ExecServerRequest::Exec {
|
||||
request_id: RequestId::Integer(2),
|
||||
params: crate::protocol::ExecParams {
|
||||
process_id: "proc-1".to_string(),
|
||||
argv: vec![
|
||||
"bash".to_string(),
|
||||
"-lc".to_string(),
|
||||
"sleep 30".to_string(),
|
||||
],
|
||||
cwd: std::env::current_dir().expect("cwd"),
|
||||
env: HashMap::new(),
|
||||
tty: false,
|
||||
arg0: None,
|
||||
sandbox: None,
|
||||
},
|
||||
}))
|
||||
.await
|
||||
{
|
||||
panic!("exec should not fail the handler: {err}");
|
||||
}
|
||||
let _ = recv_outbound(&mut outgoing_rx).await;
|
||||
|
||||
if let Err(err) = handler
|
||||
.handle_message(ExecServerInboundMessage::Request(
|
||||
ExecServerRequest::Terminate {
|
||||
request_id: RequestId::Integer(3),
|
||||
params: crate::protocol::TerminateParams {
|
||||
process_id: "proc-1".to_string(),
|
||||
},
|
||||
},
|
||||
))
|
||||
.await
|
||||
{
|
||||
panic!("terminate should not fail the handler: {err}");
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
recv_outbound(&mut outgoing_rx).await,
|
||||
ExecServerOutboundMessage::Response {
|
||||
request_id: RequestId::Integer(3),
|
||||
response: ExecServerResponseMessage::Terminate(TerminateResponse { running: true }),
|
||||
}
|
||||
);
|
||||
|
||||
assert!(
|
||||
handler.processes.lock().await.contains_key("proc-1"),
|
||||
"terminated ids should stay reserved until exit cleanup removes them"
|
||||
);
|
||||
|
||||
let deadline = tokio::time::Instant::now() + Duration::from_secs(1);
|
||||
loop {
|
||||
if !handler.processes.lock().await.contains_key("proc-1") {
|
||||
break;
|
||||
}
|
||||
assert!(
|
||||
tokio::time::Instant::now() < deadline,
|
||||
"terminated ids should be removed after the exit retention window"
|
||||
);
|
||||
tokio::time::sleep(Duration::from_millis(25)).await;
|
||||
}
|
||||
|
||||
handler.shutdown().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn read_paginates_retained_output_without_skipping_omitted_chunks() {
|
||||
let (outgoing_tx, _outgoing_rx) = tokio::sync::mpsc::channel(1);
|
||||
let mut handler = ExecServerHandler::new(outgoing_tx);
|
||||
let _ = handler.initialize().expect("initialize should succeed");
|
||||
handler.initialized().expect("initialized should succeed");
|
||||
|
||||
let spawned = codex_utils_pty::spawn_pipe_process_no_stdin(
|
||||
"bash",
|
||||
&["-lc".to_string(), "true".to_string()],
|
||||
std::env::current_dir().expect("cwd").as_path(),
|
||||
&HashMap::new(),
|
||||
&None,
|
||||
)
|
||||
.await
|
||||
.expect("spawn test process");
|
||||
{
|
||||
let mut process_map = handler.processes.lock().await;
|
||||
process_map.insert(
|
||||
"proc-1".to_string(),
|
||||
RunningProcess {
|
||||
session: spawned.session,
|
||||
tty: false,
|
||||
output: VecDeque::from([
|
||||
RetainedOutputChunk {
|
||||
seq: 1,
|
||||
stream: ExecOutputStream::Stdout,
|
||||
chunk: b"abc".to_vec(),
|
||||
},
|
||||
RetainedOutputChunk {
|
||||
seq: 2,
|
||||
stream: ExecOutputStream::Stderr,
|
||||
chunk: b"def".to_vec(),
|
||||
},
|
||||
]),
|
||||
retained_bytes: 6,
|
||||
next_seq: 3,
|
||||
exit_code: None,
|
||||
output_notify: Arc::new(Notify::new()),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
let first = handler
|
||||
.exec_read(ReadParams {
|
||||
process_id: "proc-1".to_string(),
|
||||
after_seq: Some(0),
|
||||
max_bytes: Some(3),
|
||||
wait_ms: Some(0),
|
||||
})
|
||||
.await
|
||||
.expect("first read should succeed");
|
||||
|
||||
assert_eq!(first.chunks.len(), 1);
|
||||
assert_eq!(first.chunks[0].seq, 1);
|
||||
assert_eq!(first.chunks[0].stream, ExecOutputStream::Stdout);
|
||||
assert_eq!(first.chunks[0].chunk.clone().into_inner(), b"abc".to_vec());
|
||||
assert_eq!(first.next_seq, 2);
|
||||
|
||||
let second = handler
|
||||
.exec_read(ReadParams {
|
||||
process_id: "proc-1".to_string(),
|
||||
after_seq: Some(first.next_seq - 1),
|
||||
max_bytes: Some(3),
|
||||
wait_ms: Some(0),
|
||||
})
|
||||
.await
|
||||
.expect("second read should succeed");
|
||||
|
||||
assert_eq!(second.chunks.len(), 1);
|
||||
assert_eq!(second.chunks[0].seq, 2);
|
||||
assert_eq!(second.chunks[0].stream, ExecOutputStream::Stderr);
|
||||
assert_eq!(second.chunks[0].chunk.clone().into_inner(), b"def".to_vec());
|
||||
assert_eq!(second.next_seq, 3);
|
||||
|
||||
handler.shutdown().await;
|
||||
}
|
||||
@@ -6,17 +6,13 @@ use crate::connection::CHANNEL_CAPACITY;
|
||||
use crate::connection::JsonRpcConnection;
|
||||
use crate::connection::JsonRpcConnectionEvent;
|
||||
use crate::server::handler::ExecServerHandler;
|
||||
use crate::server::routing::ExecServerClientNotification;
|
||||
use crate::server::routing::ExecServerInboundMessage;
|
||||
use crate::server::routing::ExecServerOutboundMessage;
|
||||
use crate::server::routing::ExecServerRequest;
|
||||
use crate::server::routing::ExecServerResponseMessage;
|
||||
use crate::server::routing::RoutedExecServerMessage;
|
||||
use crate::server::routing::encode_outbound_message;
|
||||
use crate::server::routing::route_jsonrpc_message;
|
||||
|
||||
pub(crate) async fn run_connection(connection: JsonRpcConnection) {
|
||||
let (json_outgoing_tx, mut incoming_rx) = connection.into_parts();
|
||||
let (json_outgoing_tx, mut incoming_rx, _connection_tasks) = connection.into_parts();
|
||||
let (outgoing_tx, mut outgoing_rx) =
|
||||
mpsc::channel::<ExecServerOutboundMessage>(CHANNEL_CAPACITY);
|
||||
let mut handler = ExecServerHandler::new(outgoing_tx.clone());
|
||||
@@ -40,8 +36,7 @@ pub(crate) async fn run_connection(connection: JsonRpcConnection) {
|
||||
match event {
|
||||
JsonRpcConnectionEvent::Message(message) => match route_jsonrpc_message(message) {
|
||||
Ok(RoutedExecServerMessage::Inbound(message)) => {
|
||||
if let Err(err) = dispatch_to_handler(&mut handler, message, &outgoing_tx).await
|
||||
{
|
||||
if let Err(err) = handler.handle_message(message).await {
|
||||
warn!("closing exec-server connection after protocol error: {err}");
|
||||
break;
|
||||
}
|
||||
@@ -70,70 +65,3 @@ pub(crate) async fn run_connection(connection: JsonRpcConnection) {
|
||||
drop(outgoing_tx);
|
||||
let _ = outbound_task.await;
|
||||
}
|
||||
|
||||
async fn dispatch_to_handler(
|
||||
handler: &mut ExecServerHandler,
|
||||
message: ExecServerInboundMessage,
|
||||
outgoing_tx: &mpsc::Sender<ExecServerOutboundMessage>,
|
||||
) -> Result<(), String> {
|
||||
match message {
|
||||
ExecServerInboundMessage::Request(request) => {
|
||||
let outbound = match request {
|
||||
ExecServerRequest::Initialize { request_id, .. } => request_outbound(
|
||||
request_id,
|
||||
handler
|
||||
.initialize()
|
||||
.map(ExecServerResponseMessage::Initialize),
|
||||
),
|
||||
ExecServerRequest::Exec { request_id, params } => request_outbound(
|
||||
request_id,
|
||||
handler
|
||||
.exec(params)
|
||||
.await
|
||||
.map(ExecServerResponseMessage::Exec),
|
||||
),
|
||||
ExecServerRequest::Read { request_id, params } => request_outbound(
|
||||
request_id,
|
||||
handler
|
||||
.read(params)
|
||||
.await
|
||||
.map(ExecServerResponseMessage::Read),
|
||||
),
|
||||
ExecServerRequest::Write { request_id, params } => request_outbound(
|
||||
request_id,
|
||||
handler
|
||||
.write(params)
|
||||
.await
|
||||
.map(ExecServerResponseMessage::Write),
|
||||
),
|
||||
ExecServerRequest::Terminate { request_id, params } => request_outbound(
|
||||
request_id,
|
||||
handler
|
||||
.terminate(params)
|
||||
.await
|
||||
.map(ExecServerResponseMessage::Terminate),
|
||||
),
|
||||
};
|
||||
outgoing_tx
|
||||
.send(outbound)
|
||||
.await
|
||||
.map_err(|_| "outbound channel closed".to_string())
|
||||
}
|
||||
ExecServerInboundMessage::Notification(ExecServerClientNotification::Initialized) => {
|
||||
handler.initialized()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn request_outbound(
|
||||
request_id: codex_app_server_protocol::RequestId,
|
||||
result: Result<ExecServerResponseMessage, codex_app_server_protocol::JSONRPCErrorError>,
|
||||
) -> ExecServerOutboundMessage {
|
||||
match result {
|
||||
Ok(response) => ExecServerOutboundMessage::Response {
|
||||
request_id,
|
||||
response,
|
||||
},
|
||||
Err(error) => ExecServerOutboundMessage::Error { request_id, error },
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,6 +17,13 @@ use crate::protocol::ExecExitedNotification;
|
||||
use crate::protocol::ExecOutputDeltaNotification;
|
||||
use crate::protocol::ExecParams;
|
||||
use crate::protocol::ExecResponse;
|
||||
use crate::protocol::FS_COPY_METHOD;
|
||||
use crate::protocol::FS_CREATE_DIRECTORY_METHOD;
|
||||
use crate::protocol::FS_GET_METADATA_METHOD;
|
||||
use crate::protocol::FS_READ_DIRECTORY_METHOD;
|
||||
use crate::protocol::FS_READ_FILE_METHOD;
|
||||
use crate::protocol::FS_REMOVE_METHOD;
|
||||
use crate::protocol::FS_WRITE_FILE_METHOD;
|
||||
use crate::protocol::INITIALIZE_METHOD;
|
||||
use crate::protocol::INITIALIZED_METHOD;
|
||||
use crate::protocol::InitializeParams;
|
||||
@@ -27,6 +34,20 @@ use crate::protocol::TerminateParams;
|
||||
use crate::protocol::TerminateResponse;
|
||||
use crate::protocol::WriteParams;
|
||||
use crate::protocol::WriteResponse;
|
||||
use codex_app_server_protocol::FsCopyParams;
|
||||
use codex_app_server_protocol::FsCopyResponse;
|
||||
use codex_app_server_protocol::FsCreateDirectoryParams;
|
||||
use codex_app_server_protocol::FsCreateDirectoryResponse;
|
||||
use codex_app_server_protocol::FsGetMetadataParams;
|
||||
use codex_app_server_protocol::FsGetMetadataResponse;
|
||||
use codex_app_server_protocol::FsReadDirectoryParams;
|
||||
use codex_app_server_protocol::FsReadDirectoryResponse;
|
||||
use codex_app_server_protocol::FsReadFileParams;
|
||||
use codex_app_server_protocol::FsReadFileResponse;
|
||||
use codex_app_server_protocol::FsRemoveParams;
|
||||
use codex_app_server_protocol::FsRemoveResponse;
|
||||
use codex_app_server_protocol::FsWriteFileParams;
|
||||
use codex_app_server_protocol::FsWriteFileResponse;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub(crate) enum ExecServerInboundMessage {
|
||||
@@ -56,6 +77,34 @@ pub(crate) enum ExecServerRequest {
|
||||
request_id: RequestId,
|
||||
params: TerminateParams,
|
||||
},
|
||||
FsReadFile {
|
||||
request_id: RequestId,
|
||||
params: FsReadFileParams,
|
||||
},
|
||||
FsWriteFile {
|
||||
request_id: RequestId,
|
||||
params: FsWriteFileParams,
|
||||
},
|
||||
FsCreateDirectory {
|
||||
request_id: RequestId,
|
||||
params: FsCreateDirectoryParams,
|
||||
},
|
||||
FsGetMetadata {
|
||||
request_id: RequestId,
|
||||
params: FsGetMetadataParams,
|
||||
},
|
||||
FsReadDirectory {
|
||||
request_id: RequestId,
|
||||
params: FsReadDirectoryParams,
|
||||
},
|
||||
FsRemove {
|
||||
request_id: RequestId,
|
||||
params: FsRemoveParams,
|
||||
},
|
||||
FsCopy {
|
||||
request_id: RequestId,
|
||||
params: FsCopyParams,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
@@ -83,6 +132,13 @@ pub(crate) enum ExecServerResponseMessage {
|
||||
Read(ReadResponse),
|
||||
Write(WriteResponse),
|
||||
Terminate(TerminateResponse),
|
||||
FsReadFile(FsReadFileResponse),
|
||||
FsWriteFile(FsWriteFileResponse),
|
||||
FsCreateDirectory(FsCreateDirectoryResponse),
|
||||
FsGetMetadata(FsGetMetadataResponse),
|
||||
FsReadDirectory(FsReadDirectoryResponse),
|
||||
FsRemove(FsRemoveResponse),
|
||||
FsCopy(FsCopyResponse),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
@@ -178,6 +234,27 @@ fn route_request(request: JSONRPCRequest) -> Result<RoutedExecServerMessage, Str
|
||||
EXEC_TERMINATE_METHOD => Ok(parse_request_params(request, |request_id, params| {
|
||||
ExecServerRequest::Terminate { request_id, params }
|
||||
})),
|
||||
FS_READ_FILE_METHOD => Ok(parse_request_params(request, |request_id, params| {
|
||||
ExecServerRequest::FsReadFile { request_id, params }
|
||||
})),
|
||||
FS_WRITE_FILE_METHOD => Ok(parse_request_params(request, |request_id, params| {
|
||||
ExecServerRequest::FsWriteFile { request_id, params }
|
||||
})),
|
||||
FS_CREATE_DIRECTORY_METHOD => Ok(parse_request_params(request, |request_id, params| {
|
||||
ExecServerRequest::FsCreateDirectory { request_id, params }
|
||||
})),
|
||||
FS_GET_METADATA_METHOD => Ok(parse_request_params(request, |request_id, params| {
|
||||
ExecServerRequest::FsGetMetadata { request_id, params }
|
||||
})),
|
||||
FS_READ_DIRECTORY_METHOD => Ok(parse_request_params(request, |request_id, params| {
|
||||
ExecServerRequest::FsReadDirectory { request_id, params }
|
||||
})),
|
||||
FS_REMOVE_METHOD => Ok(parse_request_params(request, |request_id, params| {
|
||||
ExecServerRequest::FsRemove { request_id, params }
|
||||
})),
|
||||
FS_COPY_METHOD => Ok(parse_request_params(request, |request_id, params| {
|
||||
ExecServerRequest::FsCopy { request_id, params }
|
||||
})),
|
||||
other => Ok(RoutedExecServerMessage::ImmediateOutbound(
|
||||
ExecServerOutboundMessage::Error {
|
||||
request_id: request.id,
|
||||
@@ -224,6 +301,13 @@ fn serialize_response(
|
||||
ExecServerResponseMessage::Read(response) => serde_json::to_value(response),
|
||||
ExecServerResponseMessage::Write(response) => serde_json::to_value(response),
|
||||
ExecServerResponseMessage::Terminate(response) => serde_json::to_value(response),
|
||||
ExecServerResponseMessage::FsReadFile(response) => serde_json::to_value(response),
|
||||
ExecServerResponseMessage::FsWriteFile(response) => serde_json::to_value(response),
|
||||
ExecServerResponseMessage::FsCreateDirectory(response) => serde_json::to_value(response),
|
||||
ExecServerResponseMessage::FsGetMetadata(response) => serde_json::to_value(response),
|
||||
ExecServerResponseMessage::FsReadDirectory(response) => serde_json::to_value(response),
|
||||
ExecServerResponseMessage::FsRemove(response) => serde_json::to_value(response),
|
||||
ExecServerResponseMessage::FsCopy(response) => serde_json::to_value(response),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -261,6 +345,8 @@ mod tests {
|
||||
use crate::protocol::ExecExitedNotification;
|
||||
use crate::protocol::ExecParams;
|
||||
use crate::protocol::ExecResponse;
|
||||
use crate::protocol::ExecSandboxConfig;
|
||||
use crate::protocol::ExecSandboxMode;
|
||||
use crate::protocol::INITIALIZE_METHOD;
|
||||
use crate::protocol::INITIALIZED_METHOD;
|
||||
use crate::protocol::InitializeParams;
|
||||
@@ -407,6 +493,51 @@ mod tests {
|
||||
env: std::collections::HashMap::new(),
|
||||
tty: true,
|
||||
arg0: None,
|
||||
sandbox: None,
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn routes_exec_requests_with_optional_sandbox_config() {
|
||||
let cwd = std::env::current_dir().expect("cwd");
|
||||
let routed = route_jsonrpc_message(JSONRPCMessage::Request(JSONRPCRequest {
|
||||
id: RequestId::Integer(4),
|
||||
method: EXEC_METHOD.to_string(),
|
||||
params: Some(json!({
|
||||
"processId": "proc-1",
|
||||
"argv": ["bash", "-lc", "true"],
|
||||
"cwd": cwd,
|
||||
"env": {},
|
||||
"tty": true,
|
||||
"arg0": null,
|
||||
"sandbox": {
|
||||
"mode": "none",
|
||||
},
|
||||
})),
|
||||
trace: None,
|
||||
}))
|
||||
.expect("exec request with sandbox should route");
|
||||
|
||||
let RoutedExecServerMessage::Inbound(ExecServerInboundMessage::Request(
|
||||
ExecServerRequest::Exec { request_id, params },
|
||||
)) = routed
|
||||
else {
|
||||
panic!("expected typed exec request");
|
||||
};
|
||||
assert_eq!(request_id, RequestId::Integer(4));
|
||||
assert_eq!(
|
||||
params,
|
||||
ExecParams {
|
||||
process_id: "proc-1".to_string(),
|
||||
argv: vec!["bash".to_string(), "-lc".to_string(), "true".to_string()],
|
||||
cwd: std::env::current_dir().expect("cwd"),
|
||||
env: std::collections::HashMap::new(),
|
||||
tty: true,
|
||||
arg0: None,
|
||||
sandbox: Some(ExecSandboxConfig {
|
||||
mode: ExecSandboxMode::None,
|
||||
}),
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
@@ -75,6 +75,125 @@ async fn exec_server_accepts_initialize_over_stdio() -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn exec_server_accepts_explicit_none_sandbox_over_stdio() -> anyhow::Result<()> {
|
||||
let binary = cargo_bin("codex-exec-server")?;
|
||||
let mut child = Command::new(binary);
|
||||
child.stdin(Stdio::piped());
|
||||
child.stdout(Stdio::piped());
|
||||
child.stderr(Stdio::inherit());
|
||||
let mut child = child.spawn()?;
|
||||
|
||||
let mut stdin = child.stdin.take().expect("stdin");
|
||||
let stdout = child.stdout.take().expect("stdout");
|
||||
let mut stdout = BufReader::new(stdout).lines();
|
||||
|
||||
send_initialize_over_stdio(&mut stdin, &mut stdout).await?;
|
||||
|
||||
let exec = JSONRPCMessage::Request(JSONRPCRequest {
|
||||
id: RequestId::Integer(2),
|
||||
method: "process/start".to_string(),
|
||||
params: Some(serde_json::json!({
|
||||
"processId": "proc-1",
|
||||
"argv": ["printf", "sandbox-none"],
|
||||
"cwd": std::env::current_dir()?,
|
||||
"env": {},
|
||||
"tty": false,
|
||||
"arg0": null,
|
||||
"sandbox": {
|
||||
"mode": "none"
|
||||
}
|
||||
})),
|
||||
trace: None,
|
||||
});
|
||||
stdin
|
||||
.write_all(format!("{}\n", serde_json::to_string(&exec)?).as_bytes())
|
||||
.await?;
|
||||
|
||||
let response_line = timeout(Duration::from_secs(5), stdout.next_line()).await??;
|
||||
let response_line = response_line.expect("exec response line");
|
||||
let response: JSONRPCMessage = serde_json::from_str(&response_line)?;
|
||||
let JSONRPCMessage::Response(JSONRPCResponse { id, result }) = response else {
|
||||
panic!("expected process/start response");
|
||||
};
|
||||
assert_eq!(id, RequestId::Integer(2));
|
||||
assert_eq!(result, serde_json::json!({ "processId": "proc-1" }));
|
||||
|
||||
let deadline = tokio::time::Instant::now() + Duration::from_secs(5);
|
||||
let mut saw_output = false;
|
||||
while !saw_output {
|
||||
let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
|
||||
let line = timeout(remaining, stdout.next_line()).await??;
|
||||
let line = line.context("missing process notification")?;
|
||||
let message: JSONRPCMessage = serde_json::from_str(&line)?;
|
||||
if let JSONRPCMessage::Notification(JSONRPCNotification { method, params }) = message
|
||||
&& method == "process/output"
|
||||
{
|
||||
let params = params.context("missing process/output params")?;
|
||||
assert_eq!(params["processId"], "proc-1");
|
||||
assert_eq!(params["stream"], "stdout");
|
||||
assert_eq!(params["chunk"], "c2FuZGJveC1ub25l");
|
||||
saw_output = true;
|
||||
}
|
||||
}
|
||||
|
||||
child.start_kill()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn exec_server_rejects_host_default_sandbox_over_stdio() -> anyhow::Result<()> {
|
||||
let binary = cargo_bin("codex-exec-server")?;
|
||||
let mut child = Command::new(binary);
|
||||
child.stdin(Stdio::piped());
|
||||
child.stdout(Stdio::piped());
|
||||
child.stderr(Stdio::inherit());
|
||||
let mut child = child.spawn()?;
|
||||
|
||||
let mut stdin = child.stdin.take().expect("stdin");
|
||||
let stdout = child.stdout.take().expect("stdout");
|
||||
let mut stdout = BufReader::new(stdout).lines();
|
||||
|
||||
send_initialize_over_stdio(&mut stdin, &mut stdout).await?;
|
||||
|
||||
let exec = JSONRPCMessage::Request(JSONRPCRequest {
|
||||
id: RequestId::Integer(2),
|
||||
method: "process/start".to_string(),
|
||||
params: Some(serde_json::json!({
|
||||
"processId": "proc-1",
|
||||
"argv": ["bash", "-lc", "true"],
|
||||
"cwd": std::env::current_dir()?,
|
||||
"env": {},
|
||||
"tty": false,
|
||||
"arg0": null,
|
||||
"sandbox": {
|
||||
"mode": "hostDefault"
|
||||
}
|
||||
})),
|
||||
trace: None,
|
||||
});
|
||||
stdin
|
||||
.write_all(format!("{}\n", serde_json::to_string(&exec)?).as_bytes())
|
||||
.await?;
|
||||
|
||||
let response_line = timeout(Duration::from_secs(5), stdout.next_line()).await??;
|
||||
let response_line = response_line.expect("exec error line");
|
||||
let response: JSONRPCMessage = serde_json::from_str(&response_line)?;
|
||||
let JSONRPCMessage::Error(codex_app_server_protocol::JSONRPCError { id, error }) = response
|
||||
else {
|
||||
panic!("expected process/start error");
|
||||
};
|
||||
assert_eq!(id, RequestId::Integer(2));
|
||||
assert_eq!(error.code, -32600);
|
||||
assert_eq!(
|
||||
error.message,
|
||||
"sandbox mode `hostDefault` is not supported by exec-server yet"
|
||||
);
|
||||
|
||||
child.start_kill()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn exec_server_client_streams_output_and_accepts_writes() -> anyhow::Result<()> {
|
||||
let mut env = std::collections::HashMap::new();
|
||||
@@ -109,6 +228,7 @@ async fn exec_server_client_streams_output_and_accepts_writes() -> anyhow::Resul
|
||||
env,
|
||||
tty: true,
|
||||
arg0: None,
|
||||
sandbox: None,
|
||||
})
|
||||
.await?;
|
||||
let process_id = response.process_id;
|
||||
@@ -174,6 +294,7 @@ async fn exec_server_client_connects_over_websocket() -> anyhow::Result<()> {
|
||||
env,
|
||||
tty: true,
|
||||
arg0: None,
|
||||
sandbox: None,
|
||||
})
|
||||
.await?;
|
||||
let process_id = response.process_id;
|
||||
@@ -248,6 +369,7 @@ async fn websocket_disconnect_terminates_processes_for_that_connection() -> anyh
|
||||
env,
|
||||
tty: false,
|
||||
arg0: None,
|
||||
sandbox: None,
|
||||
})
|
||||
.await?;
|
||||
}
|
||||
@@ -276,6 +398,48 @@ where
|
||||
Ok(websocket_url.to_string())
|
||||
}
|
||||
|
||||
async fn send_initialize_over_stdio<W, R>(
|
||||
stdin: &mut W,
|
||||
stdout: &mut tokio::io::Lines<BufReader<R>>,
|
||||
) -> anyhow::Result<()>
|
||||
where
|
||||
W: tokio::io::AsyncWrite + Unpin,
|
||||
R: tokio::io::AsyncRead + Unpin,
|
||||
{
|
||||
let initialize = JSONRPCMessage::Request(JSONRPCRequest {
|
||||
id: RequestId::Integer(1),
|
||||
method: "initialize".to_string(),
|
||||
params: Some(serde_json::to_value(InitializeParams {
|
||||
client_name: "exec-server-test".to_string(),
|
||||
})?),
|
||||
trace: None,
|
||||
});
|
||||
stdin
|
||||
.write_all(format!("{}\n", serde_json::to_string(&initialize)?).as_bytes())
|
||||
.await?;
|
||||
|
||||
let response_line = timeout(Duration::from_secs(5), stdout.next_line()).await??;
|
||||
let response_line = response_line
|
||||
.ok_or_else(|| anyhow::anyhow!("missing initialize response line from stdio server"))?;
|
||||
let response: JSONRPCMessage = serde_json::from_str(&response_line)?;
|
||||
let JSONRPCMessage::Response(JSONRPCResponse { id, result }) = response else {
|
||||
panic!("expected initialize response");
|
||||
};
|
||||
assert_eq!(id, RequestId::Integer(1));
|
||||
let initialize_response: InitializeResponse = serde_json::from_value(result)?;
|
||||
assert_eq!(initialize_response.protocol_version, "exec-server.v0");
|
||||
|
||||
let initialized = JSONRPCMessage::Notification(JSONRPCNotification {
|
||||
method: "initialized".to_string(),
|
||||
params: Some(serde_json::json!({})),
|
||||
});
|
||||
stdin
|
||||
.write_all(format!("{}\n", serde_json::to_string(&initialized)?).as_bytes())
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn recv_until_contains(
|
||||
events: &mut broadcast::Receiver<ExecServerEvent>,
|
||||
process_id: &str,
|
||||
|
||||
@@ -6,6 +6,7 @@ mod ephemeral;
|
||||
mod mcp_required_exit;
|
||||
mod originator;
|
||||
mod output_schema;
|
||||
mod remote_exec_server;
|
||||
mod resume;
|
||||
mod sandbox;
|
||||
mod server_error_exit;
|
||||
|
||||
90
codex-rs/exec/tests/suite/remote_exec_server.rs
Normal file
90
codex-rs/exec/tests/suite/remote_exec_server.rs
Normal file
@@ -0,0 +1,90 @@
|
||||
#![cfg(not(target_os = "windows"))]
|
||||
#![allow(clippy::expect_used, clippy::unwrap_used)]
|
||||
|
||||
use std::net::TcpListener;
|
||||
use std::process::Stdio;
|
||||
use std::time::Duration;
|
||||
|
||||
use core_test_support::responses;
|
||||
use core_test_support::test_codex_exec::test_codex_exec;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
use tokio::process::Command;
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn exec_cli_can_route_exec_command_through_remote_exec_server() -> anyhow::Result<()> {
|
||||
let test = test_codex_exec();
|
||||
|
||||
let websocket_listener = TcpListener::bind("127.0.0.1:0")?;
|
||||
let websocket_port = websocket_listener.local_addr()?.port();
|
||||
drop(websocket_listener);
|
||||
let websocket_url = format!("ws://127.0.0.1:{websocket_port}");
|
||||
|
||||
let mut exec_server = Command::new(codex_utils_cargo_bin::cargo_bin("codex-exec-server")?)
|
||||
.arg("--listen")
|
||||
.arg(&websocket_url)
|
||||
.stdin(Stdio::null())
|
||||
.stdout(Stdio::null())
|
||||
.stderr(Stdio::inherit())
|
||||
.spawn()?;
|
||||
tokio::time::sleep(Duration::from_millis(250)).await;
|
||||
|
||||
let generated_path = test.cwd_path().join("remote_exec_generated.txt");
|
||||
let server = responses::start_mock_server().await;
|
||||
let response_mock = responses::mount_sse_sequence(
|
||||
&server,
|
||||
vec![
|
||||
responses::sse(vec![
|
||||
responses::ev_response_created("resp-1"),
|
||||
responses::ev_function_call(
|
||||
"call-exec",
|
||||
"exec_command",
|
||||
&serde_json::to_string(&json!({
|
||||
"cmd": "/bin/sh -lc 'printf from-remote > remote_exec_generated.txt'",
|
||||
"yield_time_ms": 500,
|
||||
}))?,
|
||||
),
|
||||
responses::ev_completed("resp-1"),
|
||||
]),
|
||||
responses::sse(vec![
|
||||
responses::ev_response_created("resp-2"),
|
||||
responses::ev_assistant_message("msg-1", "done"),
|
||||
responses::ev_completed("resp-2"),
|
||||
]),
|
||||
],
|
||||
)
|
||||
.await;
|
||||
|
||||
test.cmd_with_server(&server)
|
||||
.arg("--skip-git-repo-check")
|
||||
.arg("-s")
|
||||
.arg("danger-full-access")
|
||||
.arg("-c")
|
||||
.arg("experimental_use_unified_exec_tool=true")
|
||||
.arg("-c")
|
||||
.arg(format!(
|
||||
"experimental_unified_exec_exec_server_websocket_url={}",
|
||||
serde_json::to_string(&websocket_url)?
|
||||
))
|
||||
.arg("-c")
|
||||
.arg(format!(
|
||||
"experimental_unified_exec_exec_server_workspace_root={}",
|
||||
serde_json::to_string(test.cwd_path().to_string_lossy().as_ref())?
|
||||
))
|
||||
.arg("run remote exec-server command")
|
||||
.assert()
|
||||
.success();
|
||||
|
||||
let deadline = tokio::time::Instant::now() + Duration::from_secs(5);
|
||||
while tokio::time::Instant::now() < deadline && !generated_path.exists() {
|
||||
tokio::time::sleep(Duration::from_millis(50)).await;
|
||||
}
|
||||
assert_eq!(std::fs::read_to_string(&generated_path)?, "from-remote");
|
||||
|
||||
let requests = response_mock.requests();
|
||||
assert_eq!(requests.len(), 2);
|
||||
|
||||
exec_server.start_kill()?;
|
||||
let _ = exec_server.wait().await;
|
||||
Ok(())
|
||||
}
|
||||
17
codex-rs/mcp-types/Cargo.toml
Normal file
17
codex-rs/mcp-types/Cargo.toml
Normal file
@@ -0,0 +1,17 @@
|
||||
[package]
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
name = "codex-mcp-types"
|
||||
version.workspace = true
|
||||
|
||||
[lib]
|
||||
doctest = false
|
||||
name = "codex_mcp_types"
|
||||
path = "src/lib.rs"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
codex-protocol = { workspace = true }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
53
codex-rs/mcp-types/src/lib.rs
Normal file
53
codex-rs/mcp-types/src/lib.rs
Normal file
@@ -0,0 +1,53 @@
|
||||
//! Small, stable MCP data helpers split out of `codex-core`.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use codex_protocol::mcp::Tool;
|
||||
use codex_protocol::protocol::SandboxPolicy;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
|
||||
const MCP_TOOL_NAME_PREFIX: &str = "mcp";
|
||||
const MCP_TOOL_NAME_DELIMITER: &str = "__";
|
||||
|
||||
/// State needed by MCP servers to align their own sandboxing decisions with Codex.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct SandboxState {
|
||||
pub sandbox_policy: SandboxPolicy,
|
||||
pub codex_linux_sandbox_exe: Option<PathBuf>,
|
||||
pub sandbox_cwd: PathBuf,
|
||||
#[serde(default)]
|
||||
pub use_legacy_landlock: bool,
|
||||
}
|
||||
|
||||
pub fn split_qualified_tool_name(qualified_name: &str) -> Option<(String, String)> {
|
||||
let mut parts = qualified_name.split(MCP_TOOL_NAME_DELIMITER);
|
||||
let prefix = parts.next()?;
|
||||
if prefix != MCP_TOOL_NAME_PREFIX {
|
||||
return None;
|
||||
}
|
||||
let server_name = parts.next()?;
|
||||
let tool_name: String = parts.collect::<Vec<_>>().join(MCP_TOOL_NAME_DELIMITER);
|
||||
if tool_name.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some((server_name.to_string(), tool_name))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn group_tools_by_server(
|
||||
tools: &HashMap<String, Tool>,
|
||||
) -> HashMap<String, HashMap<String, Tool>> {
|
||||
let mut grouped = HashMap::new();
|
||||
for (qualified_name, tool) in tools {
|
||||
if let Some((server_name, tool_name)) = split_qualified_tool_name(qualified_name) {
|
||||
grouped
|
||||
.entry(server_name)
|
||||
.or_insert_with(HashMap::new)
|
||||
.insert(tool_name, tool.clone());
|
||||
}
|
||||
}
|
||||
grouped
|
||||
}
|
||||
17
codex-rs/tool-spec/Cargo.toml
Normal file
17
codex-rs/tool-spec/Cargo.toml
Normal file
@@ -0,0 +1,17 @@
|
||||
[package]
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
name = "codex-tool-spec"
|
||||
version.workspace = true
|
||||
|
||||
[lib]
|
||||
doctest = false
|
||||
name = "codex_tool_spec"
|
||||
path = "src/lib.rs"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
serde_json = { workspace = true }
|
||||
173
codex-rs/tool-spec/src/lib.rs
Normal file
173
codex-rs/tool-spec/src/lib.rs
Normal file
@@ -0,0 +1,173 @@
|
||||
//! Shared tool-schema types split out of `codex-core` so schema-only churn can
|
||||
//! rebuild independently from the full tool registry/orchestration crate.
|
||||
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use serde_json::Value as JsonValue;
|
||||
use serde_json::json;
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
/// Generic JSON-Schema subset needed for our tool definitions.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(tag = "type", rename_all = "lowercase")]
|
||||
pub enum JsonSchema {
|
||||
Boolean {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
description: Option<String>,
|
||||
},
|
||||
String {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
description: Option<String>,
|
||||
},
|
||||
/// MCP schema allows "number" | "integer" for Number.
|
||||
#[serde(alias = "integer")]
|
||||
Number {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
description: Option<String>,
|
||||
},
|
||||
Array {
|
||||
items: Box<JsonSchema>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
description: Option<String>,
|
||||
},
|
||||
Object {
|
||||
properties: BTreeMap<String, JsonSchema>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
required: Option<Vec<String>>,
|
||||
#[serde(
|
||||
rename = "additionalProperties",
|
||||
skip_serializing_if = "Option::is_none"
|
||||
)]
|
||||
additional_properties: Option<AdditionalProperties>,
|
||||
},
|
||||
}
|
||||
|
||||
/// Whether additional properties are allowed, and if so, any required schema.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(untagged)]
|
||||
pub enum AdditionalProperties {
|
||||
Boolean(bool),
|
||||
Schema(Box<JsonSchema>),
|
||||
}
|
||||
|
||||
impl From<bool> for AdditionalProperties {
|
||||
fn from(b: bool) -> Self {
|
||||
Self::Boolean(b)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<JsonSchema> for AdditionalProperties {
|
||||
fn from(s: JsonSchema) -> Self {
|
||||
Self::Schema(Box::new(s))
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse a tool `input_schema` or return an error for invalid schema.
|
||||
pub fn parse_tool_input_schema(input_schema: &JsonValue) -> Result<JsonSchema, serde_json::Error> {
|
||||
let mut input_schema = input_schema.clone();
|
||||
sanitize_json_schema(&mut input_schema);
|
||||
serde_json::from_value::<JsonSchema>(input_schema)
|
||||
}
|
||||
|
||||
/// Sanitize a JSON Schema (as `serde_json::Value`) so it can fit our limited
|
||||
/// `JsonSchema` enum. This function:
|
||||
/// - Ensures every schema object has a "type". If missing, infers it from
|
||||
/// common keywords (`properties` => object, `items` => array, `enum`/`const`/`format` => string)
|
||||
/// and otherwise defaults to "string".
|
||||
/// - Fills required child fields (e.g. array items, object properties) with
|
||||
/// permissive defaults when absent.
|
||||
fn sanitize_json_schema(value: &mut JsonValue) {
|
||||
match value {
|
||||
JsonValue::Bool(_) => {
|
||||
// JSON Schema boolean form: true/false. Coerce to an accept-all string.
|
||||
*value = json!({ "type": "string" });
|
||||
}
|
||||
JsonValue::Array(arr) => {
|
||||
for v in arr.iter_mut() {
|
||||
sanitize_json_schema(v);
|
||||
}
|
||||
}
|
||||
JsonValue::Object(map) => {
|
||||
if let Some(props) = map.get_mut("properties")
|
||||
&& let Some(props_map) = props.as_object_mut()
|
||||
{
|
||||
for (_k, v) in props_map.iter_mut() {
|
||||
sanitize_json_schema(v);
|
||||
}
|
||||
}
|
||||
if let Some(items) = map.get_mut("items") {
|
||||
sanitize_json_schema(items);
|
||||
}
|
||||
for combiner in ["oneOf", "anyOf", "allOf", "prefixItems"] {
|
||||
if let Some(v) = map.get_mut(combiner) {
|
||||
sanitize_json_schema(v);
|
||||
}
|
||||
}
|
||||
|
||||
let mut ty = map.get("type").and_then(|v| v.as_str()).map(str::to_string);
|
||||
|
||||
if ty.is_none()
|
||||
&& let Some(JsonValue::Array(types)) = map.get("type")
|
||||
{
|
||||
for t in types {
|
||||
if let Some(tt) = t.as_str()
|
||||
&& matches!(
|
||||
tt,
|
||||
"object" | "array" | "string" | "number" | "integer" | "boolean"
|
||||
)
|
||||
{
|
||||
ty = Some(tt.to_string());
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if ty.is_none() {
|
||||
if map.contains_key("properties")
|
||||
|| map.contains_key("required")
|
||||
|| map.contains_key("additionalProperties")
|
||||
{
|
||||
ty = Some("object".to_string());
|
||||
} else if map.contains_key("items") || map.contains_key("prefixItems") {
|
||||
ty = Some("array".to_string());
|
||||
} else if map.contains_key("enum")
|
||||
|| map.contains_key("const")
|
||||
|| map.contains_key("format")
|
||||
{
|
||||
ty = Some("string".to_string());
|
||||
} else if map.contains_key("minimum")
|
||||
|| map.contains_key("maximum")
|
||||
|| map.contains_key("exclusiveMinimum")
|
||||
|| map.contains_key("exclusiveMaximum")
|
||||
|| map.contains_key("multipleOf")
|
||||
{
|
||||
ty = Some("number".to_string());
|
||||
}
|
||||
}
|
||||
|
||||
let ty = ty.unwrap_or_else(|| "string".to_string());
|
||||
map.insert("type".to_string(), JsonValue::String(ty.clone()));
|
||||
|
||||
if ty == "object" {
|
||||
if !map.contains_key("properties") {
|
||||
map.insert(
|
||||
"properties".to_string(),
|
||||
JsonValue::Object(serde_json::Map::new()),
|
||||
);
|
||||
}
|
||||
if let Some(ap) = map.get_mut("additionalProperties") {
|
||||
let is_bool = matches!(ap, JsonValue::Bool(_));
|
||||
if !is_bool {
|
||||
sanitize_json_schema(ap);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if ty == "array" && !map.contains_key("items") {
|
||||
map.insert("items".to_string(), json!({ "type": "string" }));
|
||||
}
|
||||
}
|
||||
JsonValue::Null | JsonValue::Number(_) | JsonValue::String(_) => {}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user