mirror of
https://github.com/openai/codex.git
synced 2026-02-01 22:47:52 +00:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c986aeb0c1 | ||
|
|
a43ae86b6c | ||
|
|
496cb801e1 |
8
codex-rs/Cargo.lock
generated
8
codex-rs/Cargo.lock
generated
@@ -4772,9 +4772,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "rmcp"
|
||||
version = "0.8.0"
|
||||
version = "0.8.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "583d060e99feb3a3683fb48a1e4bf5f8d4a50951f429726f330ee5ff548837f8"
|
||||
checksum = "6f35acda8f89fca5fd8c96cae3c6d5b4c38ea0072df4c8030915f3b5ff469c1c"
|
||||
dependencies = [
|
||||
"base64",
|
||||
"bytes",
|
||||
@@ -4806,9 +4806,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "rmcp-macros"
|
||||
version = "0.8.0"
|
||||
version = "0.8.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "421d8b0ba302f479214889486f9550e63feca3af310f1190efcf6e2016802693"
|
||||
checksum = "c9f1d5220aaa23b79c3d02e18f7a554403b3ccea544bbb6c69d6bcb3e854a274"
|
||||
dependencies = [
|
||||
"darling 0.21.3",
|
||||
"proc-macro2",
|
||||
|
||||
@@ -4,6 +4,7 @@ use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use anyhow::anyhow;
|
||||
use anyhow::bail;
|
||||
use clap::ArgGroup;
|
||||
use codex_common::CliConfigOverrides;
|
||||
use codex_core::config::Config;
|
||||
use codex_core::config::ConfigOverrides;
|
||||
@@ -77,13 +78,61 @@ pub struct AddArgs {
|
||||
/// Name for the MCP server configuration.
|
||||
pub name: String,
|
||||
|
||||
/// Environment variables to set when launching the server.
|
||||
#[arg(long, value_parser = parse_env_pair, value_name = "KEY=VALUE")]
|
||||
pub env: Vec<(String, String)>,
|
||||
#[command(flatten)]
|
||||
pub transport_args: AddMcpTransportArgs,
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::Args)]
|
||||
#[command(
|
||||
group(
|
||||
ArgGroup::new("transport")
|
||||
.args(["command", "url"])
|
||||
.required(true)
|
||||
.multiple(false)
|
||||
)
|
||||
)]
|
||||
pub struct AddMcpTransportArgs {
|
||||
#[command(flatten)]
|
||||
pub stdio: Option<AddMcpStdioArgs>,
|
||||
|
||||
#[command(flatten)]
|
||||
pub streamable_http: Option<AddMcpStreamableHttpArgs>,
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::Args)]
|
||||
pub struct AddMcpStdioArgs {
|
||||
/// Command to launch the MCP server.
|
||||
#[arg(trailing_var_arg = true, num_args = 1..)]
|
||||
/// Use --url for a streamable HTTP server.
|
||||
#[arg(
|
||||
trailing_var_arg = true,
|
||||
num_args = 0..,
|
||||
)]
|
||||
pub command: Vec<String>,
|
||||
|
||||
/// Environment variables to set when launching the server.
|
||||
/// Only valid with stdio servers.
|
||||
#[arg(
|
||||
long,
|
||||
value_parser = parse_env_pair,
|
||||
value_name = "KEY=VALUE",
|
||||
)]
|
||||
pub env: Vec<(String, String)>,
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::Args)]
|
||||
pub struct AddMcpStreamableHttpArgs {
|
||||
/// URL for a streamable HTTP MCP server.
|
||||
#[arg(long)]
|
||||
pub url: String,
|
||||
|
||||
/// Optional environment variable to read for a bearer token.
|
||||
/// Only valid with streamable HTTP servers.
|
||||
#[arg(
|
||||
long = "bearer-token-env-var",
|
||||
value_name = "ENV_VAR",
|
||||
requires = "url"
|
||||
)]
|
||||
pub bearer_token_env_var: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::Parser)]
|
||||
@@ -140,37 +189,51 @@ async fn run_add(config_overrides: &CliConfigOverrides, add_args: AddArgs) -> Re
|
||||
// Validate any provided overrides even though they are not currently applied.
|
||||
config_overrides.parse_overrides().map_err(|e| anyhow!(e))?;
|
||||
|
||||
let AddArgs { name, env, command } = add_args;
|
||||
let AddArgs {
|
||||
name,
|
||||
transport_args,
|
||||
} = add_args;
|
||||
|
||||
validate_server_name(&name)?;
|
||||
|
||||
let mut command_parts = command.into_iter();
|
||||
let command_bin = command_parts
|
||||
.next()
|
||||
.ok_or_else(|| anyhow!("command is required"))?;
|
||||
let command_args: Vec<String> = command_parts.collect();
|
||||
|
||||
let env_map = if env.is_empty() {
|
||||
None
|
||||
} else {
|
||||
let mut map = HashMap::new();
|
||||
for (key, value) in env {
|
||||
map.insert(key, value);
|
||||
}
|
||||
Some(map)
|
||||
};
|
||||
|
||||
let codex_home = find_codex_home().context("failed to resolve CODEX_HOME")?;
|
||||
let mut servers = load_global_mcp_servers(&codex_home)
|
||||
.await
|
||||
.with_context(|| format!("failed to load MCP servers from {}", codex_home.display()))?;
|
||||
|
||||
let new_entry = McpServerConfig {
|
||||
transport: McpServerTransportConfig::Stdio {
|
||||
command: command_bin,
|
||||
args: command_args,
|
||||
env: env_map,
|
||||
let transport = match transport_args {
|
||||
AddMcpTransportArgs {
|
||||
stdio: Some(stdio), ..
|
||||
} => {
|
||||
let mut command_parts = stdio.command.into_iter();
|
||||
let command_bin = command_parts
|
||||
.next()
|
||||
.ok_or_else(|| anyhow!("command is required"))?;
|
||||
let command_args: Vec<String> = command_parts.collect();
|
||||
|
||||
let env_map = if stdio.env.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(stdio.env.into_iter().collect::<HashMap<_, _>>())
|
||||
};
|
||||
McpServerTransportConfig::Stdio {
|
||||
command: command_bin,
|
||||
args: command_args,
|
||||
env: env_map,
|
||||
}
|
||||
}
|
||||
AddMcpTransportArgs {
|
||||
streamable_http: Some(streamable_http),
|
||||
..
|
||||
} => McpServerTransportConfig::StreamableHttp {
|
||||
url: streamable_http.url,
|
||||
bearer_token_env_var: streamable_http.bearer_token_env_var,
|
||||
},
|
||||
AddMcpTransportArgs { .. } => bail!("exactly one of --command or --url must be provided"),
|
||||
};
|
||||
|
||||
let new_entry = McpServerConfig {
|
||||
transport,
|
||||
startup_timeout_sec: None,
|
||||
tool_timeout_sec: None,
|
||||
};
|
||||
@@ -236,7 +299,7 @@ async fn run_login(config_overrides: &CliConfigOverrides, login_args: LoginArgs)
|
||||
_ => bail!("OAuth login is only supported for streamable HTTP servers."),
|
||||
};
|
||||
|
||||
perform_oauth_login(&name, &url).await?;
|
||||
perform_oauth_login(&name, &url, config.mcp_oauth_credentials_store_mode).await?;
|
||||
println!("Successfully logged in to MCP server '{name}'.");
|
||||
Ok(())
|
||||
}
|
||||
@@ -259,7 +322,7 @@ async fn run_logout(config_overrides: &CliConfigOverrides, logout_args: LogoutAr
|
||||
_ => bail!("OAuth logout is only supported for streamable_http transports."),
|
||||
};
|
||||
|
||||
match delete_oauth_tokens(&name, &url) {
|
||||
match delete_oauth_tokens(&name, &url, config.mcp_oauth_credentials_store_mode) {
|
||||
Ok(true) => println!("Removed OAuth credentials for '{name}'."),
|
||||
Ok(false) => println!("No OAuth credentials stored for '{name}'."),
|
||||
Err(err) => return Err(anyhow!("failed to delete OAuth credentials: {err}")),
|
||||
@@ -288,11 +351,14 @@ async fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) ->
|
||||
"args": args,
|
||||
"env": env,
|
||||
}),
|
||||
McpServerTransportConfig::StreamableHttp { url, bearer_token } => {
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url,
|
||||
bearer_token_env_var,
|
||||
} => {
|
||||
serde_json::json!({
|
||||
"type": "streamable_http",
|
||||
"url": url,
|
||||
"bearer_token": bearer_token,
|
||||
"bearer_token_env_var": bearer_token_env_var,
|
||||
})
|
||||
}
|
||||
};
|
||||
@@ -345,13 +411,15 @@ async fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) ->
|
||||
};
|
||||
stdio_rows.push([name.clone(), command.clone(), args_display, env_display]);
|
||||
}
|
||||
McpServerTransportConfig::StreamableHttp { url, bearer_token } => {
|
||||
let has_bearer = if bearer_token.is_some() {
|
||||
"True"
|
||||
} else {
|
||||
"False"
|
||||
};
|
||||
http_rows.push([name.clone(), url.clone(), has_bearer.into()]);
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url,
|
||||
bearer_token_env_var,
|
||||
} => {
|
||||
http_rows.push([
|
||||
name.clone(),
|
||||
url.clone(),
|
||||
bearer_token_env_var.clone().unwrap_or("-".to_string()),
|
||||
]);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -396,7 +464,7 @@ async fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) ->
|
||||
}
|
||||
|
||||
if !http_rows.is_empty() {
|
||||
let mut widths = ["Name".len(), "Url".len(), "Has Bearer Token".len()];
|
||||
let mut widths = ["Name".len(), "Url".len(), "Bearer Token Env Var".len()];
|
||||
for row in &http_rows {
|
||||
for (i, cell) in row.iter().enumerate() {
|
||||
widths[i] = widths[i].max(cell.len());
|
||||
@@ -407,7 +475,7 @@ async fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) ->
|
||||
"{:<name_w$} {:<url_w$} {:<token_w$}",
|
||||
"Name",
|
||||
"Url",
|
||||
"Has Bearer Token",
|
||||
"Bearer Token Env Var",
|
||||
name_w = widths[0],
|
||||
url_w = widths[1],
|
||||
token_w = widths[2],
|
||||
@@ -447,10 +515,13 @@ async fn run_get(config_overrides: &CliConfigOverrides, get_args: GetArgs) -> Re
|
||||
"args": args,
|
||||
"env": env,
|
||||
}),
|
||||
McpServerTransportConfig::StreamableHttp { url, bearer_token } => serde_json::json!({
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url,
|
||||
bearer_token_env_var,
|
||||
} => serde_json::json!({
|
||||
"type": "streamable_http",
|
||||
"url": url,
|
||||
"bearer_token": bearer_token,
|
||||
"bearer_token_env_var": bearer_token_env_var,
|
||||
}),
|
||||
};
|
||||
let output = serde_json::to_string_pretty(&serde_json::json!({
|
||||
@@ -493,11 +564,14 @@ async fn run_get(config_overrides: &CliConfigOverrides, get_args: GetArgs) -> Re
|
||||
};
|
||||
println!(" env: {env_display}");
|
||||
}
|
||||
McpServerTransportConfig::StreamableHttp { url, bearer_token } => {
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url,
|
||||
bearer_token_env_var,
|
||||
} => {
|
||||
println!(" transport: streamable_http");
|
||||
println!(" url: {url}");
|
||||
let bearer = bearer_token.as_deref().unwrap_or("-");
|
||||
println!(" bearer_token: {bearer}");
|
||||
let env_var = bearer_token_env_var.as_deref().unwrap_or("-");
|
||||
println!(" bearer_token_env_var: {env_var}");
|
||||
}
|
||||
}
|
||||
if let Some(timeout) = server.startup_timeout_sec {
|
||||
|
||||
@@ -93,3 +93,116 @@ async fn add_with_env_preserves_key_order_and_values() -> Result<()> {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn add_streamable_http_without_manual_token() -> Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
|
||||
let mut add_cmd = codex_command(codex_home.path())?;
|
||||
add_cmd
|
||||
.args(["mcp", "add", "github", "--url", "https://example.com/mcp"])
|
||||
.assert()
|
||||
.success();
|
||||
|
||||
let servers = load_global_mcp_servers(codex_home.path()).await?;
|
||||
let github = servers.get("github").expect("github server should exist");
|
||||
match &github.transport {
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url,
|
||||
bearer_token_env_var,
|
||||
} => {
|
||||
assert_eq!(url, "https://example.com/mcp");
|
||||
assert!(bearer_token_env_var.is_none());
|
||||
}
|
||||
other => panic!("unexpected transport: {other:?}"),
|
||||
}
|
||||
|
||||
assert!(!codex_home.path().join(".credentials.json").exists());
|
||||
assert!(!codex_home.path().join(".env").exists());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn add_streamable_http_with_custom_env_var() -> Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
|
||||
let mut add_cmd = codex_command(codex_home.path())?;
|
||||
add_cmd
|
||||
.args([
|
||||
"mcp",
|
||||
"add",
|
||||
"issues",
|
||||
"--url",
|
||||
"https://example.com/issues",
|
||||
"--bearer-token-env-var",
|
||||
"GITHUB_TOKEN",
|
||||
])
|
||||
.assert()
|
||||
.success();
|
||||
|
||||
let servers = load_global_mcp_servers(codex_home.path()).await?;
|
||||
let issues = servers.get("issues").expect("issues server should exist");
|
||||
match &issues.transport {
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url,
|
||||
bearer_token_env_var,
|
||||
} => {
|
||||
assert_eq!(url, "https://example.com/issues");
|
||||
assert_eq!(bearer_token_env_var.as_deref(), Some("GITHUB_TOKEN"));
|
||||
}
|
||||
other => panic!("unexpected transport: {other:?}"),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn add_streamable_http_rejects_removed_flag() -> Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
|
||||
let mut add_cmd = codex_command(codex_home.path())?;
|
||||
add_cmd
|
||||
.args([
|
||||
"mcp",
|
||||
"add",
|
||||
"github",
|
||||
"--url",
|
||||
"https://example.com/mcp",
|
||||
"--with-bearer-token",
|
||||
])
|
||||
.assert()
|
||||
.failure()
|
||||
.stderr(contains("--with-bearer-token"));
|
||||
|
||||
let servers = load_global_mcp_servers(codex_home.path()).await?;
|
||||
assert!(servers.is_empty());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn add_cant_add_command_and_url() -> Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
|
||||
let mut add_cmd = codex_command(codex_home.path())?;
|
||||
add_cmd
|
||||
.args([
|
||||
"mcp",
|
||||
"add",
|
||||
"github",
|
||||
"--url",
|
||||
"https://example.com/mcp",
|
||||
"--command",
|
||||
"--",
|
||||
"echo",
|
||||
"hello",
|
||||
])
|
||||
.assert()
|
||||
.failure()
|
||||
.stderr(contains("unexpected argument '--command' found"));
|
||||
|
||||
let servers = load_global_mcp_servers(codex_home.path()).await?;
|
||||
assert!(servers.is_empty());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -22,6 +22,7 @@ use tokio::sync::mpsc;
|
||||
use tokio::time::timeout;
|
||||
use tokio_util::io::ReaderStream;
|
||||
use tracing::debug;
|
||||
use tracing::error;
|
||||
use tracing::trace;
|
||||
use tracing::warn;
|
||||
|
||||
@@ -655,7 +656,7 @@ async fn process_sse<S>(
|
||||
{
|
||||
Ok(Some(Ok(sse))) => sse,
|
||||
Ok(Some(Err(e))) => {
|
||||
debug!("SSE Error: {e:#}");
|
||||
error!("SSE Error: {e:#}");
|
||||
let event = CodexErr::Stream(e.to_string(), None);
|
||||
let _ = tx_event.send(Err(event)).await;
|
||||
return;
|
||||
@@ -716,7 +717,7 @@ async fn process_sse<S>(
|
||||
let event: SseEvent = match serde_json::from_str(&sse.data) {
|
||||
Ok(event) => event,
|
||||
Err(e) => {
|
||||
debug!("Failed to parse SSE event: {e}, data: {}", &sse.data);
|
||||
error!("Failed to parse SSE event: {e}, data: {}", &sse.data);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
@@ -743,7 +744,7 @@ async fn process_sse<S>(
|
||||
"response.output_item.done" => {
|
||||
let Some(item_val) = event.item else { continue };
|
||||
let Ok(item) = serde_json::from_value::<ResponseItem>(item_val) else {
|
||||
debug!("failed to parse ResponseItem from output_item.done");
|
||||
error!("failed to parse ResponseItem from output_item.done");
|
||||
continue;
|
||||
};
|
||||
|
||||
@@ -802,9 +803,7 @@ async fn process_sse<S>(
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
let error = format!("failed to parse ErrorResponse: {e}");
|
||||
debug!(error);
|
||||
response_error = Some(CodexErr::Stream(error, None))
|
||||
error!("failed to parse ErrorResponse: {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -818,9 +817,7 @@ async fn process_sse<S>(
|
||||
response_completed = Some(r);
|
||||
}
|
||||
Err(e) => {
|
||||
let error = format!("failed to parse ResponseCompleted: {e}");
|
||||
debug!(error);
|
||||
response_error = Some(CodexErr::Stream(error, None));
|
||||
error!("failed to parse ResponseCompleted: {e}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -364,6 +364,7 @@ impl Session {
|
||||
let mcp_fut = McpConnectionManager::new(
|
||||
config.mcp_servers.clone(),
|
||||
config.use_experimental_use_rmcp_client,
|
||||
config.mcp_oauth_credentials_store_mode,
|
||||
);
|
||||
let default_shell_fut = shell::default_user_shell();
|
||||
let history_meta_fut = crate::message_history::history_metadata(&config);
|
||||
|
||||
@@ -33,12 +33,15 @@ use codex_protocol::config_types::ReasoningEffort;
|
||||
use codex_protocol::config_types::ReasoningSummary;
|
||||
use codex_protocol::config_types::SandboxMode;
|
||||
use codex_protocol::config_types::Verbosity;
|
||||
use codex_rmcp_client::OAuthCredentialsStoreMode;
|
||||
use dirs::home_dir;
|
||||
use serde::Deserialize;
|
||||
use std::collections::BTreeMap;
|
||||
use std::collections::HashMap;
|
||||
use std::io::ErrorKind;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use tempfile::NamedTempFile;
|
||||
use toml::Value as TomlValue;
|
||||
use toml_edit::Array as TomlArray;
|
||||
@@ -142,6 +145,15 @@ pub struct Config {
|
||||
/// Definition for MCP servers that Codex can reach out to for tool calls.
|
||||
pub mcp_servers: HashMap<String, McpServerConfig>,
|
||||
|
||||
/// Preferred store for MCP OAuth credentials.
|
||||
/// keyring: Use an OS-specific keyring service.
|
||||
/// Credentials stored in the keyring will only be readable by Codex unless the user explicitly grants access via OS-level keyring access.
|
||||
/// https://github.com/openai/codex/blob/main/codex-rs/rmcp-client/src/oauth.rs#L2
|
||||
/// file: CODEX_HOME/.credentials.json
|
||||
/// This file will be readable to Codex and other applications running as the same user.
|
||||
/// auto (default): keyring if available, otherwise file.
|
||||
pub mcp_oauth_credentials_store_mode: OAuthCredentialsStoreMode,
|
||||
|
||||
/// Combined provider map (defaults merged with user-defined overrides).
|
||||
pub model_providers: HashMap<String, ModelProviderInfo>,
|
||||
|
||||
@@ -301,12 +313,35 @@ pub async fn load_global_mcp_servers(
|
||||
return Ok(BTreeMap::new());
|
||||
};
|
||||
|
||||
ensure_no_inline_bearer_tokens(servers_value)?;
|
||||
|
||||
servers_value
|
||||
.clone()
|
||||
.try_into()
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
|
||||
}
|
||||
|
||||
/// We briefly allowed plain text bearer_token fields in MCP server configs.
|
||||
/// We want to warn people who recently added these fields but can remove this after a few months.
|
||||
fn ensure_no_inline_bearer_tokens(value: &TomlValue) -> std::io::Result<()> {
|
||||
let Some(servers_table) = value.as_table() else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
for (server_name, server_value) in servers_table {
|
||||
if let Some(server_table) = server_value.as_table()
|
||||
&& server_table.contains_key("bearer_token")
|
||||
{
|
||||
let message = format!(
|
||||
"mcp_servers.{server_name} uses unsupported `bearer_token`; set `bearer_token_env_var`."
|
||||
);
|
||||
return Err(std::io::Error::new(ErrorKind::InvalidData, message));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn write_global_mcp_servers(
|
||||
codex_home: &Path,
|
||||
servers: &BTreeMap<String, McpServerConfig>,
|
||||
@@ -355,10 +390,13 @@ pub fn write_global_mcp_servers(
|
||||
entry["env"] = TomlItem::Table(env_table);
|
||||
}
|
||||
}
|
||||
McpServerTransportConfig::StreamableHttp { url, bearer_token } => {
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url,
|
||||
bearer_token_env_var,
|
||||
} => {
|
||||
entry["url"] = toml_edit::value(url.clone());
|
||||
if let Some(token) = bearer_token {
|
||||
entry["bearer_token"] = toml_edit::value(token.clone());
|
||||
if let Some(env_var) = bearer_token_env_var {
|
||||
entry["bearer_token_env_var"] = toml_edit::value(env_var.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -694,6 +732,14 @@ pub struct ConfigToml {
|
||||
#[serde(default)]
|
||||
pub mcp_servers: HashMap<String, McpServerConfig>,
|
||||
|
||||
/// Preferred backend for storing MCP OAuth credentials.
|
||||
/// keyring: Use an OS-specific keyring service.
|
||||
/// https://github.com/openai/codex/blob/main/codex-rs/rmcp-client/src/oauth.rs#L2
|
||||
/// file: Use a file in the Codex home directory.
|
||||
/// auto (default): Use the OS-specific keyring service if available, otherwise use a file.
|
||||
#[serde(default)]
|
||||
pub mcp_oauth_credentials_store: Option<OAuthCredentialsStoreMode>,
|
||||
|
||||
/// User-defined provider entries that extend/override the built-in list.
|
||||
#[serde(default)]
|
||||
pub model_providers: HashMap<String, ModelProviderInfo>,
|
||||
@@ -1074,6 +1120,9 @@ impl Config {
|
||||
user_instructions,
|
||||
base_instructions,
|
||||
mcp_servers: cfg.mcp_servers,
|
||||
// The config.toml omits "_mode" because it's a config file. However, "_mode"
|
||||
// is important in code to differentiate the mode from the store implementation.
|
||||
mcp_oauth_credentials_store_mode: cfg.mcp_oauth_credentials_store.unwrap_or_default(),
|
||||
model_providers,
|
||||
project_doc_max_bytes: cfg.project_doc_max_bytes.unwrap_or(PROJECT_DOC_MAX_BYTES),
|
||||
project_doc_fallback_filenames: cfg
|
||||
@@ -1364,6 +1413,85 @@ exclude_slash_tmp = true
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_defaults_to_auto_oauth_store_mode() -> std::io::Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
let cfg = ConfigToml::default();
|
||||
|
||||
let config = Config::load_from_base_config_with_overrides(
|
||||
cfg,
|
||||
ConfigOverrides::default(),
|
||||
codex_home.path().to_path_buf(),
|
||||
)?;
|
||||
|
||||
assert_eq!(
|
||||
config.mcp_oauth_credentials_store_mode,
|
||||
OAuthCredentialsStoreMode::Auto,
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_honors_explicit_file_oauth_store_mode() -> std::io::Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
let cfg = ConfigToml {
|
||||
mcp_oauth_credentials_store: Some(OAuthCredentialsStoreMode::File),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let config = Config::load_from_base_config_with_overrides(
|
||||
cfg,
|
||||
ConfigOverrides::default(),
|
||||
codex_home.path().to_path_buf(),
|
||||
)?;
|
||||
|
||||
assert_eq!(
|
||||
config.mcp_oauth_credentials_store_mode,
|
||||
OAuthCredentialsStoreMode::File,
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn managed_config_overrides_oauth_store_mode() -> anyhow::Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
let managed_path = codex_home.path().join("managed_config.toml");
|
||||
let config_path = codex_home.path().join(CONFIG_TOML_FILE);
|
||||
|
||||
std::fs::write(&config_path, "mcp_oauth_credentials_store = \"file\"\n")?;
|
||||
std::fs::write(&managed_path, "mcp_oauth_credentials_store = \"keyring\"\n")?;
|
||||
|
||||
let overrides = crate::config_loader::LoaderOverrides {
|
||||
managed_config_path: Some(managed_path.clone()),
|
||||
#[cfg(target_os = "macos")]
|
||||
managed_preferences_base64: None,
|
||||
};
|
||||
|
||||
let root_value = load_resolved_config(codex_home.path(), Vec::new(), overrides).await?;
|
||||
let cfg: ConfigToml = root_value.try_into().map_err(|e| {
|
||||
tracing::error!("Failed to deserialize overridden config: {e}");
|
||||
std::io::Error::new(std::io::ErrorKind::InvalidData, e)
|
||||
})?;
|
||||
assert_eq!(
|
||||
cfg.mcp_oauth_credentials_store,
|
||||
Some(OAuthCredentialsStoreMode::Keyring),
|
||||
);
|
||||
|
||||
let final_config = Config::load_from_base_config_with_overrides(
|
||||
cfg,
|
||||
ConfigOverrides::default(),
|
||||
codex_home.path().to_path_buf(),
|
||||
)?;
|
||||
assert_eq!(
|
||||
final_config.mcp_oauth_credentials_store_mode,
|
||||
OAuthCredentialsStoreMode::Keyring,
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn load_global_mcp_servers_returns_empty_if_missing() -> anyhow::Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
@@ -1471,6 +1599,31 @@ startup_timeout_ms = 2500
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn load_global_mcp_servers_rejects_inline_bearer_token() -> anyhow::Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
let config_path = codex_home.path().join(CONFIG_TOML_FILE);
|
||||
|
||||
std::fs::write(
|
||||
&config_path,
|
||||
r#"
|
||||
[mcp_servers.docs]
|
||||
url = "https://example.com/mcp"
|
||||
bearer_token = "secret"
|
||||
"#,
|
||||
)?;
|
||||
|
||||
let err = load_global_mcp_servers(codex_home.path())
|
||||
.await
|
||||
.expect_err("bearer_token entries should be rejected");
|
||||
|
||||
assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
|
||||
assert!(err.to_string().contains("bearer_token"));
|
||||
assert!(err.to_string().contains("bearer_token_env_var"));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn write_global_mcp_servers_serializes_env_sorted() -> anyhow::Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
@@ -1534,7 +1687,7 @@ ZIG_VAR = "3"
|
||||
McpServerConfig {
|
||||
transport: McpServerTransportConfig::StreamableHttp {
|
||||
url: "https://example.com/mcp".to_string(),
|
||||
bearer_token: Some("secret-token".to_string()),
|
||||
bearer_token_env_var: Some("MCP_TOKEN".to_string()),
|
||||
},
|
||||
startup_timeout_sec: Some(Duration::from_secs(2)),
|
||||
tool_timeout_sec: None,
|
||||
@@ -1549,7 +1702,7 @@ ZIG_VAR = "3"
|
||||
serialized,
|
||||
r#"[mcp_servers.docs]
|
||||
url = "https://example.com/mcp"
|
||||
bearer_token = "secret-token"
|
||||
bearer_token_env_var = "MCP_TOKEN"
|
||||
startup_timeout_sec = 2.0
|
||||
"#
|
||||
);
|
||||
@@ -1557,9 +1710,12 @@ startup_timeout_sec = 2.0
|
||||
let loaded = load_global_mcp_servers(codex_home.path()).await?;
|
||||
let docs = loaded.get("docs").expect("docs entry");
|
||||
match &docs.transport {
|
||||
McpServerTransportConfig::StreamableHttp { url, bearer_token } => {
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url,
|
||||
bearer_token_env_var,
|
||||
} => {
|
||||
assert_eq!(url, "https://example.com/mcp");
|
||||
assert_eq!(bearer_token.as_deref(), Some("secret-token"));
|
||||
assert_eq!(bearer_token_env_var.as_deref(), Some("MCP_TOKEN"));
|
||||
}
|
||||
other => panic!("unexpected transport {other:?}"),
|
||||
}
|
||||
@@ -1570,7 +1726,7 @@ startup_timeout_sec = 2.0
|
||||
McpServerConfig {
|
||||
transport: McpServerTransportConfig::StreamableHttp {
|
||||
url: "https://example.com/mcp".to_string(),
|
||||
bearer_token: None,
|
||||
bearer_token_env_var: None,
|
||||
},
|
||||
startup_timeout_sec: None,
|
||||
tool_timeout_sec: None,
|
||||
@@ -1589,9 +1745,12 @@ url = "https://example.com/mcp"
|
||||
let loaded = load_global_mcp_servers(codex_home.path()).await?;
|
||||
let docs = loaded.get("docs").expect("docs entry");
|
||||
match &docs.transport {
|
||||
McpServerTransportConfig::StreamableHttp { url, bearer_token } => {
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url,
|
||||
bearer_token_env_var,
|
||||
} => {
|
||||
assert_eq!(url, "https://example.com/mcp");
|
||||
assert!(bearer_token.is_none());
|
||||
assert!(bearer_token_env_var.is_none());
|
||||
}
|
||||
other => panic!("unexpected transport {other:?}"),
|
||||
}
|
||||
@@ -1896,6 +2055,7 @@ model_verbosity = "high"
|
||||
notify: None,
|
||||
cwd: fixture.cwd(),
|
||||
mcp_servers: HashMap::new(),
|
||||
mcp_oauth_credentials_store_mode: Default::default(),
|
||||
model_providers: fixture.model_provider_map.clone(),
|
||||
project_doc_max_bytes: PROJECT_DOC_MAX_BYTES,
|
||||
project_doc_fallback_filenames: Vec::new(),
|
||||
@@ -1958,6 +2118,7 @@ model_verbosity = "high"
|
||||
notify: None,
|
||||
cwd: fixture.cwd(),
|
||||
mcp_servers: HashMap::new(),
|
||||
mcp_oauth_credentials_store_mode: Default::default(),
|
||||
model_providers: fixture.model_provider_map.clone(),
|
||||
project_doc_max_bytes: PROJECT_DOC_MAX_BYTES,
|
||||
project_doc_fallback_filenames: Vec::new(),
|
||||
@@ -2035,6 +2196,7 @@ model_verbosity = "high"
|
||||
notify: None,
|
||||
cwd: fixture.cwd(),
|
||||
mcp_servers: HashMap::new(),
|
||||
mcp_oauth_credentials_store_mode: Default::default(),
|
||||
model_providers: fixture.model_provider_map.clone(),
|
||||
project_doc_max_bytes: PROJECT_DOC_MAX_BYTES,
|
||||
project_doc_fallback_filenames: Vec::new(),
|
||||
@@ -2098,6 +2260,7 @@ model_verbosity = "high"
|
||||
notify: None,
|
||||
cwd: fixture.cwd(),
|
||||
mcp_servers: HashMap::new(),
|
||||
mcp_oauth_credentials_store_mode: Default::default(),
|
||||
model_providers: fixture.model_provider_map.clone(),
|
||||
project_doc_max_bytes: PROJECT_DOC_MAX_BYTES,
|
||||
project_doc_fallback_filenames: Vec::new(),
|
||||
|
||||
@@ -48,6 +48,7 @@ impl<'de> Deserialize<'de> for McpServerConfig {
|
||||
|
||||
url: Option<String>,
|
||||
bearer_token: Option<String>,
|
||||
bearer_token_env_var: Option<String>,
|
||||
|
||||
#[serde(default)]
|
||||
startup_timeout_sec: Option<f64>,
|
||||
@@ -86,11 +87,15 @@ impl<'de> Deserialize<'de> for McpServerConfig {
|
||||
args,
|
||||
env,
|
||||
url,
|
||||
bearer_token,
|
||||
bearer_token_env_var,
|
||||
..
|
||||
} => {
|
||||
throw_if_set("stdio", "url", url.as_ref())?;
|
||||
throw_if_set("stdio", "bearer_token", bearer_token.as_ref())?;
|
||||
throw_if_set(
|
||||
"stdio",
|
||||
"bearer_token_env_var",
|
||||
bearer_token_env_var.as_ref(),
|
||||
)?;
|
||||
McpServerTransportConfig::Stdio {
|
||||
command,
|
||||
args: args.unwrap_or_default(),
|
||||
@@ -100,6 +105,7 @@ impl<'de> Deserialize<'de> for McpServerConfig {
|
||||
RawMcpServerConfig {
|
||||
url: Some(url),
|
||||
bearer_token,
|
||||
bearer_token_env_var,
|
||||
command,
|
||||
args,
|
||||
env,
|
||||
@@ -108,7 +114,11 @@ impl<'de> Deserialize<'de> for McpServerConfig {
|
||||
throw_if_set("streamable_http", "command", command.as_ref())?;
|
||||
throw_if_set("streamable_http", "args", args.as_ref())?;
|
||||
throw_if_set("streamable_http", "env", env.as_ref())?;
|
||||
McpServerTransportConfig::StreamableHttp { url, bearer_token }
|
||||
throw_if_set("streamable_http", "bearer_token", bearer_token.as_ref())?;
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url,
|
||||
bearer_token_env_var,
|
||||
}
|
||||
}
|
||||
_ => return Err(SerdeError::custom("invalid transport")),
|
||||
};
|
||||
@@ -135,11 +145,11 @@ pub enum McpServerTransportConfig {
|
||||
/// https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#streamable-http
|
||||
StreamableHttp {
|
||||
url: String,
|
||||
/// A plain text bearer token to use for authentication.
|
||||
/// This bearer token will be included in the HTTP request header as an `Authorization: Bearer <token>` header.
|
||||
/// This should be used with caution because it lives on disk in clear text.
|
||||
/// Name of the environment variable to read for an HTTP bearer token.
|
||||
/// When set, requests will include the token via `Authorization: Bearer <token>`.
|
||||
/// The actual secret value must be provided via the environment.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
bearer_token: Option<String>,
|
||||
bearer_token_env_var: Option<String>,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -506,17 +516,17 @@ mod tests {
|
||||
cfg.transport,
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url: "https://example.com/mcp".to_string(),
|
||||
bearer_token: None
|
||||
bearer_token_env_var: None
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_streamable_http_server_config_with_bearer_token() {
|
||||
fn deserialize_streamable_http_server_config_with_env_var() {
|
||||
let cfg: McpServerConfig = toml::from_str(
|
||||
r#"
|
||||
url = "https://example.com/mcp"
|
||||
bearer_token = "secret"
|
||||
bearer_token_env_var = "GITHUB_TOKEN"
|
||||
"#,
|
||||
)
|
||||
.expect("should deserialize http config");
|
||||
@@ -525,7 +535,7 @@ mod tests {
|
||||
cfg.transport,
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url: "https://example.com/mcp".to_string(),
|
||||
bearer_token: Some("secret".to_string())
|
||||
bearer_token_env_var: Some("GITHUB_TOKEN".to_string())
|
||||
}
|
||||
);
|
||||
}
|
||||
@@ -553,13 +563,18 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_rejects_bearer_token_for_stdio_transport() {
|
||||
toml::from_str::<McpServerConfig>(
|
||||
fn deserialize_rejects_inline_bearer_token_field() {
|
||||
let err = toml::from_str::<McpServerConfig>(
|
||||
r#"
|
||||
command = "echo"
|
||||
url = "https://example.com"
|
||||
bearer_token = "secret"
|
||||
"#,
|
||||
)
|
||||
.expect_err("should reject bearer token for stdio transport");
|
||||
.expect_err("should reject bearer_token field");
|
||||
|
||||
assert!(
|
||||
err.to_string().contains("bearer_token is not supported"),
|
||||
"unexpected error: {err}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
use std::env;
|
||||
use std::ffi::OsString;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
@@ -16,6 +17,7 @@ use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use anyhow::anyhow;
|
||||
use codex_mcp_client::McpClient;
|
||||
use codex_rmcp_client::OAuthCredentialsStoreMode;
|
||||
use codex_rmcp_client::RmcpClient;
|
||||
use mcp_types::ClientCapabilities;
|
||||
use mcp_types::Implementation;
|
||||
@@ -125,9 +127,11 @@ impl McpClientAdapter {
|
||||
bearer_token: Option<String>,
|
||||
params: mcp_types::InitializeRequestParams,
|
||||
startup_timeout: Duration,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
) -> Result<Self> {
|
||||
let client = Arc::new(
|
||||
RmcpClient::new_streamable_http_client(&server_name, &url, bearer_token).await?,
|
||||
RmcpClient::new_streamable_http_client(&server_name, &url, bearer_token, store_mode)
|
||||
.await?,
|
||||
);
|
||||
client.initialize(params, Some(startup_timeout)).await?;
|
||||
Ok(McpClientAdapter::Rmcp(client))
|
||||
@@ -182,6 +186,7 @@ impl McpConnectionManager {
|
||||
pub async fn new(
|
||||
mcp_servers: HashMap<String, McpServerConfig>,
|
||||
use_rmcp_client: bool,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
) -> Result<(Self, ClientStartErrors)> {
|
||||
// Early exit if no servers are configured.
|
||||
if mcp_servers.is_empty() {
|
||||
@@ -205,6 +210,14 @@ impl McpConnectionManager {
|
||||
let startup_timeout = cfg.startup_timeout_sec.unwrap_or(DEFAULT_STARTUP_TIMEOUT);
|
||||
let tool_timeout = cfg.tool_timeout_sec.unwrap_or(DEFAULT_TOOL_TIMEOUT);
|
||||
|
||||
let resolved_bearer_token = match &cfg.transport {
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
bearer_token_env_var,
|
||||
..
|
||||
} => resolve_bearer_token(&server_name, bearer_token_env_var.as_deref()),
|
||||
_ => Ok(None),
|
||||
};
|
||||
|
||||
join_set.spawn(async move {
|
||||
let McpServerConfig { transport, .. } = cfg;
|
||||
let params = mcp_types::InitializeRequestParams {
|
||||
@@ -242,13 +255,14 @@ impl McpConnectionManager {
|
||||
)
|
||||
.await
|
||||
}
|
||||
McpServerTransportConfig::StreamableHttp { url, bearer_token } => {
|
||||
McpServerTransportConfig::StreamableHttp { url, .. } => {
|
||||
McpClientAdapter::new_streamable_http_client(
|
||||
server_name.clone(),
|
||||
url,
|
||||
bearer_token,
|
||||
resolved_bearer_token.unwrap_or_default(),
|
||||
params,
|
||||
startup_timeout,
|
||||
store_mode,
|
||||
)
|
||||
.await
|
||||
}
|
||||
@@ -336,6 +350,33 @@ impl McpConnectionManager {
|
||||
}
|
||||
}
|
||||
|
||||
fn resolve_bearer_token(
|
||||
server_name: &str,
|
||||
bearer_token_env_var: Option<&str>,
|
||||
) -> Result<Option<String>> {
|
||||
let Some(env_var) = bearer_token_env_var else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
match env::var(env_var) {
|
||||
Ok(value) => {
|
||||
if value.is_empty() {
|
||||
Err(anyhow!(
|
||||
"Environment variable {env_var} for MCP server '{server_name}' is empty"
|
||||
))
|
||||
} else {
|
||||
Ok(Some(value))
|
||||
}
|
||||
}
|
||||
Err(env::VarError::NotPresent) => Err(anyhow!(
|
||||
"Environment variable {env_var} for MCP server '{server_name}' is not set"
|
||||
)),
|
||||
Err(env::VarError::NotUnicode(_)) => Err(anyhow!(
|
||||
"Environment variable {env_var} for MCP server '{server_name}' contains invalid Unicode"
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Query every server for its available tools and return a single map that
|
||||
/// contains **all** tools. Each key is the fully-qualified name for the tool.
|
||||
async fn list_all_tools(clients: &HashMap<String, ManagedClient>) -> Result<Vec<ToolInfo>> {
|
||||
|
||||
@@ -232,7 +232,7 @@ async fn streamable_http_tool_call_round_trip() -> anyhow::Result<()> {
|
||||
McpServerConfig {
|
||||
transport: McpServerTransportConfig::StreamableHttp {
|
||||
url: server_url,
|
||||
bearer_token: None,
|
||||
bearer_token_env_var: None,
|
||||
},
|
||||
startup_timeout_sec: Some(Duration::from_secs(10)),
|
||||
tool_timeout_sec: None,
|
||||
@@ -412,7 +412,7 @@ async fn streamable_http_with_oauth_round_trip() -> anyhow::Result<()> {
|
||||
McpServerConfig {
|
||||
transport: McpServerTransportConfig::StreamableHttp {
|
||||
url: server_url,
|
||||
bearer_token: None,
|
||||
bearer_token_env_var: None,
|
||||
},
|
||||
startup_timeout_sec: Some(Duration::from_secs(10)),
|
||||
tool_timeout_sec: None,
|
||||
|
||||
@@ -5,6 +5,7 @@ mod perform_oauth_login;
|
||||
mod rmcp_client;
|
||||
mod utils;
|
||||
|
||||
pub use oauth::OAuthCredentialsStoreMode;
|
||||
pub use oauth::StoredOAuthTokens;
|
||||
pub use oauth::WrappedOAuthTokenResponse;
|
||||
pub use oauth::delete_oauth_tokens;
|
||||
|
||||
@@ -58,6 +58,21 @@ pub struct StoredOAuthTokens {
|
||||
pub token_response: WrappedOAuthTokenResponse,
|
||||
}
|
||||
|
||||
/// Determine where Codex should store and read MCP credentials.
|
||||
#[derive(Debug, Default, Copy, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum OAuthCredentialsStoreMode {
|
||||
/// `Keyring` when available; otherwise, `File`.
|
||||
/// Credentials stored in the keyring will only be readable by Codex unless the user explicitly grants access via OS-level keyring access.
|
||||
#[default]
|
||||
Auto,
|
||||
/// CODEX_HOME/.credentials.json
|
||||
/// This file will be readable to Codex and other applications running as the same user.
|
||||
File,
|
||||
/// Keyring when available, otherwise fail.
|
||||
Keyring,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct CredentialStoreError(anyhow::Error);
|
||||
|
||||
@@ -83,15 +98,15 @@ impl fmt::Display for CredentialStoreError {
|
||||
|
||||
impl std::error::Error for CredentialStoreError {}
|
||||
|
||||
trait CredentialStore {
|
||||
trait KeyringStore {
|
||||
fn load(&self, service: &str, account: &str) -> Result<Option<String>, CredentialStoreError>;
|
||||
fn save(&self, service: &str, account: &str, value: &str) -> Result<(), CredentialStoreError>;
|
||||
fn delete(&self, service: &str, account: &str) -> Result<bool, CredentialStoreError>;
|
||||
}
|
||||
|
||||
struct KeyringCredentialStore;
|
||||
struct DefaultKeyringStore;
|
||||
|
||||
impl CredentialStore for KeyringCredentialStore {
|
||||
impl KeyringStore for DefaultKeyringStore {
|
||||
fn load(&self, service: &str, account: &str) -> Result<Option<String>, CredentialStoreError> {
|
||||
let entry = Entry::new(service, account).map_err(CredentialStoreError::new)?;
|
||||
match entry.get_password() {
|
||||
@@ -129,47 +144,85 @@ impl PartialEq for WrappedOAuthTokenResponse {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn load_oauth_tokens(server_name: &str, url: &str) -> Result<Option<StoredOAuthTokens>> {
|
||||
let store = KeyringCredentialStore;
|
||||
load_oauth_tokens_with_store(&store, server_name, url)
|
||||
pub(crate) fn load_oauth_tokens(
|
||||
server_name: &str,
|
||||
url: &str,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
) -> Result<Option<StoredOAuthTokens>> {
|
||||
let keyring_store = DefaultKeyringStore;
|
||||
match store_mode {
|
||||
OAuthCredentialsStoreMode::Auto => {
|
||||
load_oauth_tokens_from_keyring_with_fallback_to_file(&keyring_store, server_name, url)
|
||||
}
|
||||
OAuthCredentialsStoreMode::File => load_oauth_tokens_from_file(server_name, url),
|
||||
OAuthCredentialsStoreMode::Keyring => {
|
||||
load_oauth_tokens_from_keyring(&keyring_store, server_name, url)
|
||||
.with_context(|| "failed to read OAuth tokens from keyring".to_string())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn load_oauth_tokens_with_store<C: CredentialStore>(
|
||||
store: &C,
|
||||
fn load_oauth_tokens_from_keyring_with_fallback_to_file<K: KeyringStore>(
|
||||
keyring_store: &K,
|
||||
server_name: &str,
|
||||
url: &str,
|
||||
) -> Result<Option<StoredOAuthTokens>> {
|
||||
match load_oauth_tokens_from_keyring(keyring_store, server_name, url) {
|
||||
Ok(Some(tokens)) => Ok(Some(tokens)),
|
||||
Ok(None) => load_oauth_tokens_from_file(server_name, url),
|
||||
Err(error) => {
|
||||
warn!("failed to read OAuth tokens from keyring: {error}");
|
||||
load_oauth_tokens_from_file(server_name, url)
|
||||
.with_context(|| format!("failed to read OAuth tokens from keyring: {error}"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn load_oauth_tokens_from_keyring<K: KeyringStore>(
|
||||
keyring_store: &K,
|
||||
server_name: &str,
|
||||
url: &str,
|
||||
) -> Result<Option<StoredOAuthTokens>> {
|
||||
let key = compute_store_key(server_name, url)?;
|
||||
match store.load(KEYRING_SERVICE, &key) {
|
||||
match keyring_store.load(KEYRING_SERVICE, &key) {
|
||||
Ok(Some(serialized)) => {
|
||||
let tokens: StoredOAuthTokens = serde_json::from_str(&serialized)
|
||||
.context("failed to deserialize OAuth tokens from keyring")?;
|
||||
Ok(Some(tokens))
|
||||
}
|
||||
Ok(None) => load_oauth_tokens_from_file(server_name, url),
|
||||
Err(error) => {
|
||||
let message = error.message();
|
||||
warn!("failed to read OAuth tokens from keyring: {message}");
|
||||
load_oauth_tokens_from_file(server_name, url)
|
||||
.with_context(|| format!("failed to read OAuth tokens from keyring: {message}"))
|
||||
Ok(None) => Ok(None),
|
||||
Err(error) => Err(error.into_error()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn save_oauth_tokens(
|
||||
server_name: &str,
|
||||
tokens: &StoredOAuthTokens,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
) -> Result<()> {
|
||||
let keyring_store = DefaultKeyringStore;
|
||||
match store_mode {
|
||||
OAuthCredentialsStoreMode::Auto => save_oauth_tokens_with_keyring_with_fallback_to_file(
|
||||
&keyring_store,
|
||||
server_name,
|
||||
tokens,
|
||||
),
|
||||
OAuthCredentialsStoreMode::File => save_oauth_tokens_to_file(tokens),
|
||||
OAuthCredentialsStoreMode::Keyring => {
|
||||
save_oauth_tokens_with_keyring(&keyring_store, server_name, tokens)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn save_oauth_tokens(server_name: &str, tokens: &StoredOAuthTokens) -> Result<()> {
|
||||
let store = KeyringCredentialStore;
|
||||
save_oauth_tokens_with_store(&store, server_name, tokens)
|
||||
}
|
||||
|
||||
fn save_oauth_tokens_with_store<C: CredentialStore>(
|
||||
store: &C,
|
||||
fn save_oauth_tokens_with_keyring<K: KeyringStore>(
|
||||
keyring_store: &K,
|
||||
server_name: &str,
|
||||
tokens: &StoredOAuthTokens,
|
||||
) -> Result<()> {
|
||||
let serialized = serde_json::to_string(tokens).context("failed to serialize OAuth tokens")?;
|
||||
|
||||
let key = compute_store_key(server_name, &tokens.url)?;
|
||||
match store.save(KEYRING_SERVICE, &key, &serialized) {
|
||||
match keyring_store.save(KEYRING_SERVICE, &key, &serialized) {
|
||||
Ok(()) => {
|
||||
if let Err(error) = delete_oauth_tokens_from_file(&key) {
|
||||
warn!("failed to remove OAuth tokens from fallback storage: {error:?}");
|
||||
@@ -177,31 +230,61 @@ fn save_oauth_tokens_with_store<C: CredentialStore>(
|
||||
Ok(())
|
||||
}
|
||||
Err(error) => {
|
||||
let message = error.message();
|
||||
warn!("failed to write OAuth tokens to keyring: {message}");
|
||||
let message = format!(
|
||||
"failed to write OAuth tokens to keyring: {}",
|
||||
error.message()
|
||||
);
|
||||
warn!("{message}");
|
||||
Err(error.into_error().context(message))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn save_oauth_tokens_with_keyring_with_fallback_to_file<K: KeyringStore>(
|
||||
keyring_store: &K,
|
||||
server_name: &str,
|
||||
tokens: &StoredOAuthTokens,
|
||||
) -> Result<()> {
|
||||
match save_oauth_tokens_with_keyring(keyring_store, server_name, tokens) {
|
||||
Ok(()) => Ok(()),
|
||||
Err(error) => {
|
||||
let message = error.to_string();
|
||||
warn!("falling back to file storage for OAuth tokens: {message}");
|
||||
save_oauth_tokens_to_file(tokens)
|
||||
.with_context(|| format!("failed to write OAuth tokens to keyring: {message}"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn delete_oauth_tokens(server_name: &str, url: &str) -> Result<bool> {
|
||||
let store = KeyringCredentialStore;
|
||||
delete_oauth_tokens_with_store(&store, server_name, url)
|
||||
pub fn delete_oauth_tokens(
|
||||
server_name: &str,
|
||||
url: &str,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
) -> Result<bool> {
|
||||
let keyring_store = DefaultKeyringStore;
|
||||
delete_oauth_tokens_from_keyring_and_file(&keyring_store, store_mode, server_name, url)
|
||||
}
|
||||
|
||||
fn delete_oauth_tokens_with_store<C: CredentialStore>(
|
||||
store: &C,
|
||||
fn delete_oauth_tokens_from_keyring_and_file<K: KeyringStore>(
|
||||
keyring_store: &K,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
server_name: &str,
|
||||
url: &str,
|
||||
) -> Result<bool> {
|
||||
let key = compute_store_key(server_name, url)?;
|
||||
let keyring_removed = match store.delete(KEYRING_SERVICE, &key) {
|
||||
let keyring_result = keyring_store.delete(KEYRING_SERVICE, &key);
|
||||
let keyring_removed = match keyring_result {
|
||||
Ok(removed) => removed,
|
||||
Err(error) => {
|
||||
let message = error.message();
|
||||
warn!("failed to delete OAuth tokens from keyring: {message}");
|
||||
return Err(error.into_error()).context("failed to delete OAuth tokens from keyring");
|
||||
match store_mode {
|
||||
OAuthCredentialsStoreMode::Auto | OAuthCredentialsStoreMode::Keyring => {
|
||||
return Err(error.into_error())
|
||||
.context("failed to delete OAuth tokens from keyring");
|
||||
}
|
||||
OAuthCredentialsStoreMode::File => false,
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -218,6 +301,7 @@ struct OAuthPersistorInner {
|
||||
server_name: String,
|
||||
url: String,
|
||||
authorization_manager: Arc<Mutex<AuthorizationManager>>,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
last_credentials: Mutex<Option<StoredOAuthTokens>>,
|
||||
}
|
||||
|
||||
@@ -225,14 +309,16 @@ impl OAuthPersistor {
|
||||
pub(crate) fn new(
|
||||
server_name: String,
|
||||
url: String,
|
||||
manager: Arc<Mutex<AuthorizationManager>>,
|
||||
authorization_manager: Arc<Mutex<AuthorizationManager>>,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
initial_credentials: Option<StoredOAuthTokens>,
|
||||
) -> Self {
|
||||
Self {
|
||||
inner: Arc::new(OAuthPersistorInner {
|
||||
server_name,
|
||||
url,
|
||||
authorization_manager: manager,
|
||||
authorization_manager,
|
||||
store_mode,
|
||||
last_credentials: Mutex::new(initial_credentials),
|
||||
}),
|
||||
}
|
||||
@@ -257,15 +343,18 @@ impl OAuthPersistor {
|
||||
};
|
||||
let mut last_credentials = self.inner.last_credentials.lock().await;
|
||||
if last_credentials.as_ref() != Some(&stored) {
|
||||
save_oauth_tokens(&self.inner.server_name, &stored)?;
|
||||
save_oauth_tokens(&self.inner.server_name, &stored, self.inner.store_mode)?;
|
||||
*last_credentials = Some(stored);
|
||||
}
|
||||
}
|
||||
None => {
|
||||
let mut last_serialized = self.inner.last_credentials.lock().await;
|
||||
if last_serialized.take().is_some()
|
||||
&& let Err(error) =
|
||||
delete_oauth_tokens(&self.inner.server_name, &self.inner.url)
|
||||
&& let Err(error) = delete_oauth_tokens(
|
||||
&self.inner.server_name,
|
||||
&self.inner.url,
|
||||
self.inner.store_mode,
|
||||
)
|
||||
{
|
||||
warn!(
|
||||
"failed to remove OAuth tokens for server {}: {error}",
|
||||
@@ -542,7 +631,7 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
impl CredentialStore for MockCredentialStore {
|
||||
impl KeyringStore for MockCredentialStore {
|
||||
fn load(
|
||||
&self,
|
||||
_service: &str,
|
||||
@@ -643,7 +732,8 @@ mod tests {
|
||||
let key = super::compute_store_key(&tokens.server_name, &tokens.url)?;
|
||||
store.save(KEYRING_SERVICE, &key, &serialized)?;
|
||||
|
||||
let loaded = super::load_oauth_tokens_with_store(&store, &tokens.server_name, &tokens.url)?;
|
||||
let loaded =
|
||||
super::load_oauth_tokens_from_keyring(&store, &tokens.server_name, &tokens.url)?;
|
||||
assert_eq!(loaded, Some(expected));
|
||||
Ok(())
|
||||
}
|
||||
@@ -657,8 +747,12 @@ mod tests {
|
||||
|
||||
super::save_oauth_tokens_to_file(&tokens)?;
|
||||
|
||||
let loaded = super::load_oauth_tokens_with_store(&store, &tokens.server_name, &tokens.url)?
|
||||
.expect("tokens should load from fallback");
|
||||
let loaded = super::load_oauth_tokens_from_keyring_with_fallback_to_file(
|
||||
&store,
|
||||
&tokens.server_name,
|
||||
&tokens.url,
|
||||
)?
|
||||
.expect("tokens should load from fallback");
|
||||
assert_tokens_match_without_expiry(&loaded, &expected);
|
||||
Ok(())
|
||||
}
|
||||
@@ -674,8 +768,12 @@ mod tests {
|
||||
|
||||
super::save_oauth_tokens_to_file(&tokens)?;
|
||||
|
||||
let loaded = super::load_oauth_tokens_with_store(&store, &tokens.server_name, &tokens.url)?
|
||||
.expect("tokens should load from fallback");
|
||||
let loaded = super::load_oauth_tokens_from_keyring_with_fallback_to_file(
|
||||
&store,
|
||||
&tokens.server_name,
|
||||
&tokens.url,
|
||||
)?
|
||||
.expect("tokens should load from fallback");
|
||||
assert_tokens_match_without_expiry(&loaded, &expected);
|
||||
Ok(())
|
||||
}
|
||||
@@ -689,7 +787,11 @@ mod tests {
|
||||
|
||||
super::save_oauth_tokens_to_file(&tokens)?;
|
||||
|
||||
super::save_oauth_tokens_with_store(&store, &tokens.server_name, &tokens)?;
|
||||
super::save_oauth_tokens_with_keyring_with_fallback_to_file(
|
||||
&store,
|
||||
&tokens.server_name,
|
||||
&tokens,
|
||||
)?;
|
||||
|
||||
let fallback_path = super::fallback_file_path()?;
|
||||
assert!(!fallback_path.exists(), "fallback file should be removed");
|
||||
@@ -706,7 +808,11 @@ mod tests {
|
||||
let key = super::compute_store_key(&tokens.server_name, &tokens.url)?;
|
||||
store.set_error(&key, KeyringError::Invalid("error".into(), "save".into()));
|
||||
|
||||
super::save_oauth_tokens_with_store(&store, &tokens.server_name, &tokens)?;
|
||||
super::save_oauth_tokens_with_keyring_with_fallback_to_file(
|
||||
&store,
|
||||
&tokens.server_name,
|
||||
&tokens,
|
||||
)?;
|
||||
|
||||
let fallback_path = super::fallback_file_path()?;
|
||||
assert!(fallback_path.exists(), "fallback file should be created");
|
||||
@@ -734,8 +840,34 @@ mod tests {
|
||||
store.save(KEYRING_SERVICE, &key, &serialized)?;
|
||||
super::save_oauth_tokens_to_file(&tokens)?;
|
||||
|
||||
let removed =
|
||||
super::delete_oauth_tokens_with_store(&store, &tokens.server_name, &tokens.url)?;
|
||||
let removed = super::delete_oauth_tokens_from_keyring_and_file(
|
||||
&store,
|
||||
OAuthCredentialsStoreMode::Auto,
|
||||
&tokens.server_name,
|
||||
&tokens.url,
|
||||
)?;
|
||||
assert!(removed);
|
||||
assert!(!store.contains(&key));
|
||||
assert!(!super::fallback_file_path()?.exists());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn delete_oauth_tokens_file_mode_removes_keyring_only_entry() -> Result<()> {
|
||||
let _env = TempCodexHome::new();
|
||||
let store = MockCredentialStore::default();
|
||||
let tokens = sample_tokens();
|
||||
let serialized = serde_json::to_string(&tokens)?;
|
||||
let key = super::compute_store_key(&tokens.server_name, &tokens.url)?;
|
||||
store.save(KEYRING_SERVICE, &key, &serialized)?;
|
||||
assert!(store.contains(&key));
|
||||
|
||||
let removed = super::delete_oauth_tokens_from_keyring_and_file(
|
||||
&store,
|
||||
OAuthCredentialsStoreMode::Auto,
|
||||
&tokens.server_name,
|
||||
&tokens.url,
|
||||
)?;
|
||||
assert!(removed);
|
||||
assert!(!store.contains(&key));
|
||||
assert!(!super::fallback_file_path()?.exists());
|
||||
@@ -751,8 +883,12 @@ mod tests {
|
||||
store.set_error(&key, KeyringError::Invalid("error".into(), "delete".into()));
|
||||
super::save_oauth_tokens_to_file(&tokens).unwrap();
|
||||
|
||||
let result =
|
||||
super::delete_oauth_tokens_with_store(&store, &tokens.server_name, &tokens.url);
|
||||
let result = super::delete_oauth_tokens_from_keyring_and_file(
|
||||
&store,
|
||||
OAuthCredentialsStoreMode::Auto,
|
||||
&tokens.server_name,
|
||||
&tokens.url,
|
||||
);
|
||||
assert!(result.is_err());
|
||||
assert!(super::fallback_file_path().unwrap().exists());
|
||||
Ok(())
|
||||
|
||||
@@ -12,6 +12,7 @@ use tokio::sync::oneshot;
|
||||
use tokio::time::timeout;
|
||||
use urlencoding::decode;
|
||||
|
||||
use crate::OAuthCredentialsStoreMode;
|
||||
use crate::StoredOAuthTokens;
|
||||
use crate::WrappedOAuthTokenResponse;
|
||||
use crate::save_oauth_tokens;
|
||||
@@ -26,7 +27,11 @@ impl Drop for CallbackServerGuard {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn perform_oauth_login(server_name: &str, server_url: &str) -> Result<()> {
|
||||
pub async fn perform_oauth_login(
|
||||
server_name: &str,
|
||||
server_url: &str,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
) -> Result<()> {
|
||||
let server = Arc::new(Server::http("127.0.0.1:0").map_err(|err| anyhow!(err))?);
|
||||
let guard = CallbackServerGuard {
|
||||
server: Arc::clone(&server),
|
||||
@@ -81,7 +86,7 @@ pub async fn perform_oauth_login(server_name: &str, server_url: &str) -> Result<
|
||||
client_id,
|
||||
token_response: WrappedOAuthTokenResponse(credentials),
|
||||
};
|
||||
save_oauth_tokens(server_name, &stored)?;
|
||||
save_oauth_tokens(server_name, &stored, store_mode)?;
|
||||
|
||||
drop(guard);
|
||||
Ok(())
|
||||
|
||||
@@ -35,6 +35,7 @@ use tracing::warn;
|
||||
|
||||
use crate::load_oauth_tokens;
|
||||
use crate::logging_client_handler::LoggingClientHandler;
|
||||
use crate::oauth::OAuthCredentialsStoreMode;
|
||||
use crate::oauth::OAuthPersistor;
|
||||
use crate::oauth::StoredOAuthTokens;
|
||||
use crate::utils::convert_call_tool_result;
|
||||
@@ -119,10 +120,11 @@ impl RmcpClient {
|
||||
server_name: &str,
|
||||
url: &str,
|
||||
bearer_token: Option<String>,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
) -> Result<Self> {
|
||||
let initial_oauth_tokens = match bearer_token {
|
||||
Some(_) => None,
|
||||
None => match load_oauth_tokens(server_name, url) {
|
||||
None => match load_oauth_tokens(server_name, url, store_mode) {
|
||||
Ok(tokens) => tokens,
|
||||
Err(err) => {
|
||||
warn!("failed to read tokens for server `{server_name}`: {err}");
|
||||
@@ -132,7 +134,8 @@ impl RmcpClient {
|
||||
};
|
||||
let transport = if let Some(initial_tokens) = initial_oauth_tokens.clone() {
|
||||
let (transport, oauth_persistor) =
|
||||
create_oauth_transport_and_runtime(server_name, url, initial_tokens).await?;
|
||||
create_oauth_transport_and_runtime(server_name, url, initial_tokens, store_mode)
|
||||
.await?;
|
||||
PendingTransport::StreamableHttpWithOAuth {
|
||||
transport,
|
||||
oauth_persistor,
|
||||
@@ -286,6 +289,7 @@ async fn create_oauth_transport_and_runtime(
|
||||
server_name: &str,
|
||||
url: &str,
|
||||
initial_tokens: StoredOAuthTokens,
|
||||
credentials_store: OAuthCredentialsStoreMode,
|
||||
) -> Result<(
|
||||
StreamableHttpClientTransport<AuthClient<reqwest::Client>>,
|
||||
OAuthPersistor,
|
||||
@@ -320,6 +324,7 @@ async fn create_oauth_transport_and_runtime(
|
||||
server_name.to_string(),
|
||||
url.to_string(),
|
||||
auth_manager,
|
||||
credentials_store,
|
||||
Some(initial_tokens),
|
||||
);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user