feat: polluted memories (#13008)

Add a feature flag to disable memory creation for "polluted"
This commit is contained in:
jif-oai
2026-03-02 12:57:32 +01:00
committed by GitHub
parent b08bdd91e3
commit b649953845
19 changed files with 939 additions and 33 deletions

View File

@@ -11,7 +11,9 @@ use core_test_support::responses::ResponsesRequest;
use core_test_support::responses::ev_assistant_message;
use core_test_support::responses::ev_completed;
use core_test_support::responses::ev_response_created;
use core_test_support::responses::ev_web_search_call_done;
use core_test_support::responses::mount_sse_once;
use core_test_support::responses::mount_sse_sequence;
use core_test_support::responses::sse;
use core_test_support::responses::start_mock_server;
use core_test_support::test_codex::TestCodex;
@@ -157,6 +159,156 @@ async fn memories_startup_phase2_tracks_added_and_removed_inputs_across_runs() -
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn web_search_pollution_moves_selected_thread_into_removed_phase2_inputs() -> Result<()> {
let server = start_mock_server().await;
let home = Arc::new(TempDir::new()?);
let db = init_state_db(&home).await?;
let mut initial_builder = test_codex().with_home(home.clone()).with_config(|config| {
config.features.enable(Feature::Sqlite);
config.features.enable(Feature::MemoryTool);
config.memories.max_raw_memories_for_global = 1;
config.memories.no_memories_if_mcp_or_web_search = true;
});
let initial = initial_builder.build(&server).await?;
mount_sse_once(
&server,
sse(vec![
ev_response_created("resp-initial-1"),
ev_assistant_message("msg-initial-1", "initial turn complete"),
ev_completed("resp-initial-1"),
]),
)
.await;
initial.submit_turn("hello before memories").await?;
let rollout_path = initial
.session_configured
.rollout_path
.clone()
.expect("rollout path");
let thread_id = initial.session_configured.session_id;
let updated_at = {
let deadline = Instant::now() + Duration::from_secs(10);
loop {
if let Some(metadata) = db.get_thread(thread_id).await? {
break metadata.updated_at;
}
assert!(
Instant::now() < deadline,
"timed out waiting for thread metadata for {thread_id}"
);
tokio::time::sleep(Duration::from_millis(50)).await;
}
};
seed_stage1_output_for_existing_thread(
db.as_ref(),
thread_id,
updated_at.timestamp(),
"raw memory seeded for web search pollution",
"rollout summary seeded for web search pollution",
Some("pollution-rollout"),
)
.await?;
shutdown_test_codex(&initial).await?;
let responses = mount_sse_sequence(
&server,
vec![
sse(vec![
ev_response_created("resp-phase2-1"),
ev_assistant_message("msg-phase2-1", "phase2 complete"),
ev_completed("resp-phase2-1"),
]),
sse(vec![
ev_response_created("resp-web-1"),
ev_web_search_call_done("ws-1", "completed", "weather seattle"),
ev_completed("resp-web-1"),
]),
],
)
.await;
let mut resumed_builder = test_codex().with_home(home.clone()).with_config(|config| {
config.features.enable(Feature::Sqlite);
config.features.enable(Feature::MemoryTool);
config.memories.max_raw_memories_for_global = 1;
config.memories.no_memories_if_mcp_or_web_search = true;
});
let resumed = resumed_builder
.resume(&server, home.clone(), rollout_path.clone())
.await?;
let first_phase2_request = wait_for_request(&responses, 1).await.remove(0);
let first_phase2_prompt = phase2_prompt_text(&first_phase2_request);
assert!(
first_phase2_prompt.contains("- selected inputs this run: 1"),
"expected seeded thread to be selected before pollution: {first_phase2_prompt}"
);
assert!(
first_phase2_prompt.contains("- newly added since the last successful Phase 2 run: 1"),
"expected seeded thread to be added before pollution: {first_phase2_prompt}"
);
assert!(
first_phase2_prompt.contains(&format!("- [added] thread_id={thread_id},")),
"expected selected thread in first phase2 prompt: {first_phase2_prompt}"
);
wait_for_phase2_success(db.as_ref(), thread_id).await?;
resumed
.submit_turn("search the web for weather seattle")
.await?;
assert_eq!(
{
let deadline = Instant::now() + Duration::from_secs(10);
loop {
let memory_mode = db.get_thread_memory_mode(thread_id).await?;
if memory_mode.as_deref() == Some("polluted") {
break memory_mode;
}
assert!(
Instant::now() < deadline,
"timed out waiting for polluted memory mode for {thread_id}"
);
tokio::time::sleep(Duration::from_millis(50)).await;
}
}
.as_deref(),
Some("polluted")
);
let selection = {
let deadline = Instant::now() + Duration::from_secs(10);
loop {
let selection = db.get_phase2_input_selection(1, 30).await?;
if selection.selected.is_empty()
&& selection.retained_thread_ids.is_empty()
&& selection.removed.len() == 1
&& selection.removed[0].thread_id == thread_id
{
break selection;
}
assert!(
Instant::now() < deadline,
"timed out waiting for polluted thread to move into removed phase2 inputs: \
{selection:?}"
);
tokio::time::sleep(Duration::from_millis(50)).await;
}
};
assert_eq!(responses.requests().len(), 2);
assert!(selection.selected.is_empty());
assert_eq!(selection.retained_thread_ids, Vec::<ThreadId>::new());
assert_eq!(selection.removed.len(), 1);
assert_eq!(selection.removed[0].thread_id, thread_id);
shutdown_test_codex(&resumed).await?;
Ok(())
}
async fn build_test_codex(server: &wiremock::MockServer, home: Arc<TempDir>) -> Result<TestCodex> {
let mut builder = test_codex().with_home(home).with_config(|config| {
config.features.enable(Feature::Sqlite);
@@ -195,46 +347,33 @@ async fn seed_stage1_output(
let metadata = metadata_builder.build("test-provider");
db.upsert_thread(&metadata).await?;
let claim = db
.try_claim_stage1_job(
thread_id,
ThreadId::new(),
updated_at.timestamp(),
3_600,
64,
)
.await?;
let ownership_token = match claim {
codex_state::Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token,
other => panic!("unexpected stage-1 claim outcome: {other:?}"),
};
assert!(
db.mark_stage1_job_succeeded(
thread_id,
&ownership_token,
updated_at.timestamp(),
raw_memory,
rollout_summary,
Some(rollout_slug),
)
.await?,
"stage-1 success should enqueue global consolidation"
);
seed_stage1_output_for_existing_thread(
db,
thread_id,
updated_at.timestamp(),
raw_memory,
rollout_summary,
Some(rollout_slug),
)
.await?;
Ok(thread_id)
}
async fn wait_for_single_request(mock: &ResponseMock) -> ResponsesRequest {
wait_for_request(mock, 1).await.remove(0)
}
async fn wait_for_request(mock: &ResponseMock, expected_count: usize) -> Vec<ResponsesRequest> {
let deadline = Instant::now() + Duration::from_secs(10);
loop {
let requests = mock.requests();
if let Some(request) = requests.into_iter().next() {
return request;
if requests.len() >= expected_count {
return requests;
}
assert!(
Instant::now() < deadline,
"timed out waiting for phase2 request"
"timed out waiting for {expected_count} phase2 requests"
);
tokio::time::sleep(Duration::from_millis(50)).await;
}
@@ -272,6 +411,39 @@ async fn wait_for_phase2_success(
}
}
async fn seed_stage1_output_for_existing_thread(
db: &codex_state::StateRuntime,
thread_id: ThreadId,
updated_at: i64,
raw_memory: &str,
rollout_summary: &str,
rollout_slug: Option<&str>,
) -> Result<()> {
let owner = ThreadId::new();
let claim = db
.try_claim_stage1_job(thread_id, owner, updated_at, 3_600, 64)
.await?;
let ownership_token = match claim {
codex_state::Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token,
other => panic!("unexpected stage-1 claim outcome: {other:?}"),
};
assert!(
db.mark_stage1_job_succeeded(
thread_id,
&ownership_token,
updated_at,
raw_memory,
rollout_summary,
rollout_slug,
)
.await?,
"stage-1 success should enqueue global consolidation"
);
Ok(())
}
async fn read_rollout_summary_bodies(memory_root: &Path) -> Result<Vec<String>> {
let mut dir = tokio::fs::read_dir(memory_root.join("rollout_summaries")).await?;
let mut summaries = Vec::new();