mirror of
https://github.com/openai/codex.git
synced 2026-04-14 01:35:00 +00:00
Compare commits
5 Commits
dev/shaqay
...
dev/cc/pro
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a033ca9721 | ||
|
|
936f574dd5 | ||
|
|
46aaf95a4f | ||
|
|
d16f64936e | ||
|
|
6b5dcb9afd |
16
codex-rs/Cargo.lock
generated
16
codex-rs/Cargo.lock
generated
@@ -1877,6 +1877,7 @@ dependencies = [
|
||||
"codex-instructions",
|
||||
"codex-login",
|
||||
"codex-mcp",
|
||||
"codex-model-provider",
|
||||
"codex-model-provider-info",
|
||||
"codex-models-manager",
|
||||
"codex-network-proxy",
|
||||
@@ -2265,10 +2266,12 @@ dependencies = [
|
||||
"codex-client",
|
||||
"codex-config",
|
||||
"codex-keyring-store",
|
||||
"codex-model-provider",
|
||||
"codex-model-provider-info",
|
||||
"codex-otel",
|
||||
"codex-protocol",
|
||||
"codex-terminal-detection",
|
||||
"codex-utils-absolute-path",
|
||||
"codex-utils-template",
|
||||
"core_test_support",
|
||||
"keyring",
|
||||
@@ -2354,6 +2357,19 @@ dependencies = [
|
||||
"wiremock",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "codex-model-provider"
|
||||
version = "0.0.0"
|
||||
dependencies = [
|
||||
"codex-api",
|
||||
"codex-app-server-protocol",
|
||||
"codex-model-provider-info",
|
||||
"codex-protocol",
|
||||
"codex-utils-absolute-path",
|
||||
"pretty_assertions",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "codex-model-provider-info"
|
||||
version = "0.0.0"
|
||||
|
||||
@@ -42,6 +42,7 @@ members = [
|
||||
"login",
|
||||
"codex-mcp",
|
||||
"mcp-server",
|
||||
"model-provider",
|
||||
"model-provider-info",
|
||||
"models-manager",
|
||||
"network-proxy",
|
||||
@@ -141,6 +142,7 @@ codex-lmstudio = { path = "lmstudio" }
|
||||
codex-login = { path = "login" }
|
||||
codex-mcp = { path = "codex-mcp" }
|
||||
codex-mcp-server = { path = "mcp-server" }
|
||||
codex-model-provider = { path = "model-provider" }
|
||||
codex-model-provider-info = { path = "model-provider-info" }
|
||||
codex-models-manager = { path = "models-manager" }
|
||||
codex-network-proxy = { path = "network-proxy" }
|
||||
|
||||
@@ -39,6 +39,7 @@ codex-features = { workspace = true }
|
||||
codex-feedback = { workspace = true }
|
||||
codex-login = { workspace = true }
|
||||
codex-mcp = { workspace = true }
|
||||
codex-model-provider = { workspace = true }
|
||||
codex-model-provider-info = { workspace = true }
|
||||
codex-models-manager = { workspace = true }
|
||||
codex-shell-command = { workspace = true }
|
||||
|
||||
@@ -103,9 +103,12 @@ use codex_api::map_api_error;
|
||||
use codex_feedback::FeedbackRequestTags;
|
||||
use codex_feedback::emit_feedback_request_tags_with_auth_env;
|
||||
use codex_login::api_bridge::auth_provider_from_auth;
|
||||
use codex_login::api_bridge::auth_provider_from_runtime;
|
||||
use codex_login::auth_env_telemetry::AuthEnvTelemetry;
|
||||
use codex_login::auth_env_telemetry::collect_auth_env_telemetry;
|
||||
use codex_login::provider_auth::auth_manager_for_provider;
|
||||
use codex_login::auth_env_telemetry::collect_auth_env_telemetry_for_runtime;
|
||||
use codex_login::provider_auth::auth_manager_for_provider_runtime;
|
||||
use codex_model_provider::ProviderRuntime;
|
||||
use codex_model_provider::ResolvedModelProvider;
|
||||
#[cfg(test)]
|
||||
use codex_model_provider_info::DEFAULT_WEBSOCKET_CONNECT_TIMEOUT_MS;
|
||||
use codex_model_provider_info::ModelProviderInfo;
|
||||
@@ -145,6 +148,7 @@ struct ModelClientState {
|
||||
window_generation: AtomicU64,
|
||||
installation_id: String,
|
||||
provider: ModelProviderInfo,
|
||||
provider_runtime: ProviderRuntime,
|
||||
auth_env_telemetry: AuthEnvTelemetry,
|
||||
session_source: SessionSource,
|
||||
model_verbosity: Option<VerbosityConfig>,
|
||||
@@ -267,17 +271,23 @@ impl ModelClient {
|
||||
conversation_id: ThreadId,
|
||||
installation_id: String,
|
||||
provider: ModelProviderInfo,
|
||||
provider_runtime: ProviderRuntime,
|
||||
session_source: SessionSource,
|
||||
model_verbosity: Option<VerbosityConfig>,
|
||||
enable_request_compression: bool,
|
||||
include_timing_metrics: bool,
|
||||
beta_features_header: Option<String>,
|
||||
) -> Self {
|
||||
let auth_manager = auth_manager_for_provider(auth_manager, &provider);
|
||||
let auth_manager =
|
||||
auth_manager_for_provider_runtime(auth_manager, &provider_runtime, &provider);
|
||||
let codex_api_key_env_enabled = auth_manager
|
||||
.as_ref()
|
||||
.is_some_and(|manager| manager.codex_api_key_env_enabled());
|
||||
let auth_env_telemetry = collect_auth_env_telemetry(&provider, codex_api_key_env_enabled);
|
||||
let auth_env_telemetry = collect_auth_env_telemetry_for_runtime(
|
||||
&provider_runtime,
|
||||
&provider,
|
||||
codex_api_key_env_enabled,
|
||||
);
|
||||
Self {
|
||||
state: Arc::new(ModelClientState {
|
||||
auth_manager,
|
||||
@@ -285,6 +295,7 @@ impl ModelClient {
|
||||
window_generation: AtomicU64::new(0),
|
||||
installation_id,
|
||||
provider,
|
||||
provider_runtime,
|
||||
auth_env_telemetry,
|
||||
session_source,
|
||||
model_verbosity,
|
||||
@@ -610,6 +621,16 @@ impl ModelClient {
|
||||
Some(manager) => manager.auth().await,
|
||||
None => None,
|
||||
};
|
||||
|
||||
match &self.state.provider_runtime {
|
||||
ProviderRuntime::Legacy => self.current_client_setup_legacy(auth),
|
||||
ProviderRuntime::Resolved(provider) => {
|
||||
self.current_client_setup_resolved(auth, provider)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn current_client_setup_legacy(&self, auth: Option<CodexAuth>) -> Result<CurrentClientSetup> {
|
||||
let api_provider = self
|
||||
.state
|
||||
.provider
|
||||
@@ -622,6 +643,25 @@ impl ModelClient {
|
||||
})
|
||||
}
|
||||
|
||||
fn current_client_setup_resolved(
|
||||
&self,
|
||||
auth: Option<CodexAuth>,
|
||||
provider: &ResolvedModelProvider,
|
||||
) -> Result<CurrentClientSetup> {
|
||||
let api_provider =
|
||||
provider.to_legacy_api_provider(auth.as_ref().map(CodexAuth::auth_mode))?;
|
||||
let api_auth = auth_provider_from_runtime(
|
||||
auth.clone(),
|
||||
&self.state.provider_runtime,
|
||||
&self.state.provider,
|
||||
)?;
|
||||
Ok(CurrentClientSetup {
|
||||
auth,
|
||||
api_provider,
|
||||
api_auth,
|
||||
})
|
||||
}
|
||||
|
||||
/// Opens a websocket connection using the same header and telemetry wiring as normal turns.
|
||||
///
|
||||
/// Both startup prewarm and in-turn `needs_new` reconnects call this path so handshake
|
||||
|
||||
@@ -9,6 +9,9 @@ use super::X_CODEX_WINDOW_ID_HEADER;
|
||||
use super::X_OPENAI_SUBAGENT_HEADER;
|
||||
use codex_api::CoreAuthProvider;
|
||||
use codex_app_server_protocol::AuthMode;
|
||||
use codex_model_provider::ProviderResolutionPolicy;
|
||||
use codex_model_provider::ProviderRuntime;
|
||||
use codex_model_provider::resolve_model_provider;
|
||||
use codex_model_provider_info::WireApi;
|
||||
use codex_model_provider_info::create_oss_provider_with_base_url;
|
||||
use codex_otel::SessionTelemetry;
|
||||
@@ -26,6 +29,7 @@ fn test_model_client(session_source: SessionSource) -> ModelClient {
|
||||
ThreadId::new(),
|
||||
/*installation_id*/ "11111111-1111-4111-8111-111111111111".to_string(),
|
||||
provider,
|
||||
ProviderRuntime::Legacy,
|
||||
session_source,
|
||||
/*model_verbosity*/ None,
|
||||
/*enable_request_compression*/ false,
|
||||
@@ -79,6 +83,77 @@ fn test_session_telemetry() -> SessionTelemetry {
|
||||
)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn current_client_setup_non_allowlisted_env_key_provider_uses_env_bearer() {
|
||||
let mut provider =
|
||||
create_oss_provider_with_base_url("https://example.com/v1", WireApi::Responses);
|
||||
provider.env_key = Some("PATH".to_string());
|
||||
let expected_token = std::env::var("PATH").expect("PATH should be set for tests");
|
||||
let client = ModelClient::new(
|
||||
/*auth_manager*/ None,
|
||||
ThreadId::new(),
|
||||
/*installation_id*/ "11111111-1111-4111-8111-111111111111".to_string(),
|
||||
provider,
|
||||
ProviderRuntime::Legacy,
|
||||
SessionSource::Cli,
|
||||
/*model_verbosity*/ None,
|
||||
/*enable_request_compression*/ false,
|
||||
/*include_timing_metrics*/ false,
|
||||
/*beta_features_header*/ None,
|
||||
);
|
||||
|
||||
let setup = client
|
||||
.current_client_setup()
|
||||
.await
|
||||
.expect("client setup should succeed");
|
||||
|
||||
assert_eq!(
|
||||
setup.api_auth.token.as_deref(),
|
||||
Some(expected_token.as_str())
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn current_client_setup_uses_resolved_runtime_provider() {
|
||||
let legacy_provider =
|
||||
create_oss_provider_with_base_url("https://legacy.example.com/v1", WireApi::Responses);
|
||||
let mut resolved_provider =
|
||||
create_oss_provider_with_base_url("https://resolved.example.com/v1", WireApi::Responses);
|
||||
resolved_provider.experimental_bearer_token = Some("resolved-token".to_string());
|
||||
let mut runtime = resolve_model_provider(
|
||||
"custom",
|
||||
&resolved_provider,
|
||||
&ProviderResolutionPolicy::with_enabled_provider_ids([String::from("custom")]),
|
||||
);
|
||||
let ProviderRuntime::Resolved(provider) = &mut runtime else {
|
||||
panic!("enabled provider should resolve through the provider framework");
|
||||
};
|
||||
provider.info.experimental_bearer_token = None;
|
||||
let client = ModelClient::new(
|
||||
/*auth_manager*/ None,
|
||||
ThreadId::new(),
|
||||
/*installation_id*/ "11111111-1111-4111-8111-111111111111".to_string(),
|
||||
legacy_provider,
|
||||
runtime,
|
||||
SessionSource::Cli,
|
||||
/*model_verbosity*/ None,
|
||||
/*enable_request_compression*/ false,
|
||||
/*include_timing_metrics*/ false,
|
||||
/*beta_features_header*/ None,
|
||||
);
|
||||
|
||||
let setup = client
|
||||
.current_client_setup()
|
||||
.await
|
||||
.expect("client setup should succeed");
|
||||
|
||||
assert_eq!(
|
||||
setup.api_provider.base_url,
|
||||
"https://resolved.example.com/v1"
|
||||
);
|
||||
assert_eq!(setup.api_auth.token.as_deref(), Some("resolved-token"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_subagent_headers_sets_other_subagent_label() {
|
||||
let client = test_model_client(SessionSource::SubAgent(SubAgentSource::Other(
|
||||
|
||||
@@ -189,6 +189,7 @@ use crate::environment_context::EnvironmentContext;
|
||||
use codex_config::CONFIG_TOML_FILE;
|
||||
use codex_config::types::McpServerConfig;
|
||||
use codex_config::types::ShellEnvironmentPolicy;
|
||||
use codex_model_provider::ProviderRuntime;
|
||||
use codex_model_provider_info::ModelProviderInfo;
|
||||
use codex_protocol::error::CodexErr;
|
||||
use codex_protocol::error::Result as CodexResult;
|
||||
@@ -636,6 +637,7 @@ impl Codex {
|
||||
};
|
||||
let session_configuration = SessionConfiguration {
|
||||
provider: config.model_provider.clone(),
|
||||
provider_runtime: config.provider_runtime.clone(),
|
||||
collaboration_mode,
|
||||
model_reasoning_summary: config.model_reasoning_summary,
|
||||
service_tier: config.service_tier,
|
||||
@@ -1097,6 +1099,7 @@ fn local_time_context() -> (String, String) {
|
||||
pub(crate) struct SessionConfiguration {
|
||||
/// Provider identifier ("openai", "openrouter", ...).
|
||||
provider: ModelProviderInfo,
|
||||
provider_runtime: ProviderRuntime,
|
||||
|
||||
collaboration_mode: CollaborationMode,
|
||||
model_reasoning_summary: Option<ReasoningSummaryConfig>,
|
||||
@@ -1974,6 +1977,7 @@ impl Session {
|
||||
conversation_id,
|
||||
installation_id,
|
||||
session_configuration.provider.clone(),
|
||||
session_configuration.provider_runtime.clone(),
|
||||
session_configuration.session_source.clone(),
|
||||
config.model_verbosity,
|
||||
config.features.enabled(Feature::EnableRequestCompression),
|
||||
|
||||
@@ -257,6 +257,7 @@ fn test_model_client_session() -> crate::client::ModelClientSession {
|
||||
.expect("test thread id should be valid"),
|
||||
/*installation_id*/ "11111111-1111-4111-8111-111111111111".to_string(),
|
||||
ModelProviderInfo::create_openai_provider(/* base_url */ /*base_url*/ None),
|
||||
codex_model_provider::ProviderRuntime::Legacy,
|
||||
codex_protocol::protocol::SessionSource::Exec,
|
||||
/*model_verbosity*/ None,
|
||||
/*enable_request_compression*/ false,
|
||||
@@ -1851,6 +1852,7 @@ async fn set_rate_limits_retains_previous_credits() {
|
||||
};
|
||||
let session_configuration = SessionConfiguration {
|
||||
provider: config.model_provider.clone(),
|
||||
provider_runtime: config.provider_runtime.clone(),
|
||||
collaboration_mode,
|
||||
model_reasoning_summary: config.model_reasoning_summary,
|
||||
developer_instructions: config.developer_instructions.clone(),
|
||||
@@ -1954,6 +1956,7 @@ async fn set_rate_limits_updates_plan_type_when_present() {
|
||||
};
|
||||
let session_configuration = SessionConfiguration {
|
||||
provider: config.model_provider.clone(),
|
||||
provider_runtime: config.provider_runtime.clone(),
|
||||
collaboration_mode,
|
||||
model_reasoning_summary: config.model_reasoning_summary,
|
||||
developer_instructions: config.developer_instructions.clone(),
|
||||
@@ -2305,6 +2308,7 @@ pub(crate) async fn make_session_configuration_for_tests() -> SessionConfigurati
|
||||
|
||||
SessionConfiguration {
|
||||
provider: config.model_provider.clone(),
|
||||
provider_runtime: config.provider_runtime.clone(),
|
||||
collaboration_mode,
|
||||
model_reasoning_summary: config.model_reasoning_summary,
|
||||
developer_instructions: config.developer_instructions.clone(),
|
||||
@@ -2569,6 +2573,7 @@ async fn session_new_fails_when_zsh_fork_enabled_without_zsh_path() {
|
||||
};
|
||||
let session_configuration = SessionConfiguration {
|
||||
provider: config.model_provider.clone(),
|
||||
provider_runtime: config.provider_runtime.clone(),
|
||||
collaboration_mode,
|
||||
model_reasoning_summary: config.model_reasoning_summary,
|
||||
developer_instructions: config.developer_instructions.clone(),
|
||||
@@ -2673,6 +2678,7 @@ pub(crate) async fn make_session_and_context() -> (Session, TurnContext) {
|
||||
};
|
||||
let session_configuration = SessionConfiguration {
|
||||
provider: config.model_provider.clone(),
|
||||
provider_runtime: config.provider_runtime.clone(),
|
||||
collaboration_mode,
|
||||
model_reasoning_summary: config.model_reasoning_summary,
|
||||
developer_instructions: config.developer_instructions.clone(),
|
||||
@@ -2772,6 +2778,7 @@ pub(crate) async fn make_session_and_context() -> (Session, TurnContext) {
|
||||
conversation_id,
|
||||
/*installation_id*/ "11111111-1111-4111-8111-111111111111".to_string(),
|
||||
session_configuration.provider.clone(),
|
||||
session_configuration.provider_runtime.clone(),
|
||||
session_configuration.session_source.clone(),
|
||||
config.model_verbosity,
|
||||
config.features.enabled(Feature::EnableRequestCompression),
|
||||
@@ -3515,6 +3522,7 @@ pub(crate) async fn make_session_and_context_with_dynamic_tools_and_rx(
|
||||
};
|
||||
let session_configuration = SessionConfiguration {
|
||||
provider: config.model_provider.clone(),
|
||||
provider_runtime: config.provider_runtime.clone(),
|
||||
collaboration_mode,
|
||||
model_reasoning_summary: config.model_reasoning_summary,
|
||||
developer_instructions: config.developer_instructions.clone(),
|
||||
@@ -3614,6 +3622,7 @@ pub(crate) async fn make_session_and_context_with_dynamic_tools_and_rx(
|
||||
conversation_id,
|
||||
/*installation_id*/ "11111111-1111-4111-8111-111111111111".to_string(),
|
||||
session_configuration.provider.clone(),
|
||||
session_configuration.provider_runtime.clone(),
|
||||
session_configuration.session_source.clone(),
|
||||
config.model_verbosity,
|
||||
config.features.enabled(Feature::EnableRequestCompression),
|
||||
|
||||
@@ -290,6 +290,26 @@ command = "print-token"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn provider_framework_config_allowlist_does_not_enable_runtime() -> std::io::Result<()> {
|
||||
let cfg = toml::from_str::<ConfigToml>(
|
||||
r#"
|
||||
[experimental_provider_framework]
|
||||
enabled_model_providers = ["openai"]
|
||||
"#,
|
||||
)
|
||||
.expect("unknown provider framework config should not fail to deserialize");
|
||||
|
||||
let config = Config::load_from_base_config_with_overrides(
|
||||
cfg,
|
||||
ConfigOverrides::default(),
|
||||
tempdir()?.abs().into_path_buf(),
|
||||
)?;
|
||||
|
||||
assert_eq!(config.provider_runtime, ProviderRuntime::Legacy);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_toml_deserializes_model_availability_nux() {
|
||||
let toml = r#"
|
||||
@@ -4441,6 +4461,7 @@ fn test_precedence_fixture_with_o3_profile() -> std::io::Result<()> {
|
||||
service_tier: None,
|
||||
model_provider_id: "openai".to_string(),
|
||||
model_provider: fixture.openai_provider.clone(),
|
||||
provider_runtime: ProviderRuntime::Legacy,
|
||||
permissions: Permissions {
|
||||
approval_policy: Constrained::allow_any(AskForApproval::Never),
|
||||
sandbox_policy: Constrained::allow_any(SandboxPolicy::new_read_only_policy()),
|
||||
@@ -4587,6 +4608,7 @@ fn test_precedence_fixture_with_gpt3_profile() -> std::io::Result<()> {
|
||||
service_tier: None,
|
||||
model_provider_id: "openai-custom".to_string(),
|
||||
model_provider: fixture.openai_custom_provider.clone(),
|
||||
provider_runtime: ProviderRuntime::Legacy,
|
||||
permissions: Permissions {
|
||||
approval_policy: Constrained::allow_any(AskForApproval::UnlessTrusted),
|
||||
sandbox_policy: Constrained::allow_any(SandboxPolicy::new_read_only_policy()),
|
||||
@@ -4731,6 +4753,7 @@ fn test_precedence_fixture_with_zdr_profile() -> std::io::Result<()> {
|
||||
service_tier: None,
|
||||
model_provider_id: "openai".to_string(),
|
||||
model_provider: fixture.openai_provider.clone(),
|
||||
provider_runtime: ProviderRuntime::Legacy,
|
||||
permissions: Permissions {
|
||||
approval_policy: Constrained::allow_any(AskForApproval::OnFailure),
|
||||
sandbox_policy: Constrained::allow_any(SandboxPolicy::new_read_only_policy()),
|
||||
@@ -4861,6 +4884,7 @@ fn test_precedence_fixture_with_gpt5_profile() -> std::io::Result<()> {
|
||||
service_tier: None,
|
||||
model_provider_id: "openai".to_string(),
|
||||
model_provider: fixture.openai_provider.clone(),
|
||||
provider_runtime: ProviderRuntime::Legacy,
|
||||
permissions: Permissions {
|
||||
approval_policy: Constrained::allow_any(AskForApproval::OnFailure),
|
||||
sandbox_policy: Constrained::allow_any(SandboxPolicy::new_read_only_policy()),
|
||||
|
||||
@@ -55,6 +55,9 @@ use codex_features::FeatureOverrides;
|
||||
use codex_features::Features;
|
||||
use codex_login::AuthManagerConfig;
|
||||
use codex_mcp::McpConfig;
|
||||
use codex_model_provider::ProviderRuntime;
|
||||
use codex_model_provider::production_provider_resolution_policy;
|
||||
use codex_model_provider::resolve_model_provider;
|
||||
use codex_model_provider_info::LEGACY_OLLAMA_CHAT_PROVIDER_ID;
|
||||
use codex_model_provider_info::ModelProviderInfo;
|
||||
use codex_model_provider_info::OLLAMA_CHAT_PROVIDER_REMOVED_ERROR;
|
||||
@@ -214,6 +217,9 @@ pub struct Config {
|
||||
/// Info needed to make an API request to the model.
|
||||
pub model_provider: ModelProviderInfo,
|
||||
|
||||
/// Runtime provider strategy. This remains legacy unless explicitly opted in.
|
||||
pub provider_runtime: ProviderRuntime,
|
||||
|
||||
/// Optionally specify the personality of the model
|
||||
pub personality: Option<Personality>,
|
||||
|
||||
@@ -1641,6 +1647,11 @@ impl Config {
|
||||
std::io::Error::new(std::io::ErrorKind::NotFound, message)
|
||||
})?
|
||||
.clone();
|
||||
let provider_runtime = resolve_model_provider(
|
||||
&model_provider_id,
|
||||
&model_provider,
|
||||
&production_provider_resolution_policy(),
|
||||
);
|
||||
|
||||
let shell_environment_policy = cfg.shell_environment_policy.into();
|
||||
let allow_login_shell = cfg.allow_login_shell.unwrap_or(true);
|
||||
@@ -1930,6 +1941,7 @@ impl Config {
|
||||
model_auto_compact_token_limit: cfg.model_auto_compact_token_limit,
|
||||
model_provider_id,
|
||||
model_provider,
|
||||
provider_runtime,
|
||||
cwd: resolved_cwd,
|
||||
startup_warnings,
|
||||
permissions: Permissions {
|
||||
|
||||
@@ -5,6 +5,7 @@ use codex_core::ModelClient;
|
||||
use codex_core::Prompt;
|
||||
use codex_core::ResponseEvent;
|
||||
use codex_login::CodexAuth;
|
||||
use codex_model_provider::ProviderRuntime;
|
||||
use codex_model_provider_info::ModelProviderInfo;
|
||||
use codex_model_provider_info::WireApi;
|
||||
use codex_otel::SessionTelemetry;
|
||||
@@ -102,6 +103,7 @@ async fn responses_stream_includes_subagent_header_on_review() {
|
||||
conversation_id,
|
||||
/*installation_id*/ TEST_INSTALLATION_ID.to_string(),
|
||||
provider.clone(),
|
||||
ProviderRuntime::Legacy,
|
||||
session_source,
|
||||
config.model_verbosity,
|
||||
/*enable_request_compression*/ false,
|
||||
@@ -227,6 +229,7 @@ async fn responses_stream_includes_subagent_header_on_other() {
|
||||
conversation_id,
|
||||
/*installation_id*/ TEST_INSTALLATION_ID.to_string(),
|
||||
provider.clone(),
|
||||
ProviderRuntime::Legacy,
|
||||
session_source,
|
||||
config.model_verbosity,
|
||||
/*enable_request_compression*/ false,
|
||||
@@ -341,6 +344,7 @@ async fn responses_respects_model_info_overrides_from_config() {
|
||||
conversation_id,
|
||||
/*installation_id*/ TEST_INSTALLATION_ID.to_string(),
|
||||
provider.clone(),
|
||||
ProviderRuntime::Legacy,
|
||||
session_source,
|
||||
config.model_verbosity,
|
||||
/*enable_request_compression*/ false,
|
||||
|
||||
@@ -8,6 +8,7 @@ use codex_features::Feature;
|
||||
use codex_login::AuthManager;
|
||||
use codex_login::CodexAuth;
|
||||
use codex_login::default_client::originator;
|
||||
use codex_model_provider::ProviderRuntime;
|
||||
use codex_model_provider_info::ModelProviderInfo;
|
||||
use codex_model_provider_info::WireApi;
|
||||
use codex_model_provider_info::built_in_model_providers;
|
||||
@@ -777,7 +778,7 @@ async fn includes_conversation_id_and_model_headers_in_request() {
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn provider_auth_command_supplies_bearer_token() {
|
||||
async fn non_allowlisted_provider_auth_command_supplies_bearer_token() {
|
||||
skip_if_no_network!();
|
||||
|
||||
let server = MockServer::start().await;
|
||||
@@ -880,6 +881,7 @@ async fn send_provider_auth_request(server: &MockServer, auth: ModelProviderAuth
|
||||
conversation_id,
|
||||
/*installation_id*/ "11111111-1111-4111-8111-111111111111".to_string(),
|
||||
provider,
|
||||
ProviderRuntime::Legacy,
|
||||
SessionSource::Exec,
|
||||
config.model_verbosity,
|
||||
/*enable_request_compression*/ false,
|
||||
@@ -2199,6 +2201,7 @@ async fn azure_responses_request_includes_store_and_reasoning_ids() {
|
||||
conversation_id,
|
||||
/*installation_id*/ "11111111-1111-4111-8111-111111111111".to_string(),
|
||||
provider.clone(),
|
||||
ProviderRuntime::Legacy,
|
||||
SessionSource::Exec,
|
||||
config.model_verbosity,
|
||||
/*enable_request_compression*/ false,
|
||||
|
||||
@@ -8,6 +8,7 @@ use codex_core::ResponseEvent;
|
||||
use codex_core::X_RESPONSESAPI_INCLUDE_TIMING_METRICS_HEADER;
|
||||
use codex_features::Feature;
|
||||
use codex_login::CodexAuth;
|
||||
use codex_model_provider::ProviderRuntime;
|
||||
use codex_model_provider_info::ModelProviderInfo;
|
||||
use codex_model_provider_info::WireApi;
|
||||
use codex_otel::MetricsClient;
|
||||
@@ -1763,6 +1764,7 @@ async fn websocket_harness_with_provider_options(
|
||||
conversation_id,
|
||||
/*installation_id*/ TEST_INSTALLATION_ID.to_string(),
|
||||
provider.clone(),
|
||||
ProviderRuntime::Legacy,
|
||||
SessionSource::Exec,
|
||||
config.model_verbosity,
|
||||
/*enable_request_compression*/ false,
|
||||
|
||||
@@ -16,6 +16,7 @@ codex-api = { workspace = true }
|
||||
codex-client = { workspace = true }
|
||||
codex-config = { workspace = true }
|
||||
codex-keyring-store = { workspace = true }
|
||||
codex-model-provider = { workspace = true }
|
||||
codex-model-provider-info = { workspace = true }
|
||||
codex-otel = { workspace = true }
|
||||
codex-protocol = { workspace = true }
|
||||
@@ -44,6 +45,7 @@ webbrowser = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
codex-utils-absolute-path = { workspace = true }
|
||||
core_test_support = { workspace = true }
|
||||
keyring = { workspace = true }
|
||||
pretty_assertions = { workspace = true }
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
use codex_api::CoreAuthProvider;
|
||||
use codex_model_provider::ProviderAuthKind;
|
||||
use codex_model_provider::ProviderRuntime;
|
||||
use codex_model_provider_info::ModelProviderInfo;
|
||||
use codex_protocol::error::CodexErr;
|
||||
use codex_protocol::error::EnvVarError;
|
||||
|
||||
use crate::CodexAuth;
|
||||
|
||||
@@ -34,3 +38,181 @@ pub fn auth_provider_from_auth(
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn auth_provider_from_runtime(
|
||||
auth: Option<CodexAuth>,
|
||||
provider_runtime: &ProviderRuntime,
|
||||
legacy_provider: &ModelProviderInfo,
|
||||
) -> codex_protocol::error::Result<CoreAuthProvider> {
|
||||
match provider_runtime {
|
||||
ProviderRuntime::Legacy => auth_provider_from_auth(auth, legacy_provider),
|
||||
ProviderRuntime::Resolved(provider) => {
|
||||
auth_provider_from_provider_auth(auth, &provider.auth)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn auth_provider_from_provider_auth(
|
||||
auth: Option<CodexAuth>,
|
||||
provider_auth: &ProviderAuthKind,
|
||||
) -> codex_protocol::error::Result<CoreAuthProvider> {
|
||||
match provider_auth {
|
||||
ProviderAuthKind::EnvBearer {
|
||||
env_key,
|
||||
instructions,
|
||||
} => {
|
||||
let token = std::env::var(env_key)
|
||||
.ok()
|
||||
.filter(|value| !value.trim().is_empty())
|
||||
.ok_or_else(|| {
|
||||
CodexErr::EnvVar(EnvVarError {
|
||||
var: env_key.clone(),
|
||||
instructions: instructions.clone(),
|
||||
})
|
||||
})?;
|
||||
Ok(CoreAuthProvider {
|
||||
token: Some(token),
|
||||
account_id: None,
|
||||
})
|
||||
}
|
||||
ProviderAuthKind::StaticBearer { token } => Ok(CoreAuthProvider {
|
||||
token: Some(token.clone()),
|
||||
account_id: None,
|
||||
}),
|
||||
ProviderAuthKind::CommandBearer { .. } | ProviderAuthKind::AuthManager => {
|
||||
if let Some(auth) = auth {
|
||||
let token = auth.get_token()?;
|
||||
Ok(CoreAuthProvider {
|
||||
token: Some(token),
|
||||
account_id: auth.get_account_id(),
|
||||
})
|
||||
} else {
|
||||
Ok(CoreAuthProvider {
|
||||
token: None,
|
||||
account_id: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use codex_api::AuthProvider;
|
||||
use codex_model_provider::ProviderResolutionPolicy;
|
||||
use codex_model_provider::resolve_model_provider;
|
||||
use codex_model_provider_info::ModelProviderInfo;
|
||||
use codex_model_provider_info::WireApi;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
use super::*;
|
||||
|
||||
fn bearer_provider() -> ModelProviderInfo {
|
||||
ModelProviderInfo {
|
||||
name: "custom".to_string(),
|
||||
base_url: Some("https://example.com/v1".to_string()),
|
||||
env_key: None,
|
||||
env_key_instructions: None,
|
||||
experimental_bearer_token: Some("token".to_string()),
|
||||
auth: None,
|
||||
wire_api: WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
request_max_retries: None,
|
||||
stream_max_retries: None,
|
||||
stream_idle_timeout_ms: None,
|
||||
websocket_connect_timeout_ms: None,
|
||||
requires_openai_auth: false,
|
||||
supports_websockets: false,
|
||||
}
|
||||
}
|
||||
|
||||
fn missing_env_key_provider() -> ModelProviderInfo {
|
||||
ModelProviderInfo {
|
||||
name: "custom".to_string(),
|
||||
base_url: Some("https://example.com/v1".to_string()),
|
||||
env_key: Some("MISSING_CODEX_PROVIDER_TEST_KEY".to_string()),
|
||||
env_key_instructions: None,
|
||||
experimental_bearer_token: None,
|
||||
auth: None,
|
||||
wire_api: WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
request_max_retries: None,
|
||||
stream_max_retries: None,
|
||||
stream_idle_timeout_ms: None,
|
||||
websocket_connect_timeout_ms: None,
|
||||
requires_openai_auth: false,
|
||||
supports_websockets: false,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn runtime_auth_adapter_uses_resolved_static_bearer_auth() {
|
||||
let provider = bearer_provider();
|
||||
let mut runtime = resolve_model_provider(
|
||||
"custom",
|
||||
&provider,
|
||||
&ProviderResolutionPolicy::with_enabled_provider_ids(["custom".to_string()]),
|
||||
);
|
||||
let ProviderRuntime::Resolved(resolved) = &mut runtime else {
|
||||
panic!("enabled provider should resolve through the provider framework");
|
||||
};
|
||||
resolved.info.experimental_bearer_token = None;
|
||||
|
||||
let legacy = auth_provider_from_auth(None, &provider).expect("legacy auth");
|
||||
let resolved = auth_provider_from_runtime(None, &runtime, &provider).expect("runtime auth");
|
||||
|
||||
assert_eq!(resolved.bearer_token(), legacy.bearer_token());
|
||||
assert_eq!(resolved.account_id(), legacy.account_id());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn runtime_auth_adapter_uses_resolved_env_key_errors() {
|
||||
let provider = missing_env_key_provider();
|
||||
let mut runtime = resolve_model_provider(
|
||||
"custom",
|
||||
&provider,
|
||||
&ProviderResolutionPolicy::with_enabled_provider_ids(["custom".to_string()]),
|
||||
);
|
||||
let ProviderRuntime::Resolved(resolved) = &mut runtime else {
|
||||
panic!("enabled provider should resolve through the provider framework");
|
||||
};
|
||||
resolved.info.env_key = None;
|
||||
resolved.info.env_key_instructions = None;
|
||||
|
||||
let legacy = match auth_provider_from_auth(None, &provider) {
|
||||
Ok(_) => panic!("missing env key should fail"),
|
||||
Err(err) => err.to_string(),
|
||||
};
|
||||
let resolved = match auth_provider_from_runtime(None, &runtime, &provider) {
|
||||
Ok(_) => panic!("missing env key should fail"),
|
||||
Err(err) => err.to_string(),
|
||||
};
|
||||
|
||||
assert_eq!(resolved, legacy);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn runtime_auth_adapter_uses_auth_manager_fallback() {
|
||||
let provider = ModelProviderInfo {
|
||||
experimental_bearer_token: None,
|
||||
requires_openai_auth: false,
|
||||
..bearer_provider()
|
||||
};
|
||||
let runtime = resolve_model_provider(
|
||||
"custom",
|
||||
&provider,
|
||||
&ProviderResolutionPolicy::with_enabled_provider_ids(["custom".to_string()]),
|
||||
);
|
||||
let auth = Some(CodexAuth::from_api_key("auth-manager-token"));
|
||||
|
||||
let legacy = auth_provider_from_auth(auth.clone(), &provider).expect("legacy auth");
|
||||
let resolved = auth_provider_from_runtime(auth, &runtime, &provider).expect("runtime auth");
|
||||
|
||||
assert_eq!(resolved.bearer_token(), legacy.bearer_token());
|
||||
assert_eq!(resolved.account_id(), legacy.account_id());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
use codex_model_provider::ProviderAuthKind;
|
||||
use codex_model_provider::ProviderRuntime;
|
||||
use codex_model_provider_info::ModelProviderInfo;
|
||||
use codex_otel::AuthEnvTelemetryMetadata;
|
||||
|
||||
@@ -31,13 +33,49 @@ impl AuthEnvTelemetry {
|
||||
pub fn collect_auth_env_telemetry(
|
||||
provider: &ModelProviderInfo,
|
||||
codex_api_key_env_enabled: bool,
|
||||
) -> AuthEnvTelemetry {
|
||||
let provider_env_key = provider.env_key.as_deref();
|
||||
collect_auth_env_telemetry_for_env_key(provider_env_key, codex_api_key_env_enabled)
|
||||
}
|
||||
|
||||
pub fn collect_auth_env_telemetry_for_runtime(
|
||||
provider_runtime: &ProviderRuntime,
|
||||
legacy_provider: &ModelProviderInfo,
|
||||
codex_api_key_env_enabled: bool,
|
||||
) -> AuthEnvTelemetry {
|
||||
match provider_runtime {
|
||||
ProviderRuntime::Legacy => {
|
||||
collect_auth_env_telemetry(legacy_provider, codex_api_key_env_enabled)
|
||||
}
|
||||
ProviderRuntime::Resolved(provider) => {
|
||||
collect_auth_env_telemetry_for_provider_auth(&provider.auth, codex_api_key_env_enabled)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn collect_auth_env_telemetry_for_provider_auth(
|
||||
provider_auth: &ProviderAuthKind,
|
||||
codex_api_key_env_enabled: bool,
|
||||
) -> AuthEnvTelemetry {
|
||||
let provider_env_key = match provider_auth {
|
||||
ProviderAuthKind::EnvBearer { env_key, .. } => Some(env_key.as_str()),
|
||||
ProviderAuthKind::StaticBearer { .. }
|
||||
| ProviderAuthKind::CommandBearer { .. }
|
||||
| ProviderAuthKind::AuthManager => None,
|
||||
};
|
||||
collect_auth_env_telemetry_for_env_key(provider_env_key, codex_api_key_env_enabled)
|
||||
}
|
||||
|
||||
fn collect_auth_env_telemetry_for_env_key(
|
||||
provider_env_key: Option<&str>,
|
||||
codex_api_key_env_enabled: bool,
|
||||
) -> AuthEnvTelemetry {
|
||||
AuthEnvTelemetry {
|
||||
openai_api_key_env_present: env_var_present(OPENAI_API_KEY_ENV_VAR),
|
||||
codex_api_key_env_present: env_var_present(CODEX_API_KEY_ENV_VAR),
|
||||
codex_api_key_env_enabled,
|
||||
provider_env_key_name: provider.env_key.as_ref().map(|_| "configured".to_string()),
|
||||
provider_env_key_present: provider.env_key.as_deref().map(env_var_present),
|
||||
provider_env_key_name: provider_env_key.map(|_| "configured".to_string()),
|
||||
provider_env_key_present: provider_env_key.map(env_var_present),
|
||||
refresh_token_url_override_present: env_var_present(REFRESH_TOKEN_URL_OVERRIDE_ENV_VAR),
|
||||
}
|
||||
}
|
||||
@@ -53,6 +91,9 @@ fn env_var_present(name: &str) -> bool {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use codex_model_provider::ProviderResolutionPolicy;
|
||||
use codex_model_provider::ProviderRuntime;
|
||||
use codex_model_provider::resolve_model_provider;
|
||||
use codex_model_provider_info::WireApi;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
@@ -85,4 +126,45 @@ mod tests {
|
||||
Some("configured".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn runtime_auth_env_telemetry_uses_resolved_provider_auth() {
|
||||
let provider = ModelProviderInfo {
|
||||
name: "Custom".to_string(),
|
||||
base_url: None,
|
||||
env_key: Some("PATH".to_string()),
|
||||
env_key_instructions: None,
|
||||
experimental_bearer_token: None,
|
||||
auth: None,
|
||||
wire_api: WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
request_max_retries: None,
|
||||
stream_max_retries: None,
|
||||
stream_idle_timeout_ms: None,
|
||||
websocket_connect_timeout_ms: None,
|
||||
requires_openai_auth: false,
|
||||
supports_websockets: false,
|
||||
};
|
||||
let mut runtime = resolve_model_provider(
|
||||
"custom",
|
||||
&provider,
|
||||
&ProviderResolutionPolicy::with_enabled_provider_ids(["custom".to_string()]),
|
||||
);
|
||||
let ProviderRuntime::Resolved(resolved) = &mut runtime else {
|
||||
panic!("enabled provider should resolve through the provider framework");
|
||||
};
|
||||
resolved.info.env_key = None;
|
||||
|
||||
let telemetry = collect_auth_env_telemetry_for_runtime(
|
||||
&runtime, &provider, /*codex_api_key_env_enabled*/ false,
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
telemetry.provider_env_key_name,
|
||||
Some("configured".to_string())
|
||||
);
|
||||
assert_eq!(telemetry.provider_env_key_present, Some(true));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use codex_model_provider::ProviderAuthKind;
|
||||
use codex_model_provider::ProviderRuntime;
|
||||
use codex_model_provider_info::ModelProviderInfo;
|
||||
|
||||
use crate::AuthManager;
|
||||
@@ -17,6 +19,33 @@ pub fn auth_manager_for_provider(
|
||||
}
|
||||
}
|
||||
|
||||
pub fn auth_manager_for_provider_runtime(
|
||||
auth_manager: Option<Arc<AuthManager>>,
|
||||
provider_runtime: &ProviderRuntime,
|
||||
legacy_provider: &ModelProviderInfo,
|
||||
) -> Option<Arc<AuthManager>> {
|
||||
match provider_runtime {
|
||||
ProviderRuntime::Legacy => auth_manager_for_provider(auth_manager, legacy_provider),
|
||||
ProviderRuntime::Resolved(provider) => {
|
||||
auth_manager_for_provider_auth(auth_manager, &provider.auth)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn auth_manager_for_provider_auth(
|
||||
auth_manager: Option<Arc<AuthManager>>,
|
||||
provider_auth: &ProviderAuthKind,
|
||||
) -> Option<Arc<AuthManager>> {
|
||||
match provider_auth {
|
||||
ProviderAuthKind::CommandBearer { config } => {
|
||||
Some(AuthManager::external_bearer_only(config.clone()))
|
||||
}
|
||||
ProviderAuthKind::EnvBearer { .. }
|
||||
| ProviderAuthKind::StaticBearer { .. }
|
||||
| ProviderAuthKind::AuthManager => auth_manager,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns an auth manager for request paths that always require authentication.
|
||||
///
|
||||
/// Providers with command-backed auth get a bearer-only manager; otherwise the caller's manager
|
||||
@@ -30,3 +59,75 @@ pub fn required_auth_manager_for_provider(
|
||||
None => auth_manager,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use codex_model_provider::ProviderResolutionPolicy;
|
||||
use codex_model_provider::resolve_model_provider;
|
||||
use codex_model_provider_info::WireApi;
|
||||
use codex_protocol::config_types::ModelProviderAuthInfo;
|
||||
use codex_utils_absolute_path::AbsolutePathBuf;
|
||||
use std::num::NonZeroU64;
|
||||
|
||||
use super::*;
|
||||
|
||||
fn provider_with_command_auth() -> ModelProviderInfo {
|
||||
ModelProviderInfo {
|
||||
name: "custom".to_string(),
|
||||
base_url: Some("https://example.com/v1".to_string()),
|
||||
env_key: None,
|
||||
env_key_instructions: None,
|
||||
experimental_bearer_token: None,
|
||||
auth: Some(ModelProviderAuthInfo {
|
||||
command: "print-token".to_string(),
|
||||
args: Vec::new(),
|
||||
timeout_ms: NonZeroU64::MIN,
|
||||
refresh_interval_ms: 0,
|
||||
cwd: AbsolutePathBuf::resolve_path_against_base(".", "/tmp"),
|
||||
}),
|
||||
wire_api: WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
request_max_retries: None,
|
||||
stream_max_retries: None,
|
||||
stream_idle_timeout_ms: None,
|
||||
websocket_connect_timeout_ms: None,
|
||||
requires_openai_auth: false,
|
||||
supports_websockets: false,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn runtime_auth_manager_adapter_uses_resolved_command_auth() {
|
||||
let provider = provider_with_command_auth();
|
||||
let mut runtime = resolve_model_provider(
|
||||
"custom",
|
||||
&provider,
|
||||
&ProviderResolutionPolicy::with_enabled_provider_ids(["custom".to_string()]),
|
||||
);
|
||||
let ProviderRuntime::Resolved(resolved) = &mut runtime else {
|
||||
panic!("enabled provider should resolve through the provider framework");
|
||||
};
|
||||
resolved.info.auth = None;
|
||||
|
||||
assert!(auth_manager_for_provider(None, &provider).is_some());
|
||||
assert!(auth_manager_for_provider_runtime(None, &runtime, &provider).is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn runtime_auth_manager_adapter_preserves_absent_auth_manager_for_plain_provider() {
|
||||
let provider = ModelProviderInfo {
|
||||
auth: None,
|
||||
..provider_with_command_auth()
|
||||
};
|
||||
let runtime = resolve_model_provider(
|
||||
"custom",
|
||||
&provider,
|
||||
&ProviderResolutionPolicy::with_enabled_provider_ids(["custom".to_string()]),
|
||||
);
|
||||
|
||||
assert!(auth_manager_for_provider(None, &provider).is_none());
|
||||
assert!(auth_manager_for_provider_runtime(None, &runtime, &provider).is_none());
|
||||
}
|
||||
}
|
||||
|
||||
6
codex-rs/model-provider/BUILD.bazel
Normal file
6
codex-rs/model-provider/BUILD.bazel
Normal file
@@ -0,0 +1,6 @@
|
||||
load("//:defs.bzl", "codex_rust_crate")
|
||||
|
||||
codex_rust_crate(
|
||||
name = "model-provider",
|
||||
crate_name = "codex_model_provider",
|
||||
)
|
||||
24
codex-rs/model-provider/Cargo.toml
Normal file
24
codex-rs/model-provider/Cargo.toml
Normal file
@@ -0,0 +1,24 @@
|
||||
[package]
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
name = "codex-model-provider"
|
||||
version.workspace = true
|
||||
|
||||
[lib]
|
||||
doctest = false
|
||||
name = "codex_model_provider"
|
||||
path = "src/lib.rs"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
codex-api = { workspace = true }
|
||||
codex-app-server-protocol = { workspace = true }
|
||||
codex-model-provider-info = { workspace = true }
|
||||
codex-protocol = { workspace = true }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
|
||||
[dev-dependencies]
|
||||
codex-utils-absolute-path = { workspace = true }
|
||||
pretty_assertions = { workspace = true }
|
||||
378
codex-rs/model-provider/src/lib.rs
Normal file
378
codex-rs/model-provider/src/lib.rs
Normal file
@@ -0,0 +1,378 @@
|
||||
//! Runtime provider abstraction for model-provider-specific behavior.
|
||||
//!
|
||||
//! The framework is intentionally opt-in. Providers continue through legacy
|
||||
//! auth, transport, model-listing, and capability paths unless a resolution
|
||||
//! policy explicitly enables the selected provider.
|
||||
|
||||
use codex_model_provider_info::ModelProviderInfo;
|
||||
use codex_protocol::config_types::ModelProviderAuthInfo;
|
||||
use codex_protocol::error::Result as CodexResult;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use std::collections::HashSet;
|
||||
|
||||
/// Provider IDs opted in to the runtime provider framework in production.
|
||||
///
|
||||
/// This is intentionally code-owned rather than user-configurable. Leaving it
|
||||
/// empty preserves legacy behavior for every provider.
|
||||
pub const PROVIDER_FRAMEWORK_ENABLED_PROVIDER_IDS: &[&str] = &[];
|
||||
|
||||
/// Runtime strategy selected for the active model provider.
|
||||
#[derive(Debug, Clone, PartialEq, Default)]
|
||||
pub enum ProviderRuntime {
|
||||
/// Preserve the existing provider implementation.
|
||||
#[default]
|
||||
Legacy,
|
||||
/// Use provider-owned runtime strategies.
|
||||
Resolved(Box<ResolvedModelProvider>),
|
||||
}
|
||||
|
||||
/// Runtime provider object resolved from config-facing provider metadata.
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct ResolvedModelProvider {
|
||||
pub id: String,
|
||||
pub info: ModelProviderInfo,
|
||||
pub auth: ProviderAuthKind,
|
||||
}
|
||||
|
||||
impl ResolvedModelProvider {
|
||||
/// Build the legacy API provider for resolved providers that deliberately
|
||||
/// mirror existing OpenAI-compatible behavior.
|
||||
pub fn to_legacy_api_provider(
|
||||
&self,
|
||||
auth_mode: Option<codex_app_server_protocol::AuthMode>,
|
||||
) -> CodexResult<codex_api::Provider> {
|
||||
self.info.to_api_provider(auth_mode)
|
||||
}
|
||||
}
|
||||
|
||||
/// Provider-owned authentication strategy.
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "camelCase")]
|
||||
pub enum ProviderAuthKind {
|
||||
/// Bearer token read from a configured environment variable.
|
||||
EnvBearer {
|
||||
env_key: String,
|
||||
instructions: Option<String>,
|
||||
},
|
||||
/// Bearer token read from provider config.
|
||||
StaticBearer { token: String },
|
||||
/// Command-backed bearer token.
|
||||
CommandBearer { config: ModelProviderAuthInfo },
|
||||
/// Bearer token supplied by the session auth manager.
|
||||
AuthManager,
|
||||
}
|
||||
|
||||
/// Policy controlling which providers may use the new runtime framework.
|
||||
#[derive(Debug, Clone, Default, PartialEq, Eq)]
|
||||
pub struct ProviderResolutionPolicy {
|
||||
enabled_provider_ids: HashSet<String>,
|
||||
}
|
||||
|
||||
impl ProviderResolutionPolicy {
|
||||
pub fn disabled() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
pub fn with_enabled_provider_ids(provider_ids: impl IntoIterator<Item = String>) -> Self {
|
||||
Self {
|
||||
enabled_provider_ids: provider_ids.into_iter().collect(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_static_provider_ids(provider_ids: &'static [&'static str]) -> Self {
|
||||
Self {
|
||||
enabled_provider_ids: provider_ids.iter().map(|id| (*id).to_string()).collect(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn enables_provider(&self, provider_id: &str) -> bool {
|
||||
self.enabled_provider_ids.contains(provider_id)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn production_provider_resolution_policy() -> ProviderResolutionPolicy {
|
||||
ProviderResolutionPolicy::from_static_provider_ids(PROVIDER_FRAMEWORK_ENABLED_PROVIDER_IDS)
|
||||
}
|
||||
|
||||
/// Resolve the config-facing provider into a runtime strategy.
|
||||
pub fn resolve_model_provider(
|
||||
provider_id: &str,
|
||||
provider: &ModelProviderInfo,
|
||||
policy: &ProviderResolutionPolicy,
|
||||
) -> ProviderRuntime {
|
||||
if !policy.enables_provider(provider_id) {
|
||||
return ProviderRuntime::Legacy;
|
||||
}
|
||||
|
||||
ProviderRuntime::Resolved(Box::new(resolve_generic_model_provider(
|
||||
provider_id,
|
||||
provider,
|
||||
)))
|
||||
}
|
||||
|
||||
fn resolve_generic_model_provider(
|
||||
provider_id: &str,
|
||||
provider: &ModelProviderInfo,
|
||||
) -> ResolvedModelProvider {
|
||||
ResolvedModelProvider {
|
||||
id: provider_id.to_string(),
|
||||
info: provider.clone(),
|
||||
auth: resolve_provider_auth(provider),
|
||||
}
|
||||
}
|
||||
|
||||
fn resolve_provider_auth(provider: &ModelProviderInfo) -> ProviderAuthKind {
|
||||
if let Some(env_key) = &provider.env_key {
|
||||
ProviderAuthKind::EnvBearer {
|
||||
env_key: env_key.clone(),
|
||||
instructions: provider.env_key_instructions.clone(),
|
||||
}
|
||||
} else if let Some(token) = &provider.experimental_bearer_token {
|
||||
ProviderAuthKind::StaticBearer {
|
||||
token: token.clone(),
|
||||
}
|
||||
} else if let Some(config) = &provider.auth {
|
||||
ProviderAuthKind::CommandBearer {
|
||||
config: config.clone(),
|
||||
}
|
||||
} else {
|
||||
ProviderAuthKind::AuthManager
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use codex_model_provider_info::LMSTUDIO_OSS_PROVIDER_ID;
|
||||
use codex_model_provider_info::OLLAMA_OSS_PROVIDER_ID;
|
||||
use codex_model_provider_info::OPENAI_PROVIDER_ID;
|
||||
use codex_model_provider_info::WireApi;
|
||||
use codex_protocol::config_types::ModelProviderAuthInfo;
|
||||
use codex_utils_absolute_path::AbsolutePathBuf;
|
||||
use pretty_assertions::assert_eq;
|
||||
use std::num::NonZeroU64;
|
||||
|
||||
#[test]
|
||||
fn empty_policy_keeps_known_providers_on_legacy_runtime() {
|
||||
let providers =
|
||||
codex_model_provider_info::built_in_model_providers(/*openai_base_url*/ None);
|
||||
let policy = ProviderResolutionPolicy::disabled();
|
||||
|
||||
for provider_id in [
|
||||
OPENAI_PROVIDER_ID,
|
||||
OLLAMA_OSS_PROVIDER_ID,
|
||||
LMSTUDIO_OSS_PROVIDER_ID,
|
||||
] {
|
||||
let provider = providers.get(provider_id).expect("provider should exist");
|
||||
assert_eq!(
|
||||
resolve_model_provider(provider_id, provider, &policy),
|
||||
ProviderRuntime::Legacy
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_policy_keeps_custom_env_key_provider_on_legacy_runtime() {
|
||||
let provider = ModelProviderInfo {
|
||||
name: "custom".to_string(),
|
||||
base_url: Some("https://example.com/v1".to_string()),
|
||||
env_key: Some("CUSTOM_API_KEY".to_string()),
|
||||
env_key_instructions: None,
|
||||
experimental_bearer_token: None,
|
||||
auth: None,
|
||||
wire_api: WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
request_max_retries: None,
|
||||
stream_max_retries: None,
|
||||
stream_idle_timeout_ms: None,
|
||||
websocket_connect_timeout_ms: None,
|
||||
requires_openai_auth: false,
|
||||
supports_websockets: false,
|
||||
};
|
||||
|
||||
assert_eq!(
|
||||
resolve_model_provider("custom", &provider, &ProviderResolutionPolicy::disabled()),
|
||||
ProviderRuntime::Legacy
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_policy_keeps_command_auth_provider_on_legacy_runtime() {
|
||||
let provider = ModelProviderInfo {
|
||||
name: "custom".to_string(),
|
||||
base_url: Some("https://example.com/v1".to_string()),
|
||||
env_key: None,
|
||||
env_key_instructions: None,
|
||||
experimental_bearer_token: None,
|
||||
auth: Some(ModelProviderAuthInfo {
|
||||
command: "print-token".to_string(),
|
||||
args: Vec::new(),
|
||||
timeout_ms: NonZeroU64::MIN,
|
||||
refresh_interval_ms: 0,
|
||||
cwd: AbsolutePathBuf::resolve_path_against_base(".", "/tmp"),
|
||||
}),
|
||||
wire_api: WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
request_max_retries: None,
|
||||
stream_max_retries: None,
|
||||
stream_idle_timeout_ms: None,
|
||||
websocket_connect_timeout_ms: None,
|
||||
requires_openai_auth: false,
|
||||
supports_websockets: false,
|
||||
};
|
||||
|
||||
assert_eq!(
|
||||
resolve_model_provider("custom", &provider, &ProviderResolutionPolicy::disabled()),
|
||||
ProviderRuntime::Legacy
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn production_policy_keeps_known_providers_on_legacy_runtime() {
|
||||
let providers =
|
||||
codex_model_provider_info::built_in_model_providers(/*openai_base_url*/ None);
|
||||
let policy = production_provider_resolution_policy();
|
||||
|
||||
for provider_id in [
|
||||
OPENAI_PROVIDER_ID,
|
||||
OLLAMA_OSS_PROVIDER_ID,
|
||||
LMSTUDIO_OSS_PROVIDER_ID,
|
||||
] {
|
||||
let provider = providers.get(provider_id).expect("provider should exist");
|
||||
assert_eq!(
|
||||
resolve_model_provider(provider_id, provider, &policy),
|
||||
ProviderRuntime::Legacy
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn enabled_policy_activates_generic_resolved_provider() {
|
||||
let providers =
|
||||
codex_model_provider_info::built_in_model_providers(/*openai_base_url*/ None);
|
||||
let provider = providers
|
||||
.get(OPENAI_PROVIDER_ID)
|
||||
.expect("provider should exist");
|
||||
let policy =
|
||||
ProviderResolutionPolicy::with_enabled_provider_ids([OPENAI_PROVIDER_ID.to_string()]);
|
||||
|
||||
assert!(policy.enables_provider(OPENAI_PROVIDER_ID));
|
||||
let runtime = resolve_model_provider(OPENAI_PROVIDER_ID, provider, &policy);
|
||||
let ProviderRuntime::Resolved(resolved) = runtime else {
|
||||
panic!("enabled provider should resolve through the provider framework");
|
||||
};
|
||||
assert_eq!(resolved.id, OPENAI_PROVIDER_ID);
|
||||
assert_eq!(resolved.info, *provider);
|
||||
assert_eq!(resolved.auth, ProviderAuthKind::AuthManager);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn enabled_policy_resolves_env_bearer_auth() {
|
||||
let provider = ModelProviderInfo {
|
||||
name: "custom".to_string(),
|
||||
base_url: Some("https://example.com/v1".to_string()),
|
||||
env_key: Some("CUSTOM_API_KEY".to_string()),
|
||||
env_key_instructions: Some("set CUSTOM_API_KEY".to_string()),
|
||||
experimental_bearer_token: None,
|
||||
auth: None,
|
||||
wire_api: WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
request_max_retries: None,
|
||||
stream_max_retries: None,
|
||||
stream_idle_timeout_ms: None,
|
||||
websocket_connect_timeout_ms: None,
|
||||
requires_openai_auth: false,
|
||||
supports_websockets: false,
|
||||
};
|
||||
let ProviderRuntime::Resolved(resolved) = resolve_model_provider(
|
||||
"custom",
|
||||
&provider,
|
||||
&ProviderResolutionPolicy::with_enabled_provider_ids(["custom".to_string()]),
|
||||
) else {
|
||||
panic!("enabled provider should resolve through the provider framework");
|
||||
};
|
||||
|
||||
assert_eq!(
|
||||
resolved.auth,
|
||||
ProviderAuthKind::EnvBearer {
|
||||
env_key: "CUSTOM_API_KEY".to_string(),
|
||||
instructions: Some("set CUSTOM_API_KEY".to_string()),
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn enabled_policy_resolves_command_bearer_auth() {
|
||||
let auth = ModelProviderAuthInfo {
|
||||
command: "print-token".to_string(),
|
||||
args: Vec::new(),
|
||||
timeout_ms: NonZeroU64::MIN,
|
||||
refresh_interval_ms: 0,
|
||||
cwd: AbsolutePathBuf::resolve_path_against_base(".", "/tmp"),
|
||||
};
|
||||
let provider = ModelProviderInfo {
|
||||
name: "custom".to_string(),
|
||||
base_url: Some("https://example.com/v1".to_string()),
|
||||
env_key: None,
|
||||
env_key_instructions: None,
|
||||
experimental_bearer_token: None,
|
||||
auth: Some(auth.clone()),
|
||||
wire_api: WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
request_max_retries: None,
|
||||
stream_max_retries: None,
|
||||
stream_idle_timeout_ms: None,
|
||||
websocket_connect_timeout_ms: None,
|
||||
requires_openai_auth: false,
|
||||
supports_websockets: false,
|
||||
};
|
||||
let ProviderRuntime::Resolved(resolved) = resolve_model_provider(
|
||||
"custom",
|
||||
&provider,
|
||||
&ProviderResolutionPolicy::with_enabled_provider_ids(["custom".to_string()]),
|
||||
) else {
|
||||
panic!("enabled provider should resolve through the provider framework");
|
||||
};
|
||||
|
||||
assert_eq!(
|
||||
resolved.auth,
|
||||
ProviderAuthKind::CommandBearer { config: auth }
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn generic_resolved_provider_uses_legacy_api_provider_adapter() {
|
||||
let providers =
|
||||
codex_model_provider_info::built_in_model_providers(/*openai_base_url*/ None);
|
||||
let provider = providers
|
||||
.get(OPENAI_PROVIDER_ID)
|
||||
.expect("provider should exist");
|
||||
let policy =
|
||||
ProviderResolutionPolicy::with_enabled_provider_ids([OPENAI_PROVIDER_ID.to_string()]);
|
||||
let ProviderRuntime::Resolved(resolved) =
|
||||
resolve_model_provider(OPENAI_PROVIDER_ID, provider, &policy)
|
||||
else {
|
||||
panic!("enabled provider should resolve through the provider framework");
|
||||
};
|
||||
|
||||
let legacy = provider.to_api_provider(None).expect("legacy provider");
|
||||
let resolved = resolved
|
||||
.to_legacy_api_provider(None)
|
||||
.expect("resolved provider");
|
||||
|
||||
assert_eq!(resolved.name, legacy.name);
|
||||
assert_eq!(resolved.base_url, legacy.base_url);
|
||||
assert_eq!(resolved.query_params, legacy.query_params);
|
||||
assert_eq!(resolved.headers, legacy.headers);
|
||||
assert_eq!(resolved.stream_idle_timeout, legacy.stream_idle_timeout);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user