feat: polluted memories (#13008)

Add a feature flag to disable memory creation for "polluted"
This commit is contained in:
jif-oai
2026-03-02 12:57:32 +01:00
committed by GitHub
parent b08bdd91e3
commit b649953845
19 changed files with 939 additions and 33 deletions

View File

@@ -84,6 +84,7 @@ pub fn create_fake_rollout_with_source(
model_provider: model_provider.map(str::to_string),
base_instructions: None,
dynamic_tools: None,
memory_mode: None,
};
let payload = serde_json::to_value(SessionMetaLine {
meta,
@@ -165,6 +166,7 @@ pub fn create_fake_rollout_with_text_elements(
model_provider: model_provider.map(str::to_string),
base_instructions: None,
dynamic_tools: None,
memory_mode: None,
};
let payload = serde_json::to_value(SessionMetaLine {
meta,

View File

@@ -647,6 +647,10 @@
"format": "int64",
"type": "integer"
},
"no_memories_if_mcp_or_web_search": {
"description": "When `true`, web searches and MCP tool calls mark the thread `memory_mode` as `\"polluted\"`.",
"type": "boolean"
},
"phase_1_model": {
"description": "Model used for thread summarisation.",
"type": "string"

View File

@@ -2505,6 +2505,7 @@ persistence = "none"
let memories = r#"
[memories]
no_memories_if_mcp_or_web_search = true
generate_memories = false
use_memories = false
max_raw_memories_for_global = 512
@@ -2519,6 +2520,7 @@ phase_2_model = "gpt-5"
toml::from_str::<ConfigToml>(memories).expect("TOML deserialization should succeed");
assert_eq!(
Some(MemoriesToml {
no_memories_if_mcp_or_web_search: Some(true),
generate_memories: Some(false),
use_memories: Some(false),
max_raw_memories_for_global: Some(512),
@@ -2541,6 +2543,7 @@ phase_2_model = "gpt-5"
assert_eq!(
config.memories,
MemoriesConfig {
no_memories_if_mcp_or_web_search: true,
generate_memories: false,
use_memories: false,
max_raw_memories_for_global: 512,

View File

@@ -371,6 +371,8 @@ pub struct FeedbackConfigToml {
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default, JsonSchema)]
#[schemars(deny_unknown_fields)]
pub struct MemoriesToml {
/// When `true`, web searches and MCP tool calls mark the thread `memory_mode` as `"polluted"`.
pub no_memories_if_mcp_or_web_search: Option<bool>,
/// When `false`, newly created threads are stored with `memory_mode = "disabled"` in the state DB.
pub generate_memories: Option<bool>,
/// When `false`, skip injecting memory usage instructions into developer prompts.
@@ -394,6 +396,7 @@ pub struct MemoriesToml {
/// Effective memories settings after defaults are applied.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct MemoriesConfig {
pub no_memories_if_mcp_or_web_search: bool,
pub generate_memories: bool,
pub use_memories: bool,
pub max_raw_memories_for_global: usize,
@@ -408,6 +411,7 @@ pub struct MemoriesConfig {
impl Default for MemoriesConfig {
fn default() -> Self {
Self {
no_memories_if_mcp_or_web_search: false,
generate_memories: true,
use_memories: true,
max_raw_memories_for_global: DEFAULT_MEMORIES_MAX_RAW_MEMORIES_FOR_GLOBAL,
@@ -425,6 +429,9 @@ impl From<MemoriesToml> for MemoriesConfig {
fn from(toml: MemoriesToml) -> Self {
let defaults = Self::default();
Self {
no_memories_if_mcp_or_web_search: toml
.no_memories_if_mcp_or_web_search
.unwrap_or(defaults.no_memories_if_mcp_or_web_search),
generate_memories: toml.generate_memories.unwrap_or(defaults.generate_memories),
use_memories: toml.use_memories.unwrap_or(defaults.use_memories),
max_raw_memories_for_global: toml

View File

@@ -15,6 +15,7 @@ use crate::protocol::EventMsg;
use crate::protocol::McpInvocation;
use crate::protocol::McpToolCallBeginEvent;
use crate::protocol::McpToolCallEndEvent;
use crate::state_db;
use codex_protocol::mcp::CallToolResult;
use codex_protocol::models::FunctionCallOutputBody;
use codex_protocol::models::FunctionCallOutputPayload;
@@ -121,6 +122,7 @@ pub(crate) async fn handle_mcp_tool_call(
});
notify_mcp_tool_call_event(sess.as_ref(), turn_context, tool_call_begin_event)
.await;
maybe_mark_thread_memory_mode_polluted(sess.as_ref(), turn_context).await;
let start = Instant::now();
let result = sess
@@ -189,6 +191,7 @@ pub(crate) async fn handle_mcp_tool_call(
invocation: invocation.clone(),
});
notify_mcp_tool_call_event(sess.as_ref(), turn_context, tool_call_begin_event).await;
maybe_mark_thread_memory_mode_polluted(sess.as_ref(), turn_context).await;
let start = Instant::now();
// Perform the tool call.
@@ -224,6 +227,22 @@ pub(crate) async fn handle_mcp_tool_call(
ResponseInputItem::McpToolCallOutput { call_id, result }
}
async fn maybe_mark_thread_memory_mode_polluted(sess: &Session, turn_context: &TurnContext) {
if !turn_context
.config
.memories
.no_memories_if_mcp_or_web_search
{
return;
}
state_db::mark_thread_memory_mode_polluted(
sess.services.state_db.as_deref(),
sess.conversation_id,
"mcp_tool_call",
)
.await;
}
fn sanitize_mcp_tool_result_for_model(
supports_image_input: bool,
result: Result<CallToolResult, String>,

View File

@@ -177,6 +177,7 @@ mod tests {
model_provider: None,
base_instructions: None,
dynamic_tools: None,
memory_mode: None,
},
git: None,
};

View File

@@ -129,6 +129,13 @@ pub(crate) async fn extract_metadata_from_rollout(
}
Ok(ExtractionOutcome {
metadata,
memory_mode: items.iter().rev().find_map(|item| match item {
RolloutItem::SessionMeta(meta_line) => meta_line.meta.memory_mode.clone(),
RolloutItem::ResponseItem(_)
| RolloutItem::Compacted(_)
| RolloutItem::TurnContext(_)
| RolloutItem::EventMsg(_) => None,
}),
parse_errors,
})
}
@@ -272,6 +279,7 @@ pub(crate) async fn backfill_sessions(
);
}
let mut metadata = outcome.metadata;
let memory_mode = outcome.memory_mode.unwrap_or_else(|| "enabled".to_string());
if rollout.archived && metadata.archived_at.is_none() {
let fallback_archived_at = metadata.updated_at;
metadata.archived_at = file_modified_time_utc(&rollout.path)
@@ -282,6 +290,17 @@ pub(crate) async fn backfill_sessions(
stats.failed = stats.failed.saturating_add(1);
warn!("failed to upsert rollout {}: {err}", rollout.path.display());
} else {
if let Err(err) = runtime
.set_thread_memory_mode(metadata.id, memory_mode.as_str())
.await
{
stats.failed = stats.failed.saturating_add(1);
warn!(
"failed to restore memory mode for {}: {err}",
rollout.path.display()
);
continue;
}
stats.upserted = stats.upserted.saturating_add(1);
if let Ok(meta_line) =
rollout::list::read_session_meta_line(&rollout.path).await
@@ -519,6 +538,7 @@ mod tests {
model_provider: Some("openai".to_string()),
base_instructions: None,
dynamic_tools: None,
memory_mode: None,
};
let session_meta_line = SessionMetaLine {
meta: session_meta,
@@ -543,9 +563,71 @@ mod tests {
expected.updated_at = file_modified_time_utc(&path).await.expect("mtime");
assert_eq!(outcome.metadata, expected);
assert_eq!(outcome.memory_mode, None);
assert_eq!(outcome.parse_errors, 0);
}
#[tokio::test]
async fn extract_metadata_from_rollout_returns_latest_memory_mode() {
let dir = tempdir().expect("tempdir");
let uuid = Uuid::new_v4();
let id = ThreadId::from_string(&uuid.to_string()).expect("thread id");
let path = dir
.path()
.join(format!("rollout-2026-01-27T12-34-56-{uuid}.jsonl"));
let session_meta = SessionMeta {
id,
forked_from_id: None,
timestamp: "2026-01-27T12:34:56Z".to_string(),
cwd: dir.path().to_path_buf(),
originator: "cli".to_string(),
cli_version: "0.0.0".to_string(),
source: SessionSource::default(),
agent_nickname: None,
agent_role: None,
model_provider: Some("openai".to_string()),
base_instructions: None,
dynamic_tools: None,
memory_mode: None,
};
let polluted_meta = SessionMeta {
memory_mode: Some("polluted".to_string()),
..session_meta.clone()
};
let lines = vec![
RolloutLine {
timestamp: "2026-01-27T12:34:56Z".to_string(),
item: RolloutItem::SessionMeta(SessionMetaLine {
meta: session_meta,
git: None,
}),
},
RolloutLine {
timestamp: "2026-01-27T12:35:00Z".to_string(),
item: RolloutItem::SessionMeta(SessionMetaLine {
meta: polluted_meta,
git: None,
}),
},
];
let mut file = File::create(&path).expect("create rollout");
for line in lines {
writeln!(
file,
"{}",
serde_json::to_string(&line).expect("serialize rollout line")
)
.expect("write rollout line");
}
let outcome = extract_metadata_from_rollout(&path, "openai", None)
.await
.expect("extract");
assert_eq!(outcome.memory_mode.as_deref(), Some("polluted"));
}
#[test]
fn builder_from_items_falls_back_to_filename() {
let dir = tempdir().expect("tempdir");
@@ -669,6 +751,7 @@ mod tests {
model_provider: Some("test-provider".to_string()),
base_instructions: None,
dynamic_tools: None,
memory_mode: None,
};
let session_meta_line = SessionMetaLine {
meta: session_meta,

View File

@@ -412,6 +412,8 @@ impl RolloutRecorder {
} else {
Some(dynamic_tools)
},
memory_mode: (!config.memories.generate_memories)
.then_some("disabled".to_string()),
};
(

View File

@@ -1109,6 +1109,7 @@ async fn test_updated_at_uses_file_mtime() -> Result<()> {
model_provider: Some("test-provider".into()),
base_instructions: None,
dynamic_tools: None,
memory_mode: None,
},
git: None,
}),

View File

@@ -337,6 +337,19 @@ pub async fn persist_dynamic_tools(
}
}
pub async fn mark_thread_memory_mode_polluted(
context: Option<&codex_state::StateRuntime>,
thread_id: ThreadId,
stage: &str,
) {
let Some(ctx) = context else {
return;
};
if let Err(err) = ctx.mark_thread_memory_mode_polluted(thread_id).await {
warn!("state db mark_thread_memory_mode_polluted failed during {stage}: {err}");
}
}
/// Reconcile rollout items into SQLite, falling back to scanning the rollout file.
pub async fn reconcile_rollout(
context: Option<&codex_state::StateRuntime>,
@@ -375,6 +388,7 @@ pub async fn reconcile_rollout(
}
};
let mut metadata = outcome.metadata;
let memory_mode = outcome.memory_mode.unwrap_or_else(|| "enabled".to_string());
metadata.cwd = normalize_cwd_for_state_db(&metadata.cwd);
match archived_only {
Some(true) if metadata.archived_at.is_none() => {
@@ -392,6 +406,16 @@ pub async fn reconcile_rollout(
);
return;
}
if let Err(err) = ctx
.set_thread_memory_mode(metadata.id, memory_mode.as_str())
.await
{
warn!(
"state db reconcile_rollout memory_mode update failed {}: {err}",
rollout_path.display()
);
return;
}
if let Ok(meta_line) = crate::rollout::list::read_session_meta_line(rollout_path).await {
persist_dynamic_tools(
Some(ctx),

View File

@@ -58,9 +58,31 @@ pub(crate) async fn record_completed_response_item(
) {
sess.record_conversation_items(turn_context, std::slice::from_ref(item))
.await;
maybe_mark_thread_memory_mode_polluted_from_web_search(sess, turn_context, item).await;
record_stage1_output_usage_for_completed_item(turn_context, item).await;
}
async fn maybe_mark_thread_memory_mode_polluted_from_web_search(
sess: &Session,
turn_context: &TurnContext,
item: &ResponseItem,
) {
if !turn_context
.config
.memories
.no_memories_if_mcp_or_web_search
|| !matches!(item, ResponseItem::WebSearchCall { .. })
{
return;
}
state_db::mark_thread_memory_mode_polluted(
sess.services.state_db.as_deref(),
sess.conversation_id,
"record_completed_response_item",
)
.await;
}
async fn record_stage1_output_usage_for_completed_item(
turn_context: &TurnContext,
item: &ResponseItem,

View File

@@ -11,7 +11,9 @@ use core_test_support::responses::ResponsesRequest;
use core_test_support::responses::ev_assistant_message;
use core_test_support::responses::ev_completed;
use core_test_support::responses::ev_response_created;
use core_test_support::responses::ev_web_search_call_done;
use core_test_support::responses::mount_sse_once;
use core_test_support::responses::mount_sse_sequence;
use core_test_support::responses::sse;
use core_test_support::responses::start_mock_server;
use core_test_support::test_codex::TestCodex;
@@ -157,6 +159,156 @@ async fn memories_startup_phase2_tracks_added_and_removed_inputs_across_runs() -
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn web_search_pollution_moves_selected_thread_into_removed_phase2_inputs() -> Result<()> {
let server = start_mock_server().await;
let home = Arc::new(TempDir::new()?);
let db = init_state_db(&home).await?;
let mut initial_builder = test_codex().with_home(home.clone()).with_config(|config| {
config.features.enable(Feature::Sqlite);
config.features.enable(Feature::MemoryTool);
config.memories.max_raw_memories_for_global = 1;
config.memories.no_memories_if_mcp_or_web_search = true;
});
let initial = initial_builder.build(&server).await?;
mount_sse_once(
&server,
sse(vec![
ev_response_created("resp-initial-1"),
ev_assistant_message("msg-initial-1", "initial turn complete"),
ev_completed("resp-initial-1"),
]),
)
.await;
initial.submit_turn("hello before memories").await?;
let rollout_path = initial
.session_configured
.rollout_path
.clone()
.expect("rollout path");
let thread_id = initial.session_configured.session_id;
let updated_at = {
let deadline = Instant::now() + Duration::from_secs(10);
loop {
if let Some(metadata) = db.get_thread(thread_id).await? {
break metadata.updated_at;
}
assert!(
Instant::now() < deadline,
"timed out waiting for thread metadata for {thread_id}"
);
tokio::time::sleep(Duration::from_millis(50)).await;
}
};
seed_stage1_output_for_existing_thread(
db.as_ref(),
thread_id,
updated_at.timestamp(),
"raw memory seeded for web search pollution",
"rollout summary seeded for web search pollution",
Some("pollution-rollout"),
)
.await?;
shutdown_test_codex(&initial).await?;
let responses = mount_sse_sequence(
&server,
vec![
sse(vec![
ev_response_created("resp-phase2-1"),
ev_assistant_message("msg-phase2-1", "phase2 complete"),
ev_completed("resp-phase2-1"),
]),
sse(vec![
ev_response_created("resp-web-1"),
ev_web_search_call_done("ws-1", "completed", "weather seattle"),
ev_completed("resp-web-1"),
]),
],
)
.await;
let mut resumed_builder = test_codex().with_home(home.clone()).with_config(|config| {
config.features.enable(Feature::Sqlite);
config.features.enable(Feature::MemoryTool);
config.memories.max_raw_memories_for_global = 1;
config.memories.no_memories_if_mcp_or_web_search = true;
});
let resumed = resumed_builder
.resume(&server, home.clone(), rollout_path.clone())
.await?;
let first_phase2_request = wait_for_request(&responses, 1).await.remove(0);
let first_phase2_prompt = phase2_prompt_text(&first_phase2_request);
assert!(
first_phase2_prompt.contains("- selected inputs this run: 1"),
"expected seeded thread to be selected before pollution: {first_phase2_prompt}"
);
assert!(
first_phase2_prompt.contains("- newly added since the last successful Phase 2 run: 1"),
"expected seeded thread to be added before pollution: {first_phase2_prompt}"
);
assert!(
first_phase2_prompt.contains(&format!("- [added] thread_id={thread_id},")),
"expected selected thread in first phase2 prompt: {first_phase2_prompt}"
);
wait_for_phase2_success(db.as_ref(), thread_id).await?;
resumed
.submit_turn("search the web for weather seattle")
.await?;
assert_eq!(
{
let deadline = Instant::now() + Duration::from_secs(10);
loop {
let memory_mode = db.get_thread_memory_mode(thread_id).await?;
if memory_mode.as_deref() == Some("polluted") {
break memory_mode;
}
assert!(
Instant::now() < deadline,
"timed out waiting for polluted memory mode for {thread_id}"
);
tokio::time::sleep(Duration::from_millis(50)).await;
}
}
.as_deref(),
Some("polluted")
);
let selection = {
let deadline = Instant::now() + Duration::from_secs(10);
loop {
let selection = db.get_phase2_input_selection(1, 30).await?;
if selection.selected.is_empty()
&& selection.retained_thread_ids.is_empty()
&& selection.removed.len() == 1
&& selection.removed[0].thread_id == thread_id
{
break selection;
}
assert!(
Instant::now() < deadline,
"timed out waiting for polluted thread to move into removed phase2 inputs: \
{selection:?}"
);
tokio::time::sleep(Duration::from_millis(50)).await;
}
};
assert_eq!(responses.requests().len(), 2);
assert!(selection.selected.is_empty());
assert_eq!(selection.retained_thread_ids, Vec::<ThreadId>::new());
assert_eq!(selection.removed.len(), 1);
assert_eq!(selection.removed[0].thread_id, thread_id);
shutdown_test_codex(&resumed).await?;
Ok(())
}
async fn build_test_codex(server: &wiremock::MockServer, home: Arc<TempDir>) -> Result<TestCodex> {
let mut builder = test_codex().with_home(home).with_config(|config| {
config.features.enable(Feature::Sqlite);
@@ -195,46 +347,33 @@ async fn seed_stage1_output(
let metadata = metadata_builder.build("test-provider");
db.upsert_thread(&metadata).await?;
let claim = db
.try_claim_stage1_job(
thread_id,
ThreadId::new(),
updated_at.timestamp(),
3_600,
64,
)
.await?;
let ownership_token = match claim {
codex_state::Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token,
other => panic!("unexpected stage-1 claim outcome: {other:?}"),
};
assert!(
db.mark_stage1_job_succeeded(
thread_id,
&ownership_token,
updated_at.timestamp(),
raw_memory,
rollout_summary,
Some(rollout_slug),
)
.await?,
"stage-1 success should enqueue global consolidation"
);
seed_stage1_output_for_existing_thread(
db,
thread_id,
updated_at.timestamp(),
raw_memory,
rollout_summary,
Some(rollout_slug),
)
.await?;
Ok(thread_id)
}
async fn wait_for_single_request(mock: &ResponseMock) -> ResponsesRequest {
wait_for_request(mock, 1).await.remove(0)
}
async fn wait_for_request(mock: &ResponseMock, expected_count: usize) -> Vec<ResponsesRequest> {
let deadline = Instant::now() + Duration::from_secs(10);
loop {
let requests = mock.requests();
if let Some(request) = requests.into_iter().next() {
return request;
if requests.len() >= expected_count {
return requests;
}
assert!(
Instant::now() < deadline,
"timed out waiting for phase2 request"
"timed out waiting for {expected_count} phase2 requests"
);
tokio::time::sleep(Duration::from_millis(50)).await;
}
@@ -272,6 +411,39 @@ async fn wait_for_phase2_success(
}
}
async fn seed_stage1_output_for_existing_thread(
db: &codex_state::StateRuntime,
thread_id: ThreadId,
updated_at: i64,
raw_memory: &str,
rollout_summary: &str,
rollout_slug: Option<&str>,
) -> Result<()> {
let owner = ThreadId::new();
let claim = db
.try_claim_stage1_job(thread_id, owner, updated_at, 3_600, 64)
.await?;
let ownership_token = match claim {
codex_state::Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token,
other => panic!("unexpected stage-1 claim outcome: {other:?}"),
};
assert!(
db.mark_stage1_job_succeeded(
thread_id,
&ownership_token,
updated_at,
raw_memory,
rollout_summary,
rollout_slug,
)
.await?,
"stage-1 success should enqueue global consolidation"
);
Ok(())
}
async fn read_rollout_summary_bodies(memory_root: &Path) -> Result<Vec<String>> {
let mut dir = tokio::fs::read_dir(memory_root.join("rollout_summaries")).await?;
let mut summaries = Vec::new();

View File

@@ -71,6 +71,7 @@ async fn write_rollout_with_user_event(dir: &Path, thread_id: ThreadId) -> io::R
model_provider: None,
base_instructions: None,
dynamic_tools: None,
memory_mode: None,
},
git: None,
};
@@ -114,6 +115,7 @@ async fn write_rollout_with_meta_only(dir: &Path, thread_id: ThreadId) -> io::Re
model_provider: None,
base_instructions: None,
dynamic_tools: None,
memory_mode: None,
},
git: None,
};

View File

@@ -1,23 +1,36 @@
use anyhow::Result;
use codex_core::config::types::McpServerConfig;
use codex_core::config::types::McpServerTransportConfig;
use codex_core::features::Feature;
use codex_protocol::ThreadId;
use codex_protocol::dynamic_tools::DynamicToolSpec;
use codex_protocol::protocol::AskForApproval;
use codex_protocol::protocol::EventMsg;
use codex_protocol::protocol::Op;
use codex_protocol::protocol::RolloutItem;
use codex_protocol::protocol::RolloutLine;
use codex_protocol::protocol::SandboxPolicy;
use codex_protocol::protocol::SessionMeta;
use codex_protocol::protocol::SessionMetaLine;
use codex_protocol::protocol::SessionSource;
use codex_protocol::protocol::UserMessageEvent;
use codex_protocol::user_input::UserInput;
use core_test_support::responses;
use core_test_support::responses::ev_completed;
use core_test_support::responses::ev_function_call;
use core_test_support::responses::ev_response_created;
use core_test_support::responses::ev_web_search_call_done;
use core_test_support::responses::mount_sse_once;
use core_test_support::responses::mount_sse_sequence;
use core_test_support::responses::start_mock_server;
use core_test_support::skip_if_no_network;
use core_test_support::stdio_server_bin;
use core_test_support::test_codex::test_codex;
use core_test_support::wait_for_event;
use core_test_support::wait_for_event_match;
use pretty_assertions::assert_eq;
use serde_json::json;
use std::collections::HashMap;
use std::fs;
use tokio::time::Duration;
use tracing_subscriber::prelude::*;
@@ -128,6 +141,7 @@ async fn backfill_scans_existing_rollouts() -> Result<()> {
model_provider: None,
base_instructions: None,
dynamic_tools: Some(dynamic_tools_for_hook),
memory_mode: None,
},
git: None,
};
@@ -253,6 +267,148 @@ async fn user_messages_persist_in_state_db() -> Result<()> {
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn web_search_marks_thread_memory_mode_polluted_when_configured() -> Result<()> {
let server = start_mock_server().await;
mount_sse_sequence(
&server,
vec![responses::sse(vec![
ev_response_created("resp-1"),
ev_web_search_call_done("ws-1", "completed", "weather seattle"),
ev_completed("resp-1"),
])],
)
.await;
let mut builder = test_codex().with_config(|config| {
config.features.enable(Feature::Sqlite);
config.memories.no_memories_if_mcp_or_web_search = true;
});
let test = builder.build(&server).await?;
let db = test.codex.state_db().expect("state db enabled");
let thread_id = test.session_configured.session_id;
test.submit_turn("search the web").await?;
let mut memory_mode = None;
for _ in 0..100 {
memory_mode = db.get_thread_memory_mode(thread_id).await?;
if memory_mode.as_deref() == Some("polluted") {
break;
}
tokio::time::sleep(Duration::from_millis(25)).await;
}
assert_eq!(memory_mode.as_deref(), Some("polluted"));
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn mcp_call_marks_thread_memory_mode_polluted_when_configured() -> Result<()> {
skip_if_no_network!(Ok(()));
let server = start_mock_server().await;
let call_id = "call-123";
let server_name = "rmcp";
let tool_name = format!("mcp__{server_name}__echo");
mount_sse_once(
&server,
responses::sse(vec![
ev_response_created("resp-1"),
ev_function_call(call_id, &tool_name, "{\"message\":\"ping\"}"),
ev_completed("resp-1"),
]),
)
.await;
mount_sse_once(
&server,
responses::sse(vec![
responses::ev_assistant_message("msg-1", "rmcp echo tool completed."),
ev_completed("resp-2"),
]),
)
.await;
let rmcp_test_server_bin = stdio_server_bin()?;
let mut builder = test_codex().with_config(move |config| {
config.features.enable(Feature::Sqlite);
config.memories.no_memories_if_mcp_or_web_search = true;
let mut servers = config.mcp_servers.get().clone();
servers.insert(
server_name.to_string(),
McpServerConfig {
transport: McpServerTransportConfig::Stdio {
command: rmcp_test_server_bin,
args: Vec::new(),
env: Some(HashMap::from([(
"MCP_TEST_VALUE".to_string(),
"propagated-env".to_string(),
)])),
env_vars: Vec::new(),
cwd: None,
},
enabled: true,
required: false,
disabled_reason: None,
startup_timeout_sec: Some(Duration::from_secs(10)),
tool_timeout_sec: None,
enabled_tools: None,
disabled_tools: None,
scopes: None,
oauth_resource: None,
},
);
config
.mcp_servers
.set(servers)
.expect("test mcp servers should accept any configuration");
});
let test = builder.build(&server).await?;
let db = test.codex.state_db().expect("state db enabled");
let thread_id = test.session_configured.session_id;
test.codex
.submit(Op::UserTurn {
items: vec![UserInput::Text {
text: "call the rmcp echo tool".to_string(),
text_elements: Vec::new(),
}],
final_output_json_schema: None,
cwd: test.cwd_path().to_path_buf(),
approval_policy: AskForApproval::Never,
sandbox_policy: SandboxPolicy::new_read_only_policy(),
model: test.session_configured.model.clone(),
effort: None,
summary: None,
collaboration_mode: None,
personality: None,
})
.await?;
wait_for_event(&test.codex, |event| {
matches!(event, EventMsg::McpToolCallEnd(_))
})
.await;
wait_for_event_match(&test.codex, |event| match event {
EventMsg::Error(err) => Some(Err(anyhow::anyhow!(err.message.clone()))),
EventMsg::TurnComplete(_) => Some(Ok(())),
_ => None,
})
.await?;
let mut memory_mode = None;
for _ in 0..100 {
memory_mode = db.get_thread_memory_mode(thread_id).await?;
if memory_mode.as_deref() == Some("polluted") {
break;
}
tokio::time::sleep(Duration::from_millis(25)).await;
}
assert_eq!(memory_mode.as_deref(), Some("polluted"));
Ok(())
}
#[tokio::test(flavor = "current_thread")]
async fn tool_call_logs_include_thread_id() -> Result<()> {
let server = start_mock_server().await;

View File

@@ -2060,6 +2060,8 @@ pub struct SessionMeta {
pub base_instructions: Option<BaseInstructions>,
#[serde(skip_serializing_if = "Option::is_none")]
pub dynamic_tools: Option<Vec<DynamicToolSpec>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub memory_mode: Option<String>,
}
impl Default for SessionMeta {
@@ -2077,6 +2079,7 @@ impl Default for SessionMeta {
model_provider: None,
base_instructions: None,
dynamic_tools: None,
memory_mode: None,
}
}
}

View File

@@ -242,6 +242,7 @@ mod tests {
model_provider: Some("openai".to_string()),
base_instructions: None,
dynamic_tools: None,
memory_mode: None,
},
git: None,
}),

View File

@@ -45,6 +45,8 @@ pub struct ThreadsPage {
pub struct ExtractionOutcome {
/// The extracted thread metadata.
pub metadata: ThreadMetadata,
/// The explicit thread memory mode from rollout metadata, if present.
pub memory_mode: Option<String>,
/// The number of rollout lines that failed to parse.
pub parse_errors: usize,
}

View File

@@ -277,7 +277,8 @@ SELECT
FROM stage1_outputs AS so
LEFT JOIN threads AS t
ON t.id = so.thread_id
WHERE length(trim(so.raw_memory)) > 0 OR length(trim(so.rollout_summary)) > 0
WHERE t.memory_mode = 'enabled'
AND (length(trim(so.raw_memory)) > 0 OR length(trim(so.rollout_summary)) > 0)
ORDER BY so.source_updated_at DESC, so.thread_id DESC
LIMIT ?
"#,
@@ -304,11 +305,13 @@ LIMIT ?
/// `thread_id DESC`
/// - previously selected rows are identified by `selected_for_phase2 = 1`
/// - `previous_selected` contains the current persisted rows that belonged
/// to the last successful phase-2 baseline
/// to the last successful phase-2 baseline, even if those threads are no
/// longer memory-eligible
/// - `retained_thread_ids` records which current rows still match the exact
/// snapshot selected in the last successful phase-2 run
/// - removed rows are previously selected rows that are still present in
/// `stage1_outputs` but fall outside the current top-`n` selection
/// `stage1_outputs` but are no longer in the current selection, including
/// threads that are no longer memory-eligible
pub async fn get_phase2_input_selection(
&self,
n: usize,
@@ -336,7 +339,8 @@ SELECT
FROM stage1_outputs AS so
LEFT JOIN threads AS t
ON t.id = so.thread_id
WHERE (length(trim(so.raw_memory)) > 0 OR length(trim(so.rollout_summary)) > 0)
WHERE t.memory_mode = 'enabled'
AND (length(trim(so.raw_memory)) > 0 OR length(trim(so.rollout_summary)) > 0)
AND (
(so.last_usage IS NOT NULL AND so.last_usage >= ?)
OR (so.last_usage IS NULL AND so.source_updated_at >= ?)
@@ -421,6 +425,51 @@ ORDER BY so.source_updated_at DESC, so.thread_id DESC
})
}
/// Marks a thread as polluted and enqueues phase-2 forgetting when the
/// thread participated in the last successful phase-2 baseline.
pub async fn mark_thread_memory_mode_polluted(
&self,
thread_id: ThreadId,
) -> anyhow::Result<bool> {
let now = Utc::now().timestamp();
let thread_id = thread_id.to_string();
let mut tx = self.pool.begin().await?;
let rows_affected = sqlx::query(
r#"
UPDATE threads
SET memory_mode = 'polluted'
WHERE id = ? AND memory_mode != 'polluted'
"#,
)
.bind(thread_id.as_str())
.execute(&mut *tx)
.await?
.rows_affected();
if rows_affected == 0 {
tx.commit().await?;
return Ok(false);
}
let selected_for_phase2 = sqlx::query_scalar::<_, i64>(
r#"
SELECT selected_for_phase2
FROM stage1_outputs
WHERE thread_id = ?
"#,
)
.bind(thread_id.as_str())
.fetch_optional(&mut *tx)
.await?
.unwrap_or(0);
if selected_for_phase2 != 0 {
enqueue_global_consolidation_with_executor(&mut *tx, now).await?;
}
tx.commit().await?;
Ok(true)
}
/// Attempts to claim a stage-1 job for a thread at `source_updated_at`.
///
/// Claim semantics:
@@ -2569,6 +2618,71 @@ VALUES (?, ?, ?, ?, ?)
let _ = tokio::fs::remove_dir_all(codex_home).await;
}
#[tokio::test]
async fn list_stage1_outputs_for_global_skips_polluted_threads() {
let codex_home = unique_temp_dir();
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None)
.await
.expect("initialize runtime");
let thread_id_enabled =
ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id");
let thread_id_polluted =
ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id");
let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id");
for (thread_id, workspace) in [
(thread_id_enabled, "workspace-enabled"),
(thread_id_polluted, "workspace-polluted"),
] {
runtime
.upsert_thread(&test_thread_metadata(
&codex_home,
thread_id,
codex_home.join(workspace),
))
.await
.expect("upsert thread");
let claim = runtime
.try_claim_stage1_job(thread_id, owner, 100, 3600, 64)
.await
.expect("claim stage1");
let ownership_token = match claim {
Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token,
other => panic!("unexpected stage1 claim outcome: {other:?}"),
};
assert!(
runtime
.mark_stage1_job_succeeded(
thread_id,
ownership_token.as_str(),
100,
"raw memory",
"summary",
None,
)
.await
.expect("mark stage1 succeeded"),
"stage1 success should persist output"
);
}
runtime
.set_thread_memory_mode(thread_id_polluted, "polluted")
.await
.expect("mark thread polluted");
let outputs = runtime
.list_stage1_outputs_for_global(10)
.await
.expect("list stage1 outputs for global");
assert_eq!(outputs.len(), 1);
assert_eq!(outputs[0].thread_id, thread_id_enabled);
let _ = tokio::fs::remove_dir_all(codex_home).await;
}
#[tokio::test]
async fn get_phase2_input_selection_reports_added_retained_and_removed_rows() {
let codex_home = unique_temp_dir();
@@ -2681,6 +2795,197 @@ VALUES (?, ?, ?, ?, ?)
let _ = tokio::fs::remove_dir_all(codex_home).await;
}
#[tokio::test]
async fn get_phase2_input_selection_marks_polluted_previous_selection_as_removed() {
let codex_home = unique_temp_dir();
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None)
.await
.expect("initialize runtime");
let thread_id_enabled =
ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id");
let thread_id_polluted =
ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id");
let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id");
for (thread_id, updated_at) in [(thread_id_enabled, 100), (thread_id_polluted, 101)] {
runtime
.upsert_thread(&test_thread_metadata(
&codex_home,
thread_id,
codex_home.join(thread_id.to_string()),
))
.await
.expect("upsert thread");
let claim = runtime
.try_claim_stage1_job(thread_id, owner, updated_at, 3600, 64)
.await
.expect("claim stage1");
let ownership_token = match claim {
Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token,
other => panic!("unexpected stage1 claim outcome: {other:?}"),
};
assert!(
runtime
.mark_stage1_job_succeeded(
thread_id,
ownership_token.as_str(),
updated_at,
&format!("raw-{updated_at}"),
&format!("summary-{updated_at}"),
None,
)
.await
.expect("mark stage1 succeeded"),
"stage1 success should persist output"
);
}
let claim = runtime
.try_claim_global_phase2_job(owner, 3600)
.await
.expect("claim phase2");
let (ownership_token, input_watermark) = match claim {
Phase2JobClaimOutcome::Claimed {
ownership_token,
input_watermark,
} => (ownership_token, input_watermark),
other => panic!("unexpected phase2 claim outcome: {other:?}"),
};
let selected_outputs = runtime
.list_stage1_outputs_for_global(10)
.await
.expect("list stage1 outputs for global");
assert!(
runtime
.mark_global_phase2_job_succeeded(
ownership_token.as_str(),
input_watermark,
&selected_outputs,
)
.await
.expect("mark phase2 success"),
"phase2 success should persist selected rows"
);
runtime
.set_thread_memory_mode(thread_id_polluted, "polluted")
.await
.expect("mark thread polluted");
let selection = runtime
.get_phase2_input_selection(2, 36_500)
.await
.expect("load phase2 input selection");
assert_eq!(selection.selected.len(), 1);
assert_eq!(selection.selected[0].thread_id, thread_id_enabled);
assert_eq!(selection.previous_selected.len(), 2);
assert!(
selection
.previous_selected
.iter()
.any(|item| item.thread_id == thread_id_enabled)
);
assert!(
selection
.previous_selected
.iter()
.any(|item| item.thread_id == thread_id_polluted)
);
assert_eq!(selection.retained_thread_ids, vec![thread_id_enabled]);
assert_eq!(selection.removed.len(), 1);
assert_eq!(selection.removed[0].thread_id, thread_id_polluted);
let _ = tokio::fs::remove_dir_all(codex_home).await;
}
#[tokio::test]
async fn mark_thread_memory_mode_polluted_enqueues_phase2_for_selected_threads() {
let codex_home = unique_temp_dir();
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None)
.await
.expect("initialize runtime");
let thread_id = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id");
let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id");
runtime
.upsert_thread(&test_thread_metadata(
&codex_home,
thread_id,
codex_home.join("workspace"),
))
.await
.expect("upsert thread");
let claim = runtime
.try_claim_stage1_job(thread_id, owner, 100, 3600, 64)
.await
.expect("claim stage1");
let ownership_token = match claim {
Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token,
other => panic!("unexpected stage1 claim outcome: {other:?}"),
};
assert!(
runtime
.mark_stage1_job_succeeded(
thread_id,
ownership_token.as_str(),
100,
"raw",
"summary",
None,
)
.await
.expect("mark stage1 succeeded"),
"stage1 success should persist output"
);
let phase2_claim = runtime
.try_claim_global_phase2_job(owner, 3600)
.await
.expect("claim phase2");
let (phase2_token, input_watermark) = match phase2_claim {
Phase2JobClaimOutcome::Claimed {
ownership_token,
input_watermark,
} => (ownership_token, input_watermark),
other => panic!("unexpected phase2 claim outcome: {other:?}"),
};
let selected_outputs = runtime
.list_stage1_outputs_for_global(10)
.await
.expect("list stage1 outputs");
assert!(
runtime
.mark_global_phase2_job_succeeded(
phase2_token.as_str(),
input_watermark,
&selected_outputs,
)
.await
.expect("mark phase2 success"),
"phase2 success should persist selected rows"
);
assert!(
runtime
.mark_thread_memory_mode_polluted(thread_id)
.await
.expect("mark thread polluted"),
"thread should transition to polluted"
);
let next_claim = runtime
.try_claim_global_phase2_job(owner, 3600)
.await
.expect("claim phase2 after pollution");
assert!(matches!(next_claim, Phase2JobClaimOutcome::Claimed { .. }));
let _ = tokio::fs::remove_dir_all(codex_home).await;
}
#[tokio::test]
async fn get_phase2_input_selection_treats_regenerated_selected_rows_as_added() {
let codex_home = unique_temp_dir();

View File

@@ -35,6 +35,14 @@ WHERE id = ?
.transpose()
}
pub async fn get_thread_memory_mode(&self, id: ThreadId) -> anyhow::Result<Option<String>> {
let row = sqlx::query("SELECT memory_mode FROM threads WHERE id = ?")
.bind(id.to_string())
.fetch_optional(self.pool.as_ref())
.await?;
Ok(row.and_then(|row| row.try_get("memory_mode").ok()))
}
/// Get dynamic tools for a thread, if present.
pub async fn get_dynamic_tools(
&self,
@@ -199,6 +207,19 @@ FROM threads
.await
}
pub async fn set_thread_memory_mode(
&self,
thread_id: ThreadId,
memory_mode: &str,
) -> anyhow::Result<bool> {
let result = sqlx::query("UPDATE threads SET memory_mode = ? WHERE id = ?")
.bind(memory_mode)
.bind(thread_id.to_string())
.execute(self.pool.as_ref())
.await?;
Ok(result.rows_affected() > 0)
}
async fn upsert_thread_with_creation_memory_mode(
&self,
metadata: &crate::ThreadMetadata,
@@ -357,6 +378,16 @@ ON CONFLICT(thread_id, position) DO NOTHING
}
return Err(err);
}
if let Some(memory_mode) = extract_memory_mode(items)
&& let Err(err) = self
.set_thread_memory_mode(builder.id, memory_mode.as_str())
.await
{
if let Some(otel) = otel {
otel.counter(DB_ERROR_METRIC, 1, &[("stage", "set_thread_memory_mode")]);
}
return Err(err);
}
let dynamic_tools = extract_dynamic_tools(items);
if let Some(dynamic_tools) = dynamic_tools
&& let Err(err) = self
@@ -438,6 +469,16 @@ pub(super) fn extract_dynamic_tools(items: &[RolloutItem]) -> Option<Option<Vec<
})
}
pub(super) fn extract_memory_mode(items: &[RolloutItem]) -> Option<String> {
items.iter().rev().find_map(|item| match item {
RolloutItem::SessionMeta(meta_line) => meta_line.meta.memory_mode.clone(),
RolloutItem::ResponseItem(_)
| RolloutItem::Compacted(_)
| RolloutItem::TurnContext(_)
| RolloutItem::EventMsg(_) => None,
})
}
pub(super) fn push_thread_filters<'a>(
builder: &mut QueryBuilder<'a, Sqlite>,
archived_only: bool,
@@ -518,7 +559,11 @@ mod tests {
use super::*;
use crate::runtime::test_support::test_thread_metadata;
use crate::runtime::test_support::unique_temp_dir;
use codex_protocol::protocol::SessionMeta;
use codex_protocol::protocol::SessionMetaLine;
use codex_protocol::protocol::SessionSource;
use pretty_assertions::assert_eq;
use std::path::PathBuf;
#[tokio::test]
async fn upsert_thread_keeps_creation_memory_mode_for_existing_rows() {
@@ -557,4 +602,56 @@ mod tests {
.expect("memory mode should remain readable");
assert_eq!(memory_mode, "disabled");
}
#[tokio::test]
async fn apply_rollout_items_restores_memory_mode_from_session_meta() {
let codex_home = unique_temp_dir();
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None)
.await
.expect("state db should initialize");
let thread_id =
ThreadId::from_string("00000000-0000-0000-0000-000000000456").expect("valid thread id");
let metadata = test_thread_metadata(&codex_home, thread_id, codex_home.clone());
runtime
.upsert_thread(&metadata)
.await
.expect("initial upsert should succeed");
let builder = ThreadMetadataBuilder::new(
thread_id,
metadata.rollout_path.clone(),
metadata.created_at,
SessionSource::Cli,
);
let items = vec![RolloutItem::SessionMeta(SessionMetaLine {
meta: SessionMeta {
id: thread_id,
forked_from_id: None,
timestamp: metadata.created_at.to_rfc3339(),
cwd: PathBuf::new(),
originator: String::new(),
cli_version: String::new(),
source: SessionSource::Cli,
agent_nickname: None,
agent_role: None,
model_provider: None,
base_instructions: None,
dynamic_tools: None,
memory_mode: Some("polluted".to_string()),
},
git: None,
})];
runtime
.apply_rollout_items(&builder, &items, None, None)
.await
.expect("apply_rollout_items should succeed");
let memory_mode = runtime
.get_thread_memory_mode(thread_id)
.await
.expect("memory mode should load");
assert_eq!(memory_mode.as_deref(), Some("polluted"));
}
}