mirror of
https://github.com/openai/codex.git
synced 2026-04-24 14:45:27 +00:00
163 lines
4.8 KiB
Rust
163 lines
4.8 KiB
Rust
use crate::auth::AuthProvider;
|
|
use crate::auth::add_auth_headers;
|
|
use crate::common::CompactionInput;
|
|
use crate::error::ApiError;
|
|
use crate::provider::Provider;
|
|
use crate::provider::WireApi;
|
|
use crate::telemetry::run_with_request_telemetry;
|
|
use codex_client::HttpTransport;
|
|
use codex_client::RequestTelemetry;
|
|
use codex_protocol::models::ResponseItem;
|
|
use http::HeaderMap;
|
|
use http::Method;
|
|
use serde::Deserialize;
|
|
use serde_json::to_value;
|
|
use std::sync::Arc;
|
|
|
|
pub struct CompactClient<T: HttpTransport, A: AuthProvider> {
|
|
transport: T,
|
|
provider: Provider,
|
|
auth: A,
|
|
request_telemetry: Option<Arc<dyn RequestTelemetry>>,
|
|
}
|
|
|
|
impl<T: HttpTransport, A: AuthProvider> CompactClient<T, A> {
|
|
pub fn new(transport: T, provider: Provider, auth: A) -> Self {
|
|
Self {
|
|
transport,
|
|
provider,
|
|
auth,
|
|
request_telemetry: None,
|
|
}
|
|
}
|
|
|
|
pub fn with_telemetry(mut self, request: Option<Arc<dyn RequestTelemetry>>) -> Self {
|
|
self.request_telemetry = request;
|
|
self
|
|
}
|
|
|
|
fn path(&self) -> Result<&'static str, ApiError> {
|
|
match self.provider.wire {
|
|
WireApi::Compact | WireApi::Responses => Ok("responses/compact"),
|
|
WireApi::Chat => Err(ApiError::Stream(
|
|
"compact endpoint requires responses wire api".to_string(),
|
|
)),
|
|
}
|
|
}
|
|
|
|
pub async fn compact(
|
|
&self,
|
|
body: serde_json::Value,
|
|
extra_headers: HeaderMap,
|
|
) -> Result<Vec<ResponseItem>, ApiError> {
|
|
let path = self.path()?;
|
|
let builder = || {
|
|
let mut req = self.provider.build_request(Method::POST, path);
|
|
req.headers.extend(extra_headers.clone());
|
|
req.body = Some(body.clone());
|
|
add_auth_headers(&self.auth, req)
|
|
};
|
|
|
|
let resp = run_with_request_telemetry(
|
|
self.provider.retry.to_policy(),
|
|
self.request_telemetry.clone(),
|
|
builder,
|
|
|req| self.transport.execute(req),
|
|
)
|
|
.await?;
|
|
let parsed: CompactHistoryResponse =
|
|
serde_json::from_slice(&resp.body).map_err(|e| ApiError::Stream(e.to_string()))?;
|
|
Ok(parsed.output)
|
|
}
|
|
|
|
pub async fn compact_input(
|
|
&self,
|
|
input: &CompactionInput<'_>,
|
|
extra_headers: HeaderMap,
|
|
) -> Result<Vec<ResponseItem>, ApiError> {
|
|
let body = to_value(input)
|
|
.map_err(|e| ApiError::Stream(format!("failed to encode compaction input: {e}")))?;
|
|
self.compact(body, extra_headers).await
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct CompactHistoryResponse {
|
|
output: Vec<ResponseItem>,
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use crate::provider::RetryConfig;
|
|
use async_trait::async_trait;
|
|
use codex_client::Request;
|
|
use codex_client::Response;
|
|
use codex_client::StreamResponse;
|
|
use codex_client::TransportError;
|
|
use http::HeaderMap;
|
|
use std::time::Duration;
|
|
|
|
#[derive(Clone, Default)]
|
|
struct DummyTransport;
|
|
|
|
#[async_trait]
|
|
impl HttpTransport for DummyTransport {
|
|
async fn execute(&self, _req: Request) -> Result<Response, TransportError> {
|
|
Err(TransportError::Build("execute should not run".to_string()))
|
|
}
|
|
|
|
async fn stream(&self, _req: Request) -> Result<StreamResponse, TransportError> {
|
|
Err(TransportError::Build("stream should not run".to_string()))
|
|
}
|
|
}
|
|
|
|
#[derive(Clone, Default)]
|
|
struct DummyAuth;
|
|
|
|
impl AuthProvider for DummyAuth {
|
|
fn bearer_token(&self) -> Option<String> {
|
|
None
|
|
}
|
|
}
|
|
|
|
fn provider(wire: WireApi) -> Provider {
|
|
Provider {
|
|
name: "test".to_string(),
|
|
base_url: "https://example.com/v1".to_string(),
|
|
query_params: None,
|
|
wire,
|
|
headers: HeaderMap::new(),
|
|
retry: RetryConfig {
|
|
max_attempts: 1,
|
|
base_delay: Duration::from_millis(1),
|
|
retry_429: false,
|
|
retry_5xx: true,
|
|
retry_transport: true,
|
|
},
|
|
stream_idle_timeout: Duration::from_secs(1),
|
|
}
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn errors_when_wire_is_chat() {
|
|
let client = CompactClient::new(DummyTransport, provider(WireApi::Chat), DummyAuth);
|
|
let input = CompactionInput {
|
|
model: "gpt-test",
|
|
input: &[],
|
|
instructions: "inst",
|
|
};
|
|
let err = client
|
|
.compact_input(&input, HeaderMap::new())
|
|
.await
|
|
.expect_err("expected wire mismatch to fail");
|
|
|
|
match err {
|
|
ApiError::Stream(msg) => {
|
|
assert_eq!(msg, "compact endpoint requires responses wire api");
|
|
}
|
|
other => panic!("unexpected error: {other:?}"),
|
|
}
|
|
}
|
|
}
|