mirror of
https://github.com/openai/codex.git
synced 2026-05-02 18:37:01 +00:00
feat: polluted memories (#13008)
Add a feature flag to disable memory creation for "polluted"
This commit is contained in:
@@ -277,7 +277,8 @@ SELECT
|
||||
FROM stage1_outputs AS so
|
||||
LEFT JOIN threads AS t
|
||||
ON t.id = so.thread_id
|
||||
WHERE length(trim(so.raw_memory)) > 0 OR length(trim(so.rollout_summary)) > 0
|
||||
WHERE t.memory_mode = 'enabled'
|
||||
AND (length(trim(so.raw_memory)) > 0 OR length(trim(so.rollout_summary)) > 0)
|
||||
ORDER BY so.source_updated_at DESC, so.thread_id DESC
|
||||
LIMIT ?
|
||||
"#,
|
||||
@@ -304,11 +305,13 @@ LIMIT ?
|
||||
/// `thread_id DESC`
|
||||
/// - previously selected rows are identified by `selected_for_phase2 = 1`
|
||||
/// - `previous_selected` contains the current persisted rows that belonged
|
||||
/// to the last successful phase-2 baseline
|
||||
/// to the last successful phase-2 baseline, even if those threads are no
|
||||
/// longer memory-eligible
|
||||
/// - `retained_thread_ids` records which current rows still match the exact
|
||||
/// snapshot selected in the last successful phase-2 run
|
||||
/// - removed rows are previously selected rows that are still present in
|
||||
/// `stage1_outputs` but fall outside the current top-`n` selection
|
||||
/// `stage1_outputs` but are no longer in the current selection, including
|
||||
/// threads that are no longer memory-eligible
|
||||
pub async fn get_phase2_input_selection(
|
||||
&self,
|
||||
n: usize,
|
||||
@@ -336,7 +339,8 @@ SELECT
|
||||
FROM stage1_outputs AS so
|
||||
LEFT JOIN threads AS t
|
||||
ON t.id = so.thread_id
|
||||
WHERE (length(trim(so.raw_memory)) > 0 OR length(trim(so.rollout_summary)) > 0)
|
||||
WHERE t.memory_mode = 'enabled'
|
||||
AND (length(trim(so.raw_memory)) > 0 OR length(trim(so.rollout_summary)) > 0)
|
||||
AND (
|
||||
(so.last_usage IS NOT NULL AND so.last_usage >= ?)
|
||||
OR (so.last_usage IS NULL AND so.source_updated_at >= ?)
|
||||
@@ -421,6 +425,51 @@ ORDER BY so.source_updated_at DESC, so.thread_id DESC
|
||||
})
|
||||
}
|
||||
|
||||
/// Marks a thread as polluted and enqueues phase-2 forgetting when the
|
||||
/// thread participated in the last successful phase-2 baseline.
|
||||
pub async fn mark_thread_memory_mode_polluted(
|
||||
&self,
|
||||
thread_id: ThreadId,
|
||||
) -> anyhow::Result<bool> {
|
||||
let now = Utc::now().timestamp();
|
||||
let thread_id = thread_id.to_string();
|
||||
let mut tx = self.pool.begin().await?;
|
||||
let rows_affected = sqlx::query(
|
||||
r#"
|
||||
UPDATE threads
|
||||
SET memory_mode = 'polluted'
|
||||
WHERE id = ? AND memory_mode != 'polluted'
|
||||
"#,
|
||||
)
|
||||
.bind(thread_id.as_str())
|
||||
.execute(&mut *tx)
|
||||
.await?
|
||||
.rows_affected();
|
||||
|
||||
if rows_affected == 0 {
|
||||
tx.commit().await?;
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
let selected_for_phase2 = sqlx::query_scalar::<_, i64>(
|
||||
r#"
|
||||
SELECT selected_for_phase2
|
||||
FROM stage1_outputs
|
||||
WHERE thread_id = ?
|
||||
"#,
|
||||
)
|
||||
.bind(thread_id.as_str())
|
||||
.fetch_optional(&mut *tx)
|
||||
.await?
|
||||
.unwrap_or(0);
|
||||
if selected_for_phase2 != 0 {
|
||||
enqueue_global_consolidation_with_executor(&mut *tx, now).await?;
|
||||
}
|
||||
|
||||
tx.commit().await?;
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// Attempts to claim a stage-1 job for a thread at `source_updated_at`.
|
||||
///
|
||||
/// Claim semantics:
|
||||
@@ -2569,6 +2618,71 @@ VALUES (?, ?, ?, ?, ?)
|
||||
let _ = tokio::fs::remove_dir_all(codex_home).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn list_stage1_outputs_for_global_skips_polluted_threads() {
|
||||
let codex_home = unique_temp_dir();
|
||||
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None)
|
||||
.await
|
||||
.expect("initialize runtime");
|
||||
|
||||
let thread_id_enabled =
|
||||
ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id");
|
||||
let thread_id_polluted =
|
||||
ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id");
|
||||
let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id");
|
||||
|
||||
for (thread_id, workspace) in [
|
||||
(thread_id_enabled, "workspace-enabled"),
|
||||
(thread_id_polluted, "workspace-polluted"),
|
||||
] {
|
||||
runtime
|
||||
.upsert_thread(&test_thread_metadata(
|
||||
&codex_home,
|
||||
thread_id,
|
||||
codex_home.join(workspace),
|
||||
))
|
||||
.await
|
||||
.expect("upsert thread");
|
||||
|
||||
let claim = runtime
|
||||
.try_claim_stage1_job(thread_id, owner, 100, 3600, 64)
|
||||
.await
|
||||
.expect("claim stage1");
|
||||
let ownership_token = match claim {
|
||||
Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token,
|
||||
other => panic!("unexpected stage1 claim outcome: {other:?}"),
|
||||
};
|
||||
assert!(
|
||||
runtime
|
||||
.mark_stage1_job_succeeded(
|
||||
thread_id,
|
||||
ownership_token.as_str(),
|
||||
100,
|
||||
"raw memory",
|
||||
"summary",
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("mark stage1 succeeded"),
|
||||
"stage1 success should persist output"
|
||||
);
|
||||
}
|
||||
|
||||
runtime
|
||||
.set_thread_memory_mode(thread_id_polluted, "polluted")
|
||||
.await
|
||||
.expect("mark thread polluted");
|
||||
|
||||
let outputs = runtime
|
||||
.list_stage1_outputs_for_global(10)
|
||||
.await
|
||||
.expect("list stage1 outputs for global");
|
||||
assert_eq!(outputs.len(), 1);
|
||||
assert_eq!(outputs[0].thread_id, thread_id_enabled);
|
||||
|
||||
let _ = tokio::fs::remove_dir_all(codex_home).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn get_phase2_input_selection_reports_added_retained_and_removed_rows() {
|
||||
let codex_home = unique_temp_dir();
|
||||
@@ -2681,6 +2795,197 @@ VALUES (?, ?, ?, ?, ?)
|
||||
let _ = tokio::fs::remove_dir_all(codex_home).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn get_phase2_input_selection_marks_polluted_previous_selection_as_removed() {
|
||||
let codex_home = unique_temp_dir();
|
||||
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None)
|
||||
.await
|
||||
.expect("initialize runtime");
|
||||
|
||||
let thread_id_enabled =
|
||||
ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id");
|
||||
let thread_id_polluted =
|
||||
ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id");
|
||||
let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id");
|
||||
|
||||
for (thread_id, updated_at) in [(thread_id_enabled, 100), (thread_id_polluted, 101)] {
|
||||
runtime
|
||||
.upsert_thread(&test_thread_metadata(
|
||||
&codex_home,
|
||||
thread_id,
|
||||
codex_home.join(thread_id.to_string()),
|
||||
))
|
||||
.await
|
||||
.expect("upsert thread");
|
||||
|
||||
let claim = runtime
|
||||
.try_claim_stage1_job(thread_id, owner, updated_at, 3600, 64)
|
||||
.await
|
||||
.expect("claim stage1");
|
||||
let ownership_token = match claim {
|
||||
Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token,
|
||||
other => panic!("unexpected stage1 claim outcome: {other:?}"),
|
||||
};
|
||||
assert!(
|
||||
runtime
|
||||
.mark_stage1_job_succeeded(
|
||||
thread_id,
|
||||
ownership_token.as_str(),
|
||||
updated_at,
|
||||
&format!("raw-{updated_at}"),
|
||||
&format!("summary-{updated_at}"),
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("mark stage1 succeeded"),
|
||||
"stage1 success should persist output"
|
||||
);
|
||||
}
|
||||
|
||||
let claim = runtime
|
||||
.try_claim_global_phase2_job(owner, 3600)
|
||||
.await
|
||||
.expect("claim phase2");
|
||||
let (ownership_token, input_watermark) = match claim {
|
||||
Phase2JobClaimOutcome::Claimed {
|
||||
ownership_token,
|
||||
input_watermark,
|
||||
} => (ownership_token, input_watermark),
|
||||
other => panic!("unexpected phase2 claim outcome: {other:?}"),
|
||||
};
|
||||
let selected_outputs = runtime
|
||||
.list_stage1_outputs_for_global(10)
|
||||
.await
|
||||
.expect("list stage1 outputs for global");
|
||||
assert!(
|
||||
runtime
|
||||
.mark_global_phase2_job_succeeded(
|
||||
ownership_token.as_str(),
|
||||
input_watermark,
|
||||
&selected_outputs,
|
||||
)
|
||||
.await
|
||||
.expect("mark phase2 success"),
|
||||
"phase2 success should persist selected rows"
|
||||
);
|
||||
|
||||
runtime
|
||||
.set_thread_memory_mode(thread_id_polluted, "polluted")
|
||||
.await
|
||||
.expect("mark thread polluted");
|
||||
|
||||
let selection = runtime
|
||||
.get_phase2_input_selection(2, 36_500)
|
||||
.await
|
||||
.expect("load phase2 input selection");
|
||||
|
||||
assert_eq!(selection.selected.len(), 1);
|
||||
assert_eq!(selection.selected[0].thread_id, thread_id_enabled);
|
||||
assert_eq!(selection.previous_selected.len(), 2);
|
||||
assert!(
|
||||
selection
|
||||
.previous_selected
|
||||
.iter()
|
||||
.any(|item| item.thread_id == thread_id_enabled)
|
||||
);
|
||||
assert!(
|
||||
selection
|
||||
.previous_selected
|
||||
.iter()
|
||||
.any(|item| item.thread_id == thread_id_polluted)
|
||||
);
|
||||
assert_eq!(selection.retained_thread_ids, vec![thread_id_enabled]);
|
||||
assert_eq!(selection.removed.len(), 1);
|
||||
assert_eq!(selection.removed[0].thread_id, thread_id_polluted);
|
||||
|
||||
let _ = tokio::fs::remove_dir_all(codex_home).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn mark_thread_memory_mode_polluted_enqueues_phase2_for_selected_threads() {
|
||||
let codex_home = unique_temp_dir();
|
||||
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None)
|
||||
.await
|
||||
.expect("initialize runtime");
|
||||
|
||||
let thread_id = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id");
|
||||
let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id");
|
||||
runtime
|
||||
.upsert_thread(&test_thread_metadata(
|
||||
&codex_home,
|
||||
thread_id,
|
||||
codex_home.join("workspace"),
|
||||
))
|
||||
.await
|
||||
.expect("upsert thread");
|
||||
|
||||
let claim = runtime
|
||||
.try_claim_stage1_job(thread_id, owner, 100, 3600, 64)
|
||||
.await
|
||||
.expect("claim stage1");
|
||||
let ownership_token = match claim {
|
||||
Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token,
|
||||
other => panic!("unexpected stage1 claim outcome: {other:?}"),
|
||||
};
|
||||
assert!(
|
||||
runtime
|
||||
.mark_stage1_job_succeeded(
|
||||
thread_id,
|
||||
ownership_token.as_str(),
|
||||
100,
|
||||
"raw",
|
||||
"summary",
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("mark stage1 succeeded"),
|
||||
"stage1 success should persist output"
|
||||
);
|
||||
|
||||
let phase2_claim = runtime
|
||||
.try_claim_global_phase2_job(owner, 3600)
|
||||
.await
|
||||
.expect("claim phase2");
|
||||
let (phase2_token, input_watermark) = match phase2_claim {
|
||||
Phase2JobClaimOutcome::Claimed {
|
||||
ownership_token,
|
||||
input_watermark,
|
||||
} => (ownership_token, input_watermark),
|
||||
other => panic!("unexpected phase2 claim outcome: {other:?}"),
|
||||
};
|
||||
let selected_outputs = runtime
|
||||
.list_stage1_outputs_for_global(10)
|
||||
.await
|
||||
.expect("list stage1 outputs");
|
||||
assert!(
|
||||
runtime
|
||||
.mark_global_phase2_job_succeeded(
|
||||
phase2_token.as_str(),
|
||||
input_watermark,
|
||||
&selected_outputs,
|
||||
)
|
||||
.await
|
||||
.expect("mark phase2 success"),
|
||||
"phase2 success should persist selected rows"
|
||||
);
|
||||
|
||||
assert!(
|
||||
runtime
|
||||
.mark_thread_memory_mode_polluted(thread_id)
|
||||
.await
|
||||
.expect("mark thread polluted"),
|
||||
"thread should transition to polluted"
|
||||
);
|
||||
|
||||
let next_claim = runtime
|
||||
.try_claim_global_phase2_job(owner, 3600)
|
||||
.await
|
||||
.expect("claim phase2 after pollution");
|
||||
assert!(matches!(next_claim, Phase2JobClaimOutcome::Claimed { .. }));
|
||||
|
||||
let _ = tokio::fs::remove_dir_all(codex_home).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn get_phase2_input_selection_treats_regenerated_selected_rows_as_added() {
|
||||
let codex_home = unique_temp_dir();
|
||||
|
||||
@@ -35,6 +35,14 @@ WHERE id = ?
|
||||
.transpose()
|
||||
}
|
||||
|
||||
pub async fn get_thread_memory_mode(&self, id: ThreadId) -> anyhow::Result<Option<String>> {
|
||||
let row = sqlx::query("SELECT memory_mode FROM threads WHERE id = ?")
|
||||
.bind(id.to_string())
|
||||
.fetch_optional(self.pool.as_ref())
|
||||
.await?;
|
||||
Ok(row.and_then(|row| row.try_get("memory_mode").ok()))
|
||||
}
|
||||
|
||||
/// Get dynamic tools for a thread, if present.
|
||||
pub async fn get_dynamic_tools(
|
||||
&self,
|
||||
@@ -199,6 +207,19 @@ FROM threads
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn set_thread_memory_mode(
|
||||
&self,
|
||||
thread_id: ThreadId,
|
||||
memory_mode: &str,
|
||||
) -> anyhow::Result<bool> {
|
||||
let result = sqlx::query("UPDATE threads SET memory_mode = ? WHERE id = ?")
|
||||
.bind(memory_mode)
|
||||
.bind(thread_id.to_string())
|
||||
.execute(self.pool.as_ref())
|
||||
.await?;
|
||||
Ok(result.rows_affected() > 0)
|
||||
}
|
||||
|
||||
async fn upsert_thread_with_creation_memory_mode(
|
||||
&self,
|
||||
metadata: &crate::ThreadMetadata,
|
||||
@@ -357,6 +378,16 @@ ON CONFLICT(thread_id, position) DO NOTHING
|
||||
}
|
||||
return Err(err);
|
||||
}
|
||||
if let Some(memory_mode) = extract_memory_mode(items)
|
||||
&& let Err(err) = self
|
||||
.set_thread_memory_mode(builder.id, memory_mode.as_str())
|
||||
.await
|
||||
{
|
||||
if let Some(otel) = otel {
|
||||
otel.counter(DB_ERROR_METRIC, 1, &[("stage", "set_thread_memory_mode")]);
|
||||
}
|
||||
return Err(err);
|
||||
}
|
||||
let dynamic_tools = extract_dynamic_tools(items);
|
||||
if let Some(dynamic_tools) = dynamic_tools
|
||||
&& let Err(err) = self
|
||||
@@ -438,6 +469,16 @@ pub(super) fn extract_dynamic_tools(items: &[RolloutItem]) -> Option<Option<Vec<
|
||||
})
|
||||
}
|
||||
|
||||
pub(super) fn extract_memory_mode(items: &[RolloutItem]) -> Option<String> {
|
||||
items.iter().rev().find_map(|item| match item {
|
||||
RolloutItem::SessionMeta(meta_line) => meta_line.meta.memory_mode.clone(),
|
||||
RolloutItem::ResponseItem(_)
|
||||
| RolloutItem::Compacted(_)
|
||||
| RolloutItem::TurnContext(_)
|
||||
| RolloutItem::EventMsg(_) => None,
|
||||
})
|
||||
}
|
||||
|
||||
pub(super) fn push_thread_filters<'a>(
|
||||
builder: &mut QueryBuilder<'a, Sqlite>,
|
||||
archived_only: bool,
|
||||
@@ -518,7 +559,11 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::runtime::test_support::test_thread_metadata;
|
||||
use crate::runtime::test_support::unique_temp_dir;
|
||||
use codex_protocol::protocol::SessionMeta;
|
||||
use codex_protocol::protocol::SessionMetaLine;
|
||||
use codex_protocol::protocol::SessionSource;
|
||||
use pretty_assertions::assert_eq;
|
||||
use std::path::PathBuf;
|
||||
|
||||
#[tokio::test]
|
||||
async fn upsert_thread_keeps_creation_memory_mode_for_existing_rows() {
|
||||
@@ -557,4 +602,56 @@ mod tests {
|
||||
.expect("memory mode should remain readable");
|
||||
assert_eq!(memory_mode, "disabled");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn apply_rollout_items_restores_memory_mode_from_session_meta() {
|
||||
let codex_home = unique_temp_dir();
|
||||
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None)
|
||||
.await
|
||||
.expect("state db should initialize");
|
||||
let thread_id =
|
||||
ThreadId::from_string("00000000-0000-0000-0000-000000000456").expect("valid thread id");
|
||||
let metadata = test_thread_metadata(&codex_home, thread_id, codex_home.clone());
|
||||
|
||||
runtime
|
||||
.upsert_thread(&metadata)
|
||||
.await
|
||||
.expect("initial upsert should succeed");
|
||||
|
||||
let builder = ThreadMetadataBuilder::new(
|
||||
thread_id,
|
||||
metadata.rollout_path.clone(),
|
||||
metadata.created_at,
|
||||
SessionSource::Cli,
|
||||
);
|
||||
let items = vec![RolloutItem::SessionMeta(SessionMetaLine {
|
||||
meta: SessionMeta {
|
||||
id: thread_id,
|
||||
forked_from_id: None,
|
||||
timestamp: metadata.created_at.to_rfc3339(),
|
||||
cwd: PathBuf::new(),
|
||||
originator: String::new(),
|
||||
cli_version: String::new(),
|
||||
source: SessionSource::Cli,
|
||||
agent_nickname: None,
|
||||
agent_role: None,
|
||||
model_provider: None,
|
||||
base_instructions: None,
|
||||
dynamic_tools: None,
|
||||
memory_mode: Some("polluted".to_string()),
|
||||
},
|
||||
git: None,
|
||||
})];
|
||||
|
||||
runtime
|
||||
.apply_rollout_items(&builder, &items, None, None)
|
||||
.await
|
||||
.expect("apply_rollout_items should succeed");
|
||||
|
||||
let memory_mode = runtime
|
||||
.get_thread_memory_mode(thread_id)
|
||||
.await
|
||||
.expect("memory mode should load");
|
||||
assert_eq!(memory_mode.as_deref(), Some("polluted"));
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user