Compare commits

...

1 Commits

Author SHA1 Message Date
pakrym-oai
26c9558ba8 Fix OPENAI_API_KEY fallback after websocket upgrade failure 2026-03-23 11:34:05 -07:00
4 changed files with 165 additions and 1 deletions

View File

@@ -11,6 +11,7 @@ use serde::Deserialize;
use serde_json::Value;
use crate::auth::CodexAuth;
use crate::auth::read_openai_api_key_from_env;
use crate::error::CodexErr;
use crate::error::RetryLimitReachedError;
use crate::error::UnexpectedResponseError;
@@ -196,6 +197,21 @@ pub(crate) fn auth_provider_from_auth(
}
}
pub(crate) fn resolve_auth_for_provider(
auth: Option<CodexAuth>,
provider: &ModelProviderInfo,
) -> Option<CodexAuth> {
if auth.is_some()
|| !provider.requires_openai_auth
|| provider.env_key.is_some()
|| provider.experimental_bearer_token.is_some()
{
return auth;
}
read_openai_api_key_from_env().map(|api_key| CodexAuth::from_api_key(&api_key))
}
#[derive(Debug, Deserialize)]
struct UsageErrorResponse {
error: UsageErrorBody,

View File

@@ -33,6 +33,7 @@ use std::sync::atomic::Ordering;
use crate::api_bridge::CoreAuthProvider;
use crate::api_bridge::auth_provider_from_auth;
use crate::api_bridge::map_api_error;
use crate::api_bridge::resolve_auth_for_provider;
use crate::auth::UnauthorizedRecovery;
use crate::auth_env_telemetry::AuthEnvTelemetry;
use crate::auth_env_telemetry::collect_auth_env_telemetry;
@@ -528,6 +529,7 @@ impl ModelClient {
Some(manager) => manager.auth().await,
None => None,
};
let auth = resolve_auth_for_provider(auth, &self.state.provider);
let api_provider = self
.state
.provider

View File

@@ -1,6 +1,7 @@
use super::cache::ModelsCacheManager;
use crate::api_bridge::auth_provider_from_auth;
use crate::api_bridge::map_api_error;
use crate::api_bridge::resolve_auth_for_provider;
use crate::auth::AuthManager;
use crate::auth::AuthMode;
use crate::auth::CodexAuth;
@@ -431,7 +432,7 @@ impl ModelsManager {
async fn fetch_and_update_models(&self) -> CoreResult<()> {
let _timer =
codex_otel::start_global_timer("codex.remote_models.fetch_update.duration_ms", &[]);
let auth = self.auth_manager.auth().await;
let auth = resolve_auth_for_provider(self.auth_manager.auth().await, &self.provider);
let auth_mode = auth.as_ref().map(CodexAuth::auth_mode);
let api_provider = self.provider.to_api_provider(auth_mode)?;
let api_auth = auth_provider_from_auth(auth.clone(), &self.provider)?;

View File

@@ -1,19 +1,41 @@
use anyhow::Result;
use codex_core::AuthManager;
use codex_core::ModelClient;
use codex_core::Prompt;
use codex_core::ResponseEvent;
use codex_core::auth::AuthCredentialsStoreMode;
use codex_core::auth::OPENAI_API_KEY_ENV_VAR;
use codex_core::built_in_model_providers;
use codex_otel::SessionTelemetry;
use codex_otel::TelemetryAuthMode;
use codex_protocol::ThreadId;
use codex_protocol::config_types::ReasoningSummary;
use codex_protocol::models::ContentItem;
use codex_protocol::models::ResponseItem;
use codex_protocol::openai_models::ReasoningEffort;
use codex_protocol::protocol::AskForApproval;
use codex_protocol::protocol::EventMsg;
use codex_protocol::protocol::Op;
use codex_protocol::protocol::SandboxPolicy;
use codex_protocol::protocol::SessionSource;
use codex_protocol::user_input::UserInput;
use core_test_support::load_default_config_for_test;
use core_test_support::responses;
use core_test_support::responses::ev_completed;
use core_test_support::responses::ev_response_created;
use core_test_support::responses::mount_sse_once;
use core_test_support::responses::mount_sse_once_match;
use core_test_support::responses::mount_sse_sequence;
use core_test_support::responses::sse;
use core_test_support::skip_if_no_network;
use core_test_support::test_codex::TestCodex;
use core_test_support::test_codex::test_codex;
use futures::StreamExt;
use pretty_assertions::assert_eq;
use std::ffi::OsStr;
use std::ffi::OsString;
use std::sync::Arc;
use tempfile::TempDir;
use tokio::time::Duration;
use tokio::time::timeout;
use wiremock::Mock;
@@ -22,6 +44,34 @@ use wiremock::http::Method;
use wiremock::matchers::method;
use wiremock::matchers::path_regex;
const MODEL: &str = "gpt-5.3-codex";
struct EnvVarGuard {
key: &'static str,
original: Option<OsString>,
}
impl EnvVarGuard {
fn set(key: &'static str, value: &OsStr) -> Self {
let original = std::env::var_os(key);
unsafe {
std::env::set_var(key, value);
}
Self { key, original }
}
}
impl Drop for EnvVarGuard {
fn drop(&mut self) {
unsafe {
match &self.original {
Some(value) => std::env::set_var(self.key, value),
None => std::env::remove_var(self.key),
}
}
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn websocket_fallback_switches_to_http_on_upgrade_required_connect() -> Result<()> {
skip_if_no_network!(Ok(()));
@@ -74,6 +124,101 @@ async fn websocket_fallback_switches_to_http_on_upgrade_required_connect() -> Re
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn websocket_fallback_uses_openai_api_key_env_for_http_replay() -> Result<()> {
skip_if_no_network!(Ok(()));
let _api_key_guard = EnvVarGuard::set(OPENAI_API_KEY_ENV_VAR, OsStr::new("dummy"));
let home = TempDir::new()?;
let server = responses::start_mock_server().await;
Mock::given(method("GET"))
.and(path_regex(".*/responses$"))
.respond_with(ResponseTemplate::new(426))
.expect(1)
.mount(&server)
.await;
let response_mock = mount_sse_once_match(
&server,
wiremock::matchers::header("authorization", "Bearer dummy"),
sse(vec![ev_response_created("resp-1"), ev_completed("resp-1")]),
)
.await;
let mut config = load_default_config_for_test(&home).await;
config.model = Some(MODEL.to_string());
let model_info = codex_core::test_support::construct_model_info_offline(MODEL, &config);
let mut provider = built_in_model_providers(/* openai_base_url */ None)["openai"].clone();
provider.base_url = Some(format!("{}/v1", server.uri()));
provider.supports_websockets = true;
provider.request_max_retries = Some(0);
provider.stream_max_retries = Some(0);
let auth_manager = Arc::new(AuthManager::new(
home.path().to_path_buf(),
/*enable_codex_api_key_env*/ false,
AuthCredentialsStoreMode::File,
));
let conversation_id = ThreadId::new();
let session_telemetry = SessionTelemetry::new(
conversation_id,
MODEL,
model_info.slug.as_str(),
None,
Some("test@test.com".to_string()),
auth_manager.auth_mode().map(TelemetryAuthMode::from),
"test_originator".to_string(),
false,
"test".to_string(),
SessionSource::Exec,
);
let client = ModelClient::new(
Some(auth_manager),
conversation_id,
provider,
SessionSource::Exec,
config.model_verbosity,
false,
false,
None,
);
let mut client_session = client.new_session();
let mut prompt = Prompt::default();
prompt.input = vec![ResponseItem::Message {
id: None,
role: "user".to_string(),
content: vec![ContentItem::InputText {
text: "hello".to_string(),
}],
end_turn: None,
phase: None,
}];
let mut stream = client_session
.stream(
&prompt,
&model_info,
&session_telemetry,
/*effort*/ None::<ReasoningEffort>,
ReasoningSummary::Auto,
None,
None,
)
.await?;
while let Some(event) = stream.next().await {
if matches!(event?, ResponseEvent::Completed { .. }) {
break;
}
}
let request = response_mock.single_request();
assert_eq!(request.path(), "/v1/responses");
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn websocket_fallback_switches_to_http_after_retries_exhausted() -> Result<()> {
skip_if_no_network!(Ok(()));