Also need to expand tilda

This commit is contained in:
shijie-openai
2026-01-20 17:15:28 -08:00
parent 4371913278
commit 8a7ee646c5

View File

@@ -14,6 +14,7 @@ use std::path::PathBuf;
use std::time::Duration;
use wildmatch::WildMatchPattern;
use dirs::home_dir;
use schemars::JsonSchema;
use serde::Deserialize;
use serde::Deserializer;
@@ -156,11 +157,11 @@ impl<'de> Deserialize<'de> for McpServerConfig {
throw_if_set("stdio", "http_headers", raw.http_headers.as_ref())?;
throw_if_set("stdio", "env_http_headers", raw.env_http_headers.as_ref())?;
McpServerTransportConfig::Stdio {
command,
command: expand_tilde_string(command),
args: raw.args.clone().unwrap_or_default(),
env: raw.env.clone(),
env_vars: raw.env_vars.clone().unwrap_or_default(),
cwd: raw.cwd.take(),
cwd: raw.cwd.take().map(expand_tilde_pathbuf),
}
} else if let Some(url) = raw.url.clone() {
throw_if_set("streamable_http", "args", raw.args.as_ref())?;
@@ -194,6 +195,50 @@ const fn default_enabled() -> bool {
true
}
fn expand_tilde_pathbuf(path: PathBuf) -> PathBuf {
let Some(path_str) = path.to_str() else {
return path;
};
if cfg!(target_os = "windows") {
return path;
}
let Some(home) = home_dir() else {
return path;
};
if path_str == "~" {
return home;
}
if let Some(rest) = path_str.strip_prefix("~/") {
return home.join(rest);
}
path
}
fn expand_tilde_string(value: String) -> String {
if cfg!(target_os = "windows") {
return value;
}
let Some(home) = home_dir() else {
return value;
};
if value == "~" {
return home.to_string_lossy().into_owned();
}
if let Some(rest) = value.strip_prefix("~/") {
return home.join(rest).to_string_lossy().into_owned();
}
value
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema)]
#[serde(untagged, deny_unknown_fields, rename_all = "snake_case")]
pub enum McpServerTransportConfig {
@@ -876,6 +921,37 @@ mod tests {
);
}
#[test]
fn deserialize_stdio_command_server_config_expands_tilde_cwd() {
let cfg: McpServerConfig = toml::from_str(
r#"
command = "echo"
cwd = "~/tmp"
"#,
)
.expect("should deserialize command config with tilde cwd");
let expected_cwd = if cfg!(target_os = "windows") {
PathBuf::from("~/tmp")
} else {
let Some(home) = home_dir() else {
return;
};
home.join("tmp")
};
assert_eq!(
cfg.transport,
McpServerTransportConfig::Stdio {
command: "echo".to_string(),
args: vec![],
env: None,
env_vars: Vec::new(),
cwd: Some(expected_cwd),
}
);
}
#[test]
fn deserialize_disabled_server_config() {
let cfg: McpServerConfig = toml::from_str(