mirror of
https://github.com/openai/codex.git
synced 2026-04-24 06:35:50 +00:00
codex-client: add AnyTransport + UDS transport implementation
Add a Unix-domain-socket transport alongside the existing reqwest transport so callers can route OpenAI requests over a local socket (for the Android service proxy) without touching the rest of the client stack. Details: - Introduce `UdsTransport` (unix-only) and `AnyTransport` to unify reqwest and UDS transports behind the existing `HttpTransport` trait. - Implement request preparation with explicit Host/Content-Type handling and optional request-body compression, mirroring the reqwest behavior. - Support both execute + streaming responses using hyper HTTP/1.1 on a UDS socket with proper timeout handling and error mapping. - Re-export `AnyTransport` and `UdsTransport` from codex-client and codex-api so higher layers can select transports without reaching into internals. - Expand codex-client dependencies (hyper, hyper-util, http-body-util, net features on tokio) needed by the UDS implementation.
This commit is contained in:
@@ -9,9 +9,12 @@ pub mod sse;
|
||||
pub mod telemetry;
|
||||
|
||||
pub use crate::requests::headers::build_conversation_headers;
|
||||
pub use codex_client::AnyTransport;
|
||||
pub use codex_client::RequestTelemetry;
|
||||
pub use codex_client::ReqwestTransport;
|
||||
pub use codex_client::TransportError;
|
||||
#[cfg(unix)]
|
||||
pub use codex_client::UdsTransport;
|
||||
|
||||
pub use crate::auth::AuthProvider;
|
||||
pub use crate::common::CompactionInput;
|
||||
|
||||
@@ -9,7 +9,10 @@ async-trait = { workspace = true }
|
||||
bytes = { workspace = true }
|
||||
eventsource-stream = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
http-body-util = { workspace = true }
|
||||
http = { workspace = true }
|
||||
hyper = { workspace = true, features = ["client", "http1"] }
|
||||
hyper-util = { workspace = true, features = ["tokio"] }
|
||||
opentelemetry = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
reqwest = { workspace = true, features = ["json", "stream"] }
|
||||
@@ -19,7 +22,7 @@ rustls-pki-types = { workspace = true }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
serde_json = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
tokio = { workspace = true, features = ["macros", "rt", "time", "sync"] }
|
||||
tokio = { workspace = true, features = ["macros", "net", "rt", "sync", "time"] }
|
||||
tracing = { workspace = true }
|
||||
tracing-opentelemetry = { workspace = true }
|
||||
codex-utils-rustls-provider = { workspace = true }
|
||||
|
||||
@@ -30,7 +30,10 @@ pub use crate::retry::backoff;
|
||||
pub use crate::retry::run_with_retry;
|
||||
pub use crate::sse::sse_stream;
|
||||
pub use crate::telemetry::RequestTelemetry;
|
||||
pub use crate::transport::AnyTransport;
|
||||
pub use crate::transport::ByteStream;
|
||||
pub use crate::transport::HttpTransport;
|
||||
pub use crate::transport::ReqwestTransport;
|
||||
pub use crate::transport::StreamResponse;
|
||||
#[cfg(unix)]
|
||||
pub use crate::transport::UdsTransport;
|
||||
|
||||
@@ -11,6 +11,8 @@ use futures::stream::BoxStream;
|
||||
use http::HeaderMap;
|
||||
use http::Method;
|
||||
use http::StatusCode;
|
||||
use http_body_util::BodyExt;
|
||||
use http_body_util::Full;
|
||||
use tracing::Level;
|
||||
use tracing::enabled;
|
||||
use tracing::trace;
|
||||
@@ -119,6 +121,204 @@ impl ReqwestTransport {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum AnyTransport {
|
||||
Reqwest(ReqwestTransport),
|
||||
#[cfg(unix)]
|
||||
Uds(UdsTransport),
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct UdsTransport {
|
||||
socket_path: std::path::PathBuf,
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
impl UdsTransport {
|
||||
pub fn new(socket_path: std::path::PathBuf) -> Self {
|
||||
Self { socket_path }
|
||||
}
|
||||
|
||||
async fn send_request(
|
||||
&self,
|
||||
req: Request,
|
||||
) -> Result<hyper::Response<hyper::body::Incoming>, TransportError> {
|
||||
use hyper::client::conn::http1;
|
||||
use hyper_util::rt::TokioIo;
|
||||
use tokio::net::UnixStream;
|
||||
|
||||
let PreparedRequest {
|
||||
method,
|
||||
uri,
|
||||
headers,
|
||||
body,
|
||||
timeout,
|
||||
} = prepare_request(req)?;
|
||||
|
||||
let request_body = match body {
|
||||
Some(body) => Full::new(Bytes::from(body)),
|
||||
None => Full::new(Bytes::new()),
|
||||
};
|
||||
|
||||
let request = {
|
||||
let mut builder = hyper::Request::builder().method(method).uri(uri);
|
||||
for (name, value) in headers.iter() {
|
||||
builder = builder.header(name, value);
|
||||
}
|
||||
builder
|
||||
.body(request_body)
|
||||
.map_err(|err| TransportError::Build(err.to_string()))?
|
||||
};
|
||||
|
||||
let connect = async {
|
||||
let stream = UnixStream::connect(&self.socket_path)
|
||||
.await
|
||||
.map_err(|err| TransportError::Network(err.to_string()))?;
|
||||
let io = TokioIo::new(stream);
|
||||
let (mut sender, conn) = http1::handshake(io)
|
||||
.await
|
||||
.map_err(|err| TransportError::Network(err.to_string()))?;
|
||||
tokio::spawn(async move {
|
||||
if let Err(err) = conn.await {
|
||||
tracing::debug!("UDS connection closed: {err}");
|
||||
}
|
||||
});
|
||||
sender
|
||||
.send_request(request)
|
||||
.await
|
||||
.map_err(|err| TransportError::Network(err.to_string()))
|
||||
};
|
||||
|
||||
if let Some(timeout) = timeout {
|
||||
tokio::time::timeout(timeout, connect)
|
||||
.await
|
||||
.map_err(|_| TransportError::Timeout)?
|
||||
} else {
|
||||
connect.await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
struct PreparedRequest {
|
||||
method: Method,
|
||||
uri: http::Uri,
|
||||
headers: HeaderMap,
|
||||
body: Option<Vec<u8>>,
|
||||
timeout: Option<std::time::Duration>,
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
fn prepare_request(req: Request) -> Result<PreparedRequest, TransportError> {
|
||||
use http::header::CONTENT_ENCODING;
|
||||
use http::header::CONTENT_TYPE;
|
||||
use http::header::HOST;
|
||||
|
||||
let Request {
|
||||
method,
|
||||
url,
|
||||
mut headers,
|
||||
body,
|
||||
compression,
|
||||
timeout,
|
||||
} = req;
|
||||
|
||||
let uri = build_uds_uri(&url)?;
|
||||
|
||||
if !headers.contains_key(HOST)
|
||||
&& let Ok(host) = host_header_from_url(&url)
|
||||
{
|
||||
headers.insert(HOST, host);
|
||||
}
|
||||
|
||||
let body_bytes = if let Some(body) = body {
|
||||
if compression != RequestCompression::None {
|
||||
if headers.contains_key(CONTENT_ENCODING) {
|
||||
return Err(TransportError::Build(
|
||||
"request compression was requested but content-encoding is already set"
|
||||
.to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let json =
|
||||
serde_json::to_vec(&body).map_err(|err| TransportError::Build(err.to_string()))?;
|
||||
let pre_compression_bytes = json.len();
|
||||
let compression_start = std::time::Instant::now();
|
||||
let (compressed, content_encoding) = match compression {
|
||||
RequestCompression::None => unreachable!("guarded by compression != None"),
|
||||
RequestCompression::Zstd => (
|
||||
zstd::stream::encode_all(std::io::Cursor::new(json), 3)
|
||||
.map_err(|err| TransportError::Build(err.to_string()))?,
|
||||
http::HeaderValue::from_static("zstd"),
|
||||
),
|
||||
};
|
||||
let post_compression_bytes = compressed.len();
|
||||
let compression_duration = compression_start.elapsed();
|
||||
|
||||
headers.insert(CONTENT_ENCODING, content_encoding);
|
||||
if !headers.contains_key(CONTENT_TYPE) {
|
||||
headers.insert(
|
||||
CONTENT_TYPE,
|
||||
http::HeaderValue::from_static("application/json"),
|
||||
);
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
pre_compression_bytes,
|
||||
post_compression_bytes,
|
||||
compression_duration_ms = compression_duration.as_millis(),
|
||||
"Compressed request body with zstd"
|
||||
);
|
||||
|
||||
Some(compressed)
|
||||
} else {
|
||||
if !headers.contains_key(CONTENT_TYPE) {
|
||||
headers.insert(
|
||||
CONTENT_TYPE,
|
||||
http::HeaderValue::from_static("application/json"),
|
||||
);
|
||||
}
|
||||
Some(serde_json::to_vec(&body).map_err(|err| TransportError::Build(err.to_string()))?)
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Ok(PreparedRequest {
|
||||
method,
|
||||
uri,
|
||||
headers,
|
||||
body: body_bytes,
|
||||
timeout,
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
fn build_uds_uri(url: &str) -> Result<http::Uri, TransportError> {
|
||||
let url = reqwest::Url::parse(url).map_err(|err| TransportError::Build(err.to_string()))?;
|
||||
let path = url.path();
|
||||
let path_and_query = match url.query() {
|
||||
Some(query) => format!("{path}?{query}"),
|
||||
None => path.to_string(),
|
||||
};
|
||||
http::Uri::builder()
|
||||
.path_and_query(path_and_query)
|
||||
.build()
|
||||
.map_err(|err| TransportError::Build(err.to_string()))
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
fn host_header_from_url(url: &str) -> Result<http::HeaderValue, TransportError> {
|
||||
let url = reqwest::Url::parse(url).map_err(|err| TransportError::Build(err.to_string()))?;
|
||||
let host = match (url.host_str(), url.port()) {
|
||||
(Some(host), Some(port)) => format!("{host}:{port}"),
|
||||
(Some(host), None) => host.to_string(),
|
||||
_ => return Err(TransportError::Build("missing host".to_string())),
|
||||
};
|
||||
http::HeaderValue::from_str(&host).map_err(|err| TransportError::Build(err.to_string()))
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl HttpTransport for ReqwestTransport {
|
||||
async fn execute(&self, req: Request) -> Result<Response, TransportError> {
|
||||
@@ -187,3 +387,101 @@ impl HttpTransport for ReqwestTransport {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
#[async_trait]
|
||||
impl HttpTransport for UdsTransport {
|
||||
async fn execute(&self, req: Request) -> Result<Response, TransportError> {
|
||||
if enabled!(Level::TRACE) {
|
||||
trace!(
|
||||
"{} to {}: {}",
|
||||
req.method,
|
||||
req.url,
|
||||
req.body.as_ref().unwrap_or_default()
|
||||
);
|
||||
}
|
||||
|
||||
let url = req.url.clone();
|
||||
let resp = self.send_request(req).await?;
|
||||
let status = resp.status();
|
||||
let headers = resp.headers().clone();
|
||||
let bytes = resp
|
||||
.into_body()
|
||||
.collect()
|
||||
.await
|
||||
.map_err(|err| TransportError::Network(err.to_string()))?
|
||||
.to_bytes();
|
||||
if !status.is_success() {
|
||||
let body = String::from_utf8(bytes.to_vec()).ok();
|
||||
return Err(TransportError::Http {
|
||||
status,
|
||||
url: Some(url),
|
||||
headers: Some(headers),
|
||||
body,
|
||||
});
|
||||
}
|
||||
Ok(Response {
|
||||
status,
|
||||
headers,
|
||||
body: bytes,
|
||||
})
|
||||
}
|
||||
|
||||
async fn stream(&self, req: Request) -> Result<StreamResponse, TransportError> {
|
||||
if enabled!(Level::TRACE) {
|
||||
trace!(
|
||||
"{} to {}: {}",
|
||||
req.method,
|
||||
req.url,
|
||||
req.body.as_ref().unwrap_or_default()
|
||||
);
|
||||
}
|
||||
|
||||
let url = req.url.clone();
|
||||
let resp = self.send_request(req).await?;
|
||||
let status = resp.status();
|
||||
let headers = resp.headers().clone();
|
||||
if !status.is_success() {
|
||||
let body = resp
|
||||
.into_body()
|
||||
.collect()
|
||||
.await
|
||||
.ok()
|
||||
.and_then(|collected| String::from_utf8(collected.to_bytes().to_vec()).ok());
|
||||
return Err(TransportError::Http {
|
||||
status,
|
||||
url: Some(url),
|
||||
headers: Some(headers),
|
||||
body,
|
||||
});
|
||||
}
|
||||
let stream = resp
|
||||
.into_body()
|
||||
.into_data_stream()
|
||||
.map(|result| result.map_err(|err| TransportError::Network(err.to_string())));
|
||||
Ok(StreamResponse {
|
||||
status,
|
||||
headers,
|
||||
bytes: Box::pin(stream),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl HttpTransport for AnyTransport {
|
||||
async fn execute(&self, req: Request) -> Result<Response, TransportError> {
|
||||
match self {
|
||||
AnyTransport::Reqwest(transport) => transport.execute(req).await,
|
||||
#[cfg(unix)]
|
||||
AnyTransport::Uds(transport) => transport.execute(req).await,
|
||||
}
|
||||
}
|
||||
|
||||
async fn stream(&self, req: Request) -> Result<StreamResponse, TransportError> {
|
||||
match self {
|
||||
AnyTransport::Reqwest(transport) => transport.stream(req).await,
|
||||
#[cfg(unix)]
|
||||
AnyTransport::Uds(transport) => transport.stream(req).await,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user