Compare commits

...

4 Commits

Author SHA1 Message Date
Eric Traut
6feecbf5d1 codex: address PR review feedback (#12815) 2026-02-25 15:22:28 -08:00
Eric Traut
e7479ee1cc codex: address PR review feedback (#12815) 2026-02-25 14:20:57 -08:00
Eric Traut
78249f7fce codex: address PR review feedback (#12815) 2026-02-25 13:55:02 -08:00
Eric Traut
c47fd545dd Add guarded MCP OAuth refresh flow 2026-02-25 12:40:23 -08:00
4 changed files with 841 additions and 38 deletions

View File

@@ -12,7 +12,6 @@ use codex_core::CodexAuth;
use codex_core::config::types::McpServerConfig;
use codex_core::config::types::McpServerTransportConfig;
use codex_core::models_manager::manager::RefreshStrategy;
use codex_protocol::config_types::ReasoningSummary;
use codex_protocol::openai_models::ConfigShellToolType;
use codex_protocol::openai_models::InputModality;
@@ -28,6 +27,10 @@ use codex_protocol::protocol::McpToolCallBeginEvent;
use codex_protocol::protocol::Op;
use codex_protocol::protocol::SandboxPolicy;
use codex_protocol::user_input::UserInput;
use codex_rmcp_client::ElicitationAction;
use codex_rmcp_client::ElicitationResponse;
use codex_rmcp_client::OAuthCredentialsStoreMode;
use codex_rmcp_client::RmcpClient;
use codex_utils_cargo_bin::cargo_bin;
use core_test_support::responses;
use core_test_support::responses::mount_models_once;
@@ -36,6 +39,13 @@ use core_test_support::skip_if_no_network;
use core_test_support::stdio_server_bin;
use core_test_support::test_codex::test_codex;
use core_test_support::wait_for_event;
use futures::FutureExt;
use rmcp::model::ClientCapabilities;
use rmcp::model::ElicitationCapability;
use rmcp::model::FormElicitationCapability;
use rmcp::model::Implementation;
use rmcp::model::InitializeRequestParams;
use rmcp::model::ProtocolVersion;
use serde_json::Value;
use serde_json::json;
use serial_test::serial;
@@ -1056,6 +1066,231 @@ async fn streamable_http_with_oauth_round_trip_impl() -> anyhow::Result<()> {
Ok(())
}
/// This test writes to a fallback credentials file in CODEX_HOME.
#[serial(codex_home)]
#[test]
fn streamable_http_with_oauth_refresh_adopts_rotated_credentials() -> anyhow::Result<()> {
const TEST_STACK_SIZE_BYTES: usize = 8 * 1024 * 1024;
let handle = std::thread::Builder::new()
.name("streamable_http_with_oauth_refresh_adopts_rotated_credentials".to_string())
.stack_size(TEST_STACK_SIZE_BYTES)
.spawn(|| -> anyhow::Result<()> {
let runtime = tokio::runtime::Builder::new_multi_thread()
.worker_threads(1)
.enable_all()
.build()?;
runtime.block_on(streamable_http_with_oauth_refresh_adopts_rotated_credentials_impl())
})?;
match handle.join() {
Ok(result) => result,
Err(_) => Err(anyhow::anyhow!(
"streamable_http_with_oauth_refresh_adopts_rotated_credentials thread panicked"
)),
}
}
#[allow(clippy::expect_used)]
async fn streamable_http_with_oauth_refresh_adopts_rotated_credentials_impl() -> anyhow::Result<()>
{
skip_if_no_network!(Ok(()));
let server_name = "rmcp_http_oauth_refresh_race";
let initial_access_token = "initial-access-token";
let initial_refresh_token = "initial-refresh-token";
let rotated_access_token = "rotated-access-token";
let rotated_refresh_token = "rotated-refresh-token";
let rmcp_http_server_bin = match cargo_bin("test_streamable_http_server") {
Ok(path) => path,
Err(err) => {
eprintln!("test_streamable_http_server binary not available, skipping test: {err}");
return Ok(());
}
};
let listener = TcpListener::bind("127.0.0.1:0")?;
let port = listener.local_addr()?.port();
drop(listener);
let bind_addr = format!("127.0.0.1:{port}");
let server_url = format!("http://{bind_addr}/mcp");
let mut http_server_child = Command::new(&rmcp_http_server_bin)
.kill_on_drop(true)
.env("MCP_STREAMABLE_HTTP_BIND_ADDR", &bind_addr)
.env("MCP_EXPECT_BEARER", initial_access_token)
.env("MCP_EXPECT_REFRESH_TOKEN", initial_refresh_token)
.env("MCP_REFRESH_NEXT_ACCESS_TOKEN", rotated_access_token)
.env("MCP_REFRESH_NEXT_REFRESH_TOKEN", rotated_refresh_token)
.env("MCP_REFRESH_EXPIRES_IN", "3600")
.env("MCP_REFRESH_SINGLE_USE", "1")
.spawn()?;
wait_for_streamable_http_server(&mut http_server_child, &bind_addr, Duration::from_secs(5))
.await?;
let temp_home = tempdir()?;
let _guard = EnvVarGuard::set("CODEX_HOME", temp_home.path().as_os_str());
let initial_expires_at = SystemTime::now()
.checked_add(Duration::from_secs(1))
.ok_or_else(|| anyhow::anyhow!("failed to compute expiry time"))?
.duration_since(UNIX_EPOCH)?
.as_millis() as u64;
write_fallback_oauth_tokens_with_expiry(
temp_home.path(),
server_name,
&server_url,
"test-client-id",
initial_access_token,
initial_refresh_token,
initial_expires_at,
)?;
let client_a = RmcpClient::new_streamable_http_client(
server_name,
&server_url,
None,
None,
None,
OAuthCredentialsStoreMode::File,
)
.await?;
let client_b = RmcpClient::new_streamable_http_client(
server_name,
&server_url,
None,
None,
None,
OAuthCredentialsStoreMode::File,
)
.await?;
client_a
.initialize(
rmcp_initialize_params(),
Some(Duration::from_secs(5)),
noop_send_elicitation(),
)
.await?;
client_b
.initialize(
rmcp_initialize_params(),
Some(Duration::from_secs(5)),
noop_send_elicitation(),
)
.await?;
let tools_a = client_a
.list_tools(None, Some(Duration::from_secs(5)))
.await?;
assert_eq!(tools_a.tools.len(), 1);
assert_eq!(tools_a.tools[0].name.as_ref(), "echo");
assert_stored_oauth_tokens(
temp_home.path(),
server_name,
&server_url,
rotated_access_token,
rotated_refresh_token,
)?;
let tools_b = client_b
.list_tools(None, Some(Duration::from_secs(5)))
.await?;
assert_eq!(tools_b.tools.len(), 1);
assert_eq!(tools_b.tools[0].name.as_ref(), "echo");
assert_stored_oauth_tokens(
temp_home.path(),
server_name,
&server_url,
rotated_access_token,
rotated_refresh_token,
)?;
match http_server_child.try_wait() {
Ok(Some(_)) => {}
Ok(None) => {
let _ = http_server_child.kill().await;
}
Err(error) => {
eprintln!("failed to check streamable http oauth server status: {error}");
let _ = http_server_child.kill().await;
}
}
if let Err(error) = http_server_child.wait().await {
eprintln!("failed to await streamable http oauth server shutdown: {error}");
}
Ok(())
}
fn rmcp_initialize_params() -> InitializeRequestParams {
InitializeRequestParams {
meta: None,
capabilities: ClientCapabilities {
experimental: None,
extensions: None,
roots: None,
sampling: None,
elicitation: Some(ElicitationCapability {
form: Some(FormElicitationCapability {
schema_validation: None,
}),
url: None,
}),
tasks: None,
},
client_info: Implementation {
name: "codex-test".into(),
version: "0.0.0-test".into(),
title: Some("Codex rmcp oauth refresh test".into()),
description: None,
icons: None,
website_url: None,
},
protocol_version: ProtocolVersion::V_2025_06_18,
}
}
fn noop_send_elicitation() -> codex_rmcp_client::SendElicitation {
Box::new(|_, _| {
async {
Ok(ElicitationResponse {
action: ElicitationAction::Accept,
content: Some(json!({})),
})
}
.boxed()
})
}
fn assert_stored_oauth_tokens(
home: &Path,
server_name: &str,
server_url: &str,
expected_access_token: &str,
expected_refresh_token: &str,
) -> anyhow::Result<()> {
let file_path = home.join(".credentials.json");
let stored: Value = serde_json::from_slice(&fs::read(&file_path)?)?;
let entries = stored
.as_object()
.ok_or_else(|| anyhow::anyhow!("expected fallback OAuth credential map"))?;
let has_expected_tokens = entries.values().any(|entry| {
entry.as_object().is_some_and(|entry| {
entry.get("server_name").and_then(Value::as_str) == Some(server_name)
&& entry.get("server_url").and_then(Value::as_str) == Some(server_url)
&& entry.get("access_token").and_then(Value::as_str) == Some(expected_access_token)
&& entry.get("refresh_token").and_then(Value::as_str)
== Some(expected_refresh_token)
})
});
assert!(
has_expected_tokens,
"expected stored OAuth credentials for {server_name} at {server_url} to include access_token={expected_access_token} refresh_token={expected_refresh_token}, got {stored}",
);
Ok(())
}
async fn wait_for_streamable_http_server(
server_child: &mut Child,
address: &str,
@@ -1111,7 +1346,26 @@ fn write_fallback_oauth_tokens(
.ok_or_else(|| anyhow::anyhow!("failed to compute expiry time"))?
.duration_since(UNIX_EPOCH)?
.as_millis() as u64;
write_fallback_oauth_tokens_with_expiry(
home,
server_name,
server_url,
client_id,
access_token,
refresh_token,
expires_at,
)
}
fn write_fallback_oauth_tokens_with_expiry(
home: &Path,
server_name: &str,
server_url: &str,
client_id: &str,
access_token: &str,
refresh_token: &str,
expires_at: u64,
) -> anyhow::Result<()> {
let store = serde_json::json!({
"stub": {
"server_name": server_name,

View File

@@ -6,6 +6,7 @@ use std::sync::Arc;
use axum::Router;
use axum::body::Body;
use axum::extract::Form;
use axum::extract::State;
use axum::http::Request;
use axum::http::StatusCode;
@@ -15,6 +16,7 @@ use axum::middleware;
use axum::middleware::Next;
use axum::response::Response;
use axum::routing::get;
use axum::routing::post;
use rmcp::ErrorData as McpError;
use rmcp::handler::server::ServerHandler;
use rmcp::model::CallToolRequestParams;
@@ -39,6 +41,8 @@ use rmcp::transport::StreamableHttpService;
use rmcp::transport::streamable_http_server::session::local::LocalSessionManager;
use serde::Deserialize;
use serde_json::json;
use tokio::sync::Mutex;
use tokio::sync::RwLock;
use tokio::task;
#[derive(Clone)]
@@ -48,6 +52,22 @@ struct TestToolServer {
resource_templates: Arc<Vec<ResourceTemplate>>,
}
#[derive(Clone)]
struct AuthState {
current_bearer: Arc<RwLock<Option<String>>>,
refresh_state: Option<Arc<Mutex<RefreshTokenState>>>,
}
#[derive(Debug)]
struct RefreshTokenState {
current_refresh_token: String,
next_access_token: String,
next_refresh_token: String,
expires_in: u64,
single_use: bool,
used_once: bool,
}
const MEMO_URI: &str = "memo://codex/example-note";
const MEMO_CONTENT: &str = "This is a sample MCP resource served by the rmcp test server.";
@@ -263,6 +283,15 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
};
eprintln!("starting rmcp streamable http test server on http://{bind_addr}/mcp");
let auth_state = AuthState {
current_bearer: Arc::new(RwLock::new(
std::env::var("MCP_EXPECT_BEARER")
.ok()
.map(|token| format!("Bearer {token}")),
)),
refresh_state: refresh_state_from_env(),
};
let router = Router::new()
.route(
"/.well-known/oauth-authorization-server/mcp",
@@ -284,6 +313,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
}
}),
)
.route("/oauth/token", post(oauth_refresh_token))
.nest_service(
"/mcp",
StreamableHttpService::new(
@@ -291,28 +321,108 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
Arc::new(LocalSessionManager::default()),
StreamableHttpServerConfig::default(),
),
);
let router = if let Ok(token) = std::env::var("MCP_EXPECT_BEARER") {
let expected = Arc::new(format!("Bearer {token}"));
router.layer(middleware::from_fn_with_state(expected, require_bearer))
} else {
router
};
)
.with_state(auth_state.clone())
.layer(middleware::from_fn_with_state(auth_state, require_bearer));
axum::serve(listener, router).await?;
task::yield_now().await;
Ok(())
}
fn refresh_state_from_env() -> Option<Arc<Mutex<RefreshTokenState>>> {
let current_refresh_token = std::env::var("MCP_EXPECT_REFRESH_TOKEN").ok()?;
let next_access_token = std::env::var("MCP_REFRESH_NEXT_ACCESS_TOKEN").ok()?;
let next_refresh_token = std::env::var("MCP_REFRESH_NEXT_REFRESH_TOKEN").ok()?;
let expires_in = std::env::var("MCP_REFRESH_EXPIRES_IN")
.ok()
.and_then(|value| value.parse::<u64>().ok())
.unwrap_or(3600);
let single_use = std::env::var("MCP_REFRESH_SINGLE_USE")
.ok()
.is_some_and(|value| value == "1");
Some(Arc::new(Mutex::new(RefreshTokenState {
current_refresh_token,
next_access_token,
next_refresh_token,
expires_in,
single_use,
used_once: false,
})))
}
async fn oauth_refresh_token(
State(state): State<AuthState>,
Form(form): Form<HashMap<String, String>>,
) -> Response {
let Some(refresh_state) = state.refresh_state.clone() else {
return json_response(StatusCode::NOT_FOUND, json!({ "error": "not_found" }));
};
if form.get("grant_type").map(String::as_str) != Some("refresh_token") {
return json_response(
StatusCode::BAD_REQUEST,
json!({ "error": "unsupported_grant_type" }),
);
}
let provided_refresh_token = form.get("refresh_token").map(String::as_str);
let mut refresh_state = refresh_state.lock().await;
if refresh_state.single_use && refresh_state.used_once {
return json_response(
StatusCode::UNAUTHORIZED,
json!({
"error": "invalid_grant",
"error_description": "refresh token was already used",
"code": "refresh_token_reused",
}),
);
}
if provided_refresh_token != Some(refresh_state.current_refresh_token.as_str()) {
return json_response(
StatusCode::UNAUTHORIZED,
json!({
"error": "invalid_grant",
"error_description": "refresh token was already used",
"code": "refresh_token_reused",
}),
);
}
let access_token = refresh_state.next_access_token.clone();
let refresh_token = refresh_state.next_refresh_token.clone();
let expires_in = refresh_state.expires_in;
refresh_state.current_refresh_token = refresh_token.clone();
refresh_state.used_once = true;
*state.current_bearer.write().await = Some(format!("Bearer {access_token}"));
json_response(
StatusCode::OK,
json!({
"access_token": access_token,
"token_type": "Bearer",
"refresh_token": refresh_token,
"expires_in": expires_in,
}),
)
}
async fn require_bearer(
State(expected): State<Arc<String>>,
State(state): State<AuthState>,
request: Request<Body>,
next: Next,
) -> Result<Response, StatusCode> {
if request.uri().path().contains("/.well-known/") {
let request_path = request.uri().path();
if request_path.contains("/.well-known/") || request_path.contains("/oauth/token") {
return Ok(next.run(request).await);
}
let expected = state.current_bearer.read().await.clone();
let Some(expected) = expected else {
return Ok(next.run(request).await);
};
if request
.headers()
.get(AUTHORIZATION)
@@ -323,3 +433,14 @@ async fn require_bearer(
Err(StatusCode::UNAUTHORIZED)
}
}
fn json_response(status: StatusCode, body: serde_json::Value) -> Response {
#[expect(clippy::expect_used)]
Response::builder()
.status(status)
.header(CONTENT_TYPE, "application/json")
.body(Body::from(
serde_json::to_vec(&body).expect("failed to serialize JSON response"),
))
.expect("valid JSON response")
}

View File

@@ -25,7 +25,10 @@ use oauth2::RefreshToken;
use oauth2::Scope;
use oauth2::TokenResponse;
use oauth2::basic::BasicTokenType;
use rmcp::transport::auth::CredentialStore;
use rmcp::transport::auth::InMemoryCredentialStore;
use rmcp::transport::auth::OAuthTokenResponse;
use rmcp::transport::auth::StoredCredentials;
use schemars::JsonSchema;
use serde::Deserialize;
use serde::Serialize;
@@ -273,15 +276,32 @@ struct OAuthPersistorInner {
server_name: String,
url: String,
authorization_manager: Arc<Mutex<AuthorizationManager>>,
runtime_credentials: InMemoryCredentialStore,
store_mode: OAuthCredentialsStoreMode,
last_credentials: Mutex<Option<StoredOAuthTokens>>,
}
#[derive(Debug, Clone, PartialEq)]
enum GuardedRefreshOutcome {
NoAction,
ReloadedChanged(StoredOAuthTokens),
ReloadedNoChange,
MissingOrInvalid,
ReloadFailed,
}
#[derive(Debug, PartialEq)]
enum GuardedRefreshPersistedCredentials {
Loaded(Option<StoredOAuthTokens>),
ReloadFailed,
}
impl OAuthPersistor {
pub(crate) fn new(
server_name: String,
url: String,
authorization_manager: Arc<Mutex<AuthorizationManager>>,
runtime_credentials: InMemoryCredentialStore,
store_mode: OAuthCredentialsStoreMode,
initial_credentials: Option<StoredOAuthTokens>,
) -> Self {
@@ -290,6 +310,7 @@ impl OAuthPersistor {
server_name,
url,
authorization_manager,
runtime_credentials,
store_mode,
last_credentials: Mutex::new(initial_credentials),
}),
@@ -350,28 +371,220 @@ impl OAuthPersistor {
Ok(())
}
/// Guard refreshes against multi-process refresh-token reuse.
///
/// MCP OAuth credentials live in shared storage, but each Codex process also keeps an
/// in-memory snapshot. Before refreshing, reload the shared credentials and compare them to
/// the cached copy:
/// - if the local cache was cleared, reload shared storage first so this process can recover
/// when another process logs in and persists fresh credentials;
/// - if shared storage changed, another process already refreshed, so adopt those credentials
/// in the live runtime and skip the local refresh;
/// - if shared storage is unchanged, this process still owns the refresh and can rotate the
/// tokens with the authority;
/// - if shared storage no longer has credentials, treat that as logged out and clear the live
/// runtime instead of sending a stale refresh token.
pub(crate) async fn refresh_if_needed(&self) -> Result<()> {
let expires_at = {
let mut cached_credentials = {
let guard = self.inner.last_credentials.lock().await;
guard.as_ref().and_then(|tokens| tokens.expires_at)
guard.clone()
};
if !token_needs_refresh(expires_at) {
return Ok(());
if cached_credentials.is_none()
&& let Some(credentials) = load_oauth_tokens_when_cache_missing(
&self.inner.server_name,
&self.inner.url,
self.inner.store_mode,
)
{
self.apply_runtime_credentials(Some(credentials.clone()))
.await?;
cached_credentials = Some(credentials);
}
match self.guarded_refresh_outcome(cached_credentials.as_ref()) {
GuardedRefreshOutcome::NoAction => Ok(()),
GuardedRefreshOutcome::ReloadedChanged(credentials) => {
self.apply_runtime_credentials(Some(credentials)).await
}
GuardedRefreshOutcome::ReloadedNoChange => {
{
let manager = self.inner.authorization_manager.clone();
let guard = manager.lock().await;
guard.refresh_token().await.with_context(|| {
format!(
"failed to refresh OAuth tokens for server {}",
self.inner.server_name
)
})?;
}
self.persist_if_needed().await
}
GuardedRefreshOutcome::MissingOrInvalid => self.apply_runtime_credentials(None).await,
GuardedRefreshOutcome::ReloadFailed => Ok(()),
}
}
fn guarded_refresh_outcome(
&self,
cached_credentials: Option<&StoredOAuthTokens>,
) -> GuardedRefreshOutcome {
let Some(cached_credentials) = cached_credentials else {
return GuardedRefreshOutcome::NoAction;
};
if !token_needs_refresh(cached_credentials.expires_at) {
return GuardedRefreshOutcome::NoAction;
}
match load_oauth_tokens_for_guarded_refresh(
&self.inner.server_name,
&self.inner.url,
self.inner.store_mode,
) {
GuardedRefreshPersistedCredentials::Loaded(persisted_credentials) => {
determine_guarded_refresh_outcome(cached_credentials, persisted_credentials)
}
GuardedRefreshPersistedCredentials::ReloadFailed => GuardedRefreshOutcome::ReloadFailed,
}
}
async fn apply_runtime_credentials(
&self,
credentials: Option<StoredOAuthTokens>,
) -> Result<()> {
{
let manager = self.inner.authorization_manager.clone();
let guard = manager.lock().await;
guard.refresh_token().await.with_context(|| {
format!(
"failed to refresh OAuth tokens for server {}",
self.inner.server_name
)
})?;
let mut guard = manager.lock().await;
match credentials.as_ref() {
Some(credentials) => {
self.inner
.runtime_credentials
.save(StoredCredentials {
client_id: credentials.client_id.clone(),
token_response: Some(credentials.token_response.0.clone()),
})
.await?;
guard
.configure_client_id(&credentials.client_id)
.with_context(|| {
format!(
"failed to reconfigure OAuth client for server {}",
self.inner.server_name
)
})?;
}
None => {
self.inner.runtime_credentials.clear().await?;
}
}
}
self.persist_if_needed().await
let mut last_credentials = self.inner.last_credentials.lock().await;
*last_credentials = credentials;
Ok(())
}
}
fn load_oauth_tokens_for_guarded_refresh(
server_name: &str,
url: &str,
store_mode: OAuthCredentialsStoreMode,
) -> GuardedRefreshPersistedCredentials {
let keyring_store = DefaultKeyringStore;
match store_mode {
OAuthCredentialsStoreMode::Auto => {
load_oauth_tokens_for_guarded_refresh_with_keyring_fallback(
&keyring_store,
server_name,
url,
)
}
OAuthCredentialsStoreMode::File => guarded_refresh_persisted_credentials_from_load_result(
load_oauth_tokens_from_file(server_name, url),
server_name,
),
OAuthCredentialsStoreMode::Keyring => {
guarded_refresh_persisted_credentials_from_load_result(
load_oauth_tokens_from_keyring(&keyring_store, server_name, url)
.with_context(|| "failed to read OAuth tokens from keyring".to_string()),
server_name,
)
}
}
}
fn load_oauth_tokens_when_cache_missing(
server_name: &str,
url: &str,
store_mode: OAuthCredentialsStoreMode,
) -> Option<StoredOAuthTokens> {
match load_oauth_tokens_for_guarded_refresh(server_name, url, store_mode) {
GuardedRefreshPersistedCredentials::Loaded(Some(credentials)) => Some(credentials),
GuardedRefreshPersistedCredentials::Loaded(None)
| GuardedRefreshPersistedCredentials::ReloadFailed => None,
}
}
fn load_oauth_tokens_for_guarded_refresh_with_keyring_fallback<K: KeyringStore>(
keyring_store: &K,
server_name: &str,
url: &str,
) -> GuardedRefreshPersistedCredentials {
match load_oauth_tokens_from_keyring(keyring_store, server_name, url) {
Ok(Some(tokens)) => GuardedRefreshPersistedCredentials::Loaded(Some(tokens)),
Ok(None) => guarded_refresh_persisted_credentials_from_load_result(
load_oauth_tokens_from_file(server_name, url),
server_name,
),
Err(error) => {
warn!("failed to read OAuth tokens from keyring: {error}");
match load_oauth_tokens_from_file(server_name, url) {
Ok(Some(tokens)) => GuardedRefreshPersistedCredentials::Loaded(Some(tokens)),
Ok(None) => {
warn!(
"failed to reload OAuth tokens for server {server_name}: keyring read failed and no fallback file credentials were available"
);
GuardedRefreshPersistedCredentials::ReloadFailed
}
Err(file_error) => {
warn!(
"failed to reload OAuth tokens for server {server_name}: keyring read failed ({error}) and fallback file reload failed: {file_error}"
);
GuardedRefreshPersistedCredentials::ReloadFailed
}
}
}
}
}
#[cfg(test)]
fn guarded_refresh_outcome_from_load_result(
cached_credentials: &StoredOAuthTokens,
persisted_credentials: Result<Option<StoredOAuthTokens>>,
server_name: &str,
) -> GuardedRefreshOutcome {
match guarded_refresh_persisted_credentials_from_load_result(persisted_credentials, server_name)
{
GuardedRefreshPersistedCredentials::Loaded(persisted_credentials) => {
determine_guarded_refresh_outcome(cached_credentials, persisted_credentials)
}
GuardedRefreshPersistedCredentials::ReloadFailed => GuardedRefreshOutcome::ReloadFailed,
}
}
fn guarded_refresh_persisted_credentials_from_load_result(
persisted_credentials: Result<Option<StoredOAuthTokens>>,
server_name: &str,
) -> GuardedRefreshPersistedCredentials {
match persisted_credentials {
Ok(credentials) => GuardedRefreshPersistedCredentials::Loaded(credentials),
Err(error) => {
warn!("failed to reload OAuth tokens for server {server_name}: {error}");
GuardedRefreshPersistedCredentials::ReloadFailed
}
}
}
@@ -521,6 +734,61 @@ fn token_needs_refresh(expires_at: Option<u64>) -> bool {
now.saturating_add(REFRESH_SKEW_MILLIS) >= expires_at
}
fn determine_guarded_refresh_outcome(
cached_credentials: &StoredOAuthTokens,
persisted_credentials: Option<StoredOAuthTokens>,
) -> GuardedRefreshOutcome {
match persisted_credentials {
Some(persisted_credentials)
if oauth_tokens_equal_for_refresh(
Some(cached_credentials),
Some(&persisted_credentials),
) =>
{
GuardedRefreshOutcome::ReloadedNoChange
}
Some(persisted_credentials) => {
GuardedRefreshOutcome::ReloadedChanged(persisted_credentials)
}
None => GuardedRefreshOutcome::MissingOrInvalid,
}
}
fn oauth_tokens_equal_for_refresh(
left: Option<&StoredOAuthTokens>,
right: Option<&StoredOAuthTokens>,
) -> bool {
match (left, right) {
(None, None) => true,
(Some(left), Some(right)) => {
left.server_name == right.server_name
&& left.url == right.url
&& left.client_id == right.client_id
&& left.expires_at == right.expires_at
&& oauth_token_responses_equal_for_refresh(
&left.token_response,
&right.token_response,
)
}
_ => false,
}
}
fn oauth_token_responses_equal_for_refresh(
left: &WrappedOAuthTokenResponse,
right: &WrappedOAuthTokenResponse,
) -> bool {
let left = &left.0;
let right = &right.0;
left.access_token().secret() == right.access_token().secret()
&& left.token_type() == right.token_type()
&& left.refresh_token().map(RefreshToken::secret)
== right.refresh_token().map(RefreshToken::secret)
&& left.scopes() == right.scopes()
&& left.extra_fields() == right.extra_fields()
}
fn compute_store_key(server_name: &str, server_url: &str) -> Result<String> {
let mut payload = JsonMap::new();
payload.insert(
@@ -855,6 +1123,158 @@ mod tests {
assert!(tokens.token_response.0.expires_in().is_none());
}
#[test]
fn guarded_refresh_outcome_reloads_when_persisted_credentials_changed() {
let cached = sample_tokens();
let mut persisted = sample_tokens();
persisted
.token_response
.0
.set_refresh_token(Some(RefreshToken::new("rotated-refresh-token".to_string())));
persisted
.token_response
.0
.set_expires_in(Some(&Duration::from_secs(7200)));
persisted.expires_at = super::compute_expires_at_millis(&persisted.token_response.0);
assert_eq!(
super::determine_guarded_refresh_outcome(&cached, Some(persisted.clone())),
super::GuardedRefreshOutcome::ReloadedChanged(persisted),
);
}
#[test]
fn guarded_refresh_outcome_refreshes_when_persisted_credentials_match() {
let cached = sample_tokens();
let mut persisted = cached.clone();
persisted
.token_response
.0
.set_expires_in(Some(&Duration::from_secs(5)));
assert_eq!(
super::determine_guarded_refresh_outcome(&cached, Some(persisted)),
super::GuardedRefreshOutcome::ReloadedNoChange,
);
}
#[test]
fn guarded_refresh_outcome_clears_when_persisted_credentials_are_missing() {
assert_eq!(
super::determine_guarded_refresh_outcome(&sample_tokens(), None),
super::GuardedRefreshOutcome::MissingOrInvalid,
);
}
#[test]
fn guarded_refresh_outcome_keeps_state_recoverable_when_reload_fails() {
let error = anyhow::anyhow!("transient read failure");
assert_eq!(
super::guarded_refresh_outcome_from_load_result(
&sample_tokens(),
Err(error),
"test-server",
),
super::GuardedRefreshOutcome::ReloadFailed,
);
}
#[test]
fn guarded_refresh_auto_load_keeps_state_recoverable_when_keyring_fails_without_file() {
let _env = TempCodexHome::new();
let store = MockKeyringStore::default();
let tokens = sample_tokens();
let key = super::compute_store_key(&tokens.server_name, &tokens.url)
.expect("store key should compute");
store.set_error(&key, KeyringError::Invalid("error".into(), "load".into()));
assert_eq!(
super::load_oauth_tokens_for_guarded_refresh_with_keyring_fallback(
&store,
&tokens.server_name,
&tokens.url,
),
super::GuardedRefreshPersistedCredentials::ReloadFailed,
);
}
#[test]
fn missing_cached_credentials_reload_shared_store_from_file() -> Result<()> {
let _env = TempCodexHome::new();
let tokens = sample_tokens();
let expected = tokens.clone();
super::save_oauth_tokens_to_file(&tokens)?;
let loaded = super::load_oauth_tokens_when_cache_missing(
&tokens.server_name,
&tokens.url,
OAuthCredentialsStoreMode::File,
)
.expect("tokens should reload from shared file store");
assert_tokens_match_without_expiry(&loaded, &expected);
Ok(())
}
#[test]
fn missing_cached_credentials_ignore_reload_failures() {
let _env = TempCodexHome::new();
let store = MockKeyringStore::default();
let tokens = sample_tokens();
let key = super::compute_store_key(&tokens.server_name, &tokens.url)
.expect("store key should compute");
store.set_error(&key, KeyringError::Invalid("error".into(), "load".into()));
assert_eq!(
super::load_oauth_tokens_for_guarded_refresh_with_keyring_fallback(
&store,
&tokens.server_name,
&tokens.url,
),
super::GuardedRefreshPersistedCredentials::ReloadFailed,
);
assert_eq!(
super::load_oauth_tokens_when_cache_missing(
&tokens.server_name,
&tokens.url,
OAuthCredentialsStoreMode::Auto,
),
None,
);
}
#[test]
fn oauth_tokens_equal_for_refresh_ignores_only_expires_in() {
let left = sample_tokens();
let mut right = left.clone();
right
.token_response
.0
.set_expires_in(Some(&Duration::from_secs(5)));
assert!(super::oauth_tokens_equal_for_refresh(
Some(&left),
Some(&right),
));
let mut different_refresh_token = right.clone();
different_refresh_token
.token_response
.0
.set_refresh_token(Some(RefreshToken::new("different-refresh".to_string())));
assert!(!super::oauth_tokens_equal_for_refresh(
Some(&left),
Some(&different_refresh_token),
));
let mut different_expiry = right;
different_expiry.expires_at = different_expiry.expires_at.map(|value| value + 1000);
assert!(!super::oauth_tokens_equal_for_refresh(
Some(&left),
Some(&different_expiry),
));
}
fn assert_tokens_match_without_expiry(
actual: &StoredOAuthTokens,
expected: &StoredOAuthTokens,

View File

@@ -39,7 +39,10 @@ use rmcp::service::{self};
use rmcp::transport::StreamableHttpClientTransport;
use rmcp::transport::auth::AuthClient;
use rmcp::transport::auth::AuthError;
use rmcp::transport::auth::OAuthState;
use rmcp::transport::auth::AuthorizationManager;
use rmcp::transport::auth::CredentialStore;
use rmcp::transport::auth::InMemoryCredentialStore;
use rmcp::transport::auth::StoredCredentials;
use rmcp::transport::child_process::TokioChildProcess;
use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig;
use serde_json::Value;
@@ -358,6 +361,12 @@ impl RmcpClient {
}
};
if let Some(runtime) = &oauth_persistor
&& let Err(error) = runtime.refresh_if_needed().await
{
warn!("failed to refresh OAuth tokens before initialize: {error}");
}
let service = match timeout {
Some(duration) => time::timeout(duration, transport)
.await
@@ -595,22 +604,20 @@ async fn create_oauth_transport_and_runtime(
)> {
let http_client =
apply_default_headers(reqwest::Client::builder(), &default_headers).build()?;
let mut oauth_state = OAuthState::new(url.to_string(), Some(http_client.clone())).await?;
oauth_state
.set_credentials(
&initial_tokens.client_id,
initial_tokens.token_response.0.clone(),
)
let runtime_credentials = InMemoryCredentialStore::new();
runtime_credentials
.save(StoredCredentials {
client_id: initial_tokens.client_id.clone(),
token_response: Some(initial_tokens.token_response.0.clone()),
})
.await?;
let manager = match oauth_state {
OAuthState::Authorized(manager) => manager,
OAuthState::Unauthorized(manager) => manager,
OAuthState::Session(_) | OAuthState::AuthorizedHttpClient(_) => {
return Err(anyhow!("unexpected OAuth state during client setup"));
}
};
let mut manager = AuthorizationManager::new(url.to_string()).await?;
manager.set_credential_store(runtime_credentials.clone());
manager.with_client(http_client.clone())?;
let metadata = manager.discover_metadata().await?;
manager.set_metadata(metadata);
manager.configure_client_id(&initial_tokens.client_id)?;
let auth_client = AuthClient::new(http_client, manager);
let auth_manager = auth_client.auth_manager.clone();
@@ -624,6 +631,7 @@ async fn create_oauth_transport_and_runtime(
server_name.to_string(),
url.to_string(),
auth_manager,
runtime_credentials,
credentials_store,
Some(initial_tokens),
);