thread/list supports filtering by model providers

This commit is contained in:
Owen Lin
2025-11-02 12:04:15 -08:00
parent 6fbb3665a2
commit fbe3078c2f
3 changed files with 88 additions and 2 deletions

View File

@@ -247,6 +247,9 @@ pub struct ThreadListParams {
pub limit: Option<i32>,
/// Optional sort order; defaults to descending.
pub order: Option<SortOrder>,
/// Optional provider filter; when set, only sessions recorded under these
/// providers are returned. When present but empty, includes all providers.
pub model_providers: Option<Vec<String>>,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]

View File

@@ -1115,6 +1115,7 @@ impl CodexMessageProcessor {
cursor,
limit,
order: _,
model_providers,
} = params;
let page_size = limit.unwrap_or(25).max(1) as usize;
@@ -1126,8 +1127,13 @@ impl CodexMessageProcessor {
};
let cursor_ref = cursor_obj.as_ref();
// v2 API does not filter by provider unless specified; include all.
let model_provider_slice: Option<&[String]> = None;
// v2: include all providers by default; if provided, honor non-empty filters.
let model_provider_vec = match model_providers {
Some(v) if v.is_empty() => None,
Some(v) => Some(v),
None => None,
};
let model_provider_slice = model_provider_vec.as_deref();
let fallback_provider = self.config.model_provider_id.clone();
let page = match RolloutRecorder::list_conversations(

View File

@@ -26,6 +26,7 @@ async fn thread_list_basic_empty() -> Result<()> {
cursor: None,
limit: Some(10),
order: None,
model_providers: None,
})
.await?;
let list_resp: JSONRPCResponse = timeout(
@@ -140,6 +141,7 @@ async fn thread_list_pagination_next_cursor_none_on_last_page() -> Result<()> {
cursor: None,
limit: Some(2),
order: None,
model_providers: None,
})
.await?;
let page1_resp: JSONRPCResponse = timeout(
@@ -160,6 +162,7 @@ async fn thread_list_pagination_next_cursor_none_on_last_page() -> Result<()> {
cursor: Some(cursor1),
limit: Some(2),
order: None,
model_providers: None,
})
.await?;
let page2_resp: JSONRPCResponse = timeout(
@@ -179,3 +182,77 @@ async fn thread_list_pagination_next_cursor_none_on_last_page() -> Result<()> {
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn thread_list_respects_provider_filter() -> Result<()> {
let codex_home = TempDir::new()?;
create_minimal_config(codex_home.path())?;
// Create rollouts under two providers.
let _a = create_fake_rollout(
codex_home.path(),
"2025-01-02T10-00-00",
"2025-01-02T10:00:00Z",
)?; // mock_provider
// one with a different provider
let uuid = Uuid::new_v4();
let dir = codex_home
.path()
.join("sessions")
.join("2025")
.join("01")
.join("02");
std::fs::create_dir_all(&dir)?;
let file_path = dir.join(format!("rollout-2025-01-02T11-00-00-{uuid}.jsonl"));
let lines = [json!({
"timestamp": "2025-01-02T11:00:00Z",
"type": "session_meta",
"payload": {
"id": uuid,
"timestamp": "2025-01-02T11:00:00Z",
"cwd": "/",
"originator": "codex",
"cli_version": "0.0.0",
"instructions": null,
"source": "vscode",
"model_provider": "other_provider"
}
})
.to_string(),
json!({
"timestamp": "2025-01-02T11:00:00Z",
"type":"response_item",
"payload": {"type":"message","role":"user","content":[{"type":"input_text","text":"X"}]}
})
.to_string(),
json!({
"timestamp": "2025-01-02T11:00:00Z",
"type":"event_msg",
"payload": {"type":"user_message","message":"X","kind":"plain"}
})
.to_string()];
std::fs::write(file_path, lines.join("\n") + "\n")?;
let mut mcp = McpProcess::new(codex_home.path()).await?;
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()).await??;
// Filter to only other_provider; expect 1 item, nextCursor None.
let list_id = mcp
.send_thread_list_request(ThreadListParams {
cursor: None,
limit: Some(10),
order: None,
model_providers: Some(vec!["other_provider".to_string()]),
})
.await?;
let resp: JSONRPCResponse = timeout(
DEFAULT_READ_TIMEOUT,
mcp.read_stream_until_response_message(RequestId::Integer(list_id)),
)
.await??;
let ThreadListResponse { data, next_cursor } = to_response::<ThreadListResponse>(resp)?;
assert_eq!(data.len(), 1);
assert!(next_cursor.is_none());
Ok(())
}