diff --git a/codex-rs/codex-api/src/endpoint/images.rs b/codex-rs/codex-api/src/endpoint/images.rs new file mode 100644 index 0000000000..9d1bd41eea --- /dev/null +++ b/codex-rs/codex-api/src/endpoint/images.rs @@ -0,0 +1,302 @@ +use crate::auth::SharedAuthProvider; +use crate::endpoint::session::EndpointSession; +use crate::error::ApiError; +use crate::images::ImageEditRequest; +use crate::images::ImageGenerationRequest; +use crate::images::ImageResponse; +use crate::provider::Provider; +use codex_client::HttpTransport; +use codex_client::RequestTelemetry; +use http::HeaderMap; +use http::Method; +use serde::Serialize; +use serde_json::to_value; +use std::sync::Arc; + +pub struct ImagesClient { + session: EndpointSession, +} + +impl ImagesClient { + pub fn new(transport: T, provider: Provider, auth: SharedAuthProvider) -> Self { + Self { + session: EndpointSession::new(transport, provider, auth), + } + } + + pub fn with_telemetry(self, request: Option>) -> Self { + Self { + session: self.session.with_request_telemetry(request), + } + } + + pub async fn generate( + &self, + request: &ImageGenerationRequest, + extra_headers: HeaderMap, + ) -> Result { + self.post_image_request( + "images/generations", + request, + extra_headers, + "image generation", + ) + .await + } + + pub async fn edit( + &self, + request: &ImageEditRequest, + extra_headers: HeaderMap, + ) -> Result { + self.post_image_request("images/edits", request, extra_headers, "image edit") + .await + } + + async fn post_image_request( + &self, + path: &str, + request: &R, + extra_headers: HeaderMap, + operation: &str, + ) -> Result { + let body = to_value(request) + .map_err(|e| ApiError::Stream(format!("failed to encode {operation} request: {e}")))?; + let resp = self + .session + .execute(Method::POST, path, extra_headers, Some(body)) + .await?; + serde_json::from_slice(&resp.body) + .map_err(|e| ApiError::Stream(format!("failed to decode {operation} response: {e}"))) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::auth::AuthProvider; + use crate::images::ImageBackground; + use crate::images::ImageData; + use crate::images::ImageQuality; + use crate::images::ImageUrl; + use crate::provider::RetryConfig; + use async_trait::async_trait; + use codex_client::Request; + use codex_client::RequestBody; + use codex_client::Response; + use codex_client::StreamResponse; + use codex_client::TransportError; + use http::StatusCode; + use pretty_assertions::assert_eq; + use serde_json::json; + use std::sync::Mutex; + use std::time::Duration; + + #[derive(Clone, Default)] + struct DummyAuth; + + impl AuthProvider for DummyAuth { + fn add_auth_headers(&self, _headers: &mut HeaderMap) {} + } + + #[derive(Clone)] + struct CapturingTransport { + last_request: Arc>>, + response_body: Arc>, + } + + impl CapturingTransport { + fn new(response_body: Vec) -> Self { + Self { + last_request: Arc::new(Mutex::new(None)), + response_body: Arc::new(response_body), + } + } + } + + #[async_trait] + impl HttpTransport for CapturingTransport { + async fn execute(&self, req: Request) -> Result { + *self.last_request.lock().expect("lock request store") = Some(req); + Ok(Response { + status: StatusCode::OK, + headers: HeaderMap::new(), + body: self.response_body.as_ref().clone().into(), + }) + } + + async fn stream(&self, _req: Request) -> Result { + Err(TransportError::Build("stream should not run".to_string())) + } + } + + fn provider() -> Provider { + Provider { + name: "test".to_string(), + base_url: "https://example.com/api/codex".to_string(), + query_params: None, + headers: HeaderMap::new(), + retry: RetryConfig { + max_attempts: 1, + base_delay: Duration::from_millis(1), + retry_429: false, + retry_5xx: true, + retry_transport: true, + }, + stream_idle_timeout: Duration::from_secs(1), + } + } + + fn response_body() -> Vec { + serde_json::to_vec(&json!({ + "created": 1778832973u64, + "background": "opaque", + "data": [{"b64_json": "REDACT"}], + "output_format": "png", + "quality": "medium", + "size": "1024x1536", + "usage": { + "input_tokens": 1474, + "input_tokens_details": { + "image_tokens": 1457, + "text_tokens": 17, + }, + "output_tokens": 1372, + "output_tokens_details": { + "image_tokens": 1372, + "text_tokens": 0, + }, + "total_tokens": 2846, + } + })) + .expect("serialize response") + } + + fn expected_response() -> ImageResponse { + ImageResponse { + created: 1778832973, + background: Some(ImageBackground::Opaque), + data: vec![ImageData { + b64_json: "REDACT".to_string(), + }], + quality: Some(ImageQuality::Medium), + size: Some("1024x1536".to_string()), + } + } + + fn captured_request(transport: &CapturingTransport) -> Request { + transport + .last_request + .lock() + .expect("lock request store") + .clone() + .expect("request should be captured") + } + + #[tokio::test] + async fn generate_posts_typed_request_and_parses_image_response() { + let transport = CapturingTransport::new(response_body()); + let client = ImagesClient::new(transport.clone(), provider(), Arc::new(DummyAuth)); + + let response = client + .generate( + &ImageGenerationRequest { + prompt: "a red fox in a field".to_string(), + background: Some(ImageBackground::Opaque), + model: "gpt-image-1.5".to_string(), + n: None, + quality: Some(ImageQuality::Medium), + size: Some("1024x1536".to_string()), + }, + HeaderMap::new(), + ) + .await + .expect("image generation request should succeed"); + + assert_eq!(response, expected_response()); + + let request = captured_request(&transport); + assert_eq!( + request.url, + "https://example.com/api/codex/images/generations" + ); + assert_eq!( + request.body.as_ref().and_then(RequestBody::json), + Some(&json!({ + "prompt": "a red fox in a field", + "background": "opaque", + "model": "gpt-image-1.5", + "quality": "medium", + "size": "1024x1536", + })) + ); + } + + #[tokio::test] + async fn edit_posts_typed_request_and_parses_image_response() { + let transport = CapturingTransport::new(response_body()); + let client = ImagesClient::new(transport.clone(), provider(), Arc::new(DummyAuth)); + + let response = client + .edit( + &ImageEditRequest { + images: vec![ImageUrl { + image_url: "data:image/png;base64,Zm9v".to_string(), + }], + prompt: "add a red hat".to_string(), + background: None, + model: "gpt-image-1.5".to_string(), + n: None, + quality: None, + size: None, + }, + HeaderMap::new(), + ) + .await + .expect("image edit request should succeed"); + + assert_eq!(response, expected_response()); + + let request = captured_request(&transport); + assert_eq!(request.url, "https://example.com/api/codex/images/edits"); + assert_eq!( + request.body.as_ref().and_then(RequestBody::json), + Some(&json!({ + "images": [{"image_url": "data:image/png;base64,Zm9v"}], + "prompt": "add a red hat", + "model": "gpt-image-1.5", + })) + ); + } + + #[tokio::test] + async fn image_response_requires_image_data() { + let transport = CapturingTransport::new( + serde_json::to_vec(&json!({"created": 1778832973u64})).expect("serialize response"), + ); + let client = ImagesClient::new(transport, provider(), Arc::new(DummyAuth)); + + let error = client + .generate( + &ImageGenerationRequest { + prompt: "a red fox in a field".to_string(), + background: None, + model: "gpt-image-1.5".to_string(), + n: None, + quality: None, + size: None, + }, + HeaderMap::new(), + ) + .await + .expect_err("image response without data should fail"); + + let ApiError::Stream(message) = error else { + panic!("expected image response decode error"); + }; + assert!( + message.starts_with("failed to decode image generation response: missing field `data`"), + "{message}" + ); + } +} diff --git a/codex-rs/codex-api/src/endpoint/mod.rs b/codex-rs/codex-api/src/endpoint/mod.rs index 21ebf372a1..106c5d73ff 100644 --- a/codex-rs/codex-api/src/endpoint/mod.rs +++ b/codex-rs/codex-api/src/endpoint/mod.rs @@ -1,4 +1,5 @@ pub(crate) mod compact; +pub(crate) mod images; pub(crate) mod memories; pub(crate) mod models; pub(crate) mod realtime_call; @@ -9,6 +10,7 @@ pub(crate) mod search; mod session; pub use compact::CompactClient; +pub use images::ImagesClient; pub use memories::MemoriesClient; pub use models::ModelsClient; pub use realtime_call::RealtimeCallClient; diff --git a/codex-rs/codex-api/src/images.rs b/codex-rs/codex-api/src/images.rs new file mode 100644 index 0000000000..f915a5f78d --- /dev/null +++ b/codex-rs/codex-api/src/images.rs @@ -0,0 +1,70 @@ +use serde::Deserialize; +use serde::Serialize; + +#[derive(Debug, Clone, Serialize, PartialEq)] +pub struct ImageGenerationRequest { + pub prompt: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub background: Option, + pub model: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub n: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub quality: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub size: Option, +} + +#[derive(Debug, Clone, Serialize, PartialEq)] +pub struct ImageEditRequest { + pub images: Vec, + pub prompt: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub background: Option, + pub model: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub n: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub quality: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub size: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct ImageUrl { + pub image_url: String, +} + +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "lowercase")] +pub enum ImageBackground { + Transparent, + Opaque, + Auto, +} + +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "lowercase")] +pub enum ImageQuality { + Low, + Medium, + High, + Auto, +} + +#[derive(Debug, Clone, Deserialize, PartialEq, Eq)] +pub struct ImageResponse { + pub created: u64, + pub data: Vec, + #[serde(default)] + pub background: Option, + #[serde(default)] + pub quality: Option, + #[serde(default)] + pub size: Option, +} + +#[derive(Debug, Clone, Deserialize, PartialEq, Eq)] +pub struct ImageData { + pub b64_json: String, +} diff --git a/codex-rs/codex-api/src/lib.rs b/codex-rs/codex-api/src/lib.rs index f47791b799..08176d8dec 100644 --- a/codex-rs/codex-api/src/lib.rs +++ b/codex-rs/codex-api/src/lib.rs @@ -4,6 +4,7 @@ pub(crate) mod common; pub(crate) mod endpoint; pub(crate) mod error; pub(crate) mod files; +pub(crate) mod images; pub(crate) mod provider; pub(crate) mod rate_limits; pub(crate) mod requests; @@ -41,6 +42,7 @@ pub use crate::common::WS_REQUEST_HEADER_TRACESTATE_CLIENT_METADATA_KEY; pub use crate::common::create_text_param_for_request; pub use crate::common::response_create_client_metadata; pub use crate::endpoint::CompactClient; +pub use crate::endpoint::ImagesClient; pub use crate::endpoint::MemoriesClient; pub use crate::endpoint::ModelsClient; pub use crate::endpoint::RealtimeCallClient; @@ -63,6 +65,13 @@ pub use crate::endpoint::SearchClient; pub use crate::endpoint::session_update_session_json; pub use crate::error::ApiError; pub use crate::files::upload_local_file; +pub use crate::images::ImageBackground; +pub use crate::images::ImageData; +pub use crate::images::ImageEditRequest; +pub use crate::images::ImageGenerationRequest; +pub use crate::images::ImageQuality; +pub use crate::images::ImageResponse; +pub use crate::images::ImageUrl; pub use crate::provider::Provider; pub use crate::provider::RetryConfig; pub use crate::provider::is_azure_responses_provider;