Compare commits

...

6 Commits

Author SHA1 Message Date
Owen Lin
b168272203 Add SQLite operation telemetry 2026-05-07 19:25:18 -07:00
Owen Lin
9ddf828d4c Add SQLite init and fallback telemetry 2026-05-07 19:24:28 -07:00
starr-openai
a3de5bde6e Add stdio exec-server client transport (#20664)
## Why

Configured environments need to connect to exec-server instances that
are not necessarily already listening on a websocket URL. A
command-backed stdio transport lets Codex start an exec-server process,
speak JSON-RPC over its stdio streams, and clean up that child process
with the client lifetime.

**Stack position:** this is PR 2 of 5. It builds on the server-side
stdio listener from PR 1 and provides the client transport used by later
environment/config PRs.

## What Changed

- Add `ExecServerTransport` variants for websocket URLs and stdio shell
commands.
- Add stdio command connection support for `ExecServerClient`.
- Move websocket/stdio transport setup into `client_transport.rs` so
`client.rs` stays focused on shared JSON-RPC client, session, HTTP, and
notification behavior.
- Tie stdio child process cleanup to the JSON-RPC connection lifetime
with a RAII lifetime guard.
- Keep existing websocket environment behavior by adapting URL-backed
remotes to `ExecServerTransport::WebSocketUrl`.

## Stack

- 1. https://github.com/openai/codex/pull/20663 - Add stdio exec-server
listener
- **2. This PR:** https://github.com/openai/codex/pull/20664 - Add stdio
exec-server client transport
- 3. https://github.com/openai/codex/pull/20665 - Make environment
providers own default selection
- 4. https://github.com/openai/codex/pull/20666 - Add CODEX_HOME
environments TOML provider
- 5. https://github.com/openai/codex/pull/20667 - Load configured
environments from CODEX_HOME

Split from original draft: https://github.com/openai/codex/pull/20508

## Validation

Not run locally; this was split out of the original draft stack and then
refactored to separate transport setup from the base client.

---------

Co-authored-by: Codex <noreply@openai.com>
2026-05-07 23:48:50 +00:00
Zanie Blue
79154e6952 Use --locked in cargo build and lint invocations (#21602)
This ensures CI fails if the committed lockfile is outdated
2026-05-07 23:14:18 +00:00
William Woodruff
893038f77c [codex] Apply a Dependabot cooldown of 7 days (#21599)
This adds 7-day cooldowns to all of our Dependabot ecosystem blocks. Our
Dependabot runs will continue at the same cadence as before, but the
scheduled PRs will no suggest updates that are fewer than 7 days old
themselves. This serves two purposes: to let dependencies "bake" for a
bit in terms of stability before we adopt them, and to give third-party
security services/tooling a chance to detect and revoke malware.

This should have no functional changes/consequences besides how rapidly
we get (non-security) updates. Dependabot security PRs can still be
scheduled and will bypass the cooldown.
2026-05-07 16:07:46 -07:00
bbrown-oai
31b233c7c6 codex-otel: add configurable trace metadata (#21556)
Add Codex config for static trace span attributes and structured W3C
tracestate field upserts. The config flows through OtelSettings so
callers can attach trace metadata without touching every span call site.

Apply span attributes with an SDK span processor so every exported
trace span carries the configured metadata. Model tracestate as nested
member fields so configured keys can be upserted while unrelated
propagated state in the same member is preserved.

Validate configured tracestate before installing provider-global state,
including header-unsafe values the SDK does not reject by itself. This
keeps Codex from propagating malformed trace context from config.

Update the config schema, public docs, and OTLP loopback coverage for
config parsing, span export, propagation, and invalid-header rejection.
2026-05-07 16:06:57 -07:00
52 changed files with 2761 additions and 583 deletions

View File

@@ -6,25 +6,37 @@ updates:
directory: .github/actions/codex
schedule:
interval: weekly
cooldown:
default-days: 7
- package-ecosystem: cargo
directories:
- codex-rs
- codex-rs/*
schedule:
interval: weekly
cooldown:
default-days: 7
- package-ecosystem: devcontainers
directory: /
schedule:
interval: weekly
cooldown:
default-days: 7
- package-ecosystem: docker
directory: codex-cli
schedule:
interval: weekly
cooldown:
default-days: 7
- package-ecosystem: github-actions
directory: /
schedule:
interval: weekly
cooldown:
default-days: 7
- package-ecosystem: rust-toolchain
directory: codex-rs
schedule:
interval: weekly
cooldown:
default-days: 7

View File

@@ -445,7 +445,7 @@ jobs:
cargo chef cook --recipe-path "$RECIPE" --target ${{ matrix.target }} --release
- name: cargo clippy
run: cargo clippy --target ${{ matrix.target }} --tests --profile ${{ matrix.profile }} --timings -- -D warnings
run: cargo clippy --target ${{ matrix.target }} --tests --profile ${{ matrix.profile }} --timings --locked -- -D warnings
- name: Upload Cargo timings (clippy)
if: always()

View File

@@ -75,7 +75,7 @@ jobs:
- name: Cargo build
working-directory: tools/argument-comment-lint
shell: bash
run: cargo build --release --target ${{ matrix.target }}
run: cargo build --release --target ${{ matrix.target }} --locked
- name: Stage artifact
shell: bash

View File

@@ -109,7 +109,7 @@ jobs:
for binary in ${{ matrix.binaries }}; do
build_args+=(--bin "$binary")
done
cargo build --target ${{ matrix.target }} --release --timings "${build_args[@]}"
cargo build --target ${{ matrix.target }} --release --timings --locked "${build_args[@]}"
- name: Upload Cargo timings
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0

View File

@@ -261,7 +261,7 @@ jobs:
run: |
set -euo pipefail
target="${{ matrix.target }}"
cargo build --target "$target" --release --timings --bin bwrap
cargo build --target "$target" --release --timings --locked --bin bwrap
bwrap_path="target/${target}/release/bwrap"
if [[ ! -f "$bwrap_path" ]]; then
@@ -281,7 +281,7 @@ jobs:
build_args+=(--bin "$binary")
done
echo "CARGO_PROFILE_RELEASE_LTO: ${CARGO_PROFILE_RELEASE_LTO}"
cargo build --target ${{ matrix.target }} --release --timings "${build_args[@]}"
cargo build --target ${{ matrix.target }} --release --timings --locked "${build_args[@]}"
- name: Upload Cargo timings
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0

View File

@@ -60,6 +60,7 @@ const PROJECT_LOCAL_CONFIG_DENYLIST: &[&str] = &[
"profile",
"profiles",
"experimental_realtime_ws_base_url",
"otel",
];
async fn first_layer_config_error_from_entries(layers: &[ConfigLayerEntry]) -> Option<ConfigError> {

View File

@@ -514,6 +514,12 @@ pub struct OtelConfigToml {
/// Optional metrics exporter
pub metrics_exporter: Option<OtelExporterKind>,
/// Attributes to add to every exported trace span.
pub span_attributes: Option<BTreeMap<String, String>>,
/// Semicolon-separated `key:value` fields to upsert into W3C tracestate members.
pub tracestate: Option<BTreeMap<String, BTreeMap<String, String>>>,
}
/// Effective OTEL settings after defaults are applied.
@@ -524,6 +530,8 @@ pub struct OtelConfig {
pub exporter: OtelExporterKind,
pub trace_exporter: OtelExporterKind,
pub metrics_exporter: OtelExporterKind,
pub span_attributes: BTreeMap<String, String>,
pub tracestate: BTreeMap<String, BTreeMap<String, String>>,
}
impl Default for OtelConfig {
@@ -534,6 +542,8 @@ impl Default for OtelConfig {
exporter: OtelExporterKind::None,
trace_exporter: OtelExporterKind::None,
metrics_exporter: OtelExporterKind::Statsig,
span_attributes: BTreeMap::new(),
tracestate: BTreeMap::new(),
}
}
}

View File

@@ -1695,6 +1695,13 @@
],
"description": "Optional metrics exporter"
},
"span_attributes": {
"additionalProperties": {
"type": "string"
},
"description": "Attributes to add to every exported trace span.",
"type": "object"
},
"trace_exporter": {
"allOf": [
{
@@ -1702,6 +1709,16 @@
}
],
"description": "Optional trace exporter"
},
"tracestate": {
"additionalProperties": {
"additionalProperties": {
"type": "string"
},
"type": "object"
},
"description": "Semicolon-separated `key:value` fields to upsert into W3C tracestate members.",
"type": "object"
}
},
"type": "object"

View File

@@ -1752,6 +1752,9 @@ notify = ["sh", "-c", "echo attacker"]
profile = "attacker"
experimental_realtime_ws_base_url = "wss://attacker.example/realtime"
[otel]
environment = "attacker"
[profiles.attacker]
model = "attacker-model"
model_instructions_file = 1
@@ -1801,6 +1804,7 @@ wire_api = "responses"
"profile",
"profiles",
"experimental_realtime_ws_base_url",
"otel",
];
let expected_startup_warnings = vec![format!(
concat!(

View File

@@ -44,6 +44,9 @@ use codex_config::types::Notice;
use codex_config::types::NotificationCondition;
use codex_config::types::NotificationMethod;
use codex_config::types::Notifications;
use codex_config::types::OtelConfig;
use codex_config::types::OtelConfigToml;
use codex_config::types::OtelExporterKind;
use codex_config::types::SandboxWorkspaceWrite;
use codex_config::types::SessionPickerViewMode;
use codex_config::types::SkillsConfig;
@@ -7118,6 +7121,119 @@ async fn trace_exporter_defaults_to_none_when_log_exporter_is_set() -> std::io::
Ok(())
}
#[tokio::test]
async fn load_config_applies_otel_trace_metadata() -> std::io::Result<()> {
let mut fixture = create_test_fixture()?;
fixture.cfg = toml::from_str(
r#"
[otel.span_attributes]
"example.trace_attr" = "enabled"
[otel.tracestate.example]
alpha = "one"
beta = "two"
"#,
)
.expect("TOML deserialization should succeed");
let config = Config::load_from_base_config_with_overrides(
fixture.cfg.clone(),
ConfigOverrides {
cwd: Some(fixture.cwd_path()),
..Default::default()
},
fixture.codex_home(),
)
.await?;
assert_eq!(
config.otel.span_attributes,
BTreeMap::from([("example.trace_attr".to_string(), "enabled".to_string())])
);
assert_eq!(
config.otel.tracestate,
BTreeMap::from([(
"example".to_string(),
BTreeMap::from([
("alpha".to_string(), "one".to_string()),
("beta".to_string(), "two".to_string()),
]),
)])
);
Ok(())
}
#[tokio::test]
async fn load_config_drops_invalid_otel_trace_metadata_entries() -> std::io::Result<()> {
let mut fixture = create_test_fixture()?;
fixture.cfg = toml::from_str(
r#"
[otel]
environment = "test"
[otel.span_attributes]
"" = "missing-key"
"example.trace_attr" = "enabled"
[otel.tracestate.example]
alpha = "one"
beta = "two\ntoo"
[otel.tracestate.bad]
alpha = "one\ntwo"
"#,
)
.expect("TOML deserialization should succeed");
let config = Config::load_from_base_config_with_overrides(
fixture.cfg.clone(),
ConfigOverrides {
cwd: Some(fixture.cwd_path()),
..Default::default()
},
fixture.codex_home(),
)
.await?;
assert_eq!(config.otel.environment, "test");
assert_eq!(
config.otel.span_attributes,
BTreeMap::from([("example.trace_attr".to_string(), "enabled".to_string())])
);
assert_eq!(
config.otel.tracestate,
BTreeMap::from([(
"example".to_string(),
BTreeMap::from([("alpha".to_string(), "one".to_string())]),
)])
);
assert!(
config.startup_warnings.iter().any(|warning| {
warning.contains("Ignoring invalid `otel.span_attributes` config")
&& warning.contains("configured span attribute key must not be empty")
}),
"{:?}",
config.startup_warnings
);
assert!(
config.startup_warnings.iter().any(|warning| {
warning.contains("Ignoring invalid `otel.tracestate` config")
&& warning.contains("invalid configured tracestate value for example.beta")
}),
"{:?}",
config.startup_warnings
);
assert!(
config.startup_warnings.iter().any(|warning| {
warning.contains("Ignoring invalid `otel.tracestate` config")
&& warning.contains("invalid configured tracestate value for bad.alpha")
}),
"{:?}",
config.startup_warnings
);
Ok(())
}
#[tokio::test]
async fn explicit_null_service_tier_override_sets_fast_default_opt_out() -> std::io::Result<()> {
let fixture = create_test_fixture()?;

View File

@@ -37,7 +37,6 @@ use codex_config::profile_toml::ConfigProfile;
use codex_config::sandbox_mode_requirement_for_permission_profile;
use codex_config::types::ApprovalsReviewer;
use codex_config::types::AuthCredentialsStoreMode;
use codex_config::types::DEFAULT_OTEL_ENVIRONMENT;
use codex_config::types::History;
use codex_config::types::McpServerConfig;
use codex_config::types::McpServerDisabledReason;
@@ -46,9 +45,6 @@ use codex_config::types::MemoriesConfig;
use codex_config::types::ModelAvailabilityNuxConfig;
use codex_config::types::Notice;
use codex_config::types::OAuthCredentialsStoreMode;
use codex_config::types::OtelConfig;
use codex_config::types::OtelConfigToml;
use codex_config::types::OtelExporterKind;
use codex_config::types::SessionPickerViewMode;
use codex_config::types::ToolSuggestConfig;
use codex_config::types::ToolSuggestDisabledTool;
@@ -132,6 +128,7 @@ pub(crate) mod agent_roles;
pub mod edit;
mod managed_features;
mod network_proxy_spec;
mod otel;
mod permissions;
#[cfg(test)]
mod schema;
@@ -2978,6 +2975,7 @@ impl Config {
.value
.set(effective_permission_profile)
.map_err(std::io::Error::from)?;
let otel = otel::resolve_config(cfg.otel.unwrap_or_default(), &mut startup_warnings);
let config = Self {
model,
service_tier,
@@ -3205,26 +3203,7 @@ impl Config {
.as_ref()
.map(|t| t.keymap.clone())
.unwrap_or_default(),
otel: {
let t: OtelConfigToml = cfg.otel.unwrap_or_default();
let log_user_prompt = t.log_user_prompt.unwrap_or(false);
let environment = t
.environment
.unwrap_or(DEFAULT_OTEL_ENVIRONMENT.to_string());
let exporter = t.exporter.unwrap_or(OtelExporterKind::None);
// OTLP HTTP endpoints are signal-specific in our config, so
// enabling log export must not implicitly send spans to a
// /v1/logs endpoint.
let trace_exporter = t.trace_exporter.unwrap_or(OtelExporterKind::None);
let metrics_exporter = t.metrics_exporter.unwrap_or(OtelExporterKind::Statsig);
OtelConfig {
log_user_prompt,
environment,
exporter,
trace_exporter,
metrics_exporter,
}
},
otel,
};
Ok(config)
})

View File

@@ -0,0 +1,117 @@
use std::collections::BTreeMap;
use std::fmt::Display;
use codex_config::types::DEFAULT_OTEL_ENVIRONMENT;
use codex_config::types::OtelConfig;
use codex_config::types::OtelConfigToml;
use codex_config::types::OtelExporterKind;
pub(crate) fn resolve_config(
config: OtelConfigToml,
startup_warnings: &mut Vec<String>,
) -> OtelConfig {
let log_user_prompt = config.log_user_prompt.unwrap_or(false);
let environment = config
.environment
.unwrap_or_else(|| DEFAULT_OTEL_ENVIRONMENT.to_string());
let exporter = config.exporter.unwrap_or(OtelExporterKind::None);
// OTLP HTTP endpoints are signal-specific in our config, so enabling log
// export must not implicitly send spans to a /v1/logs endpoint.
let trace_exporter = config.trace_exporter.unwrap_or(OtelExporterKind::None);
let metrics_exporter = config.metrics_exporter.unwrap_or(OtelExporterKind::Statsig);
// Provider initialization installs process-global OTEL state. Sanitize
// user-editable trace metadata here so malformed config is reported as a
// startup warning instead of making startup fail.
let span_attributes = resolve_span_attributes(config.span_attributes, startup_warnings);
let tracestate = resolve_tracestate(config.tracestate, startup_warnings);
OtelConfig {
log_user_prompt,
environment,
exporter,
trace_exporter,
metrics_exporter,
span_attributes,
tracestate,
}
}
fn resolve_span_attributes(
span_attributes: Option<BTreeMap<String, String>>,
startup_warnings: &mut Vec<String>,
) -> BTreeMap<String, String> {
let Some(span_attributes) = span_attributes else {
return BTreeMap::new();
};
let mut valid_attributes = BTreeMap::new();
for (key, value) in span_attributes {
let attribute = BTreeMap::from([(key.clone(), value.clone())]);
if let Err(err) = codex_otel::validate_span_attributes(&attribute) {
push_invalid_config_warning("otel.span_attributes", err, startup_warnings);
continue;
}
valid_attributes.insert(key, value);
}
valid_attributes
}
fn resolve_tracestate(
tracestate: Option<BTreeMap<String, BTreeMap<String, String>>>,
startup_warnings: &mut Vec<String>,
) -> BTreeMap<String, BTreeMap<String, String>> {
let Some(tracestate) = tracestate else {
return BTreeMap::new();
};
let mut valid_entries = BTreeMap::new();
for (member_key, fields) in tracestate {
let fields = resolve_tracestate_member_fields(&member_key, fields, startup_warnings);
if fields.is_empty() {
continue;
}
if let Err(err) = codex_otel::validate_tracestate_member(&member_key, &fields) {
push_invalid_config_warning("otel.tracestate", err, startup_warnings);
continue;
}
valid_entries.insert(member_key, fields);
}
// Tracestate members can be valid individually while the combined W3C
// tracestate header is not, so validate the filtered set before handing it
// to provider initialization.
if let Err(err) = codex_otel::validate_tracestate_entries(&valid_entries) {
push_invalid_config_warning("otel.tracestate", err, startup_warnings);
return BTreeMap::new();
}
valid_entries
}
fn resolve_tracestate_member_fields(
member_key: &str,
fields: BTreeMap<String, String>,
startup_warnings: &mut Vec<String>,
) -> BTreeMap<String, String> {
let mut valid_fields = BTreeMap::new();
for (field_key, value) in fields {
let field = BTreeMap::from([(field_key.clone(), value.clone())]);
if let Err(err) = codex_otel::validate_tracestate_member(member_key, &field) {
push_invalid_config_warning("otel.tracestate", err, startup_warnings);
continue;
}
valid_fields.insert(field_key, value);
}
valid_fields
}
fn push_invalid_config_warning(
config_key: &str,
err: impl Display,
startup_warnings: &mut Vec<String>,
) {
let message = format!("Ignoring invalid `{config_key}` config: {err}");
tracing::warn!("{message}");
startup_warnings.push(message);
}

View File

@@ -80,7 +80,7 @@ pub fn build_provider(
let service_name = service_name_override.unwrap_or(originator.value.as_str());
let runtime_metrics = config.features.enabled(Feature::RuntimeMetrics);
OtelProvider::from(&OtelSettings {
let provider = OtelProvider::from(&OtelSettings {
service_name: service_name.to_string(),
service_version: service_version.to_string(),
codex_home: config.codex_home.to_path_buf(),
@@ -89,7 +89,17 @@ pub fn build_provider(
trace_exporter,
metrics_exporter,
runtime_metrics,
})
span_attributes: config.otel.span_attributes.clone(),
tracestate: config.otel.tracestate.clone(),
})?;
if let Some(provider) = provider.as_ref()
&& let Some(metrics) = provider.metrics()
{
let _ = codex_otel::record_process_start_once(metrics, originator.value.as_str());
}
Ok(provider)
}
/// Filter predicate for exporting only Codex-owned events via OTEL.

View File

@@ -17,13 +17,14 @@ use tokio::sync::mpsc;
use tokio::sync::watch;
use tokio::time::timeout;
use tokio_tungstenite::connect_async;
use tracing::debug;
use crate::ProcessId;
use crate::client_api::ExecServerClientConnectOptions;
use crate::client_api::ExecServerTransportParams;
use crate::client_api::HttpClient;
use crate::client_api::RemoteExecServerConnectArgs;
use crate::client_api::StdioExecServerConnectArgs;
use crate::connection::JsonRpcConnection;
use crate::process::ExecProcessEvent;
use crate::process::ExecProcessEventLog;
@@ -105,6 +106,16 @@ impl From<RemoteExecServerConnectArgs> for ExecServerClientConnectOptions {
}
}
impl From<StdioExecServerConnectArgs> for ExecServerClientConnectOptions {
fn from(value: StdioExecServerConnectArgs) -> Self {
Self {
client_name: value.client_name,
initialize_timeout: value.initialize_timeout,
resume_session_id: value.resume_session_id,
}
}
}
impl RemoteExecServerConnectArgs {
pub fn new(websocket_url: String, client_name: String) -> Self {
Self {
@@ -180,29 +191,25 @@ pub struct ExecServerClient {
#[derive(Clone)]
pub(crate) struct LazyRemoteExecServerClient {
websocket_url: String,
transport_params: ExecServerTransportParams,
client: Arc<OnceCell<ExecServerClient>>,
}
impl LazyRemoteExecServerClient {
pub(crate) fn new(websocket_url: String) -> Self {
pub(crate) fn new(transport_params: ExecServerTransportParams) -> Self {
Self {
websocket_url,
transport_params,
client: Arc::new(OnceCell::new()),
}
}
pub(crate) async fn get(&self) -> Result<ExecServerClient, ExecServerError> {
self.client
.get_or_try_init(|| async {
ExecServerClient::connect_websocket(RemoteExecServerConnectArgs {
websocket_url: self.websocket_url.clone(),
client_name: "codex-environment".to_string(),
connect_timeout: Duration::from_secs(5),
initialize_timeout: Duration::from_secs(5),
resume_session_id: None,
})
.await
// TODO: Add reconnect/disconnect handling here instead of reusing
// the first successfully initialized connection forever.
.get_or_try_init(|| {
let transport_params = self.transport_params.clone();
async move { ExecServerClient::connect_for_transport(transport_params).await }
})
.await
.cloned()
@@ -269,32 +276,6 @@ pub enum ExecServerError {
}
impl ExecServerClient {
pub async fn connect_websocket(
args: RemoteExecServerConnectArgs,
) -> Result<Self, ExecServerError> {
let websocket_url = args.websocket_url.clone();
let connect_timeout = args.connect_timeout;
let (stream, _) = timeout(connect_timeout, connect_async(websocket_url.as_str()))
.await
.map_err(|_| ExecServerError::WebSocketConnectTimeout {
url: websocket_url.clone(),
timeout: connect_timeout,
})?
.map_err(|source| ExecServerError::WebSocketConnect {
url: websocket_url.clone(),
source,
})?;
Self::connect(
JsonRpcConnection::from_websocket(
stream,
format!("exec-server websocket {websocket_url}"),
),
args.into(),
)
.await
}
pub async fn initialize(
&self,
options: ExecServerClientConnectOptions,
@@ -443,7 +424,7 @@ impl ExecServerClient {
.clone()
}
async fn connect(
pub(crate) async fn connect(
connection: JsonRpcConnection,
options: ExecServerClientConnectOptions,
) -> Result<Self, ExecServerError> {
@@ -893,18 +874,30 @@ mod tests {
use codex_app_server_protocol::JSONRPCNotification;
use codex_app_server_protocol::JSONRPCResponse;
use pretty_assertions::assert_eq;
use std::collections::HashMap;
#[cfg(unix)]
use std::path::Path;
#[cfg(unix)]
use std::process::Command;
use tokio::io::AsyncBufReadExt;
use tokio::io::AsyncWrite;
use tokio::io::AsyncWriteExt;
use tokio::io::BufReader;
use tokio::io::duplex;
use tokio::sync::mpsc;
use tokio::sync::oneshot;
use tokio::time::Duration;
#[cfg(unix)]
use tokio::time::sleep;
use tokio::time::timeout;
use super::ExecServerClient;
use super::ExecServerClientConnectOptions;
use crate::ProcessId;
#[cfg(not(windows))]
use crate::client_api::ExecServerTransportParams;
use crate::client_api::StdioExecServerCommand;
use crate::client_api::StdioExecServerConnectArgs;
use crate::connection::JsonRpcConnection;
use crate::process::ExecProcessEvent;
use crate::protocol::EXEC_CLOSED_METHOD;
@@ -942,6 +935,191 @@ mod tests {
.expect("json-rpc line should write");
}
#[cfg(not(windows))]
#[tokio::test]
async fn connect_stdio_command_initializes_json_rpc_client() {
let client = ExecServerClient::connect_stdio_command(StdioExecServerConnectArgs {
command: StdioExecServerCommand {
program: "sh".to_string(),
args: vec![
"-c".to_string(),
"read _line; printf '%s\\n' '{\"id\":1,\"result\":{\"sessionId\":\"stdio-test\"}}'; read _line; sleep 60".to_string(),
],
env: HashMap::new(),
cwd: None,
},
client_name: "stdio-test-client".to_string(),
initialize_timeout: Duration::from_secs(1),
resume_session_id: None,
})
.await
.expect("stdio client should connect");
assert_eq!(client.session_id().as_deref(), Some("stdio-test"));
}
#[cfg(not(windows))]
#[tokio::test]
async fn connect_for_transport_initializes_stdio_command() {
let client = ExecServerClient::connect_for_transport(
ExecServerTransportParams::StdioCommand(StdioExecServerCommand {
program: "sh".to_string(),
args: vec![
"-c".to_string(),
"read _line; printf '%s\\n' '{\"id\":1,\"result\":{\"sessionId\":\"stdio-test\"}}'; read _line; sleep 60".to_string(),
],
env: HashMap::new(),
cwd: None,
}),
)
.await
.expect("stdio transport should connect");
assert_eq!(client.session_id().as_deref(), Some("stdio-test"));
}
#[cfg(windows)]
#[tokio::test]
async fn connect_stdio_command_initializes_json_rpc_client_on_windows() {
let client = ExecServerClient::connect_stdio_command(StdioExecServerConnectArgs {
command: StdioExecServerCommand {
program: "powershell".to_string(),
args: vec![
"-NoProfile".to_string(),
"-Command".to_string(),
"$null = [Console]::In.ReadLine(); [Console]::Out.WriteLine('{\"id\":1,\"result\":{\"sessionId\":\"stdio-test\"}}'); $null = [Console]::In.ReadLine(); Start-Sleep -Seconds 60".to_string(),
],
env: HashMap::new(),
cwd: None,
},
client_name: "stdio-test-client".to_string(),
initialize_timeout: Duration::from_secs(1),
resume_session_id: None,
})
.await
.expect("stdio client should connect");
assert_eq!(client.session_id().as_deref(), Some("stdio-test"));
}
#[cfg(unix)]
#[tokio::test]
async fn dropping_stdio_client_terminates_spawned_process() {
let tempdir = tempfile::tempdir().expect("tempdir should be created");
let pid_file = tempdir.path().join("server.pid");
let child_pid_file = tempdir.path().join("server-child.pid");
let stdio_script = format!(
"read _line; \
echo \"$$\" > {}; \
sleep 60 >/dev/null 2>&1 & echo \"$!\" > {}; \
printf '%s\\n' '{{\"id\":1,\"result\":{{\"sessionId\":\"stdio-test\"}}}}'; \
read _line; \
wait",
shell_quote(pid_file.as_path()),
shell_quote(child_pid_file.as_path()),
);
let client = ExecServerClient::connect_stdio_command(StdioExecServerConnectArgs {
command: StdioExecServerCommand {
program: "sh".to_string(),
args: vec!["-c".to_string(), stdio_script],
env: HashMap::new(),
cwd: None,
},
client_name: "stdio-test-client".to_string(),
initialize_timeout: Duration::from_secs(1),
resume_session_id: None,
})
.await
.expect("stdio client should connect");
let server_pid = read_pid_file(pid_file.as_path()).await;
let child_pid = read_pid_file(child_pid_file.as_path()).await;
assert!(
process_exists(server_pid),
"spawned stdio process should be running before client drop"
);
assert!(
process_exists(child_pid),
"spawned stdio child process should be running before client drop"
);
drop(client);
wait_for_process_exit(server_pid).await;
wait_for_process_exit(child_pid).await;
}
#[cfg(unix)]
#[tokio::test]
async fn malformed_stdio_message_terminates_spawned_process() {
let tempdir = tempfile::tempdir().expect("tempdir should be created");
let pid_file = tempdir.path().join("server.pid");
let stdio_script = format!(
"read _line; \
echo \"$$\" > {}; \
printf '%s\\n' 'not-json'; \
sleep 60",
shell_quote(pid_file.as_path()),
);
let result = ExecServerClient::connect_stdio_command(StdioExecServerConnectArgs {
command: StdioExecServerCommand {
program: "sh".to_string(),
args: vec!["-c".to_string(), stdio_script],
env: HashMap::new(),
cwd: None,
},
client_name: "stdio-test-client".to_string(),
initialize_timeout: Duration::from_secs(1),
resume_session_id: None,
})
.await;
assert!(result.is_err(), "malformed stdio server should not connect");
let server_pid = read_pid_file(pid_file.as_path()).await;
wait_for_process_exit(server_pid).await;
}
#[cfg(unix)]
async fn read_pid_file(path: &Path) -> u32 {
for _ in 0..20 {
if let Ok(contents) = std::fs::read_to_string(path) {
return contents
.trim()
.parse()
.expect("pid file should contain a pid");
}
sleep(Duration::from_millis(50)).await;
}
panic!("pid file {} should be written", path.display());
}
#[cfg(unix)]
async fn wait_for_process_exit(pid: u32) {
for _ in 0..20 {
if !process_exists(pid) {
return;
}
sleep(Duration::from_millis(100)).await;
}
panic!("process {pid} should exit");
}
#[cfg(unix)]
fn process_exists(pid: u32) -> bool {
Command::new("kill")
.arg("-0")
.arg(pid.to_string())
.status()
.is_ok_and(|status| status.success())
}
#[cfg(unix)]
fn shell_quote(path: &Path) -> String {
let value = path.to_string_lossy();
format!("'{}'", value.replace('\'', "'\\''"))
}
#[tokio::test]
async fn process_events_are_delivered_in_seq_order_when_notifications_are_reordered() {
let (client_stdin, server_reader) = duplex(1 << 20);
@@ -1085,6 +1263,92 @@ mod tests {
server.await.expect("server task should finish");
}
#[tokio::test]
async fn transport_disconnect_fails_sessions_and_rejects_new_sessions() {
let (client_stdin, server_reader) = duplex(1 << 20);
let (mut server_writer, client_stdout) = duplex(1 << 20);
let (disconnect_tx, disconnect_rx) = oneshot::channel();
let server = tokio::spawn(async move {
let mut lines = BufReader::new(server_reader).lines();
let initialize = read_jsonrpc_line(&mut lines).await;
let request = match initialize {
JSONRPCMessage::Request(request) if request.method == INITIALIZE_METHOD => request,
other => panic!("expected initialize request, got {other:?}"),
};
write_jsonrpc_line(
&mut server_writer,
JSONRPCMessage::Response(JSONRPCResponse {
id: request.id,
result: serde_json::to_value(InitializeResponse {
session_id: "session-1".to_string(),
})
.expect("initialize response should serialize"),
}),
)
.await;
let initialized = read_jsonrpc_line(&mut lines).await;
match initialized {
JSONRPCMessage::Notification(notification)
if notification.method == INITIALIZED_METHOD => {}
other => panic!("expected initialized notification, got {other:?}"),
}
let _ = disconnect_rx.await;
drop(server_writer);
});
let client = ExecServerClient::connect(
JsonRpcConnection::from_stdio(
client_stdout,
client_stdin,
"test-exec-server-client".to_string(),
),
ExecServerClientConnectOptions::default(),
)
.await
.expect("client should connect");
let process_id = ProcessId::from("disconnect");
let session = client
.register_session(&process_id)
.await
.expect("session should register");
let mut events = session.subscribe_events();
disconnect_tx.send(()).expect("disconnect should signal");
let event = timeout(Duration::from_secs(1), events.recv())
.await
.expect("session failure should not time out")
.expect("session event stream should stay open");
let ExecProcessEvent::Failed(message) = event else {
panic!("expected session failure after disconnect, got {event:?}");
};
assert_eq!(message, "exec-server transport disconnected");
let response = session
.read(
/*after_seq*/ None, /*max_bytes*/ None, /*wait_ms*/ None,
)
.await
.expect("disconnected session read should synthesize a response");
assert_eq!(
response.failure.as_deref(),
Some("exec-server transport disconnected")
);
assert!(response.closed);
let new_session = client.register_session(&ProcessId::from("new")).await;
assert!(matches!(
new_session,
Err(super::ExecServerError::Disconnected(_))
));
drop(client);
server.await.expect("server task should finish");
}
#[tokio::test]
async fn wake_notifications_do_not_block_other_sessions() {
let (client_stdin, server_reader) = duplex(1 << 20);

View File

@@ -1,3 +1,5 @@
use std::collections::HashMap;
use std::path::PathBuf;
use std::time::Duration;
use futures::future::BoxFuture;
@@ -25,6 +27,32 @@ pub struct RemoteExecServerConnectArgs {
pub resume_session_id: Option<String>,
}
/// Stdio connection arguments for a command-backed exec-server.
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) struct StdioExecServerConnectArgs {
pub command: StdioExecServerCommand,
pub client_name: String,
pub initialize_timeout: Duration,
pub resume_session_id: Option<String>,
}
/// Structured process command used to start an exec-server over stdio.
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) struct StdioExecServerCommand {
pub program: String,
pub args: Vec<String>,
pub env: HashMap<String, String>,
pub cwd: Option<PathBuf>,
}
/// Parameters used to connect to a remote exec-server environment.
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) enum ExecServerTransportParams {
WebSocketUrl(String),
#[allow(dead_code)]
StdioCommand(StdioExecServerCommand),
}
/// Sends HTTP requests through a runtime-selected transport.
///
/// This is the HTTP capability counterpart to [`crate::ExecBackend`]. Callers

