mirror of
https://github.com/openai/codex.git
synced 2026-05-04 11:26:33 +00:00
[rmcp-client] Recover from streamable HTTP 404 sessions (#13514)
## Summary - add one-time session recovery in `RmcpClient` for streamable HTTP MCP `404` session expiry - rebuild the transport and retry the failed operation once after reinitializing the client state - extend the test server and integration coverage for `404`, `401`, single-retry, and non-session failure scenarios ## Testing - just fmt - cargo test -p codex-rmcp-client (the post-rebase run lost its final summary in the terminal; the suite had passed earlier before the rebase) - just fix -p codex-rmcp-client
This commit is contained in:
@@ -6,7 +6,9 @@ use std::sync::Arc;
|
||||
|
||||
use axum::Router;
|
||||
use axum::body::Body;
|
||||
use axum::extract::Json;
|
||||
use axum::extract::State;
|
||||
use axum::http::Method;
|
||||
use axum::http::Request;
|
||||
use axum::http::StatusCode;
|
||||
use axum::http::header::AUTHORIZATION;
|
||||
@@ -15,6 +17,7 @@ use axum::middleware;
|
||||
use axum::middleware::Next;
|
||||
use axum::response::Response;
|
||||
use axum::routing::get;
|
||||
use axum::routing::post;
|
||||
use rmcp::ErrorData as McpError;
|
||||
use rmcp::handler::server::ServerHandler;
|
||||
use rmcp::model::CallToolRequestParams;
|
||||
@@ -39,6 +42,7 @@ use rmcp::transport::StreamableHttpService;
|
||||
use rmcp::transport::streamable_http_server::session::local::LocalSessionManager;
|
||||
use serde::Deserialize;
|
||||
use serde_json::json;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::task;
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -50,6 +54,8 @@ struct TestToolServer {
|
||||
|
||||
const MEMO_URI: &str = "memo://codex/example-note";
|
||||
const MEMO_CONTENT: &str = "This is a sample MCP resource served by the rmcp test server.";
|
||||
const MCP_SESSION_ID_HEADER: &str = "mcp-session-id";
|
||||
const SESSION_POST_FAILURE_CONTROL_PATH: &str = "/test/control/session-post-failure";
|
||||
|
||||
impl TestToolServer {
|
||||
fn new() -> Self {
|
||||
@@ -116,6 +122,23 @@ impl TestToolServer {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
struct SessionFailureState {
|
||||
armed_failure: Arc<Mutex<Option<ArmedFailure>>>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct ArmedFailure {
|
||||
status: StatusCode,
|
||||
remaining: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ArmSessionPostFailureRequest {
|
||||
status: u16,
|
||||
remaining: usize,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct EchoArgs {
|
||||
message: String,
|
||||
@@ -251,6 +274,7 @@ fn parse_bind_addr() -> Result<SocketAddr, Box<dyn std::error::Error>> {
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let bind_addr = parse_bind_addr()?;
|
||||
let session_failure_state = SessionFailureState::default();
|
||||
let listener = match tokio::net::TcpListener::bind(&bind_addr).await {
|
||||
Ok(listener) => listener,
|
||||
Err(err) if err.kind() == ErrorKind::PermissionDenied => {
|
||||
@@ -264,6 +288,10 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
eprintln!("starting rmcp streamable http test server on http://{bind_addr}/mcp");
|
||||
|
||||
let router = Router::new()
|
||||
.route(
|
||||
SESSION_POST_FAILURE_CONTROL_PATH,
|
||||
post(arm_session_post_failure),
|
||||
)
|
||||
.route(
|
||||
"/.well-known/oauth-authorization-server/mcp",
|
||||
get({
|
||||
@@ -291,7 +319,12 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
Arc::new(LocalSessionManager::default()),
|
||||
StreamableHttpServerConfig::default(),
|
||||
),
|
||||
);
|
||||
)
|
||||
.layer(middleware::from_fn_with_state(
|
||||
session_failure_state.clone(),
|
||||
fail_session_post_when_armed,
|
||||
))
|
||||
.with_state(session_failure_state);
|
||||
|
||||
let router = if let Ok(token) = std::env::var("MCP_EXPECT_BEARER") {
|
||||
let expected = Arc::new(format!("Bearer {token}"));
|
||||
@@ -323,3 +356,52 @@ async fn require_bearer(
|
||||
Err(StatusCode::UNAUTHORIZED)
|
||||
}
|
||||
}
|
||||
|
||||
async fn arm_session_post_failure(
|
||||
State(state): State<SessionFailureState>,
|
||||
Json(request): Json<ArmSessionPostFailureRequest>,
|
||||
) -> Result<StatusCode, StatusCode> {
|
||||
let status = StatusCode::from_u16(request.status).map_err(|_| StatusCode::BAD_REQUEST)?;
|
||||
let armed_failure = if request.remaining == 0 {
|
||||
None
|
||||
} else {
|
||||
Some(ArmedFailure {
|
||||
status,
|
||||
remaining: request.remaining,
|
||||
})
|
||||
};
|
||||
*state.armed_failure.lock().await = armed_failure;
|
||||
Ok(StatusCode::NO_CONTENT)
|
||||
}
|
||||
|
||||
async fn fail_session_post_when_armed(
|
||||
State(state): State<SessionFailureState>,
|
||||
request: Request<Body>,
|
||||
next: Next,
|
||||
) -> Response {
|
||||
if request.uri().path() != "/mcp"
|
||||
|| request.method() != Method::POST
|
||||
|| !request.headers().contains_key(MCP_SESSION_ID_HEADER)
|
||||
{
|
||||
return next.run(request).await;
|
||||
}
|
||||
|
||||
let mut armed_failure = state.armed_failure.lock().await;
|
||||
if let Some(failure) = armed_failure.as_mut()
|
||||
&& failure.remaining > 0
|
||||
{
|
||||
failure.remaining -= 1;
|
||||
let status = failure.status;
|
||||
if failure.remaining == 0 {
|
||||
*armed_failure = None;
|
||||
}
|
||||
let mut response = Response::new(Body::from(format!(
|
||||
"forced session failure with status {status}"
|
||||
)));
|
||||
*response.status_mut() = status;
|
||||
return response;
|
||||
}
|
||||
|
||||
drop(armed_failure);
|
||||
next.run(request).await
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user