mirror of
https://github.com/openai/codex.git
synced 2026-04-29 08:56:38 +00:00
feat: polluted memories (#13008)
Add a feature flag to disable memory creation for "polluted"
This commit is contained in:
@@ -11,7 +11,9 @@ use core_test_support::responses::ResponsesRequest;
|
||||
use core_test_support::responses::ev_assistant_message;
|
||||
use core_test_support::responses::ev_completed;
|
||||
use core_test_support::responses::ev_response_created;
|
||||
use core_test_support::responses::ev_web_search_call_done;
|
||||
use core_test_support::responses::mount_sse_once;
|
||||
use core_test_support::responses::mount_sse_sequence;
|
||||
use core_test_support::responses::sse;
|
||||
use core_test_support::responses::start_mock_server;
|
||||
use core_test_support::test_codex::TestCodex;
|
||||
@@ -157,6 +159,156 @@ async fn memories_startup_phase2_tracks_added_and_removed_inputs_across_runs() -
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn web_search_pollution_moves_selected_thread_into_removed_phase2_inputs() -> Result<()> {
|
||||
let server = start_mock_server().await;
|
||||
let home = Arc::new(TempDir::new()?);
|
||||
let db = init_state_db(&home).await?;
|
||||
|
||||
let mut initial_builder = test_codex().with_home(home.clone()).with_config(|config| {
|
||||
config.features.enable(Feature::Sqlite);
|
||||
config.features.enable(Feature::MemoryTool);
|
||||
config.memories.max_raw_memories_for_global = 1;
|
||||
config.memories.no_memories_if_mcp_or_web_search = true;
|
||||
});
|
||||
let initial = initial_builder.build(&server).await?;
|
||||
mount_sse_once(
|
||||
&server,
|
||||
sse(vec![
|
||||
ev_response_created("resp-initial-1"),
|
||||
ev_assistant_message("msg-initial-1", "initial turn complete"),
|
||||
ev_completed("resp-initial-1"),
|
||||
]),
|
||||
)
|
||||
.await;
|
||||
initial.submit_turn("hello before memories").await?;
|
||||
let rollout_path = initial
|
||||
.session_configured
|
||||
.rollout_path
|
||||
.clone()
|
||||
.expect("rollout path");
|
||||
let thread_id = initial.session_configured.session_id;
|
||||
let updated_at = {
|
||||
let deadline = Instant::now() + Duration::from_secs(10);
|
||||
loop {
|
||||
if let Some(metadata) = db.get_thread(thread_id).await? {
|
||||
break metadata.updated_at;
|
||||
}
|
||||
assert!(
|
||||
Instant::now() < deadline,
|
||||
"timed out waiting for thread metadata for {thread_id}"
|
||||
);
|
||||
tokio::time::sleep(Duration::from_millis(50)).await;
|
||||
}
|
||||
};
|
||||
|
||||
seed_stage1_output_for_existing_thread(
|
||||
db.as_ref(),
|
||||
thread_id,
|
||||
updated_at.timestamp(),
|
||||
"raw memory seeded for web search pollution",
|
||||
"rollout summary seeded for web search pollution",
|
||||
Some("pollution-rollout"),
|
||||
)
|
||||
.await?;
|
||||
|
||||
shutdown_test_codex(&initial).await?;
|
||||
|
||||
let responses = mount_sse_sequence(
|
||||
&server,
|
||||
vec![
|
||||
sse(vec![
|
||||
ev_response_created("resp-phase2-1"),
|
||||
ev_assistant_message("msg-phase2-1", "phase2 complete"),
|
||||
ev_completed("resp-phase2-1"),
|
||||
]),
|
||||
sse(vec![
|
||||
ev_response_created("resp-web-1"),
|
||||
ev_web_search_call_done("ws-1", "completed", "weather seattle"),
|
||||
ev_completed("resp-web-1"),
|
||||
]),
|
||||
],
|
||||
)
|
||||
.await;
|
||||
|
||||
let mut resumed_builder = test_codex().with_home(home.clone()).with_config(|config| {
|
||||
config.features.enable(Feature::Sqlite);
|
||||
config.features.enable(Feature::MemoryTool);
|
||||
config.memories.max_raw_memories_for_global = 1;
|
||||
config.memories.no_memories_if_mcp_or_web_search = true;
|
||||
});
|
||||
let resumed = resumed_builder
|
||||
.resume(&server, home.clone(), rollout_path.clone())
|
||||
.await?;
|
||||
|
||||
let first_phase2_request = wait_for_request(&responses, 1).await.remove(0);
|
||||
let first_phase2_prompt = phase2_prompt_text(&first_phase2_request);
|
||||
assert!(
|
||||
first_phase2_prompt.contains("- selected inputs this run: 1"),
|
||||
"expected seeded thread to be selected before pollution: {first_phase2_prompt}"
|
||||
);
|
||||
assert!(
|
||||
first_phase2_prompt.contains("- newly added since the last successful Phase 2 run: 1"),
|
||||
"expected seeded thread to be added before pollution: {first_phase2_prompt}"
|
||||
);
|
||||
assert!(
|
||||
first_phase2_prompt.contains(&format!("- [added] thread_id={thread_id},")),
|
||||
"expected selected thread in first phase2 prompt: {first_phase2_prompt}"
|
||||
);
|
||||
|
||||
wait_for_phase2_success(db.as_ref(), thread_id).await?;
|
||||
|
||||
resumed
|
||||
.submit_turn("search the web for weather seattle")
|
||||
.await?;
|
||||
assert_eq!(
|
||||
{
|
||||
let deadline = Instant::now() + Duration::from_secs(10);
|
||||
loop {
|
||||
let memory_mode = db.get_thread_memory_mode(thread_id).await?;
|
||||
if memory_mode.as_deref() == Some("polluted") {
|
||||
break memory_mode;
|
||||
}
|
||||
assert!(
|
||||
Instant::now() < deadline,
|
||||
"timed out waiting for polluted memory mode for {thread_id}"
|
||||
);
|
||||
tokio::time::sleep(Duration::from_millis(50)).await;
|
||||
}
|
||||
}
|
||||
.as_deref(),
|
||||
Some("polluted")
|
||||
);
|
||||
|
||||
let selection = {
|
||||
let deadline = Instant::now() + Duration::from_secs(10);
|
||||
loop {
|
||||
let selection = db.get_phase2_input_selection(1, 30).await?;
|
||||
if selection.selected.is_empty()
|
||||
&& selection.retained_thread_ids.is_empty()
|
||||
&& selection.removed.len() == 1
|
||||
&& selection.removed[0].thread_id == thread_id
|
||||
{
|
||||
break selection;
|
||||
}
|
||||
assert!(
|
||||
Instant::now() < deadline,
|
||||
"timed out waiting for polluted thread to move into removed phase2 inputs: \
|
||||
{selection:?}"
|
||||
);
|
||||
tokio::time::sleep(Duration::from_millis(50)).await;
|
||||
}
|
||||
};
|
||||
assert_eq!(responses.requests().len(), 2);
|
||||
assert!(selection.selected.is_empty());
|
||||
assert_eq!(selection.retained_thread_ids, Vec::<ThreadId>::new());
|
||||
assert_eq!(selection.removed.len(), 1);
|
||||
assert_eq!(selection.removed[0].thread_id, thread_id);
|
||||
|
||||
shutdown_test_codex(&resumed).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn build_test_codex(server: &wiremock::MockServer, home: Arc<TempDir>) -> Result<TestCodex> {
|
||||
let mut builder = test_codex().with_home(home).with_config(|config| {
|
||||
config.features.enable(Feature::Sqlite);
|
||||
@@ -195,46 +347,33 @@ async fn seed_stage1_output(
|
||||
let metadata = metadata_builder.build("test-provider");
|
||||
db.upsert_thread(&metadata).await?;
|
||||
|
||||
let claim = db
|
||||
.try_claim_stage1_job(
|
||||
thread_id,
|
||||
ThreadId::new(),
|
||||
updated_at.timestamp(),
|
||||
3_600,
|
||||
64,
|
||||
)
|
||||
.await?;
|
||||
let ownership_token = match claim {
|
||||
codex_state::Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token,
|
||||
other => panic!("unexpected stage-1 claim outcome: {other:?}"),
|
||||
};
|
||||
|
||||
assert!(
|
||||
db.mark_stage1_job_succeeded(
|
||||
thread_id,
|
||||
&ownership_token,
|
||||
updated_at.timestamp(),
|
||||
raw_memory,
|
||||
rollout_summary,
|
||||
Some(rollout_slug),
|
||||
)
|
||||
.await?,
|
||||
"stage-1 success should enqueue global consolidation"
|
||||
);
|
||||
seed_stage1_output_for_existing_thread(
|
||||
db,
|
||||
thread_id,
|
||||
updated_at.timestamp(),
|
||||
raw_memory,
|
||||
rollout_summary,
|
||||
Some(rollout_slug),
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(thread_id)
|
||||
}
|
||||
|
||||
async fn wait_for_single_request(mock: &ResponseMock) -> ResponsesRequest {
|
||||
wait_for_request(mock, 1).await.remove(0)
|
||||
}
|
||||
|
||||
async fn wait_for_request(mock: &ResponseMock, expected_count: usize) -> Vec<ResponsesRequest> {
|
||||
let deadline = Instant::now() + Duration::from_secs(10);
|
||||
loop {
|
||||
let requests = mock.requests();
|
||||
if let Some(request) = requests.into_iter().next() {
|
||||
return request;
|
||||
if requests.len() >= expected_count {
|
||||
return requests;
|
||||
}
|
||||
assert!(
|
||||
Instant::now() < deadline,
|
||||
"timed out waiting for phase2 request"
|
||||
"timed out waiting for {expected_count} phase2 requests"
|
||||
);
|
||||
tokio::time::sleep(Duration::from_millis(50)).await;
|
||||
}
|
||||
@@ -272,6 +411,39 @@ async fn wait_for_phase2_success(
|
||||
}
|
||||
}
|
||||
|
||||
async fn seed_stage1_output_for_existing_thread(
|
||||
db: &codex_state::StateRuntime,
|
||||
thread_id: ThreadId,
|
||||
updated_at: i64,
|
||||
raw_memory: &str,
|
||||
rollout_summary: &str,
|
||||
rollout_slug: Option<&str>,
|
||||
) -> Result<()> {
|
||||
let owner = ThreadId::new();
|
||||
let claim = db
|
||||
.try_claim_stage1_job(thread_id, owner, updated_at, 3_600, 64)
|
||||
.await?;
|
||||
let ownership_token = match claim {
|
||||
codex_state::Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token,
|
||||
other => panic!("unexpected stage-1 claim outcome: {other:?}"),
|
||||
};
|
||||
|
||||
assert!(
|
||||
db.mark_stage1_job_succeeded(
|
||||
thread_id,
|
||||
&ownership_token,
|
||||
updated_at,
|
||||
raw_memory,
|
||||
rollout_summary,
|
||||
rollout_slug,
|
||||
)
|
||||
.await?,
|
||||
"stage-1 success should enqueue global consolidation"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn read_rollout_summary_bodies(memory_root: &Path) -> Result<Vec<String>> {
|
||||
let mut dir = tokio::fs::read_dir(memory_root.join("rollout_summaries")).await?;
|
||||
let mut summaries = Vec::new();
|
||||
|
||||
@@ -71,6 +71,7 @@ async fn write_rollout_with_user_event(dir: &Path, thread_id: ThreadId) -> io::R
|
||||
model_provider: None,
|
||||
base_instructions: None,
|
||||
dynamic_tools: None,
|
||||
memory_mode: None,
|
||||
},
|
||||
git: None,
|
||||
};
|
||||
@@ -114,6 +115,7 @@ async fn write_rollout_with_meta_only(dir: &Path, thread_id: ThreadId) -> io::Re
|
||||
model_provider: None,
|
||||
base_instructions: None,
|
||||
dynamic_tools: None,
|
||||
memory_mode: None,
|
||||
},
|
||||
git: None,
|
||||
};
|
||||
|
||||
@@ -1,23 +1,36 @@
|
||||
use anyhow::Result;
|
||||
use codex_core::config::types::McpServerConfig;
|
||||
use codex_core::config::types::McpServerTransportConfig;
|
||||
use codex_core::features::Feature;
|
||||
use codex_protocol::ThreadId;
|
||||
use codex_protocol::dynamic_tools::DynamicToolSpec;
|
||||
use codex_protocol::protocol::AskForApproval;
|
||||
use codex_protocol::protocol::EventMsg;
|
||||
use codex_protocol::protocol::Op;
|
||||
use codex_protocol::protocol::RolloutItem;
|
||||
use codex_protocol::protocol::RolloutLine;
|
||||
use codex_protocol::protocol::SandboxPolicy;
|
||||
use codex_protocol::protocol::SessionMeta;
|
||||
use codex_protocol::protocol::SessionMetaLine;
|
||||
use codex_protocol::protocol::SessionSource;
|
||||
use codex_protocol::protocol::UserMessageEvent;
|
||||
use codex_protocol::user_input::UserInput;
|
||||
use core_test_support::responses;
|
||||
use core_test_support::responses::ev_completed;
|
||||
use core_test_support::responses::ev_function_call;
|
||||
use core_test_support::responses::ev_response_created;
|
||||
use core_test_support::responses::ev_web_search_call_done;
|
||||
use core_test_support::responses::mount_sse_once;
|
||||
use core_test_support::responses::mount_sse_sequence;
|
||||
use core_test_support::responses::start_mock_server;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::stdio_server_bin;
|
||||
use core_test_support::test_codex::test_codex;
|
||||
use core_test_support::wait_for_event;
|
||||
use core_test_support::wait_for_event_match;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
use std::collections::HashMap;
|
||||
use std::fs;
|
||||
use tokio::time::Duration;
|
||||
use tracing_subscriber::prelude::*;
|
||||
@@ -128,6 +141,7 @@ async fn backfill_scans_existing_rollouts() -> Result<()> {
|
||||
model_provider: None,
|
||||
base_instructions: None,
|
||||
dynamic_tools: Some(dynamic_tools_for_hook),
|
||||
memory_mode: None,
|
||||
},
|
||||
git: None,
|
||||
};
|
||||
@@ -253,6 +267,148 @@ async fn user_messages_persist_in_state_db() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn web_search_marks_thread_memory_mode_polluted_when_configured() -> Result<()> {
|
||||
let server = start_mock_server().await;
|
||||
mount_sse_sequence(
|
||||
&server,
|
||||
vec![responses::sse(vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_web_search_call_done("ws-1", "completed", "weather seattle"),
|
||||
ev_completed("resp-1"),
|
||||
])],
|
||||
)
|
||||
.await;
|
||||
|
||||
let mut builder = test_codex().with_config(|config| {
|
||||
config.features.enable(Feature::Sqlite);
|
||||
config.memories.no_memories_if_mcp_or_web_search = true;
|
||||
});
|
||||
let test = builder.build(&server).await?;
|
||||
let db = test.codex.state_db().expect("state db enabled");
|
||||
let thread_id = test.session_configured.session_id;
|
||||
|
||||
test.submit_turn("search the web").await?;
|
||||
|
||||
let mut memory_mode = None;
|
||||
for _ in 0..100 {
|
||||
memory_mode = db.get_thread_memory_mode(thread_id).await?;
|
||||
if memory_mode.as_deref() == Some("polluted") {
|
||||
break;
|
||||
}
|
||||
tokio::time::sleep(Duration::from_millis(25)).await;
|
||||
}
|
||||
|
||||
assert_eq!(memory_mode.as_deref(), Some("polluted"));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn mcp_call_marks_thread_memory_mode_polluted_when_configured() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let call_id = "call-123";
|
||||
let server_name = "rmcp";
|
||||
let tool_name = format!("mcp__{server_name}__echo");
|
||||
mount_sse_once(
|
||||
&server,
|
||||
responses::sse(vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_function_call(call_id, &tool_name, "{\"message\":\"ping\"}"),
|
||||
ev_completed("resp-1"),
|
||||
]),
|
||||
)
|
||||
.await;
|
||||
mount_sse_once(
|
||||
&server,
|
||||
responses::sse(vec![
|
||||
responses::ev_assistant_message("msg-1", "rmcp echo tool completed."),
|
||||
ev_completed("resp-2"),
|
||||
]),
|
||||
)
|
||||
.await;
|
||||
|
||||
let rmcp_test_server_bin = stdio_server_bin()?;
|
||||
let mut builder = test_codex().with_config(move |config| {
|
||||
config.features.enable(Feature::Sqlite);
|
||||
config.memories.no_memories_if_mcp_or_web_search = true;
|
||||
|
||||
let mut servers = config.mcp_servers.get().clone();
|
||||
servers.insert(
|
||||
server_name.to_string(),
|
||||
McpServerConfig {
|
||||
transport: McpServerTransportConfig::Stdio {
|
||||
command: rmcp_test_server_bin,
|
||||
args: Vec::new(),
|
||||
env: Some(HashMap::from([(
|
||||
"MCP_TEST_VALUE".to_string(),
|
||||
"propagated-env".to_string(),
|
||||
)])),
|
||||
env_vars: Vec::new(),
|
||||
cwd: None,
|
||||
},
|
||||
enabled: true,
|
||||
required: false,
|
||||
disabled_reason: None,
|
||||
startup_timeout_sec: Some(Duration::from_secs(10)),
|
||||
tool_timeout_sec: None,
|
||||
enabled_tools: None,
|
||||
disabled_tools: None,
|
||||
scopes: None,
|
||||
oauth_resource: None,
|
||||
},
|
||||
);
|
||||
config
|
||||
.mcp_servers
|
||||
.set(servers)
|
||||
.expect("test mcp servers should accept any configuration");
|
||||
});
|
||||
let test = builder.build(&server).await?;
|
||||
let db = test.codex.state_db().expect("state db enabled");
|
||||
let thread_id = test.session_configured.session_id;
|
||||
|
||||
test.codex
|
||||
.submit(Op::UserTurn {
|
||||
items: vec![UserInput::Text {
|
||||
text: "call the rmcp echo tool".to_string(),
|
||||
text_elements: Vec::new(),
|
||||
}],
|
||||
final_output_json_schema: None,
|
||||
cwd: test.cwd_path().to_path_buf(),
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::new_read_only_policy(),
|
||||
model: test.session_configured.model.clone(),
|
||||
effort: None,
|
||||
summary: None,
|
||||
collaboration_mode: None,
|
||||
personality: None,
|
||||
})
|
||||
.await?;
|
||||
wait_for_event(&test.codex, |event| {
|
||||
matches!(event, EventMsg::McpToolCallEnd(_))
|
||||
})
|
||||
.await;
|
||||
wait_for_event_match(&test.codex, |event| match event {
|
||||
EventMsg::Error(err) => Some(Err(anyhow::anyhow!(err.message.clone()))),
|
||||
EventMsg::TurnComplete(_) => Some(Ok(())),
|
||||
_ => None,
|
||||
})
|
||||
.await?;
|
||||
|
||||
let mut memory_mode = None;
|
||||
for _ in 0..100 {
|
||||
memory_mode = db.get_thread_memory_mode(thread_id).await?;
|
||||
if memory_mode.as_deref() == Some("polluted") {
|
||||
break;
|
||||
}
|
||||
tokio::time::sleep(Duration::from_millis(25)).await;
|
||||
}
|
||||
|
||||
assert_eq!(memory_mode.as_deref(), Some("polluted"));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "current_thread")]
|
||||
async fn tool_call_logs_include_thread_id() -> Result<()> {
|
||||
let server = start_mock_server().await;
|
||||
|
||||
Reference in New Issue
Block a user