mirror of
https://github.com/openai/codex.git
synced 2026-04-24 06:35:50 +00:00
Also need to expand tilda
This commit is contained in:
@@ -14,6 +14,7 @@ use std::path::PathBuf;
|
||||
use std::time::Duration;
|
||||
use wildmatch::WildMatchPattern;
|
||||
|
||||
use dirs::home_dir;
|
||||
use schemars::JsonSchema;
|
||||
use serde::Deserialize;
|
||||
use serde::Deserializer;
|
||||
@@ -156,11 +157,11 @@ impl<'de> Deserialize<'de> for McpServerConfig {
|
||||
throw_if_set("stdio", "http_headers", raw.http_headers.as_ref())?;
|
||||
throw_if_set("stdio", "env_http_headers", raw.env_http_headers.as_ref())?;
|
||||
McpServerTransportConfig::Stdio {
|
||||
command,
|
||||
command: expand_tilde_string(command),
|
||||
args: raw.args.clone().unwrap_or_default(),
|
||||
env: raw.env.clone(),
|
||||
env_vars: raw.env_vars.clone().unwrap_or_default(),
|
||||
cwd: raw.cwd.take(),
|
||||
cwd: raw.cwd.take().map(expand_tilde_pathbuf),
|
||||
}
|
||||
} else if let Some(url) = raw.url.clone() {
|
||||
throw_if_set("streamable_http", "args", raw.args.as_ref())?;
|
||||
@@ -194,6 +195,50 @@ const fn default_enabled() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn expand_tilde_pathbuf(path: PathBuf) -> PathBuf {
|
||||
let Some(path_str) = path.to_str() else {
|
||||
return path;
|
||||
};
|
||||
|
||||
if cfg!(target_os = "windows") {
|
||||
return path;
|
||||
}
|
||||
|
||||
let Some(home) = home_dir() else {
|
||||
return path;
|
||||
};
|
||||
|
||||
if path_str == "~" {
|
||||
return home;
|
||||
}
|
||||
|
||||
if let Some(rest) = path_str.strip_prefix("~/") {
|
||||
return home.join(rest);
|
||||
}
|
||||
|
||||
path
|
||||
}
|
||||
|
||||
fn expand_tilde_string(value: String) -> String {
|
||||
if cfg!(target_os = "windows") {
|
||||
return value;
|
||||
}
|
||||
|
||||
let Some(home) = home_dir() else {
|
||||
return value;
|
||||
};
|
||||
|
||||
if value == "~" {
|
||||
return home.to_string_lossy().into_owned();
|
||||
}
|
||||
|
||||
if let Some(rest) = value.strip_prefix("~/") {
|
||||
return home.join(rest).to_string_lossy().into_owned();
|
||||
}
|
||||
|
||||
value
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema)]
|
||||
#[serde(untagged, deny_unknown_fields, rename_all = "snake_case")]
|
||||
pub enum McpServerTransportConfig {
|
||||
@@ -876,6 +921,37 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_stdio_command_server_config_expands_tilde_cwd() {
|
||||
let cfg: McpServerConfig = toml::from_str(
|
||||
r#"
|
||||
command = "echo"
|
||||
cwd = "~/tmp"
|
||||
"#,
|
||||
)
|
||||
.expect("should deserialize command config with tilde cwd");
|
||||
|
||||
let expected_cwd = if cfg!(target_os = "windows") {
|
||||
PathBuf::from("~/tmp")
|
||||
} else {
|
||||
let Some(home) = home_dir() else {
|
||||
return;
|
||||
};
|
||||
home.join("tmp")
|
||||
};
|
||||
|
||||
assert_eq!(
|
||||
cfg.transport,
|
||||
McpServerTransportConfig::Stdio {
|
||||
command: "echo".to_string(),
|
||||
args: vec![],
|
||||
env: None,
|
||||
env_vars: Vec::new(),
|
||||
cwd: Some(expected_cwd),
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_disabled_server_config() {
|
||||
let cfg: McpServerConfig = toml::from_str(
|
||||
|
||||
Reference in New Issue
Block a user