Compare commits

...

1 Commits

Author SHA1 Message Date
jif-oai
71230af19d Bound state log payloads 2026-05-18 10:59:53 +02:00

View File

@@ -19,6 +19,8 @@
//! # }
//! ```
use std::fmt;
use std::fmt::Write as _;
use std::future::Future;
use std::sync::OnceLock;
use std::time::Duration;
@@ -36,8 +38,8 @@ use tracing::span::Record;
use tracing_subscriber::Layer;
use tracing_subscriber::field::RecordFields;
use tracing_subscriber::fmt::FormatFields;
use tracing_subscriber::fmt::FormattedFields;
use tracing_subscriber::fmt::format::DefaultFields;
use tracing_subscriber::fmt::format::Writer;
use tracing_subscriber::registry::LookupSpan;
use uuid::Uuid;
@@ -47,6 +49,8 @@ use crate::StateRuntime;
const LOG_QUEUE_CAPACITY: usize = 512;
const LOG_BATCH_SIZE: usize = 128;
const LOG_FLUSH_INTERVAL: Duration = Duration::from_secs(2);
const MAX_LOG_VALUE_BYTES: usize = 16 * 1024;
const LOG_VALUE_TRUNCATED_MARKER: &str = "...[truncated]...";
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct LogSinkQueueConfig {
@@ -245,7 +249,7 @@ struct SpanFieldVisitor {
}
impl SpanFieldVisitor {
fn record_field(&mut self, field: &Field, value: String) {
fn record_thread_id(&mut self, field: &Field, value: String) {
if field.name() == "thread_id" && self.thread_id.is_none() {
self.thread_id = Some(value);
}
@@ -254,31 +258,33 @@ impl SpanFieldVisitor {
impl Visit for SpanFieldVisitor {
fn record_i64(&mut self, field: &Field, value: i64) {
self.record_field(field, value.to_string());
self.record_thread_id(field, value.to_string());
}
fn record_u64(&mut self, field: &Field, value: u64) {
self.record_field(field, value.to_string());
self.record_thread_id(field, value.to_string());
}
fn record_bool(&mut self, field: &Field, value: bool) {
self.record_field(field, value.to_string());
self.record_thread_id(field, value.to_string());
}
fn record_f64(&mut self, field: &Field, value: f64) {
self.record_field(field, value.to_string());
self.record_thread_id(field, value.to_string());
}
fn record_str(&mut self, field: &Field, value: &str) {
self.record_field(field, value.to_string());
self.record_thread_id(field, value.to_string());
}
fn record_error(&mut self, field: &Field, value: &(dyn std::error::Error + 'static)) {
self.record_field(field, value.to_string());
self.record_thread_id(field, value.to_string());
}
fn record_debug(&mut self, field: &Field, value: &dyn std::fmt::Debug) {
self.record_field(field, format!("{value:?}"));
if field.name() == "thread_id" && self.thread_id.is_none() {
self.thread_id = Some(format!("{value:?}"));
}
}
}
@@ -310,28 +316,28 @@ fn format_feedback_log_body<S>(
where
S: tracing::Subscriber + for<'a> LookupSpan<'a>,
{
let mut feedback_log_body = String::new();
let mut feedback_log_body = BoundedLogWriter::new(MAX_LOG_VALUE_BYTES);
if let Some(scope) = ctx.event_scope(event) {
for span in scope.from_root() {
let extensions = span.extensions();
if let Some(log_context) = extensions.get::<SpanLogContext>() {
feedback_log_body.push_str(&log_context.name);
let _ = feedback_log_body.write_str(&log_context.name);
if !log_context.formatted_fields.is_empty() {
feedback_log_body.push('{');
feedback_log_body.push_str(&log_context.formatted_fields);
feedback_log_body.push('}');
let _ = feedback_log_body.write_char('{');
let _ = feedback_log_body.write_str(&log_context.formatted_fields);
let _ = feedback_log_body.write_char('}');
}
} else {
feedback_log_body.push_str(span.metadata().name());
let _ = feedback_log_body.write_str(span.metadata().name());
}
feedback_log_body.push(':');
let _ = feedback_log_body.write_char(':');
}
if !feedback_log_body.is_empty() {
feedback_log_body.push(' ');
let _ = feedback_log_body.write_char(' ');
}
}
feedback_log_body.push_str(&format_fields(event));
feedback_log_body
let _ = feedback_log_body.write_str(&format_fields(event));
feedback_log_body.finish()
}
fn format_fields<R>(fields: R) -> String
@@ -339,16 +345,162 @@ where
R: RecordFields,
{
let formatter = DefaultFields::default();
let mut formatted = FormattedFields::<DefaultFields>::new(String::new());
let _ = formatter.format_fields(formatted.as_writer(), fields);
formatted.fields
let mut formatted = BoundedLogWriter::new(MAX_LOG_VALUE_BYTES);
let _ = formatter.format_fields(Writer::new(&mut formatted), fields);
formatted.finish()
}
fn append_fields(fields: &mut String, values: &Record<'_>) {
let formatter = DefaultFields::default();
let mut formatted = FormattedFields::<DefaultFields>::new(std::mem::take(fields));
let _ = formatter.add_fields(&mut formatted, values);
*fields = formatted.fields;
let mut formatted = BoundedLogWriter::new(MAX_LOG_VALUE_BYTES);
let current = std::mem::take(fields);
if !current.is_empty() {
let _ = formatted.write_str(&current);
let _ = formatted.write_char(' ');
}
let _ = formatter.format_fields(Writer::new(&mut formatted), values);
*fields = formatted.finish();
}
struct BoundedLogWriter {
max_bytes: usize,
head: String,
tail: String,
full: Option<String>,
original_bytes: usize,
}
impl BoundedLogWriter {
fn new(max_bytes: usize) -> Self {
Self {
max_bytes,
head: String::new(),
tail: String::new(),
full: Some(String::new()),
original_bytes: 0,
}
}
fn finish(self) -> String {
if let Some(full) = self.full {
return full;
}
if self.max_bytes <= LOG_VALUE_TRUNCATED_MARKER.len() {
let boundary = utf8_prefix_boundary(LOG_VALUE_TRUNCATED_MARKER, self.max_bytes);
return LOG_VALUE_TRUNCATED_MARKER[..boundary].to_string();
}
format!("{}{}{}", self.head, LOG_VALUE_TRUNCATED_MARKER, self.tail)
}
fn is_empty(&self) -> bool {
self.original_bytes == 0
}
fn preview_capacity(&self) -> usize {
self.max_bytes
.saturating_sub(LOG_VALUE_TRUNCATED_MARKER.len())
}
fn head_capacity(&self) -> usize {
self.preview_capacity() / 2
}
fn tail_capacity(&self) -> usize {
self.preview_capacity() - self.head_capacity()
}
fn push_head(&mut self, value: &str) {
let remaining = self.head_capacity().saturating_sub(self.head.len());
if remaining == 0 {
return;
}
let boundary = utf8_prefix_boundary(value, remaining);
self.head.push_str(&value[..boundary]);
}
fn push_tail(&mut self, value: &str) {
let capacity = self.tail_capacity();
if capacity == 0 {
return;
}
if value.len() >= capacity {
let start = utf8_suffix_boundary(value, capacity);
self.tail.clear();
self.tail.push_str(&value[start..]);
return;
}
self.tail.push_str(value);
if self.tail.len() > capacity {
let start = utf8_suffix_boundary(&self.tail, capacity);
self.tail.replace_range(..start, "");
}
}
}
impl fmt::Write for BoundedLogWriter {
fn write_str(&mut self, value: &str) -> fmt::Result {
let original_bytes = self.original_bytes.saturating_add(value.len());
self.push_head(value);
self.push_tail(value);
if let Some(full) = &mut self.full {
if original_bytes <= self.max_bytes {
full.push_str(value);
} else {
self.full = None;
}
}
self.original_bytes = original_bytes;
Ok(())
}
}
fn bounded_log_value(value: &str) -> String {
let mut writer = BoundedLogWriter::new(MAX_LOG_VALUE_BYTES);
let _ = writer.write_str(value);
writer.finish()
}
fn bounded_log_debug(value: &dyn std::fmt::Debug) -> String {
let mut writer = BoundedLogWriter::new(MAX_LOG_VALUE_BYTES);
let _ = writer.write_fmt(format_args!("{value:?}"));
writer.finish()
}
fn bounded_log_display(value: &dyn fmt::Display) -> String {
let mut writer = BoundedLogWriter::new(MAX_LOG_VALUE_BYTES);
let _ = writer.write_fmt(format_args!("{value}"));
writer.finish()
}
fn utf8_prefix_boundary(value: &str, max_bytes: usize) -> usize {
if value.len() <= max_bytes {
return value.len();
}
let mut boundary = max_bytes;
while boundary > 0 && !value.is_char_boundary(boundary) {
boundary -= 1;
}
boundary
}
fn utf8_suffix_boundary(value: &str, max_bytes: usize) -> usize {
if value.len() <= max_bytes {
return 0;
}
let mut boundary = value.len().saturating_sub(max_bytes);
while boundary < value.len() && !value.is_char_boundary(boundary) {
boundary += 1;
}
boundary
}
fn current_process_log_uuid() -> &'static str {
@@ -411,43 +563,51 @@ struct MessageVisitor {
}
impl MessageVisitor {
fn record_field(&mut self, field: &Field, value: String) {
fn record_field(&mut self, field: &Field, value: &str) {
if field.name() == "message" && self.message.is_none() {
self.message = Some(value.clone());
self.message = Some(bounded_log_value(value));
}
if field.name() == "thread_id" && self.thread_id.is_none() {
self.thread_id = Some(value);
self.thread_id = Some(value.to_string());
}
}
}
impl Visit for MessageVisitor {
fn record_i64(&mut self, field: &Field, value: i64) {
self.record_field(field, value.to_string());
self.record_field(field, &value.to_string());
}
fn record_u64(&mut self, field: &Field, value: u64) {
self.record_field(field, value.to_string());
self.record_field(field, &value.to_string());
}
fn record_bool(&mut self, field: &Field, value: bool) {
self.record_field(field, value.to_string());
self.record_field(field, &value.to_string());
}
fn record_f64(&mut self, field: &Field, value: f64) {
self.record_field(field, value.to_string());
self.record_field(field, &value.to_string());
}
fn record_str(&mut self, field: &Field, value: &str) {
self.record_field(field, value.to_string());
self.record_field(field, value);
}
fn record_error(&mut self, field: &Field, value: &(dyn std::error::Error + 'static)) {
self.record_field(field, value.to_string());
if field.name() == "message" && self.message.is_none() {
self.message = Some(bounded_log_display(value));
} else if field.name() == "thread_id" && self.thread_id.is_none() {
self.thread_id = Some(value.to_string());
}
}
fn record_debug(&mut self, field: &Field, value: &dyn std::fmt::Debug) {
self.record_field(field, format!("{value:?}"));
if field.name() == "message" && self.message.is_none() {
self.message = Some(bounded_log_debug(value));
} else if field.name() == "thread_id" && self.thread_id.is_none() {
self.thread_id = Some(format!("{value:?}"));
}
}
}
@@ -504,6 +664,16 @@ mod tests {
}
}
fn assert_bounded_log_value(value: &str) {
assert!(
value.len() <= MAX_LOG_VALUE_BYTES,
"log value exceeded cap: {} > {}",
value.len(),
MAX_LOG_VALUE_BYTES
);
assert!(value.contains(LOG_VALUE_TRUNCATED_MARKER));
}
#[derive(Clone, Default)]
struct SharedWriter {
bytes: Arc<Mutex<Vec<u8>>>,
@@ -633,6 +803,75 @@ mod tests {
let _ = tokio::fs::remove_dir_all(codex_home).await;
}
#[tokio::test]
async fn large_event_payloads_are_bounded_before_sqlite_insert() {
let (sender, mut receiver) = mpsc::channel(8);
let layer = LogDbLayer {
sender,
process_uuid: "process-1".to_string(),
};
let large = "x".repeat(MAX_LOG_VALUE_BYTES * 3);
let guard = tracing_subscriber::registry()
.with(
layer
.clone()
.with_filter(Targets::new().with_default(tracing::Level::TRACE)),
)
.set_default();
tracing::info!(large_field = %large, "{large}");
drop(guard);
let entry = match receiver.recv().await.expect("queued log entry") {
LogDbCommand::Entry(entry) => entry,
LogDbCommand::Flush(_) => panic!("expected queued entry"),
};
assert_bounded_log_value(entry.message.as_deref().expect("message"));
assert_bounded_log_value(
entry
.feedback_log_body
.as_deref()
.expect("feedback log body"),
);
}
#[tokio::test]
async fn large_span_fields_are_bounded_before_sqlite_insert() {
let (sender, mut receiver) = mpsc::channel(8);
let layer = LogDbLayer {
sender,
process_uuid: "process-1".to_string(),
};
let large = "x".repeat(MAX_LOG_VALUE_BYTES * 3);
let guard = tracing_subscriber::registry()
.with(
layer
.clone()
.with_filter(Targets::new().with_default(tracing::Level::TRACE)),
)
.set_default();
let span = tracing::info_span!("large-span", large_field = %large);
let _span_guard = span.enter();
tracing::info!("small-message");
drop(_span_guard);
drop(guard);
let entry = match receiver.recv().await.expect("queued log entry") {
LogDbCommand::Entry(entry) => entry,
LogDbCommand::Flush(_) => panic!("expected queued entry"),
};
assert_eq!(entry.message.as_deref(), Some("small-message"));
assert_bounded_log_value(
entry
.feedback_log_body
.as_deref()
.expect("feedback log body"),
);
}
#[tokio::test]
async fn configured_batch_size_flushes_without_explicit_flush() {
let codex_home = temp_codex_home();