mirror of
https://github.com/openai/codex.git
synced 2026-04-24 06:35:50 +00:00
580 lines
20 KiB
Rust
580 lines
20 KiB
Rust
use super::ConnectionSessionState;
|
|
use super::MessageProcessor;
|
|
use super::MessageProcessorArgs;
|
|
use crate::outgoing_message::ConnectionId;
|
|
use crate::outgoing_message::OutgoingMessageSender;
|
|
use crate::transport::AppServerTransport;
|
|
use anyhow::Result;
|
|
use app_test_support::create_mock_responses_server_repeating_assistant;
|
|
use app_test_support::write_mock_responses_config_toml;
|
|
use codex_app_server_protocol::ClientInfo;
|
|
use codex_app_server_protocol::ClientRequest;
|
|
use codex_app_server_protocol::InitializeCapabilities;
|
|
use codex_app_server_protocol::InitializeParams;
|
|
use codex_app_server_protocol::InitializeResponse;
|
|
use codex_app_server_protocol::JSONRPCRequest;
|
|
use codex_app_server_protocol::RequestId;
|
|
use codex_app_server_protocol::ThreadStartParams;
|
|
use codex_app_server_protocol::ThreadStartResponse;
|
|
use codex_app_server_protocol::TurnStartParams;
|
|
use codex_app_server_protocol::TurnStartResponse;
|
|
use codex_app_server_protocol::UserInput;
|
|
use codex_arg0::Arg0DispatchPaths;
|
|
use codex_core::config::Config;
|
|
use codex_core::config::ConfigBuilder;
|
|
use codex_core::config_loader::CloudRequirementsLoader;
|
|
use codex_core::config_loader::LoaderOverrides;
|
|
use codex_feedback::CodexFeedback;
|
|
use codex_protocol::protocol::SessionSource;
|
|
use codex_protocol::protocol::W3cTraceContext;
|
|
use opentelemetry::global;
|
|
use opentelemetry::trace::SpanId;
|
|
use opentelemetry::trace::SpanKind;
|
|
use opentelemetry::trace::TraceId;
|
|
use opentelemetry::trace::TracerProvider as _;
|
|
use opentelemetry_sdk::propagation::TraceContextPropagator;
|
|
use opentelemetry_sdk::trace::InMemorySpanExporter;
|
|
use opentelemetry_sdk::trace::SdkTracerProvider;
|
|
use opentelemetry_sdk::trace::SpanData;
|
|
use pretty_assertions::assert_eq;
|
|
use std::collections::BTreeMap;
|
|
use std::path::Path;
|
|
use std::sync::Arc;
|
|
use std::sync::OnceLock;
|
|
use tempfile::TempDir;
|
|
use tokio::sync::mpsc;
|
|
use tracing_subscriber::layer::SubscriberExt;
|
|
use wiremock::MockServer;
|
|
|
|
const TEST_CONNECTION_ID: ConnectionId = ConnectionId(7);
|
|
|
|
struct TestTracing {
|
|
exporter: InMemorySpanExporter,
|
|
provider: SdkTracerProvider,
|
|
}
|
|
|
|
struct RemoteTrace {
|
|
trace_id: TraceId,
|
|
parent_span_id: SpanId,
|
|
context: W3cTraceContext,
|
|
}
|
|
|
|
impl RemoteTrace {
|
|
fn new(trace_id: &str, parent_span_id: &str) -> Self {
|
|
let trace_id = TraceId::from_hex(trace_id).expect("trace id");
|
|
let parent_span_id = SpanId::from_hex(parent_span_id).expect("parent span id");
|
|
let context = W3cTraceContext {
|
|
traceparent: Some(format!("00-{trace_id}-{parent_span_id}-01")),
|
|
tracestate: Some("vendor=value".to_string()),
|
|
};
|
|
|
|
Self {
|
|
trace_id,
|
|
parent_span_id,
|
|
context,
|
|
}
|
|
}
|
|
}
|
|
|
|
fn init_test_tracing() -> &'static TestTracing {
|
|
static TEST_TRACING: OnceLock<TestTracing> = OnceLock::new();
|
|
TEST_TRACING.get_or_init(|| {
|
|
let exporter = InMemorySpanExporter::default();
|
|
let provider = SdkTracerProvider::builder()
|
|
.with_simple_exporter(exporter.clone())
|
|
.build();
|
|
let tracer = provider.tracer("codex-app-server-message-processor-tests");
|
|
global::set_text_map_propagator(TraceContextPropagator::new());
|
|
let subscriber =
|
|
tracing_subscriber::registry().with(tracing_opentelemetry::layer().with_tracer(tracer));
|
|
tracing::subscriber::set_global_default(subscriber)
|
|
.expect("global tracing subscriber should only be installed once");
|
|
TestTracing { exporter, provider }
|
|
})
|
|
}
|
|
|
|
fn request_from_client_request(request: ClientRequest) -> JSONRPCRequest {
|
|
serde_json::from_value(serde_json::to_value(request).expect("serialize client request"))
|
|
.expect("client request should convert to JSON-RPC")
|
|
}
|
|
|
|
fn tracing_test_guard() -> &'static tokio::sync::Mutex<()> {
|
|
static GUARD: OnceLock<tokio::sync::Mutex<()>> = OnceLock::new();
|
|
GUARD.get_or_init(|| tokio::sync::Mutex::new(()))
|
|
}
|
|
|
|
struct TracingHarness {
|
|
_server: MockServer,
|
|
_codex_home: TempDir,
|
|
processor: MessageProcessor,
|
|
outgoing_rx: mpsc::Receiver<crate::outgoing_message::OutgoingEnvelope>,
|
|
session: ConnectionSessionState,
|
|
tracing: &'static TestTracing,
|
|
}
|
|
|
|
impl TracingHarness {
|
|
async fn new() -> Result<Self> {
|
|
let server = create_mock_responses_server_repeating_assistant("Done").await;
|
|
let codex_home = TempDir::new()?;
|
|
let config = Arc::new(build_test_config(codex_home.path(), &server.uri()).await?);
|
|
let (processor, outgoing_rx) = build_test_processor(config);
|
|
let tracing = init_test_tracing();
|
|
tracing.exporter.reset();
|
|
tracing::callsite::rebuild_interest_cache();
|
|
let mut harness = Self {
|
|
_server: server,
|
|
_codex_home: codex_home,
|
|
processor,
|
|
outgoing_rx,
|
|
session: ConnectionSessionState::default(),
|
|
tracing,
|
|
};
|
|
|
|
let _: InitializeResponse = harness
|
|
.request(
|
|
ClientRequest::Initialize {
|
|
request_id: RequestId::Integer(1),
|
|
params: InitializeParams {
|
|
client_info: ClientInfo {
|
|
name: "codex-app-server-tests".to_string(),
|
|
title: None,
|
|
version: "0.1.0".to_string(),
|
|
},
|
|
capabilities: Some(InitializeCapabilities {
|
|
experimental_api: true,
|
|
..Default::default()
|
|
}),
|
|
},
|
|
},
|
|
None,
|
|
)
|
|
.await;
|
|
assert!(harness.session.initialized);
|
|
|
|
Ok(harness)
|
|
}
|
|
|
|
fn reset_tracing(&self) {
|
|
self.tracing.exporter.reset();
|
|
}
|
|
|
|
async fn shutdown(self) {
|
|
self.processor.shutdown_threads().await;
|
|
self.processor.drain_background_tasks().await;
|
|
}
|
|
|
|
async fn request<T>(&mut self, request: ClientRequest, trace: Option<W3cTraceContext>) -> T
|
|
where
|
|
T: serde::de::DeserializeOwned,
|
|
{
|
|
let request_id = match request.id() {
|
|
RequestId::Integer(request_id) => *request_id,
|
|
request_id => panic!("expected integer request id in test harness, got {request_id:?}"),
|
|
};
|
|
let mut request = request_from_client_request(request);
|
|
request.trace = trace;
|
|
|
|
self.processor
|
|
.process_request(
|
|
TEST_CONNECTION_ID,
|
|
request,
|
|
AppServerTransport::Stdio,
|
|
&mut self.session,
|
|
)
|
|
.await;
|
|
read_response(&mut self.outgoing_rx, request_id).await
|
|
}
|
|
|
|
async fn start_thread(
|
|
&mut self,
|
|
request_id: i64,
|
|
trace: Option<W3cTraceContext>,
|
|
) -> ThreadStartResponse {
|
|
let response = self
|
|
.request(
|
|
ClientRequest::ThreadStart {
|
|
request_id: RequestId::Integer(request_id),
|
|
params: ThreadStartParams {
|
|
ephemeral: Some(true),
|
|
..ThreadStartParams::default()
|
|
},
|
|
},
|
|
trace,
|
|
)
|
|
.await;
|
|
read_thread_started_notification(&mut self.outgoing_rx).await;
|
|
response
|
|
}
|
|
}
|
|
|
|
async fn build_test_config(codex_home: &Path, server_uri: &str) -> Result<Config> {
|
|
write_mock_responses_config_toml(
|
|
codex_home,
|
|
server_uri,
|
|
&BTreeMap::new(),
|
|
8_192,
|
|
Some(false),
|
|
"mock_provider",
|
|
"compact",
|
|
)?;
|
|
|
|
Ok(ConfigBuilder::default()
|
|
.codex_home(codex_home.to_path_buf())
|
|
.build()
|
|
.await?)
|
|
}
|
|
|
|
fn build_test_processor(
|
|
config: Arc<Config>,
|
|
) -> (
|
|
MessageProcessor,
|
|
mpsc::Receiver<crate::outgoing_message::OutgoingEnvelope>,
|
|
) {
|
|
let (outgoing_tx, outgoing_rx) = mpsc::channel(16);
|
|
let outgoing = Arc::new(OutgoingMessageSender::new(outgoing_tx));
|
|
let processor = MessageProcessor::new(MessageProcessorArgs {
|
|
outgoing,
|
|
arg0_paths: Arg0DispatchPaths::default(),
|
|
config,
|
|
cli_overrides: Vec::new(),
|
|
loader_overrides: LoaderOverrides::default(),
|
|
cloud_requirements: CloudRequirementsLoader::default(),
|
|
feedback: CodexFeedback::new(),
|
|
log_db: None,
|
|
config_warnings: Vec::new(),
|
|
session_source: SessionSource::VSCode,
|
|
enable_codex_api_key_env: false,
|
|
});
|
|
(processor, outgoing_rx)
|
|
}
|
|
|
|
fn span_attr<'a>(span: &'a SpanData, key: &str) -> Option<&'a str> {
|
|
span.attributes
|
|
.iter()
|
|
.find(|kv| kv.key.as_str() == key)
|
|
.and_then(|kv| match &kv.value {
|
|
opentelemetry::Value::String(value) => Some(value.as_str()),
|
|
_ => None,
|
|
})
|
|
}
|
|
|
|
fn find_rpc_span_with_trace<'a>(
|
|
spans: &'a [SpanData],
|
|
kind: SpanKind,
|
|
method: &str,
|
|
trace_id: TraceId,
|
|
) -> &'a SpanData {
|
|
spans
|
|
.iter()
|
|
.find(|span| {
|
|
span.span_kind == kind
|
|
&& span_attr(span, "rpc.system") == Some("jsonrpc")
|
|
&& span_attr(span, "rpc.method") == Some(method)
|
|
&& span.span_context.trace_id() == trace_id
|
|
})
|
|
.unwrap_or_else(|| {
|
|
panic!(
|
|
"missing {kind:?} span for rpc.method={method} trace={trace_id}; exported spans:\n{}",
|
|
format_spans(spans)
|
|
)
|
|
})
|
|
}
|
|
|
|
fn find_span_by_name<'a>(spans: &'a [SpanData], name: &str) -> &'a SpanData {
|
|
spans
|
|
.iter()
|
|
.find(|span| span.name.as_ref() == name)
|
|
.unwrap_or_else(|| {
|
|
panic!(
|
|
"missing span named {name}; exported spans:\n{}",
|
|
format_spans(spans)
|
|
)
|
|
})
|
|
}
|
|
|
|
fn find_span_by_name_with_trace<'a>(
|
|
spans: &'a [SpanData],
|
|
name: &str,
|
|
trace_id: TraceId,
|
|
) -> &'a SpanData {
|
|
spans
|
|
.iter()
|
|
.find(|span| span.name.as_ref() == name && span.span_context.trace_id() == trace_id)
|
|
.unwrap_or_else(|| {
|
|
panic!(
|
|
"missing span named {name} for trace={trace_id}; exported spans:\n{}",
|
|
format_spans(spans)
|
|
)
|
|
})
|
|
}
|
|
|
|
fn format_spans(spans: &[SpanData]) -> String {
|
|
spans
|
|
.iter()
|
|
.map(|span| {
|
|
let rpc_method = span_attr(span, "rpc.method").unwrap_or("-");
|
|
format!(
|
|
"name={} span_id={} kind={:?} parent={} trace={} rpc.method={}",
|
|
span.name,
|
|
span.span_context.span_id(),
|
|
span.span_kind,
|
|
span.parent_span_id,
|
|
span.span_context.trace_id(),
|
|
rpc_method
|
|
)
|
|
})
|
|
.collect::<Vec<_>>()
|
|
.join("\n")
|
|
}
|
|
|
|
fn assert_span_descends_from(spans: &[SpanData], child: &SpanData, ancestor: &SpanData) {
|
|
let ancestor_span_id = ancestor.span_context.span_id();
|
|
let mut parent_span_id = child.parent_span_id;
|
|
while parent_span_id != SpanId::INVALID {
|
|
if parent_span_id == ancestor_span_id {
|
|
return;
|
|
}
|
|
let Some(parent_span) = spans
|
|
.iter()
|
|
.find(|span| span.span_context.span_id() == parent_span_id)
|
|
else {
|
|
break;
|
|
};
|
|
parent_span_id = parent_span.parent_span_id;
|
|
}
|
|
|
|
panic!(
|
|
"span {} does not descend from {}; exported spans:\n{}",
|
|
child.name,
|
|
ancestor.name,
|
|
format_spans(spans)
|
|
);
|
|
}
|
|
|
|
async fn read_response<T: serde::de::DeserializeOwned>(
|
|
outgoing_rx: &mut mpsc::Receiver<crate::outgoing_message::OutgoingEnvelope>,
|
|
request_id: i64,
|
|
) -> T {
|
|
loop {
|
|
let envelope = tokio::time::timeout(std::time::Duration::from_secs(5), outgoing_rx.recv())
|
|
.await
|
|
.expect("timed out waiting for response")
|
|
.expect("outgoing channel closed");
|
|
let crate::outgoing_message::OutgoingEnvelope::ToConnection {
|
|
connection_id,
|
|
message,
|
|
} = envelope
|
|
else {
|
|
continue;
|
|
};
|
|
if connection_id != TEST_CONNECTION_ID {
|
|
continue;
|
|
}
|
|
let crate::outgoing_message::OutgoingMessage::Response(response) = message else {
|
|
continue;
|
|
};
|
|
if response.id != RequestId::Integer(request_id) {
|
|
continue;
|
|
}
|
|
return serde_json::from_value(response.result)
|
|
.expect("response payload should deserialize");
|
|
}
|
|
}
|
|
|
|
async fn read_thread_started_notification(
|
|
outgoing_rx: &mut mpsc::Receiver<crate::outgoing_message::OutgoingEnvelope>,
|
|
) {
|
|
loop {
|
|
let envelope = tokio::time::timeout(std::time::Duration::from_secs(5), outgoing_rx.recv())
|
|
.await
|
|
.expect("timed out waiting for thread/started notification")
|
|
.expect("outgoing channel closed");
|
|
match envelope {
|
|
crate::outgoing_message::OutgoingEnvelope::ToConnection {
|
|
connection_id,
|
|
message,
|
|
} => {
|
|
if connection_id != TEST_CONNECTION_ID {
|
|
continue;
|
|
}
|
|
let crate::outgoing_message::OutgoingMessage::AppServerNotification(notification) =
|
|
message
|
|
else {
|
|
continue;
|
|
};
|
|
if matches!(
|
|
notification,
|
|
codex_app_server_protocol::ServerNotification::ThreadStarted(_)
|
|
) {
|
|
return;
|
|
}
|
|
}
|
|
crate::outgoing_message::OutgoingEnvelope::Broadcast { message } => {
|
|
let crate::outgoing_message::OutgoingMessage::AppServerNotification(notification) =
|
|
message
|
|
else {
|
|
continue;
|
|
};
|
|
if matches!(
|
|
notification,
|
|
codex_app_server_protocol::ServerNotification::ThreadStarted(_)
|
|
) {
|
|
return;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn wait_for_exported_spans<F>(tracing: &TestTracing, predicate: F) -> Vec<SpanData>
|
|
where
|
|
F: Fn(&[SpanData]) -> bool,
|
|
{
|
|
let mut last_spans = Vec::new();
|
|
for _ in 0..200 {
|
|
tokio::task::yield_now().await;
|
|
tracing
|
|
.provider
|
|
.force_flush()
|
|
.expect("force flush should succeed");
|
|
let spans = tracing.exporter.get_finished_spans().expect("span export");
|
|
last_spans = spans.clone();
|
|
if predicate(&spans) {
|
|
return spans;
|
|
}
|
|
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
|
|
}
|
|
|
|
panic!(
|
|
"timed out waiting for expected exported spans:\n{}",
|
|
format_spans(&last_spans)
|
|
);
|
|
}
|
|
|
|
#[tokio::test(flavor = "current_thread")]
|
|
async fn thread_start_jsonrpc_span_exports_server_span_and_parents_children() -> Result<()> {
|
|
let _guard = tracing_test_guard().lock().await;
|
|
let mut harness = TracingHarness::new().await?;
|
|
|
|
let RemoteTrace {
|
|
trace_id: remote_trace_id,
|
|
context: remote_trace,
|
|
..
|
|
} = RemoteTrace::new("00000000000000000000000000000011", "0000000000000022");
|
|
|
|
let _: ThreadStartResponse = harness.start_thread(2, None).await;
|
|
let untraced_spans = wait_for_exported_spans(harness.tracing, |spans| {
|
|
spans
|
|
.iter()
|
|
.any(|span| span.name.as_ref() == "app_server.thread_start.derive_config")
|
|
})
|
|
.await;
|
|
let untraced_derive_config_span =
|
|
find_span_by_name(&untraced_spans, "app_server.thread_start.derive_config");
|
|
assert_ne!(untraced_derive_config_span.parent_span_id, SpanId::INVALID);
|
|
|
|
harness.reset_tracing();
|
|
|
|
let _: ThreadStartResponse = harness.start_thread(3, Some(remote_trace)).await;
|
|
let spans = wait_for_exported_spans(harness.tracing, |spans| {
|
|
spans.iter().any(|span| {
|
|
span.span_kind == SpanKind::Server
|
|
&& span_attr(span, "rpc.method") == Some("thread/start")
|
|
&& span.span_context.trace_id() == remote_trace_id
|
|
}) && spans
|
|
.iter()
|
|
.any(|span| span.name.as_ref() == "app_server.thread_start.notify_started")
|
|
})
|
|
.await;
|
|
|
|
let derive_config_span = find_span_by_name(&spans, "app_server.thread_start.derive_config");
|
|
let server_request_span =
|
|
find_rpc_span_with_trace(&spans, SpanKind::Server, "thread/start", remote_trace_id);
|
|
let thread_spawn_span = find_span_by_name_with_trace(&spans, "thread_spawn", remote_trace_id);
|
|
let session_init_span = find_span_by_name_with_trace(&spans, "session_init", remote_trace_id);
|
|
assert_eq!(server_request_span.name.as_ref(), "thread/start");
|
|
assert_eq!(server_request_span.span_context.trace_id(), remote_trace_id);
|
|
assert_ne!(server_request_span.span_context.span_id(), SpanId::INVALID);
|
|
assert_eq!(
|
|
derive_config_span.span_context.trace_id(),
|
|
remote_trace_id,
|
|
"thread/start startup spans did not inherit the inbound request trace"
|
|
);
|
|
assert_span_descends_from(&spans, derive_config_span, server_request_span);
|
|
assert_span_descends_from(&spans, thread_spawn_span, server_request_span);
|
|
assert_span_descends_from(&spans, session_init_span, server_request_span);
|
|
|
|
let default_model_span = find_span_by_name(&spans, "get_default_model");
|
|
let session_init_rollout_span = find_span_by_name(&spans, "session_init.rollout");
|
|
assert_span_descends_from(&spans, default_model_span, server_request_span);
|
|
assert_span_descends_from(&spans, session_init_rollout_span, server_request_span);
|
|
harness.shutdown().await;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
#[tokio::test(flavor = "current_thread")]
|
|
async fn turn_start_jsonrpc_span_parents_core_turn_spans() -> Result<()> {
|
|
let _guard = tracing_test_guard().lock().await;
|
|
let mut harness = TracingHarness::new().await?;
|
|
let thread_start_response = harness.start_thread(2, None).await;
|
|
let thread_id = thread_start_response.thread.id.clone();
|
|
|
|
harness.reset_tracing();
|
|
|
|
let RemoteTrace {
|
|
trace_id: remote_trace_id,
|
|
parent_span_id: remote_parent_span_id,
|
|
context: remote_trace,
|
|
} = RemoteTrace::new("00000000000000000000000000000077", "0000000000000088");
|
|
let _: TurnStartResponse = harness
|
|
.request(
|
|
ClientRequest::TurnStart {
|
|
request_id: RequestId::Integer(3),
|
|
params: TurnStartParams {
|
|
thread_id,
|
|
input: vec![UserInput::Text {
|
|
text: "hello".to_string(),
|
|
text_elements: Vec::new(),
|
|
}],
|
|
cwd: None,
|
|
approval_policy: None,
|
|
sandbox_policy: None,
|
|
model: None,
|
|
service_tier: None,
|
|
effort: None,
|
|
summary: None,
|
|
personality: None,
|
|
output_schema: None,
|
|
collaboration_mode: None,
|
|
},
|
|
},
|
|
Some(remote_trace),
|
|
)
|
|
.await;
|
|
let spans = wait_for_exported_spans(harness.tracing, |spans| {
|
|
spans.iter().any(|span| {
|
|
span.span_kind == SpanKind::Server
|
|
&& span_attr(span, "rpc.method") == Some("turn/start")
|
|
&& span.span_context.trace_id() == remote_trace_id
|
|
}) && spans.iter().any(|span| {
|
|
span.name.as_ref() == "op.dispatch.user_input"
|
|
&& span.span_context.trace_id() == remote_trace_id
|
|
})
|
|
})
|
|
.await;
|
|
|
|
let server_request_span =
|
|
find_rpc_span_with_trace(&spans, SpanKind::Server, "turn/start", remote_trace_id);
|
|
let core_turn_span =
|
|
find_span_by_name_with_trace(&spans, "op.dispatch.user_input", remote_trace_id);
|
|
|
|
assert_eq!(server_request_span.parent_span_id, remote_parent_span_id);
|
|
assert!(server_request_span.parent_span_is_remote);
|
|
assert_eq!(server_request_span.span_context.trace_id(), remote_trace_id);
|
|
assert_span_descends_from(&spans, core_turn_span, server_request_span);
|
|
harness.shutdown().await;
|
|
|
|
Ok(())
|
|
}
|