View File

@@ -0,0 +1,127 @@
use std::process::Stdio;
use std::time::Duration;
use tokio::io::AsyncBufReadExt;
use tokio::io::BufReader;
use tokio::process::Command;
use tokio::time::timeout;
use tokio_tungstenite::connect_async;
use tracing::debug;
use tracing::warn;
use crate::ExecServerClient;
use crate::ExecServerError;
use crate::client_api::RemoteExecServerConnectArgs;
use crate::client_api::StdioExecServerCommand;
use crate::client_api::StdioExecServerConnectArgs;
use crate::connection::JsonRpcConnection;
const ENVIRONMENT_CLIENT_NAME: &str = "codex-environment";
const ENVIRONMENT_CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
const ENVIRONMENT_INITIALIZE_TIMEOUT: Duration = Duration::from_secs(5);
impl ExecServerClient {
pub(crate) async fn connect_for_transport(
transport_params: crate::client_api::ExecServerTransportParams,
) -> Result<Self, ExecServerError> {
match transport_params {
crate::client_api::ExecServerTransportParams::WebSocketUrl(websocket_url) => {
Self::connect_websocket(RemoteExecServerConnectArgs {
websocket_url,
client_name: ENVIRONMENT_CLIENT_NAME.to_string(),
connect_timeout: ENVIRONMENT_CONNECT_TIMEOUT,
initialize_timeout: ENVIRONMENT_INITIALIZE_TIMEOUT,
resume_session_id: None,
})
.await
}
crate::client_api::ExecServerTransportParams::StdioCommand(command) => {
Self::connect_stdio_command(StdioExecServerConnectArgs {
command,
client_name: ENVIRONMENT_CLIENT_NAME.to_string(),
initialize_timeout: ENVIRONMENT_INITIALIZE_TIMEOUT,
resume_session_id: None,
})
.await
}
}
}
pub async fn connect_websocket(
args: RemoteExecServerConnectArgs,
) -> Result<Self, ExecServerError> {
let websocket_url = args.websocket_url.clone();
let connect_timeout = args.connect_timeout;
let (stream, _) = timeout(connect_timeout, connect_async(websocket_url.as_str()))
.await
.map_err(|_| ExecServerError::WebSocketConnectTimeout {
url: websocket_url.clone(),
timeout: connect_timeout,
})?
.map_err(|source| ExecServerError::WebSocketConnect {
url: websocket_url.clone(),
source,
})?;
Self::connect(
JsonRpcConnection::from_websocket(
stream,
format!("exec-server websocket {websocket_url}"),
),
args.into(),
)
.await
}
pub(crate) async fn connect_stdio_command(
args: StdioExecServerConnectArgs,
) -> Result<Self, ExecServerError> {
let mut child = stdio_command_process(&args.command)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()
.map_err(ExecServerError::Spawn)?;
let stdin = child.stdin.take().ok_or_else(|| {
ExecServerError::Protocol("spawned exec-server command has no stdin".to_string())
})?;
let stdout = child.stdout.take().ok_or_else(|| {
ExecServerError::Protocol("spawned exec-server command has no stdout".to_string())
})?;
if let Some(stderr) = child.stderr.take() {
tokio::spawn(async move {
let mut lines = BufReader::new(stderr).lines();
loop {
match lines.next_line().await {
Ok(Some(line)) => debug!("exec-server stdio stderr: {line}"),
Ok(None) => break,
Err(err) => {
warn!("failed to read exec-server stdio stderr: {err}");
break;
}
}
}
});
}
Self::connect(
JsonRpcConnection::from_stdio(stdout, stdin, "exec-server stdio command".to_string())
.with_child_process(child),
args.into(),
)
.await
}
}
fn stdio_command_process(stdio_command: &StdioExecServerCommand) -> Command {
let mut command = Command::new(&stdio_command.program);
command.args(&stdio_command.args);
command.envs(&stdio_command.env);
if let Some(cwd) = &stdio_command.cwd {
command.current_dir(cwd);
}
#[cfg(unix)]
command.process_group(0);
command
}

View File

@@ -1,12 +1,21 @@
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering;
use std::time::Duration;
use codex_app_server_protocol::JSONRPCMessage;
use futures::SinkExt;
use futures::StreamExt;
use tokio::io::AsyncRead;
use tokio::io::AsyncWrite;
use tokio::process::Child;
use tokio::sync::mpsc;
use tokio::sync::watch;
use tokio::time::timeout;
use tokio_tungstenite::WebSocketStream;
use tokio_tungstenite::tungstenite::Message;
use tracing::debug;
use tracing::warn;
use tokio::io::AsyncBufReadExt;
use tokio::io::AsyncWriteExt;
@@ -14,6 +23,7 @@ use tokio::io::BufReader;
use tokio::io::BufWriter;
pub(crate) const CHANNEL_CAPACITY: usize = 128;
const STDIO_TERMINATION_GRACE_PERIOD: Duration = Duration::from_secs(2);
#[derive(Debug)]
pub(crate) enum JsonRpcConnectionEvent {
@@ -22,11 +32,186 @@ pub(crate) enum JsonRpcConnectionEvent {
Disconnected { reason: Option<String> },
}
#[derive(Clone)]
pub(crate) enum JsonRpcTransport {
Plain,
Stdio { transport: StdioTransport },
}
impl JsonRpcTransport {
fn from_child_process(child_process: Child) -> Self {
Self::Stdio {
transport: StdioTransport::spawn(child_process),
}
}
pub(crate) fn terminate(&self) {
match self {
Self::Plain => {}
Self::Stdio { transport } => transport.terminate(),
}
}
}
#[derive(Clone)]
pub(crate) struct StdioTransport {
handle: Arc<StdioTransportHandle>,
}
struct StdioTransportHandle {
terminate_tx: watch::Sender<bool>,
terminate_requested: AtomicBool,
}
impl StdioTransport {
fn spawn(child_process: Child) -> Self {
let (terminate_tx, terminate_rx) = watch::channel(false);
let handle = Arc::new(StdioTransportHandle {
terminate_tx,
terminate_requested: AtomicBool::new(false),
});
spawn_stdio_child_supervisor(child_process, terminate_rx);
Self { handle }
}
fn terminate(&self) {
self.handle.terminate();
}
}
impl StdioTransportHandle {
fn terminate(&self) {
if !self.terminate_requested.swap(true, Ordering::AcqRel) {
let _ = self.terminate_tx.send(true);
}
}
}
impl Drop for StdioTransportHandle {
fn drop(&mut self) {
self.terminate();
}
}
fn spawn_stdio_child_supervisor(mut child_process: Child, mut terminate_rx: watch::Receiver<bool>) {
let process_group_id = child_process.id();
tokio::spawn(async move {
tokio::select! {
result = child_process.wait() => {
log_stdio_child_wait_result(result);
kill_process_tree(&mut child_process, process_group_id);
}
() = wait_for_stdio_termination(&mut terminate_rx) => {
terminate_stdio_child(&mut child_process, process_group_id).await;
}
}
});
}
async fn wait_for_stdio_termination(terminate_rx: &mut watch::Receiver<bool>) {
loop {
if *terminate_rx.borrow() {
return;
}
if terminate_rx.changed().await.is_err() {
return;
}
}
}
async fn terminate_stdio_child(child_process: &mut Child, process_group_id: Option<u32>) {
terminate_process_tree(child_process, process_group_id);
match timeout(STDIO_TERMINATION_GRACE_PERIOD, child_process.wait()).await {
Ok(result) => {
log_stdio_child_wait_result(result);
}
Err(_) => {
kill_process_tree(child_process, process_group_id);
log_stdio_child_wait_result(child_process.wait().await);
}
}
}
fn terminate_process_tree(child_process: &mut Child, process_group_id: Option<u32>) {
let Some(process_group_id) = process_group_id else {
kill_direct_child(child_process, "terminate");
return;
};
#[cfg(unix)]
if let Err(err) = codex_utils_pty::process_group::terminate_process_group(process_group_id) {
warn!("failed to terminate exec-server stdio process group {process_group_id}: {err}");
kill_direct_child(child_process, "terminate");
}
#[cfg(windows)]
if !kill_windows_process_tree(process_group_id) {
kill_direct_child(child_process, "terminate");
}
#[cfg(not(any(unix, windows)))]
{
let _ = process_group_id;
kill_direct_child(child_process, "terminate");
}
}
fn kill_process_tree(child_process: &mut Child, process_group_id: Option<u32>) {
let Some(process_group_id) = process_group_id else {
kill_direct_child(child_process, "kill");
return;
};
#[cfg(unix)]
if let Err(err) = codex_utils_pty::process_group::kill_process_group(process_group_id) {
warn!("failed to kill exec-server stdio process group {process_group_id}: {err}");
}
#[cfg(windows)]
if !kill_windows_process_tree(process_group_id) {
kill_direct_child(child_process, "kill");
}
#[cfg(not(any(unix, windows)))]
{
let _ = process_group_id;
kill_direct_child(child_process, "kill");
}
}
fn kill_direct_child(child_process: &mut Child, action: &str) {
if let Err(err) = child_process.start_kill() {
debug!("failed to {action} exec-server stdio child: {err}");
}
}
#[cfg(windows)]
fn kill_windows_process_tree(pid: u32) -> bool {
let pid = pid.to_string();
match std::process::Command::new("taskkill")
.args(["/PID", pid.as_str(), "/T", "/F"])
.status()
{
Ok(status) => status.success(),
Err(err) => {
warn!("failed to run taskkill for exec-server stdio process tree {pid}: {err}");
false
}
}
}
fn log_stdio_child_wait_result(result: std::io::Result<std::process::ExitStatus>) {
if let Err(err) = result {
debug!("failed to wait for exec-server stdio child: {err}");
}
}
pub(crate) struct JsonRpcConnection {
outgoing_tx: mpsc::Sender<JSONRPCMessage>,
incoming_rx: mpsc::Receiver<JsonRpcConnectionEvent>,
disconnected_rx: watch::Receiver<bool>,
task_handles: Vec<tokio::task::JoinHandle<()>>,
pub(crate) outgoing_tx: mpsc::Sender<JSONRPCMessage>,
pub(crate) incoming_rx: mpsc::Receiver<JsonRpcConnectionEvent>,
pub(crate) disconnected_rx: watch::Receiver<bool>,
pub(crate) task_handles: Vec<tokio::task::JoinHandle<()>>,
pub(crate) transport: JsonRpcTransport,
}
impl JsonRpcConnection {
@@ -117,6 +302,7 @@ impl JsonRpcConnection {
incoming_rx,
disconnected_rx,
task_handles: vec![reader_task, writer_task],
transport: JsonRpcTransport::Plain,
}
}
@@ -251,23 +437,13 @@ impl JsonRpcConnection {
incoming_rx,
disconnected_rx,
task_handles: vec![reader_task, writer_task],
transport: JsonRpcTransport::Plain,
}
}
pub(crate) fn into_parts(
self,
) -> (
mpsc::Sender<JSONRPCMessage>,
mpsc::Receiver<JsonRpcConnectionEvent>,
watch::Receiver<bool>,
Vec<tokio::task::JoinHandle<()>>,
) {
(
self.outgoing_tx,
self.incoming_rx,
self.disconnected_rx,
self.task_handles,
)
pub(crate) fn with_child_process(mut self, child_process: Child) -> Self {
self.transport = JsonRpcTransport::from_child_process(child_process);
self
}
}

View File

