This commit is contained in:
celia-oai
2026-04-15 17:31:10 -07:00
parent 025091d2ab
commit 1a1b5f695a
20 changed files with 94 additions and 67 deletions

View File

@@ -30,8 +30,8 @@ pub(crate) async fn run_responses_command(
let base_auth_manager = codex_login::AuthManager::shared_from_config(
&config, /*enable_codex_api_key_env*/ true,
);
let provider = create_model_provider(config.model_provider, Some(base_auth_manager));
let provider_auth = provider.resolve_auth().await?;
let model_provider = create_model_provider(config.model_provider, Some(base_auth_manager));
let provider_auth = model_provider.resolve_auth().await?;
let client = codex_api::ResponsesClient::new(
codex_api::ReqwestTransport::new(codex_login::default_client::build_reqwest_client()),
provider_auth.api_provider,

View File

@@ -1,3 +1,5 @@
use std::sync::Arc;
use http::HeaderMap;
/// Adds authentication headers to API requests.
@@ -7,12 +9,25 @@ use http::HeaderMap;
/// reach this interface.
pub trait AuthProvider: Send + Sync {
fn add_auth_headers(&self, headers: &mut HeaderMap);
}
fn auth_header_attached(&self) -> bool {
false
}
/// Shared auth handle passed through API clients.
pub type SharedAuthProvider = Arc<dyn AuthProvider>;
fn auth_header_name(&self) -> Option<&'static str> {
None
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
pub struct AuthHeaderTelemetry {
pub attached: bool,
pub name: Option<&'static str>,
}
pub fn auth_header_telemetry(auth: &dyn AuthProvider) -> AuthHeaderTelemetry {
let mut headers = HeaderMap::new();
auth.add_auth_headers(&mut headers);
let name = headers
.contains_key(http::header::AUTHORIZATION)
.then_some("authorization");
AuthHeaderTelemetry {
attached: name.is_some(),
name,
}
}

View File

@@ -1,4 +1,4 @@
use crate::auth::AuthProvider;
use crate::auth::SharedAuthProvider;
use crate::common::CompactionInput;
use crate::endpoint::session::EndpointSession;
use crate::error::ApiError;
@@ -17,7 +17,7 @@ pub struct CompactClient<T: HttpTransport> {
}
impl<T: HttpTransport> CompactClient<T> {
pub fn new(transport: T, provider: Provider, auth: Arc<dyn AuthProvider>) -> Self {
pub fn new(transport: T, provider: Provider, auth: SharedAuthProvider) -> Self {
Self {
session: EndpointSession::new(transport, provider, auth),
}

View File

@@ -1,4 +1,4 @@
use crate::auth::AuthProvider;
use crate::auth::SharedAuthProvider;
use crate::common::MemorySummarizeInput;
use crate::common::MemorySummarizeOutput;
use crate::endpoint::session::EndpointSession;
@@ -17,7 +17,7 @@ pub struct MemoriesClient<T: HttpTransport> {
}
impl<T: HttpTransport> MemoriesClient<T> {
pub fn new(transport: T, provider: Provider, auth: Arc<dyn AuthProvider>) -> Self {
pub fn new(transport: T, provider: Provider, auth: SharedAuthProvider) -> Self {
Self {
session: EndpointSession::new(transport, provider, auth),
}
@@ -67,6 +67,7 @@ struct SummarizeResponse {
#[cfg(test)]
mod tests {
use super::*;
use crate::auth::AuthProvider;
use crate::common::RawMemory;
use crate::common::RawMemoryMetadata;
use crate::provider::RetryConfig;

View File

@@ -1,4 +1,4 @@
use crate::auth::AuthProvider;
use crate::auth::SharedAuthProvider;
use crate::endpoint::session::EndpointSession;
use crate::error::ApiError;
use crate::provider::Provider;
@@ -16,7 +16,7 @@ pub struct ModelsClient<T: HttpTransport> {
}
impl<T: HttpTransport> ModelsClient<T> {
pub fn new(transport: T, provider: Provider, auth: Arc<dyn AuthProvider>) -> Self {
pub fn new(transport: T, provider: Provider, auth: SharedAuthProvider) -> Self {
Self {
session: EndpointSession::new(transport, provider, auth),
}
@@ -76,6 +76,7 @@ impl<T: HttpTransport> ModelsClient<T> {
#[cfg(test)]
mod tests {
use super::*;
use crate::auth::AuthProvider;
use crate::provider::RetryConfig;
use async_trait::async_trait;
use codex_client::Request;

View File

@@ -1,4 +1,4 @@
use crate::auth::AuthProvider;
use crate::auth::SharedAuthProvider;
use crate::endpoint::realtime_websocket::RealtimeSessionConfig;
use crate::endpoint::realtime_websocket::session_update_session_json;
use crate::endpoint::session::EndpointSession;
@@ -45,7 +45,7 @@ struct BackendRealtimeCallRequest<'a> {
}
impl<T: HttpTransport> RealtimeCallClient<T> {
pub fn new(transport: T, provider: Provider, auth: Arc<dyn AuthProvider>) -> Self {
pub fn new(transport: T, provider: Provider, auth: SharedAuthProvider) -> Self {
Self {
session: EndpointSession::new(transport, provider, auth),
}
@@ -221,6 +221,7 @@ fn decode_call_id_from_location(headers: &HeaderMap) -> Result<String, ApiError>
#[cfg(test)]
mod tests {
use super::*;
use crate::auth::AuthProvider;
use crate::endpoint::realtime_websocket::RealtimeEventParser;
use crate::endpoint::realtime_websocket::RealtimeOutputModality;
use crate::endpoint::realtime_websocket::RealtimeSessionMode;

View File

@@ -1,4 +1,4 @@
use crate::auth::AuthProvider;
use crate::auth::SharedAuthProvider;
use crate::common::ResponseStream;
use crate::common::ResponsesApiRequest;
use crate::endpoint::session::EndpointSession;
@@ -38,7 +38,7 @@ pub struct ResponsesOptions {
}
impl<T: HttpTransport> ResponsesClient<T> {
pub fn new(transport: T, provider: Provider, auth: Arc<dyn AuthProvider>) -> Self {
pub fn new(transport: T, provider: Provider, auth: SharedAuthProvider) -> Self {
Self {
session: EndpointSession::new(transport, provider, auth),
sse_telemetry: None,

View File

@@ -1,4 +1,4 @@
use crate::auth::AuthProvider;
use crate::auth::SharedAuthProvider;
use crate::common::ResponseEvent;
use crate::common::ResponseStream;
use crate::common::ResponsesWsRequest;
@@ -281,11 +281,11 @@ impl ResponsesWebsocketConnection {
pub struct ResponsesWebsocketClient {
provider: Provider,
auth: Arc<dyn AuthProvider>,
auth: SharedAuthProvider,
}
impl ResponsesWebsocketClient {
pub fn new(provider: Provider, auth: Arc<dyn AuthProvider>) -> Self {
pub fn new(provider: Provider, auth: SharedAuthProvider) -> Self {
Self { provider, auth }
}

View File

@@ -1,4 +1,4 @@
use crate::auth::AuthProvider;
use crate::auth::SharedAuthProvider;
use crate::error::ApiError;
use crate::provider::Provider;
use crate::telemetry::run_with_request_telemetry;
@@ -17,12 +17,12 @@ use tracing::instrument;
pub(crate) struct EndpointSession<T: HttpTransport> {
transport: T,
provider: Provider,
auth: Arc<dyn AuthProvider>,
auth: SharedAuthProvider,
request_telemetry: Option<Arc<dyn RequestTelemetry>>,
}
impl<T: HttpTransport> EndpointSession<T> {
pub(crate) fn new(transport: T, provider: Provider, auth: Arc<dyn AuthProvider>) -> Self {
pub(crate) fn new(transport: T, provider: Provider, auth: SharedAuthProvider) -> Self {
Self {
transport,
provider,

View File

@@ -96,7 +96,7 @@ pub fn openai_file_uri(file_id: &str) -> String {
pub async fn upload_local_file(
base_url: &str,
auth: &impl AuthProvider,
auth: &dyn AuthProvider,
path: &Path,
) -> Result<UploadedOpenAiFile, OpenAiFileError> {
let metadata = tokio::fs::metadata(path)
@@ -252,7 +252,7 @@ pub async fn upload_local_file(
}
fn authorized_request(
auth: &impl AuthProvider,
auth: &dyn AuthProvider,
method: reqwest::Method,
url: &str,
) -> reqwest::RequestBuilder {

View File

@@ -16,7 +16,10 @@ pub use codex_client::ReqwestTransport;
pub use codex_client::TransportError;
pub use crate::api_bridge::map_api_error;
pub use crate::auth::AuthHeaderTelemetry;
pub use crate::auth::AuthProvider;
pub use crate::auth::SharedAuthProvider;
pub use crate::auth::auth_header_telemetry;
pub use crate::common::CompactionInput;
pub use crate::common::MemorySummarizeInput;
pub use crate::common::MemorySummarizeOutput;

View File

@@ -52,9 +52,11 @@ use codex_api::ResponsesOptions as ApiResponsesOptions;
use codex_api::ResponsesWebsocketClient as ApiWebSocketResponsesClient;
use codex_api::ResponsesWebsocketConnection as ApiWebSocketConnection;
use codex_api::ResponsesWsRequest;
use codex_api::SharedAuthProvider;
use codex_api::SseTelemetry;
use codex_api::TransportError;
use codex_api::WebsocketTelemetry;
use codex_api::auth_header_telemetry;
use codex_api::build_conversation_headers;
use codex_api::create_text_param_for_request;
use codex_api::response_create_client_metadata;
@@ -280,26 +282,26 @@ impl ModelClient {
auth_manager: Option<Arc<AuthManager>>,
conversation_id: ThreadId,
installation_id: String,
provider: ModelProviderInfo,
provider_info: ModelProviderInfo,
session_source: SessionSource,
model_verbosity: Option<VerbosityConfig>,
enable_request_compression: bool,
include_timing_metrics: bool,
beta_features_header: Option<String>,
) -> Self {
let provider = create_model_provider(provider, auth_manager);
let codex_api_key_env_enabled = provider
let model_provider = create_model_provider(provider_info, auth_manager);
let codex_api_key_env_enabled = model_provider
.auth_manager()
.as_ref()
.is_some_and(|manager| manager.codex_api_key_env_enabled());
let auth_env_telemetry =
collect_auth_env_telemetry(provider.info(), codex_api_key_env_enabled);
collect_auth_env_telemetry(model_provider.info(), codex_api_key_env_enabled);
Self {
state: Arc::new(ModelClientState {
conversation_id,
window_generation: AtomicU64::new(0),
installation_id,
provider,
provider: model_provider,
auth_env_telemetry,
session_source,
model_verbosity,
@@ -655,7 +657,7 @@ impl ModelClient {
&self,
session_telemetry: &SessionTelemetry,
api_provider: codex_api::Provider,
api_auth: Arc<dyn AuthProvider>,
api_auth: SharedAuthProvider,
turn_state: Option<Arc<OnceLock<String>>>,
turn_metadata_header: Option<&str>,
auth_context: AuthRequestTelemetryContext,
@@ -1662,13 +1664,14 @@ impl AuthRequestTelemetryContext {
api_auth: &dyn AuthProvider,
retry: PendingUnauthorizedRetry,
) -> Self {
let auth_telemetry = auth_header_telemetry(api_auth);
Self {
auth_mode: auth_mode.map(|mode| match mode {
AuthMode::ApiKey => "ApiKey",
AuthMode::Chatgpt | AuthMode::ChatgptAuthTokens => "Chatgpt",
}),
auth_header_attached: api_auth.auth_header_attached(),
auth_header_name: api_auth.auth_header_name(),
auth_header_attached: auth_telemetry.attached,
auth_header_name: auth_telemetry.name,
retry_after_unauthorized: retry.retry_after_unauthorized,
recovery_mode: retry.recovery_mode,
recovery_phase: retry.recovery_phase,
@@ -1679,7 +1682,7 @@ impl AuthRequestTelemetryContext {
struct WebsocketConnectParams<'a> {
session_telemetry: &'a SessionTelemetry,
api_provider: codex_api::Provider,
api_auth: Arc<dyn AuthProvider>,
api_auth: SharedAuthProvider,
turn_metadata_header: Option<&'a str>,
options: &'a ApiResponsesOptions,
auth_context: AuthRequestTelemetryContext,

View File

@@ -1570,7 +1570,7 @@ impl Session {
conversation_id: ThreadId,
auth_manager: Option<Arc<AuthManager>>,
session_telemetry: &SessionTelemetry,
provider: ModelProviderInfo,
provider_info: ModelProviderInfo,
session_configuration: &SessionConfiguration,
user_shell: &shell::Shell,
shell_zsh_path: Option<&PathBuf>,
@@ -1596,7 +1596,7 @@ impl Session {
let image_generation_tool_auth_allowed =
image_generation_tool_auth_allowed(auth_manager.as_deref());
let auth_manager_for_context = auth_manager.clone();
let provider_for_context = create_model_provider(provider, auth_manager);
let provider_for_context = create_model_provider(provider_info, auth_manager);
let session_telemetry_for_context = session_telemetry;
let tools_config = ToolsConfig::new(&ToolsConfigParams {
model_info: &model_info,

View File

@@ -58,6 +58,7 @@ pub(crate) enum InitialContextInjection {
DoNotInject,
}
// TODO(celia-oai): Move this onto ModelProvider crate.
pub(crate) fn should_use_remote_compact_task(provider: &dyn ModelProvider) -> bool {
provider.info().supports_remote_compaction()
}

View File

@@ -190,7 +190,7 @@ fn build_token_limited_compacted_history_appends_summary_message() {
#[test]
fn should_use_remote_compact_task_for_azure_provider() {
let provider = create_model_provider(
let model_provider = create_model_provider(
ModelProviderInfo {
name: "Azure".into(),
base_url: Some("https://example.com/openai".into()),
@@ -212,7 +212,7 @@ fn should_use_remote_compact_task_for_azure_provider() {
/*auth_manager*/ None,
);
assert!(should_use_remote_compact_task(provider.as_ref()));
assert!(should_use_remote_compact_task(model_provider.as_ref()));
}
#[tokio::test]

View File

@@ -354,10 +354,10 @@ async fn spawn_agent_uses_explorer_role_and_preserves_approval_policy() {
let manager = thread_manager();
session.services.agent_control = manager.agent_control();
let mut config = (*turn.config).clone();
let provider =
let provider_info =
built_in_model_providers(/* openai_base_url */ /*openai_base_url*/ None)["ollama"].clone();
config.model_provider_id = "ollama".to_string();
config.model_provider = provider.clone();
config.model_provider = provider_info.clone();
config
.permissions
.approval_policy
@@ -366,7 +366,7 @@ async fn spawn_agent_uses_explorer_role_and_preserves_approval_policy() {
turn.approval_policy
.set(AskForApproval::OnRequest)
.expect("approval policy should be set");
turn.provider = create_model_provider(provider, turn.auth_manager.clone());
turn.provider = create_model_provider(provider_info, turn.auth_manager.clone());
turn.config = Arc::new(config);
let invocation = invocation(

View File

@@ -1,6 +1,6 @@
use std::sync::Arc;
use codex_api::AuthProvider;
use codex_api::SharedAuthProvider;
use codex_login::AuthManager;
use codex_login::CodexAuth;
use codex_model_provider_info::ModelProviderInfo;
@@ -55,7 +55,6 @@ fn bearer_auth_provider_from_auth(
pub(crate) fn resolve_provider_auth(
auth: Option<CodexAuth>,
provider: &ModelProviderInfo,
) -> codex_protocol::error::Result<Arc<dyn AuthProvider>> {
let api_auth = bearer_auth_provider_from_auth(auth, provider)?;
Ok(Arc::new(api_auth))
) -> codex_protocol::error::Result<SharedAuthProvider> {
Ok(Arc::new(bearer_auth_provider_from_auth(auth, provider)?))
}

View File

@@ -31,16 +31,6 @@ impl AuthProvider for BearerAuthProvider {
let _ = headers.insert("ChatGPT-Account-ID", header);
}
}
fn auth_header_attached(&self) -> bool {
self.token
.as_ref()
.is_some_and(|token| HeaderValue::from_str(&format!("Bearer {token}")).is_ok())
}
fn auth_header_name(&self) -> Option<&'static str> {
self.auth_header_attached().then_some("authorization")
}
}
#[cfg(test)]
@@ -55,8 +45,13 @@ mod tests {
account_id: None,
};
assert!(auth.auth_header_attached());
assert_eq!(auth.auth_header_name(), Some("authorization"));
assert_eq!(
codex_api::auth_header_telemetry(&auth),
codex_api::AuthHeaderTelemetry {
attached: true,
name: Some("authorization"),
}
);
}
#[test]

View File

@@ -2,8 +2,8 @@ use std::fmt;
use std::sync::Arc;
use async_trait::async_trait;
use codex_api::AuthProvider;
use codex_api::Provider;
use codex_api::SharedAuthProvider;
use codex_login::AuthManager;
use codex_login::CodexAuth;
use codex_model_provider_info::ModelProviderInfo;
@@ -38,17 +38,20 @@ pub struct ResolvedProviderAuth {
/// Provider configuration adapted for the API client.
pub api_provider: Provider,
/// Auth provider used to attach request credentials.
pub api_auth: Arc<dyn AuthProvider>,
pub api_auth: SharedAuthProvider,
}
/// Creates the default runtime model provider for configured provider metadata.
pub fn create_model_provider(
info: ModelProviderInfo,
provider_info: ModelProviderInfo,
auth_manager: Option<Arc<AuthManager>>,
) -> SharedModelProvider {
let auth_manager =
auth_manager.map(|auth_manager| auth_manager_for_provider(auth_manager, &info));
Arc::new(ConfiguredModelProvider { info, auth_manager })
auth_manager.map(|auth_manager| auth_manager_for_provider(auth_manager, &provider_info));
Arc::new(ConfiguredModelProvider {
info: provider_info,
auth_manager,
})
}
/// Runtime model provider backed by configured `ModelProviderInfo`.

View File

@@ -7,6 +7,7 @@ use codex_api::ModelsClient;
use codex_api::RequestTelemetry;
use codex_api::ReqwestTransport;
use codex_api::TransportError;
use codex_api::auth_header_telemetry;
use codex_api::map_api_error;
use codex_app_server_protocol::AuthMode;
use codex_feedback::FeedbackRequestTags;
@@ -205,14 +206,17 @@ impl ModelsManager {
}
/// Construct a manager with an explicit provider used for remote model refreshes.
// TODO(celia-oai): Revisit this ownership direction: the model provider should likely
// own or return the models manager instead of requiring the manager to construct and use
// a provider from provider info.
pub fn new_with_provider(
codex_home: PathBuf,
auth_manager: Arc<AuthManager>,
model_catalog: Option<ModelsResponse>,
collaboration_modes_config: CollaborationModesConfig,
provider: ModelProviderInfo,
provider_info: ModelProviderInfo,
) -> Self {
let provider = create_model_provider(provider, Some(auth_manager));
let model_provider = create_model_provider(provider_info, Some(auth_manager));
let cache_path = codex_home.join(MODEL_CACHE_FILE);
let cache_manager = ModelsCacheManager::new(cache_path, DEFAULT_MODEL_CACHE_TTL);
let catalog_mode = if model_catalog.is_some() {
@@ -229,7 +233,7 @@ impl ModelsManager {
collaboration_modes_config,
etag: RwLock::new(None),
cache_manager,
provider,
provider: model_provider,
}
}
@@ -440,10 +444,11 @@ impl ModelsManager {
let auth_mode = provider_auth.auth.as_ref().map(CodexAuth::auth_mode);
let auth_env = collect_auth_env_telemetry(self.provider.info(), codex_api_key_env_enabled);
let transport = ReqwestTransport::new(build_reqwest_client());
let auth_telemetry = auth_header_telemetry(provider_auth.api_auth.as_ref());
let request_telemetry: Arc<dyn RequestTelemetry> = Arc::new(ModelsRequestTelemetry {
auth_mode: auth_mode.map(|mode| TelemetryAuthMode::from(mode).to_string()),
auth_header_attached: provider_auth.api_auth.auth_header_attached(),
auth_header_name: provider_auth.api_auth.auth_header_name(),
auth_header_attached: auth_telemetry.attached,
auth_header_name: auth_telemetry.name,
auth_env,
});
let client = ModelsClient::new(