diff --git a/codex-rs/codex-api/src/endpoint/mod.rs b/codex-rs/codex-api/src/endpoint/mod.rs index 70c52b9580..21ebf372a1 100644 --- a/codex-rs/codex-api/src/endpoint/mod.rs +++ b/codex-rs/codex-api/src/endpoint/mod.rs @@ -5,6 +5,7 @@ pub(crate) mod realtime_call; pub(crate) mod realtime_websocket; pub(crate) mod responses; pub(crate) mod responses_websocket; +pub(crate) mod search; mod session; pub use compact::CompactClient; @@ -27,3 +28,4 @@ pub use responses_websocket::ResponsesWebsocketClient; pub use responses_websocket::ResponsesWebsocketClose; pub use responses_websocket::ResponsesWebsocketConnection; pub use responses_websocket::ResponsesWebsocketProbe; +pub use search::SearchClient; diff --git a/codex-rs/codex-api/src/endpoint/search.rs b/codex-rs/codex-api/src/endpoint/search.rs new file mode 100644 index 0000000000..2940f872d8 --- /dev/null +++ b/codex-rs/codex-api/src/endpoint/search.rs @@ -0,0 +1,264 @@ +use crate::auth::SharedAuthProvider; +use crate::endpoint::session::EndpointSession; +use crate::error::ApiError; +use crate::provider::Provider; +use crate::search::SearchRequest; +use crate::search::SearchResponse; +use codex_client::HttpTransport; +use codex_client::RequestTelemetry; +use http::HeaderMap; +use http::Method; +use serde_json::to_value; +use std::sync::Arc; + +pub struct SearchClient { + session: EndpointSession, +} + +impl SearchClient { + pub fn new(transport: T, provider: Provider, auth: SharedAuthProvider) -> Self { + Self { + session: EndpointSession::new(transport, provider, auth), + } + } + + pub fn with_telemetry(self, request: Option>) -> Self { + Self { + session: self.session.with_request_telemetry(request), + } + } + + fn path() -> &'static str { + "alpha/search" + } + + pub async fn search( + &self, + request: &SearchRequest, + extra_headers: HeaderMap, + ) -> Result { + let body = to_value(request) + .map_err(|e| ApiError::Stream(format!("failed to encode search request: {e}")))?; + let resp = self + .session + .execute(Method::POST, Self::path(), extra_headers, Some(body)) + .await?; + serde_json::from_slice(&resp.body) + .map_err(|e| ApiError::Stream(format!("failed to decode search response: {e}"))) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::auth::AuthProvider; + use crate::provider::RetryConfig; + use crate::search::AllowedCaller; + use crate::search::ApproximateLocation; + use crate::search::LocationType; + use crate::search::OpenOperation; + use crate::search::SearchCommands; + use crate::search::SearchContextSize; + use crate::search::SearchFilters; + use crate::search::SearchImageSettings; + use crate::search::SearchInput; + use crate::search::SearchQuery; + use crate::search::SearchSettings; + use async_trait::async_trait; + use codex_client::Request; + use codex_client::RequestBody; + use codex_client::Response; + use codex_client::StreamResponse; + use codex_client::TransportError; + use codex_protocol::models::ContentItem; + use codex_protocol::models::ResponseItem; + use http::StatusCode; + use pretty_assertions::assert_eq; + use serde_json::json; + use std::sync::Mutex; + use std::time::Duration; + + #[derive(Clone, Default)] + struct DummyAuth; + + impl AuthProvider for DummyAuth { + fn add_auth_headers(&self, _headers: &mut HeaderMap) {} + } + + #[derive(Clone)] + struct CapturingTransport { + last_request: Arc>>, + response_body: Arc>, + } + + impl CapturingTransport { + fn new(response_body: Vec) -> Self { + Self { + last_request: Arc::new(Mutex::new(None)), + response_body: Arc::new(response_body), + } + } + } + + #[async_trait] + impl HttpTransport for CapturingTransport { + async fn execute(&self, req: Request) -> Result { + *self.last_request.lock().expect("lock request store") = Some(req); + Ok(Response { + status: StatusCode::OK, + headers: HeaderMap::new(), + body: self.response_body.as_ref().clone().into(), + }) + } + + async fn stream(&self, _req: Request) -> Result { + Err(TransportError::Build("stream should not run".to_string())) + } + } + + fn provider() -> Provider { + Provider { + name: "test".to_string(), + base_url: "https://example.com/v1".to_string(), + query_params: None, + headers: HeaderMap::new(), + retry: RetryConfig { + max_attempts: 1, + base_delay: Duration::from_millis(1), + retry_429: false, + retry_5xx: true, + retry_transport: true, + }, + stream_idle_timeout: Duration::from_secs(1), + } + } + + #[tokio::test] + async fn search_posts_typed_request_and_parses_encrypted_output() { + let transport = CapturingTransport::new( + serde_json::to_vec(&json!({"encrypted_output": "ciphertext"})) + .expect("serialize response"), + ); + let client = SearchClient::new(transport.clone(), provider(), Arc::new(DummyAuth)); + + let response = client + .search( + &SearchRequest { + id: "search-session".to_string(), + model: Some("gpt-test".to_string()), + reasoning: None, + input: Some(SearchInput::Items(vec![ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ + ContentItem::InputText { + text: "find this".to_string(), + }, + ContentItem::InputImage { + image_url: "https://example.com/image.png".to_string(), + detail: None, + }, + ], + phase: None, + }])), + commands: Some(SearchCommands { + search_query: Some(vec![SearchQuery { + q: "OpenAI news".to_string(), + recency: Some(7), + domains: Some(vec!["openai.com".to_string()]), + }]), + open: Some(vec![OpenOperation { + ref_id: "https://openai.com".to_string(), + lineno: Some(12), + }]), + ..Default::default() + }), + settings: Some(SearchSettings { + user_location: Some(ApproximateLocation { + r#type: LocationType::Approximate, + country: Some("US".to_string()), + region: None, + city: Some("San Francisco".to_string()), + timezone: None, + }), + search_context_size: Some(SearchContextSize::Low), + filters: Some(SearchFilters { + allowed_domains: Some(vec!["openai.com".to_string()]), + blocked_domains: Some(vec!["example.com".to_string()]), + }), + image_settings: Some(SearchImageSettings { + max_results: Some(4), + caption: Some(true), + }), + allowed_callers: Some(vec![AllowedCaller::Direct]), + external_web_access: Some(true), + }), + max_output_tokens: Some(2500), + }, + HeaderMap::new(), + ) + .await + .expect("search request should succeed"); + + assert_eq!( + response, + SearchResponse { + encrypted_output: "ciphertext".to_string(), + } + ); + + let request = transport + .last_request + .lock() + .expect("lock request store") + .clone() + .expect("request should be captured"); + let body = request + .body + .as_ref() + .and_then(RequestBody::json) + .expect("request body should be JSON"); + assert_eq!( + body, + &json!({ + "id": "search-session", + "model": "gpt-test", + "input": [{ + "type": "message", + "role": "user", + "content": [ + {"type": "input_text", "text": "find this"}, + { + "type": "input_image", + "image_url": "https://example.com/image.png" + } + ] + }], + "commands": { + "search_query": [{ + "q": "OpenAI news", + "recency": 7, + "domains": ["openai.com"] + }], + "open": [{"ref_id": "https://openai.com", "lineno": 12}] + }, + "settings": { + "user_location": { + "type": "approximate", + "country": "US", + "city": "San Francisco" + }, + "search_context_size": "low", + "filters": { + "allowed_domains": ["openai.com"], + "blocked_domains": ["example.com"] + }, + "image_settings": {"max_results": 4, "caption": true}, + "allowed_callers": ["direct"], + "external_web_access": true + }, + "max_output_tokens": 2500 + }) + ); + } +} diff --git a/codex-rs/codex-api/src/lib.rs b/codex-rs/codex-api/src/lib.rs index 99470cac59..f47791b799 100644 --- a/codex-rs/codex-api/src/lib.rs +++ b/codex-rs/codex-api/src/lib.rs @@ -7,6 +7,7 @@ pub(crate) mod files; pub(crate) mod provider; pub(crate) mod rate_limits; pub(crate) mod requests; +pub(crate) mod search; pub(crate) mod sse; pub(crate) mod telemetry; @@ -58,6 +59,7 @@ pub use crate::endpoint::ResponsesWebsocketClient; pub use crate::endpoint::ResponsesWebsocketClose; pub use crate::endpoint::ResponsesWebsocketConnection; pub use crate::endpoint::ResponsesWebsocketProbe; +pub use crate::endpoint::SearchClient; pub use crate::endpoint::session_update_session_json; pub use crate::error::ApiError; pub use crate::files::upload_local_file; @@ -65,6 +67,31 @@ pub use crate::provider::Provider; pub use crate::provider::RetryConfig; pub use crate::provider::is_azure_responses_provider; pub use crate::requests::Compression; +pub use crate::search::AllowedCaller; +pub use crate::search::ApproximateLocation; +pub use crate::search::ClickOperation; +pub use crate::search::FinanceAssetType; +pub use crate::search::FinanceOperation; +pub use crate::search::FindOperation; +pub use crate::search::LocationType; +pub use crate::search::OpenOperation; +pub use crate::search::ScreenshotOperation; +pub use crate::search::SearchCommands; +pub use crate::search::SearchContextSize; +pub use crate::search::SearchFilters; +pub use crate::search::SearchImageSettings; +pub use crate::search::SearchInput; +pub use crate::search::SearchQuery; +pub use crate::search::SearchRequest; +pub use crate::search::SearchResponse; +pub use crate::search::SearchResponseLength; +pub use crate::search::SearchSettings; +pub use crate::search::SportsFunction; +pub use crate::search::SportsLeague; +pub use crate::search::SportsOperation; +pub use crate::search::SportsToolName; +pub use crate::search::TimeOperation; +pub use crate::search::WeatherOperation; pub use crate::telemetry::SseTelemetry; pub use crate::telemetry::WebsocketTelemetry; pub use codex_protocol::protocol::RealtimeAudioFrame; diff --git a/codex-rs/codex-api/src/search.rs b/codex-rs/codex-api/src/search.rs new file mode 100644 index 0000000000..b841d06a30 --- /dev/null +++ b/codex-rs/codex-api/src/search.rs @@ -0,0 +1,246 @@ +use crate::common::Reasoning; +use codex_protocol::models::ResponseItem; +use serde::Deserialize; +use serde::Serialize; + +#[derive(Debug, Clone, Serialize, PartialEq)] +pub struct SearchRequest { + pub id: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub input: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub commands: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub settings: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub max_output_tokens: Option, +} + +#[derive(Debug, Clone, Serialize, PartialEq)] +#[serde(untagged)] +pub enum SearchInput { + Text(String), + Items(Vec), +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)] +pub struct SearchCommands { + #[serde(skip_serializing_if = "Option::is_none")] + pub search_query: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub image_query: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub open: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub click: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub find: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub screenshot: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub finance: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub weather: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub sports: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub time: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub response_length: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct SearchQuery { + pub q: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub recency: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub domains: Option>, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct OpenOperation { + pub ref_id: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub lineno: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct ClickOperation { + pub ref_id: String, + pub id: u64, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct FindOperation { + pub ref_id: String, + pub pattern: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct ScreenshotOperation { + pub ref_id: String, + pub pageno: u64, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct FinanceOperation { + pub ticker: String, + pub r#type: FinanceAssetType, + #[serde(skip_serializing_if = "Option::is_none")] + pub market: Option, +} + +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "lowercase")] +pub enum FinanceAssetType { + Equity, + Fund, + Crypto, + Index, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct WeatherOperation { + pub location: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub start: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub duration: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct SportsOperation { + #[serde(skip_serializing_if = "Option::is_none")] + pub tool: Option, + pub r#fn: SportsFunction, + pub league: SportsLeague, + #[serde(skip_serializing_if = "Option::is_none")] + pub team: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub opponent: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub date_from: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub date_to: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub num_games: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub locale: Option, +} + +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "lowercase")] +pub enum SportsToolName { + Sports, +} + +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "lowercase")] +pub enum SportsFunction { + Schedule, + Standings, +} + +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "lowercase")] +pub enum SportsLeague { + Nba, + Wnba, + Nfl, + Nhl, + Mlb, + Epl, + Ncaamb, + Ncaawb, + Ipl, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct TimeOperation { + pub utc_offset: String, +} + +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "lowercase")] +pub enum SearchResponseLength { + Short, + Medium, + Long, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)] +pub struct SearchSettings { + #[serde(skip_serializing_if = "Option::is_none")] + pub user_location: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub search_context_size: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub filters: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub image_settings: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub allowed_callers: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub external_web_access: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct ApproximateLocation { + pub r#type: LocationType, + #[serde(skip_serializing_if = "Option::is_none")] + pub country: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub region: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub city: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub timezone: Option, +} + +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "lowercase")] +pub enum LocationType { + Approximate, +} + +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "lowercase")] +pub enum SearchContextSize { + Low, + Medium, + High, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)] +pub struct SearchFilters { + #[serde(skip_serializing_if = "Option::is_none")] + pub allowed_domains: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub blocked_domains: Option>, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)] +pub struct SearchImageSettings { + #[serde(skip_serializing_if = "Option::is_none")] + pub max_results: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub caption: Option, +} + +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum AllowedCaller { + Direct, + Shell, + CodeInterpreter, +} + +#[derive(Debug, Clone, Deserialize, PartialEq, Eq)] +pub struct SearchResponse { + pub encrypted_output: String, +}