From f75c60087202503d84258b259607b2ea655cd51c Mon Sep 17 00:00:00 2001 From: jif-oai Date: Tue, 5 May 2026 17:15:21 +0200 Subject: [PATCH] feat: support windowed multi-query memory search (#21204) ## Why Memory search currently supports either independent substring matches or requiring every query to appear on the same line. That is too restrictive for memory files where related terms often land on nearby lines in the same note or bullet block. ## What changed - Replace the old `all` match mode with explicit tagged modes: `all_on_same_line` and `all_within_lines { line_count }`. - Add windowed matching in `codex-rs/memories/mcp/src/local.rs` so callers can require every query to appear within a bounded line range while returning only the minimal qualifying windows. - Reject invalid zero-width windows and update the MCP tool description plus argument parsing to expose the new mode. - Add coverage for same-line matching, windowed matching, and invalid `line_count` input. ## Verification - Added targeted coverage in `codex-rs/memories/mcp/src/local_tests.rs` for `search_supports_all_within_lines_match_mode` and `search_rejects_zero_line_window`. - Added server-side parsing coverage in `codex-rs/memories/mcp/src/server.rs` for `search_args_accept_windowed_all_match_mode`. --- codex-rs/memories/mcp/src/backend.rs | 12 +- codex-rs/memories/mcp/src/local.rs | 158 ++++++++++++++++++----- codex-rs/memories/mcp/src/local_tests.rs | 63 ++++++++- codex-rs/memories/mcp/src/server.rs | 34 ++++- 4 files changed, 230 insertions(+), 37 deletions(-) diff --git a/codex-rs/memories/mcp/src/backend.rs b/codex-rs/memories/mcp/src/backend.rs index a4b542cd69..f402a9b8a7 100644 --- a/codex-rs/memories/mcp/src/backend.rs +++ b/codex-rs/memories/mcp/src/backend.rs @@ -87,11 +87,15 @@ pub struct SearchMemoriesResponse { pub truncated: bool, } -#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, JsonSchema)] -#[serde(rename_all = "snake_case")] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)] +#[serde(tag = "type", rename_all = "snake_case")] pub enum SearchMatchMode { Any, - All, + AllOnSameLine, + AllWithinLines { + #[schemars(range(min = 1))] + line_count: usize, + }, } #[derive(Debug, Clone, PartialEq, Eq, Serialize, JsonSchema)] @@ -136,6 +140,8 @@ pub enum MemoriesBackendError { NotFile { path: String }, #[error("queries must not be empty or contain empty strings")] EmptyQuery, + #[error("all_within_lines.line_count must be a positive integer")] + InvalidMatchWindow, #[error("I/O error while reading memories: {0}")] Io(#[from] std::io::Error), } diff --git a/codex-rs/memories/mcp/src/local.rs b/codex-rs/memories/mcp/src/local.rs index 0f642d43f0..f16dba195c 100644 --- a/codex-rs/memories/mcp/src/local.rs +++ b/codex-rs/memories/mcp/src/local.rs @@ -224,6 +224,12 @@ impl MemoriesBackend for LocalMemoriesBackend { if queries.is_empty() || queries.iter().any(std::string::String::is_empty) { return Err(MemoriesBackendError::EmptyQuery); } + if matches!( + request.match_mode, + SearchMatchMode::AllWithinLines { line_count: 0 } + ) { + return Err(MemoriesBackendError::InvalidMatchWindow); + } let max_results = request.max_results.min(MAX_SEARCH_RESULTS); let start = self.resolve_scoped_path(request.path.as_deref()).await?; @@ -240,8 +246,11 @@ impl MemoriesBackend for LocalMemoriesBackend { }; reject_symlink(&display_relative_path(&self.root, &start), &metadata)?; - let matcher = - SearchMatcher::new(queries.clone(), request.match_mode, request.case_sensitive); + let matcher = SearchMatcher::new( + queries.clone(), + request.match_mode.clone(), + request.case_sensitive, + ); let mut matches = Vec::new(); search_entries( &self.root, @@ -329,26 +338,116 @@ async fn search_file( Err(err) => return Err(err.into()), }; let lines = content.lines().collect::>(); - for (idx, line) in lines.iter().enumerate() { - let matched_queries = matcher.matched_queries(line); - if !matched_queries.is_empty() { - let start_index = idx.saturating_sub(context_lines); - let end_index = idx - .saturating_add(context_lines) - .saturating_add(1) - .min(lines.len()); - matches.push(MemorySearchMatch { - path: display_relative_path(root, path), - match_line_number: idx + 1, - content_start_line_number: start_index + 1, - content: lines[start_index..end_index].join("\n"), - matched_queries, - }); + let line_matches = lines + .iter() + .map(|line| matcher.matched_query_flags(line)) + .collect::>(); + match &matcher.match_mode { + SearchMatchMode::Any => { + for (idx, matched_query_flags) in line_matches.iter().enumerate() { + if matched_query_flags.iter().any(|matched| *matched) { + matches.push(build_search_match( + root, + path, + &lines, + idx, + idx, + context_lines, + matcher.matched_queries(matched_query_flags), + )); + } + } + } + SearchMatchMode::AllOnSameLine => { + for (idx, matched_query_flags) in line_matches.iter().enumerate() { + if matched_query_flags.iter().all(|matched| *matched) { + matches.push(build_search_match( + root, + path, + &lines, + idx, + idx, + context_lines, + matcher.matched_queries(matched_query_flags), + )); + } + } + } + SearchMatchMode::AllWithinLines { line_count } => { + let mut windows = Vec::new(); + for start_index in 0..lines.len() { + if !line_matches[start_index].iter().any(|matched| *matched) { + continue; + } + let last_allowed_index = start_index + .saturating_add(line_count.saturating_sub(1)) + .min(lines.len().saturating_sub(1)); + let mut matched_query_flags = vec![false; matcher.queries.len()]; + for (end_index, line_match_flags) in line_matches + .iter() + .enumerate() + .take(last_allowed_index + 1) + .skip(start_index) + { + for (idx, matched) in line_match_flags.iter().enumerate() { + matched_query_flags[idx] |= matched; + } + if matched_query_flags.iter().all(|matched| *matched) { + windows.push((start_index, end_index, matched_query_flags)); + break; + } + } + } + for (idx, (start_index, end_index, matched_query_flags)) in windows.iter().enumerate() { + let strictly_contains_another_window = windows.iter().enumerate().any( + |(other_idx, (other_start_index, other_end_index, _))| { + idx != other_idx + && start_index <= other_start_index + && end_index >= other_end_index + && (start_index != other_start_index || end_index != other_end_index) + }, + ); + if strictly_contains_another_window { + continue; + } + matches.push(build_search_match( + root, + path, + &lines, + *start_index, + *end_index, + context_lines, + matcher.matched_queries(matched_query_flags), + )); + } } } Ok(()) } +fn build_search_match( + root: &Path, + path: &Path, + lines: &[&str], + match_start_index: usize, + match_end_index: usize, + context_lines: usize, + matched_queries: Vec, +) -> MemorySearchMatch { + let content_start_index = match_start_index.saturating_sub(context_lines); + let content_end_index = match_end_index + .saturating_add(context_lines) + .saturating_add(1) + .min(lines.len()); + MemorySearchMatch { + path: display_relative_path(root, path), + match_line_number: match_start_index + 1, + content_start_line_number: content_start_index + 1, + content: lines[content_start_index..content_end_index].join("\n"), + matched_queries, + } +} + struct SearchMatcher { queries: Vec, normalized_queries: Option>, @@ -370,23 +469,24 @@ impl SearchMatcher { } } - fn matched_queries(&self, line: &str) -> Vec { + fn matched_query_flags(&self, line: &str) -> Vec { let line = match self.normalized_queries.as_ref() { Some(_) => Cow::Owned(line.to_lowercase()), None => Cow::Borrowed(line), }; let queries = self.normalized_queries.as_deref().unwrap_or(&self.queries); - let mut matched_queries = Vec::new(); - for (idx, query) in queries.iter().enumerate() { - if line.as_ref().contains(query) { - matched_queries.push(self.queries[idx].clone()); - } - } - match self.match_mode { - SearchMatchMode::Any => matched_queries, - SearchMatchMode::All if matched_queries.len() == self.queries.len() => matched_queries, - SearchMatchMode::All => Vec::new(), - } + queries + .iter() + .map(|query| line.as_ref().contains(query)) + .collect() + } + + fn matched_queries(&self, matched_query_flags: &[bool]) -> Vec { + self.queries + .iter() + .zip(matched_query_flags) + .filter_map(|(query, matched)| matched.then_some(query.clone())) + .collect() } } diff --git a/codex-rs/memories/mcp/src/local_tests.rs b/codex-rs/memories/mcp/src/local_tests.rs index 7773489532..4edf06eef8 100644 --- a/codex-rs/memories/mcp/src/local_tests.rs +++ b/codex-rs/memories/mcp/src/local_tests.rs @@ -752,7 +752,7 @@ async fn search_supports_case_insensitive_matching() { } #[tokio::test] -async fn search_supports_any_and_all_match_modes() { +async fn search_supports_any_and_all_on_same_line_match_modes() { let tempdir = TempDir::new().expect("tempdir"); tokio::fs::write( tempdir.path().join("MEMORY.md"), @@ -793,11 +793,11 @@ async fn search_supports_any_and_all_match_modes() { ); let mut request = search_request(&["alpha", "needle"]); - request.match_mode = SearchMatchMode::All; + request.match_mode = SearchMatchMode::AllOnSameLine; let all_response = backend(&tempdir) .search(request) .await - .expect("search with all match mode"); + .expect("search with all-on-same-line match mode"); assert_eq!( all_response.matches, vec![MemorySearchMatch { @@ -810,6 +810,63 @@ async fn search_supports_any_and_all_match_modes() { ); } +#[tokio::test] +async fn search_supports_all_within_lines_match_mode() { + let tempdir = TempDir::new().expect("tempdir"); + tokio::fs::write( + tempdir.path().join("MEMORY.md"), + "alpha first\nmiddle\nneedle later\nalpha again needle together\n", + ) + .await + .expect("write memory file"); + + let mut request = search_request(&["alpha", "needle"]); + request.match_mode = SearchMatchMode::AllWithinLines { line_count: 3 }; + request.context_lines = 1; + let response = backend(&tempdir) + .search(request) + .await + .expect("search with all-within-lines match mode"); + + assert_eq!( + response.matches, + vec![ + MemorySearchMatch { + path: "MEMORY.md".to_string(), + match_line_number: 1, + content_start_line_number: 1, + content: "alpha first\nmiddle\nneedle later\nalpha again needle together" + .to_string(), + matched_queries: vec!["alpha".to_string(), "needle".to_string()], + }, + MemorySearchMatch { + path: "MEMORY.md".to_string(), + match_line_number: 4, + content_start_line_number: 3, + content: "needle later\nalpha again needle together".to_string(), + matched_queries: vec!["alpha".to_string(), "needle".to_string()], + }, + ] + ); +} + +#[tokio::test] +async fn search_rejects_zero_line_window() { + let tempdir = TempDir::new().expect("tempdir"); + tokio::fs::write(tempdir.path().join("MEMORY.md"), "needle\n") + .await + .expect("write memory file"); + + let mut request = search_request(&["needle"]); + request.match_mode = SearchMatchMode::AllWithinLines { line_count: 0 }; + let err = backend(&tempdir) + .search(request) + .await + .expect_err("zero-width windows should be rejected"); + + assert!(matches!(err, MemoriesBackendError::InvalidMatchWindow)); +} + #[tokio::test] async fn search_rejects_invalid_cursor() { let tempdir = TempDir::new().expect("tempdir"); diff --git a/codex-rs/memories/mcp/src/server.rs b/codex-rs/memories/mcp/src/server.rs index 25cc7252b9..88eeb8fa18 100644 --- a/codex-rs/memories/mcp/src/server.rs +++ b/codex-rs/memories/mcp/src/server.rs @@ -227,7 +227,7 @@ fn search_tool() -> Tool { let mut tool = Tool::new( Cow::Borrowed(SEARCH_TOOL_NAME), Cow::Borrowed( - "Search Codex memory files for line-based substring matches, optionally requiring any or all query substrings on the same line.", + "Search Codex memory files for substring matches, optionally requiring all query substrings on the same line or within a line window.", ), Arc::new(schema::input_schema_for::()), ); @@ -273,7 +273,10 @@ fn backend_error_to_mcp(err: MemoriesBackendError) -> McpError { | MemoriesBackendError::InvalidMaxLines | MemoriesBackendError::LineOffsetExceedsFileLength | MemoriesBackendError::NotFile { .. } - | MemoriesBackendError::EmptyQuery => McpError::invalid_params(err.to_string(), None), + | MemoriesBackendError::EmptyQuery + | MemoriesBackendError::InvalidMatchWindow => { + McpError::invalid_params(err.to_string(), None) + } MemoriesBackendError::Io(_) => McpError::internal_error(err.to_string(), None), } } @@ -308,6 +311,33 @@ mod tests { ); } + #[test] + fn search_args_accept_windowed_all_match_mode() { + let args: SearchArgs = parse_args(json!({ + "queries": ["alpha", "needle"], + "match_mode": { + "type": "all_within_lines", + "line_count": 3 + } + })) + .expect("windowed all args should parse"); + + let request = args.into_request(); + + assert_eq!( + request, + SearchMemoriesRequest { + queries: vec!["alpha".to_string(), "needle".to_string()], + match_mode: SearchMatchMode::AllWithinLines { line_count: 3 }, + path: None, + cursor: None, + context_lines: 0, + case_sensitive: true, + max_results: DEFAULT_SEARCH_MAX_RESULTS, + } + ); + } + #[test] fn search_args_reject_legacy_single_query() { let err = parse_args::(json!({