@@ -7,6 +7,7 @@ use crate::ExecutorFileSystem;
use crate::HttpClient;
use crate::client::LazyRemoteExecServerClient;
use crate::client::http_client::ReqwestHttpClient;
use crate::client_api::ExecServerTransportParams;
use crate::environment_provider::DefaultEnvironmentProvider;
use crate::environment_provider::EnvironmentProvider;
use crate::environment_provider::normalize_exec_server_url;
@@ -274,7 +275,9 @@ impl Environment {
exec_server_url: String,
local_runtime_paths: Option<ExecServerRuntimePaths>,
) -> Self {
let client = LazyRemoteExecServerClient::new(exec_server_url.clone());
let client = LazyRemoteExecServerClient::new(ExecServerTransportParams::WebSocketUrl(
exec_server_url.clone(),
));
let exec_backend: Arc<dyn ExecBackend> = Arc::new(RemoteProcess::new(client.clone()));
let filesystem: Arc<dyn ExecutorFileSystem> =
Arc::new(RemoteFileSystem::new(client.clone()));

View File

@@ -1,5 +1,6 @@
mod client;
mod client_api;
mod client_transport;
mod connection;
mod environment;
mod environment_provider;

View File

@@ -23,6 +23,7 @@ use tokio::task::JoinHandle;
use crate::connection::JsonRpcConnection;
use crate::connection::JsonRpcConnectionEvent;
use crate::connection::JsonRpcTransport;
#[derive(Debug)]
pub(crate) enum RpcCallError {
@@ -58,11 +59,9 @@ pub(crate) enum RpcServerOutboundMessage {
request_id: RequestId,
error: JSONRPCErrorError,
},
#[allow(dead_code)]
Notification(JSONRPCNotification),
}
#[allow(dead_code)]
#[derive(Clone)]
pub(crate) struct RpcNotificationSender {
outgoing_tx: mpsc::Sender<RpcServerOutboundMessage>,
@@ -84,7 +83,6 @@ impl RpcNotificationSender {
.map_err(|_| internal_error("RPC connection closed while sending response".into()))
}
#[allow(dead_code)]
pub(crate) async fn notify<P: Serialize>(
&self,
method: &str,
@@ -229,43 +227,55 @@ pub(crate) struct RpcClient {
disconnected_rx: watch::Receiver<bool>,
next_request_id: AtomicI64,
transport_tasks: Vec<JoinHandle<()>>,
transport: JsonRpcTransport,
reader_task: JoinHandle<()>,
}
impl RpcClient {
pub(crate) fn new(connection: JsonRpcConnection) -> (Self, mpsc::Receiver<RpcClientEvent>) {
let (write_tx, mut incoming_rx, disconnected_rx, transport_tasks) = connection.into_parts();
let JsonRpcConnection {
outgoing_tx: write_tx,
mut incoming_rx,
disconnected_rx,
task_handles: transport_tasks,
transport,
} = connection;
let pending = Arc::new(Mutex::new(HashMap::<RequestId, PendingRequest>::new()));
let (event_tx, event_rx) = mpsc::channel(128);
let pending_for_reader = Arc::clone(&pending);
let transport_for_reader = transport.clone();
let reader_task = tokio::spawn(async move {
while let Some(event) = incoming_rx.recv().await {
let disconnect_reason = loop {
let Some(event) = incoming_rx.recv().await else {
break None;
};
match event {
JsonRpcConnectionEvent::Message(message) => {
if let Err(err) =
handle_server_message(&pending_for_reader, &event_tx, message).await
{
let _ = err;
break;
break None;
}
}
JsonRpcConnectionEvent::MalformedMessage { reason } => {
let _ = reason;
break;
break None;
}
JsonRpcConnectionEvent::Disconnected { reason } => {
let _ = event_tx.send(RpcClientEvent::Disconnected { reason }).await;
drain_pending(&pending_for_reader).await;
return;
break reason;
}
}
}
};
let _ = event_tx
.send(RpcClientEvent::Disconnected { reason: None })
.send(RpcClientEvent::Disconnected {
reason: disconnect_reason,
})
.await;
drain_pending(&pending_for_reader).await;
transport_for_reader.terminate();
});
(
@@ -275,6 +285,7 @@ impl RpcClient {
disconnected_rx,
next_request_id: AtomicI64::new(1),
transport_tasks,
transport,
reader_task,
},
event_rx,
@@ -357,7 +368,6 @@ impl RpcClient {
}
#[cfg(test)]
#[allow(dead_code)]
pub(crate) async fn pending_request_count(&self) -> usize {
self.pending.lock().await.len()
}
@@ -365,6 +375,7 @@ impl RpcClient {
impl Drop for RpcClient {
fn drop(&mut self) {
self.transport.terminate();
for task in &self.transport_tasks {
task.abort();
}
@@ -565,11 +576,9 @@ mod tests {
async fn rpc_client_matches_out_of_order_responses_by_request_id() {
let (client_stdin, server_reader) = tokio::io::duplex(4096);
let (mut server_writer, client_stdout) = tokio::io::duplex(4096);
let (client, _events_rx) = RpcClient::new(JsonRpcConnection::from_stdio(
client_stdout,
client_stdin,
"test-rpc".to_string(),
));
let connection =
JsonRpcConnection::from_stdio(client_stdout, client_stdin, "test-rpc".to_string());
let (client, _events_rx) = RpcClient::new(connection);
let server = tokio::spawn(async move {
let mut lines = BufReader::new(server_reader).lines();

View File

@@ -47,8 +47,13 @@ async fn run_connection(
runtime_paths: ExecServerRuntimePaths,
) {
let router = Arc::new(build_router());
let (json_outgoing_tx, mut incoming_rx, mut disconnected_rx, connection_tasks) =
connection.into_parts();
let JsonRpcConnection {
outgoing_tx: json_outgoing_tx,
mut incoming_rx,
mut disconnected_rx,
task_handles: connection_tasks,
transport: _transport,
} = connection;
let (outgoing_tx, mut outgoing_rx) =
mpsc::channel::<RpcServerOutboundMessage>(CHANNEL_CAPACITY);
let notifications = RpcNotificationSender::new(outgoing_tx.clone());

View File

@@ -200,10 +200,15 @@ pub async fn run_main(
mod tests {
use super::*;
use codex_config::types::OtelExporterKind;
use codex_config::types::OtelHttpProtocol;
use codex_core::config::ConfigBuilder;
use pretty_assertions::assert_eq;
use std::collections::HashMap;
use tempfile::TempDir;
use wiremock::Mock;
use wiremock::MockServer;
use wiremock::ResponseTemplate;
use wiremock::matchers::method;
#[test]
fn mcp_server_defaults_analytics_to_enabled() {
@@ -212,14 +217,21 @@ mod tests {
#[tokio::test]
async fn mcp_server_builds_otel_provider_with_logs_traces_and_metrics() -> anyhow::Result<()> {
let collector = MockServer::start().await;
Mock::given(method("POST"))
.respond_with(ResponseTemplate::new(200))
.mount(&collector)
.await;
let codex_home = TempDir::new()?;
let mut config = ConfigBuilder::default()
.codex_home(codex_home.path().to_path_buf())
.build()
.await?;
let exporter = OtelExporterKind::OtlpGrpc {
endpoint: "http://localhost:4317".to_string(),
let exporter = OtelExporterKind::OtlpHttp {
endpoint: collector.uri(),
headers: HashMap::new(),
protocol: OtelHttpProtocol::Binary,
tls: None,
};
config.otel.exporter = exporter.clone();

View File

@@ -39,6 +39,8 @@ let settings = OtelSettings {
tls: None,
},
metrics_exporter: OtelExporter::None,
span_attributes: std::collections::BTreeMap::new(),
tracestate: std::collections::BTreeMap::new(),
};
if let Some(provider) = OtelProvider::from(&settings)? {
@@ -49,6 +51,26 @@ if let Some(provider) = OtelProvider::from(&settings)? {
}
```
Configured span attributes and W3C tracestate member fields are applied to
exported trace spans and propagated trace context:
```toml
[otel.span_attributes]
"example.trace_attr" = "enabled"
[otel.tracestate.example]
alpha = "one"
beta = "two"
```
Configured tracestate members and encoded values must be valid W3C tracestate.
Each nested table is encoded as semicolon-separated `key:value` fields inside
that member. If propagated trace context already has the named member, Codex
upserts configured fields and preserves other fields in that member. This
config shape does not support setting opaque tracestate member values. Invalid
trace metadata entries are ignored during config load and reported as startup
warnings.
## SessionTelemetry (events)
`SessionTelemetry` adds consistent metadata to tracing events and helps record

View File

@@ -1,3 +1,4 @@
use std::collections::BTreeMap;
use std::collections::HashMap;
use std::path::PathBuf;
@@ -34,6 +35,18 @@ pub(crate) fn resolve_exporter(exporter: &OtelExporter) -> OtelExporter {
}
}
/// Validates configured span attributes before they are attached to exported spans.
pub fn validate_span_attributes(attributes: &BTreeMap<String, String>) -> std::io::Result<()> {
if attributes.keys().any(String::is_empty) {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"configured span attribute key must not be empty",
));
}
Ok(())
}
#[derive(Clone, Debug)]
pub struct OtelSettings {
pub environment: String,
@@ -44,6 +57,8 @@ pub struct OtelSettings {
pub trace_exporter: OtelExporter,
pub metrics_exporter: OtelExporter,
pub runtime_metrics: bool,
pub span_attributes: BTreeMap<String, String>,
pub tracestate: BTreeMap<String, BTreeMap<String, String>>,
}
/// Resolved Statsig metrics settings that another process can use to recreate

View File

@@ -16,6 +16,7 @@ pub use crate::config::OtelHttpProtocol;
pub use crate::config::OtelSettings;
pub use crate::config::OtelTlsConfig;
pub use crate::config::StatsigMetricsSettings;
pub use crate::config::validate_span_attributes;
pub use crate::events::session_telemetry::AuthEnvTelemetryMetadata;
pub use crate::events::session_telemetry::SessionTelemetry;
pub use crate::events::session_telemetry::SessionTelemetryMetadata;
@@ -31,6 +32,8 @@ pub use crate::trace_context::set_parent_from_context;
pub use crate::trace_context::set_parent_from_w3c_trace_context;
pub use crate::trace_context::span_w3c_trace_context;
pub use crate::trace_context::traceparent_context_from_env;
pub use crate::trace_context::validate_tracestate_entries;
pub use crate::trace_context::validate_tracestate_member;
pub use codex_utils_string::sanitize_metric_tag_value;
#[derive(Debug, Clone, Serialize, Display)]

View File

@@ -41,6 +41,7 @@ use std::time::Duration;
use tracing::debug;
const ENV_ATTRIBUTE: &str = "env";
const ARCH_ATTRIBUTE: &str = "arch";
const METER_NAME: &str = "codex";
const DURATION_UNIT: &str = "ms";
const DURATION_DESCRIPTION: &str = "Duration in milliseconds.";
@@ -198,13 +199,13 @@ impl MetricsClient {
validate_tags(&default_tags)?;
let mut resource_attributes = Vec::with_capacity(4);
let mut resource_attributes = Vec::with_capacity(5);
resource_attributes.push(KeyValue::new(
semconv::attribute::SERVICE_VERSION,
service_version,
));
resource_attributes.push(KeyValue::new(ENV_ATTRIBUTE, environment));
resource_attributes.extend(os_resource_attributes());
resource_attributes.extend(platform_resource_attributes());
let resource = Resource::builder()
.with_service_name(service_name)
@@ -290,12 +291,13 @@ impl MetricsClient {
}
}
fn os_resource_attributes() -> Vec<KeyValue> {
fn platform_resource_attributes() -> Vec<KeyValue> {
let os_info = os_info::get();
let os_type_raw = os_info.os_type().to_string();
let os_type = sanitize_metric_tag_value(os_type_raw.as_str());
let os_version_raw = os_info.version().to_string();
let os_version = sanitize_metric_tag_value(os_version_raw.as_str());
let arch = sanitize_metric_tag_value(std::env::consts::ARCH);
let mut attributes = Vec::new();
if os_type != "unspecified" {
attributes.push(KeyValue::new("os", os_type));
@@ -303,6 +305,9 @@ fn os_resource_attributes() -> Vec<KeyValue> {
if os_version != "unspecified" {
attributes.push(KeyValue::new("os_version", os_version));
}
if arch != "unspecified" {
attributes.push(KeyValue::new(ARCH_ATTRIBUTE, arch));
}
attributes
}

View File

@@ -2,6 +2,7 @@ mod client;
mod config;
mod error;
pub(crate) mod names;
mod process;
pub(crate) mod runtime_metrics;
pub(crate) mod tags;
pub(crate) mod timer;
@@ -13,9 +14,12 @@ pub use crate::metrics::config::MetricsConfig;
pub use crate::metrics::config::MetricsExporter;
pub use crate::metrics::error::MetricsError;
pub use crate::metrics::error::Result;
pub use crate::metrics::process::record_process_start_once;
pub use names::*;
use std::sync::OnceLock;
pub use tags::ORIGINATOR_TAG;
pub use tags::SessionMetricTagValues;
pub use tags::bounded_originator_tag_value;
static GLOBAL_METRICS: OnceLock<MetricsClient> = OnceLock::new();
static GLOBAL_STATSIG_METRICS_SETTINGS: OnceLock<StatsigMetricsSettings> = OnceLock::new();

View File

@@ -1,6 +1,7 @@
pub const TOOL_CALL_COUNT_METRIC: &str = "codex.tool.call";
pub const TOOL_CALL_DURATION_METRIC: &str = "codex.tool.call.duration_ms";
pub const TOOL_CALL_UNIFIED_EXEC_METRIC: &str = "codex.tool.unified_exec";
pub const PROCESS_START_METRIC: &str = "codex.process.start";
pub const API_CALL_COUNT_METRIC: &str = "codex.api_request";
pub const API_CALL_DURATION_METRIC: &str = "codex.api_request.duration_ms";
pub const SSE_EVENT_COUNT_METRIC: &str = "codex.sse_event";

View File

@@ -0,0 +1,27 @@
use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering;
use super::client::MetricsClient;
use super::error::Result;
use super::names::PROCESS_START_METRIC;
use super::tags::ORIGINATOR_TAG;
use super::tags::bounded_originator_tag_value;
static PROCESS_START_RECORDED: AtomicBool = AtomicBool::new(false);
/// Record the process start counter at most once for this process.
pub fn record_process_start_once(metrics: &MetricsClient, originator: &str) -> Result<bool> {
if PROCESS_START_RECORDED
.compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed)
.is_err()
{
return Ok(false);
}
metrics.counter(
PROCESS_START_METRIC,
/*inc*/ 1,
&[(ORIGINATOR_TAG, bounded_originator_tag_value(originator))],
)?;
Ok(true)
}

View File

@@ -1,6 +1,7 @@
use crate::metrics::Result;
use crate::metrics::validation::validate_tag_key;
use crate::metrics::validation::validate_tag_value;
use codex_utils_string::sanitize_metric_tag_value;
pub const APP_VERSION_TAG: &str = "app.version";
pub const AUTH_MODE_TAG: &str = "auth_mode";
@@ -9,6 +10,24 @@ pub const ORIGINATOR_TAG: &str = "originator";
pub const SERVICE_NAME_TAG: &str = "service_name";
pub const SESSION_SOURCE_TAG: &str = "session_source";
const OTHER_ORIGINATOR_TAG_VALUE: &str = "other";
/// Returns a sanitized, low-cardinality originator value that is safe to use as a metric tag.
pub fn bounded_originator_tag_value(originator: &str) -> &'static str {
match sanitize_metric_tag_value(originator).as_str() {
"codex_desktop" => "codex_desktop",
"codex_cli_rs" => "codex_cli_rs",
"codex-tui" => "codex-tui",
"codex_vscode" => "codex_vscode",
"none" => "none",
"codex_exec" => "codex_exec",
"codex-cli" => "codex-cli",
"codex_sdk_ts" => "codex_sdk_ts",
"codex-app-server-sdk" => "codex-app-server-sdk",
_ => OTHER_ORIGINATOR_TAG_VALUE,
}
}
pub struct SessionMetricTagValues<'a> {
pub auth_mode: Option<&'a str>,
pub session_source: &'a str,

View File

@@ -7,8 +7,10 @@ use crate::metrics::MetricsConfig;
use crate::targets::is_log_export_target;
use crate::targets::is_trace_safe_target;
use gethostname::gethostname;
use opentelemetry::Context;
use opentelemetry::KeyValue;
use opentelemetry::global;
use opentelemetry::trace::Span as _;
use opentelemetry::trace::TracerProvider as _;
use opentelemetry_appender_tracing::layer::OpenTelemetryTracingBridge;
use opentelemetry_otlp::LogExporter;
@@ -22,15 +24,22 @@ use opentelemetry_otlp::WithTonicConfig;
use opentelemetry_otlp::tonic_types::metadata::MetadataMap;
use opentelemetry_otlp::tonic_types::transport::ClientTlsConfig;
use opentelemetry_sdk::Resource;
use opentelemetry_sdk::error::OTelSdkResult;
use opentelemetry_sdk::logs::SdkLoggerProvider;
use opentelemetry_sdk::propagation::TraceContextPropagator;
use opentelemetry_sdk::runtime;
use opentelemetry_sdk::trace::BatchSpanProcessor;
use opentelemetry_sdk::trace::SdkTracerProvider;
use opentelemetry_sdk::trace::Span;
use opentelemetry_sdk::trace::SpanData;
use opentelemetry_sdk::trace::SpanProcessor;
use opentelemetry_sdk::trace::Tracer;
use opentelemetry_sdk::trace::TracerProviderBuilder;
use opentelemetry_sdk::trace::span_processor_with_async_runtime::BatchSpanProcessor as TokioBatchSpanProcessor;
use opentelemetry_semantic_conventions as semconv;
use std::collections::BTreeMap;
use std::error::Error;
use std::time::Duration;
use tracing::debug;
use tracing_subscriber::Layer;
use tracing_subscriber::registry::LookupSpan;
@@ -68,8 +77,28 @@ impl OtelProvider {
pub fn from(settings: &OtelSettings) -> Result<Option<Self>, Box<dyn Error>> {
let log_enabled = !matches!(settings.exporter, OtelExporter::None);
let trace_enabled = !matches!(settings.trace_exporter, OtelExporter::None);
let metric_exporter = crate::config::resolve_exporter(&settings.metrics_exporter);
let metrics_enabled = !matches!(metric_exporter, OtelExporter::None);
if !log_enabled && !trace_enabled && !metrics_enabled {
// Tracestate propagation is process-global; clear it when these
// settings do not install an active provider.
crate::trace_context::set_tracestate_entries(BTreeMap::new())?;
debug!("No OTEL exporter enabled in settings.");
return Ok(None);
}
// Provider setup below installs process-global OTEL state that cannot
// be rolled back, so reject invalid trace metadata before any setup
// path can mutate those globals.
if trace_enabled && settings.span_attributes.keys().any(String::is_empty) {
return Err(Box::new(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"configured span attribute key must not be empty",
)));
}
crate::trace_context::validate_tracestate_entries(&settings.tracestate)?;
let metrics = if matches!(metric_exporter, OtelExporter::None) {
None
} else {
@@ -85,20 +114,6 @@ impl OtelProvider {
Some(MetricsClient::new(config)?)
};
if let Some(metrics) = metrics.as_ref() {
crate::metrics::install_global(metrics.clone());
if matches!(settings.metrics_exporter, OtelExporter::Statsig) {
crate::metrics::install_global_statsig_settings(StatsigMetricsSettings {
environment: settings.environment.clone(),
});
}
}
if !log_enabled && !trace_enabled && metrics.is_none() {
debug!("No OTEL exporter enabled in settings.");
return Ok(None);
}
let log_resource = make_resource(settings, ResourceKind::Logs);
let trace_resource = make_resource(settings, ResourceKind::Traces);
let logger = log_enabled
@@ -106,17 +121,32 @@ impl OtelProvider {
.transpose()?;
let tracer_provider = trace_enabled
.then(|| build_tracer_provider(&trace_resource, &settings.trace_exporter))
.then(|| {
build_tracer_provider(
&trace_resource,
&settings.trace_exporter,
settings.span_attributes.clone(),
)
})
.transpose()?;
let tracer = tracer_provider
.as_ref()
.map(|provider| provider.tracer(settings.service_name.clone()));
crate::trace_context::set_tracestate_entries(settings.tracestate.clone())?;
if let Some(provider) = tracer_provider.clone() {
global::set_tracer_provider(provider);
global::set_text_map_propagator(TraceContextPropagator::new());
}
if let Some(metrics) = metrics.as_ref() {
crate::metrics::install_global(metrics.clone());
if matches!(settings.metrics_exporter, OtelExporter::Statsig) {
crate::metrics::install_global_statsig_settings(StatsigMetricsSettings {
environment: settings.environment.clone(),
});
}
}
Ok(Some(Self {
logger,
tracer_provider,
@@ -222,6 +252,47 @@ fn normalize_host_name(host_name: &str) -> Option<String> {
(!host_name.is_empty()).then(|| host_name.to_owned())
}
fn tracer_provider_builder(
resource: &Resource,
span_attributes: BTreeMap<String, String>,
) -> TracerProviderBuilder {
let builder = SdkTracerProvider::builder().with_resource(resource.clone());
if span_attributes.is_empty() {
builder
} else {
builder.with_span_processor(SpanAttributesProcessor {
attributes: span_attributes,
})
}
}
/// Applies configured attributes when spans start.
///
/// Resource attributes describe the provider process. These attributes are
/// per-span metadata, so they need to be attached before each span is exported.
#[derive(Debug)]
struct SpanAttributesProcessor {
attributes: BTreeMap<String, String>,
}
impl SpanProcessor for SpanAttributesProcessor {
fn on_start(&self, span: &mut Span, _cx: &Context) {
for (key, value) in self.attributes.iter() {
span.set_attribute(KeyValue::new(key.clone(), value.clone()));
}
}
fn on_end(&self, _span: SpanData) {}
fn force_flush(&self) -> OTelSdkResult {
Ok(())
}
fn shutdown_with_timeout(&self, _timeout: Duration) -> OTelSdkResult {
Ok(())
}
}
fn build_logger(
resource: &Resource,
exporter: &OtelExporter,
@@ -294,9 +365,10 @@ fn build_logger(
fn build_tracer_provider(
resource: &Resource,
exporter: &OtelExporter,
span_attributes: BTreeMap<String, String>,
) -> Result<SdkTracerProvider, Box<dyn Error>> {
let span_exporter = match crate::config::resolve_exporter(exporter) {
OtelExporter::None => return Ok(SdkTracerProvider::builder().build()),
OtelExporter::None => return Ok(tracer_provider_builder(resource, span_attributes).build()),
OtelExporter::Statsig => unreachable!("statsig exporter should be resolved"),
OtelExporter::OtlpGrpc {
endpoint,
@@ -353,8 +425,7 @@ fn build_tracer_provider(
TokioBatchSpanProcessor::builder(exporter_builder.build()?, runtime::Tokio)
.build();
return Ok(SdkTracerProvider::builder()
.with_resource(resource.clone())
return Ok(tracer_provider_builder(resource, span_attributes)
.with_span_processor(processor)
.build());
}
@@ -382,8 +453,7 @@ fn build_tracer_provider(
let processor = BatchSpanProcessor::builder(span_exporter).build();
Ok(SdkTracerProvider::builder()
.with_resource(resource.clone())
Ok(tracer_provider_builder(resource, span_attributes)
.with_span_processor(processor)
.build())
}
@@ -467,6 +537,8 @@ mod tests {
trace_exporter: OtelExporter::None,
metrics_exporter: OtelExporter::None,
runtime_metrics: false,
span_attributes: BTreeMap::new(),
tracestate: BTreeMap::new(),
}
}
}

View File

@@ -1,11 +1,16 @@
use std::collections::BTreeMap;
use std::collections::BTreeSet;
use std::collections::HashMap;
use std::env;
use std::str::FromStr;
use std::sync::OnceLock;
use std::sync::RwLock;
use codex_protocol::protocol::W3cTraceContext;
use opentelemetry::Context;
use opentelemetry::propagation::TextMapPropagator;
use opentelemetry::trace::TraceContextExt;
use opentelemetry::trace::TraceState;
use opentelemetry_sdk::propagation::TraceContextPropagator;
use tracing::Span;
use tracing::debug;
@@ -16,6 +21,11 @@ const TRACEPARENT_ENV_VAR: &str = "TRACEPARENT";
const TRACESTATE_ENV_VAR: &str = "TRACESTATE";
static TRACEPARENT_CONTEXT: OnceLock<Option<Context>> = OnceLock::new();
// Trace context propagation can happen outside the provider object, so configured
// tracestate lives beside the process-global tracer provider.
static TRACESTATE_ENTRIES: OnceLock<RwLock<BTreeMap<String, BTreeMap<String, String>>>> =
OnceLock::new();
pub fn current_span_w3c_trace_context() -> Option<W3cTraceContext> {
span_w3c_trace_context(&Span::current())
}
@@ -28,13 +38,28 @@ pub fn span_w3c_trace_context(span: &Span) -> Option<W3cTraceContext> {
let mut headers = HashMap::new();
TraceContextPropagator::new().inject_context(&context, &mut headers);
let tracestate = headers.remove("tracestate");
let configured_tracestate_guard = tracestate_entries()
.read()
.unwrap_or_else(std::sync::PoisonError::into_inner);
Some(W3cTraceContext {
traceparent: headers.remove("traceparent"),
tracestate: headers.remove("tracestate"),
tracestate: merge_tracestate_entries(tracestate.as_deref(), &configured_tracestate_guard),
})
}
pub(crate) fn set_tracestate_entries(
entries: BTreeMap<String, BTreeMap<String, String>>,
) -> Result<(), Box<dyn std::error::Error>> {
validate_tracestate_entries(&entries)?;
let mut guard = tracestate_entries()
.write()
.unwrap_or_else(std::sync::PoisonError::into_inner);
*guard = entries;
Ok(())
}
pub fn current_span_trace_id() -> Option<String> {
let context = Span::current().context();
let span = context.span();
@@ -103,6 +128,177 @@ fn load_traceparent_context() -> Option<Context> {
}
}
fn tracestate_entries() -> &'static RwLock<BTreeMap<String, BTreeMap<String, String>>> {
TRACESTATE_ENTRIES.get_or_init(|| RwLock::new(BTreeMap::new()))
}
fn merge_tracestate_entries(
tracestate: Option<&str>,
configured_entries: &BTreeMap<String, BTreeMap<String, String>>,
) -> Option<String> {
let mut trace_state = tracestate
.and_then(|tracestate| match TraceState::from_str(tracestate) {
Ok(trace_state) => Some(trace_state),
Err(err) => {
warn!("ignoring invalid tracestate while propagating trace context: {err}");
None
}
})
.unwrap_or_default();
// TraceState::insert places members at the front. Reverse iteration keeps
// deterministic map order while upserting fields inside configured members.
for (key, fields) in configured_entries.iter().rev() {
let value = merge_tracestate_member_fields(trace_state.get(key), fields);
trace_state = match trace_state.insert(key.clone(), value) {
Ok(trace_state) => trace_state,
Err(err) => {
warn!("ignoring configured tracestate while propagating trace context: {err}");
break;
}
};
}
let tracestate = trace_state.header();
(!tracestate.is_empty()).then_some(tracestate)
}
/// Validates configured tracestate members before they are propagated in W3C trace context.
pub fn validate_tracestate_entries(
entries: &BTreeMap<String, BTreeMap<String, String>>,
) -> Result<(), Box<dyn std::error::Error>> {
// Reject malformed entries before installing them so propagated trace
// context remains acceptable to other W3C Trace Context extractors. The
// SDK validates member keys and list structure, but configured member
// fields are joined into header values here and need stricter validation.
let entries = entries
.iter()
.map(|(key, fields)| encode_tracestate_member_fields(key, fields))
.collect::<Result<Vec<_>, _>>()?;
TraceState::from_key_value(
entries
.iter()
.map(|(key, value)| (key.as_str(), value.as_str())),
)
.map_err(|err| {
std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("invalid configured tracestate: {err}"),
)
})?;
Ok(())
}
/// Validates one configured tracestate member and its encoded field value.
pub fn validate_tracestate_member(
member_key: &str,
fields: &BTreeMap<String, String>,
) -> Result<(), Box<dyn std::error::Error>> {
let (key, value) = encode_tracestate_member_fields(member_key, fields)?;
TraceState::from_key_value([(key.as_str(), value.as_str())]).map_err(|err| {
std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("invalid configured tracestate: {err}"),
)
})?;
Ok(())
}
fn encode_tracestate_member_fields(
member_key: &str,
fields: &BTreeMap<String, String>,
) -> Result<(String, String), Box<dyn std::error::Error>> {
// Configured fields are encoded into one opaque tracestate member value.
// Validate both the field grammar and the final header value so malformed
// config cannot produce propagated trace context that downstream W3C
// extractors reject.
let mut encoded = Vec::with_capacity(fields.len());
for (field_key, value) in fields {
if !is_configured_tracestate_field_key(field_key) {
return Err(invalid_tracestate_config(format!(
"invalid configured tracestate field key {member_key}.{field_key}"
)));
}
if !is_configured_tracestate_field_value(value) {
return Err(invalid_tracestate_config(format!(
"invalid configured tracestate value for {member_key}.{field_key}"
)));
}
encoded.push(format!("{field_key}:{value}"));
}
let value = encoded.join(";");
if !is_header_safe_tracestate_member_value(&value) {
return Err(invalid_tracestate_config(format!(
"invalid configured tracestate value for {member_key}"
)));
}
Ok((member_key.to_string(), value))
}
fn is_configured_tracestate_field_key(field_key: &str) -> bool {
!field_key.is_empty()
&& field_key
.bytes()
.all(|byte| matches!(byte, b'!'..=b'~') && !matches!(byte, b':' | b';' | b',' | b'='))
}
fn is_configured_tracestate_field_value(value: &str) -> bool {
value
.bytes()
.all(|byte| is_tracestate_member_value_byte(byte) && byte != b';')
}
fn is_header_safe_tracestate_member_value(value: &str) -> bool {
value.is_empty()
|| (value.bytes().all(is_tracestate_member_value_byte)
&& value.as_bytes().last().is_some_and(|byte| *byte != b' '))
}
fn is_tracestate_member_value_byte(byte: u8) -> bool {
matches!(byte, b' '..=b'~') && !matches!(byte, b',' | b'=')
}
fn invalid_tracestate_config(message: String) -> Box<dyn std::error::Error> {
Box::new(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
message,
))
}
fn merge_tracestate_member_fields(
existing: Option<&str>,
configured_fields: &BTreeMap<String, String>,
) -> String {
// W3C TraceState treats member values as opaque strings. The config models
// values as semicolon-separated key:value fields so selected fields can be
// upserted without replacing unrelated fields in the same member.
let mut fields = Vec::new();
let mut seen = BTreeSet::new();
if let Some(existing) = existing {
for field in existing.split(';').filter(|field| !field.is_empty()) {
if let Some((field_key, _)) = field.split_once(':') {
if let Some(value) = configured_fields.get(field_key) {
if seen.insert(field_key) {
fields.push(format!("{field_key}:{value}"));
}
continue;
}
seen.insert(field_key);
}
fields.push(field.to_string());
}
}
fields.extend(
configured_fields
.iter()
.filter(|(field_key, _)| !seen.contains(field_key.as_str()))
.map(|(field_key, value)| format!("{field_key}:{value}")),
);
fields.join(";")
}
#[cfg(test)]
mod tests {
use super::context_from_trace_headers;

View File

@@ -5,18 +5,25 @@ use codex_otel::OtelHttpProtocol;
use codex_otel::OtelProvider;
use codex_otel::OtelSettings;
use codex_otel::Result;
use codex_otel::current_span_w3c_trace_context;
use codex_otel::set_parent_from_w3c_trace_context;
use codex_protocol::protocol::W3cTraceContext;
use std::collections::BTreeMap;
use std::collections::HashMap;
use std::io::Read as _;
use std::io::Write as _;
use std::net::TcpListener;
use std::net::TcpStream;
use std::path::PathBuf;
use std::sync::Mutex;
use std::sync::mpsc;
use std::thread;
use std::time::Duration;
use std::time::Instant;
use tracing_subscriber::layer::SubscriberExt;
static TRACE_CONTEXT_CONFIG_LOCK: Mutex<()> = Mutex::new(());
struct CapturedRequest {
path: String,
content_type: Option<String>,
@@ -217,9 +224,41 @@ fn otlp_http_exporter_sends_metrics_to_collector() -> Result<()> {
Ok(())
}
#[test]
fn otel_provider_rejects_header_unsafe_configured_tracestate() {
let result = OtelProvider::from(&OtelSettings {
environment: "test".to_string(),
service_name: "codex-cli".to_string(),
service_version: env!("CARGO_PKG_VERSION").to_string(),
codex_home: PathBuf::from("."),
exporter: OtelExporter::None,
trace_exporter: OtelExporter::OtlpHttp {
endpoint: "http://127.0.0.1:1/v1/traces".to_string(),
headers: HashMap::new(),
protocol: OtelHttpProtocol::Json,
tls: None,
},
metrics_exporter: OtelExporter::None,
runtime_metrics: false,
span_attributes: BTreeMap::new(),
tracestate: BTreeMap::from([(
"example".to_string(),
BTreeMap::from([("alpha".to_string(), "one\ntwo".to_string())]),
)]),
});
let Err(err) = result else {
panic!("expected header-unsafe configured tracestate to be rejected");
};
assert!(err.to_string().contains("configured tracestate value"));
}
#[test]
fn otlp_http_exporter_sends_traces_to_collector()
-> std::result::Result<(), Box<dyn std::error::Error>> {
let _trace_context_config_guard = TRACE_CONTEXT_CONFIG_LOCK
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
let listener = TcpListener::bind("127.0.0.1:0").expect("bind");
let addr = listener.local_addr().expect("local_addr");
listener.set_nonblocking(true).expect("set_nonblocking");
@@ -266,12 +305,23 @@ fn otlp_http_exporter_sends_traces_to_collector()
},
metrics_exporter: OtelExporter::None,
runtime_metrics: false,
span_attributes: BTreeMap::from([(
"test.configured_attribute".to_string(),
"configured-value".to_string(),
)]),
tracestate: BTreeMap::from([(
"example".to_string(),
BTreeMap::from([
("alpha".to_string(), "one".to_string()),
("beta".to_string(), "two".to_string()),
]),
)]),
})?
.expect("otel provider");
let tracing_layer = otel.tracing_layer().expect("tracing layer");
let subscriber = tracing_subscriber::registry().with(tracing_layer);
tracing::subscriber::with_default(subscriber, || {
let propagated_trace = tracing::subscriber::with_default(subscriber, || {
let span = tracing::info_span!(
"trace-loopback",
otel.name = "trace-loopback",
@@ -279,11 +329,28 @@ fn otlp_http_exporter_sends_traces_to_collector()
rpc.system = "jsonrpc",
rpc.method = "trace-loopback",
);
assert!(set_parent_from_w3c_trace_context(
&span,
&W3cTraceContext {
traceparent: Some(
"00-00000000000000000000000000000001-0000000000000002-01".to_string(),
),
tracestate: Some("example=alpha:zero;keep:yes,other=value".to_string()),
},
));
let _guard = span.enter();
let propagated_trace =
current_span_w3c_trace_context().expect("current span should have trace context");
tracing::info!("trace loopback event");
propagated_trace
});
otel.shutdown();
assert_eq!(
propagated_trace.tracestate.as_deref(),
Some("example=alpha:one;keep:yes;beta:two,other=value")
);
server.join().expect("server join");
let captured = rx.recv_timeout(Duration::from_secs(1)).expect("captured");
@@ -321,6 +388,11 @@ fn otlp_http_exporter_sends_traces_to_collector()
"expected service name not found; body prefix: {}",
&body.chars().take(2000).collect::<String>()
);
assert!(
body.contains("test.configured_attribute") && body.contains("configured-value"),
"expected configured span attribute not found; body prefix: {}",
&body.chars().take(2000).collect::<String>()
);
Ok(())
}
@@ -328,6 +400,9 @@ fn otlp_http_exporter_sends_traces_to_collector()
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn otlp_http_exporter_sends_traces_to_collector_in_tokio_runtime()
-> std::result::Result<(), Box<dyn std::error::Error>> {
let _trace_context_config_guard = TRACE_CONTEXT_CONFIG_LOCK
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
let listener = TcpListener::bind("127.0.0.1:0").expect("bind");
let addr = listener.local_addr().expect("local_addr");
listener.set_nonblocking(true).expect("set_nonblocking");
@@ -374,6 +449,8 @@ async fn otlp_http_exporter_sends_traces_to_collector_in_tokio_runtime()
},
metrics_exporter: OtelExporter::None,
runtime_metrics: false,
span_attributes: BTreeMap::new(),
tracestate: BTreeMap::new(),
})?
.expect("otel provider");
let tracing_layer = otel.tracing_layer().expect("tracing layer");
@@ -436,6 +513,9 @@ async fn otlp_http_exporter_sends_traces_to_collector_in_tokio_runtime()
#[test]
fn otlp_http_exporter_sends_traces_to_collector_in_current_thread_tokio_runtime()
-> std::result::Result<(), Box<dyn std::error::Error>> {
let _trace_context_config_guard = TRACE_CONTEXT_CONFIG_LOCK
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
let listener = TcpListener::bind("127.0.0.1:0").expect("bind");
let addr = listener.local_addr().expect("local_addr");
listener.set_nonblocking(true).expect("set_nonblocking");
@@ -490,6 +570,8 @@ fn otlp_http_exporter_sends_traces_to_collector_in_current_thread_tokio_runtime(
},
metrics_exporter: OtelExporter::None,
runtime_metrics: false,
span_attributes: BTreeMap::new(),
tracestate: BTreeMap::new(),
})
.map_err(|err| err.to_string())?
.expect("otel provider");

View File

@@ -10,6 +10,7 @@ pub(crate) mod metadata;
pub(crate) mod policy;
pub(crate) mod recorder;
pub(crate) mod session_index;
pub(crate) mod sqlite_metrics;
pub mod state_db;
pub(crate) mod default_client {

View File

@@ -1279,6 +1279,7 @@ async fn find_thread_path_by_id_str_in_subdir(
tracing::warn!(
"state db discrepancy during find_thread_path_by_id_str_in_subdir: mismatched_db_path"
);
crate::sqlite_metrics::record_fallback("find_thread_path", "mismatch");
}
Err(err) => {
tracing::debug!(
@@ -1296,6 +1297,7 @@ async fn find_thread_path_by_id_str_in_subdir(
tracing::warn!(
"state db discrepancy during find_thread_path_by_id_str_in_subdir: stale_db_path"
);
crate::sqlite_metrics::record_fallback("find_thread_path", "stale_path");
}
}
@@ -1323,6 +1325,12 @@ async fn find_thread_path_by_id_str_in_subdir(
tracing::warn!(
"state db discrepancy during find_thread_path_by_id_str_in_subdir: falling_back"
);
let reason = if state_db_ctx.is_some() {
"missing_row"
} else {
"db_unavailable"
};
crate::sqlite_metrics::record_fallback("find_thread_path", reason);
state_db::read_repair_rollout_path(
state_db_ctx,
thread_id,

View File

@@ -450,6 +450,7 @@ impl RolloutRecorder {
if state_db_ctx.is_none() {
// Keep legacy behavior when SQLite is unavailable: return filesystem results
// at the requested page size.
crate::sqlite_metrics::record_fallback("list_threads", "db_unavailable");
return Ok(page_from_filesystem_scan(
fs_page,
sort_direction,
@@ -569,6 +570,7 @@ impl RolloutRecorder {
}
if listing_has_metadata_filters {
let page = page_from_filesystem_scan(fs_page, sort_direction, page_size, sort_key);
crate::sqlite_metrics::record_fallback("list_threads", "db_error");
return Ok(fill_missing_thread_item_metadata_from_state_db(
state_db_ctx.as_deref(),
page,
@@ -578,6 +580,7 @@ impl RolloutRecorder {
// If SQLite listing still fails, return the filesystem page rather than failing the list.
tracing::error!("Falling back on rollout system");
tracing::warn!("state db discrepancy during list_threads_with_db_fallback: falling_back");
crate::sqlite_metrics::record_fallback("list_threads", "db_error");
Ok(page_from_filesystem_scan(
fs_page,
sort_direction,

View File

@@ -0,0 +1,49 @@
use std::sync::Arc;
use std::time::Duration;
use codex_otel::ORIGINATOR_TAG;
use codex_otel::bounded_originator_tag_value;
use codex_state::DbMetricsRecorder;
use codex_state::DbMetricsRecorderHandle;
use crate::default_client::originator;
struct OtelDbMetrics {
metrics: codex_otel::MetricsClient,
originator: &'static str,
}
impl DbMetricsRecorder for OtelDbMetrics {
fn counter(&self, name: &str, inc: i64, tags: &[(&str, &str)]) {
let tags = sqlite_originator_tags(tags, self.originator);
let _ = self.metrics.counter(name, inc, &tags);
}
fn record_duration(&self, name: &str, duration: Duration, tags: &[(&str, &str)]) {
let tags = sqlite_originator_tags(tags, self.originator);
let _ = self.metrics.record_duration(name, duration, &tags);
}
}
pub(crate) fn global() -> Option<DbMetricsRecorderHandle> {
codex_otel::global().map(|metrics| {
Arc::new(OtelDbMetrics {
metrics,
originator: bounded_originator_tag_value(originator().value.as_str()),
}) as DbMetricsRecorderHandle
})
}
pub(crate) fn record_fallback(caller: &'static str, reason: &'static str) {
let metrics = global();
codex_state::record_db_fallback_metric(metrics.as_deref(), caller, reason);
}
fn sqlite_originator_tags<'a>(
tags: &[(&'a str, &'a str)],
originator: &'static str,
) -> Vec<(&'a str, &'a str)> {
let mut tags = tags.to_vec();
tags.push((ORIGINATOR_TAG, originator));
tags
}

View File

@@ -4,6 +4,7 @@ use crate::list::Cursor;
use crate::list::SortDirection;
use crate::list::ThreadSortKey;
use crate::metadata;
use crate::sqlite_metrics;
use chrono::DateTime;
use chrono::Utc;
use codex_protocol::ThreadId;
@@ -106,52 +107,80 @@ async fn try_init_with_roots_inner(
default_model_provider_id: String,
backfill_lease_seconds: Option<i64>,
) -> anyhow::Result<StateDbHandle> {
let runtime =
codex_state::StateRuntime::init(sqlite_home.clone(), default_model_provider_id.clone())
.await
.map_err(|err| {
anyhow::anyhow!(
"failed to initialize state runtime at {}: {err}",
sqlite_home.display()
)
})?;
let metrics = sqlite_metrics::global();
let runtime = codex_state::StateRuntime::init_with_metrics(
sqlite_home.clone(),
default_model_provider_id.clone(),
metrics.clone(),
)
.await
.map_err(|err| {
anyhow::anyhow!(
"failed to initialize state runtime at {}: {err}",
sqlite_home.display()
)
})?;
let backfill_gate_started = Instant::now();
let backfill_gate_result = wait_for_startup_backfill(
runtime.as_ref(),
codex_home.as_path(),
default_model_provider_id.as_str(),
backfill_lease_seconds,
)
.await;
codex_state::record_db_init_backfill_gate_metric(
metrics.as_deref(),
backfill_gate_started.elapsed(),
&backfill_gate_result,
);
backfill_gate_result?;
Ok(runtime)
}
async fn wait_for_startup_backfill(
runtime: &codex_state::StateRuntime,
codex_home: &Path,
default_model_provider_id: &str,
backfill_lease_seconds: Option<i64>,
) -> anyhow::Result<()> {
let wait_started = Instant::now();
let mut reported_wait = false;
loop {
let backfill_state = runtime.get_backfill_state().await.map_err(|err| {
anyhow::anyhow!(
"failed to read backfill state at {}: {err}",
codex_home.display()
)
})?;
let backfill_state = match runtime.get_backfill_state().await {
Ok(state) => state,
Err(err) => {
return Err(anyhow::anyhow!(
"failed to read backfill state at {}: {err}",
codex_home.display()
));
}
};
if backfill_state.status == codex_state::BackfillStatus::Complete {
return Ok(runtime);
return Ok(());
}
if let Some(backfill_lease_seconds) = backfill_lease_seconds {
metadata::backfill_sessions_with_lease(
runtime.as_ref(),
codex_home.as_path(),
default_model_provider_id.as_str(),
runtime,
codex_home,
default_model_provider_id,
backfill_lease_seconds,
)
.await;
} else {
metadata::backfill_sessions(
runtime.as_ref(),
codex_home.as_path(),
default_model_provider_id.as_str(),
)
.await;
metadata::backfill_sessions(runtime, codex_home, default_model_provider_id).await;
}
let backfill_state = runtime.get_backfill_state().await.map_err(|err| {
anyhow::anyhow!(
"failed to read backfill state at {} after startup backfill: {err}",
codex_home.display()
)
})?;
let backfill_state = match runtime.get_backfill_state().await {
Ok(state) => state,
Err(err) => {
return Err(anyhow::anyhow!(
"failed to read backfill state at {} after startup backfill: {err}",
codex_home.display()
));
}
};
if backfill_state.status == codex_state::BackfillStatus::Complete {
return Ok(runtime);
return Ok(());
}
if wait_started.elapsed() >= STARTUP_BACKFILL_WAIT_TIMEOUT {
return Err(anyhow::anyhow!(
@@ -193,22 +222,36 @@ fn emit_startup_warning(message: &str) {
/// Unlike [`init`], this helper does not run rollout backfill. It is for
/// optional local reads from non-owning contexts such as remote app-server mode.
pub async fn get_state_db(config: &impl RolloutConfigView) -> Option<StateDbHandle> {
let metrics = sqlite_metrics::global();
let state_path = codex_state::state_db_path(config.sqlite_home());
if !tokio::fs::try_exists(&state_path).await.unwrap_or(false) {
codex_state::record_db_fallback_metric(
metrics.as_deref(),
"get_state_db",
"db_unavailable",
);
return None;
}
let runtime = codex_state::StateRuntime::init(
let runtime = match codex_state::StateRuntime::init_with_metrics(
config.sqlite_home().to_path_buf(),
config.model_provider_id().to_string(),
metrics.clone(),
)
.await
.ok()?;
require_backfill_complete(runtime, config.sqlite_home()).await
{
Ok(runtime) => runtime,
Err(_) => {
codex_state::record_db_fallback_metric(metrics.as_deref(), "get_state_db", "db_error");
return None;
}
};
require_backfill_complete(runtime, config.sqlite_home(), metrics.as_deref()).await
}
async fn require_backfill_complete(
runtime: StateDbHandle,
codex_home: &Path,
metrics: Option<&dyn codex_state::DbMetricsRecorder>,
) -> Option<StateDbHandle> {
match runtime.get_backfill_state().await {
Ok(state) if state.status == codex_state::BackfillStatus::Complete => Some(runtime),
@@ -218,6 +261,7 @@ async fn require_backfill_complete(
codex_home.display(),
state.status.as_str()
);
codex_state::record_db_fallback_metric(metrics, "get_state_db", "backfill_incomplete");
None
}
Err(err) => {
@@ -225,6 +269,7 @@ async fn require_backfill_complete(
"failed to read backfill state at {}: {err}",
codex_home.display()
);
codex_state::record_db_fallback_metric(metrics, "get_state_db", "db_error");
None
}
}

View File

@@ -10,12 +10,13 @@ mod migrations;
mod model;
mod paths;
mod runtime;
mod telemetry;
pub use model::LogEntry;
pub use model::LogQuery;
pub use model::LogRow;
pub use model::Phase2JobClaimOutcome;
/// Preferred entrypoint: owns configuration and metrics.
/// Preferred entrypoint: owns SQLite configuration and optional metrics injection.
pub use runtime::StateRuntime;
/// Low-level storage engine: useful for focused tests.
@@ -56,6 +57,8 @@ pub use runtime::logs_db_filename;
pub use runtime::logs_db_path;
pub use runtime::state_db_filename;
pub use runtime::state_db_path;
pub use telemetry::DbMetricsRecorder;
pub use telemetry::DbMetricsRecorderHandle;
/// Environment variable for overriding the SQLite state database home directory.
pub const SQLITE_HOME_ENV: &str = "CODEX_SQLITE_HOME";
@@ -71,3 +74,31 @@ pub const DB_ERROR_METRIC: &str = "codex.db.error";
pub const DB_METRIC_BACKFILL: &str = "codex.db.backfill";
/// Metrics on backfill duration. Tags: [status]
pub const DB_METRIC_BACKFILL_DURATION_MS: &str = "codex.db.backfill.duration_ms";
/// SQLite startup initialization attempts. Tags: [status, phase, db, error]
pub const DB_INIT_METRIC: &str = "codex.sqlite.init.count";
/// SQLite startup initialization duration. Tags: [status, phase, db, error]
pub const DB_INIT_DURATION_METRIC: &str = "codex.sqlite.init.duration_ms";
/// SQLite logical operation attempts. Tags: [status, db, operation, access, error]
pub const DB_OPERATION_METRIC: &str = "codex.sqlite.operation.count";
/// SQLite logical operation duration. Tags: [status, db, operation, access, error]
pub const DB_OPERATION_DURATION_METRIC: &str = "codex.sqlite.operation.duration_ms";
/// Filesystem fallback after SQLite could not serve a request. Tags: [caller, reason]
pub const DB_FALLBACK_METRIC: &str = "codex.sqlite.fallback.count";
/// SQLite log queue loss or flush failure. Tags: [event, reason]
pub const DB_LOG_QUEUE_METRIC: &str = "codex.sqlite.log_queue.count";
pub fn record_db_fallback_metric(
metrics: Option<&dyn DbMetricsRecorder>,
caller: &'static str,
reason: &'static str,
) {
telemetry::record_fallback(metrics, caller, reason);
}
pub fn record_db_init_backfill_gate_metric(
metrics: Option<&dyn DbMetricsRecorder>,
duration: std::time::Duration,
result: &anyhow::Result<()>,
) {
telemetry::record_init_backfill_gate(metrics, duration, result);
}

View File

@@ -43,6 +43,7 @@ use uuid::Uuid;
use crate::LogEntry;
use crate::StateRuntime;
use crate::telemetry;
const LOG_QUEUE_CAPACITY: usize = 512;
const LOG_BATCH_SIZE: usize = 128;
@@ -94,6 +95,7 @@ where
pub struct LogDbLayer {
sender: mpsc::Sender<LogDbCommand>,
process_uuid: String,
metrics: Option<crate::DbMetricsRecorderHandle>,
}
pub fn start(state_db: std::sync::Arc<StateRuntime>) -> LogDbLayer {
@@ -105,6 +107,7 @@ impl Clone for LogDbLayer {
Self {
sender: self.sender.clone(),
process_uuid: self.process_uuid.clone(),
metrics: self.metrics.clone(),
}
}
}
@@ -120,22 +123,34 @@ impl LogDbLayer {
) -> Self {
let config = config.normalized();
let (sender, receiver) = mpsc::channel(config.queue_capacity);
let metrics = state_db.metrics_handle();
tokio::spawn(run_inserter(state_db, receiver, config));
Self {
sender,
process_uuid: current_process_log_uuid().to_string(),
metrics,
}
}
pub async fn flush(&self) {
let (tx, rx) = oneshot::channel();
if self.sender.send(LogDbCommand::Flush(tx)).await.is_ok() {
let _ = rx.await;
if self.sender.send(LogDbCommand::Flush(tx)).await.is_err() {
telemetry::record_log_queue(self.metrics.as_deref(), "flush_failed", "closed");
return;
}
if rx.await.is_err() {
telemetry::record_log_queue(self.metrics.as_deref(), "flush_failed", "closed");
}
}
fn try_send(&self, entry: LogEntry) {
let _ = self.sender.try_send(LogDbCommand::Entry(Box::new(entry)));
if let Err(err) = self.sender.try_send(LogDbCommand::Entry(Box::new(entry))) {
let reason = match err {
mpsc::error::TrySendError::Full(_) => "full",
mpsc::error::TrySendError::Closed(_) => "closed",
};
telemetry::record_log_queue(self.metrics.as_deref(), "dropped", reason);
}
}
}
@@ -401,7 +416,9 @@ async fn flush(state_db: &StateRuntime, buffer: &mut Vec<LogEntry>) {
return;
}
let entries = buffer.split_off(0);
let _ = state_db.insert_logs(entries.as_slice()).await;
if state_db.insert_logs(entries.as_slice()).await.is_err() {
telemetry::record_log_queue(state_db.metrics(), "flush_failed", "insert_failed");
}
}
#[derive(Default)]
@@ -721,6 +738,7 @@ mod tests {
let layer = LogDbLayer {
sender,
process_uuid: "process-1".to_string(),
metrics: None,
};
layer.try_send(test_entry("first-queued-log"));
@@ -741,6 +759,7 @@ mod tests {
let layer = LogDbLayer {
sender,
process_uuid: "process-1".to_string(),
metrics: None,
};
layer.try_send(test_entry("queued-before-flush"));

View File

@@ -27,6 +27,10 @@ use crate::model::datetime_to_epoch_millis;
use crate::model::datetime_to_epoch_seconds;
use crate::model::epoch_millis_to_datetime;
use crate::paths::file_modified_time_utc;
use crate::telemetry::DbAccess;
use crate::telemetry::DbKind;
use crate::telemetry::DbMetricsRecorder;
use crate::telemetry::DbMetricsRecorderHandle;
use chrono::DateTime;
use chrono::Utc;
use codex_protocol::ThreadId;
@@ -52,10 +56,12 @@ use std::path::PathBuf;
use std::sync::Arc;
use std::sync::atomic::AtomicI64;
use std::time::Duration;
use std::time::Instant;
use tracing::warn;
mod agent_jobs;
mod backfill;
mod db;
mod goals;
mod logs;
mod memories;
@@ -70,6 +76,9 @@ pub use goals::ThreadGoalUpdate;
pub use remote_control::RemoteControlEnrollmentRecord;
pub use threads::ThreadFilterOptions;
use db::DbOperation;
use db::InstrumentedDb;
// "Partition" is the retained-log-content bucket we cap at 10 MiB:
// - one bucket per non-null thread_id
// - one bucket per threadless (thread_id IS NULL) non-null process_uuid
@@ -83,8 +92,8 @@ const LOG_PARTITION_ROW_LIMIT: i64 = 1_000;
pub struct StateRuntime {
codex_home: PathBuf,
default_provider: String,
pool: Arc<sqlx::SqlitePool>,
logs_pool: Arc<sqlx::SqlitePool>,
state_db: InstrumentedDb,
logs_db: InstrumentedDb,
thread_updated_at_millis: Arc<AtomicI64>,
}
@@ -93,8 +102,18 @@ impl StateRuntime {
///
/// This opens (and migrates) the SQLite databases under `codex_home`,
/// keeping logs in a dedicated file to reduce lock contention with the
/// rest of the state store.
/// rest of the state store. Use [`Self::init_with_metrics`] when the caller
/// has a metrics sink to attach.
pub async fn init(codex_home: PathBuf, default_provider: String) -> anyhow::Result<Arc<Self>> {
Self::init_with_metrics(codex_home, default_provider, /*metrics*/ None).await
}
/// Initialize the state runtime with an explicit metrics client.
pub async fn init_with_metrics(
codex_home: PathBuf,
default_provider: String,
metrics: Option<DbMetricsRecorderHandle>,
) -> anyhow::Result<Arc<Self>> {
tokio::fs::create_dir_all(&codex_home).await?;
let state_migrator = runtime_state_migrator();
let logs_migrator = runtime_logs_migrator();
@@ -116,28 +135,49 @@ impl StateRuntime {
.await;
let state_path = state_db_path(codex_home.as_path());
let logs_path = logs_db_path(codex_home.as_path());
let pool = match open_state_sqlite(&state_path, &state_migrator).await {
let pool = match open_state_sqlite(&state_path, &state_migrator, metrics.as_deref()).await {
Ok(db) => Arc::new(db),
Err(err) => {
warn!("failed to open state db at {}: {err}", state_path.display());
return Err(err);
}
};
let logs_pool = match open_logs_sqlite(&logs_path, &logs_migrator).await {
let logs_pool = match open_logs_sqlite(&logs_path, &logs_migrator, metrics.as_deref()).await
{
Ok(db) => Arc::new(db),
Err(err) => {
warn!("failed to open logs db at {}: {err}", logs_path.display());
return Err(err);
}
};
let thread_updated_at_millis: Option<i64> =
let started = Instant::now();
let backfill_state_result = ensure_backfill_state_row_in_pool(pool.as_ref()).await;
crate::telemetry::record_init_result(
metrics.as_deref(),
DbKind::State,
"ensure_backfill_state",
started.elapsed(),
&backfill_state_result,
);
backfill_state_result?;
let started = Instant::now();
let thread_updated_at_millis_result: anyhow::Result<Option<i64>> =
sqlx::query_scalar("SELECT MAX(threads.updated_at_ms) FROM threads")
.fetch_one(pool.as_ref())
.await?;
.await
.map_err(anyhow::Error::from);
crate::telemetry::record_init_result(
metrics.as_deref(),
DbKind::State,
"post_init_query",
started.elapsed(),
&thread_updated_at_millis_result,
);
let thread_updated_at_millis = thread_updated_at_millis_result?;
let thread_updated_at_millis = thread_updated_at_millis.unwrap_or(0);
let runtime = Arc::new(Self {
pool,
logs_pool,
state_db: InstrumentedDb::new(pool, DbKind::State, metrics.clone()),
logs_db: InstrumentedDb::new(logs_pool, DbKind::Logs, metrics),
codex_home,
default_provider,
thread_updated_at_millis: Arc::new(AtomicI64::new(thread_updated_at_millis)),
@@ -155,6 +195,14 @@ impl StateRuntime {
pub fn codex_home(&self) -> &Path {
self.codex_home.as_path()
}
pub(crate) fn metrics(&self) -> Option<&dyn DbMetricsRecorder> {
self.state_db.metrics()
}
pub(crate) fn metrics_handle(&self) -> Option<DbMetricsRecorderHandle> {
self.state_db.metrics_handle()
}
}
fn base_sqlite_options(path: &Path) -> SqliteConnectOptions {
@@ -165,29 +213,90 @@ fn base_sqlite_options(path: &Path) -> SqliteConnectOptions {
.synchronous(SqliteSynchronous::Normal)
.busy_timeout(Duration::from_secs(5))
.log_statements(LevelFilter::Off)
.log_slow_statements(LevelFilter::Warn, Duration::from_millis(250))
}
async fn open_state_sqlite(path: &Path, migrator: &Migrator) -> anyhow::Result<SqlitePool> {
async fn open_state_sqlite(
path: &Path,
migrator: &Migrator,
metrics: Option<&dyn DbMetricsRecorder>,
) -> anyhow::Result<SqlitePool> {
// New state DBs should use incremental auto-vacuum, but retrofitting an
// existing DB requires a full VACUUM. Do not attempt that during process
// startup: it is maintenance work that can contend with foreground writers.
open_sqlite(
path,
migrator,
metrics,
DbKind::State,
"open_state",
"migrate_state",
)
.await
}
async fn open_logs_sqlite(
path: &Path,
migrator: &Migrator,
metrics: Option<&dyn DbMetricsRecorder>,
) -> anyhow::Result<SqlitePool> {
open_sqlite(
path,
migrator,
metrics,
DbKind::Logs,
"open_logs",
"migrate_logs",
)
.await
}
async fn open_sqlite(
path: &Path,
migrator: &Migrator,
metrics: Option<&dyn DbMetricsRecorder>,
db: DbKind,
open_phase: &'static str,
migrate_phase: &'static str,
) -> anyhow::Result<SqlitePool> {
let options = base_sqlite_options(path).auto_vacuum(SqliteAutoVacuum::Incremental);
let pool = SqlitePoolOptions::new()
let started = Instant::now();
let pool_result = SqlitePoolOptions::new()
.max_connections(5)
.acquire_slow_level(LevelFilter::Warn)
.acquire_slow_threshold(Duration::from_millis(250))
.connect_with(options)
.await?;
migrator.run(&pool).await?;
.await
.map_err(anyhow::Error::from);
crate::telemetry::record_init_result(metrics, db, open_phase, started.elapsed(), &pool_result);
let pool = pool_result?;
let started = Instant::now();
let migrate_result = migrator.run(&pool).await.map_err(anyhow::Error::from);
crate::telemetry::record_init_result(
metrics,
db,
migrate_phase,
started.elapsed(),
&migrate_result,
);
migrate_result?;
Ok(pool)
}
async fn open_logs_sqlite(path: &Path, migrator: &Migrator) -> anyhow::Result<SqlitePool> {
let options = base_sqlite_options(path).auto_vacuum(SqliteAutoVacuum::Incremental);
let pool = SqlitePoolOptions::new()
.max_connections(5)
.connect_with(options)
.await?;
migrator.run(&pool).await?;
Ok(pool)
async fn ensure_backfill_state_row_in_pool(pool: &sqlx::SqlitePool) -> anyhow::Result<()> {
sqlx::query(
r#"
INSERT INTO backfill_state (id, status, last_watermark, last_success_at, updated_at)
VALUES (?, ?, NULL, NULL, ?)
ON CONFLICT(id) DO NOTHING
"#,
)
.bind(1_i64)
.bind(crate::BackfillStatus::Pending.as_str())
.bind(Utc::now().timestamp())
.execute(pool)
.await?;
Ok(())
}
fn db_filename(base_name: &str, version: u32) -> String {
@@ -355,9 +464,13 @@ mod tests {
strict_pool.close().await;
let tolerant_migrator = runtime_state_migrator();
let tolerant_pool = open_state_sqlite(state_path.as_path(), &tolerant_migrator)
.await
.expect("runtime migrator should tolerate newer applied migrations");
let tolerant_pool = open_state_sqlite(
state_path.as_path(),
&tolerant_migrator,
/*metrics*/ None,
)
.await
.expect("runtime migrator should tolerate newer applied migrations");
tolerant_pool.close().await;
let _ = tokio::fs::remove_dir_all(codex_home).await;

View File

@@ -19,7 +19,7 @@ impl StateRuntime {
.map(i64::try_from)
.transpose()
.map_err(|_| anyhow::anyhow!("invalid max_runtime_seconds value"))?;
let mut tx = self.pool.begin().await?;
let mut tx = self.state_db.pool().begin().await?;
sqlx::query(
r#"
INSERT INTO agent_jobs (
@@ -122,7 +122,7 @@ WHERE id = ?
"#,
)
.bind(job_id)
.fetch_optional(self.pool.as_ref())
.fetch_optional(self.state_db.pool())
.await?;
row.map(AgentJob::try_from).transpose()
}
@@ -166,7 +166,7 @@ WHERE job_id =
}
let rows: Vec<AgentJobItemRow> = builder
.build_query_as::<AgentJobItemRow>()
.fetch_all(self.pool.as_ref())
.fetch_all(self.state_db.pool())
.await?;
rows.into_iter().map(AgentJobItem::try_from).collect()
}
@@ -199,7 +199,7 @@ WHERE job_id = ? AND item_id = ?
)
.bind(job_id)
.bind(item_id)
.fetch_optional(self.pool.as_ref())
.fetch_optional(self.state_db.pool())
.await?;
row.map(AgentJobItem::try_from).transpose()
}
@@ -222,7 +222,7 @@ WHERE id = ?
.bind(now)
.bind(now)
.bind(job_id)
.execute(self.pool.as_ref())
.execute(self.state_db.pool())
.await?;
Ok(())
}
@@ -240,7 +240,7 @@ WHERE id = ?
.bind(now)
.bind(now)
.bind(job_id)
.execute(self.pool.as_ref())
.execute(self.state_db.pool())
.await?;
Ok(())
}
@@ -263,7 +263,7 @@ WHERE id = ?
.bind(now)
.bind(error_message)
.bind(job_id)
.execute(self.pool.as_ref())
.execute(self.state_db.pool())
.await?;
Ok(())
}
@@ -288,7 +288,7 @@ WHERE id = ? AND status IN (?, ?)
.bind(job_id)
.bind(AgentJobStatus::Pending.as_str())
.bind(AgentJobStatus::Running.as_str())
.execute(self.pool.as_ref())
.execute(self.state_db.pool())
.await?;
Ok(result.rows_affected() > 0)
}
@@ -302,7 +302,7 @@ WHERE id = ?
"#,
)
.bind(job_id)
.fetch_optional(self.pool.as_ref())
.fetch_optional(self.state_db.pool())
.await?;
let Some(row) = row else {
return Ok(false);
@@ -334,7 +334,7 @@ WHERE job_id = ? AND item_id = ? AND status = ?
.bind(job_id)
.bind(item_id)
.bind(AgentJobItemStatus::Pending.as_str())
.execute(self.pool.as_ref())
.execute(self.state_db.pool())
.await?;
Ok(result.rows_affected() > 0)
}
@@ -364,7 +364,7 @@ WHERE job_id = ? AND item_id = ? AND status = ?
.bind(job_id)
.bind(item_id)
.bind(AgentJobItemStatus::Pending.as_str())
.execute(self.pool.as_ref())
.execute(self.state_db.pool())
.await?;
Ok(result.rows_affected() > 0)
}
@@ -393,7 +393,7 @@ WHERE job_id = ? AND item_id = ? AND status = ?
.bind(job_id)
.bind(item_id)
.bind(AgentJobItemStatus::Running.as_str())
.execute(self.pool.as_ref())
.execute(self.state_db.pool())
.await?;
Ok(result.rows_affected() > 0)
}
@@ -417,7 +417,7 @@ WHERE job_id = ? AND item_id = ? AND status = ?
.bind(job_id)
.bind(item_id)
.bind(AgentJobItemStatus::Running.as_str())
.execute(self.pool.as_ref())
.execute(self.state_db.pool())
.await?;
Ok(result.rows_affected() > 0)
}
@@ -458,7 +458,7 @@ WHERE
.bind(item_id)
.bind(AgentJobItemStatus::Running.as_str())
.bind(reporting_thread_id)
.execute(self.pool.as_ref())
.execute(self.state_db.pool())
.await?;
Ok(result.rows_affected() > 0)
}
@@ -490,7 +490,7 @@ WHERE
.bind(job_id)
.bind(item_id)
.bind(AgentJobItemStatus::Running.as_str())
.execute(self.pool.as_ref())
.execute(self.state_db.pool())
.await?;
Ok(result.rows_affected() > 0)
}
@@ -524,7 +524,7 @@ WHERE
.bind(job_id)
.bind(item_id)
.bind(AgentJobItemStatus::Running.as_str())
.execute(self.pool.as_ref())
.execute(self.state_db.pool())
.await?;
Ok(result.rows_affected() > 0)
}
@@ -547,7 +547,7 @@ WHERE job_id = ?
.bind(AgentJobItemStatus::Completed.as_str())
.bind(AgentJobItemStatus::Failed.as_str())
.bind(job_id)
.fetch_one(self.pool.as_ref())
.fetch_one(self.state_db.pool())
.await?;
let total_items: i64 = row.try_get("total_items")?;

View File

@@ -2,17 +2,20 @@ use super::*;
impl StateRuntime {
pub async fn get_backfill_state(&self) -> anyhow::Result<crate::BackfillState> {
self.ensure_backfill_state_row().await?;
let row = sqlx::query(
r#"
self.state_db
.read(DbOperation::GetBackfillState, |pool| async move {
let row = sqlx::query(
r#"
SELECT status, last_watermark, last_success_at
FROM backfill_state
WHERE id = 1
"#,
)
.fetch_one(self.pool.as_ref())
.await?;
crate::BackfillState::try_from_row(&row)
)
.fetch_one(&pool)
.await?;
crate::BackfillState::try_from_row(&row)
})
.await
}
/// Attempt to claim ownership of rollout metadata backfill.
@@ -21,69 +24,83 @@ WHERE id = 1
/// Returns `false` if backfill is already complete or currently owned by a
/// non-expired worker.
pub async fn try_claim_backfill(&self, lease_seconds: i64) -> anyhow::Result<bool> {
self.ensure_backfill_state_row().await?;
let now = Utc::now().timestamp();
let lease_cutoff = now.saturating_sub(lease_seconds.max(0));
let result = sqlx::query(
r#"
self.state_db
.write(DbOperation::TryClaimBackfill, |pool| async move {
ensure_backfill_state_row_in_pool(&pool).await?;
let now = Utc::now().timestamp();
let lease_cutoff = now.saturating_sub(lease_seconds.max(0));
let result = sqlx::query(
r#"
UPDATE backfill_state
SET status = ?, updated_at = ?
WHERE id = 1
AND status != ?
AND (status != ? OR updated_at <= ?)
"#,
)
.bind(crate::BackfillStatus::Running.as_str())
.bind(now)
.bind(crate::BackfillStatus::Complete.as_str())
.bind(crate::BackfillStatus::Running.as_str())
.bind(lease_cutoff)
.execute(self.pool.as_ref())
.await?;
Ok(result.rows_affected() == 1)
)
.bind(crate::BackfillStatus::Running.as_str())
.bind(now)
.bind(crate::BackfillStatus::Complete.as_str())
.bind(crate::BackfillStatus::Running.as_str())
.bind(lease_cutoff)
.execute(&pool)
.await?;
Ok(result.rows_affected() == 1)
})
.await
}
/// Mark rollout metadata backfill as running.
pub async fn mark_backfill_running(&self) -> anyhow::Result<()> {
self.ensure_backfill_state_row().await?;
sqlx::query(
r#"
self.state_db
.write(DbOperation::MarkBackfillRunning, |pool| async move {
ensure_backfill_state_row_in_pool(&pool).await?;
sqlx::query(
r#"
UPDATE backfill_state
SET status = ?, updated_at = ?
WHERE id = 1
"#,
)
.bind(crate::BackfillStatus::Running.as_str())
.bind(Utc::now().timestamp())
.execute(self.pool.as_ref())
.await?;
Ok(())
)
.bind(crate::BackfillStatus::Running.as_str())
.bind(Utc::now().timestamp())
.execute(&pool)
.await?;
Ok(())
})
.await
}
/// Persist rollout metadata backfill progress.
pub async fn checkpoint_backfill(&self, watermark: &str) -> anyhow::Result<()> {
self.ensure_backfill_state_row().await?;
sqlx::query(
r#"
self.state_db
.write(DbOperation::CheckpointBackfill, |pool| async move {
ensure_backfill_state_row_in_pool(&pool).await?;
sqlx::query(
r#"
UPDATE backfill_state
SET status = ?, last_watermark = ?, updated_at = ?
WHERE id = 1
"#,
)
.bind(crate::BackfillStatus::Running.as_str())
.bind(watermark)
.bind(Utc::now().timestamp())
.execute(self.pool.as_ref())
.await?;
Ok(())
)
.bind(crate::BackfillStatus::Running.as_str())
.bind(watermark)
.bind(Utc::now().timestamp())
.execute(&pool)
.await?;
Ok(())
})
.await
}
/// Mark rollout metadata backfill as complete.
pub async fn mark_backfill_complete(&self, last_watermark: Option<&str>) -> anyhow::Result<()> {
self.ensure_backfill_state_row().await?;
let now = Utc::now().timestamp();
sqlx::query(
r#"
self.state_db
.write(DbOperation::MarkBackfillComplete, |pool| async move {
ensure_backfill_state_row_in_pool(&pool).await?;
let now = Utc::now().timestamp();
sqlx::query(
r#"
UPDATE backfill_state
SET
status = ?,
@@ -92,30 +109,16 @@ SET
updated_at = ?
WHERE id = 1
"#,
)
.bind(crate::BackfillStatus::Complete.as_str())
.bind(last_watermark)
.bind(now)
.bind(now)
.execute(self.pool.as_ref())
.await?;
Ok(())
}
async fn ensure_backfill_state_row(&self) -> anyhow::Result<()> {
sqlx::query(
r#"
INSERT INTO backfill_state (id, status, last_watermark, last_success_at, updated_at)
VALUES (?, ?, NULL, NULL, ?)
ON CONFLICT(id) DO NOTHING
"#,
)
.bind(1_i64)
.bind(crate::BackfillStatus::Pending.as_str())
.bind(Utc::now().timestamp())
.execute(self.pool.as_ref())
.await?;
Ok(())
)
.bind(crate::BackfillStatus::Complete.as_str())
.bind(last_watermark)
.bind(now)
.bind(now)
.execute(&pool)
.await?;
Ok(())
})
.await
}
}
@@ -286,7 +289,7 @@ WHERE id = 1
)
.bind(crate::BackfillStatus::Running.as_str())
.bind(stale_updated_at)
.execute(runtime.pool.as_ref())
.execute(runtime.state_db.pool())
.await
.expect("force stale backfill lease");

View File

@@ -0,0 +1,158 @@
use std::future::Future;
use std::sync::Arc;
use std::time::Instant;
use sqlx::SqlitePool;
use crate::telemetry::DbAccess;
use crate::telemetry::DbKind;
use crate::telemetry::DbMetricsRecorder;
use crate::telemetry::DbMetricsRecorderHandle;
/// SQLite pool plus Codex-level operation telemetry context.
#[derive(Clone)]
pub(super) struct InstrumentedDb {
pool: Arc<SqlitePool>,
kind: DbKind,
metrics: Option<DbMetricsRecorderHandle>,
}
impl InstrumentedDb {
pub(super) fn new(
pool: Arc<SqlitePool>,
kind: DbKind,
metrics: Option<DbMetricsRecorderHandle>,
) -> Self {
Self {
pool,
kind,
metrics,
}
}
pub(super) fn pool(&self) -> &SqlitePool {
self.pool.as_ref()
}
pub(super) fn metrics(&self) -> Option<&dyn DbMetricsRecorder> {
self.metrics.as_deref()
}
pub(super) fn metrics_handle(&self) -> Option<DbMetricsRecorderHandle> {
self.metrics.clone()
}
pub(super) async fn read<T, F, Fut>(&self, operation: DbOperation, f: F) -> anyhow::Result<T>
where
F: FnOnce(SqlitePool) -> Fut,
Fut: Future<Output = anyhow::Result<T>>,
{
self.record_operation(operation, DbAccess::Read, f).await
}
pub(super) async fn write<T, F, Fut>(&self, operation: DbOperation, f: F) -> anyhow::Result<T>
where
F: FnOnce(SqlitePool) -> Fut,
Fut: Future<Output = anyhow::Result<T>>,
{
self.record_operation(operation, DbAccess::Write, f).await
}
pub(super) async fn transaction<T, F, Fut>(
&self,
operation: DbOperation,
f: F,
) -> anyhow::Result<T>
where
F: FnOnce(SqlitePool) -> Fut,
Fut: Future<Output = anyhow::Result<T>>,
{
self.record_operation(operation, DbAccess::Transaction, f)
.await
}
pub(super) async fn maintenance<T, F, Fut>(
&self,
operation: DbOperation,
f: F,
) -> anyhow::Result<T>
where
F: FnOnce(SqlitePool) -> Fut,
Fut: Future<Output = anyhow::Result<T>>,
{
self.record_operation(operation, DbAccess::Maintenance, f)
.await
}
pub(super) fn record_result<T>(
&self,
operation: DbOperation,
access: DbAccess,
started: Instant,
result: &anyhow::Result<T>,
) {
crate::telemetry::record_operation_result(
self.metrics(),
self.kind,
operation.as_str(),
access,
started.elapsed(),
result,
);
}
async fn record_operation<T, F, Fut>(
&self,
operation: DbOperation,
access: DbAccess,
f: F,
) -> anyhow::Result<T>
where
F: FnOnce(SqlitePool) -> Fut,
Fut: Future<Output = anyhow::Result<T>>,
{
let started = Instant::now();
let result = f(self.pool().clone()).await;
self.record_result(operation, access, started, &result);
result
}
}
#[derive(Clone, Copy)]
pub(super) enum DbOperation {
CheckpointBackfill,
FindRolloutPathById,
GetBackfillState,
GetDynamicTools,
GetThread,
InsertLogs,
ListThreads,
LogsStartupMaintenance,
MarkBackfillComplete,
MarkBackfillRunning,
PersistDynamicTools,
TouchThreadUpdatedAt,
TryClaimBackfill,
UpsertThread,
}
impl DbOperation {
fn as_str(self) -> &'static str {
match self {
Self::CheckpointBackfill => "checkpoint_backfill",
Self::FindRolloutPathById => "find_rollout_path_by_id",
Self::GetBackfillState => "get_backfill_state",
Self::GetDynamicTools => "get_dynamic_tools",
Self::GetThread => "get_thread",
Self::InsertLogs => "insert_logs",
Self::ListThreads => "list_threads",
Self::LogsStartupMaintenance => "logs_startup_maintenance",
Self::MarkBackfillComplete => "mark_backfill_complete",
Self::MarkBackfillRunning => "mark_backfill_running",
Self::PersistDynamicTools => "persist_dynamic_tools",
Self::TouchThreadUpdatedAt => "touch_thread_updated_at",
Self::TryClaimBackfill => "try_claim_backfill",
Self::UpsertThread => "upsert_thread",
}
}
}

View File

@@ -42,7 +42,7 @@ WHERE thread_id = ?
"#,
)
.bind(thread_id.to_string())
.fetch_optional(self.pool.as_ref())
.fetch_optional(self.state_db.pool())
.await?;
row.map(|row| thread_goal_from_row(&row)).transpose()
@@ -99,7 +99,7 @@ RETURNING
.bind(token_budget)
.bind(now_ms)
.bind(now_ms)
.fetch_one(self.pool.as_ref())
.fetch_one(self.state_db.pool())
.await?;
thread_goal_from_row(&row)
@@ -148,7 +148,7 @@ RETURNING
.bind(token_budget)
.bind(now_ms)
.bind(now_ms)
.fetch_optional(self.pool.as_ref())
.fetch_optional(self.state_db.pool())
.await?;
row.map(|row| thread_goal_from_row(&row)).transpose()
@@ -196,7 +196,7 @@ WHERE thread_id = ?
.bind(thread_id.to_string())
.bind(expected_goal_id)
.bind(expected_goal_id)
.execute(self.pool.as_ref())
.execute(self.state_db.pool())
.await?
}
(Some(status), None) => {
@@ -224,7 +224,7 @@ WHERE thread_id = ?
.bind(thread_id.to_string())
.bind(expected_goal_id)
.bind(expected_goal_id)
.execute(self.pool.as_ref())
.execute(self.state_db.pool())
.await?
}
(None, Some(token_budget)) => {
@@ -250,7 +250,7 @@ WHERE thread_id = ?
.bind(thread_id.to_string())
.bind(expected_goal_id)
.bind(expected_goal_id)
.execute(self.pool.as_ref())
.execute(self.state_db.pool())
.await?
}
(None, None) => {
@@ -289,7 +289,7 @@ WHERE thread_id = ?
.bind(crate::ThreadGoalStatus::Paused.as_str())
.bind(now_ms)
.bind(thread_id.to_string())
.execute(self.pool.as_ref())
.execute(self.state_db.pool())
.await?;
if result.rows_affected() == 0 {
@@ -307,7 +307,7 @@ WHERE thread_id = ?
"#,
)
.bind(thread_id.to_string())
.execute(self.pool.as_ref())
.execute(self.state_db.pool())
.await?;
Ok(result.rows_affected() > 0)
@@ -392,7 +392,7 @@ RETURNING
query = query.bind(expected_goal_id);
}
let row = query.fetch_optional(self.pool.as_ref()).await?;
let row = query.fetch_optional(self.state_db.pool()).await?;
let Some(row) = row else {
return Ok(ThreadGoalAccountingOutcome::Unchanged(

View File

@@ -9,41 +9,52 @@ impl StateRuntime {
/// Insert a batch of log entries into the logs table.
pub async fn insert_logs(&self, entries: &[LogEntry]) -> anyhow::Result<()> {
if entries.is_empty() {
return Ok(());
}
let started = Instant::now();
let result: anyhow::Result<()> = async {
if entries.is_empty() {
return Ok(());
}
let mut tx = self.logs_pool.begin().await?;
let mut builder = QueryBuilder::<Sqlite>::new(
"INSERT INTO logs (ts, ts_nanos, level, target, feedback_log_body, thread_id, process_uuid, module_path, file, line, estimated_bytes) ",
let mut tx = self.logs_db.pool().begin().await?;
let mut builder = QueryBuilder::<Sqlite>::new(
"INSERT INTO logs (ts, ts_nanos, level, target, feedback_log_body, thread_id, process_uuid, module_path, file, line, estimated_bytes) ",
);
builder.push_values(entries, |mut row, entry| {
let feedback_log_body = entry.feedback_log_body.as_ref().or(entry.message.as_ref());
// Keep about 10 MiB of reader-visible log content per partition.
// Both `query_logs` and `/feedback` read the persisted
// `feedback_log_body`, while `LogEntry.message` is only a write-time
// fallback for callers that still populate the old field.
let estimated_bytes = feedback_log_body.map_or(0, String::len) as i64
+ entry.level.len() as i64
+ entry.target.len() as i64
+ entry.module_path.as_ref().map_or(0, String::len) as i64
+ entry.file.as_ref().map_or(0, String::len) as i64;
row.push_bind(entry.ts)
.push_bind(entry.ts_nanos)
.push_bind(&entry.level)
.push_bind(&entry.target)
.push_bind(feedback_log_body)
.push_bind(&entry.thread_id)
.push_bind(&entry.process_uuid)
.push_bind(&entry.module_path)
.push_bind(&entry.file)
.push_bind(entry.line)
.push_bind(estimated_bytes);
});
builder.build().execute(&mut *tx).await?;
self.prune_logs_after_insert(entries, &mut tx).await?;
tx.commit().await?;
Ok(())
}
.await;
self.logs_db.record_result(
DbOperation::InsertLogs,
DbAccess::Transaction,
started,
&result,
);
builder.push_values(entries, |mut row, entry| {
let feedback_log_body = entry.feedback_log_body.as_ref().or(entry.message.as_ref());
// Keep about 10 MiB of reader-visible log content per partition.
// Both `query_logs` and `/feedback` read the persisted
// `feedback_log_body`, while `LogEntry.message` is only a write-time
// fallback for callers that still populate the old field.
let estimated_bytes = feedback_log_body.map_or(0, String::len) as i64
+ entry.level.len() as i64
+ entry.target.len() as i64
+ entry.module_path.as_ref().map_or(0, String::len) as i64
+ entry.file.as_ref().map_or(0, String::len) as i64;
row.push_bind(entry.ts)
.push_bind(entry.ts_nanos)
.push_bind(&entry.level)
.push_bind(&entry.target)
.push_bind(feedback_log_body)
.push_bind(&entry.thread_id)
.push_bind(&entry.process_uuid)
.push_bind(&entry.module_path)
.push_bind(&entry.file)
.push_bind(entry.line)
.push_bind(estimated_bytes);
});
builder.build().execute(&mut *tx).await?;
self.prune_logs_after_insert(entries, &mut tx).await?;
tx.commit().await?;
Ok(())
result
}
/// Enforce per-partition retained-log-content caps after a successful batch insert.
@@ -285,28 +296,27 @@ WHERE id IN (
Ok(())
}
pub(crate) async fn delete_logs_before(&self, cutoff_ts: i64) -> anyhow::Result<u64> {
let result = sqlx::query("DELETE FROM logs WHERE ts < ?")
.bind(cutoff_ts)
.execute(self.logs_pool.as_ref())
.await?;
Ok(result.rows_affected())
}
pub(crate) async fn run_logs_startup_maintenance(&self) -> anyhow::Result<()> {
let Some(cutoff) =
Utc::now().checked_sub_signed(chrono::Duration::days(LOG_RETENTION_DAYS))
else {
return Ok(());
};
self.delete_logs_before(cutoff.timestamp()).await?;
// Startup cleanup should not wait behind or block foreground work.
// PASSIVE checkpoints copy whatever is immediately available and skip
// frames that would require waiting on active readers or writers.
sqlx::query("PRAGMA wal_checkpoint(PASSIVE)")
.execute(self.logs_pool.as_ref())
.await?;
Ok(())
self.logs_db
.maintenance(DbOperation::LogsStartupMaintenance, |pool| async move {
let Some(cutoff) =
Utc::now().checked_sub_signed(chrono::Duration::days(LOG_RETENTION_DAYS))
else {
return Ok(());
};
sqlx::query("DELETE FROM logs WHERE ts < ?")
.bind(cutoff.timestamp())
.execute(&pool)
.await?;
// Startup cleanup should not wait behind or block foreground work.
// PASSIVE checkpoints copy whatever is immediately available and skip
// frames that would require waiting on active readers or writers.
sqlx::query("PRAGMA wal_checkpoint(PASSIVE)")
.execute(&pool)
.await?;
Ok(())
})
.await
}
/// Query logs with optional filters.
@@ -326,7 +336,7 @@ WHERE id IN (
let rows = builder
.build_query_as::<LogRow>()
.fetch_all(self.logs_pool.as_ref())
.fetch_all(self.logs_db.pool())
.await?;
Ok(rows)
}
@@ -398,7 +408,7 @@ ORDER BY ts DESC, ts_nanos DESC, id DESC
}
let rows = sql
.bind(LOG_PARTITION_SIZE_LIMIT_BYTES)
.fetch_all(self.logs_pool.as_ref())
.fetch_all(self.logs_db.pool())
.await?;
let mut lines = Vec::new();
@@ -431,7 +441,7 @@ ORDER BY ts DESC, ts_nanos DESC, id DESC
let mut builder =
QueryBuilder::<Sqlite>::new("SELECT MAX(id) AS max_id FROM logs WHERE 1 = 1");
push_log_filters(&mut builder, query);
let row = builder.build().fetch_one(self.logs_pool.as_ref()).await?;
let row = builder.build().fetch_one(self.logs_db.pool()).await?;
let max_id: Option<i64> = row.try_get("max_id")?;
Ok(max_id.unwrap_or(0))
}

View File

@@ -30,7 +30,7 @@ impl StateRuntime {
/// stage-1 (`memory_stage1`) and phase-2 (`memory_consolidate_global`)
/// memory pipelines.
pub async fn clear_memory_data(&self) -> anyhow::Result<()> {
let mut tx = self.pool.begin().await?;
let mut tx = self.state_db.pool().begin().await?;
sqlx::query(
r#"
@@ -68,7 +68,7 @@ WHERE kind = ? OR kind = ?
}
let now = Utc::now().timestamp();
let mut tx = self.pool.begin().await?;
let mut tx = self.state_db.pool().begin().await?;
let mut updated_rows = 0;
for thread_id in thread_ids {
@@ -209,7 +209,7 @@ LEFT JOIN jobs
let items = builder
.build()
.fetch_all(self.pool.as_ref())
.fetch_all(self.state_db.pool())
.await?
.into_iter()
.map(|row| ThreadRow::try_from_row(&row).and_then(ThreadMetadata::try_from))
@@ -279,7 +279,7 @@ LIMIT ?
"#,
)
.bind(n as i64)
.fetch_all(self.pool.as_ref())
.fetch_all(self.state_db.pool())
.await?;
rows.into_iter()
@@ -323,7 +323,7 @@ WHERE thread_id IN (
)
.bind(cutoff)
.bind(limit as i64)
.execute(self.pool.as_ref())
.execute(self.state_db.pool())
.await?
.rows_affected();
@@ -400,7 +400,7 @@ ORDER BY selected.thread_id ASC
.bind(cutoff)
.bind(cutoff)
.bind(n as i64)
.fetch_all(self.pool.as_ref())
.fetch_all(self.state_db.pool())
.await?;
let mut selected = Vec::with_capacity(current_rows.len());
@@ -421,7 +421,7 @@ ORDER BY selected.thread_id ASC
) -> anyhow::Result<bool> {
let now = Utc::now().timestamp();
let thread_id = thread_id.to_string();
let mut tx = self.pool.begin().await?;
let mut tx = self.state_db.pool().begin().await?;
let rows_affected = sqlx::query(
r#"
UPDATE threads
@@ -489,7 +489,7 @@ WHERE thread_id = ?
let thread_id = thread_id.to_string();
let worker_id = worker_id.to_string();
let mut tx = self.pool.begin_with("BEGIN IMMEDIATE").await?;
let mut tx = self.state_db.pool().begin_with("BEGIN IMMEDIATE").await?;
let existing_output = sqlx::query(
r#"
@@ -673,7 +673,7 @@ WHERE kind = ? AND job_key = ?
let now = Utc::now().timestamp();
let thread_id = thread_id.to_string();
let mut tx = self.pool.begin().await?;
let mut tx = self.state_db.pool().begin().await?;
let rows_affected = sqlx::query(
r#"
UPDATE jobs
@@ -750,7 +750,7 @@ WHERE excluded.source_updated_at >= stage1_outputs.source_updated_at
let now = Utc::now().timestamp();
let thread_id = thread_id.to_string();
let mut tx = self.pool.begin().await?;
let mut tx = self.state_db.pool().begin().await?;
let rows_affected = sqlx::query(
r#"
UPDATE jobs
@@ -848,7 +848,7 @@ WHERE kind = ? AND job_key = ?
.bind(JOB_KIND_MEMORY_STAGE1)
.bind(thread_id.as_str())
.bind(ownership_token)
.execute(self.pool.as_ref())
.execute(self.state_db.pool())
.await?
.rows_affected();
@@ -863,7 +863,7 @@ WHERE kind = ? AND job_key = ?
/// Phase 2 does not use this watermark as a dirty check; git workspace diffing
/// decides whether consolidation work exists after the lock is claimed.
pub async fn enqueue_global_consolidation(&self, input_watermark: i64) -> anyhow::Result<()> {
enqueue_global_consolidation_with_executor(self.pool.as_ref(), input_watermark).await
enqueue_global_consolidation_with_executor(self.state_db.pool(), input_watermark).await
}
/// Attempts to claim the global phase-2 consolidation lock.
@@ -890,7 +890,7 @@ WHERE kind = ? AND job_key = ?
let ownership_token = Uuid::new_v4().to_string();
let worker_id = worker_id.to_string();
let mut tx = self.pool.begin_with("BEGIN IMMEDIATE").await?;
let mut tx = self.state_db.pool().begin_with("BEGIN IMMEDIATE").await?;
let existing_job = sqlx::query(
r#"
@@ -1035,7 +1035,7 @@ WHERE kind = ? AND job_key = ?
.bind(JOB_KIND_MEMORY_CONSOLIDATE_GLOBAL)
.bind(MEMORY_CONSOLIDATION_JOB_KEY)
.bind(ownership_token)
.execute(self.pool.as_ref())
.execute(self.state_db.pool())
.await?
.rows_affected();
@@ -1058,7 +1058,7 @@ WHERE kind = ? AND job_key = ?
completed_watermark: i64,
selected_outputs: &[Stage1Output],
) -> anyhow::Result<bool> {
let mut tx = self.pool.begin().await?;
let mut tx = self.state_db.pool().begin().await?;
let rows_affected =
mark_global_phase2_job_succeeded_row(&mut *tx, ownership_token, completed_watermark)
.await?;
@@ -1136,7 +1136,7 @@ WHERE kind = ? AND job_key = ?
.bind(JOB_KIND_MEMORY_CONSOLIDATE_GLOBAL)
.bind(MEMORY_CONSOLIDATION_JOB_KEY)
.bind(ownership_token)
.execute(self.pool.as_ref())
.execute(self.state_db.pool())
.await?
.rows_affected();
@@ -1178,7 +1178,7 @@ WHERE kind = ? AND job_key = ?
.bind(JOB_KIND_MEMORY_CONSOLIDATE_GLOBAL)
.bind(MEMORY_CONSOLIDATION_JOB_KEY)
.bind(ownership_token)
.execute(self.pool.as_ref())
.execute(self.state_db.pool())
.await?
.rows_affected();
@@ -1300,7 +1300,7 @@ mod tests {
.bind(Utc::now().timestamp() - PHASE2_SUCCESS_COOLDOWN_SECONDS - 1)
.bind(JOB_KIND_MEMORY_CONSOLIDATE_GLOBAL)
.bind(MEMORY_CONSOLIDATION_JOB_KEY)
.execute(runtime.pool.as_ref())
.execute(runtime.state_db.pool())
.await
.expect("age phase2 success beyond cooldown");
}
@@ -1409,7 +1409,7 @@ mod tests {
sqlx::query("UPDATE jobs SET lease_until = 0 WHERE kind = 'memory_stage1' AND job_key = ?")
.bind(thread_id.to_string())
.execute(runtime.pool.as_ref())
.execute(runtime.state_db.pool())
.await
.expect("force stale lease");
@@ -1789,7 +1789,7 @@ mod tests {
.expect("upsert disabled thread");
sqlx::query("UPDATE threads SET memory_mode = 'disabled' WHERE id = ?")
.bind(disabled_thread_id.to_string())
.execute(runtime.pool.as_ref())
.execute(runtime.state_db.pool())
.await
.expect("disable thread memory mode");
@@ -1890,7 +1890,7 @@ mod tests {
.expect("upsert disabled thread");
sqlx::query("UPDATE threads SET memory_mode = 'disabled' WHERE id = ?")
.bind(disabled_thread_id.to_string())
.execute(runtime.pool.as_ref())
.execute(runtime.state_db.pool())
.await
.expect("disable existing thread");
@@ -1900,7 +1900,7 @@ mod tests {
.expect("clear memory data");
let stage1_outputs_count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM stage1_outputs")
.fetch_one(runtime.pool.as_ref())
.fetch_one(runtime.state_db.pool())
.await
.expect("count stage1 outputs");
assert_eq!(stage1_outputs_count, 0);
@@ -1909,7 +1909,7 @@ mod tests {
sqlx::query_scalar("SELECT COUNT(*) FROM jobs WHERE kind = ? OR kind = ?")
.bind(JOB_KIND_MEMORY_STAGE1)
.bind(JOB_KIND_MEMORY_CONSOLIDATE_GLOBAL)
.fetch_one(runtime.pool.as_ref())
.fetch_one(runtime.state_db.pool())
.await
.expect("count memory jobs");
assert_eq!(memory_jobs_count, 0);
@@ -1917,7 +1917,7 @@ mod tests {
let enabled_memory_mode: String =
sqlx::query_scalar("SELECT memory_mode FROM threads WHERE id = ?")
.bind(enabled_thread_id.to_string())
.fetch_one(runtime.pool.as_ref())
.fetch_one(runtime.state_db.pool())
.await
.expect("read enabled thread memory mode");
assert_eq!(enabled_memory_mode, "enabled");
@@ -1925,7 +1925,7 @@ mod tests {
let disabled_memory_mode: String =
sqlx::query_scalar("SELECT memory_mode FROM threads WHERE id = ?")
.bind(disabled_thread_id.to_string())
.fetch_one(runtime.pool.as_ref())
.fetch_one(runtime.state_db.pool())
.await
.expect("read disabled thread memory mode");
assert_eq!(disabled_memory_mode, "disabled");
@@ -2000,7 +2000,7 @@ INSERT INTO jobs (
.bind(lease_until)
.bind(3)
.bind(metadata.updated_at.timestamp())
.execute(runtime.pool.as_ref())
.execute(runtime.state_db.pool())
.await
.expect("seed running stage1 job");
}
@@ -2034,7 +2034,7 @@ WHERE kind = 'memory_stage1'
"#,
)
.bind(Utc::now().timestamp())
.fetch_one(runtime.pool.as_ref())
.fetch_one(runtime.state_db.pool())
.await
.expect("count running stage1 jobs")
.try_get::<i64, _>("count")
@@ -2191,7 +2191,7 @@ WHERE kind = 'memory_stage1'
let count_before =
sqlx::query("SELECT COUNT(*) AS count FROM stage1_outputs WHERE thread_id = ?")
.bind(thread_id.to_string())
.fetch_one(runtime.pool.as_ref())
.fetch_one(runtime.state_db.pool())
.await
.expect("count before delete")
.try_get::<i64, _>("count")
@@ -2200,14 +2200,14 @@ WHERE kind = 'memory_stage1'
sqlx::query("DELETE FROM threads WHERE id = ?")
.bind(thread_id.to_string())
.execute(runtime.pool.as_ref())
.execute(runtime.state_db.pool())
.await
.expect("delete thread");
let count_after =
sqlx::query("SELECT COUNT(*) AS count FROM stage1_outputs WHERE thread_id = ?")
.bind(thread_id.to_string())
.fetch_one(runtime.pool.as_ref())
.fetch_one(runtime.state_db.pool())
.await
.expect("count after delete")
.try_get::<i64, _>("count")
@@ -2258,7 +2258,7 @@ WHERE kind = 'memory_stage1'
let output_row_count =
sqlx::query("SELECT COUNT(*) AS count FROM stage1_outputs WHERE thread_id = ?")
.bind(thread_id.to_string())
.fetch_one(runtime.pool.as_ref())
.fetch_one(runtime.state_db.pool())
.await
.expect("load stage1 output count")
.try_get::<i64, _>("count")
@@ -2279,7 +2279,7 @@ WHERE kind = 'memory_stage1'
let global_job_row_count = sqlx::query("SELECT COUNT(*) AS count FROM jobs WHERE kind = ?")
.bind("memory_consolidate_global")
.fetch_one(runtime.pool.as_ref())
.fetch_one(runtime.state_db.pool())
.await
.expect("load phase2 job row count")
.try_get::<i64, _>("count")
@@ -2383,7 +2383,7 @@ WHERE kind = 'memory_stage1'
let output_row_count =
sqlx::query("SELECT COUNT(*) AS count FROM stage1_outputs WHERE thread_id = ?")
.bind(thread_id.to_string())
.fetch_one(runtime.pool.as_ref())
.fetch_one(runtime.state_db.pool())
.await
.expect("load stage1 output count after delete")
.try_get::<i64, _>("count")
@@ -2494,7 +2494,7 @@ WHERE kind = 'memory_stage1'
)
.bind("memory_stage1")
.bind(thread_id.to_string())
.fetch_one(runtime.pool.as_ref())
.fetch_one(runtime.state_db.pool())
.await
.expect("load stage1 job row after newer-source claim");
assert_eq!(
@@ -2620,7 +2620,7 @@ WHERE kind = 'memory_stage1'
sqlx::query("SELECT retry_remaining FROM jobs WHERE kind = ? AND job_key = ?")
.bind("memory_consolidate_global")
.bind("global")
.fetch_one(runtime.pool.as_ref())
.fetch_one(runtime.state_db.pool())
.await
.expect("load phase2 job row after retry exhaustion");
assert_eq!(
@@ -2787,7 +2787,7 @@ VALUES (?, ?, ?, ?, ?)
.bind("raw memory")
.bind("summary")
.bind(100_i64)
.execute(runtime.pool.as_ref())
.execute(runtime.state_db.pool())
.await
.expect("insert non-empty stage1 output");
sqlx::query(
@@ -2801,7 +2801,7 @@ VALUES (?, ?, ?, ?, ?)
.bind("")
.bind("")
.bind(101_i64)
.execute(runtime.pool.as_ref())
.execute(runtime.state_db.pool())
.await
.expect("insert empty stage1 output");
@@ -3292,7 +3292,7 @@ VALUES (?, ?, ?, ?, ?)
"SELECT selected_for_phase2, selected_for_phase2_source_updated_at FROM stage1_outputs WHERE thread_id = ?",
)
.bind(thread_id.to_string())
.fetch_one(runtime.pool.as_ref())
.fetch_one(runtime.state_db.pool())
.await
.expect("load selected_for_phase2");
assert_eq!(selected_for_phase2, 1);
@@ -3585,7 +3585,7 @@ VALUES (?, ?, ?, ?, ?)
"SELECT selected_for_phase2, selected_for_phase2_source_updated_at FROM stage1_outputs WHERE thread_id = ?",
)
.bind(thread_id.to_string())
.fetch_one(runtime.pool.as_ref())
.fetch_one(runtime.state_db.pool())
.await
.expect("load selected snapshot after phase2");
assert_eq!(selected_for_phase2, 1);
@@ -3698,7 +3698,7 @@ VALUES (?, ?, ?, ?, ?)
"SELECT selected_for_phase2, selected_for_phase2_source_updated_at FROM stage1_outputs WHERE thread_id = ?",
)
.bind(thread_id.to_string())
.fetch_one(runtime.pool.as_ref())
.fetch_one(runtime.state_db.pool())
.await
.expect("load selected_for_phase2");
assert_eq!(selected_for_phase2, 0);
@@ -3802,13 +3802,13 @@ VALUES (?, ?, ?, ?, ?)
let row_a =
sqlx::query("SELECT usage_count, last_usage FROM stage1_outputs WHERE thread_id = ?")
.bind(thread_a.to_string())
.fetch_one(runtime.pool.as_ref())
.fetch_one(runtime.state_db.pool())
.await
.expect("load stage1 usage row a");
let row_b =
sqlx::query("SELECT usage_count, last_usage FROM stage1_outputs WHERE thread_id = ?")
.bind(thread_b.to_string())
.fetch_one(runtime.pool.as_ref())
.fetch_one(runtime.state_db.pool())
.await
.expect("load stage1 usage row b");
@@ -3908,7 +3908,7 @@ VALUES (?, ?, ?, ?, ?)
.bind(usage_count)
.bind(last_usage.timestamp())
.bind(thread_id.to_string())
.execute(runtime.pool.as_ref())
.execute(runtime.state_db.pool())
.await
.expect("update usage metadata");
}
@@ -4004,7 +4004,7 @@ VALUES (?, ?, ?, ?, ?)
.bind(usage_count)
.bind(last_usage.map(|value| value.timestamp()))
.bind(thread_id.to_string())
.execute(runtime.pool.as_ref())
.execute(runtime.state_db.pool())
.await
.expect("update usage metadata");
}
@@ -4089,13 +4089,13 @@ VALUES (?, ?, ?, ?, ?)
sqlx::query("UPDATE stage1_outputs SET generated_at = ? WHERE thread_id = ?")
.bind(300_i64)
.bind(older_thread.to_string())
.execute(runtime.pool.as_ref())
.execute(runtime.state_db.pool())
.await
.expect("update older generated_at");
sqlx::query("UPDATE stage1_outputs SET generated_at = ? WHERE thread_id = ?")
.bind(150_i64)
.bind(newer_thread.to_string())
.execute(runtime.pool.as_ref())
.execute(runtime.state_db.pool())
.await
.expect("update newer generated_at");
@@ -4201,14 +4201,14 @@ VALUES (?, ?, ?, ?, ?)
.bind(3_i64)
.bind(now - Duration::days(40).num_seconds())
.bind(stale_used.to_string())
.execute(runtime.pool.as_ref())
.execute(runtime.state_db.pool())
.await
.expect("set stale used metadata");
sqlx::query(
"UPDATE stage1_outputs SET selected_for_phase2 = 1, selected_for_phase2_source_updated_at = source_updated_at WHERE thread_id = ?",
)
.bind(stale_selected.to_string())
.execute(runtime.pool.as_ref())
.execute(runtime.state_db.pool())
.await
.expect("mark selected for phase2");
sqlx::query(
@@ -4217,13 +4217,13 @@ VALUES (?, ?, ?, ?, ?)
.bind(8_i64)
.bind(now - Duration::days(2).num_seconds())
.bind(fresh_used.to_string())
.execute(runtime.pool.as_ref())
.execute(runtime.state_db.pool())
.await
.expect("set fresh used metadata");
let before_jobs_count =
sqlx::query_scalar::<_, i64>("SELECT COUNT(*) FROM jobs WHERE kind = 'memory_stage1'")
.fetch_one(runtime.pool.as_ref())
.fetch_one(runtime.state_db.pool())
.await
.expect("count stage1 jobs before prune");
@@ -4236,7 +4236,7 @@ VALUES (?, ?, ?, ?, ?)
let remaining = sqlx::query_scalar::<_, String>(
"SELECT thread_id FROM stage1_outputs ORDER BY thread_id",
)
.fetch_all(runtime.pool.as_ref())
.fetch_all(runtime.state_db.pool())
.await
.expect("load remaining stage1 outputs");
let mut expected_remaining = vec![fresh_used.to_string(), stale_selected.to_string()];
@@ -4245,7 +4245,7 @@ VALUES (?, ?, ?, ?, ?)
let after_jobs_count =
sqlx::query_scalar::<_, i64>("SELECT COUNT(*) FROM jobs WHERE kind = 'memory_stage1'")
.fetch_one(runtime.pool.as_ref())
.fetch_one(runtime.state_db.pool())
.await
.expect("count stage1 jobs after prune");
assert_eq!(after_jobs_count, before_jobs_count);
@@ -4323,7 +4323,7 @@ VALUES (?, ?, ?, ?, ?)
assert_eq!(pruned, 2);
let remaining_count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM stage1_outputs")
.fetch_one(runtime.pool.as_ref())
.fetch_one(runtime.state_db.pool())
.await
.expect("count remaining stage1 outputs");
assert_eq!(remaining_count, 1);
@@ -4539,7 +4539,7 @@ VALUES (?, ?, ?, ?, ?)
.bind(Utc::now().timestamp() - 1)
.bind("memory_consolidate_global")
.bind("global")
.execute(runtime.pool.as_ref())
.execute(runtime.state_db.pool())
.await
.expect("expire global consolidation lease");
@@ -4675,7 +4675,7 @@ VALUES (?, ?, ?, ?, ?)
sqlx::query("UPDATE jobs SET ownership_token = NULL WHERE kind = ? AND job_key = ?")
.bind("memory_consolidate_global")
.bind("global")
.execute(runtime.pool.as_ref())
.execute(runtime.state_db.pool())
.await
.expect("clear ownership token");

View File

@@ -44,7 +44,7 @@ WHERE websocket_url = ? AND account_id = ? AND app_server_client_name = ?
.bind(remote_control_app_server_client_name_key(
app_server_client_name,
))
.fetch_optional(self.pool.as_ref())
.fetch_optional(self.state_db.pool())
.await?;
row.map(|row| {
@@ -92,7 +92,7 @@ ON CONFLICT(websocket_url, account_id, app_server_client_name) DO UPDATE SET
.bind(&enrollment.environment_id)
.bind(&enrollment.server_name)
.bind(Utc::now().timestamp())
.execute(self.pool.as_ref())
.execute(self.state_db.pool())
.await?;
Ok(())
}
@@ -114,7 +114,7 @@ WHERE websocket_url = ? AND account_id = ? AND app_server_client_name = ?
.bind(remote_control_app_server_client_name_key(
app_server_client_name,
))
.execute(self.pool.as_ref())
.execute(self.state_db.pool())
.await?;
Ok(result.rows_affected())
}

View File

@@ -5,8 +5,10 @@ use std::sync::atomic::Ordering;
impl StateRuntime {
pub async fn get_thread(&self, id: ThreadId) -> anyhow::Result<Option<crate::ThreadMetadata>> {
let row = sqlx::query(
r#"
self.state_db
.read(DbOperation::GetThread, |pool| async move {
let row = sqlx::query(
r#"
SELECT
threads.id,
threads.rollout_path,
@@ -34,18 +36,20 @@ SELECT
FROM threads
WHERE threads.id = ?
"#,
)
.bind(id.to_string())
.fetch_optional(self.pool.as_ref())
.await?;
row.map(|row| ThreadRow::try_from_row(&row).and_then(ThreadMetadata::try_from))
.transpose()
)
.bind(id.to_string())
.fetch_optional(&pool)
.await?;
row.map(|row| ThreadRow::try_from_row(&row).and_then(ThreadMetadata::try_from))
.transpose()
})
.await
}
pub async fn get_thread_memory_mode(&self, id: ThreadId) -> anyhow::Result<Option<String>> {
let row = sqlx::query("SELECT memory_mode FROM threads WHERE id = ?")
.bind(id.to_string())
.fetch_optional(self.pool.as_ref())
.fetch_optional(self.state_db.pool())
.await?;
Ok(row.and_then(|row| row.try_get("memory_mode").ok()))
}
@@ -55,33 +59,37 @@ WHERE threads.id = ?
&self,
thread_id: ThreadId,
) -> anyhow::Result<Option<Vec<DynamicToolSpec>>> {
let rows = sqlx::query(
r#"
self.state_db
.read(DbOperation::GetDynamicTools, |pool| async move {
let rows = sqlx::query(
r#"
SELECT namespace, name, description, input_schema, defer_loading
FROM thread_dynamic_tools
WHERE thread_id = ?
ORDER BY position ASC
"#,
)
.bind(thread_id.to_string())
.fetch_all(self.pool.as_ref())
.await?;
if rows.is_empty() {
return Ok(None);
}
let mut tools = Vec::with_capacity(rows.len());
for row in rows {
let input_schema: String = row.try_get("input_schema")?;
let input_schema = serde_json::from_str::<Value>(input_schema.as_str())?;
tools.push(DynamicToolSpec {
namespace: row.try_get("namespace")?,
name: row.try_get("name")?,
description: row.try_get("description")?,
input_schema,
defer_loading: row.try_get("defer_loading")?,
});
}
Ok(Some(tools))
)
.bind(thread_id.to_string())
.fetch_all(&pool)
.await?;
if rows.is_empty() {
return Ok(None);
}
let mut tools = Vec::with_capacity(rows.len());
for row in rows {
let input_schema: String = row.try_get("input_schema")?;
let input_schema = serde_json::from_str::<Value>(input_schema.as_str())?;
tools.push(DynamicToolSpec {
namespace: row.try_get("namespace")?,
name: row.try_get("name")?,
description: row.try_get("description")?,
input_schema,
defer_loading: row.try_get("defer_loading")?,
});
}
Ok(Some(tools))
})
.await
}
/// Persist or replace the directional parent-child edge for a spawned thread.
@@ -106,7 +114,7 @@ ON CONFLICT(child_thread_id) DO UPDATE SET
.bind(parent_thread_id.to_string())
.bind(child_thread_id.to_string())
.bind(status.as_ref())
.execute(self.pool.as_ref())
.execute(self.state_db.pool())
.await?;
Ok(())
}
@@ -120,7 +128,7 @@ ON CONFLICT(child_thread_id) DO UPDATE SET
sqlx::query("UPDATE thread_spawn_edges SET status = ? WHERE child_thread_id = ?")
.bind(status.as_ref())
.bind(child_thread_id.to_string())
.execute(self.pool.as_ref())
.execute(self.state_db.pool())
.await?;
Ok(())
}
@@ -186,7 +194,7 @@ LIMIT 2
)
.bind(parent_thread_id.to_string())
.bind(agent_path)
.fetch_all(self.pool.as_ref())
.fetch_all(self.state_db.pool())
.await?;
one_thread_id_from_rows(rows, agent_path)
}
@@ -218,7 +226,7 @@ LIMIT 2
)
.bind(root_thread_id.to_string())
.bind(agent_path)
.fetch_all(self.pool.as_ref())
.fetch_all(self.state_db.pool())
.await?;
one_thread_id_from_rows(rows, agent_path)
}
@@ -241,7 +249,7 @@ LIMIT 2
sql = sql.bind(status.to_string());
}
let rows = sql.fetch_all(self.pool.as_ref()).await?;
let rows = sql.fetch_all(self.state_db.pool()).await?;
rows.into_iter()
.map(|row| {
ThreadId::try_from(row.try_get::<String, _>("child_thread_id")?).map_err(Into::into)
@@ -283,7 +291,7 @@ ORDER BY depth ASC, child_thread_id ASC
sql = sql.bind(status.clone()).bind(status);
}
let rows = sql.fetch_all(self.pool.as_ref()).await?;
let rows = sql.fetch_all(self.state_db.pool()).await?;
rows.into_iter()
.map(|row| {
ThreadId::try_from(row.try_get::<String, _>("child_thread_id")?).map_err(Into::into)
@@ -309,7 +317,7 @@ ON CONFLICT(child_thread_id) DO NOTHING
.bind(parent_thread_id.to_string())
.bind(child_thread_id.to_string())
.bind(crate::DirectionalThreadSpawnEdgeStatus::Open.as_ref())
.execute(self.pool.as_ref())
.execute(self.state_db.pool())
.await?;
Ok(())
}
@@ -332,22 +340,26 @@ ON CONFLICT(child_thread_id) DO NOTHING
id: ThreadId,
archived_only: Option<bool>,
) -> anyhow::Result<Option<PathBuf>> {
let mut builder =
QueryBuilder::<Sqlite>::new("SELECT rollout_path FROM threads WHERE id = ");
builder.push_bind(id.to_string());
match archived_only {
Some(true) => {
builder.push(" AND archived = 1");
}
Some(false) => {
builder.push(" AND archived = 0");
}
None => {}
}
let row = builder.build().fetch_optional(self.pool.as_ref()).await?;
Ok(row
.and_then(|r| r.try_get::<String, _>("rollout_path").ok())
.map(PathBuf::from))
self.state_db
.read(DbOperation::FindRolloutPathById, |pool| async move {
let mut builder =
QueryBuilder::<Sqlite>::new("SELECT rollout_path FROM threads WHERE id = ");
builder.push_bind(id.to_string());
match archived_only {
Some(true) => {
builder.push(" AND archived = 1");
}
Some(false) => {
builder.push(" AND archived = 0");
}
None => {}
}
let row = builder.build().fetch_optional(&pool).await?;
Ok(row
.and_then(|r| r.try_get::<String, _>("rollout_path").ok())
.map(PathBuf::from))
})
.await
}
/// Find the newest thread whose user-facing title exactly matches `title`.
@@ -389,7 +401,7 @@ ON CONFLICT(child_thread_id) DO NOTHING
/*limit*/ 1,
);
let row = builder.build().fetch_optional(self.pool.as_ref()).await?;
let row = builder.build().fetch_optional(self.state_db.pool()).await?;
row.map(|row| ThreadRow::try_from_row(&row).and_then(crate::ThreadMetadata::try_from))
.transpose()
}
@@ -400,35 +412,39 @@ ON CONFLICT(child_thread_id) DO NOTHING
page_size: usize,
filters: ThreadFilterOptions<'_>,
) -> anyhow::Result<crate::ThreadsPage> {
let limit = page_size.saturating_add(1);
let sort_key = filters.sort_key;
let sort_direction = filters.sort_direction;
self.state_db
.read(DbOperation::ListThreads, |pool| async move {
let limit = page_size.saturating_add(1);
let sort_key = filters.sort_key;
let sort_direction = filters.sort_direction;
let mut builder = QueryBuilder::<Sqlite>::new("");
push_thread_select_columns(&mut builder);
builder.push(" FROM threads");
push_thread_filters(&mut builder, filters);
push_thread_order_and_limit(&mut builder, sort_key, sort_direction, limit);
let mut builder = QueryBuilder::<Sqlite>::new("");
push_thread_select_columns(&mut builder);
builder.push(" FROM threads");
push_thread_filters(&mut builder, filters);
push_thread_order_and_limit(&mut builder, sort_key, sort_direction, limit);
let rows = builder.build().fetch_all(self.pool.as_ref()).await?;
let mut items = rows
.into_iter()
.map(|row| ThreadRow::try_from_row(&row).and_then(ThreadMetadata::try_from))
.collect::<Result<Vec<_>, _>>()?;
let num_scanned_rows = items.len();
let next_anchor = if items.len() > page_size {
items.pop();
items
.last()
.and_then(|item| anchor_from_item(item, sort_key))
} else {
None
};
Ok(ThreadsPage {
items,
next_anchor,
num_scanned_rows,
})
let rows = builder.build().fetch_all(&pool).await?;
let mut items = rows
.into_iter()
.map(|row| ThreadRow::try_from_row(&row).and_then(ThreadMetadata::try_from))
.collect::<Result<Vec<_>, _>>()?;
let num_scanned_rows = items.len();
let next_anchor = if items.len() > page_size {
items.pop();
items
.last()
.and_then(|item| anchor_from_item(item, sort_key))
} else {
None
};
Ok(ThreadsPage {
items,
next_anchor,
num_scanned_rows,
})
})
.await
}
/// List thread ids using the underlying database (no rollout scanning).
@@ -457,7 +473,7 @@ ON CONFLICT(child_thread_id) DO NOTHING
);
push_thread_order_and_limit(&mut builder, sort_key, SortDirection::Desc, limit);
let rows = builder.build().fetch_all(self.pool.as_ref()).await?;
let rows = builder.build().fetch_all(self.state_db.pool()).await?;
rows.into_iter()
.map(|row| {
let id: String = row.try_get("id")?;
@@ -547,7 +563,7 @@ ON CONFLICT(id) DO NOTHING
.bind(metadata.git_branch.as_deref())
.bind(metadata.git_origin_url.as_deref())
.bind("enabled")
.execute(self.pool.as_ref())
.execute(self.state_db.pool())
.await?;
self.insert_thread_spawn_edge_from_source_if_absent(metadata.id, metadata.source.as_str())
.await?;
@@ -562,7 +578,7 @@ ON CONFLICT(id) DO NOTHING
let result = sqlx::query("UPDATE threads SET memory_mode = ? WHERE id = ?")
.bind(memory_mode)
.bind(thread_id.to_string())
.execute(self.pool.as_ref())
.execute(self.state_db.pool())
.await?;
Ok(result.rows_affected() > 0)
}
@@ -575,7 +591,7 @@ ON CONFLICT(id) DO NOTHING
let result = sqlx::query("UPDATE threads SET title = ? WHERE id = ?")
.bind(title)
.bind(thread_id.to_string())
.execute(self.pool.as_ref())
.execute(self.state_db.pool())
.await?;
Ok(result.rows_affected() > 0)
}
@@ -586,14 +602,19 @@ ON CONFLICT(id) DO NOTHING
updated_at: DateTime<Utc>,
) -> anyhow::Result<bool> {
let updated_at = self.allocate_thread_updated_at(updated_at)?;
let result =
sqlx::query("UPDATE threads SET updated_at = ?, updated_at_ms = ? WHERE id = ?")
self.state_db
.write(DbOperation::TouchThreadUpdatedAt, |pool| async move {
let result = sqlx::query(
"UPDATE threads SET updated_at = ?, updated_at_ms = ? WHERE id = ?",
)
.bind(datetime_to_epoch_seconds(updated_at))
.bind(datetime_to_epoch_millis(updated_at))
.bind(thread_id.to_string())
.execute(self.pool.as_ref())
.execute(&pool)
.await?;
Ok(result.rows_affected() > 0)
Ok(result.rows_affected() > 0)
})
.await
}
/// Allocate a persisted `updated_at` value for thread-list cursor ordering.
@@ -666,7 +687,7 @@ WHERE id = ?
.bind(git_origin_url.is_some())
.bind(git_origin_url.flatten())
.bind(thread_id.to_string())
.execute(self.pool.as_ref())
.execute(self.state_db.pool())
.await?;
Ok(result.rows_affected() > 0)
}
@@ -676,12 +697,14 @@ WHERE id = ?
metadata: &crate::ThreadMetadata,
creation_memory_mode: Option<&str>,
) -> anyhow::Result<()> {
let updated_at = self.allocate_thread_updated_at(metadata.updated_at)?;
// Backfill/reconcile callers merge existing git info before upserting, but that
// read/modify/write is not atomic. Preserve non-null SQLite git fields here so
// an explicit metadata update cannot be lost if a stale rollout upsert lands later.
sqlx::query(
r#"
let started = Instant::now();
let result: anyhow::Result<()> = async {
let updated_at = self.allocate_thread_updated_at(metadata.updated_at)?;
// Backfill/reconcile callers merge existing git info before upserting, but that
// read/modify/write is not atomic. Preserve non-null SQLite git fields here so
// an explicit metadata update cannot be lost if a stale rollout upsert lands later.
sqlx::query(
r#"
INSERT INTO threads (
id,
rollout_path,
@@ -738,48 +761,56 @@ ON CONFLICT(id) DO UPDATE SET
git_branch = COALESCE(threads.git_branch, excluded.git_branch),
git_origin_url = COALESCE(threads.git_origin_url, excluded.git_origin_url)
"#,
)
.bind(metadata.id.to_string())
.bind(metadata.rollout_path.display().to_string())
.bind(datetime_to_epoch_seconds(metadata.created_at))
.bind(datetime_to_epoch_seconds(updated_at))
.bind(datetime_to_epoch_millis(metadata.created_at))
.bind(datetime_to_epoch_millis(updated_at))
.bind(metadata.source.as_str())
.bind(
metadata
.thread_source
.map(codex_protocol::protocol::ThreadSource::as_str),
)
.bind(metadata.agent_nickname.as_deref())
.bind(metadata.agent_role.as_deref())
.bind(metadata.agent_path.as_deref())
.bind(metadata.model_provider.as_str())
.bind(metadata.model.as_deref())
.bind(
metadata
.reasoning_effort
.as_ref()
.map(crate::extract::enum_to_string),
)
.bind(metadata.cwd.display().to_string())
.bind(metadata.cli_version.as_str())
.bind(metadata.title.as_str())
.bind(metadata.sandbox_policy.as_str())
.bind(metadata.approval_mode.as_str())
.bind(metadata.tokens_used)
.bind(metadata.first_user_message.as_deref().unwrap_or_default())
.bind(metadata.archived_at.is_some())
.bind(metadata.archived_at.map(datetime_to_epoch_seconds))
.bind(metadata.git_sha.as_deref())
.bind(metadata.git_branch.as_deref())
.bind(metadata.git_origin_url.as_deref())
.bind(creation_memory_mode.unwrap_or("enabled"))
.execute(self.pool.as_ref())
.await?;
self.insert_thread_spawn_edge_from_source_if_absent(metadata.id, metadata.source.as_str())
)
.bind(metadata.id.to_string())
.bind(metadata.rollout_path.display().to_string())
.bind(datetime_to_epoch_seconds(metadata.created_at))
.bind(datetime_to_epoch_seconds(updated_at))
.bind(datetime_to_epoch_millis(metadata.created_at))
.bind(datetime_to_epoch_millis(updated_at))
.bind(metadata.source.as_str())
.bind(
metadata
.thread_source
.map(codex_protocol::protocol::ThreadSource::as_str),
)
.bind(metadata.agent_nickname.as_deref())
.bind(metadata.agent_role.as_deref())
.bind(metadata.agent_path.as_deref())
.bind(metadata.model_provider.as_str())
.bind(metadata.model.as_deref())
.bind(
metadata
.reasoning_effort
.as_ref()
.map(crate::extract::enum_to_string),
)
.bind(metadata.cwd.display().to_string())
.bind(metadata.cli_version.as_str())
.bind(metadata.title.as_str())
.bind(metadata.sandbox_policy.as_str())
.bind(metadata.approval_mode.as_str())
.bind(metadata.tokens_used)
.bind(metadata.first_user_message.as_deref().unwrap_or_default())
.bind(metadata.archived_at.is_some())
.bind(metadata.archived_at.map(datetime_to_epoch_seconds))
.bind(metadata.git_sha.as_deref())
.bind(metadata.git_branch.as_deref())
.bind(metadata.git_origin_url.as_deref())
.bind(creation_memory_mode.unwrap_or("enabled"))
.execute(self.state_db.pool())
.await?;
Ok(())
self.insert_thread_spawn_edge_from_source_if_absent(
metadata.id,
metadata.source.as_str(),
)
.await?;
Ok(())
}
.await;
self.state_db
.record_result(DbOperation::UpsertThread, DbAccess::Write, started, &result);
result
}
/// Persist dynamic tools for a thread if none have been stored yet.
@@ -791,19 +822,21 @@ ON CONFLICT(id) DO UPDATE SET
thread_id: ThreadId,
tools: Option<&[DynamicToolSpec]>,
) -> anyhow::Result<()> {
let Some(tools) = tools else {
return Ok(());
};
if tools.is_empty() {
return Ok(());
}
let thread_id = thread_id.to_string();
let mut tx = self.pool.begin().await?;
for (idx, tool) in tools.iter().enumerate() {
let position = i64::try_from(idx).unwrap_or(i64::MAX);
let input_schema = serde_json::to_string(&tool.input_schema)?;
sqlx::query(
r#"
self.state_db
.transaction(DbOperation::PersistDynamicTools, |pool| async move {
let Some(tools) = tools else {
return Ok(());
};
if tools.is_empty() {
return Ok(());
}
let thread_id = thread_id.to_string();
let mut tx = pool.begin().await?;
for (idx, tool) in tools.iter().enumerate() {
let position = i64::try_from(idx).unwrap_or(i64::MAX);
let input_schema = serde_json::to_string(&tool.input_schema)?;
sqlx::query(
r#"
INSERT INTO thread_dynamic_tools (
thread_id,
position,
@@ -815,19 +848,21 @@ INSERT INTO thread_dynamic_tools (
) VALUES (?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(thread_id, position) DO NOTHING
"#,
)
.bind(thread_id.as_str())
.bind(position)
.bind(tool.namespace.as_deref())
.bind(tool.name.as_str())
.bind(tool.description.as_str())
.bind(input_schema)
.bind(tool.defer_loading)
.execute(&mut *tx)
.await?;
}
tx.commit().await?;
Ok(())
)
.bind(thread_id.as_str())
.bind(position)
.bind(tool.namespace.as_deref())
.bind(tool.name.as_str())
.bind(tool.description.as_str())
.bind(input_schema)
.bind(tool.defer_loading)
.execute(&mut *tx)
.await?;
}
tx.commit().await?;
Ok(())
})
.await
}
/// Apply rollout items incrementally using the underlying database.
@@ -937,7 +972,7 @@ ON CONFLICT(thread_id, position) DO NOTHING
pub async fn delete_thread(&self, thread_id: ThreadId) -> anyhow::Result<u64> {
let result = sqlx::query("DELETE FROM threads WHERE id = ?")
.bind(thread_id.to_string())
.execute(self.pool.as_ref())
.execute(self.state_db.pool())
.await?;
Ok(result.rows_affected())
}
@@ -1171,7 +1206,7 @@ mod tests {
let memory_mode: String =
sqlx::query_scalar("SELECT memory_mode FROM threads WHERE id = ?")
.bind(thread_id.to_string())
.fetch_one(runtime.pool.as_ref())
.fetch_one(runtime.state_db.pool())
.await
.expect("memory mode should be readable");
assert_eq!(memory_mode, "disabled");
@@ -1185,7 +1220,7 @@ mod tests {
let memory_mode: String =
sqlx::query_scalar("SELECT memory_mode FROM threads WHERE id = ?")
.bind(thread_id.to_string())
.fetch_one(runtime.pool.as_ref())
.fetch_one(runtime.state_db.pool())
.await
.expect("memory mode should remain readable");
assert_eq!(memory_mode, "disabled");
@@ -1539,7 +1574,7 @@ mod tests {
.bind(123_i64)
.bind("newer preview")
.bind(thread_id.to_string())
.execute(runtime.pool.as_ref())
.execute(runtime.state_db.pool())
.await
.expect("concurrent metadata write should succeed");
@@ -1739,7 +1774,7 @@ mod tests {
"SELECT created_at, updated_at, created_at_ms, updated_at_ms FROM threads WHERE id = ?",
)
.bind(second_id.to_string())
.fetch_one(runtime.pool.as_ref())
.fetch_one(runtime.state_db.pool())
.await
.expect("thread timestamp row should load");
assert_eq!(
@@ -1773,7 +1808,7 @@ mod tests {
sqlx::query("UPDATE threads SET updated_at = ? WHERE id = ?")
.bind(1_700_001_112_i64)
.bind(first_id.to_string())
.execute(runtime.pool.as_ref())
.execute(runtime.state_db.pool())
.await
.expect("legacy timestamp write should succeed");
let legacy = runtime
@@ -1985,7 +2020,7 @@ INSERT INTO thread_spawn_edges (
.bind(parent_thread_id.to_string())
.bind(future_child_thread_id.to_string())
.bind("future")
.execute(runtime.pool.as_ref())
.execute(runtime.state_db.pool())
.await
.expect("future-status child edge insert should succeed");

View File

@@ -0,0 +1,342 @@
use std::borrow::Cow;
use std::sync::Arc;
use std::time::Duration;
use crate::DB_FALLBACK_METRIC;
use crate::DB_INIT_DURATION_METRIC;
use crate::DB_INIT_METRIC;
use crate::DB_LOG_QUEUE_METRIC;
use crate::DB_OPERATION_DURATION_METRIC;
use crate::DB_OPERATION_METRIC;
/// Low-cardinality metrics sink used by the SQLite state runtime.
///
/// Implementations should ignore recording errors locally. Database operations
/// must never fail because telemetry delivery failed.
pub trait DbMetricsRecorder: Send + Sync + 'static {
/// Increment a counter metric by `inc` with low-cardinality tags.
fn counter(&self, name: &str, inc: i64, tags: &[(&str, &str)]);
/// Record an elapsed duration metric with low-cardinality tags.
fn record_duration(&self, name: &str, duration: Duration, tags: &[(&str, &str)]);
}
/// Shared recorder handle stored by `StateRuntime` and cloned by log layers.
pub type DbMetricsRecorderHandle = Arc<dyn DbMetricsRecorder>;
#[derive(Clone, Copy)]
pub(crate) enum DbKind {
State,
Logs,
}
impl DbKind {
fn as_str(self) -> &'static str {
match self {
Self::State => "state",
Self::Logs => "logs",
}
}
}
#[derive(Clone, Copy)]
pub(crate) enum DbAccess {
Read,
Write,
Transaction,
Maintenance,
}
impl DbAccess {
fn as_str(self) -> &'static str {
match self {
Self::Read => "read",
Self::Write => "write",
Self::Transaction => "transaction",
Self::Maintenance => "maintenance",
}
}
}
pub(crate) fn record_init_result<T>(
metrics: Option<&dyn DbMetricsRecorder>,
db: DbKind,
phase: &'static str,
duration: Duration,
result: &anyhow::Result<T>,
) {
let outcome = DbOutcomeTags::from_result(result);
let tags = [
("status", outcome.status),
("phase", phase),
("db", db.as_str()),
("error", outcome.error),
];
record_counter(metrics, DB_INIT_METRIC, &tags);
record_duration(metrics, DB_INIT_DURATION_METRIC, duration, &tags);
}
pub fn record_fallback(
metrics: Option<&dyn DbMetricsRecorder>,
caller: &'static str,
reason: &'static str,
) {
let tags = [("caller", caller), ("reason", reason)];
record_counter(metrics, DB_FALLBACK_METRIC, &tags);
}
pub fn record_init_backfill_gate(
metrics: Option<&dyn DbMetricsRecorder>,
duration: Duration,
result: &anyhow::Result<()>,
) {
record_init_result(metrics, DbKind::State, "backfill_gate", duration, result);
}
pub(crate) fn record_log_queue(
metrics: Option<&dyn DbMetricsRecorder>,
event: &'static str,
reason: &'static str,
) {
let tags = [("event", event), ("reason", reason)];
record_counter(metrics, DB_LOG_QUEUE_METRIC, &tags);
}
pub(crate) fn classify_error(err: &anyhow::Error) -> &'static str {
for cause in err.chain() {
if let Some(sqlx_err) = cause.downcast_ref::<sqlx::Error>() {
return classify_sqlx_error(sqlx_err);
}
if cause
.downcast_ref::<sqlx::migrate::MigrateError>()
.is_some()
{
return "migration";
}
if cause.downcast_ref::<serde_json::Error>().is_some() {
return "serde";
}
if cause.downcast_ref::<std::io::Error>().is_some() {
return "io";
}
}
"unknown"
}
pub(crate) fn classify_sqlite_code(code: &str) -> &'static str {
let primary_code = code.parse::<i32>().ok().map(|code| code & 0xff);
match primary_code {
Some(5) => "busy",
Some(6) => "locked",
Some(8) => "readonly",
Some(10) => "io",
Some(11) => "corrupt",
Some(13) => "full",
Some(14) => "cantopen",
Some(19) => "constraint",
Some(17) => "schema",
_ => "unknown",
}
}
pub(crate) fn record_operation_result<T>(
metrics: Option<&dyn DbMetricsRecorder>,
db: DbKind,
operation: &'static str,
access: DbAccess,
duration: Duration,
result: &anyhow::Result<T>,
) {
let outcome = DbOutcomeTags::from_result(result);
let tags = [
("status", outcome.status),
("db", db.as_str()),
("operation", operation),
("access", access.as_str()),
("error", outcome.error),
];
record_counter(metrics, DB_OPERATION_METRIC, &tags);
record_duration(metrics, DB_OPERATION_DURATION_METRIC, duration, &tags);
}
struct DbOutcomeTags {
status: &'static str,
error: &'static str,
}
impl DbOutcomeTags {
fn from_result<T>(result: &anyhow::Result<T>) -> Self {
match result {
Ok(_) => Self {
status: "success",
error: "none",
},
Err(err) => Self {
status: "failed",
error: classify_error(err),
},
}
}
}
fn classify_sqlx_error(err: &sqlx::Error) -> &'static str {
match err {
sqlx::Error::Database(database_error) => {
let code = database_error
.code()
.unwrap_or(Cow::Borrowed("none"))
.to_string();
classify_sqlite_code(code.as_str())
}
sqlx::Error::PoolTimedOut => "pool_timeout",
sqlx::Error::Io(_) => "io",
sqlx::Error::ColumnDecode { source, .. } if source.is::<serde_json::Error>() => "serde",
sqlx::Error::Decode(source) if source.is::<serde_json::Error>() => "serde",
_ => "unknown",
}
}
fn record_counter(metrics: Option<&dyn DbMetricsRecorder>, name: &str, tags: &[(&str, &str)]) {
if let Some(metrics) = metrics {
metrics.counter(name, /*inc*/ 1, tags);
}
}
fn record_duration(
metrics: Option<&dyn DbMetricsRecorder>,
name: &str,
duration: Duration,
tags: &[(&str, &str)],
) {
if let Some(metrics) = metrics {
metrics.record_duration(name, duration, tags);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::DB_FALLBACK_METRIC;
use crate::DB_OPERATION_METRIC;
use pretty_assertions::assert_eq;
use std::collections::BTreeMap;
use std::sync::Mutex;
#[derive(Default)]
struct TestMetrics {
events: Mutex<Vec<MetricEvent>>,
}
#[derive(Debug, Eq, PartialEq)]
struct MetricEvent {
name: String,
tags: BTreeMap<String, String>,
}
impl TestMetrics {
fn events(&self) -> Vec<MetricEvent> {
self.events
.lock()
.expect("metrics lock")
.iter()
.map(|event| MetricEvent {
name: event.name.clone(),
tags: event.tags.clone(),
})
.collect()
}
}
impl DbMetricsRecorder for TestMetrics {
fn counter(&self, name: &str, _inc: i64, tags: &[(&str, &str)]) {
self.events.lock().expect("metrics lock").push(MetricEvent {
name: name.to_string(),
tags: tags_to_map(tags),
});
}
fn record_duration(&self, _name: &str, _duration: Duration, _tags: &[(&str, &str)]) {}
}
fn tags_to_map(tags: &[(&str, &str)]) -> BTreeMap<String, String> {
tags.iter()
.map(|(key, value)| ((*key).to_string(), (*value).to_string()))
.collect()
}
#[test]
fn classifies_sqlite_primary_codes() {
assert_eq!(classify_sqlite_code("5"), "busy");
assert_eq!(classify_sqlite_code("6"), "locked");
assert_eq!(classify_sqlite_code("14"), "cantopen");
assert_eq!(classify_sqlite_code("2067"), "constraint");
}
#[test]
fn classifies_non_sqlite_errors() {
let io_error =
anyhow::Error::new(std::io::Error::new(std::io::ErrorKind::NotFound, "missing"));
assert_eq!(classify_error(&io_error), "io");
let serde_error =
anyhow::Error::new(serde_json::from_str::<serde_json::Value>("not-json").unwrap_err());
assert_eq!(classify_error(&serde_error), "serde");
let unknown_error = anyhow::anyhow!("plain failure");
assert_eq!(classify_error(&unknown_error), "unknown");
}
#[test]
fn classifies_sqlx_pool_timeout() {
let err = anyhow::Error::new(sqlx::Error::PoolTimedOut);
assert_eq!(classify_error(&err), "pool_timeout");
}
#[test]
fn records_operation_metric_with_stable_tags() {
let metrics = TestMetrics::default();
let result: anyhow::Result<()> = Ok(());
record_operation_result(
Some(&metrics),
DbKind::State,
"list_threads",
DbAccess::Read,
Duration::from_millis(3),
&result,
);
assert_eq!(
metrics.events(),
vec![MetricEvent {
name: DB_OPERATION_METRIC.to_string(),
tags: BTreeMap::from([
("access".to_string(), "read".to_string()),
("db".to_string(), "state".to_string()),
("error".to_string(), "none".to_string()),
("operation".to_string(), "list_threads".to_string()),
("status".to_string(), "success".to_string()),
]),
}]
);
}
#[test]
fn records_fallback_metric_with_reason() {
let metrics = TestMetrics::default();
record_fallback(Some(&metrics), "list_threads", "db_error");
assert_eq!(
metrics.events(),
vec![MetricEvent {
name: DB_FALLBACK_METRIC.to_string(),
tags: BTreeMap::from([
("caller".to_string(), "list_threads".to_string()),
("reason".to_string(), "db_error".to_string()),
]),
}]
);
}
}

View File

@@ -5,6 +5,7 @@ use codex_otel::OtelExporter;
use codex_otel::OtelProvider;
use codex_otel::OtelSettings;
use codex_otel::StatsigMetricsSettings;
use std::collections::BTreeMap;
use std::path::Path;
const WFP_SETUP_SERVICE_NAME: &str = "codex-windows-sandbox-setup";
@@ -54,6 +55,8 @@ fn build_wfp_metrics_provider(
trace_exporter: OtelExporter::None,
metrics_exporter: OtelExporter::Statsig,
runtime_metrics: false,
span_attributes: BTreeMap::new(),
tracestate: BTreeMap::new(),
})
.map_err(|err| anyhow::anyhow!("failed to initialize WFP setup metrics provider: {err}"))
}

View File

@@ -26,3 +26,24 @@ When enabled, Codex appends a `Co-authored-by:` trailer using the configured
attribution value. If `commit_attribution` is omitted, Codex uses
`Codex <noreply@openai.com>`. Set `commit_attribution = ""` to disable the
trailer while leaving the feature flag enabled.
## OpenTelemetry Trace Metadata
Codex can add static OpenTelemetry span attributes to exported trace spans and
static W3C tracestate fields to propagated trace context:
```toml
[otel.span_attributes]
"example.trace_attr" = "enabled"
[otel.tracestate.example]
alpha = "one"
beta = "two"
```
Nested `otel.tracestate` tables are encoded as semicolon-separated `key:value`
fields inside the named tracestate member. If propagated trace context already
has the named member, Codex upserts configured fields and preserves other fields
in that member. This config shape does not support setting opaque tracestate
member values. Invalid trace metadata entries are ignored during config load and
reported as startup warnings.