[rmcp-client] Recover from streamable HTTP 404 sessions (#13514)

## Summary
- add one-time session recovery in `RmcpClient` for streamable HTTP MCP
`404` session expiry
- rebuild the transport and retry the failed operation once after
reinitializing the client state
- extend the test server and integration coverage for `404`, `401`,
single-retry, and non-session failure scenarios

## Testing
- just fmt
- cargo test -p codex-rmcp-client (the post-rebase run lost its final
summary in the terminal; the suite had passed earlier before the rebase)
- just fix -p codex-rmcp-client
This commit is contained in:
Casey Chow
2026-03-06 10:02:42 -05:00
committed by GitHub
parent 5d4303510c
commit b3765a07e8
6 changed files with 1046 additions and 213 deletions

View File

@@ -6,7 +6,9 @@ use std::sync::Arc;
use axum::Router;
use axum::body::Body;
use axum::extract::Json;
use axum::extract::State;
use axum::http::Method;
use axum::http::Request;
use axum::http::StatusCode;
use axum::http::header::AUTHORIZATION;
@@ -15,6 +17,7 @@ use axum::middleware;
use axum::middleware::Next;
use axum::response::Response;
use axum::routing::get;
use axum::routing::post;
use rmcp::ErrorData as McpError;
use rmcp::handler::server::ServerHandler;
use rmcp::model::CallToolRequestParams;
@@ -39,6 +42,7 @@ use rmcp::transport::StreamableHttpService;
use rmcp::transport::streamable_http_server::session::local::LocalSessionManager;
use serde::Deserialize;
use serde_json::json;
use tokio::sync::Mutex;
use tokio::task;
#[derive(Clone)]
@@ -50,6 +54,8 @@ struct TestToolServer {
const MEMO_URI: &str = "memo://codex/example-note";
const MEMO_CONTENT: &str = "This is a sample MCP resource served by the rmcp test server.";
const MCP_SESSION_ID_HEADER: &str = "mcp-session-id";
const SESSION_POST_FAILURE_CONTROL_PATH: &str = "/test/control/session-post-failure";
impl TestToolServer {
fn new() -> Self {
@@ -116,6 +122,23 @@ impl TestToolServer {
}
}
#[derive(Clone, Default)]
struct SessionFailureState {
armed_failure: Arc<Mutex<Option<ArmedFailure>>>,
}
#[derive(Clone, Debug)]
struct ArmedFailure {
status: StatusCode,
remaining: usize,
}
#[derive(Debug, Deserialize)]
struct ArmSessionPostFailureRequest {
status: u16,
remaining: usize,
}
#[derive(Deserialize)]
struct EchoArgs {
message: String,
@@ -251,6 +274,7 @@ fn parse_bind_addr() -> Result<SocketAddr, Box<dyn std::error::Error>> {
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let bind_addr = parse_bind_addr()?;
let session_failure_state = SessionFailureState::default();
let listener = match tokio::net::TcpListener::bind(&bind_addr).await {
Ok(listener) => listener,
Err(err) if err.kind() == ErrorKind::PermissionDenied => {
@@ -264,6 +288,10 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
eprintln!("starting rmcp streamable http test server on http://{bind_addr}/mcp");
let router = Router::new()
.route(
SESSION_POST_FAILURE_CONTROL_PATH,
post(arm_session_post_failure),
)
.route(
"/.well-known/oauth-authorization-server/mcp",
get({
@@ -291,7 +319,12 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
Arc::new(LocalSessionManager::default()),
StreamableHttpServerConfig::default(),
),
);
)
.layer(middleware::from_fn_with_state(
session_failure_state.clone(),
fail_session_post_when_armed,
))
.with_state(session_failure_state);
let router = if let Ok(token) = std::env::var("MCP_EXPECT_BEARER") {
let expected = Arc::new(format!("Bearer {token}"));
@@ -323,3 +356,52 @@ async fn require_bearer(
Err(StatusCode::UNAUTHORIZED)
}
}
async fn arm_session_post_failure(
State(state): State<SessionFailureState>,
Json(request): Json<ArmSessionPostFailureRequest>,
) -> Result<StatusCode, StatusCode> {
let status = StatusCode::from_u16(request.status).map_err(|_| StatusCode::BAD_REQUEST)?;
let armed_failure = if request.remaining == 0 {
None
} else {
Some(ArmedFailure {
status,
remaining: request.remaining,
})
};
*state.armed_failure.lock().await = armed_failure;
Ok(StatusCode::NO_CONTENT)
}
async fn fail_session_post_when_armed(
State(state): State<SessionFailureState>,
request: Request<Body>,
next: Next,
) -> Response {
if request.uri().path() != "/mcp"
|| request.method() != Method::POST
|| !request.headers().contains_key(MCP_SESSION_ID_HEADER)
{
return next.run(request).await;
}
let mut armed_failure = state.armed_failure.lock().await;
if let Some(failure) = armed_failure.as_mut()
&& failure.remaining > 0
{
failure.remaining -= 1;
let status = failure.status;
if failure.remaining == 0 {
*armed_failure = None;
}
let mut response = Response::new(Body::from(format!(
"forced session failure with status {status}"
)));
*response.status_mut() = status;
return response;
}
drop(armed_failure);
next.run(request).await
}