feat: polluted memories (#13008)

Add a feature flag to disable memory creation for "polluted"
This commit is contained in:
jif-oai
2026-03-02 12:57:32 +01:00
committed by GitHub
parent b08bdd91e3
commit b649953845
19 changed files with 939 additions and 33 deletions

View File

@@ -277,7 +277,8 @@ SELECT
FROM stage1_outputs AS so
LEFT JOIN threads AS t
ON t.id = so.thread_id
WHERE length(trim(so.raw_memory)) > 0 OR length(trim(so.rollout_summary)) > 0
WHERE t.memory_mode = 'enabled'
AND (length(trim(so.raw_memory)) > 0 OR length(trim(so.rollout_summary)) > 0)
ORDER BY so.source_updated_at DESC, so.thread_id DESC
LIMIT ?
"#,
@@ -304,11 +305,13 @@ LIMIT ?
/// `thread_id DESC`
/// - previously selected rows are identified by `selected_for_phase2 = 1`
/// - `previous_selected` contains the current persisted rows that belonged
/// to the last successful phase-2 baseline
/// to the last successful phase-2 baseline, even if those threads are no
/// longer memory-eligible
/// - `retained_thread_ids` records which current rows still match the exact
/// snapshot selected in the last successful phase-2 run
/// - removed rows are previously selected rows that are still present in
/// `stage1_outputs` but fall outside the current top-`n` selection
/// `stage1_outputs` but are no longer in the current selection, including
/// threads that are no longer memory-eligible
pub async fn get_phase2_input_selection(
&self,
n: usize,
@@ -336,7 +339,8 @@ SELECT
FROM stage1_outputs AS so
LEFT JOIN threads AS t
ON t.id = so.thread_id
WHERE (length(trim(so.raw_memory)) > 0 OR length(trim(so.rollout_summary)) > 0)
WHERE t.memory_mode = 'enabled'
AND (length(trim(so.raw_memory)) > 0 OR length(trim(so.rollout_summary)) > 0)
AND (
(so.last_usage IS NOT NULL AND so.last_usage >= ?)
OR (so.last_usage IS NULL AND so.source_updated_at >= ?)
@@ -421,6 +425,51 @@ ORDER BY so.source_updated_at DESC, so.thread_id DESC
})
}
/// Marks a thread as polluted and enqueues phase-2 forgetting when the
/// thread participated in the last successful phase-2 baseline.
pub async fn mark_thread_memory_mode_polluted(
&self,
thread_id: ThreadId,
) -> anyhow::Result<bool> {
let now = Utc::now().timestamp();
let thread_id = thread_id.to_string();
let mut tx = self.pool.begin().await?;
let rows_affected = sqlx::query(
r#"
UPDATE threads
SET memory_mode = 'polluted'
WHERE id = ? AND memory_mode != 'polluted'
"#,
)
.bind(thread_id.as_str())
.execute(&mut *tx)
.await?
.rows_affected();
if rows_affected == 0 {
tx.commit().await?;
return Ok(false);
}
let selected_for_phase2 = sqlx::query_scalar::<_, i64>(
r#"
SELECT selected_for_phase2
FROM stage1_outputs
WHERE thread_id = ?
"#,
)
.bind(thread_id.as_str())
.fetch_optional(&mut *tx)
.await?
.unwrap_or(0);
if selected_for_phase2 != 0 {
enqueue_global_consolidation_with_executor(&mut *tx, now).await?;
}
tx.commit().await?;
Ok(true)
}
/// Attempts to claim a stage-1 job for a thread at `source_updated_at`.
///
/// Claim semantics:
@@ -2569,6 +2618,71 @@ VALUES (?, ?, ?, ?, ?)
let _ = tokio::fs::remove_dir_all(codex_home).await;
}
#[tokio::test]
async fn list_stage1_outputs_for_global_skips_polluted_threads() {
let codex_home = unique_temp_dir();
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None)
.await
.expect("initialize runtime");
let thread_id_enabled =
ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id");
let thread_id_polluted =
ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id");
let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id");
for (thread_id, workspace) in [
(thread_id_enabled, "workspace-enabled"),
(thread_id_polluted, "workspace-polluted"),
] {
runtime
.upsert_thread(&test_thread_metadata(
&codex_home,
thread_id,
codex_home.join(workspace),
))
.await
.expect("upsert thread");
let claim = runtime
.try_claim_stage1_job(thread_id, owner, 100, 3600, 64)
.await
.expect("claim stage1");
let ownership_token = match claim {
Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token,
other => panic!("unexpected stage1 claim outcome: {other:?}"),
};
assert!(
runtime
.mark_stage1_job_succeeded(
thread_id,
ownership_token.as_str(),
100,
"raw memory",
"summary",
None,
)
.await
.expect("mark stage1 succeeded"),
"stage1 success should persist output"
);
}
runtime
.set_thread_memory_mode(thread_id_polluted, "polluted")
.await
.expect("mark thread polluted");
let outputs = runtime
.list_stage1_outputs_for_global(10)
.await
.expect("list stage1 outputs for global");
assert_eq!(outputs.len(), 1);
assert_eq!(outputs[0].thread_id, thread_id_enabled);
let _ = tokio::fs::remove_dir_all(codex_home).await;
}
#[tokio::test]
async fn get_phase2_input_selection_reports_added_retained_and_removed_rows() {
let codex_home = unique_temp_dir();
@@ -2681,6 +2795,197 @@ VALUES (?, ?, ?, ?, ?)
let _ = tokio::fs::remove_dir_all(codex_home).await;
}
#[tokio::test]
async fn get_phase2_input_selection_marks_polluted_previous_selection_as_removed() {
let codex_home = unique_temp_dir();
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None)
.await
.expect("initialize runtime");
let thread_id_enabled =
ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id");
let thread_id_polluted =
ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id");
let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id");
for (thread_id, updated_at) in [(thread_id_enabled, 100), (thread_id_polluted, 101)] {
runtime
.upsert_thread(&test_thread_metadata(
&codex_home,
thread_id,
codex_home.join(thread_id.to_string()),
))
.await
.expect("upsert thread");
let claim = runtime
.try_claim_stage1_job(thread_id, owner, updated_at, 3600, 64)
.await
.expect("claim stage1");
let ownership_token = match claim {
Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token,
other => panic!("unexpected stage1 claim outcome: {other:?}"),
};
assert!(
runtime
.mark_stage1_job_succeeded(
thread_id,
ownership_token.as_str(),
updated_at,
&format!("raw-{updated_at}"),
&format!("summary-{updated_at}"),
None,
)
.await
.expect("mark stage1 succeeded"),
"stage1 success should persist output"
);
}
let claim = runtime
.try_claim_global_phase2_job(owner, 3600)
.await
.expect("claim phase2");
let (ownership_token, input_watermark) = match claim {
Phase2JobClaimOutcome::Claimed {
ownership_token,
input_watermark,
} => (ownership_token, input_watermark),
other => panic!("unexpected phase2 claim outcome: {other:?}"),
};
let selected_outputs = runtime
.list_stage1_outputs_for_global(10)
.await
.expect("list stage1 outputs for global");
assert!(
runtime
.mark_global_phase2_job_succeeded(
ownership_token.as_str(),
input_watermark,
&selected_outputs,
)
.await
.expect("mark phase2 success"),
"phase2 success should persist selected rows"
);
runtime
.set_thread_memory_mode(thread_id_polluted, "polluted")
.await
.expect("mark thread polluted");
let selection = runtime
.get_phase2_input_selection(2, 36_500)
.await
.expect("load phase2 input selection");
assert_eq!(selection.selected.len(), 1);
assert_eq!(selection.selected[0].thread_id, thread_id_enabled);
assert_eq!(selection.previous_selected.len(), 2);
assert!(
selection
.previous_selected
.iter()
.any(|item| item.thread_id == thread_id_enabled)
);
assert!(
selection
.previous_selected
.iter()
.any(|item| item.thread_id == thread_id_polluted)
);
assert_eq!(selection.retained_thread_ids, vec![thread_id_enabled]);
assert_eq!(selection.removed.len(), 1);
assert_eq!(selection.removed[0].thread_id, thread_id_polluted);
let _ = tokio::fs::remove_dir_all(codex_home).await;
}
#[tokio::test]
async fn mark_thread_memory_mode_polluted_enqueues_phase2_for_selected_threads() {
let codex_home = unique_temp_dir();
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None)
.await
.expect("initialize runtime");
let thread_id = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id");
let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id");
runtime
.upsert_thread(&test_thread_metadata(
&codex_home,
thread_id,
codex_home.join("workspace"),
))
.await
.expect("upsert thread");
let claim = runtime
.try_claim_stage1_job(thread_id, owner, 100, 3600, 64)
.await
.expect("claim stage1");
let ownership_token = match claim {
Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token,
other => panic!("unexpected stage1 claim outcome: {other:?}"),
};
assert!(
runtime
.mark_stage1_job_succeeded(
thread_id,
ownership_token.as_str(),
100,
"raw",
"summary",
None,
)
.await
.expect("mark stage1 succeeded"),
"stage1 success should persist output"
);
let phase2_claim = runtime
.try_claim_global_phase2_job(owner, 3600)
.await
.expect("claim phase2");
let (phase2_token, input_watermark) = match phase2_claim {
Phase2JobClaimOutcome::Claimed {
ownership_token,
input_watermark,
} => (ownership_token, input_watermark),
other => panic!("unexpected phase2 claim outcome: {other:?}"),
};
let selected_outputs = runtime
.list_stage1_outputs_for_global(10)
.await
.expect("list stage1 outputs");
assert!(
runtime
.mark_global_phase2_job_succeeded(
phase2_token.as_str(),
input_watermark,
&selected_outputs,
)
.await
.expect("mark phase2 success"),
"phase2 success should persist selected rows"
);
assert!(
runtime
.mark_thread_memory_mode_polluted(thread_id)
.await
.expect("mark thread polluted"),
"thread should transition to polluted"
);
let next_claim = runtime
.try_claim_global_phase2_job(owner, 3600)
.await
.expect("claim phase2 after pollution");
assert!(matches!(next_claim, Phase2JobClaimOutcome::Claimed { .. }));
let _ = tokio::fs::remove_dir_all(codex_home).await;
}
#[tokio::test]
async fn get_phase2_input_selection_treats_regenerated_selected_rows_as_added() {
let codex_home = unique_temp_dir();

View File

@@ -35,6 +35,14 @@ WHERE id = ?
.transpose()
}
pub async fn get_thread_memory_mode(&self, id: ThreadId) -> anyhow::Result<Option<String>> {
let row = sqlx::query("SELECT memory_mode FROM threads WHERE id = ?")
.bind(id.to_string())
.fetch_optional(self.pool.as_ref())
.await?;
Ok(row.and_then(|row| row.try_get("memory_mode").ok()))
}
/// Get dynamic tools for a thread, if present.
pub async fn get_dynamic_tools(
&self,
@@ -199,6 +207,19 @@ FROM threads
.await
}
pub async fn set_thread_memory_mode(
&self,
thread_id: ThreadId,
memory_mode: &str,
) -> anyhow::Result<bool> {
let result = sqlx::query("UPDATE threads SET memory_mode = ? WHERE id = ?")
.bind(memory_mode)
.bind(thread_id.to_string())
.execute(self.pool.as_ref())
.await?;
Ok(result.rows_affected() > 0)
}
async fn upsert_thread_with_creation_memory_mode(
&self,
metadata: &crate::ThreadMetadata,
@@ -357,6 +378,16 @@ ON CONFLICT(thread_id, position) DO NOTHING
}
return Err(err);
}
if let Some(memory_mode) = extract_memory_mode(items)
&& let Err(err) = self
.set_thread_memory_mode(builder.id, memory_mode.as_str())
.await
{
if let Some(otel) = otel {
otel.counter(DB_ERROR_METRIC, 1, &[("stage", "set_thread_memory_mode")]);
}
return Err(err);
}
let dynamic_tools = extract_dynamic_tools(items);
if let Some(dynamic_tools) = dynamic_tools
&& let Err(err) = self
@@ -438,6 +469,16 @@ pub(super) fn extract_dynamic_tools(items: &[RolloutItem]) -> Option<Option<Vec<
})
}
pub(super) fn extract_memory_mode(items: &[RolloutItem]) -> Option<String> {
items.iter().rev().find_map(|item| match item {
RolloutItem::SessionMeta(meta_line) => meta_line.meta.memory_mode.clone(),
RolloutItem::ResponseItem(_)
| RolloutItem::Compacted(_)
| RolloutItem::TurnContext(_)
| RolloutItem::EventMsg(_) => None,
})
}
pub(super) fn push_thread_filters<'a>(
builder: &mut QueryBuilder<'a, Sqlite>,
archived_only: bool,
@@ -518,7 +559,11 @@ mod tests {
use super::*;
use crate::runtime::test_support::test_thread_metadata;
use crate::runtime::test_support::unique_temp_dir;
use codex_protocol::protocol::SessionMeta;
use codex_protocol::protocol::SessionMetaLine;
use codex_protocol::protocol::SessionSource;
use pretty_assertions::assert_eq;
use std::path::PathBuf;
#[tokio::test]
async fn upsert_thread_keeps_creation_memory_mode_for_existing_rows() {
@@ -557,4 +602,56 @@ mod tests {
.expect("memory mode should remain readable");
assert_eq!(memory_mode, "disabled");
}
#[tokio::test]
async fn apply_rollout_items_restores_memory_mode_from_session_meta() {
let codex_home = unique_temp_dir();
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None)
.await
.expect("state db should initialize");
let thread_id =
ThreadId::from_string("00000000-0000-0000-0000-000000000456").expect("valid thread id");
let metadata = test_thread_metadata(&codex_home, thread_id, codex_home.clone());
runtime
.upsert_thread(&metadata)
.await
.expect("initial upsert should succeed");
let builder = ThreadMetadataBuilder::new(
thread_id,
metadata.rollout_path.clone(),
metadata.created_at,
SessionSource::Cli,
);
let items = vec![RolloutItem::SessionMeta(SessionMetaLine {
meta: SessionMeta {
id: thread_id,
forked_from_id: None,
timestamp: metadata.created_at.to_rfc3339(),
cwd: PathBuf::new(),
originator: String::new(),
cli_version: String::new(),
source: SessionSource::Cli,
agent_nickname: None,
agent_role: None,
model_provider: None,
base_instructions: None,
dynamic_tools: None,
memory_mode: Some("polluted".to_string()),
},
git: None,
})];
runtime
.apply_rollout_items(&builder, &items, None, None)
.await
.expect("apply_rollout_items should succeed");
let memory_mode = runtime
.get_thread_memory_mode(thread_id)
.await
.expect("memory mode should load");
assert_eq!(memory_mode.as_deref(), Some("polluted"));
}
}