core: support dynamic auth tokens for model providers (#16288)

## Summary

Fixes #15189.

Custom model providers that set `requires_openai_auth = false` could
only use static credentials via `env_key` or
`experimental_bearer_token`. That is not enough for providers that mint
short-lived bearer tokens, because Codex had no way to run a command to
obtain a bearer token, cache it briefly in memory, and retry with a
refreshed token after a `401`.

This PR adds that provider config and wires it through the existing auth
design: request paths still go through `AuthManager.auth()` and
`UnauthorizedRecovery`, with `core` only choosing when to use a
provider-backed bearer-only `AuthManager`.

## Scope

To keep this PR reviewable, `/models` only uses provider auth for the
initial request in this change. It does **not** add a dedicated `401`
retry path for `/models`; that can be follow-up work if we still need it
after landing the main provider-token support.

## Example Usage

```toml
model_provider = "corp-openai"

[model_providers.corp-openai]
name = "Corp OpenAI"
base_url = "https://gateway.example.com/openai"
requires_openai_auth = false

[model_providers.corp-openai.auth]
command = "gcloud"
args = ["auth", "print-access-token"]
timeout_ms = 5000
refresh_interval_ms = 300000
```

The command contract is intentionally small:

- write the bearer token to `stdout`
- exit `0`
- any leading or trailing whitespace is trimmed before the token is used

## What Changed

- add `model_providers.<id>.auth` to the config model and generated
schema
- validate that command-backed provider auth is mutually exclusive with
`env_key`, `experimental_bearer_token`, and `requires_openai_auth`
- build a bearer-only `AuthManager` for `ModelClient` and
`ModelsManager` when a provider configures `auth`
- let normal Responses requests and realtime websocket connects use the
provider-backed bearer source through the same `AuthManager.auth()` path
- allow `/models` online refresh for command-auth providers and attach
the provider token to the initial `/models` request
- keep `auth.cwd` available as an advanced escape hatch and include it
in the generated config schema

## Testing

- `cargo test -p codex-core provider_auth_command`
- `cargo test -p codex-core
refresh_available_models_uses_provider_auth_token`
- `cargo test -p codex-core
test_deserialize_provider_auth_config_defaults`

## Docs

- `developers.openai.com/codex` should document the new
`[model_providers.<id>.auth]` block and the token-command contract
This commit is contained in:
Michael Bolin
2026-03-31 01:37:27 -07:00
committed by GitHub
parent 0071968829
commit 20f43c1e05
17 changed files with 598 additions and 4 deletions

View File

@@ -1,5 +1,7 @@
use codex_core::AuthManager;
use codex_core::CodexAuth;
use codex_core::ModelClient;
use codex_core::ModelProviderAuthInfo;
use codex_core::ModelProviderInfo;
use codex_core::NewThread;
use codex_core::Prompt;
@@ -64,6 +66,7 @@ use futures::StreamExt;
use pretty_assertions::assert_eq;
use serde_json::json;
use std::io::Write;
use std::num::NonZeroU64;
use std::sync::Arc;
use tempfile::TempDir;
use uuid::Uuid;
@@ -71,6 +74,7 @@ use wiremock::Mock;
use wiremock::MockServer;
use wiremock::ResponseTemplate;
use wiremock::matchers::body_string_contains;
use wiremock::matchers::header;
use wiremock::matchers::header_regex;
use wiremock::matchers::method;
use wiremock::matchers::path;
@@ -143,6 +147,95 @@ fn write_auth_json(
fake_jwt
}
struct ProviderAuthCommandFixture {
tempdir: TempDir,
command: String,
args: Vec<String>,
}
impl ProviderAuthCommandFixture {
fn new(tokens: &[&str]) -> std::io::Result<Self> {
let tempdir = tempfile::tempdir()?;
let tokens_file = tempdir.path().join("tokens.txt");
let mut token_file_contents = String::new();
for token in tokens {
token_file_contents.push_str(token);
token_file_contents.push('\n');
}
std::fs::write(&tokens_file, token_file_contents)?;
#[cfg(unix)]
let (command, args) = {
let script_path = tempdir.path().join("print-token.sh");
std::fs::write(
&script_path,
r#"#!/bin/sh
first_line=$(sed -n '1p' tokens.txt)
printf '%s\n' "$first_line"
tail -n +2 tokens.txt > tokens.next
mv tokens.next tokens.txt
"#,
)?;
let mut permissions = std::fs::metadata(&script_path)?.permissions();
{
use std::os::unix::fs::PermissionsExt;
permissions.set_mode(0o755);
}
std::fs::set_permissions(&script_path, permissions)?;
("./print-token.sh".to_string(), Vec::new())
};
#[cfg(windows)]
let (command, args) = {
let script_path = tempdir.path().join("print-token.ps1");
std::fs::write(
&script_path,
r#"$lines = Get-Content -Path tokens.txt
if ($lines.Count -eq 0) { exit 1 }
Write-Output $lines[0]
$lines | Select-Object -Skip 1 | Set-Content -Path tokens.txt
"#,
)?;
(
"powershell".to_string(),
vec![
"-NoProfile".to_string(),
"-ExecutionPolicy".to_string(),
"Bypass".to_string(),
"-File".to_string(),
".\\print-token.ps1".to_string(),
],
)
};
Ok(Self {
tempdir,
command,
args,
})
}
fn auth(&self) -> ModelProviderAuthInfo {
ModelProviderAuthInfo {
command: self.command.clone(),
args: self.args.clone(),
timeout_ms: non_zero_u64(/*value*/ 1_000),
refresh_interval_ms: non_zero_u64(/*value*/ 60_000),
cwd: match codex_utils_absolute_path::AbsolutePathBuf::try_from(self.tempdir.path()) {
Ok(cwd) => cwd,
Err(err) => panic!("tempdir should be absolute: {err}"),
},
}
}
}
fn non_zero_u64(value: u64) -> NonZeroU64 {
match NonZeroU64::new(value) {
Some(value) => value,
None => panic!("expected non-zero value: {value}"),
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn resume_includes_initial_messages_and_sends_prior_items() {
skip_if_no_network!();
@@ -659,6 +752,146 @@ async fn includes_conversation_id_and_model_headers_in_request() {
assert_eq!(request_authorization, "Bearer Test API Key");
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn provider_auth_command_supplies_bearer_token() {
skip_if_no_network!();
let server = MockServer::start().await;
mount_sse_once_match(
&server,
header("authorization", "Bearer command-token"),
sse(vec![ev_response_created("resp1"), ev_completed("resp1")]),
)
.await;
let auth_fixture = ProviderAuthCommandFixture::new(&["command-token"]).unwrap();
send_provider_auth_request(&server, auth_fixture.auth()).await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn provider_auth_command_refreshes_after_401() {
skip_if_no_network!();
let server = MockServer::start().await;
let auth_fixture = ProviderAuthCommandFixture::new(&["first-token", "second-token"]).unwrap();
Mock::given(method("POST"))
.and(path("/v1/responses"))
.and(header_regex("Authorization", "Bearer first-token"))
.respond_with(ResponseTemplate::new(401).set_body_string("unauthorized"))
.expect(1)
.mount(&server)
.await;
Mock::given(method("POST"))
.and(path("/v1/responses"))
.and(header_regex("Authorization", "Bearer second-token"))
.respond_with(
ResponseTemplate::new(200)
.insert_header("content-type", "text/event-stream")
.set_body_raw(
sse(vec![ev_response_created("resp1"), ev_completed("resp1")]),
"text/event-stream",
),
)
.expect(1)
.mount(&server)
.await;
send_provider_auth_request(&server, auth_fixture.auth()).await;
}
/// Issues one streamed Responses request through a provider configured with command-backed auth.
///
/// The caller owns the server-side assertions, so this helper only validates that the request
/// reaches `Completed` without surfacing an auth or transport error to the client.
async fn send_provider_auth_request(server: &MockServer, auth: ModelProviderAuthInfo) {
let provider = ModelProviderInfo {
name: "corp".into(),
base_url: Some(format!("{}/v1", server.uri())),
env_key: None,
env_key_instructions: None,
experimental_bearer_token: None,
auth: Some(auth),
wire_api: WireApi::Responses,
query_params: None,
http_headers: None,
env_http_headers: None,
request_max_retries: Some(0),
stream_max_retries: Some(0),
stream_idle_timeout_ms: Some(5_000),
websocket_connect_timeout_ms: None,
requires_openai_auth: false,
supports_websockets: false,
};
let codex_home = TempDir::new().unwrap();
let mut config = load_default_config_for_test(&codex_home).await;
config.model_provider_id = provider.name.clone();
config.model_provider = provider.clone();
let effort = config.model_reasoning_effort;
let summary = config.model_reasoning_summary;
let model = codex_core::test_support::get_model_offline(config.model.as_deref());
config.model = Some(model.clone());
let config = Arc::new(config);
let model_info =
codex_core::test_support::construct_model_info_offline(model.as_str(), &config);
let conversation_id = ThreadId::new();
let session_telemetry = SessionTelemetry::new(
conversation_id,
model.as_str(),
model_info.slug.as_str(),
/*account_id*/ None,
Some("test@test.com".to_string()),
/*auth_mode*/ None,
"test_originator".to_string(),
/*log_user_prompts*/ false,
"test".to_string(),
SessionSource::Exec,
);
let client = ModelClient::new(
Some(AuthManager::from_auth_for_testing(CodexAuth::from_api_key(
"unused-api-key",
))),
conversation_id,
provider,
SessionSource::Exec,
config.model_verbosity,
/*enable_request_compression*/ false,
/*include_timing_metrics*/ false,
/*beta_features_header*/ None,
);
let mut client_session = client.new_session();
let mut prompt = Prompt::default();
prompt.input.push(ResponseItem::Message {
id: None,
role: "user".to_string(),
content: vec![ContentItem::InputText {
text: "hello".to_string(),
}],
end_turn: None,
phase: None,
});
let mut stream = client_session
.stream(
&prompt,
&model_info,
&session_telemetry,
effort,
summary.unwrap_or(ReasoningSummary::Auto),
/*service_tier*/ None,
/*turn_metadata_header*/ None,
)
.await
.expect("responses stream to start");
while let Some(event) = stream.next().await {
if let Ok(ResponseEvent::Completed { .. }) = event {
break;
}
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn includes_base_instructions_override_in_request() {
skip_if_no_network!();
@@ -1796,6 +2029,7 @@ async fn azure_responses_request_includes_store_and_reasoning_ids() {
env_key: None,
env_key_instructions: None,
experimental_bearer_token: None,
auth: None,
wire_api: WireApi::Responses,
query_params: None,
http_headers: None,
@@ -2396,6 +2630,7 @@ async fn azure_overrides_assign_properties_used_for_responses_url() {
// Reuse the existing environment variable to avoid using unsafe code
env_key: Some(existing_env_var_with_random_value.to_string()),
experimental_bearer_token: None,
auth: None,
query_params: Some(std::collections::HashMap::from([(
"api-version".to_string(),
"2025-04-01-preview".to_string(),
@@ -2486,6 +2721,7 @@ async fn env_var_overrides_loaded_auth() {
)])),
env_key_instructions: None,
experimental_bearer_token: None,
auth: None,
wire_api: WireApi::Responses,
http_headers: Some(std::collections::HashMap::from([(
"Custom-Header".to_string(),