mirror of
https://github.com/openai/codex.git
synced 2026-04-19 12:14:48 +00:00
Compare commits
2 Commits
codex-debu
...
dh--app-se
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6ae8c38daf | ||
|
|
95bda92f33 |
@@ -346,6 +346,10 @@ client_request_definitions! {
|
||||
params: #[ts(type = "undefined")] #[serde(skip_serializing_if = "Option::is_none")] Option<()>,
|
||||
response: v1::UserInfoResponse,
|
||||
},
|
||||
FindFilesStream {
|
||||
params: FindFilesStreamParams,
|
||||
response: FindFilesStreamResponse,
|
||||
},
|
||||
FuzzyFileSearch {
|
||||
params: FuzzyFileSearchParams,
|
||||
response: FuzzyFileSearchResponse,
|
||||
@@ -579,6 +583,18 @@ pub struct FuzzyFileSearchParams {
|
||||
pub cancellation_token: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(rename_all = "camelCase")]
|
||||
pub struct FindFilesStreamParams {
|
||||
pub query: String,
|
||||
pub roots: Vec<String>,
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
pub exclude: Vec<String>,
|
||||
// if provided, will cancel any previous request that used the same value
|
||||
pub cancellation_token: Option<String>,
|
||||
}
|
||||
|
||||
/// Superset of [`codex_file_search::FileMatch`]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]
|
||||
pub struct FuzzyFileSearchResult {
|
||||
@@ -594,6 +610,24 @@ pub struct FuzzyFileSearchResponse {
|
||||
pub files: Vec<FuzzyFileSearchResult>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(rename_all = "camelCase")]
|
||||
pub struct FindFilesStreamResponse {}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(rename_all = "camelCase")]
|
||||
pub struct FindFilesStreamChunkNotification {
|
||||
pub request_id: RequestId,
|
||||
pub query: String,
|
||||
pub files: Vec<FuzzyFileSearchResult>,
|
||||
pub total_match_count: usize,
|
||||
pub chunk_index: usize,
|
||||
pub chunk_count: usize,
|
||||
pub running: bool,
|
||||
}
|
||||
|
||||
server_notification_definitions! {
|
||||
/// NEW NOTIFICATIONS
|
||||
Error => "error" (v2::ErrorNotification),
|
||||
@@ -622,6 +656,7 @@ server_notification_definitions! {
|
||||
ContextCompacted => "thread/compacted" (v2::ContextCompactedNotification),
|
||||
DeprecationNotice => "deprecationNotice" (v2::DeprecationNoticeNotification),
|
||||
ConfigWarning => "configWarning" (v2::ConfigWarningNotification),
|
||||
FindFilesStreamChunk => "findFilesStream/chunk" (FindFilesStreamChunkNotification),
|
||||
|
||||
/// Notifies the user of world-writable directories on Windows, which cannot be protected by the sandbox.
|
||||
WindowsWorldWritableWarning => "windows/worldWritableWarning" (v2::WindowsWorldWritableWarningNotification),
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use crate::bespoke_event_handling::apply_bespoke_event_handling;
|
||||
use crate::error_code::INTERNAL_ERROR_CODE;
|
||||
use crate::error_code::INVALID_REQUEST_ERROR_CODE;
|
||||
use crate::find_files_stream::run_find_files_stream;
|
||||
use crate::fuzzy_file_search::run_fuzzy_file_search;
|
||||
use crate::models::supported_models;
|
||||
use crate::outgoing_message::OutgoingMessageSender;
|
||||
@@ -34,6 +35,8 @@ use codex_app_server_protocol::DynamicToolSpec as ApiDynamicToolSpec;
|
||||
use codex_app_server_protocol::ExecOneOffCommandResponse;
|
||||
use codex_app_server_protocol::FeedbackUploadParams;
|
||||
use codex_app_server_protocol::FeedbackUploadResponse;
|
||||
use codex_app_server_protocol::FindFilesStreamParams;
|
||||
use codex_app_server_protocol::FindFilesStreamResponse;
|
||||
use codex_app_server_protocol::ForkConversationParams;
|
||||
use codex_app_server_protocol::ForkConversationResponse;
|
||||
use codex_app_server_protocol::FuzzyFileSearchParams;
|
||||
@@ -269,6 +272,7 @@ pub(crate) struct CodexMessageProcessor {
|
||||
pending_rollbacks: PendingRollbacks,
|
||||
turn_summary_store: TurnSummaryStore,
|
||||
pending_fuzzy_searches: Arc<Mutex<HashMap<String, Arc<AtomicBool>>>>,
|
||||
pending_find_files_streams: Arc<Mutex<HashMap<String, Arc<AtomicBool>>>>,
|
||||
feedback: CodexFeedback,
|
||||
}
|
||||
|
||||
@@ -325,6 +329,7 @@ impl CodexMessageProcessor {
|
||||
pending_rollbacks: Arc::new(Mutex::new(HashMap::new())),
|
||||
turn_summary_store: Arc::new(Mutex::new(HashMap::new())),
|
||||
pending_fuzzy_searches: Arc::new(Mutex::new(HashMap::new())),
|
||||
pending_find_files_streams: Arc::new(Mutex::new(HashMap::new())),
|
||||
feedback,
|
||||
}
|
||||
}
|
||||
@@ -572,6 +577,9 @@ impl CodexMessageProcessor {
|
||||
} => {
|
||||
self.get_user_info(request_id).await;
|
||||
}
|
||||
ClientRequest::FindFilesStream { request_id, params } => {
|
||||
self.find_files_stream(request_id, params).await;
|
||||
}
|
||||
ClientRequest::FuzzyFileSearch { request_id, params } => {
|
||||
self.fuzzy_file_search(request_id, params).await;
|
||||
}
|
||||
@@ -4537,6 +4545,55 @@ impl CodexMessageProcessor {
|
||||
self.outgoing.send_response(request_id, response).await;
|
||||
}
|
||||
|
||||
async fn find_files_stream(&mut self, request_id: RequestId, params: FindFilesStreamParams) {
|
||||
let FindFilesStreamParams {
|
||||
query,
|
||||
roots,
|
||||
exclude,
|
||||
cancellation_token,
|
||||
} = params;
|
||||
|
||||
let cancel_flag = match cancellation_token.clone() {
|
||||
Some(token) => {
|
||||
let mut pending_streams = self.pending_find_files_streams.lock().await;
|
||||
if let Some(existing) = pending_streams.get(&token) {
|
||||
existing.store(true, Ordering::Relaxed);
|
||||
}
|
||||
let flag = Arc::new(AtomicBool::new(false));
|
||||
pending_streams.insert(token.clone(), flag.clone());
|
||||
flag
|
||||
}
|
||||
None => Arc::new(AtomicBool::new(false)),
|
||||
};
|
||||
|
||||
self.outgoing
|
||||
.send_response(request_id.clone(), FindFilesStreamResponse {})
|
||||
.await;
|
||||
|
||||
let outgoing = self.outgoing.clone();
|
||||
let pending_streams = self.pending_find_files_streams.clone();
|
||||
tokio::spawn(async move {
|
||||
run_find_files_stream(
|
||||
request_id.clone(),
|
||||
query,
|
||||
roots,
|
||||
exclude,
|
||||
cancel_flag.clone(),
|
||||
outgoing,
|
||||
)
|
||||
.await;
|
||||
|
||||
if let Some(token) = cancellation_token {
|
||||
let mut pending_streams = pending_streams.lock().await;
|
||||
if let Some(current_flag) = pending_streams.get(&token)
|
||||
&& Arc::ptr_eq(current_flag, &cancel_flag)
|
||||
{
|
||||
pending_streams.remove(&token);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
async fn upload_feedback(&self, request_id: RequestId, params: FeedbackUploadParams) {
|
||||
if !self.config.feedback_enabled {
|
||||
let error = JSONRPCErrorError {
|
||||
|
||||
215
codex-rs/app-server/src/find_files_stream.rs
Normal file
215
codex-rs/app-server/src/find_files_stream.rs
Normal file
@@ -0,0 +1,215 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::atomic::Ordering;
|
||||
|
||||
use codex_app_server_protocol::FindFilesStreamChunkNotification;
|
||||
use codex_app_server_protocol::FuzzyFileSearchResult;
|
||||
use codex_app_server_protocol::RequestId;
|
||||
use codex_app_server_protocol::ServerNotification;
|
||||
use codex_file_search as file_search;
|
||||
use tokio::sync::mpsc;
|
||||
use tracing::warn;
|
||||
|
||||
use crate::outgoing_message::OutgoingMessageSender;
|
||||
|
||||
const DEFAULT_LIMIT: usize = 50;
|
||||
const DEFAULT_THREADS: usize = 4;
|
||||
const DEFAULT_CHUNK_SIZE: usize = 100;
|
||||
|
||||
pub(crate) async fn run_find_files_stream(
|
||||
request_id: RequestId,
|
||||
query: String,
|
||||
roots: Vec<String>,
|
||||
exclude: Vec<String>,
|
||||
cancellation_flag: Arc<AtomicBool>,
|
||||
outgoing: Arc<OutgoingMessageSender>,
|
||||
) {
|
||||
if query.is_empty() || roots.is_empty() {
|
||||
send_chunks(&outgoing, request_id, query, Vec::new(), 0, false).await;
|
||||
return;
|
||||
}
|
||||
|
||||
let (tx, mut rx) = mpsc::unbounded_channel::<StreamEvent>();
|
||||
let mut sessions = Vec::new();
|
||||
let mut started_roots = Vec::new();
|
||||
|
||||
for root in roots.iter() {
|
||||
let reporter = Arc::new(StreamReporter {
|
||||
root: root.clone(),
|
||||
tx: tx.clone(),
|
||||
cancellation_flag: cancellation_flag.clone(),
|
||||
});
|
||||
let session = file_search::create_session(
|
||||
root.as_ref(),
|
||||
file_search::SessionOptions {
|
||||
limit: std::num::NonZeroUsize::new(DEFAULT_LIMIT)
|
||||
.unwrap_or(std::num::NonZeroUsize::MIN),
|
||||
exclude: exclude.clone(),
|
||||
threads: std::num::NonZeroUsize::new(DEFAULT_THREADS)
|
||||
.unwrap_or(std::num::NonZeroUsize::MIN),
|
||||
compute_indices: true,
|
||||
respect_gitignore: true,
|
||||
},
|
||||
reporter,
|
||||
);
|
||||
match session {
|
||||
Ok(session) => {
|
||||
session.update_query(&query);
|
||||
sessions.push(session);
|
||||
started_roots.push(root.clone());
|
||||
}
|
||||
Err(err) => {
|
||||
warn!("find-files-stream failed to start for root '{root}': {err}");
|
||||
let _ = tx.send(StreamEvent::Complete { root: root.clone() });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
drop(tx);
|
||||
|
||||
let mut snapshots: HashMap<String, file_search::FileSearchSnapshot> = HashMap::new();
|
||||
let mut completed: HashMap<String, bool> = HashMap::new();
|
||||
for root in started_roots.iter() {
|
||||
completed.insert(root.clone(), false);
|
||||
}
|
||||
|
||||
while let Some(event) = rx.recv().await {
|
||||
if cancellation_flag.load(Ordering::Relaxed) {
|
||||
break;
|
||||
}
|
||||
|
||||
match event {
|
||||
StreamEvent::Update { root, snapshot } => {
|
||||
snapshots.insert(root.clone(), snapshot);
|
||||
send_aggregate_chunks(&outgoing, request_id.clone(), &query, &snapshots, true)
|
||||
.await;
|
||||
}
|
||||
StreamEvent::Complete { root } => {
|
||||
if let Some(entry) = completed.get_mut(&root) {
|
||||
*entry = true;
|
||||
}
|
||||
if completed.values().all(|done| *done) {
|
||||
send_aggregate_chunks(&outgoing, request_id, &query, &snapshots, false).await;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
drop(sessions);
|
||||
}
|
||||
|
||||
fn aggregate_results(
|
||||
snapshots: &HashMap<String, file_search::FileSearchSnapshot>,
|
||||
) -> (Vec<FuzzyFileSearchResult>, usize) {
|
||||
let mut results = Vec::new();
|
||||
let mut total_match_count: usize = 0;
|
||||
|
||||
for (root, snapshot) in snapshots {
|
||||
total_match_count = total_match_count.saturating_add(snapshot.total_match_count);
|
||||
for matched in snapshot.matches.iter() {
|
||||
results.push(FuzzyFileSearchResult {
|
||||
root: root.clone(),
|
||||
path: matched.path.clone(),
|
||||
file_name: file_search::file_name_from_path(&matched.path),
|
||||
score: matched.score,
|
||||
indices: matched.indices.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
results.sort_by(|a, b| {
|
||||
use std::cmp::Ordering;
|
||||
match b.score.cmp(&a.score) {
|
||||
Ordering::Equal => match a.path.cmp(&b.path) {
|
||||
Ordering::Equal => a.root.cmp(&b.root),
|
||||
other => other,
|
||||
},
|
||||
other => other,
|
||||
}
|
||||
});
|
||||
|
||||
(results, total_match_count)
|
||||
}
|
||||
|
||||
async fn send_aggregate_chunks(
|
||||
outgoing: &OutgoingMessageSender,
|
||||
request_id: RequestId,
|
||||
query: &str,
|
||||
snapshots: &HashMap<String, file_search::FileSearchSnapshot>,
|
||||
running: bool,
|
||||
) {
|
||||
let (results, total_match_count) = aggregate_results(snapshots);
|
||||
send_chunks(
|
||||
outgoing,
|
||||
request_id,
|
||||
query.to_string(),
|
||||
results,
|
||||
total_match_count,
|
||||
running,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
async fn send_chunks(
|
||||
outgoing: &OutgoingMessageSender,
|
||||
request_id: RequestId,
|
||||
query: String,
|
||||
files: Vec<FuzzyFileSearchResult>,
|
||||
total_match_count: usize,
|
||||
running: bool,
|
||||
) {
|
||||
let chunk_count = files.len().max(1).div_ceil(DEFAULT_CHUNK_SIZE);
|
||||
for chunk_index in 0..chunk_count {
|
||||
let start = chunk_index * DEFAULT_CHUNK_SIZE;
|
||||
let end = (start + DEFAULT_CHUNK_SIZE).min(files.len());
|
||||
let chunk = files.get(start..end).unwrap_or_default().to_vec();
|
||||
let notification = FindFilesStreamChunkNotification {
|
||||
request_id: request_id.clone(),
|
||||
query: query.clone(),
|
||||
files: chunk,
|
||||
total_match_count,
|
||||
chunk_index,
|
||||
chunk_count,
|
||||
running,
|
||||
};
|
||||
outgoing
|
||||
.send_server_notification(ServerNotification::FindFilesStreamChunk(notification))
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
enum StreamEvent {
|
||||
Update {
|
||||
root: String,
|
||||
snapshot: file_search::FileSearchSnapshot,
|
||||
},
|
||||
Complete {
|
||||
root: String,
|
||||
},
|
||||
}
|
||||
|
||||
struct StreamReporter {
|
||||
root: String,
|
||||
tx: mpsc::UnboundedSender<StreamEvent>,
|
||||
cancellation_flag: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
impl file_search::SessionReporter for StreamReporter {
|
||||
fn on_update(&self, snapshot: &file_search::FileSearchSnapshot) {
|
||||
if self.cancellation_flag.load(Ordering::Relaxed) {
|
||||
return;
|
||||
}
|
||||
let _ = self.tx.send(StreamEvent::Update {
|
||||
root: self.root.clone(),
|
||||
snapshot: snapshot.clone(),
|
||||
});
|
||||
}
|
||||
|
||||
fn on_complete(&self) {
|
||||
let _ = self.tx.send(StreamEvent::Complete {
|
||||
root: self.root.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -43,6 +43,7 @@ mod config_api;
|
||||
mod dynamic_tools;
|
||||
mod error_code;
|
||||
mod filters;
|
||||
mod find_files_stream;
|
||||
mod fuzzy_file_search;
|
||||
mod message_processor;
|
||||
mod models;
|
||||
|
||||
@@ -599,6 +599,23 @@ impl McpProcess {
|
||||
self.send_request("fuzzyFileSearch", Some(params)).await
|
||||
}
|
||||
|
||||
/// Send a `findFilesStream` JSON-RPC request.
|
||||
pub async fn send_find_files_stream_request(
|
||||
&mut self,
|
||||
query: &str,
|
||||
roots: Vec<String>,
|
||||
cancellation_token: Option<String>,
|
||||
) -> anyhow::Result<i64> {
|
||||
let mut params = serde_json::json!({
|
||||
"query": query,
|
||||
"roots": roots,
|
||||
});
|
||||
if let Some(token) = cancellation_token {
|
||||
params["cancellationToken"] = serde_json::json!(token);
|
||||
}
|
||||
self.send_request("findFilesStream", Some(params)).await
|
||||
}
|
||||
|
||||
async fn send_request(
|
||||
&mut self,
|
||||
method: &str,
|
||||
|
||||
441
codex-rs/app-server/tests/suite/find_file_stream.rs
Normal file
441
codex-rs/app-server/tests/suite/find_file_stream.rs
Normal file
@@ -0,0 +1,441 @@
|
||||
use std::collections::BTreeSet;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Result;
|
||||
use anyhow::anyhow;
|
||||
use app_test_support::McpProcess;
|
||||
use codex_app_server_protocol::FindFilesStreamChunkNotification;
|
||||
use codex_app_server_protocol::FindFilesStreamResponse;
|
||||
use codex_app_server_protocol::JSONRPCNotification;
|
||||
use codex_app_server_protocol::JSONRPCResponse;
|
||||
use codex_app_server_protocol::RequestId;
|
||||
use pretty_assertions::assert_eq;
|
||||
use tempfile::TempDir;
|
||||
use tokio::time::timeout;
|
||||
|
||||
const DEFAULT_READ_TIMEOUT: Duration = Duration::from_secs(10);
|
||||
const CHUNK_METHOD: &str = "findFilesStream/chunk";
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn find_files_stream_single_root_single_match() -> Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
let root = TempDir::new()?;
|
||||
|
||||
std::fs::write(root.path().join("alpha.rs"), "fn alpha() {}")?;
|
||||
std::fs::write(root.path().join("beta.rs"), "fn beta() {}")?;
|
||||
|
||||
let mut mcp = McpProcess::new(codex_home.path()).await?;
|
||||
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()).await??;
|
||||
|
||||
let root_path = root.path().to_string_lossy().to_string();
|
||||
let request_id = mcp
|
||||
.send_find_files_stream_request("alp", vec![root_path.clone()], None)
|
||||
.await?;
|
||||
|
||||
let chunks = collect_final_chunks(&mut mcp, request_id).await?;
|
||||
let files = flatten_files(&chunks);
|
||||
|
||||
assert_eq!(files.len(), 1, "files={files:?}");
|
||||
assert_eq!(files[0].root, root_path);
|
||||
assert_eq!(files[0].path, "alpha.rs");
|
||||
assert_eq!(files[0].file_name, "alpha.rs");
|
||||
assert!(files[0].indices.is_some(), "expected indices for match");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn find_files_stream_empty_query_emits_single_empty_chunk() -> Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
let root = TempDir::new()?;
|
||||
|
||||
std::fs::write(root.path().join("alpha.rs"), "fn alpha() {}")?;
|
||||
|
||||
let mut mcp = McpProcess::new(codex_home.path()).await?;
|
||||
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()).await??;
|
||||
|
||||
let request_id = mcp
|
||||
.send_find_files_stream_request("", vec![root.path().to_string_lossy().to_string()], None)
|
||||
.await?;
|
||||
|
||||
let response = read_response(&mut mcp, request_id).await?;
|
||||
let parsed: FindFilesStreamResponse = serde_json::from_value(response.result)?;
|
||||
assert_eq!(parsed, FindFilesStreamResponse {});
|
||||
|
||||
let (chunks, mismatched_count) = collect_chunks_until_complete(&mut mcp, request_id).await?;
|
||||
assert_eq!(mismatched_count, 0, "unexpected mismatched notifications");
|
||||
assert_eq!(chunks.len(), 1, "chunks={chunks:?}");
|
||||
let chunk = &chunks[0];
|
||||
assert_eq!(chunk.files.len(), 0);
|
||||
assert_eq!(chunk.total_match_count, 0);
|
||||
assert_eq!(chunk.chunk_index, 0);
|
||||
assert_eq!(chunk.chunk_count, 1);
|
||||
assert!(!chunk.running);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn find_files_stream_empty_roots_emits_single_empty_chunk() -> Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
|
||||
let mut mcp = McpProcess::new(codex_home.path()).await?;
|
||||
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()).await??;
|
||||
|
||||
let request_id = mcp
|
||||
.send_find_files_stream_request("alp", Vec::new(), None)
|
||||
.await?;
|
||||
|
||||
let chunks = collect_final_chunks(&mut mcp, request_id).await?;
|
||||
assert_eq!(chunks.len(), 1, "chunks={chunks:?}");
|
||||
let chunk = &chunks[0];
|
||||
assert_eq!(chunk.files.len(), 0);
|
||||
assert_eq!(chunk.total_match_count, 0);
|
||||
assert_eq!(chunk.chunk_count, 1);
|
||||
assert!(!chunk.running);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn find_files_stream_no_matches_returns_empty_files() -> Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
let root = TempDir::new()?;
|
||||
|
||||
std::fs::write(root.path().join("alpha.rs"), "fn alpha() {}")?;
|
||||
|
||||
let mut mcp = McpProcess::new(codex_home.path()).await?;
|
||||
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()).await??;
|
||||
|
||||
let request_id = mcp
|
||||
.send_find_files_stream_request(
|
||||
"zzz",
|
||||
vec![root.path().to_string_lossy().to_string()],
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let chunks = collect_final_chunks(&mut mcp, request_id).await?;
|
||||
let chunk = chunks
|
||||
.iter()
|
||||
.find(|chunk| chunk.chunk_index == 0)
|
||||
.ok_or_else(|| anyhow!("missing chunk 0"))?;
|
||||
|
||||
assert_eq!(chunk.files.len(), 0);
|
||||
assert_eq!(chunk.total_match_count, 0);
|
||||
assert!(!chunk.running);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn find_files_stream_merges_results_across_roots() -> Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
let root_a = TempDir::new()?;
|
||||
let root_b = TempDir::new()?;
|
||||
|
||||
std::fs::write(root_a.path().join("alpha.rs"), "fn alpha() {}")?;
|
||||
std::fs::write(root_b.path().join("alpine.rs"), "fn alpine() {}")?;
|
||||
std::fs::write(root_b.path().join("beta.rs"), "fn beta() {}")?;
|
||||
|
||||
let mut mcp = McpProcess::new(codex_home.path()).await?;
|
||||
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()).await??;
|
||||
|
||||
let root_a_path = root_a.path().to_string_lossy().to_string();
|
||||
let root_b_path = root_b.path().to_string_lossy().to_string();
|
||||
|
||||
let request_id = mcp
|
||||
.send_find_files_stream_request("alp", vec![root_a_path.clone(), root_b_path.clone()], None)
|
||||
.await?;
|
||||
|
||||
let chunks = collect_final_chunks(&mut mcp, request_id).await?;
|
||||
let files = flatten_files(&chunks);
|
||||
|
||||
let observed: BTreeSet<(String, String)> = files
|
||||
.into_iter()
|
||||
.map(|file| (file.root, file.path))
|
||||
.collect();
|
||||
let expected: BTreeSet<(String, String)> = BTreeSet::from([
|
||||
(root_a_path, "alpha.rs".to_string()),
|
||||
(root_b_path, "alpine.rs".to_string()),
|
||||
]);
|
||||
|
||||
assert_eq!(observed, expected);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn find_files_stream_same_token_updates_request_id_and_query() -> Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
let root = TempDir::new()?;
|
||||
|
||||
std::fs::write(root.path().join("alpha.rs"), "fn alpha() {}")?;
|
||||
std::fs::write(root.path().join("beta.rs"), "fn beta() {}")?;
|
||||
|
||||
// Create enough extra files to keep the stream active while we issue a follow-up query.
|
||||
write_matching_files(root.path(), "alpha-extra", 150)?;
|
||||
|
||||
let mut mcp = McpProcess::new(codex_home.path()).await?;
|
||||
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()).await??;
|
||||
|
||||
let root_path = root.path().to_string_lossy().to_string();
|
||||
let token = "shared-token".to_string();
|
||||
|
||||
let first_request_id = mcp
|
||||
.send_find_files_stream_request("alp", vec![root_path.clone()], Some(token.clone()))
|
||||
.await?;
|
||||
let _first_response = read_response(&mut mcp, first_request_id).await?;
|
||||
|
||||
let second_request_id = mcp
|
||||
.send_find_files_stream_request("bet", vec![root_path.clone()], Some(token))
|
||||
.await?;
|
||||
|
||||
let (chunks, _mismatched_count) =
|
||||
collect_chunks_until_complete(&mut mcp, second_request_id).await?;
|
||||
assert_eq!(
|
||||
chunks[0].request_id,
|
||||
RequestId::Integer(second_request_id),
|
||||
"expected notifications to adopt latest request id"
|
||||
);
|
||||
assert_eq!(chunks[0].query, "bet");
|
||||
|
||||
let files = flatten_files(&chunks);
|
||||
assert!(files.iter().any(|file| file.path == "beta.rs"));
|
||||
assert!(
|
||||
chunks
|
||||
.iter()
|
||||
.all(|chunk| chunk.request_id == RequestId::Integer(second_request_id))
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn find_files_stream_same_token_with_different_roots_cancels_old_stream() -> Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
let root_a = TempDir::new()?;
|
||||
let root_b = TempDir::new()?;
|
||||
|
||||
std::fs::write(root_a.path().join("alpha.rs"), "fn alpha() {}")?;
|
||||
std::fs::write(root_b.path().join("beta.rs"), "fn beta() {}")?;
|
||||
|
||||
write_matching_files(root_a.path(), "alpha-extra", 120)?;
|
||||
|
||||
let mut mcp = McpProcess::new(codex_home.path()).await?;
|
||||
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()).await??;
|
||||
|
||||
let token = "root-swap-token".to_string();
|
||||
let root_a_path = root_a.path().to_string_lossy().to_string();
|
||||
let root_b_path = root_b.path().to_string_lossy().to_string();
|
||||
|
||||
let first_request_id = mcp
|
||||
.send_find_files_stream_request("alp", vec![root_a_path], Some(token.clone()))
|
||||
.await?;
|
||||
let _first_response = read_response(&mut mcp, first_request_id).await?;
|
||||
|
||||
let second_request_id = mcp
|
||||
.send_find_files_stream_request("alp", vec![root_b_path.clone()], Some(token))
|
||||
.await?;
|
||||
|
||||
let (chunks, _mismatched_count) =
|
||||
collect_chunks_until_complete(&mut mcp, second_request_id).await?;
|
||||
|
||||
let files = flatten_files(&chunks);
|
||||
assert!(files.iter().all(|file| file.root == root_b_path));
|
||||
assert!(files.iter().all(|file| file.path != "alpha.rs"));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn find_files_stream_enforces_limit_per_root() -> Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
let root = TempDir::new()?;
|
||||
|
||||
write_matching_files(root.path(), "limit-case", 60)?;
|
||||
|
||||
let mut mcp = McpProcess::new(codex_home.path()).await?;
|
||||
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()).await??;
|
||||
|
||||
let request_id = mcp
|
||||
.send_find_files_stream_request(
|
||||
"limit-case",
|
||||
vec![root.path().to_string_lossy().to_string()],
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let chunks = collect_final_chunks(&mut mcp, request_id).await?;
|
||||
let files = flatten_files(&chunks);
|
||||
|
||||
assert_eq!(
|
||||
files.len(),
|
||||
50,
|
||||
"expected limit-per-root to cap emitted matches"
|
||||
);
|
||||
assert!(
|
||||
chunks[0].total_match_count >= 60,
|
||||
"expected total match count to reflect all matches"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn find_files_stream_chunks_results_when_over_chunk_size() -> Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
let root_a = TempDir::new()?;
|
||||
let root_b = TempDir::new()?;
|
||||
let root_c = TempDir::new()?;
|
||||
|
||||
write_matching_files(root_a.path(), "chunk-case", 55)?;
|
||||
write_matching_files(root_b.path(), "chunk-case", 55)?;
|
||||
write_matching_files(root_c.path(), "chunk-case", 55)?;
|
||||
|
||||
let mut mcp = McpProcess::new(codex_home.path()).await?;
|
||||
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()).await??;
|
||||
|
||||
let request_id = mcp
|
||||
.send_find_files_stream_request(
|
||||
"chunk-case",
|
||||
vec![
|
||||
root_a.path().to_string_lossy().to_string(),
|
||||
root_b.path().to_string_lossy().to_string(),
|
||||
root_c.path().to_string_lossy().to_string(),
|
||||
],
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let chunks = collect_final_chunks(&mut mcp, request_id).await?;
|
||||
let chunk_indices: BTreeSet<usize> = chunks.iter().map(|chunk| chunk.chunk_index).collect();
|
||||
|
||||
assert_eq!(chunks[0].chunk_count, 2);
|
||||
assert_eq!(chunk_indices, BTreeSet::from([0, 1]));
|
||||
assert_eq!(flatten_files(&chunks).len(), 150);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn find_files_stream_emits_sorted_unique_indices() -> Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
let root = TempDir::new()?;
|
||||
|
||||
std::fs::write(root.path().join("abcde.rs"), "fn main() {}")?;
|
||||
|
||||
let mut mcp = McpProcess::new(codex_home.path()).await?;
|
||||
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()).await??;
|
||||
|
||||
let request_id = mcp
|
||||
.send_find_files_stream_request(
|
||||
"ace",
|
||||
vec![root.path().to_string_lossy().to_string()],
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let chunks = collect_final_chunks(&mut mcp, request_id).await?;
|
||||
let files = flatten_files(&chunks);
|
||||
|
||||
assert_eq!(files.len(), 1, "files={files:?}");
|
||||
let indices = files[0]
|
||||
.indices
|
||||
.clone()
|
||||
.ok_or_else(|| anyhow!("missing indices"))?;
|
||||
assert_eq!(indices, vec![0, 2, 4]);
|
||||
assert!(is_sorted_unique(&indices));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn collect_final_chunks(
|
||||
mcp: &mut McpProcess,
|
||||
request_id: i64,
|
||||
) -> anyhow::Result<Vec<FindFilesStreamChunkNotification>> {
|
||||
let _response = read_response(mcp, request_id).await?;
|
||||
let (chunks, mismatched_count) = collect_chunks_until_complete(mcp, request_id).await?;
|
||||
if mismatched_count != 0 {
|
||||
anyhow::bail!("saw {mismatched_count} notifications for other request ids");
|
||||
}
|
||||
Ok(chunks)
|
||||
}
|
||||
|
||||
async fn read_response(mcp: &mut McpProcess, request_id: i64) -> anyhow::Result<JSONRPCResponse> {
|
||||
timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(request_id)),
|
||||
)
|
||||
.await?
|
||||
}
|
||||
|
||||
async fn collect_chunks_until_complete(
|
||||
mcp: &mut McpProcess,
|
||||
request_id: i64,
|
||||
) -> anyhow::Result<(Vec<FindFilesStreamChunkNotification>, usize)> {
|
||||
let mut latest_query = String::new();
|
||||
let mut latest_chunk_count = 0usize;
|
||||
let mut latest_chunks = std::collections::BTreeMap::new();
|
||||
let mut mismatched_count = 0usize;
|
||||
|
||||
loop {
|
||||
let notification = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_notification_message(CHUNK_METHOD),
|
||||
)
|
||||
.await??;
|
||||
let chunk = parse_chunk(notification)?;
|
||||
|
||||
if chunk.request_id != RequestId::Integer(request_id) {
|
||||
mismatched_count += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
if chunk.query != latest_query || chunk.chunk_count != latest_chunk_count {
|
||||
latest_query.clear();
|
||||
latest_query.push_str(&chunk.query);
|
||||
latest_chunk_count = chunk.chunk_count;
|
||||
latest_chunks.clear();
|
||||
}
|
||||
|
||||
latest_chunks.insert(chunk.chunk_index, chunk.clone());
|
||||
|
||||
if !chunk.running && latest_chunks.len() == latest_chunk_count {
|
||||
let chunks = latest_chunks.into_values().collect();
|
||||
return Ok((chunks, mismatched_count));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_chunk(
|
||||
notification: JSONRPCNotification,
|
||||
) -> anyhow::Result<FindFilesStreamChunkNotification> {
|
||||
let params = notification
|
||||
.params
|
||||
.ok_or_else(|| anyhow!("notification missing params"))?;
|
||||
let chunk = serde_json::from_value::<FindFilesStreamChunkNotification>(params)?;
|
||||
Ok(chunk)
|
||||
}
|
||||
|
||||
fn flatten_files(
|
||||
chunks: &[FindFilesStreamChunkNotification],
|
||||
) -> Vec<codex_app_server_protocol::FuzzyFileSearchResult> {
|
||||
let mut files = Vec::new();
|
||||
for chunk in chunks {
|
||||
files.extend(chunk.files.clone());
|
||||
}
|
||||
files
|
||||
}
|
||||
|
||||
fn write_matching_files(root: &std::path::Path, prefix: &str, count: usize) -> Result<()> {
|
||||
for index in 0..count {
|
||||
let file_name = format!("{prefix}-{index:03}.rs");
|
||||
std::fs::write(root.join(file_name), "fn main() {}")?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn is_sorted_unique(indices: &[u32]) -> bool {
|
||||
indices.windows(2).all(|pair| pair[0] < pair[1])
|
||||
}
|
||||
@@ -3,6 +3,7 @@ mod auth;
|
||||
mod codex_message_processor_flow;
|
||||
mod config;
|
||||
mod create_thread;
|
||||
mod find_file_stream;
|
||||
mod fork_thread;
|
||||
mod fuzzy_file_search;
|
||||
mod interrupt;
|
||||
|
||||
Reference in New Issue
Block a user