feat: add opt-in provider runtime abstraction (#17713)

## Summary

- Add `codex-model-provider` as the runtime home for model-provider
behavior that does not belong in `codex-core`, `codex-login`, or
`codex-api`.
- The new crate wraps configured `ModelProviderInfo` in a
`ModelProvider` trait object that can resolve the API provider config,
provider-scoped auth manager, and request auth provider for each call.
- This centralizes provider auth behavior in one place today, and gives
us an extension point for future provider-specific auth, model listing,
request setup, and related runtime behavior.

## Tests
Ran tests manually to make sure that provider auth under different
configs still work as expected.

---------

Co-authored-by: pakrym-oai <pakrym@openai.com>
This commit is contained in:
Celia Chen
2026-04-16 19:27:45 -07:00
committed by GitHub
parent 91e8eebd03
commit a803790a10
45 changed files with 577 additions and 369 deletions

View File

@@ -1,4 +1,4 @@
use crate::auth::AuthProvider;
use crate::auth::SharedAuthProvider;
use crate::common::CompactionInput;
use crate::endpoint::session::EndpointSession;
use crate::error::ApiError;
@@ -12,12 +12,12 @@ use serde::Deserialize;
use serde_json::to_value;
use std::sync::Arc;
pub struct CompactClient<T: HttpTransport, A: AuthProvider> {
session: EndpointSession<T, A>,
pub struct CompactClient<T: HttpTransport> {
session: EndpointSession<T>,
}
impl<T: HttpTransport, A: AuthProvider> CompactClient<T, A> {
pub fn new(transport: T, provider: Provider, auth: A) -> Self {
impl<T: HttpTransport> CompactClient<T> {
pub fn new(transport: T, provider: Provider, auth: SharedAuthProvider) -> Self {
Self {
session: EndpointSession::new(transport, provider, auth),
}
@@ -86,18 +86,8 @@ mod tests {
}
}
#[derive(Clone, Default)]
struct DummyAuth;
impl AuthProvider for DummyAuth {
fn add_auth_headers(&self, _headers: &mut HeaderMap) {}
}
#[test]
fn path_is_responses_compact() {
assert_eq!(
CompactClient::<DummyTransport, DummyAuth>::path(),
"responses/compact"
);
assert_eq!(CompactClient::<DummyTransport>::path(), "responses/compact");
}
}

View File

@@ -1,4 +1,4 @@
use crate::auth::AuthProvider;
use crate::auth::SharedAuthProvider;
use crate::common::MemorySummarizeInput;
use crate::common::MemorySummarizeOutput;
use crate::endpoint::session::EndpointSession;
@@ -12,12 +12,12 @@ use serde::Deserialize;
use serde_json::to_value;
use std::sync::Arc;
pub struct MemoriesClient<T: HttpTransport, A: AuthProvider> {
session: EndpointSession<T, A>,
pub struct MemoriesClient<T: HttpTransport> {
session: EndpointSession<T>,
}
impl<T: HttpTransport, A: AuthProvider> MemoriesClient<T, A> {
pub fn new(transport: T, provider: Provider, auth: A) -> Self {
impl<T: HttpTransport> MemoriesClient<T> {
pub fn new(transport: T, provider: Provider, auth: SharedAuthProvider) -> Self {
Self {
session: EndpointSession::new(transport, provider, auth),
}
@@ -67,6 +67,7 @@ struct SummarizeResponse {
#[cfg(test)]
mod tests {
use super::*;
use crate::auth::AuthProvider;
use crate::common::RawMemory;
use crate::common::RawMemoryMetadata;
use crate::provider::RetryConfig;
@@ -157,7 +158,7 @@ mod tests {
#[test]
fn path_is_memories_trace_summarize_for_wire_compatibility() {
assert_eq!(
MemoriesClient::<DummyTransport, DummyAuth>::path(),
MemoriesClient::<DummyTransport>::path(),
"memories/trace_summarize"
);
}
@@ -178,7 +179,7 @@ mod tests {
let client = MemoriesClient::new(
transport.clone(),
provider("https://example.com/api/codex"),
DummyAuth,
Arc::new(DummyAuth),
);
let input = MemorySummarizeInput {

View File

@@ -1,4 +1,4 @@
use crate::auth::AuthProvider;
use crate::auth::SharedAuthProvider;
use crate::endpoint::session::EndpointSession;
use crate::error::ApiError;
use crate::provider::Provider;
@@ -11,12 +11,12 @@ use http::Method;
use http::header::ETAG;
use std::sync::Arc;
pub struct ModelsClient<T: HttpTransport, A: AuthProvider> {
session: EndpointSession<T, A>,
pub struct ModelsClient<T: HttpTransport> {
session: EndpointSession<T>,
}
impl<T: HttpTransport, A: AuthProvider> ModelsClient<T, A> {
pub fn new(transport: T, provider: Provider, auth: A) -> Self {
impl<T: HttpTransport> ModelsClient<T> {
pub fn new(transport: T, provider: Provider, auth: SharedAuthProvider) -> Self {
Self {
session: EndpointSession::new(transport, provider, auth),
}
@@ -76,6 +76,7 @@ impl<T: HttpTransport, A: AuthProvider> ModelsClient<T, A> {
#[cfg(test)]
mod tests {
use super::*;
use crate::auth::AuthProvider;
use crate::provider::RetryConfig;
use async_trait::async_trait;
use codex_client::Request;
@@ -165,7 +166,7 @@ mod tests {
let client = ModelsClient::new(
transport.clone(),
provider("https://example.com/api/codex"),
DummyAuth,
Arc::new(DummyAuth),
);
let (models, _) = client
@@ -229,7 +230,7 @@ mod tests {
let client = ModelsClient::new(
transport,
provider("https://example.com/api/codex"),
DummyAuth,
Arc::new(DummyAuth),
);
let (models, _) = client
@@ -256,7 +257,7 @@ mod tests {
let client = ModelsClient::new(
transport,
provider("https://example.com/api/codex"),
DummyAuth,
Arc::new(DummyAuth),
);
let (models, etag) = client

View File

@@ -1,4 +1,4 @@
use crate::auth::AuthProvider;
use crate::auth::SharedAuthProvider;
use crate::endpoint::realtime_websocket::RealtimeSessionConfig;
use crate::endpoint::realtime_websocket::session_update_session_json;
use crate::endpoint::session::EndpointSession;
@@ -24,8 +24,8 @@ use tracing::trace;
const MULTIPART_BOUNDARY: &str = "codex-realtime-call-boundary";
const MULTIPART_CONTENT_TYPE: &str = "multipart/form-data; boundary=codex-realtime-call-boundary";
pub struct RealtimeCallClient<T: HttpTransport, A: AuthProvider> {
session: EndpointSession<T, A>,
pub struct RealtimeCallClient<T: HttpTransport> {
session: EndpointSession<T>,
}
/// Answer from creating a WebRTC Realtime call.
@@ -44,8 +44,8 @@ struct BackendRealtimeCallRequest<'a> {
session: &'a Value,
}
impl<T: HttpTransport, A: AuthProvider> RealtimeCallClient<T, A> {
pub fn new(transport: T, provider: Provider, auth: A) -> Self {
impl<T: HttpTransport> RealtimeCallClient<T> {
pub fn new(transport: T, provider: Provider, auth: SharedAuthProvider) -> Self {
Self {
session: EndpointSession::new(transport, provider, auth),
}
@@ -221,6 +221,7 @@ fn decode_call_id_from_location(headers: &HeaderMap) -> Result<String, ApiError>
#[cfg(test)]
mod tests {
use super::*;
use crate::auth::AuthProvider;
use crate::endpoint::realtime_websocket::RealtimeEventParser;
use crate::endpoint::realtime_websocket::RealtimeOutputModality;
use crate::endpoint::realtime_websocket::RealtimeSessionMode;
@@ -327,7 +328,7 @@ mod tests {
let client = RealtimeCallClient::new(
transport.clone(),
provider("https://api.openai.com/v1"),
DummyAuth,
Arc::new(DummyAuth),
);
let response = client
@@ -370,7 +371,7 @@ mod tests {
let client = RealtimeCallClient::new(
transport.clone(),
provider("https://chatgpt.com/backend-api/codex"),
DummyAuth,
Arc::new(DummyAuth),
);
let response = client
@@ -404,7 +405,7 @@ mod tests {
let client = RealtimeCallClient::new(
transport.clone(),
provider("https://api.openai.com/v1"),
DummyAuth,
Arc::new(DummyAuth),
);
let response = client
@@ -466,7 +467,7 @@ mod tests {
let client = RealtimeCallClient::new(
transport.clone(),
provider("https://chatgpt.com/backend-api/codex"),
DummyAuth,
Arc::new(DummyAuth),
);
let response = client
@@ -512,8 +513,11 @@ mod tests {
#[tokio::test]
async fn errors_when_location_is_missing() {
let transport = CapturingTransport::without_location();
let client =
RealtimeCallClient::new(transport, provider("https://api.openai.com/v1"), DummyAuth);
let client = RealtimeCallClient::new(
transport,
provider("https://api.openai.com/v1"),
Arc::new(DummyAuth),
);
let err = client
.create("v=offer\r\n".to_string())

View File

@@ -1,4 +1,4 @@
use crate::auth::AuthProvider;
use crate::auth::SharedAuthProvider;
use crate::common::ResponseStream;
use crate::common::ResponsesApiRequest;
use crate::endpoint::session::EndpointSession;
@@ -23,8 +23,8 @@ use std::sync::Arc;
use std::sync::OnceLock;
use tracing::instrument;
pub struct ResponsesClient<T: HttpTransport, A: AuthProvider> {
session: EndpointSession<T, A>,
pub struct ResponsesClient<T: HttpTransport> {
session: EndpointSession<T>,
sse_telemetry: Option<Arc<dyn SseTelemetry>>,
}
@@ -37,8 +37,8 @@ pub struct ResponsesOptions {
pub turn_state: Option<Arc<OnceLock<String>>>,
}
impl<T: HttpTransport, A: AuthProvider> ResponsesClient<T, A> {
pub fn new(transport: T, provider: Provider, auth: A) -> Self {
impl<T: HttpTransport> ResponsesClient<T> {
pub fn new(transport: T, provider: Provider, auth: SharedAuthProvider) -> Self {
Self {
session: EndpointSession::new(transport, provider, auth),
sse_telemetry: None,

View File

@@ -1,4 +1,4 @@
use crate::auth::AuthProvider;
use crate::auth::SharedAuthProvider;
use crate::common::ResponseEvent;
use crate::common::ResponseStream;
use crate::common::ResponsesWsRequest;
@@ -279,13 +279,13 @@ impl ResponsesWebsocketConnection {
}
}
pub struct ResponsesWebsocketClient<A: AuthProvider> {
pub struct ResponsesWebsocketClient {
provider: Provider,
auth: A,
auth: SharedAuthProvider,
}
impl<A: AuthProvider> ResponsesWebsocketClient<A> {
pub fn new(provider: Provider, auth: A) -> Self {
impl ResponsesWebsocketClient {
pub fn new(provider: Provider, auth: SharedAuthProvider) -> Self {
Self { provider, auth }
}

View File

@@ -1,4 +1,4 @@
use crate::auth::AuthProvider;
use crate::auth::SharedAuthProvider;
use crate::error::ApiError;
use crate::provider::Provider;
use crate::telemetry::run_with_request_telemetry;
@@ -14,15 +14,15 @@ use serde_json::Value;
use std::sync::Arc;
use tracing::instrument;
pub(crate) struct EndpointSession<T: HttpTransport, A: AuthProvider> {
pub(crate) struct EndpointSession<T: HttpTransport> {
transport: T,
provider: Provider,
auth: A,
auth: SharedAuthProvider,
request_telemetry: Option<Arc<dyn RequestTelemetry>>,
}
impl<T: HttpTransport, A: AuthProvider> EndpointSession<T, A> {
pub(crate) fn new(transport: T, provider: Provider, auth: A) -> Self {
impl<T: HttpTransport> EndpointSession<T> {
pub(crate) fn new(transport: T, provider: Provider, auth: SharedAuthProvider) -> Self {
Self {
transport,
provider,