mirror of
https://github.com/openai/codex.git
synced 2026-05-01 18:06:47 +00:00
Compare commits
2 Commits
dev/abhina
...
nicholascl
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e9e5bea81b | ||
|
|
bf5f89b535 |
5
codex-rs/Cargo.lock
generated
5
codex-rs/Cargo.lock
generated
@@ -2514,6 +2514,7 @@ dependencies = [
|
||||
"axum",
|
||||
"codex-client",
|
||||
"codex-keyring-store",
|
||||
"codex-otel",
|
||||
"codex-protocol",
|
||||
"codex-utils-cargo-bin",
|
||||
"codex-utils-home-dir",
|
||||
@@ -2521,6 +2522,8 @@ dependencies = [
|
||||
"futures",
|
||||
"keyring",
|
||||
"oauth2",
|
||||
"opentelemetry",
|
||||
"opentelemetry_sdk",
|
||||
"pretty_assertions",
|
||||
"reqwest",
|
||||
"rmcp",
|
||||
@@ -2535,6 +2538,8 @@ dependencies = [
|
||||
"tiny_http",
|
||||
"tokio",
|
||||
"tracing",
|
||||
"tracing-opentelemetry",
|
||||
"tracing-subscriber",
|
||||
"urlencoding",
|
||||
"webbrowser",
|
||||
"which 8.0.0",
|
||||
|
||||
@@ -10,7 +10,9 @@ use reqwest::Response;
|
||||
use serde::Serialize;
|
||||
use std::fmt::Display;
|
||||
use std::time::Duration;
|
||||
use tracing::Instrument;
|
||||
use tracing::Span;
|
||||
use tracing::field::Empty;
|
||||
use tracing_opentelemetry::OpenTelemetrySpanExt;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
@@ -111,10 +113,39 @@ impl CodexRequestBuilder {
|
||||
}
|
||||
|
||||
pub async fn send(self) -> Result<Response, reqwest::Error> {
|
||||
let headers = trace_headers();
|
||||
let parsed_url = reqwest::Url::parse(&self.url).ok();
|
||||
let path = parsed_url
|
||||
.as_ref()
|
||||
.map(|url| url.path().to_string())
|
||||
.unwrap_or_else(|| self.url.clone());
|
||||
let request_span = tracing::info_span!(
|
||||
"http.client",
|
||||
otel.kind = "client",
|
||||
http.request.method = %self.method,
|
||||
http.response.status_code = Empty,
|
||||
url.path = %path,
|
||||
server.address = Empty,
|
||||
server.port = Empty,
|
||||
);
|
||||
if let Some(url) = parsed_url.as_ref() {
|
||||
if let Some(host) = url.host_str() {
|
||||
request_span.record("server.address", host);
|
||||
}
|
||||
if let Some(port) = url.port_or_known_default() {
|
||||
request_span.record("server.port", port as i64);
|
||||
}
|
||||
}
|
||||
let headers = trace_headers_for_span(&request_span);
|
||||
|
||||
match self.builder.headers(headers).send().await {
|
||||
match async { self.builder.headers(headers).send().await }
|
||||
.instrument(request_span.clone())
|
||||
.await
|
||||
{
|
||||
Ok(response) => {
|
||||
request_span.record(
|
||||
"http.response.status_code",
|
||||
response.status().as_u16() as i64,
|
||||
);
|
||||
tracing::debug!(
|
||||
method = %self.method,
|
||||
url = %self.url,
|
||||
@@ -127,11 +158,14 @@ impl CodexRequestBuilder {
|
||||
Ok(response)
|
||||
}
|
||||
Err(error) => {
|
||||
let status = error.status();
|
||||
let status = error.status().map(|status| status.as_u16() as i64);
|
||||
if let Some(status) = status {
|
||||
request_span.record("http.response.status_code", status);
|
||||
}
|
||||
tracing::debug!(
|
||||
method = %self.method,
|
||||
url = %self.url,
|
||||
status = status.map(|s| s.as_u16()),
|
||||
status,
|
||||
error = %error,
|
||||
"Request failed"
|
||||
);
|
||||
@@ -154,13 +188,15 @@ impl<'a> Injector for HeaderMapInjector<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn trace_headers() -> HeaderMap {
|
||||
trace_headers_for_span(&Span::current())
|
||||
}
|
||||
|
||||
fn trace_headers_for_span(span: &Span) -> HeaderMap {
|
||||
let mut headers = HeaderMap::new();
|
||||
global::get_text_map_propagator(|prop| {
|
||||
prop.inject_context(
|
||||
&Span::current().context(),
|
||||
&mut HeaderMapInjector(&mut headers),
|
||||
);
|
||||
prop.inject_context(&span.context(), &mut HeaderMapInjector(&mut headers));
|
||||
});
|
||||
headers
|
||||
}
|
||||
@@ -204,6 +240,33 @@ mod tests {
|
||||
assert_eq!(extracted_context.span_id(), span_context.span_id());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn inject_trace_headers_for_span_uses_explicit_span_context() {
|
||||
global::set_text_map_propagator(TraceContextPropagator::new());
|
||||
|
||||
let provider = SdkTracerProvider::builder().build();
|
||||
let tracer = provider.tracer("test-tracer");
|
||||
let subscriber =
|
||||
tracing_subscriber::registry().with(tracing_opentelemetry::layer().with_tracer(tracer));
|
||||
let _guard = subscriber.set_default();
|
||||
|
||||
let parent = trace_span!("parent");
|
||||
let _parent_entered = parent.enter();
|
||||
let child = trace_span!("child");
|
||||
let child_context = child.context().span().span_context().clone();
|
||||
|
||||
let headers = trace_headers_for_span(&child);
|
||||
|
||||
let extractor = HeaderMapExtractor(&headers);
|
||||
let extracted = TraceContextPropagator::new().extract(&extractor);
|
||||
let extracted_span = extracted.span();
|
||||
let extracted_context = extracted_span.span_context();
|
||||
|
||||
assert!(extracted_context.is_valid());
|
||||
assert_eq!(extracted_context.trace_id(), child_context.trace_id());
|
||||
assert_eq!(extracted_context.span_id(), child_context.span_id());
|
||||
}
|
||||
|
||||
struct HeaderMapExtractor<'a>(&'a HeaderMap);
|
||||
|
||||
impl<'a> Extractor for HeaderMapExtractor<'a> {
|
||||
|
||||
@@ -86,6 +86,7 @@ use tokio::sync::oneshot;
|
||||
use tokio::sync::oneshot::error::TryRecvError;
|
||||
use tokio_tungstenite::tungstenite::Error;
|
||||
use tokio_tungstenite::tungstenite::Message;
|
||||
use tracing::Instrument;
|
||||
use tracing::instrument;
|
||||
use tracing::trace;
|
||||
use tracing::warn;
|
||||
@@ -1023,6 +1024,7 @@ impl ModelClientSession {
|
||||
let mut pending_retry = PendingUnauthorizedRetry::default();
|
||||
loop {
|
||||
let client_setup = self.client.current_client_setup().await?;
|
||||
let api_base_url = client_setup.api_provider.base_url.clone();
|
||||
let transport = ReqwestTransport::new(build_reqwest_client());
|
||||
let request_auth_context = AuthRequestTelemetryContext::new(
|
||||
client_setup.auth.as_ref().map(CodexAuth::auth_mode),
|
||||
@@ -1052,10 +1054,20 @@ impl ModelClientSession {
|
||||
client_setup.api_auth,
|
||||
)
|
||||
.with_telemetry(Some(request_telemetry), Some(sse_telemetry));
|
||||
let stream_result = client.stream_request(request, options).await;
|
||||
let request_span = crate::network_trace::responses_http_request_span(
|
||||
&self.client.state.conversation_id,
|
||||
turn_metadata_header,
|
||||
&self.client.state.provider.name,
|
||||
&model_info.slug,
|
||||
&api_base_url,
|
||||
);
|
||||
let stream_result = async { client.stream_request(request, options).await }
|
||||
.instrument(request_span.clone())
|
||||
.await;
|
||||
|
||||
match stream_result {
|
||||
Ok(stream) => {
|
||||
let _entered = request_span.enter();
|
||||
let (stream, _) = map_response_stream(stream, session_telemetry.clone());
|
||||
return Ok(stream);
|
||||
}
|
||||
@@ -1414,72 +1426,76 @@ where
|
||||
{
|
||||
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent>>(1600);
|
||||
let (tx_last_response, rx_last_response) = oneshot::channel::<LastResponse>();
|
||||
let current_span = tracing::Span::current();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut logged_error = false;
|
||||
let mut tx_last_response = Some(tx_last_response);
|
||||
let mut items_added: Vec<ResponseItem> = Vec::new();
|
||||
let mut api_stream = api_stream;
|
||||
while let Some(event) = api_stream.next().await {
|
||||
match event {
|
||||
Ok(ResponseEvent::OutputItemDone(item)) => {
|
||||
items_added.push(item.clone());
|
||||
if tx_event
|
||||
.send(Ok(ResponseEvent::OutputItemDone(item)))
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
return;
|
||||
tokio::spawn(
|
||||
async move {
|
||||
let mut logged_error = false;
|
||||
let mut tx_last_response = Some(tx_last_response);
|
||||
let mut items_added: Vec<ResponseItem> = Vec::new();
|
||||
let mut api_stream = api_stream;
|
||||
while let Some(event) = api_stream.next().await {
|
||||
match event {
|
||||
Ok(ResponseEvent::OutputItemDone(item)) => {
|
||||
items_added.push(item.clone());
|
||||
if tx_event
|
||||
.send(Ok(ResponseEvent::OutputItemDone(item)))
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(ResponseEvent::Completed {
|
||||
response_id,
|
||||
token_usage,
|
||||
}) => {
|
||||
if let Some(usage) = &token_usage {
|
||||
session_telemetry.sse_event_completed(
|
||||
usage.input_tokens,
|
||||
usage.output_tokens,
|
||||
Some(usage.cached_input_tokens),
|
||||
Some(usage.reasoning_output_tokens),
|
||||
usage.total_tokens,
|
||||
);
|
||||
Ok(ResponseEvent::Completed {
|
||||
response_id,
|
||||
token_usage,
|
||||
}) => {
|
||||
if let Some(usage) = &token_usage {
|
||||
session_telemetry.sse_event_completed(
|
||||
usage.input_tokens,
|
||||
usage.output_tokens,
|
||||
Some(usage.cached_input_tokens),
|
||||
Some(usage.reasoning_output_tokens),
|
||||
usage.total_tokens,
|
||||
);
|
||||
}
|
||||
if let Some(sender) = tx_last_response.take() {
|
||||
let _ = sender.send(LastResponse {
|
||||
response_id: response_id.clone(),
|
||||
items_added: std::mem::take(&mut items_added),
|
||||
});
|
||||
}
|
||||
if tx_event
|
||||
.send(Ok(ResponseEvent::Completed {
|
||||
response_id,
|
||||
token_usage,
|
||||
}))
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
if let Some(sender) = tx_last_response.take() {
|
||||
let _ = sender.send(LastResponse {
|
||||
response_id: response_id.clone(),
|
||||
items_added: std::mem::take(&mut items_added),
|
||||
});
|
||||
Ok(event) => {
|
||||
if tx_event.send(Ok(event)).await.is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
if tx_event
|
||||
.send(Ok(ResponseEvent::Completed {
|
||||
response_id,
|
||||
token_usage,
|
||||
}))
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
Ok(event) => {
|
||||
if tx_event.send(Ok(event)).await.is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
let mapped = map_api_error(err);
|
||||
if !logged_error {
|
||||
session_telemetry.see_event_completed_failed(&mapped);
|
||||
logged_error = true;
|
||||
}
|
||||
if tx_event.send(Err(mapped)).await.is_err() {
|
||||
return;
|
||||
Err(err) => {
|
||||
let mapped = map_api_error(err);
|
||||
if !logged_error {
|
||||
session_telemetry.see_event_completed_failed(&mapped);
|
||||
logged_error = true;
|
||||
}
|
||||
if tx_event.send(Err(mapped)).await.is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
.instrument(current_span),
|
||||
);
|
||||
|
||||
(ResponseStream { rx_event }, rx_last_response)
|
||||
}
|
||||
|
||||
@@ -52,6 +52,7 @@ mod mcp_tool_approval_templates;
|
||||
pub mod models_manager;
|
||||
mod network_policy_decision;
|
||||
pub mod network_proxy_loader;
|
||||
mod network_trace;
|
||||
mod original_image_detail;
|
||||
mod packages;
|
||||
pub use mcp_connection_manager::MCP_SANDBOX_STATE_CAPABILITY;
|
||||
|
||||
85
codex-rs/core/src/network_trace.rs
Normal file
85
codex-rs/core/src/network_trace.rs
Normal file
@@ -0,0 +1,85 @@
|
||||
use codex_protocol::ThreadId;
|
||||
use serde::Deserialize;
|
||||
use tracing::Span;
|
||||
use tracing::field::Empty;
|
||||
use url::Url;
|
||||
|
||||
#[derive(Debug, Default, Deserialize)]
|
||||
struct TurnTraceCorrelation {
|
||||
session_id: Option<String>,
|
||||
turn_id: Option<String>,
|
||||
}
|
||||
|
||||
struct CorrelationFields {
|
||||
conversation_id: String,
|
||||
session_id: String,
|
||||
turn_id: Option<String>,
|
||||
}
|
||||
|
||||
impl CorrelationFields {
|
||||
fn from_turn_metadata_header(
|
||||
conversation_id: &ThreadId,
|
||||
turn_metadata_header: Option<&str>,
|
||||
) -> Self {
|
||||
let conversation_id = conversation_id.to_string();
|
||||
let correlation = turn_metadata_header
|
||||
.and_then(|header| serde_json::from_str::<TurnTraceCorrelation>(header).ok())
|
||||
.unwrap_or_default();
|
||||
let session_id = correlation
|
||||
.session_id
|
||||
.unwrap_or_else(|| conversation_id.clone());
|
||||
Self {
|
||||
conversation_id,
|
||||
session_id,
|
||||
turn_id: correlation.turn_id,
|
||||
}
|
||||
}
|
||||
fn record_on(&self, span: &Span) {
|
||||
span.record("conversation.id", self.conversation_id.as_str());
|
||||
span.record("session.id", self.session_id.as_str());
|
||||
if let Some(turn_id) = self.turn_id.as_deref() {
|
||||
span.record("turn.id", turn_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn record_server_fields(span: &Span, url: Option<&str>) {
|
||||
let Some(url) = url else {
|
||||
return;
|
||||
};
|
||||
let Ok(parsed) = Url::parse(url) else {
|
||||
return;
|
||||
};
|
||||
if let Some(host) = parsed.host_str() {
|
||||
span.record("server.address", host);
|
||||
}
|
||||
if let Some(port) = parsed.port_or_known_default() {
|
||||
span.record("server.port", port as i64);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn responses_http_request_span(
|
||||
conversation_id: &ThreadId,
|
||||
turn_metadata_header: Option<&str>,
|
||||
provider_name: &str,
|
||||
model: &str,
|
||||
base_url: &str,
|
||||
) -> Span {
|
||||
let span = tracing::info_span!(
|
||||
"responses_http.request",
|
||||
otel.kind = "client",
|
||||
provider = provider_name,
|
||||
model,
|
||||
transport = "responses_http",
|
||||
api.path = "responses",
|
||||
conversation.id = Empty,
|
||||
session.id = Empty,
|
||||
turn.id = Empty,
|
||||
server.address = Empty,
|
||||
server.port = Empty,
|
||||
);
|
||||
CorrelationFields::from_turn_metadata_header(conversation_id, turn_metadata_header)
|
||||
.record_on(&span);
|
||||
record_server_fields(&span, Some(base_url));
|
||||
span
|
||||
}
|
||||
@@ -197,11 +197,12 @@ impl Session {
|
||||
let ctx = Arc::clone(&turn_context);
|
||||
let task_for_run = Arc::clone(&task);
|
||||
let task_cancellation_token = cancellation_token.child_token();
|
||||
// Task-owned turn spans keep a core-owned span open for the
|
||||
// full task lifecycle after the submission dispatch span ends.
|
||||
// Task-owned turn spans keep a core-owned parent span open for the
|
||||
// full turn lifecycle after the submission dispatch span ends.
|
||||
let task_span = info_span!(
|
||||
"turn",
|
||||
otel.name = span_name,
|
||||
conversation.id = %self.conversation_id,
|
||||
thread.id = %self.conversation_id,
|
||||
turn.id = %turn_context.sub_id,
|
||||
model = %turn_context.model_info.slug,
|
||||
|
||||
@@ -718,6 +718,57 @@ async fn record_responses_sets_span_fields_for_response_events() {
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "current_thread")]
|
||||
#[traced_test]
|
||||
async fn responses_request_span_records_turn_correlation_fields() {
|
||||
let server = start_mock_server().await;
|
||||
mount_sse_once(
|
||||
&server,
|
||||
sse(vec![ev_response_created("resp-1"), ev_completed("resp-1")]),
|
||||
)
|
||||
.await;
|
||||
|
||||
let TestCodex { codex, .. } = test_codex().build(&server).await.unwrap();
|
||||
|
||||
codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![UserInput::Text {
|
||||
text: "hello".into(),
|
||||
text_elements: Vec::new(),
|
||||
}],
|
||||
final_output_json_schema: None,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TurnComplete(_))).await;
|
||||
|
||||
logs_assert(|lines: &[&str]| {
|
||||
lines
|
||||
.iter()
|
||||
.find(|line| {
|
||||
line.contains("turn{otel.name=\"session_task.turn\"")
|
||||
&& line.contains("conversation.id=")
|
||||
&& line.contains("thread.id=")
|
||||
&& line.contains("turn.id=")
|
||||
&& line.contains("model=")
|
||||
&& line.contains("responses_http.request{")
|
||||
&& line.contains("otel.kind=\"client\"")
|
||||
&& line.contains("transport=\"responses_http\"")
|
||||
&& line.contains("conversation.id=")
|
||||
&& line.contains("session.id=")
|
||||
&& line.contains("turn.id=")
|
||||
})
|
||||
.map(|_| Ok(()))
|
||||
.unwrap_or_else(|| {
|
||||
Err(
|
||||
"missing responses_http.request span nested under session_task.turn"
|
||||
.to_string(),
|
||||
)
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[traced_test]
|
||||
async fn handle_response_item_records_tool_result_for_custom_tool_call() {
|
||||
|
||||
@@ -36,6 +36,7 @@ use core_test_support::responses::mount_sse_once;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::stdio_server_bin;
|
||||
use core_test_support::test_codex::test_codex;
|
||||
use core_test_support::tracing::install_test_tracing;
|
||||
use core_test_support::wait_for_event;
|
||||
use core_test_support::wait_for_event_with_timeout;
|
||||
use reqwest::Client;
|
||||
@@ -687,6 +688,8 @@ async fn stdio_server_propagates_whitelisted_env_vars() -> anyhow::Result<()> {
|
||||
async fn streamable_http_tool_call_round_trip() -> anyhow::Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let _trace_test_context = install_test_tracing("rmcp-integration-tests");
|
||||
|
||||
let server = responses::start_mock_server().await;
|
||||
|
||||
let call_id = "call-456";
|
||||
@@ -733,6 +736,7 @@ async fn streamable_http_tool_call_round_trip() -> anyhow::Result<()> {
|
||||
.kill_on_drop(true)
|
||||
.env("MCP_STREAMABLE_HTTP_BIND_ADDR", &bind_addr)
|
||||
.env("MCP_TEST_VALUE", expected_env_value)
|
||||
.env("MCP_EXPECT_TRACEPARENT", "1")
|
||||
.spawn()?;
|
||||
|
||||
wait_for_streamable_http_server(&mut http_server_child, &bind_addr, Duration::from_secs(5))
|
||||
|
||||
@@ -15,6 +15,7 @@ axum = { workspace = true, default-features = false, features = [
|
||||
] }
|
||||
codex-client = { workspace = true }
|
||||
codex-keyring-store = { workspace = true }
|
||||
codex-otel = { workspace = true }
|
||||
codex-protocol = { workspace = true }
|
||||
codex-utils-pty = { workspace = true }
|
||||
codex-utils-home-dir = { workspace = true }
|
||||
@@ -60,9 +61,13 @@ which = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
codex-utils-cargo-bin = { workspace = true }
|
||||
opentelemetry = { workspace = true }
|
||||
opentelemetry_sdk = { workspace = true }
|
||||
pretty_assertions = { workspace = true }
|
||||
serial_test = { workspace = true }
|
||||
tempfile = { workspace = true }
|
||||
tracing-opentelemetry = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
[target.'cfg(target_os = "linux")'.dependencies]
|
||||
keyring = { workspace = true, features = ["linux-native-async-persistent"] }
|
||||
|
||||
|
||||
@@ -315,7 +315,7 @@ impl ServerHandler for TestToolServer {
|
||||
async fn call_tool(
|
||||
&self,
|
||||
request: CallToolRequestParams,
|
||||
_context: rmcp::service::RequestContext<rmcp::service::RoleServer>,
|
||||
context: rmcp::service::RequestContext<rmcp::service::RoleServer>,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
match request.name.as_ref() {
|
||||
"echo" | "echo-tool" => {
|
||||
@@ -333,9 +333,19 @@ impl ServerHandler for TestToolServer {
|
||||
};
|
||||
|
||||
let env_snapshot: HashMap<String, String> = std::env::vars().collect();
|
||||
let traceparent = context
|
||||
.meta
|
||||
.get("x-codex-traceparent")
|
||||
.and_then(serde_json::Value::as_str);
|
||||
let tracestate = context
|
||||
.meta
|
||||
.get("x-codex-tracestate")
|
||||
.and_then(serde_json::Value::as_str);
|
||||
let structured_content = json!({
|
||||
"echo": format!("ECHOING: {}", args.message),
|
||||
"env": env_snapshot.get("MCP_TEST_VALUE"),
|
||||
"traceparent": traceparent,
|
||||
"tracestate": tracestate,
|
||||
});
|
||||
|
||||
Ok(CallToolResult {
|
||||
|
||||
@@ -58,6 +58,7 @@ struct TestToolServer {
|
||||
const MEMO_URI: &str = "memo://codex/example-note";
|
||||
const MEMO_CONTENT: &str = "This is a sample MCP resource served by the rmcp test server.";
|
||||
const MCP_SESSION_ID_HEADER: &str = "mcp-session-id";
|
||||
const TRACEPARENT_HEADER: &str = "traceparent";
|
||||
const SESSION_POST_FAILURE_CONTROL_PATH: &str = "/test/control/session-post-failure";
|
||||
|
||||
impl TestToolServer {
|
||||
@@ -347,6 +348,11 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
} else {
|
||||
router
|
||||
};
|
||||
let router = if std::env::var("MCP_EXPECT_TRACEPARENT").is_ok() {
|
||||
router.layer(middleware::from_fn(require_traceparent_on_session_post))
|
||||
} else {
|
||||
router
|
||||
};
|
||||
|
||||
axum::serve(listener, router).await?;
|
||||
task::yield_now().await;
|
||||
@@ -389,6 +395,23 @@ async fn arm_session_post_failure(
|
||||
Ok(StatusCode::NO_CONTENT)
|
||||
}
|
||||
|
||||
async fn require_traceparent_on_session_post(request: Request<Body>, next: Next) -> Response {
|
||||
if request.uri().path() != "/mcp"
|
||||
|| request.method() != Method::POST
|
||||
|| !request.headers().contains_key(MCP_SESSION_ID_HEADER)
|
||||
{
|
||||
return next.run(request).await;
|
||||
}
|
||||
|
||||
if request.headers().contains_key(TRACEPARENT_HEADER) {
|
||||
next.run(request).await
|
||||
} else {
|
||||
let mut response = Response::new(Body::from("missing traceparent header"));
|
||||
*response.status_mut() = StatusCode::BAD_REQUEST;
|
||||
response
|
||||
}
|
||||
}
|
||||
|
||||
async fn fail_session_post_when_armed(
|
||||
State(state): State<SessionFailureState>,
|
||||
request: Request<Body>,
|
||||
|
||||
@@ -10,6 +10,9 @@ use std::time::Duration;
|
||||
use anyhow::Result;
|
||||
use anyhow::anyhow;
|
||||
use codex_client::build_reqwest_client_with_custom_ca;
|
||||
use codex_otel::current_span_w3c_trace_context;
|
||||
use codex_otel::span_w3c_trace_context;
|
||||
use codex_protocol::protocol::W3cTraceContext;
|
||||
use futures::FutureExt;
|
||||
use futures::StreamExt;
|
||||
use futures::future::BoxFuture;
|
||||
@@ -64,6 +67,8 @@ use tokio::io::BufReader;
|
||||
use tokio::process::Command;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::time;
|
||||
use tracing::Instrument;
|
||||
use tracing::field::Empty;
|
||||
use tracing::info;
|
||||
use tracing::warn;
|
||||
|
||||
@@ -82,6 +87,14 @@ const JSON_MIME_TYPE: &str = "application/json";
|
||||
const HEADER_LAST_EVENT_ID: &str = "Last-Event-Id";
|
||||
const HEADER_SESSION_ID: &str = "Mcp-Session-Id";
|
||||
const NON_JSON_RESPONSE_BODY_PREVIEW_BYTES: usize = 8_192;
|
||||
const TRACEPARENT_HEADER: &str = "traceparent";
|
||||
const TRACESTATE_HEADER: &str = "tracestate";
|
||||
const TRACEPARENT_META_KEY: &str = "x-codex-traceparent";
|
||||
const TRACESTATE_META_KEY: &str = "x-codex-tracestate";
|
||||
|
||||
tokio::task_local! {
|
||||
static SERVICE_OPERATION_TRACE_CONTEXT: Option<W3cTraceContext>;
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct StreamableHttpResponseClient {
|
||||
@@ -98,6 +111,24 @@ impl StreamableHttpResponseClient {
|
||||
) -> StreamableHttpError<StreamableHttpResponseClientError> {
|
||||
StreamableHttpError::Client(StreamableHttpResponseClientError::from(error))
|
||||
}
|
||||
|
||||
fn apply_trace_context(
|
||||
request: reqwest::RequestBuilder,
|
||||
trace: Option<&W3cTraceContext>,
|
||||
) -> reqwest::RequestBuilder {
|
||||
let Some(trace) = trace else {
|
||||
return request;
|
||||
};
|
||||
|
||||
let mut request = request;
|
||||
if let Some(traceparent) = trace.traceparent.as_deref() {
|
||||
request = request.header(TRACEPARENT_HEADER, traceparent);
|
||||
}
|
||||
if let Some(tracestate) = trace.tracestate.as_deref() {
|
||||
request = request.header(TRACESTATE_HEADER, tracestate);
|
||||
}
|
||||
request
|
||||
}
|
||||
}
|
||||
|
||||
fn build_http_client(default_headers: &HeaderMap) -> Result<reqwest::Client> {
|
||||
@@ -123,6 +154,7 @@ impl StreamableHttpClient for StreamableHttpResponseClient {
|
||||
session_id: Option<Arc<str>>,
|
||||
auth_token: Option<String>,
|
||||
) -> std::result::Result<StreamableHttpPostResponse, StreamableHttpError<Self::Error>> {
|
||||
let trace = current_service_operation_trace_context();
|
||||
let mut request = self
|
||||
.inner
|
||||
.post(uri.as_ref())
|
||||
@@ -133,6 +165,7 @@ impl StreamableHttpClient for StreamableHttpResponseClient {
|
||||
if let Some(session_id_value) = session_id.as_ref() {
|
||||
request = request.header(HEADER_SESSION_ID, session_id_value.as_ref());
|
||||
}
|
||||
request = Self::apply_trace_context(request, trace.as_ref());
|
||||
|
||||
let response = request
|
||||
.json(&message)
|
||||
@@ -224,10 +257,12 @@ impl StreamableHttpClient for StreamableHttpResponseClient {
|
||||
session: Arc<str>,
|
||||
auth_token: Option<String>,
|
||||
) -> std::result::Result<(), StreamableHttpError<Self::Error>> {
|
||||
let trace = current_service_operation_trace_context();
|
||||
let mut request_builder = self.inner.delete(uri.as_ref());
|
||||
if let Some(auth_header) = auth_token {
|
||||
request_builder = request_builder.bearer_auth(auth_header);
|
||||
}
|
||||
request_builder = Self::apply_trace_context(request_builder, trace.as_ref());
|
||||
let response = request_builder
|
||||
.header(HEADER_SESSION_ID, session.as_ref())
|
||||
.send()
|
||||
@@ -254,6 +289,7 @@ impl StreamableHttpClient for StreamableHttpResponseClient {
|
||||
BoxStream<'static, std::result::Result<Sse, sse_stream::Error>>,
|
||||
StreamableHttpError<Self::Error>,
|
||||
> {
|
||||
let trace = current_service_operation_trace_context();
|
||||
let mut request_builder = self
|
||||
.inner
|
||||
.get(uri.as_ref())
|
||||
@@ -265,6 +301,7 @@ impl StreamableHttpClient for StreamableHttpResponseClient {
|
||||
if let Some(auth_header) = auth_token {
|
||||
request_builder = request_builder.bearer_auth(auth_header);
|
||||
}
|
||||
request_builder = Self::apply_trace_context(request_builder, trace.as_ref());
|
||||
|
||||
let response = request_builder
|
||||
.send()
|
||||
@@ -733,6 +770,8 @@ impl RmcpClient {
|
||||
let rmcp_params = rmcp_params.clone();
|
||||
let meta = meta.clone();
|
||||
async move {
|
||||
let trace = current_service_operation_trace_context();
|
||||
let meta = merge_trace_context_into_meta(meta, trace.as_ref());
|
||||
let result = service
|
||||
.peer()
|
||||
.send_request_with_option(
|
||||
@@ -1052,41 +1091,104 @@ impl RmcpClient {
|
||||
Fut: std::future::Future<Output = std::result::Result<T, rmcp::service::ServiceError>>,
|
||||
{
|
||||
let service = self.service().await?;
|
||||
match Self::run_service_operation_once(Arc::clone(&service), label, timeout, &operation)
|
||||
.await
|
||||
let operation_span = self.service_operation_span(label);
|
||||
let operation_trace = span_w3c_trace_context(&operation_span);
|
||||
match Self::run_service_operation_once(
|
||||
Arc::clone(&service),
|
||||
label,
|
||||
timeout,
|
||||
operation_span,
|
||||
operation_trace,
|
||||
&operation,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(result) => Ok(result),
|
||||
Err(error) if Self::is_session_expired_404(&error) => {
|
||||
self.reinitialize_after_session_expiry(&service).await?;
|
||||
let recovered_service = self.service().await?;
|
||||
Self::run_service_operation_once(recovered_service, label, timeout, &operation)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
let operation_span = self.service_operation_span(label);
|
||||
let operation_trace = span_w3c_trace_context(&operation_span);
|
||||
Self::run_service_operation_once(
|
||||
recovered_service,
|
||||
label,
|
||||
timeout,
|
||||
operation_span,
|
||||
operation_trace,
|
||||
&operation,
|
||||
)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
}
|
||||
Err(error) => Err(error.into()),
|
||||
}
|
||||
}
|
||||
|
||||
fn service_operation_span(&self, label: &str) -> tracing::Span {
|
||||
let span = tracing::info_span!(
|
||||
"mcp.client.operation",
|
||||
otel.kind = "client",
|
||||
rpc.system = "jsonrpc",
|
||||
rpc.method = label,
|
||||
mcp.transport = Empty,
|
||||
mcp.server.name = Empty,
|
||||
server.address = Empty,
|
||||
server.port = Empty,
|
||||
);
|
||||
|
||||
match &self.transport_recipe {
|
||||
TransportRecipe::Stdio { .. } => {
|
||||
span.record("mcp.transport", "stdio");
|
||||
}
|
||||
TransportRecipe::StreamableHttp {
|
||||
server_name, url, ..
|
||||
} => {
|
||||
span.record("mcp.transport", "streamable_http");
|
||||
span.record("mcp.server.name", server_name.as_str());
|
||||
if let Ok(parsed_url) = reqwest::Url::parse(url) {
|
||||
if let Some(host) = parsed_url.host_str() {
|
||||
span.record("server.address", host);
|
||||
}
|
||||
if let Some(port) = parsed_url.port_or_known_default() {
|
||||
span.record("server.port", port as i64);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
span
|
||||
}
|
||||
|
||||
async fn run_service_operation_once<T, F, Fut>(
|
||||
service: Arc<RunningService<RoleClient, LoggingClientHandler>>,
|
||||
label: &str,
|
||||
timeout: Option<Duration>,
|
||||
operation_span: tracing::Span,
|
||||
operation_trace: Option<W3cTraceContext>,
|
||||
operation: &F,
|
||||
) -> std::result::Result<T, ClientOperationError>
|
||||
where
|
||||
F: Fn(Arc<RunningService<RoleClient, LoggingClientHandler>>) -> Fut,
|
||||
Fut: std::future::Future<Output = std::result::Result<T, rmcp::service::ServiceError>>,
|
||||
{
|
||||
match timeout {
|
||||
Some(duration) => time::timeout(duration, operation(service))
|
||||
SERVICE_OPERATION_TRACE_CONTEXT
|
||||
.scope(operation_trace, async move {
|
||||
async move {
|
||||
match timeout {
|
||||
Some(duration) => time::timeout(duration, operation(service))
|
||||
.await
|
||||
.map_err(|_| ClientOperationError::Timeout {
|
||||
label: label.to_string(),
|
||||
duration,
|
||||
})?
|
||||
.map_err(ClientOperationError::from),
|
||||
None => operation(service).await.map_err(ClientOperationError::from),
|
||||
}
|
||||
}
|
||||
.instrument(operation_span)
|
||||
.await
|
||||
.map_err(|_| ClientOperationError::Timeout {
|
||||
label: label.to_string(),
|
||||
duration,
|
||||
})?
|
||||
.map_err(ClientOperationError::from),
|
||||
None => operation(service).await.map_err(ClientOperationError::from),
|
||||
}
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
fn is_session_expired_404(error: &ClientOperationError) -> bool {
|
||||
@@ -1161,6 +1263,39 @@ impl RmcpClient {
|
||||
}
|
||||
}
|
||||
|
||||
fn merge_trace_context_into_meta(
|
||||
meta: Option<rmcp::model::Meta>,
|
||||
trace: Option<&W3cTraceContext>,
|
||||
) -> Option<rmcp::model::Meta> {
|
||||
let mut meta = meta.unwrap_or_default();
|
||||
let Some(trace) = trace else {
|
||||
return (!meta.is_empty()).then_some(meta);
|
||||
};
|
||||
|
||||
if let Some(traceparent) = trace.traceparent.as_ref() {
|
||||
meta.insert(
|
||||
TRACEPARENT_META_KEY.to_string(),
|
||||
serde_json::Value::String(traceparent.clone()),
|
||||
);
|
||||
}
|
||||
if let Some(tracestate) = trace.tracestate.as_ref() {
|
||||
meta.insert(
|
||||
TRACESTATE_META_KEY.to_string(),
|
||||
serde_json::Value::String(tracestate.clone()),
|
||||
);
|
||||
}
|
||||
|
||||
(!meta.is_empty()).then_some(meta)
|
||||
}
|
||||
|
||||
fn current_service_operation_trace_context() -> Option<W3cTraceContext> {
|
||||
SERVICE_OPERATION_TRACE_CONTEXT
|
||||
.try_with(Clone::clone)
|
||||
.ok()
|
||||
.flatten()
|
||||
.or_else(current_span_w3c_trace_context)
|
||||
}
|
||||
|
||||
async fn create_oauth_transport_and_runtime(
|
||||
server_name: &str,
|
||||
url: &str,
|
||||
@@ -1207,3 +1342,75 @@ async fn create_oauth_transport_and_runtime(
|
||||
|
||||
Ok((transport, runtime))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod trace_tests {
|
||||
use super::StreamableHttpResponseClient;
|
||||
use super::TRACEPARENT_HEADER;
|
||||
use super::TRACEPARENT_META_KEY;
|
||||
use super::TRACESTATE_HEADER;
|
||||
use super::TRACESTATE_META_KEY;
|
||||
use super::merge_trace_context_into_meta;
|
||||
use codex_protocol::protocol::W3cTraceContext;
|
||||
use pretty_assertions::assert_eq;
|
||||
use reqwest::Client;
|
||||
use serde_json::Value;
|
||||
|
||||
#[test]
|
||||
fn merge_trace_context_into_meta_preserves_existing_fields() {
|
||||
let trace = W3cTraceContext {
|
||||
traceparent: Some("00-aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa-bbbbbbbbbbbbbbbb-01".into()),
|
||||
tracestate: Some("vendor=value".into()),
|
||||
};
|
||||
let meta = rmcp::model::Meta(serde_json::Map::from_iter([(
|
||||
"existing".to_string(),
|
||||
Value::String("value".into()),
|
||||
)]));
|
||||
|
||||
let merged = merge_trace_context_into_meta(Some(meta), Some(&trace)).expect("meta");
|
||||
|
||||
assert_eq!(
|
||||
merged,
|
||||
rmcp::model::Meta(serde_json::Map::from_iter([
|
||||
("existing".to_string(), Value::String("value".into())),
|
||||
(
|
||||
TRACEPARENT_META_KEY.to_string(),
|
||||
Value::String("00-aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa-bbbbbbbbbbbbbbbb-01".into())
|
||||
),
|
||||
(
|
||||
TRACESTATE_META_KEY.to_string(),
|
||||
Value::String("vendor=value".into())
|
||||
),
|
||||
]))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn apply_trace_context_injects_http_headers() {
|
||||
let trace = W3cTraceContext {
|
||||
traceparent: Some("00-aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa-bbbbbbbbbbbbbbbb-01".into()),
|
||||
tracestate: Some("vendor=value".into()),
|
||||
};
|
||||
let request = StreamableHttpResponseClient::apply_trace_context(
|
||||
Client::new().post("http://example.com"),
|
||||
Some(&trace),
|
||||
)
|
||||
.build()
|
||||
.expect("request");
|
||||
|
||||
assert_eq!(
|
||||
request
|
||||
.headers()
|
||||
.get(TRACEPARENT_HEADER)
|
||||
.and_then(|value| value.to_str().ok()),
|
||||
Some("00-aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa-bbbbbbbbbbbbbbbb-01")
|
||||
);
|
||||
assert_eq!(
|
||||
request
|
||||
.headers()
|
||||
.get(TRACESTATE_HEADER)
|
||||
.and_then(|value| value.to_str().ok()),
|
||||
Some("vendor=value")
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,6 +7,10 @@ use codex_rmcp_client::ElicitationResponse;
|
||||
use codex_rmcp_client::RmcpClient;
|
||||
use codex_utils_cargo_bin::CargoBinError;
|
||||
use futures::FutureExt as _;
|
||||
use opentelemetry::global;
|
||||
use opentelemetry::trace::TracerProvider as _;
|
||||
use opentelemetry_sdk::propagation::TraceContextPropagator;
|
||||
use opentelemetry_sdk::trace::SdkTracerProvider;
|
||||
use rmcp::model::AnnotateAble;
|
||||
use rmcp::model::ClientCapabilities;
|
||||
use rmcp::model::ElicitationCapability;
|
||||
@@ -18,6 +22,10 @@ use rmcp::model::ProtocolVersion;
|
||||
use rmcp::model::ReadResourceRequestParams;
|
||||
use rmcp::model::ResourceContents;
|
||||
use serde_json::json;
|
||||
use tracing::Instrument;
|
||||
use tracing::dispatcher::DefaultGuard;
|
||||
use tracing_subscriber::layer::SubscriberExt;
|
||||
use tracing_subscriber::util::SubscriberInitExt;
|
||||
|
||||
const RESOURCE_URI: &str = "memo://codex/example-note";
|
||||
|
||||
@@ -53,6 +61,25 @@ fn init_params() -> InitializeRequestParams {
|
||||
}
|
||||
}
|
||||
|
||||
struct TestTracingContext {
|
||||
_provider: SdkTracerProvider,
|
||||
_guard: DefaultGuard,
|
||||
}
|
||||
|
||||
fn install_test_tracing(tracer_name: &str) -> TestTracingContext {
|
||||
global::set_text_map_propagator(TraceContextPropagator::new());
|
||||
|
||||
let provider = SdkTracerProvider::builder().build();
|
||||
let tracer = provider.tracer(tracer_name.to_string());
|
||||
let subscriber =
|
||||
tracing_subscriber::registry().with(tracing_opentelemetry::layer().with_tracer(tracer));
|
||||
|
||||
TestTracingContext {
|
||||
_provider: provider,
|
||||
_guard: subscriber.set_default(),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
|
||||
async fn rmcp_client_can_list_and_read_resources() -> anyhow::Result<()> {
|
||||
let client = RmcpClient::new_stdio_client(
|
||||
@@ -149,3 +176,61 @@ async fn rmcp_client_can_list_and_read_resources() -> anyhow::Result<()> {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "current_thread")]
|
||||
async fn stdio_tool_call_propagates_trace_metadata() -> anyhow::Result<()> {
|
||||
let _trace = install_test_tracing("rmcp-stdio-trace-test");
|
||||
let client = RmcpClient::new_stdio_client(
|
||||
stdio_server_bin()?.into(),
|
||||
Vec::<OsString>::new(),
|
||||
None,
|
||||
&[],
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
|
||||
client
|
||||
.initialize(
|
||||
init_params(),
|
||||
Some(Duration::from_secs(5)),
|
||||
Box::new(|_, _| {
|
||||
async {
|
||||
Ok(ElicitationResponse {
|
||||
action: ElicitationAction::Accept,
|
||||
content: Some(json!({})),
|
||||
meta: None,
|
||||
})
|
||||
}
|
||||
.boxed()
|
||||
}),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let result = async {
|
||||
client
|
||||
.call_tool(
|
||||
"echo".to_string(),
|
||||
Some(json!({ "message": "ping" })),
|
||||
None,
|
||||
Some(Duration::from_secs(5)),
|
||||
)
|
||||
.await
|
||||
}
|
||||
.instrument(tracing::info_span!("rmcp.client.trace_test"))
|
||||
.await?;
|
||||
|
||||
assert_eq!(result.is_error, Some(false));
|
||||
let structured = result.structured_content.expect("structured content");
|
||||
assert_eq!(structured["echo"], json!("ECHOING: ping"));
|
||||
assert_eq!(structured["env"], serde_json::Value::Null);
|
||||
assert!(
|
||||
structured["tracestate"].is_null()
|
||||
|| structured["tracestate"].as_str().is_some_and(str::is_empty)
|
||||
);
|
||||
let traceparent = structured["traceparent"]
|
||||
.as_str()
|
||||
.expect("traceparent should be propagated via request metadata");
|
||||
assert!(traceparent.starts_with("00-"));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user