Compare commits

...

1 Commits

Author SHA1 Message Date
Sama Setty
9649a2eaee Retry streamable HTTP initialize failures 2026-05-29 16:01:18 -07:00
5 changed files with 468 additions and 70 deletions

View File

@@ -63,11 +63,13 @@ struct TestToolServer {
const MEMO_URI: &str = "memo://codex/example-note";
const MEMO_CONTENT: &str = "This is a sample MCP resource served by the rmcp test server.";
const MCP_SESSION_ID_HEADER: &str = "mcp-session-id";
const INITIALIZE_FAILURE_CONTROL_PATH: &str = "/test/control/initialize-failure";
const SESSION_POST_FAILURE_CONTROL_PATH: &str = "/test/control/session-post-failure";
#[derive(Clone, Default)]
struct SessionFailureState {
armed_failure: Arc<Mutex<Option<ArmedFailure>>>,
struct FailureState {
initialize_failure: Arc<Mutex<Option<ArmedFailure>>>,
session_post_failure: Arc<Mutex<Option<ArmedFailure>>>,
}
#[derive(Clone, Debug)]
@@ -79,7 +81,7 @@ struct ArmedFailure {
}
#[derive(Debug, Deserialize)]
struct ArmSessionPostFailureRequest {
struct ArmFailureRequest {
status: u16,
remaining: usize,
/// Raw `WWW-Authenticate` challenge header field values to add to the failure.
@@ -97,7 +99,7 @@ struct EchoArgs {
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let bind_addr = parse_bind_addr()?;
let session_failure_state = SessionFailureState::default();
let failure_state = FailureState::default();
const MAX_BIND_RETRIES: u32 = 20;
const BIND_RETRY_DELAY: Duration = Duration::from_millis(50);
@@ -125,6 +127,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
eprintln!("starting rmcp streamable http test server on http://{actual_bind_addr}/mcp");
let router = Router::new()
.route(INITIALIZE_FAILURE_CONTROL_PATH, post(arm_initialize_failure))
.route(
SESSION_POST_FAILURE_CONTROL_PATH,
post(arm_session_post_failure),
@@ -162,10 +165,10 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
),
)
.layer(middleware::from_fn_with_state(
session_failure_state.clone(),
fail_session_post_when_armed,
failure_state.clone(),
fail_post_when_armed,
))
.with_state(session_failure_state);
.with_state(failure_state);
let router = if let Ok(token) = std::env::var("MCP_EXPECT_BEARER") {
let expected = Arc::new(format!("Bearer {token}"));
@@ -404,8 +407,22 @@ async fn require_bearer(
}
async fn arm_session_post_failure(
State(state): State<SessionFailureState>,
Json(request): Json<ArmSessionPostFailureRequest>,
State(state): State<FailureState>,
Json(request): Json<ArmFailureRequest>,
) -> Result<StatusCode, StatusCode> {
arm_failure(&state.session_post_failure, request).await
}
async fn arm_initialize_failure(
State(state): State<FailureState>,
Json(request): Json<ArmFailureRequest>,
) -> Result<StatusCode, StatusCode> {
arm_failure(&state.initialize_failure, request).await
}
async fn arm_failure(
armed_failure: &Arc<Mutex<Option<ArmedFailure>>>,
request: ArmFailureRequest,
) -> Result<StatusCode, StatusCode> {
let status = StatusCode::from_u16(request.status).map_err(|_| StatusCode::BAD_REQUEST)?;
let www_authenticate_headers = request
@@ -413,7 +430,7 @@ async fn arm_session_post_failure(
.into_iter()
.map(|value| HeaderValue::from_str(&value).map_err(|_| StatusCode::BAD_REQUEST))
.collect::<Result<Vec<_>, _>>()?;
let armed_failure = if request.remaining == 0 {
let failure = if request.remaining == 0 {
None
} else {
Some(ArmedFailure {
@@ -422,45 +439,56 @@ async fn arm_session_post_failure(
www_authenticate_headers,
})
};
*state.armed_failure.lock().await = armed_failure;
*armed_failure.lock().await = failure;
Ok(StatusCode::NO_CONTENT)
}
async fn fail_session_post_when_armed(
State(state): State<SessionFailureState>,
async fn fail_post_when_armed(
State(state): State<FailureState>,
request: Request<Body>,
next: Next,
) -> Response {
if request.uri().path() != "/mcp"
|| request.method() != Method::POST
|| !request.headers().contains_key(MCP_SESSION_ID_HEADER)
{
if request.uri().path() != "/mcp" || request.method() != Method::POST {
return next.run(request).await;
}
{
let mut armed_failure = state.armed_failure.lock().await;
if let Some(failure) = armed_failure.as_mut()
&& failure.remaining > 0
{
failure.remaining -= 1;
let status = failure.status;
let www_authenticate_headers = failure.www_authenticate_headers.clone();
if failure.remaining == 0 {
*armed_failure = None;
}
let mut response = Response::new(Body::from(format!(
"forced session failure with status {status}"
)));
*response.status_mut() = status;
for www_authenticate_header in www_authenticate_headers {
response
.headers_mut()
.append(WWW_AUTHENTICATE, www_authenticate_header);
}
return response;
}
let (armed_failure, label) = if request.headers().contains_key(MCP_SESSION_ID_HEADER) {
(&state.session_post_failure, "session")
} else {
(&state.initialize_failure, "initialize")
};
if let Some(response) = consume_failure(armed_failure, label).await {
return response;
}
next.run(request).await
}
async fn consume_failure(
armed_failure: &Arc<Mutex<Option<ArmedFailure>>>,
label: &str,
) -> Option<Response> {
let mut armed_failure = armed_failure.lock().await;
let failure = armed_failure.as_mut()?;
if failure.remaining == 0 {
return None;
}
failure.remaining -= 1;
let status = failure.status;
let www_authenticate_headers = failure.www_authenticate_headers.clone();
if failure.remaining == 0 {
*armed_failure = None;
}
let mut response = Response::new(Body::from(format!(
"forced {label} failure with status {status}"
)));
*response.status_mut() = status;
for www_authenticate_header in www_authenticate_headers {
response
.headers_mut()
.append(WWW_AUTHENTICATE, www_authenticate_header);
}
Some(response)
}

View File

@@ -58,6 +58,8 @@ pub(crate) struct StreamableHttpClientAdapter {
pub(crate) enum StreamableHttpClientAdapterError {
#[error("streamable HTTP session expired with 404 Not Found")]
SessionExpired404,
#[error("streamable HTTP request returned retryable HTTP {0}")]
RetryableHttpStatus(u16),
#[error(transparent)]
HttpRequest(#[from] ExecServerError),
#[error("invalid HTTP header: {0}")]
@@ -182,6 +184,11 @@ impl StreamableHttpClient for StreamableHttpClientAdapter {
) {
return Ok(StreamableHttpPostResponse::Accepted);
}
if is_retryable_http_status(response.status) {
return Err(StreamableHttpError::Client(
StreamableHttpClientAdapterError::RetryableHttpStatus(response.status),
));
}
let content_type = response_header(&response.headers, CONTENT_TYPE);
let session_id = response_header(&response.headers, HEADER_SESSION_ID);
@@ -463,6 +470,10 @@ fn status_is_success(status: u16) -> bool {
StatusCode::from_u16(status).is_ok_and(|status| status.is_success())
}
fn is_retryable_http_status(status: u16) -> bool {
matches!(status, 408 | 429 | 500 | 502 | 503 | 504)
}
async fn collect_body(
body_stream: &mut HttpResponseBodyStream,
) -> std::result::Result<Vec<u8>, StreamableHttpError<StreamableHttpClientAdapterError>> {

View File

@@ -14,6 +14,7 @@ use anyhow::anyhow;
use codex_api::SharedAuthProvider;
use codex_client::maybe_build_rustls_client_config_with_custom_ca;
use codex_config::types::McpServerEnvVar;
use codex_exec_server::ExecServerError;
use codex_exec_server::HttpClient;
use futures::FutureExt;
use futures::future::BoxFuture;
@@ -74,6 +75,8 @@ use crate::utils::apply_default_headers;
use crate::utils::build_default_headers;
use codex_config::types::OAuthCredentialsStoreMode;
const STREAMABLE_HTTP_RETRY_DELAYS_MS: [u64; 2] = [250, 1_000];
enum PendingTransport {
InProcess {
transport: tokio::io::DuplexStream,
@@ -90,6 +93,16 @@ enum PendingTransport {
},
}
impl PendingTransport {
fn is_streamable_http(&self) -> bool {
matches!(
self,
PendingTransport::StreamableHttp { .. }
| PendingTransport::StreamableHttpWithOAuth { .. }
)
}
}
enum ClientState {
Connecting {
transport: Option<PendingTransport>,
@@ -223,6 +236,13 @@ enum ClientOperationError {
Timeout { label: String, duration: Duration },
}
#[derive(Debug, thiserror::Error)]
#[error("handshaking with MCP server failed: {source}")]
struct HandshakeError {
#[source]
source: rmcp::service::ClientInitializeError,
}
pub type Elicitation = CreateElicitationRequestParams;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
@@ -396,9 +416,13 @@ impl RmcpClient {
}
};
let (service, oauth_persistor) =
Self::connect_pending_transport(pending_transport, client_service.clone(), timeout)
.await?;
let (service, oauth_persistor) = self
.connect_pending_transport_with_initialize_retries(
pending_transport,
client_service.clone(),
timeout,
)
.await?;
let initialize_result_rmcp = service
.peer()
@@ -849,15 +873,63 @@ impl RmcpClient {
Some(duration) => time::timeout(duration, transport)
.await
.map_err(|_| anyhow!("timed out handshaking with MCP server after {duration:?}"))?
.map_err(|err| anyhow!("handshaking with MCP server failed: {err}"))?,
.map_err(|source| HandshakeError { source })?,
None => transport
.await
.map_err(|err| anyhow!("handshaking with MCP server failed: {err}"))?,
.map_err(|source| HandshakeError { source })?,
};
Ok((Arc::new(service), oauth_persistor))
}
async fn connect_pending_transport_with_initialize_retries(
&self,
initial_transport: PendingTransport,
client_service: ElicitationClientService,
timeout: Option<Duration>,
) -> Result<(
Arc<RunningService<RoleClient, ElicitationClientService>>,
Option<OAuthPersistor>,
)> {
let should_retry = initial_transport.is_streamable_http();
let mut pending_transport = Some(initial_transport);
let retry_schedule = STREAMABLE_HTTP_RETRY_DELAYS_MS
.iter()
.copied()
.map(Some)
.chain(std::iter::once(None));
for (attempt, retry_delay_ms) in retry_schedule.enumerate() {
let transport = match pending_transport.take() {
Some(transport) => transport,
None => Self::create_pending_transport(&self.transport_recipe).await?,
};
match Self::connect_pending_transport(transport, client_service.clone(), timeout).await
{
Ok(result) => return Ok(result),
Err(error) if should_retry && Self::is_retryable_initialize_error(&error) => {
let Some(retry_delay_ms) = retry_delay_ms else {
return Err(error);
};
let delay = Duration::from_millis(retry_delay_ms);
warn!(
attempt = attempt + 1,
max_attempts = STREAMABLE_HTTP_RETRY_DELAYS_MS.len() + 1,
delay_ms = delay.as_millis(),
error = %error,
"streamable HTTP MCP initialize failed with a retryable error; retrying"
);
time::sleep(delay).await;
}
Err(error) => return Err(error),
}
}
unreachable!("initialize retry loop should return on success or final error")
}
async fn run_service_operation<T, F, Fut>(
&self,
label: &str,
@@ -868,31 +940,45 @@ impl RmcpClient {
F: Fn(Arc<RunningService<RoleClient, ElicitationClientService>>) -> Fut,
Fut: std::future::Future<Output = std::result::Result<T, rmcp::service::ServiceError>>,
{
let service = self.service().await?;
match Self::run_service_operation_once(
Arc::clone(&service),
label,
timeout,
self.elicitation_pause_state.clone(),
&operation,
)
.await
{
Ok(result) => Ok(result),
Err(error) if Self::is_session_expired_404(&error) => {
self.reinitialize_after_session_expiry(&service).await?;
let recovered_service = self.service().await?;
Self::run_service_operation_once(
recovered_service,
label,
timeout,
self.elicitation_pause_state.clone(),
&operation,
)
.await
.map_err(Into::into)
let mut session_recovery_attempted = false;
let mut retry_attempt = 0;
loop {
let service = self.service().await?;
match Self::run_service_operation_once(
Arc::clone(&service),
label,
timeout,
self.elicitation_pause_state.clone(),
&operation,
)
.await
{
Ok(result) => return Ok(result),
Err(error)
if !session_recovery_attempted && Self::is_session_expired_404(&error) =>
{
session_recovery_attempted = true;
self.reinitialize_after_session_expiry(&service).await?;
}
Err(error)
if Self::should_retry_tools_list_operation(label, retry_attempt, &error) =>
{
let delay =
Duration::from_millis(STREAMABLE_HTTP_RETRY_DELAYS_MS[retry_attempt]);
retry_attempt += 1;
warn!(
label,
attempt = retry_attempt,
max_attempts = STREAMABLE_HTTP_RETRY_DELAYS_MS.len() + 1,
delay_ms = delay.as_millis(),
error = %error,
"MCP service operation failed with a retryable error; retrying"
);
time::sleep(delay).await;
}
Err(error) => return Err(error.into()),
}
Err(error) => Err(error.into()),
}
}
@@ -941,6 +1027,68 @@ impl RmcpClient {
})
}
fn should_retry_tools_list_operation(
label: &str,
retry_attempt: usize,
error: &ClientOperationError,
) -> bool {
label == "tools/list"
&& retry_attempt < STREAMABLE_HTTP_RETRY_DELAYS_MS.len()
&& Self::is_retryable_service_operation_error(error)
}
fn is_retryable_service_operation_error(error: &ClientOperationError) -> bool {
let ClientOperationError::Service(rmcp::service::ServiceError::TransportSend(error)) =
error
else {
return false;
};
error
.error
.downcast_ref::<StreamableHttpError<StreamableHttpClientAdapterError>>()
.is_some_and(Self::is_retryable_streamable_http_error)
}
fn is_retryable_initialize_error(error: &anyhow::Error) -> bool {
error.chain().any(|source| {
source
.downcast_ref::<HandshakeError>()
.is_some_and(|error| Self::is_retryable_client_initialize_error(&error.source))
|| source
.downcast_ref::<rmcp::service::ClientInitializeError>()
.is_some_and(Self::is_retryable_client_initialize_error)
})
}
fn is_retryable_client_initialize_error(error: &rmcp::service::ClientInitializeError) -> bool {
match error {
rmcp::service::ClientInitializeError::TransportError { error, context }
if context.as_ref() == "send initialize request" =>
{
error
.error
.downcast_ref::<StreamableHttpError<StreamableHttpClientAdapterError>>()
.is_some_and(Self::is_retryable_streamable_http_error)
}
_ => false,
}
}
fn is_retryable_streamable_http_error(
error: &StreamableHttpError<StreamableHttpClientAdapterError>,
) -> bool {
matches!(
error,
StreamableHttpError::Client(
StreamableHttpClientAdapterError::RetryableHttpStatus(_)
| StreamableHttpClientAdapterError::HttpRequest(ExecServerError::HttpRequest(
_
))
)
)
}
async fn reinitialize_after_session_expiry(
&self,
failed_service: &Arc<RunningService<RoleClient, ElicitationClientService>>,

View File

@@ -1,13 +1,196 @@
mod streamable_http_test_support;
use pretty_assertions::assert_eq;
use std::sync::Arc;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;
use std::time::Duration;
use codex_exec_server::Environment;
use codex_exec_server::ExecServerError;
use codex_exec_server::HttpClient;
use codex_exec_server::HttpRequestParams;
use codex_exec_server::HttpRequestResponse;
use codex_exec_server::HttpResponseBodyStream;
use futures::FutureExt as _;
use futures::future::BoxFuture;
use pretty_assertions::assert_eq;
use serde_json::Value;
use streamable_http_test_support::arm_initialize_failure;
use streamable_http_test_support::arm_session_post_failure;
use streamable_http_test_support::call_echo_tool;
use streamable_http_test_support::create_client;
use streamable_http_test_support::create_client_with_http_client;
use streamable_http_test_support::expected_echo_result;
use streamable_http_test_support::spawn_streamable_http_server;
#[derive(Clone)]
struct FailFirstMethodHttpClient {
inner: Arc<dyn HttpClient>,
method: &'static str,
failures_remaining: Arc<AtomicUsize>,
matching_post_attempts: Arc<AtomicUsize>,
}
impl FailFirstMethodHttpClient {
fn new(inner: Arc<dyn HttpClient>, method: &'static str) -> Self {
Self {
inner,
method,
failures_remaining: Arc::new(AtomicUsize::new(1)),
matching_post_attempts: Arc::new(AtomicUsize::new(0)),
}
}
fn matching_post_attempts(&self) -> usize {
self.matching_post_attempts.load(Ordering::SeqCst)
}
}
impl HttpClient for FailFirstMethodHttpClient {
fn http_request(
&self,
params: HttpRequestParams,
) -> BoxFuture<'_, Result<HttpRequestResponse, ExecServerError>> {
self.inner.http_request(params)
}
fn http_request_stream(
&self,
params: HttpRequestParams,
) -> BoxFuture<'_, Result<(HttpRequestResponse, HttpResponseBodyStream), ExecServerError>> {
let inner = Arc::clone(&self.inner);
let method = self.method;
let failures_remaining = Arc::clone(&self.failures_remaining);
let matching_post_attempts = Arc::clone(&self.matching_post_attempts);
async move {
if is_json_rpc_method(&params, method) {
matching_post_attempts.fetch_add(1, Ordering::SeqCst);
if failures_remaining
.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |remaining| {
remaining.checked_sub(1)
})
.is_ok()
{
return Err(ExecServerError::HttpRequest(
"http/request failed: error sending request for url (simulated no response)"
.to_string(),
));
}
}
inner.http_request_stream(params).await
}
.boxed()
}
}
fn is_json_rpc_method(params: &HttpRequestParams, method: &str) -> bool {
if !params.method.eq_ignore_ascii_case("POST") {
return false;
}
params
.body
.as_ref()
.and_then(|body| serde_json::from_slice::<Value>(&body.0).ok())
.and_then(|body| {
body.get("method")
.and_then(Value::as_str)
.map(str::to_string)
})
.is_some_and(|request_method| request_method == method)
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn streamable_http_initialize_retries_retryable_status() -> anyhow::Result<()> {
let (_server, base_url) = spawn_streamable_http_server().await?;
arm_initialize_failure(&base_url, /*status*/ 503, /*remaining*/ 1).await?;
let client = create_client(&base_url).await?;
let result = call_echo_tool(&client, "after-init-retry").await?;
assert_eq!(result, expected_echo_result("after-init-retry"));
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn streamable_http_initialize_retries_http_request_error() -> anyhow::Result<()> {
let (_server, base_url) = spawn_streamable_http_server().await?;
let http_client = FailFirstMethodHttpClient::new(
Environment::default_for_tests().get_http_client(),
"initialize",
);
let client = create_client_with_http_client(&base_url, Arc::new(http_client.clone())).await?;
let result = call_echo_tool(&client, "after-no-response-retry").await?;
assert_eq!(http_client.matching_post_attempts(), 2);
assert_eq!(result, expected_echo_result("after-no-response-retry"));
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn streamable_http_tools_list_retries_retryable_status() -> anyhow::Result<()> {
let (_server, base_url) = spawn_streamable_http_server().await?;
let client = create_client(&base_url).await?;
arm_session_post_failure(
&base_url,
/*status*/ 503,
/*remaining*/ 1,
/*www_authenticate_headers*/ &[],
)
.await?;
let tools = client
.list_tools(/*params*/ None, Some(Duration::from_secs(5)))
.await?;
assert_eq!(tools.tools.len(), 1);
assert_eq!(tools.tools[0].name, "echo");
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn streamable_http_tools_list_retries_http_request_error() -> anyhow::Result<()> {
let (_server, base_url) = spawn_streamable_http_server().await?;
let http_client = FailFirstMethodHttpClient::new(
Environment::default_for_tests().get_http_client(),
"tools/list",
);
let client = create_client_with_http_client(&base_url, Arc::new(http_client.clone())).await?;
let tools = client
.list_tools(/*params*/ None, Some(Duration::from_secs(5)))
.await?;
assert_eq!(http_client.matching_post_attempts(), 2);
assert_eq!(tools.tools.len(), 1);
assert_eq!(tools.tools[0].name, "echo");
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn streamable_http_initialize_does_not_retry_non_retryable_status() -> anyhow::Result<()> {
let (_server, base_url) = spawn_streamable_http_server().await?;
arm_initialize_failure(&base_url, /*status*/ 403, /*remaining*/ 1).await?;
let error = match create_client(&base_url).await {
Ok(_) => panic!("initialize unexpectedly succeeded after non-retryable HTTP 403"),
Err(error) => error,
};
assert!(format!("{error:#}").contains("403"));
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn streamable_http_404_session_expiry_recovers_and_retries_once() -> anyhow::Result<()> {
let (_server, base_url) = spawn_streamable_http_server().await?;

View File

@@ -20,6 +20,7 @@ use anyhow::Context as _;
use codex_config::types::OAuthCredentialsStoreMode;
use codex_exec_server::Environment;
use codex_exec_server::ExecServerClient;
use codex_exec_server::HttpClient;
use codex_exec_server::RemoteExecServerConnectArgs;
use codex_rmcp_client::ElicitationAction;
use codex_rmcp_client::ElicitationResponse;
@@ -43,6 +44,7 @@ use tokio::process::Child;
use tokio::process::Command;
use tokio::time::sleep;
const INITIALIZE_FAILURE_CONTROL_PATH: &str = "/test/control/initialize-failure";
const SESSION_POST_FAILURE_CONTROL_PATH: &str = "/test/control/session-post-failure";
fn streamable_http_server_bin() -> Result<PathBuf, CargoBinError> {
@@ -74,6 +76,14 @@ pub(crate) fn expected_echo_result(message: &str) -> CallToolResult {
}
pub(crate) async fn create_client(base_url: &str) -> anyhow::Result<RmcpClient> {
create_client_with_http_client(base_url, Environment::default_for_tests().get_http_client())
.await
}
pub(crate) async fn create_client_with_http_client(
base_url: &str,
http_client: Arc<dyn HttpClient>,
) -> anyhow::Result<RmcpClient> {
let client = RmcpClient::new_streamable_http_client(
"test-streamable-http",
&format!("{base_url}/mcp"),
@@ -81,7 +91,7 @@ pub(crate) async fn create_client(base_url: &str) -> anyhow::Result<RmcpClient>
/*http_headers*/ None,
/*env_http_headers*/ None,
OAuthCredentialsStoreMode::File,
Environment::default_for_tests().get_http_client(),
http_client,
/*auth_provider*/ None,
)
.await?;
@@ -178,6 +188,24 @@ pub(crate) async fn arm_session_post_failure(
Ok(())
}
pub(crate) async fn arm_initialize_failure(
base_url: &str,
status: u16,
remaining: usize,
) -> anyhow::Result<()> {
let response = reqwest::Client::new()
.post(format!("{base_url}{INITIALIZE_FAILURE_CONTROL_PATH}"))
.json(&json!({
"status": status,
"remaining": remaining,
}))
.send()
.await?;
assert_eq!(response.status(), reqwest::StatusCode::NO_CONTENT);
Ok(())
}
pub(crate) async fn spawn_streamable_http_server() -> anyhow::Result<(Child, String)> {
let listener = TcpListener::bind("127.0.0.1:0")?;
let port = listener.local_addr()?.port();