mirror of
https://github.com/openai/codex.git
synced 2026-05-01 18:06:47 +00:00
Compare commits
4 Commits
windows-sa
...
etraut/add
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6feecbf5d1 | ||
|
|
e7479ee1cc | ||
|
|
78249f7fce | ||
|
|
c47fd545dd |
@@ -12,7 +12,6 @@ use codex_core::CodexAuth;
|
||||
use codex_core::config::types::McpServerConfig;
|
||||
use codex_core::config::types::McpServerTransportConfig;
|
||||
use codex_core::models_manager::manager::RefreshStrategy;
|
||||
|
||||
use codex_protocol::config_types::ReasoningSummary;
|
||||
use codex_protocol::openai_models::ConfigShellToolType;
|
||||
use codex_protocol::openai_models::InputModality;
|
||||
@@ -28,6 +27,10 @@ use codex_protocol::protocol::McpToolCallBeginEvent;
|
||||
use codex_protocol::protocol::Op;
|
||||
use codex_protocol::protocol::SandboxPolicy;
|
||||
use codex_protocol::user_input::UserInput;
|
||||
use codex_rmcp_client::ElicitationAction;
|
||||
use codex_rmcp_client::ElicitationResponse;
|
||||
use codex_rmcp_client::OAuthCredentialsStoreMode;
|
||||
use codex_rmcp_client::RmcpClient;
|
||||
use codex_utils_cargo_bin::cargo_bin;
|
||||
use core_test_support::responses;
|
||||
use core_test_support::responses::mount_models_once;
|
||||
@@ -36,6 +39,13 @@ use core_test_support::skip_if_no_network;
|
||||
use core_test_support::stdio_server_bin;
|
||||
use core_test_support::test_codex::test_codex;
|
||||
use core_test_support::wait_for_event;
|
||||
use futures::FutureExt;
|
||||
use rmcp::model::ClientCapabilities;
|
||||
use rmcp::model::ElicitationCapability;
|
||||
use rmcp::model::FormElicitationCapability;
|
||||
use rmcp::model::Implementation;
|
||||
use rmcp::model::InitializeRequestParams;
|
||||
use rmcp::model::ProtocolVersion;
|
||||
use serde_json::Value;
|
||||
use serde_json::json;
|
||||
use serial_test::serial;
|
||||
@@ -1056,6 +1066,231 @@ async fn streamable_http_with_oauth_round_trip_impl() -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// This test writes to a fallback credentials file in CODEX_HOME.
|
||||
#[serial(codex_home)]
|
||||
#[test]
|
||||
fn streamable_http_with_oauth_refresh_adopts_rotated_credentials() -> anyhow::Result<()> {
|
||||
const TEST_STACK_SIZE_BYTES: usize = 8 * 1024 * 1024;
|
||||
|
||||
let handle = std::thread::Builder::new()
|
||||
.name("streamable_http_with_oauth_refresh_adopts_rotated_credentials".to_string())
|
||||
.stack_size(TEST_STACK_SIZE_BYTES)
|
||||
.spawn(|| -> anyhow::Result<()> {
|
||||
let runtime = tokio::runtime::Builder::new_multi_thread()
|
||||
.worker_threads(1)
|
||||
.enable_all()
|
||||
.build()?;
|
||||
runtime.block_on(streamable_http_with_oauth_refresh_adopts_rotated_credentials_impl())
|
||||
})?;
|
||||
|
||||
match handle.join() {
|
||||
Ok(result) => result,
|
||||
Err(_) => Err(anyhow::anyhow!(
|
||||
"streamable_http_with_oauth_refresh_adopts_rotated_credentials thread panicked"
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::expect_used)]
|
||||
async fn streamable_http_with_oauth_refresh_adopts_rotated_credentials_impl() -> anyhow::Result<()>
|
||||
{
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server_name = "rmcp_http_oauth_refresh_race";
|
||||
let initial_access_token = "initial-access-token";
|
||||
let initial_refresh_token = "initial-refresh-token";
|
||||
let rotated_access_token = "rotated-access-token";
|
||||
let rotated_refresh_token = "rotated-refresh-token";
|
||||
let rmcp_http_server_bin = match cargo_bin("test_streamable_http_server") {
|
||||
Ok(path) => path,
|
||||
Err(err) => {
|
||||
eprintln!("test_streamable_http_server binary not available, skipping test: {err}");
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
|
||||
let listener = TcpListener::bind("127.0.0.1:0")?;
|
||||
let port = listener.local_addr()?.port();
|
||||
drop(listener);
|
||||
let bind_addr = format!("127.0.0.1:{port}");
|
||||
let server_url = format!("http://{bind_addr}/mcp");
|
||||
|
||||
let mut http_server_child = Command::new(&rmcp_http_server_bin)
|
||||
.kill_on_drop(true)
|
||||
.env("MCP_STREAMABLE_HTTP_BIND_ADDR", &bind_addr)
|
||||
.env("MCP_EXPECT_BEARER", initial_access_token)
|
||||
.env("MCP_EXPECT_REFRESH_TOKEN", initial_refresh_token)
|
||||
.env("MCP_REFRESH_NEXT_ACCESS_TOKEN", rotated_access_token)
|
||||
.env("MCP_REFRESH_NEXT_REFRESH_TOKEN", rotated_refresh_token)
|
||||
.env("MCP_REFRESH_EXPIRES_IN", "3600")
|
||||
.env("MCP_REFRESH_SINGLE_USE", "1")
|
||||
.spawn()?;
|
||||
|
||||
wait_for_streamable_http_server(&mut http_server_child, &bind_addr, Duration::from_secs(5))
|
||||
.await?;
|
||||
|
||||
let temp_home = tempdir()?;
|
||||
let _guard = EnvVarGuard::set("CODEX_HOME", temp_home.path().as_os_str());
|
||||
let initial_expires_at = SystemTime::now()
|
||||
.checked_add(Duration::from_secs(1))
|
||||
.ok_or_else(|| anyhow::anyhow!("failed to compute expiry time"))?
|
||||
.duration_since(UNIX_EPOCH)?
|
||||
.as_millis() as u64;
|
||||
write_fallback_oauth_tokens_with_expiry(
|
||||
temp_home.path(),
|
||||
server_name,
|
||||
&server_url,
|
||||
"test-client-id",
|
||||
initial_access_token,
|
||||
initial_refresh_token,
|
||||
initial_expires_at,
|
||||
)?;
|
||||
|
||||
let client_a = RmcpClient::new_streamable_http_client(
|
||||
server_name,
|
||||
&server_url,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
OAuthCredentialsStoreMode::File,
|
||||
)
|
||||
.await?;
|
||||
let client_b = RmcpClient::new_streamable_http_client(
|
||||
server_name,
|
||||
&server_url,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
OAuthCredentialsStoreMode::File,
|
||||
)
|
||||
.await?;
|
||||
|
||||
client_a
|
||||
.initialize(
|
||||
rmcp_initialize_params(),
|
||||
Some(Duration::from_secs(5)),
|
||||
noop_send_elicitation(),
|
||||
)
|
||||
.await?;
|
||||
client_b
|
||||
.initialize(
|
||||
rmcp_initialize_params(),
|
||||
Some(Duration::from_secs(5)),
|
||||
noop_send_elicitation(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let tools_a = client_a
|
||||
.list_tools(None, Some(Duration::from_secs(5)))
|
||||
.await?;
|
||||
assert_eq!(tools_a.tools.len(), 1);
|
||||
assert_eq!(tools_a.tools[0].name.as_ref(), "echo");
|
||||
assert_stored_oauth_tokens(
|
||||
temp_home.path(),
|
||||
server_name,
|
||||
&server_url,
|
||||
rotated_access_token,
|
||||
rotated_refresh_token,
|
||||
)?;
|
||||
|
||||
let tools_b = client_b
|
||||
.list_tools(None, Some(Duration::from_secs(5)))
|
||||
.await?;
|
||||
assert_eq!(tools_b.tools.len(), 1);
|
||||
assert_eq!(tools_b.tools[0].name.as_ref(), "echo");
|
||||
assert_stored_oauth_tokens(
|
||||
temp_home.path(),
|
||||
server_name,
|
||||
&server_url,
|
||||
rotated_access_token,
|
||||
rotated_refresh_token,
|
||||
)?;
|
||||
|
||||
match http_server_child.try_wait() {
|
||||
Ok(Some(_)) => {}
|
||||
Ok(None) => {
|
||||
let _ = http_server_child.kill().await;
|
||||
}
|
||||
Err(error) => {
|
||||
eprintln!("failed to check streamable http oauth server status: {error}");
|
||||
let _ = http_server_child.kill().await;
|
||||
}
|
||||
}
|
||||
if let Err(error) = http_server_child.wait().await {
|
||||
eprintln!("failed to await streamable http oauth server shutdown: {error}");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn rmcp_initialize_params() -> InitializeRequestParams {
|
||||
InitializeRequestParams {
|
||||
meta: None,
|
||||
capabilities: ClientCapabilities {
|
||||
experimental: None,
|
||||
extensions: None,
|
||||
roots: None,
|
||||
sampling: None,
|
||||
elicitation: Some(ElicitationCapability {
|
||||
form: Some(FormElicitationCapability {
|
||||
schema_validation: None,
|
||||
}),
|
||||
url: None,
|
||||
}),
|
||||
tasks: None,
|
||||
},
|
||||
client_info: Implementation {
|
||||
name: "codex-test".into(),
|
||||
version: "0.0.0-test".into(),
|
||||
title: Some("Codex rmcp oauth refresh test".into()),
|
||||
description: None,
|
||||
icons: None,
|
||||
website_url: None,
|
||||
},
|
||||
protocol_version: ProtocolVersion::V_2025_06_18,
|
||||
}
|
||||
}
|
||||
|
||||
fn noop_send_elicitation() -> codex_rmcp_client::SendElicitation {
|
||||
Box::new(|_, _| {
|
||||
async {
|
||||
Ok(ElicitationResponse {
|
||||
action: ElicitationAction::Accept,
|
||||
content: Some(json!({})),
|
||||
})
|
||||
}
|
||||
.boxed()
|
||||
})
|
||||
}
|
||||
|
||||
fn assert_stored_oauth_tokens(
|
||||
home: &Path,
|
||||
server_name: &str,
|
||||
server_url: &str,
|
||||
expected_access_token: &str,
|
||||
expected_refresh_token: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
let file_path = home.join(".credentials.json");
|
||||
let stored: Value = serde_json::from_slice(&fs::read(&file_path)?)?;
|
||||
let entries = stored
|
||||
.as_object()
|
||||
.ok_or_else(|| anyhow::anyhow!("expected fallback OAuth credential map"))?;
|
||||
let has_expected_tokens = entries.values().any(|entry| {
|
||||
entry.as_object().is_some_and(|entry| {
|
||||
entry.get("server_name").and_then(Value::as_str) == Some(server_name)
|
||||
&& entry.get("server_url").and_then(Value::as_str) == Some(server_url)
|
||||
&& entry.get("access_token").and_then(Value::as_str) == Some(expected_access_token)
|
||||
&& entry.get("refresh_token").and_then(Value::as_str)
|
||||
== Some(expected_refresh_token)
|
||||
})
|
||||
});
|
||||
assert!(
|
||||
has_expected_tokens,
|
||||
"expected stored OAuth credentials for {server_name} at {server_url} to include access_token={expected_access_token} refresh_token={expected_refresh_token}, got {stored}",
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn wait_for_streamable_http_server(
|
||||
server_child: &mut Child,
|
||||
address: &str,
|
||||
@@ -1111,7 +1346,26 @@ fn write_fallback_oauth_tokens(
|
||||
.ok_or_else(|| anyhow::anyhow!("failed to compute expiry time"))?
|
||||
.duration_since(UNIX_EPOCH)?
|
||||
.as_millis() as u64;
|
||||
write_fallback_oauth_tokens_with_expiry(
|
||||
home,
|
||||
server_name,
|
||||
server_url,
|
||||
client_id,
|
||||
access_token,
|
||||
refresh_token,
|
||||
expires_at,
|
||||
)
|
||||
}
|
||||
|
||||
fn write_fallback_oauth_tokens_with_expiry(
|
||||
home: &Path,
|
||||
server_name: &str,
|
||||
server_url: &str,
|
||||
client_id: &str,
|
||||
access_token: &str,
|
||||
refresh_token: &str,
|
||||
expires_at: u64,
|
||||
) -> anyhow::Result<()> {
|
||||
let store = serde_json::json!({
|
||||
"stub": {
|
||||
"server_name": server_name,
|
||||
|
||||
@@ -6,6 +6,7 @@ use std::sync::Arc;
|
||||
|
||||
use axum::Router;
|
||||
use axum::body::Body;
|
||||
use axum::extract::Form;
|
||||
use axum::extract::State;
|
||||
use axum::http::Request;
|
||||
use axum::http::StatusCode;
|
||||
@@ -15,6 +16,7 @@ use axum::middleware;
|
||||
use axum::middleware::Next;
|
||||
use axum::response::Response;
|
||||
use axum::routing::get;
|
||||
use axum::routing::post;
|
||||
use rmcp::ErrorData as McpError;
|
||||
use rmcp::handler::server::ServerHandler;
|
||||
use rmcp::model::CallToolRequestParams;
|
||||
@@ -39,6 +41,8 @@ use rmcp::transport::StreamableHttpService;
|
||||
use rmcp::transport::streamable_http_server::session::local::LocalSessionManager;
|
||||
use serde::Deserialize;
|
||||
use serde_json::json;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::RwLock;
|
||||
use tokio::task;
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -48,6 +52,22 @@ struct TestToolServer {
|
||||
resource_templates: Arc<Vec<ResourceTemplate>>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct AuthState {
|
||||
current_bearer: Arc<RwLock<Option<String>>>,
|
||||
refresh_state: Option<Arc<Mutex<RefreshTokenState>>>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct RefreshTokenState {
|
||||
current_refresh_token: String,
|
||||
next_access_token: String,
|
||||
next_refresh_token: String,
|
||||
expires_in: u64,
|
||||
single_use: bool,
|
||||
used_once: bool,
|
||||
}
|
||||
|
||||
const MEMO_URI: &str = "memo://codex/example-note";
|
||||
const MEMO_CONTENT: &str = "This is a sample MCP resource served by the rmcp test server.";
|
||||
|
||||
@@ -263,6 +283,15 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
};
|
||||
eprintln!("starting rmcp streamable http test server on http://{bind_addr}/mcp");
|
||||
|
||||
let auth_state = AuthState {
|
||||
current_bearer: Arc::new(RwLock::new(
|
||||
std::env::var("MCP_EXPECT_BEARER")
|
||||
.ok()
|
||||
.map(|token| format!("Bearer {token}")),
|
||||
)),
|
||||
refresh_state: refresh_state_from_env(),
|
||||
};
|
||||
|
||||
let router = Router::new()
|
||||
.route(
|
||||
"/.well-known/oauth-authorization-server/mcp",
|
||||
@@ -284,6 +313,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
}
|
||||
}),
|
||||
)
|
||||
.route("/oauth/token", post(oauth_refresh_token))
|
||||
.nest_service(
|
||||
"/mcp",
|
||||
StreamableHttpService::new(
|
||||
@@ -291,28 +321,108 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
Arc::new(LocalSessionManager::default()),
|
||||
StreamableHttpServerConfig::default(),
|
||||
),
|
||||
);
|
||||
|
||||
let router = if let Ok(token) = std::env::var("MCP_EXPECT_BEARER") {
|
||||
let expected = Arc::new(format!("Bearer {token}"));
|
||||
router.layer(middleware::from_fn_with_state(expected, require_bearer))
|
||||
} else {
|
||||
router
|
||||
};
|
||||
)
|
||||
.with_state(auth_state.clone())
|
||||
.layer(middleware::from_fn_with_state(auth_state, require_bearer));
|
||||
|
||||
axum::serve(listener, router).await?;
|
||||
task::yield_now().await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn refresh_state_from_env() -> Option<Arc<Mutex<RefreshTokenState>>> {
|
||||
let current_refresh_token = std::env::var("MCP_EXPECT_REFRESH_TOKEN").ok()?;
|
||||
let next_access_token = std::env::var("MCP_REFRESH_NEXT_ACCESS_TOKEN").ok()?;
|
||||
let next_refresh_token = std::env::var("MCP_REFRESH_NEXT_REFRESH_TOKEN").ok()?;
|
||||
let expires_in = std::env::var("MCP_REFRESH_EXPIRES_IN")
|
||||
.ok()
|
||||
.and_then(|value| value.parse::<u64>().ok())
|
||||
.unwrap_or(3600);
|
||||
let single_use = std::env::var("MCP_REFRESH_SINGLE_USE")
|
||||
.ok()
|
||||
.is_some_and(|value| value == "1");
|
||||
|
||||
Some(Arc::new(Mutex::new(RefreshTokenState {
|
||||
current_refresh_token,
|
||||
next_access_token,
|
||||
next_refresh_token,
|
||||
expires_in,
|
||||
single_use,
|
||||
used_once: false,
|
||||
})))
|
||||
}
|
||||
|
||||
async fn oauth_refresh_token(
|
||||
State(state): State<AuthState>,
|
||||
Form(form): Form<HashMap<String, String>>,
|
||||
) -> Response {
|
||||
let Some(refresh_state) = state.refresh_state.clone() else {
|
||||
return json_response(StatusCode::NOT_FOUND, json!({ "error": "not_found" }));
|
||||
};
|
||||
|
||||
if form.get("grant_type").map(String::as_str) != Some("refresh_token") {
|
||||
return json_response(
|
||||
StatusCode::BAD_REQUEST,
|
||||
json!({ "error": "unsupported_grant_type" }),
|
||||
);
|
||||
}
|
||||
|
||||
let provided_refresh_token = form.get("refresh_token").map(String::as_str);
|
||||
let mut refresh_state = refresh_state.lock().await;
|
||||
if refresh_state.single_use && refresh_state.used_once {
|
||||
return json_response(
|
||||
StatusCode::UNAUTHORIZED,
|
||||
json!({
|
||||
"error": "invalid_grant",
|
||||
"error_description": "refresh token was already used",
|
||||
"code": "refresh_token_reused",
|
||||
}),
|
||||
);
|
||||
}
|
||||
if provided_refresh_token != Some(refresh_state.current_refresh_token.as_str()) {
|
||||
return json_response(
|
||||
StatusCode::UNAUTHORIZED,
|
||||
json!({
|
||||
"error": "invalid_grant",
|
||||
"error_description": "refresh token was already used",
|
||||
"code": "refresh_token_reused",
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
let access_token = refresh_state.next_access_token.clone();
|
||||
let refresh_token = refresh_state.next_refresh_token.clone();
|
||||
let expires_in = refresh_state.expires_in;
|
||||
refresh_state.current_refresh_token = refresh_token.clone();
|
||||
refresh_state.used_once = true;
|
||||
*state.current_bearer.write().await = Some(format!("Bearer {access_token}"));
|
||||
|
||||
json_response(
|
||||
StatusCode::OK,
|
||||
json!({
|
||||
"access_token": access_token,
|
||||
"token_type": "Bearer",
|
||||
"refresh_token": refresh_token,
|
||||
"expires_in": expires_in,
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
async fn require_bearer(
|
||||
State(expected): State<Arc<String>>,
|
||||
State(state): State<AuthState>,
|
||||
request: Request<Body>,
|
||||
next: Next,
|
||||
) -> Result<Response, StatusCode> {
|
||||
if request.uri().path().contains("/.well-known/") {
|
||||
let request_path = request.uri().path();
|
||||
if request_path.contains("/.well-known/") || request_path.contains("/oauth/token") {
|
||||
return Ok(next.run(request).await);
|
||||
}
|
||||
|
||||
let expected = state.current_bearer.read().await.clone();
|
||||
let Some(expected) = expected else {
|
||||
return Ok(next.run(request).await);
|
||||
};
|
||||
|
||||
if request
|
||||
.headers()
|
||||
.get(AUTHORIZATION)
|
||||
@@ -323,3 +433,14 @@ async fn require_bearer(
|
||||
Err(StatusCode::UNAUTHORIZED)
|
||||
}
|
||||
}
|
||||
|
||||
fn json_response(status: StatusCode, body: serde_json::Value) -> Response {
|
||||
#[expect(clippy::expect_used)]
|
||||
Response::builder()
|
||||
.status(status)
|
||||
.header(CONTENT_TYPE, "application/json")
|
||||
.body(Body::from(
|
||||
serde_json::to_vec(&body).expect("failed to serialize JSON response"),
|
||||
))
|
||||
.expect("valid JSON response")
|
||||
}
|
||||
|
||||
@@ -25,7 +25,10 @@ use oauth2::RefreshToken;
|
||||
use oauth2::Scope;
|
||||
use oauth2::TokenResponse;
|
||||
use oauth2::basic::BasicTokenType;
|
||||
use rmcp::transport::auth::CredentialStore;
|
||||
use rmcp::transport::auth::InMemoryCredentialStore;
|
||||
use rmcp::transport::auth::OAuthTokenResponse;
|
||||
use rmcp::transport::auth::StoredCredentials;
|
||||
use schemars::JsonSchema;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
@@ -273,15 +276,32 @@ struct OAuthPersistorInner {
|
||||
server_name: String,
|
||||
url: String,
|
||||
authorization_manager: Arc<Mutex<AuthorizationManager>>,
|
||||
runtime_credentials: InMemoryCredentialStore,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
last_credentials: Mutex<Option<StoredOAuthTokens>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
enum GuardedRefreshOutcome {
|
||||
NoAction,
|
||||
ReloadedChanged(StoredOAuthTokens),
|
||||
ReloadedNoChange,
|
||||
MissingOrInvalid,
|
||||
ReloadFailed,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq)]
|
||||
enum GuardedRefreshPersistedCredentials {
|
||||
Loaded(Option<StoredOAuthTokens>),
|
||||
ReloadFailed,
|
||||
}
|
||||
|
||||
impl OAuthPersistor {
|
||||
pub(crate) fn new(
|
||||
server_name: String,
|
||||
url: String,
|
||||
authorization_manager: Arc<Mutex<AuthorizationManager>>,
|
||||
runtime_credentials: InMemoryCredentialStore,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
initial_credentials: Option<StoredOAuthTokens>,
|
||||
) -> Self {
|
||||
@@ -290,6 +310,7 @@ impl OAuthPersistor {
|
||||
server_name,
|
||||
url,
|
||||
authorization_manager,
|
||||
runtime_credentials,
|
||||
store_mode,
|
||||
last_credentials: Mutex::new(initial_credentials),
|
||||
}),
|
||||
@@ -350,28 +371,220 @@ impl OAuthPersistor {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Guard refreshes against multi-process refresh-token reuse.
|
||||
///
|
||||
/// MCP OAuth credentials live in shared storage, but each Codex process also keeps an
|
||||
/// in-memory snapshot. Before refreshing, reload the shared credentials and compare them to
|
||||
/// the cached copy:
|
||||
/// - if the local cache was cleared, reload shared storage first so this process can recover
|
||||
/// when another process logs in and persists fresh credentials;
|
||||
/// - if shared storage changed, another process already refreshed, so adopt those credentials
|
||||
/// in the live runtime and skip the local refresh;
|
||||
/// - if shared storage is unchanged, this process still owns the refresh and can rotate the
|
||||
/// tokens with the authority;
|
||||
/// - if shared storage no longer has credentials, treat that as logged out and clear the live
|
||||
/// runtime instead of sending a stale refresh token.
|
||||
pub(crate) async fn refresh_if_needed(&self) -> Result<()> {
|
||||
let expires_at = {
|
||||
let mut cached_credentials = {
|
||||
let guard = self.inner.last_credentials.lock().await;
|
||||
guard.as_ref().and_then(|tokens| tokens.expires_at)
|
||||
guard.clone()
|
||||
};
|
||||
|
||||
if !token_needs_refresh(expires_at) {
|
||||
return Ok(());
|
||||
if cached_credentials.is_none()
|
||||
&& let Some(credentials) = load_oauth_tokens_when_cache_missing(
|
||||
&self.inner.server_name,
|
||||
&self.inner.url,
|
||||
self.inner.store_mode,
|
||||
)
|
||||
{
|
||||
self.apply_runtime_credentials(Some(credentials.clone()))
|
||||
.await?;
|
||||
cached_credentials = Some(credentials);
|
||||
}
|
||||
|
||||
match self.guarded_refresh_outcome(cached_credentials.as_ref()) {
|
||||
GuardedRefreshOutcome::NoAction => Ok(()),
|
||||
GuardedRefreshOutcome::ReloadedChanged(credentials) => {
|
||||
self.apply_runtime_credentials(Some(credentials)).await
|
||||
}
|
||||
GuardedRefreshOutcome::ReloadedNoChange => {
|
||||
{
|
||||
let manager = self.inner.authorization_manager.clone();
|
||||
let guard = manager.lock().await;
|
||||
guard.refresh_token().await.with_context(|| {
|
||||
format!(
|
||||
"failed to refresh OAuth tokens for server {}",
|
||||
self.inner.server_name
|
||||
)
|
||||
})?;
|
||||
}
|
||||
|
||||
self.persist_if_needed().await
|
||||
}
|
||||
GuardedRefreshOutcome::MissingOrInvalid => self.apply_runtime_credentials(None).await,
|
||||
GuardedRefreshOutcome::ReloadFailed => Ok(()),
|
||||
}
|
||||
}
|
||||
|
||||
fn guarded_refresh_outcome(
|
||||
&self,
|
||||
cached_credentials: Option<&StoredOAuthTokens>,
|
||||
) -> GuardedRefreshOutcome {
|
||||
let Some(cached_credentials) = cached_credentials else {
|
||||
return GuardedRefreshOutcome::NoAction;
|
||||
};
|
||||
|
||||
if !token_needs_refresh(cached_credentials.expires_at) {
|
||||
return GuardedRefreshOutcome::NoAction;
|
||||
}
|
||||
|
||||
match load_oauth_tokens_for_guarded_refresh(
|
||||
&self.inner.server_name,
|
||||
&self.inner.url,
|
||||
self.inner.store_mode,
|
||||
) {
|
||||
GuardedRefreshPersistedCredentials::Loaded(persisted_credentials) => {
|
||||
determine_guarded_refresh_outcome(cached_credentials, persisted_credentials)
|
||||
}
|
||||
GuardedRefreshPersistedCredentials::ReloadFailed => GuardedRefreshOutcome::ReloadFailed,
|
||||
}
|
||||
}
|
||||
|
||||
async fn apply_runtime_credentials(
|
||||
&self,
|
||||
credentials: Option<StoredOAuthTokens>,
|
||||
) -> Result<()> {
|
||||
{
|
||||
let manager = self.inner.authorization_manager.clone();
|
||||
let guard = manager.lock().await;
|
||||
guard.refresh_token().await.with_context(|| {
|
||||
format!(
|
||||
"failed to refresh OAuth tokens for server {}",
|
||||
self.inner.server_name
|
||||
)
|
||||
})?;
|
||||
let mut guard = manager.lock().await;
|
||||
|
||||
match credentials.as_ref() {
|
||||
Some(credentials) => {
|
||||
self.inner
|
||||
.runtime_credentials
|
||||
.save(StoredCredentials {
|
||||
client_id: credentials.client_id.clone(),
|
||||
token_response: Some(credentials.token_response.0.clone()),
|
||||
})
|
||||
.await?;
|
||||
guard
|
||||
.configure_client_id(&credentials.client_id)
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"failed to reconfigure OAuth client for server {}",
|
||||
self.inner.server_name
|
||||
)
|
||||
})?;
|
||||
}
|
||||
None => {
|
||||
self.inner.runtime_credentials.clear().await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
self.persist_if_needed().await
|
||||
let mut last_credentials = self.inner.last_credentials.lock().await;
|
||||
*last_credentials = credentials;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn load_oauth_tokens_for_guarded_refresh(
|
||||
server_name: &str,
|
||||
url: &str,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
) -> GuardedRefreshPersistedCredentials {
|
||||
let keyring_store = DefaultKeyringStore;
|
||||
match store_mode {
|
||||
OAuthCredentialsStoreMode::Auto => {
|
||||
load_oauth_tokens_for_guarded_refresh_with_keyring_fallback(
|
||||
&keyring_store,
|
||||
server_name,
|
||||
url,
|
||||
)
|
||||
}
|
||||
OAuthCredentialsStoreMode::File => guarded_refresh_persisted_credentials_from_load_result(
|
||||
load_oauth_tokens_from_file(server_name, url),
|
||||
server_name,
|
||||
),
|
||||
OAuthCredentialsStoreMode::Keyring => {
|
||||
guarded_refresh_persisted_credentials_from_load_result(
|
||||
load_oauth_tokens_from_keyring(&keyring_store, server_name, url)
|
||||
.with_context(|| "failed to read OAuth tokens from keyring".to_string()),
|
||||
server_name,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn load_oauth_tokens_when_cache_missing(
|
||||
server_name: &str,
|
||||
url: &str,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
) -> Option<StoredOAuthTokens> {
|
||||
match load_oauth_tokens_for_guarded_refresh(server_name, url, store_mode) {
|
||||
GuardedRefreshPersistedCredentials::Loaded(Some(credentials)) => Some(credentials),
|
||||
GuardedRefreshPersistedCredentials::Loaded(None)
|
||||
| GuardedRefreshPersistedCredentials::ReloadFailed => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn load_oauth_tokens_for_guarded_refresh_with_keyring_fallback<K: KeyringStore>(
|
||||
keyring_store: &K,
|
||||
server_name: &str,
|
||||
url: &str,
|
||||
) -> GuardedRefreshPersistedCredentials {
|
||||
match load_oauth_tokens_from_keyring(keyring_store, server_name, url) {
|
||||
Ok(Some(tokens)) => GuardedRefreshPersistedCredentials::Loaded(Some(tokens)),
|
||||
Ok(None) => guarded_refresh_persisted_credentials_from_load_result(
|
||||
load_oauth_tokens_from_file(server_name, url),
|
||||
server_name,
|
||||
),
|
||||
Err(error) => {
|
||||
warn!("failed to read OAuth tokens from keyring: {error}");
|
||||
match load_oauth_tokens_from_file(server_name, url) {
|
||||
Ok(Some(tokens)) => GuardedRefreshPersistedCredentials::Loaded(Some(tokens)),
|
||||
Ok(None) => {
|
||||
warn!(
|
||||
"failed to reload OAuth tokens for server {server_name}: keyring read failed and no fallback file credentials were available"
|
||||
);
|
||||
GuardedRefreshPersistedCredentials::ReloadFailed
|
||||
}
|
||||
Err(file_error) => {
|
||||
warn!(
|
||||
"failed to reload OAuth tokens for server {server_name}: keyring read failed ({error}) and fallback file reload failed: {file_error}"
|
||||
);
|
||||
GuardedRefreshPersistedCredentials::ReloadFailed
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn guarded_refresh_outcome_from_load_result(
|
||||
cached_credentials: &StoredOAuthTokens,
|
||||
persisted_credentials: Result<Option<StoredOAuthTokens>>,
|
||||
server_name: &str,
|
||||
) -> GuardedRefreshOutcome {
|
||||
match guarded_refresh_persisted_credentials_from_load_result(persisted_credentials, server_name)
|
||||
{
|
||||
GuardedRefreshPersistedCredentials::Loaded(persisted_credentials) => {
|
||||
determine_guarded_refresh_outcome(cached_credentials, persisted_credentials)
|
||||
}
|
||||
GuardedRefreshPersistedCredentials::ReloadFailed => GuardedRefreshOutcome::ReloadFailed,
|
||||
}
|
||||
}
|
||||
|
||||
fn guarded_refresh_persisted_credentials_from_load_result(
|
||||
persisted_credentials: Result<Option<StoredOAuthTokens>>,
|
||||
server_name: &str,
|
||||
) -> GuardedRefreshPersistedCredentials {
|
||||
match persisted_credentials {
|
||||
Ok(credentials) => GuardedRefreshPersistedCredentials::Loaded(credentials),
|
||||
Err(error) => {
|
||||
warn!("failed to reload OAuth tokens for server {server_name}: {error}");
|
||||
GuardedRefreshPersistedCredentials::ReloadFailed
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -521,6 +734,61 @@ fn token_needs_refresh(expires_at: Option<u64>) -> bool {
|
||||
now.saturating_add(REFRESH_SKEW_MILLIS) >= expires_at
|
||||
}
|
||||
|
||||
fn determine_guarded_refresh_outcome(
|
||||
cached_credentials: &StoredOAuthTokens,
|
||||
persisted_credentials: Option<StoredOAuthTokens>,
|
||||
) -> GuardedRefreshOutcome {
|
||||
match persisted_credentials {
|
||||
Some(persisted_credentials)
|
||||
if oauth_tokens_equal_for_refresh(
|
||||
Some(cached_credentials),
|
||||
Some(&persisted_credentials),
|
||||
) =>
|
||||
{
|
||||
GuardedRefreshOutcome::ReloadedNoChange
|
||||
}
|
||||
Some(persisted_credentials) => {
|
||||
GuardedRefreshOutcome::ReloadedChanged(persisted_credentials)
|
||||
}
|
||||
None => GuardedRefreshOutcome::MissingOrInvalid,
|
||||
}
|
||||
}
|
||||
|
||||
fn oauth_tokens_equal_for_refresh(
|
||||
left: Option<&StoredOAuthTokens>,
|
||||
right: Option<&StoredOAuthTokens>,
|
||||
) -> bool {
|
||||
match (left, right) {
|
||||
(None, None) => true,
|
||||
(Some(left), Some(right)) => {
|
||||
left.server_name == right.server_name
|
||||
&& left.url == right.url
|
||||
&& left.client_id == right.client_id
|
||||
&& left.expires_at == right.expires_at
|
||||
&& oauth_token_responses_equal_for_refresh(
|
||||
&left.token_response,
|
||||
&right.token_response,
|
||||
)
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn oauth_token_responses_equal_for_refresh(
|
||||
left: &WrappedOAuthTokenResponse,
|
||||
right: &WrappedOAuthTokenResponse,
|
||||
) -> bool {
|
||||
let left = &left.0;
|
||||
let right = &right.0;
|
||||
|
||||
left.access_token().secret() == right.access_token().secret()
|
||||
&& left.token_type() == right.token_type()
|
||||
&& left.refresh_token().map(RefreshToken::secret)
|
||||
== right.refresh_token().map(RefreshToken::secret)
|
||||
&& left.scopes() == right.scopes()
|
||||
&& left.extra_fields() == right.extra_fields()
|
||||
}
|
||||
|
||||
fn compute_store_key(server_name: &str, server_url: &str) -> Result<String> {
|
||||
let mut payload = JsonMap::new();
|
||||
payload.insert(
|
||||
@@ -855,6 +1123,158 @@ mod tests {
|
||||
assert!(tokens.token_response.0.expires_in().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn guarded_refresh_outcome_reloads_when_persisted_credentials_changed() {
|
||||
let cached = sample_tokens();
|
||||
let mut persisted = sample_tokens();
|
||||
persisted
|
||||
.token_response
|
||||
.0
|
||||
.set_refresh_token(Some(RefreshToken::new("rotated-refresh-token".to_string())));
|
||||
persisted
|
||||
.token_response
|
||||
.0
|
||||
.set_expires_in(Some(&Duration::from_secs(7200)));
|
||||
persisted.expires_at = super::compute_expires_at_millis(&persisted.token_response.0);
|
||||
|
||||
assert_eq!(
|
||||
super::determine_guarded_refresh_outcome(&cached, Some(persisted.clone())),
|
||||
super::GuardedRefreshOutcome::ReloadedChanged(persisted),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn guarded_refresh_outcome_refreshes_when_persisted_credentials_match() {
|
||||
let cached = sample_tokens();
|
||||
let mut persisted = cached.clone();
|
||||
persisted
|
||||
.token_response
|
||||
.0
|
||||
.set_expires_in(Some(&Duration::from_secs(5)));
|
||||
|
||||
assert_eq!(
|
||||
super::determine_guarded_refresh_outcome(&cached, Some(persisted)),
|
||||
super::GuardedRefreshOutcome::ReloadedNoChange,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn guarded_refresh_outcome_clears_when_persisted_credentials_are_missing() {
|
||||
assert_eq!(
|
||||
super::determine_guarded_refresh_outcome(&sample_tokens(), None),
|
||||
super::GuardedRefreshOutcome::MissingOrInvalid,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn guarded_refresh_outcome_keeps_state_recoverable_when_reload_fails() {
|
||||
let error = anyhow::anyhow!("transient read failure");
|
||||
|
||||
assert_eq!(
|
||||
super::guarded_refresh_outcome_from_load_result(
|
||||
&sample_tokens(),
|
||||
Err(error),
|
||||
"test-server",
|
||||
),
|
||||
super::GuardedRefreshOutcome::ReloadFailed,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn guarded_refresh_auto_load_keeps_state_recoverable_when_keyring_fails_without_file() {
|
||||
let _env = TempCodexHome::new();
|
||||
let store = MockKeyringStore::default();
|
||||
let tokens = sample_tokens();
|
||||
let key = super::compute_store_key(&tokens.server_name, &tokens.url)
|
||||
.expect("store key should compute");
|
||||
store.set_error(&key, KeyringError::Invalid("error".into(), "load".into()));
|
||||
|
||||
assert_eq!(
|
||||
super::load_oauth_tokens_for_guarded_refresh_with_keyring_fallback(
|
||||
&store,
|
||||
&tokens.server_name,
|
||||
&tokens.url,
|
||||
),
|
||||
super::GuardedRefreshPersistedCredentials::ReloadFailed,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn missing_cached_credentials_reload_shared_store_from_file() -> Result<()> {
|
||||
let _env = TempCodexHome::new();
|
||||
let tokens = sample_tokens();
|
||||
let expected = tokens.clone();
|
||||
super::save_oauth_tokens_to_file(&tokens)?;
|
||||
|
||||
let loaded = super::load_oauth_tokens_when_cache_missing(
|
||||
&tokens.server_name,
|
||||
&tokens.url,
|
||||
OAuthCredentialsStoreMode::File,
|
||||
)
|
||||
.expect("tokens should reload from shared file store");
|
||||
assert_tokens_match_without_expiry(&loaded, &expected);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn missing_cached_credentials_ignore_reload_failures() {
|
||||
let _env = TempCodexHome::new();
|
||||
let store = MockKeyringStore::default();
|
||||
let tokens = sample_tokens();
|
||||
let key = super::compute_store_key(&tokens.server_name, &tokens.url)
|
||||
.expect("store key should compute");
|
||||
store.set_error(&key, KeyringError::Invalid("error".into(), "load".into()));
|
||||
|
||||
assert_eq!(
|
||||
super::load_oauth_tokens_for_guarded_refresh_with_keyring_fallback(
|
||||
&store,
|
||||
&tokens.server_name,
|
||||
&tokens.url,
|
||||
),
|
||||
super::GuardedRefreshPersistedCredentials::ReloadFailed,
|
||||
);
|
||||
assert_eq!(
|
||||
super::load_oauth_tokens_when_cache_missing(
|
||||
&tokens.server_name,
|
||||
&tokens.url,
|
||||
OAuthCredentialsStoreMode::Auto,
|
||||
),
|
||||
None,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn oauth_tokens_equal_for_refresh_ignores_only_expires_in() {
|
||||
let left = sample_tokens();
|
||||
let mut right = left.clone();
|
||||
right
|
||||
.token_response
|
||||
.0
|
||||
.set_expires_in(Some(&Duration::from_secs(5)));
|
||||
|
||||
assert!(super::oauth_tokens_equal_for_refresh(
|
||||
Some(&left),
|
||||
Some(&right),
|
||||
));
|
||||
|
||||
let mut different_refresh_token = right.clone();
|
||||
different_refresh_token
|
||||
.token_response
|
||||
.0
|
||||
.set_refresh_token(Some(RefreshToken::new("different-refresh".to_string())));
|
||||
assert!(!super::oauth_tokens_equal_for_refresh(
|
||||
Some(&left),
|
||||
Some(&different_refresh_token),
|
||||
));
|
||||
|
||||
let mut different_expiry = right;
|
||||
different_expiry.expires_at = different_expiry.expires_at.map(|value| value + 1000);
|
||||
assert!(!super::oauth_tokens_equal_for_refresh(
|
||||
Some(&left),
|
||||
Some(&different_expiry),
|
||||
));
|
||||
}
|
||||
|
||||
fn assert_tokens_match_without_expiry(
|
||||
actual: &StoredOAuthTokens,
|
||||
expected: &StoredOAuthTokens,
|
||||
|
||||
@@ -39,7 +39,10 @@ use rmcp::service::{self};
|
||||
use rmcp::transport::StreamableHttpClientTransport;
|
||||
use rmcp::transport::auth::AuthClient;
|
||||
use rmcp::transport::auth::AuthError;
|
||||
use rmcp::transport::auth::OAuthState;
|
||||
use rmcp::transport::auth::AuthorizationManager;
|
||||
use rmcp::transport::auth::CredentialStore;
|
||||
use rmcp::transport::auth::InMemoryCredentialStore;
|
||||
use rmcp::transport::auth::StoredCredentials;
|
||||
use rmcp::transport::child_process::TokioChildProcess;
|
||||
use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig;
|
||||
use serde_json::Value;
|
||||
@@ -358,6 +361,12 @@ impl RmcpClient {
|
||||
}
|
||||
};
|
||||
|
||||
if let Some(runtime) = &oauth_persistor
|
||||
&& let Err(error) = runtime.refresh_if_needed().await
|
||||
{
|
||||
warn!("failed to refresh OAuth tokens before initialize: {error}");
|
||||
}
|
||||
|
||||
let service = match timeout {
|
||||
Some(duration) => time::timeout(duration, transport)
|
||||
.await
|
||||
@@ -595,22 +604,20 @@ async fn create_oauth_transport_and_runtime(
|
||||
)> {
|
||||
let http_client =
|
||||
apply_default_headers(reqwest::Client::builder(), &default_headers).build()?;
|
||||
let mut oauth_state = OAuthState::new(url.to_string(), Some(http_client.clone())).await?;
|
||||
|
||||
oauth_state
|
||||
.set_credentials(
|
||||
&initial_tokens.client_id,
|
||||
initial_tokens.token_response.0.clone(),
|
||||
)
|
||||
let runtime_credentials = InMemoryCredentialStore::new();
|
||||
runtime_credentials
|
||||
.save(StoredCredentials {
|
||||
client_id: initial_tokens.client_id.clone(),
|
||||
token_response: Some(initial_tokens.token_response.0.clone()),
|
||||
})
|
||||
.await?;
|
||||
|
||||
let manager = match oauth_state {
|
||||
OAuthState::Authorized(manager) => manager,
|
||||
OAuthState::Unauthorized(manager) => manager,
|
||||
OAuthState::Session(_) | OAuthState::AuthorizedHttpClient(_) => {
|
||||
return Err(anyhow!("unexpected OAuth state during client setup"));
|
||||
}
|
||||
};
|
||||
let mut manager = AuthorizationManager::new(url.to_string()).await?;
|
||||
manager.set_credential_store(runtime_credentials.clone());
|
||||
manager.with_client(http_client.clone())?;
|
||||
let metadata = manager.discover_metadata().await?;
|
||||
manager.set_metadata(metadata);
|
||||
manager.configure_client_id(&initial_tokens.client_id)?;
|
||||
|
||||
let auth_client = AuthClient::new(http_client, manager);
|
||||
let auth_manager = auth_client.auth_manager.clone();
|
||||
@@ -624,6 +631,7 @@ async fn create_oauth_transport_and_runtime(
|
||||
server_name.to_string(),
|
||||
url.to_string(),
|
||||
auth_manager,
|
||||
runtime_credentials,
|
||||
credentials_store,
|
||||
Some(initial_tokens),
|
||||
);
|
||||
|
||||
Reference in New Issue
Block a user