Compare commits

...

2 Commits

Author SHA1 Message Date
nicholasclark-openai
e9e5bea81b Propagate RMCP trace context
Add traceparent/tracestate propagation for RMCP streamable HTTP and stdio calls, and cover it with transport-level tests.

Co-authored-by: Codex <noreply@openai.com>
2026-03-25 15:19:26 -07:00
nicholasclark-openai
bf5f89b535 Add Responses and HTTP child spans
Layer the Responses-specific request spans and HTTP transport child spans on top of the MCP tracing branch, and add focused tracing assertions for the Responses path.

Co-authored-by: Codex <noreply@openai.com>
2026-03-25 15:19:23 -07:00
13 changed files with 640 additions and 84 deletions

5
codex-rs/Cargo.lock generated
View File

@@ -2514,6 +2514,7 @@ dependencies = [
"axum",
"codex-client",
"codex-keyring-store",
"codex-otel",
"codex-protocol",
"codex-utils-cargo-bin",
"codex-utils-home-dir",
@@ -2521,6 +2522,8 @@ dependencies = [
"futures",
"keyring",
"oauth2",
"opentelemetry",
"opentelemetry_sdk",
"pretty_assertions",
"reqwest",
"rmcp",
@@ -2535,6 +2538,8 @@ dependencies = [
"tiny_http",
"tokio",
"tracing",
"tracing-opentelemetry",
"tracing-subscriber",
"urlencoding",
"webbrowser",
"which 8.0.0",

View File

@@ -10,7 +10,9 @@ use reqwest::Response;
use serde::Serialize;
use std::fmt::Display;
use std::time::Duration;
use tracing::Instrument;
use tracing::Span;
use tracing::field::Empty;
use tracing_opentelemetry::OpenTelemetrySpanExt;
#[derive(Clone, Debug)]
@@ -111,10 +113,39 @@ impl CodexRequestBuilder {
}
pub async fn send(self) -> Result<Response, reqwest::Error> {
let headers = trace_headers();
let parsed_url = reqwest::Url::parse(&self.url).ok();
let path = parsed_url
.as_ref()
.map(|url| url.path().to_string())
.unwrap_or_else(|| self.url.clone());
let request_span = tracing::info_span!(
"http.client",
otel.kind = "client",
http.request.method = %self.method,
http.response.status_code = Empty,
url.path = %path,
server.address = Empty,
server.port = Empty,
);
if let Some(url) = parsed_url.as_ref() {
if let Some(host) = url.host_str() {
request_span.record("server.address", host);
}
if let Some(port) = url.port_or_known_default() {
request_span.record("server.port", port as i64);
}
}
let headers = trace_headers_for_span(&request_span);
match self.builder.headers(headers).send().await {
match async { self.builder.headers(headers).send().await }
.instrument(request_span.clone())
.await
{
Ok(response) => {
request_span.record(
"http.response.status_code",
response.status().as_u16() as i64,
);
tracing::debug!(
method = %self.method,
url = %self.url,
@@ -127,11 +158,14 @@ impl CodexRequestBuilder {
Ok(response)
}
Err(error) => {
let status = error.status();
let status = error.status().map(|status| status.as_u16() as i64);
if let Some(status) = status {
request_span.record("http.response.status_code", status);
}
tracing::debug!(
method = %self.method,
url = %self.url,
status = status.map(|s| s.as_u16()),
status,
error = %error,
"Request failed"
);
@@ -154,13 +188,15 @@ impl<'a> Injector for HeaderMapInjector<'a> {
}
}
#[cfg(test)]
fn trace_headers() -> HeaderMap {
trace_headers_for_span(&Span::current())
}
fn trace_headers_for_span(span: &Span) -> HeaderMap {
let mut headers = HeaderMap::new();
global::get_text_map_propagator(|prop| {
prop.inject_context(
&Span::current().context(),
&mut HeaderMapInjector(&mut headers),
);
prop.inject_context(&span.context(), &mut HeaderMapInjector(&mut headers));
});
headers
}
@@ -204,6 +240,33 @@ mod tests {
assert_eq!(extracted_context.span_id(), span_context.span_id());
}
#[test]
fn inject_trace_headers_for_span_uses_explicit_span_context() {
global::set_text_map_propagator(TraceContextPropagator::new());
let provider = SdkTracerProvider::builder().build();
let tracer = provider.tracer("test-tracer");
let subscriber =
tracing_subscriber::registry().with(tracing_opentelemetry::layer().with_tracer(tracer));
let _guard = subscriber.set_default();
let parent = trace_span!("parent");
let _parent_entered = parent.enter();
let child = trace_span!("child");
let child_context = child.context().span().span_context().clone();
let headers = trace_headers_for_span(&child);
let extractor = HeaderMapExtractor(&headers);
let extracted = TraceContextPropagator::new().extract(&extractor);
let extracted_span = extracted.span();
let extracted_context = extracted_span.span_context();
assert!(extracted_context.is_valid());
assert_eq!(extracted_context.trace_id(), child_context.trace_id());
assert_eq!(extracted_context.span_id(), child_context.span_id());
}
struct HeaderMapExtractor<'a>(&'a HeaderMap);
impl<'a> Extractor for HeaderMapExtractor<'a> {

View File

@@ -86,6 +86,7 @@ use tokio::sync::oneshot;
use tokio::sync::oneshot::error::TryRecvError;
use tokio_tungstenite::tungstenite::Error;
use tokio_tungstenite::tungstenite::Message;
use tracing::Instrument;
use tracing::instrument;
use tracing::trace;
use tracing::warn;
@@ -1023,6 +1024,7 @@ impl ModelClientSession {
let mut pending_retry = PendingUnauthorizedRetry::default();
loop {
let client_setup = self.client.current_client_setup().await?;
let api_base_url = client_setup.api_provider.base_url.clone();
let transport = ReqwestTransport::new(build_reqwest_client());
let request_auth_context = AuthRequestTelemetryContext::new(
client_setup.auth.as_ref().map(CodexAuth::auth_mode),
@@ -1052,10 +1054,20 @@ impl ModelClientSession {
client_setup.api_auth,
)
.with_telemetry(Some(request_telemetry), Some(sse_telemetry));
let stream_result = client.stream_request(request, options).await;
let request_span = crate::network_trace::responses_http_request_span(
&self.client.state.conversation_id,
turn_metadata_header,
&self.client.state.provider.name,
&model_info.slug,
&api_base_url,
);
let stream_result = async { client.stream_request(request, options).await }
.instrument(request_span.clone())
.await;
match stream_result {
Ok(stream) => {
let _entered = request_span.enter();
let (stream, _) = map_response_stream(stream, session_telemetry.clone());
return Ok(stream);
}
@@ -1414,72 +1426,76 @@ where
{
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent>>(1600);
let (tx_last_response, rx_last_response) = oneshot::channel::<LastResponse>();
let current_span = tracing::Span::current();
tokio::spawn(async move {
let mut logged_error = false;
let mut tx_last_response = Some(tx_last_response);
let mut items_added: Vec<ResponseItem> = Vec::new();
let mut api_stream = api_stream;
while let Some(event) = api_stream.next().await {
match event {
Ok(ResponseEvent::OutputItemDone(item)) => {
items_added.push(item.clone());
if tx_event
.send(Ok(ResponseEvent::OutputItemDone(item)))
.await
.is_err()
{
return;
tokio::spawn(
async move {
let mut logged_error = false;
let mut tx_last_response = Some(tx_last_response);
let mut items_added: Vec<ResponseItem> = Vec::new();
let mut api_stream = api_stream;
while let Some(event) = api_stream.next().await {
match event {
Ok(ResponseEvent::OutputItemDone(item)) => {
items_added.push(item.clone());
if tx_event
.send(Ok(ResponseEvent::OutputItemDone(item)))
.await
.is_err()
{
return;
}
}
}
Ok(ResponseEvent::Completed {
response_id,
token_usage,
}) => {
if let Some(usage) = &token_usage {
session_telemetry.sse_event_completed(
usage.input_tokens,
usage.output_tokens,
Some(usage.cached_input_tokens),
Some(usage.reasoning_output_tokens),
usage.total_tokens,
);
Ok(ResponseEvent::Completed {
response_id,
token_usage,
}) => {
if let Some(usage) = &token_usage {
session_telemetry.sse_event_completed(
usage.input_tokens,
usage.output_tokens,
Some(usage.cached_input_tokens),
Some(usage.reasoning_output_tokens),
usage.total_tokens,
);
}
if let Some(sender) = tx_last_response.take() {
let _ = sender.send(LastResponse {
response_id: response_id.clone(),
items_added: std::mem::take(&mut items_added),
});
}
if tx_event
.send(Ok(ResponseEvent::Completed {
response_id,
token_usage,
}))
.await
.is_err()
{
return;
}
}
if let Some(sender) = tx_last_response.take() {
let _ = sender.send(LastResponse {
response_id: response_id.clone(),
items_added: std::mem::take(&mut items_added),
});
Ok(event) => {
if tx_event.send(Ok(event)).await.is_err() {
return;
}
}
if tx_event
.send(Ok(ResponseEvent::Completed {
response_id,
token_usage,
}))
.await
.is_err()
{
return;
}
}
Ok(event) => {
if tx_event.send(Ok(event)).await.is_err() {
return;
}
}
Err(err) => {
let mapped = map_api_error(err);
if !logged_error {
session_telemetry.see_event_completed_failed(&mapped);
logged_error = true;
}
if tx_event.send(Err(mapped)).await.is_err() {
return;
Err(err) => {
let mapped = map_api_error(err);
if !logged_error {
session_telemetry.see_event_completed_failed(&mapped);
logged_error = true;
}
if tx_event.send(Err(mapped)).await.is_err() {
return;
}
}
}
}
}
});
.instrument(current_span),
);
(ResponseStream { rx_event }, rx_last_response)
}

View File

@@ -52,6 +52,7 @@ mod mcp_tool_approval_templates;
pub mod models_manager;
mod network_policy_decision;
pub mod network_proxy_loader;
mod network_trace;
mod original_image_detail;
mod packages;
pub use mcp_connection_manager::MCP_SANDBOX_STATE_CAPABILITY;

View File

@@ -0,0 +1,85 @@
use codex_protocol::ThreadId;
use serde::Deserialize;
use tracing::Span;
use tracing::field::Empty;
use url::Url;
#[derive(Debug, Default, Deserialize)]
struct TurnTraceCorrelation {
session_id: Option<String>,
turn_id: Option<String>,
}
struct CorrelationFields {
conversation_id: String,
session_id: String,
turn_id: Option<String>,
}
impl CorrelationFields {
fn from_turn_metadata_header(
conversation_id: &ThreadId,
turn_metadata_header: Option<&str>,
) -> Self {
let conversation_id = conversation_id.to_string();
let correlation = turn_metadata_header
.and_then(|header| serde_json::from_str::<TurnTraceCorrelation>(header).ok())
.unwrap_or_default();
let session_id = correlation
.session_id
.unwrap_or_else(|| conversation_id.clone());
Self {
conversation_id,
session_id,
turn_id: correlation.turn_id,
}
}
fn record_on(&self, span: &Span) {
span.record("conversation.id", self.conversation_id.as_str());
span.record("session.id", self.session_id.as_str());
if let Some(turn_id) = self.turn_id.as_deref() {
span.record("turn.id", turn_id);
}
}
}
fn record_server_fields(span: &Span, url: Option<&str>) {
let Some(url) = url else {
return;
};
let Ok(parsed) = Url::parse(url) else {
return;
};
if let Some(host) = parsed.host_str() {
span.record("server.address", host);
}
if let Some(port) = parsed.port_or_known_default() {
span.record("server.port", port as i64);
}
}
pub(crate) fn responses_http_request_span(
conversation_id: &ThreadId,
turn_metadata_header: Option<&str>,
provider_name: &str,
model: &str,
base_url: &str,
) -> Span {
let span = tracing::info_span!(
"responses_http.request",
otel.kind = "client",
provider = provider_name,
model,
transport = "responses_http",
api.path = "responses",
conversation.id = Empty,
session.id = Empty,
turn.id = Empty,
server.address = Empty,
server.port = Empty,
);
CorrelationFields::from_turn_metadata_header(conversation_id, turn_metadata_header)
.record_on(&span);
record_server_fields(&span, Some(base_url));
span
}

View File

@@ -197,11 +197,12 @@ impl Session {
let ctx = Arc::clone(&turn_context);
let task_for_run = Arc::clone(&task);
let task_cancellation_token = cancellation_token.child_token();
// Task-owned turn spans keep a core-owned span open for the
// full task lifecycle after the submission dispatch span ends.
// Task-owned turn spans keep a core-owned parent span open for the
// full turn lifecycle after the submission dispatch span ends.
let task_span = info_span!(
"turn",
otel.name = span_name,
conversation.id = %self.conversation_id,
thread.id = %self.conversation_id,
turn.id = %turn_context.sub_id,
model = %turn_context.model_info.slug,

View File

@@ -718,6 +718,57 @@ async fn record_responses_sets_span_fields_for_response_events() {
}
}
#[tokio::test(flavor = "current_thread")]
#[traced_test]
async fn responses_request_span_records_turn_correlation_fields() {
let server = start_mock_server().await;
mount_sse_once(
&server,
sse(vec![ev_response_created("resp-1"), ev_completed("resp-1")]),
)
.await;
let TestCodex { codex, .. } = test_codex().build(&server).await.unwrap();
codex
.submit(Op::UserInput {
items: vec![UserInput::Text {
text: "hello".into(),
text_elements: Vec::new(),
}],
final_output_json_schema: None,
})
.await
.unwrap();
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TurnComplete(_))).await;
logs_assert(|lines: &[&str]| {
lines
.iter()
.find(|line| {
line.contains("turn{otel.name=\"session_task.turn\"")
&& line.contains("conversation.id=")
&& line.contains("thread.id=")
&& line.contains("turn.id=")
&& line.contains("model=")
&& line.contains("responses_http.request{")
&& line.contains("otel.kind=\"client\"")
&& line.contains("transport=\"responses_http\"")
&& line.contains("conversation.id=")
&& line.contains("session.id=")
&& line.contains("turn.id=")
})
.map(|_| Ok(()))
.unwrap_or_else(|| {
Err(
"missing responses_http.request span nested under session_task.turn"
.to_string(),
)
})
});
}
#[tokio::test]
#[traced_test]
async fn handle_response_item_records_tool_result_for_custom_tool_call() {

View File

@@ -36,6 +36,7 @@ use core_test_support::responses::mount_sse_once;
use core_test_support::skip_if_no_network;
use core_test_support::stdio_server_bin;
use core_test_support::test_codex::test_codex;
use core_test_support::tracing::install_test_tracing;
use core_test_support::wait_for_event;
use core_test_support::wait_for_event_with_timeout;
use reqwest::Client;
@@ -687,6 +688,8 @@ async fn stdio_server_propagates_whitelisted_env_vars() -> anyhow::Result<()> {
async fn streamable_http_tool_call_round_trip() -> anyhow::Result<()> {
skip_if_no_network!(Ok(()));
let _trace_test_context = install_test_tracing("rmcp-integration-tests");
let server = responses::start_mock_server().await;
let call_id = "call-456";
@@ -733,6 +736,7 @@ async fn streamable_http_tool_call_round_trip() -> anyhow::Result<()> {
.kill_on_drop(true)
.env("MCP_STREAMABLE_HTTP_BIND_ADDR", &bind_addr)
.env("MCP_TEST_VALUE", expected_env_value)
.env("MCP_EXPECT_TRACEPARENT", "1")
.spawn()?;
wait_for_streamable_http_server(&mut http_server_child, &bind_addr, Duration::from_secs(5))

View File

@@ -15,6 +15,7 @@ axum = { workspace = true, default-features = false, features = [
] }
codex-client = { workspace = true }
codex-keyring-store = { workspace = true }
codex-otel = { workspace = true }
codex-protocol = { workspace = true }
codex-utils-pty = { workspace = true }
codex-utils-home-dir = { workspace = true }
@@ -60,9 +61,13 @@ which = { workspace = true }
[dev-dependencies]
codex-utils-cargo-bin = { workspace = true }
opentelemetry = { workspace = true }
opentelemetry_sdk = { workspace = true }
pretty_assertions = { workspace = true }
serial_test = { workspace = true }
tempfile = { workspace = true }
tracing-opentelemetry = { workspace = true }
tracing-subscriber = { workspace = true }
[target.'cfg(target_os = "linux")'.dependencies]
keyring = { workspace = true, features = ["linux-native-async-persistent"] }

View File

@@ -315,7 +315,7 @@ impl ServerHandler for TestToolServer {
async fn call_tool(
&self,
request: CallToolRequestParams,
_context: rmcp::service::RequestContext<rmcp::service::RoleServer>,
context: rmcp::service::RequestContext<rmcp::service::RoleServer>,
) -> Result<CallToolResult, McpError> {
match request.name.as_ref() {
"echo" | "echo-tool" => {
@@ -333,9 +333,19 @@ impl ServerHandler for TestToolServer {
};
let env_snapshot: HashMap<String, String> = std::env::vars().collect();
let traceparent = context
.meta
.get("x-codex-traceparent")
.and_then(serde_json::Value::as_str);
let tracestate = context
.meta
.get("x-codex-tracestate")
.and_then(serde_json::Value::as_str);
let structured_content = json!({
"echo": format!("ECHOING: {}", args.message),
"env": env_snapshot.get("MCP_TEST_VALUE"),
"traceparent": traceparent,
"tracestate": tracestate,
});
Ok(CallToolResult {

View File

@@ -58,6 +58,7 @@ struct TestToolServer {
const MEMO_URI: &str = "memo://codex/example-note";
const MEMO_CONTENT: &str = "This is a sample MCP resource served by the rmcp test server.";
const MCP_SESSION_ID_HEADER: &str = "mcp-session-id";
const TRACEPARENT_HEADER: &str = "traceparent";
const SESSION_POST_FAILURE_CONTROL_PATH: &str = "/test/control/session-post-failure";
impl TestToolServer {
@@ -347,6 +348,11 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
} else {
router
};
let router = if std::env::var("MCP_EXPECT_TRACEPARENT").is_ok() {
router.layer(middleware::from_fn(require_traceparent_on_session_post))
} else {
router
};
axum::serve(listener, router).await?;
task::yield_now().await;
@@ -389,6 +395,23 @@ async fn arm_session_post_failure(
Ok(StatusCode::NO_CONTENT)
}
async fn require_traceparent_on_session_post(request: Request<Body>, next: Next) -> Response {
if request.uri().path() != "/mcp"
|| request.method() != Method::POST
|| !request.headers().contains_key(MCP_SESSION_ID_HEADER)
{
return next.run(request).await;
}
if request.headers().contains_key(TRACEPARENT_HEADER) {
next.run(request).await
} else {
let mut response = Response::new(Body::from("missing traceparent header"));
*response.status_mut() = StatusCode::BAD_REQUEST;
response
}
}
async fn fail_session_post_when_armed(
State(state): State<SessionFailureState>,
request: Request<Body>,

View File

@@ -10,6 +10,9 @@ use std::time::Duration;
use anyhow::Result;
use anyhow::anyhow;
use codex_client::build_reqwest_client_with_custom_ca;
use codex_otel::current_span_w3c_trace_context;
use codex_otel::span_w3c_trace_context;
use codex_protocol::protocol::W3cTraceContext;
use futures::FutureExt;
use futures::StreamExt;
use futures::future::BoxFuture;
@@ -64,6 +67,8 @@ use tokio::io::BufReader;
use tokio::process::Command;
use tokio::sync::Mutex;
use tokio::time;
use tracing::Instrument;
use tracing::field::Empty;
use tracing::info;
use tracing::warn;
@@ -82,6 +87,14 @@ const JSON_MIME_TYPE: &str = "application/json";
const HEADER_LAST_EVENT_ID: &str = "Last-Event-Id";
const HEADER_SESSION_ID: &str = "Mcp-Session-Id";
const NON_JSON_RESPONSE_BODY_PREVIEW_BYTES: usize = 8_192;
const TRACEPARENT_HEADER: &str = "traceparent";
const TRACESTATE_HEADER: &str = "tracestate";
const TRACEPARENT_META_KEY: &str = "x-codex-traceparent";
const TRACESTATE_META_KEY: &str = "x-codex-tracestate";
tokio::task_local! {
static SERVICE_OPERATION_TRACE_CONTEXT: Option<W3cTraceContext>;
}
#[derive(Clone)]
struct StreamableHttpResponseClient {
@@ -98,6 +111,24 @@ impl StreamableHttpResponseClient {
) -> StreamableHttpError<StreamableHttpResponseClientError> {
StreamableHttpError::Client(StreamableHttpResponseClientError::from(error))
}
fn apply_trace_context(
request: reqwest::RequestBuilder,
trace: Option<&W3cTraceContext>,
) -> reqwest::RequestBuilder {
let Some(trace) = trace else {
return request;
};
let mut request = request;
if let Some(traceparent) = trace.traceparent.as_deref() {
request = request.header(TRACEPARENT_HEADER, traceparent);
}
if let Some(tracestate) = trace.tracestate.as_deref() {
request = request.header(TRACESTATE_HEADER, tracestate);
}
request
}
}
fn build_http_client(default_headers: &HeaderMap) -> Result<reqwest::Client> {
@@ -123,6 +154,7 @@ impl StreamableHttpClient for StreamableHttpResponseClient {
session_id: Option<Arc<str>>,
auth_token: Option<String>,
) -> std::result::Result<StreamableHttpPostResponse, StreamableHttpError<Self::Error>> {
let trace = current_service_operation_trace_context();
let mut request = self
.inner
.post(uri.as_ref())
@@ -133,6 +165,7 @@ impl StreamableHttpClient for StreamableHttpResponseClient {
if let Some(session_id_value) = session_id.as_ref() {
request = request.header(HEADER_SESSION_ID, session_id_value.as_ref());
}
request = Self::apply_trace_context(request, trace.as_ref());
let response = request
.json(&message)
@@ -224,10 +257,12 @@ impl StreamableHttpClient for StreamableHttpResponseClient {
session: Arc<str>,
auth_token: Option<String>,
) -> std::result::Result<(), StreamableHttpError<Self::Error>> {
let trace = current_service_operation_trace_context();
let mut request_builder = self.inner.delete(uri.as_ref());
if let Some(auth_header) = auth_token {
request_builder = request_builder.bearer_auth(auth_header);
}
request_builder = Self::apply_trace_context(request_builder, trace.as_ref());
let response = request_builder
.header(HEADER_SESSION_ID, session.as_ref())
.send()
@@ -254,6 +289,7 @@ impl StreamableHttpClient for StreamableHttpResponseClient {
BoxStream<'static, std::result::Result<Sse, sse_stream::Error>>,
StreamableHttpError<Self::Error>,
> {
let trace = current_service_operation_trace_context();
let mut request_builder = self
.inner
.get(uri.as_ref())
@@ -265,6 +301,7 @@ impl StreamableHttpClient for StreamableHttpResponseClient {
if let Some(auth_header) = auth_token {
request_builder = request_builder.bearer_auth(auth_header);
}
request_builder = Self::apply_trace_context(request_builder, trace.as_ref());
let response = request_builder
.send()
@@ -733,6 +770,8 @@ impl RmcpClient {
let rmcp_params = rmcp_params.clone();
let meta = meta.clone();
async move {
let trace = current_service_operation_trace_context();
let meta = merge_trace_context_into_meta(meta, trace.as_ref());
let result = service
.peer()
.send_request_with_option(
@@ -1052,41 +1091,104 @@ impl RmcpClient {
Fut: std::future::Future<Output = std::result::Result<T, rmcp::service::ServiceError>>,
{
let service = self.service().await?;
match Self::run_service_operation_once(Arc::clone(&service), label, timeout, &operation)
.await
let operation_span = self.service_operation_span(label);
let operation_trace = span_w3c_trace_context(&operation_span);
match Self::run_service_operation_once(
Arc::clone(&service),
label,
timeout,
operation_span,
operation_trace,
&operation,
)
.await
{
Ok(result) => Ok(result),
Err(error) if Self::is_session_expired_404(&error) => {
self.reinitialize_after_session_expiry(&service).await?;
let recovered_service = self.service().await?;
Self::run_service_operation_once(recovered_service, label, timeout, &operation)
.await
.map_err(Into::into)
let operation_span = self.service_operation_span(label);
let operation_trace = span_w3c_trace_context(&operation_span);
Self::run_service_operation_once(
recovered_service,
label,
timeout,
operation_span,
operation_trace,
&operation,
)
.await
.map_err(Into::into)
}
Err(error) => Err(error.into()),
}
}
fn service_operation_span(&self, label: &str) -> tracing::Span {
let span = tracing::info_span!(
"mcp.client.operation",
otel.kind = "client",
rpc.system = "jsonrpc",
rpc.method = label,
mcp.transport = Empty,
mcp.server.name = Empty,
server.address = Empty,
server.port = Empty,
);
match &self.transport_recipe {
TransportRecipe::Stdio { .. } => {
span.record("mcp.transport", "stdio");
}
TransportRecipe::StreamableHttp {
server_name, url, ..
} => {
span.record("mcp.transport", "streamable_http");
span.record("mcp.server.name", server_name.as_str());
if let Ok(parsed_url) = reqwest::Url::parse(url) {
if let Some(host) = parsed_url.host_str() {
span.record("server.address", host);
}
if let Some(port) = parsed_url.port_or_known_default() {
span.record("server.port", port as i64);
}
}
}
}
span
}
async fn run_service_operation_once<T, F, Fut>(
service: Arc<RunningService<RoleClient, LoggingClientHandler>>,
label: &str,
timeout: Option<Duration>,
operation_span: tracing::Span,
operation_trace: Option<W3cTraceContext>,
operation: &F,
) -> std::result::Result<T, ClientOperationError>
where
F: Fn(Arc<RunningService<RoleClient, LoggingClientHandler>>) -> Fut,
Fut: std::future::Future<Output = std::result::Result<T, rmcp::service::ServiceError>>,
{
match timeout {
Some(duration) => time::timeout(duration, operation(service))
SERVICE_OPERATION_TRACE_CONTEXT
.scope(operation_trace, async move {
async move {
match timeout {
Some(duration) => time::timeout(duration, operation(service))
.await
.map_err(|_| ClientOperationError::Timeout {
label: label.to_string(),
duration,
})?
.map_err(ClientOperationError::from),
None => operation(service).await.map_err(ClientOperationError::from),
}
}
.instrument(operation_span)
.await
.map_err(|_| ClientOperationError::Timeout {
label: label.to_string(),
duration,
})?
.map_err(ClientOperationError::from),
None => operation(service).await.map_err(ClientOperationError::from),
}
})
.await
}
fn is_session_expired_404(error: &ClientOperationError) -> bool {
@@ -1161,6 +1263,39 @@ impl RmcpClient {
}
}
fn merge_trace_context_into_meta(
meta: Option<rmcp::model::Meta>,
trace: Option<&W3cTraceContext>,
) -> Option<rmcp::model::Meta> {
let mut meta = meta.unwrap_or_default();
let Some(trace) = trace else {
return (!meta.is_empty()).then_some(meta);
};
if let Some(traceparent) = trace.traceparent.as_ref() {
meta.insert(
TRACEPARENT_META_KEY.to_string(),
serde_json::Value::String(traceparent.clone()),
);
}
if let Some(tracestate) = trace.tracestate.as_ref() {
meta.insert(
TRACESTATE_META_KEY.to_string(),
serde_json::Value::String(tracestate.clone()),
);
}
(!meta.is_empty()).then_some(meta)
}
fn current_service_operation_trace_context() -> Option<W3cTraceContext> {
SERVICE_OPERATION_TRACE_CONTEXT
.try_with(Clone::clone)
.ok()
.flatten()
.or_else(current_span_w3c_trace_context)
}
async fn create_oauth_transport_and_runtime(
server_name: &str,
url: &str,
@@ -1207,3 +1342,75 @@ async fn create_oauth_transport_and_runtime(
Ok((transport, runtime))
}
#[cfg(test)]
mod trace_tests {
use super::StreamableHttpResponseClient;
use super::TRACEPARENT_HEADER;
use super::TRACEPARENT_META_KEY;
use super::TRACESTATE_HEADER;
use super::TRACESTATE_META_KEY;
use super::merge_trace_context_into_meta;
use codex_protocol::protocol::W3cTraceContext;
use pretty_assertions::assert_eq;
use reqwest::Client;
use serde_json::Value;
#[test]
fn merge_trace_context_into_meta_preserves_existing_fields() {
let trace = W3cTraceContext {
traceparent: Some("00-aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa-bbbbbbbbbbbbbbbb-01".into()),
tracestate: Some("vendor=value".into()),
};
let meta = rmcp::model::Meta(serde_json::Map::from_iter([(
"existing".to_string(),
Value::String("value".into()),
)]));
let merged = merge_trace_context_into_meta(Some(meta), Some(&trace)).expect("meta");
assert_eq!(
merged,
rmcp::model::Meta(serde_json::Map::from_iter([
("existing".to_string(), Value::String("value".into())),
(
TRACEPARENT_META_KEY.to_string(),
Value::String("00-aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa-bbbbbbbbbbbbbbbb-01".into())
),
(
TRACESTATE_META_KEY.to_string(),
Value::String("vendor=value".into())
),
]))
);
}
#[test]
fn apply_trace_context_injects_http_headers() {
let trace = W3cTraceContext {
traceparent: Some("00-aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa-bbbbbbbbbbbbbbbb-01".into()),
tracestate: Some("vendor=value".into()),
};
let request = StreamableHttpResponseClient::apply_trace_context(
Client::new().post("http://example.com"),
Some(&trace),
)
.build()
.expect("request");
assert_eq!(
request
.headers()
.get(TRACEPARENT_HEADER)
.and_then(|value| value.to_str().ok()),
Some("00-aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa-bbbbbbbbbbbbbbbb-01")
);
assert_eq!(
request
.headers()
.get(TRACESTATE_HEADER)
.and_then(|value| value.to_str().ok()),
Some("vendor=value")
);
}
}

View File

@@ -7,6 +7,10 @@ use codex_rmcp_client::ElicitationResponse;
use codex_rmcp_client::RmcpClient;
use codex_utils_cargo_bin::CargoBinError;
use futures::FutureExt as _;
use opentelemetry::global;
use opentelemetry::trace::TracerProvider as _;
use opentelemetry_sdk::propagation::TraceContextPropagator;
use opentelemetry_sdk::trace::SdkTracerProvider;
use rmcp::model::AnnotateAble;
use rmcp::model::ClientCapabilities;
use rmcp::model::ElicitationCapability;
@@ -18,6 +22,10 @@ use rmcp::model::ProtocolVersion;
use rmcp::model::ReadResourceRequestParams;
use rmcp::model::ResourceContents;
use serde_json::json;
use tracing::Instrument;
use tracing::dispatcher::DefaultGuard;
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
const RESOURCE_URI: &str = "memo://codex/example-note";
@@ -53,6 +61,25 @@ fn init_params() -> InitializeRequestParams {
}
}
struct TestTracingContext {
_provider: SdkTracerProvider,
_guard: DefaultGuard,
}
fn install_test_tracing(tracer_name: &str) -> TestTracingContext {
global::set_text_map_propagator(TraceContextPropagator::new());
let provider = SdkTracerProvider::builder().build();
let tracer = provider.tracer(tracer_name.to_string());
let subscriber =
tracing_subscriber::registry().with(tracing_opentelemetry::layer().with_tracer(tracer));
TestTracingContext {
_provider: provider,
_guard: subscriber.set_default(),
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn rmcp_client_can_list_and_read_resources() -> anyhow::Result<()> {
let client = RmcpClient::new_stdio_client(
@@ -149,3 +176,61 @@ async fn rmcp_client_can_list_and_read_resources() -> anyhow::Result<()> {
Ok(())
}
#[tokio::test(flavor = "current_thread")]
async fn stdio_tool_call_propagates_trace_metadata() -> anyhow::Result<()> {
let _trace = install_test_tracing("rmcp-stdio-trace-test");
let client = RmcpClient::new_stdio_client(
stdio_server_bin()?.into(),
Vec::<OsString>::new(),
None,
&[],
None,
)
.await?;
client
.initialize(
init_params(),
Some(Duration::from_secs(5)),
Box::new(|_, _| {
async {
Ok(ElicitationResponse {
action: ElicitationAction::Accept,
content: Some(json!({})),
meta: None,
})
}
.boxed()
}),
)
.await?;
let result = async {
client
.call_tool(
"echo".to_string(),
Some(json!({ "message": "ping" })),
None,
Some(Duration::from_secs(5)),
)
.await
}
.instrument(tracing::info_span!("rmcp.client.trace_test"))
.await?;
assert_eq!(result.is_error, Some(false));
let structured = result.structured_content.expect("structured content");
assert_eq!(structured["echo"], json!("ECHOING: ping"));
assert_eq!(structured["env"], serde_json::Value::Null);
assert!(
structured["tracestate"].is_null()
|| structured["tracestate"].as_str().is_some_and(str::is_empty)
);
let traceparent = structured["traceparent"]
.as_str()
.expect("traceparent should be propagated via request metadata");
assert!(traceparent.starts_with("00-"));
Ok(())
}