mirror of
https://github.com/openai/codex.git
synced 2026-04-24 14:45:27 +00:00
fix: properly handle 401 error in clound requirement fetch. (#14049)
Handle cloud requirements 401s with the same auth recovery flow as normal requests, so permanent refresh failures surface the existing user-facing auth message instead of a generic workspace-config load error.
This commit is contained in:
@@ -10,6 +10,7 @@ use codex_protocol::account::PlanType as AccountPlanType;
|
||||
use codex_protocol::protocol::CreditsSnapshot;
|
||||
use codex_protocol::protocol::RateLimitSnapshot;
|
||||
use codex_protocol::protocol::RateLimitWindow;
|
||||
use reqwest::StatusCode;
|
||||
use reqwest::header::AUTHORIZATION;
|
||||
use reqwest::header::CONTENT_TYPE;
|
||||
use reqwest::header::HeaderMap;
|
||||
@@ -17,6 +18,65 @@ use reqwest::header::HeaderName;
|
||||
use reqwest::header::HeaderValue;
|
||||
use reqwest::header::USER_AGENT;
|
||||
use serde::de::DeserializeOwned;
|
||||
use std::fmt;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum RequestError {
|
||||
UnexpectedStatus {
|
||||
method: String,
|
||||
url: String,
|
||||
status: StatusCode,
|
||||
content_type: String,
|
||||
body: String,
|
||||
},
|
||||
Other(anyhow::Error),
|
||||
}
|
||||
|
||||
impl RequestError {
|
||||
pub fn status(&self) -> Option<StatusCode> {
|
||||
match self {
|
||||
Self::UnexpectedStatus { status, .. } => Some(*status),
|
||||
Self::Other(_) => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_unauthorized(&self) -> bool {
|
||||
self.status() == Some(StatusCode::UNAUTHORIZED)
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for RequestError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
Self::UnexpectedStatus {
|
||||
method,
|
||||
url,
|
||||
status,
|
||||
content_type,
|
||||
body,
|
||||
} => write!(
|
||||
f,
|
||||
"{method} {url} failed: {status}; content-type={content_type}; body={body}"
|
||||
),
|
||||
Self::Other(err) => write!(f, "{err}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for RequestError {
|
||||
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
|
||||
match self {
|
||||
Self::UnexpectedStatus { .. } => None,
|
||||
Self::Other(err) => Some(err.as_ref()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<anyhow::Error> for RequestError {
|
||||
fn from(err: anyhow::Error) -> Self {
|
||||
Self::Other(err)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub enum PathStyle {
|
||||
@@ -148,6 +208,33 @@ impl Client {
|
||||
Ok((body, ct))
|
||||
}
|
||||
|
||||
async fn exec_request_detailed(
|
||||
&self,
|
||||
req: reqwest::RequestBuilder,
|
||||
method: &str,
|
||||
url: &str,
|
||||
) -> std::result::Result<(String, String), RequestError> {
|
||||
let res = req.send().await.map_err(anyhow::Error::from)?;
|
||||
let status = res.status();
|
||||
let content_type = res
|
||||
.headers()
|
||||
.get(CONTENT_TYPE)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
let body = res.text().await.unwrap_or_default();
|
||||
if !status.is_success() {
|
||||
return Err(RequestError::UnexpectedStatus {
|
||||
method: method.to_string(),
|
||||
url: url.to_string(),
|
||||
status,
|
||||
content_type,
|
||||
body,
|
||||
});
|
||||
}
|
||||
Ok((body, content_type))
|
||||
}
|
||||
|
||||
fn decode_json<T: DeserializeOwned>(&self, url: &str, ct: &str, body: &str) -> Result<T> {
|
||||
match serde_json::from_str::<T>(body) {
|
||||
Ok(v) => Ok(v),
|
||||
@@ -256,14 +343,17 @@ impl Client {
|
||||
///
|
||||
/// `GET /api/codex/config/requirements` (Codex API style) or
|
||||
/// `GET /wham/config/requirements` (ChatGPT backend-api style).
|
||||
pub async fn get_config_requirements_file(&self) -> Result<ConfigFileResponse> {
|
||||
pub async fn get_config_requirements_file(
|
||||
&self,
|
||||
) -> std::result::Result<ConfigFileResponse, RequestError> {
|
||||
let url = match self.path_style {
|
||||
PathStyle::CodexApi => format!("{}/api/codex/config/requirements", self.base_url),
|
||||
PathStyle::ChatGptApi => format!("{}/wham/config/requirements", self.base_url),
|
||||
};
|
||||
let req = self.http.get(&url).headers(self.headers());
|
||||
let (body, ct) = self.exec_request(req, "GET", &url).await?;
|
||||
let (body, ct) = self.exec_request_detailed(req, "GET", &url).await?;
|
||||
self.decode_json::<ConfigFileResponse>(&url, &ct, &body)
|
||||
.map_err(RequestError::from)
|
||||
}
|
||||
|
||||
/// Create a new task (user turn) by POSTing to the appropriate backend path
|
||||
|
||||
@@ -2,6 +2,7 @@ mod client;
|
||||
pub mod types;
|
||||
|
||||
pub use client::Client;
|
||||
pub use client::RequestError;
|
||||
pub use types::CodeTaskDetailsResponse;
|
||||
pub use types::CodeTaskDetailsResponseExt;
|
||||
pub use types::ConfigFileResponse;
|
||||
|
||||
@@ -17,6 +17,7 @@ use chrono::Utc;
|
||||
use codex_backend_client::Client as BackendClient;
|
||||
use codex_core::AuthManager;
|
||||
use codex_core::auth::CodexAuth;
|
||||
use codex_core::auth::RefreshTokenError;
|
||||
use codex_core::config_loader::CloudRequirementsLoadError;
|
||||
use codex_core::config_loader::CloudRequirementsLoader;
|
||||
use codex_core::config_loader::ConfigRequirementsToml;
|
||||
@@ -44,6 +45,7 @@ const CLOUD_REQUIREMENTS_MAX_ATTEMPTS: usize = 5;
|
||||
const CLOUD_REQUIREMENTS_CACHE_FILENAME: &str = "cloud-requirements-cache.json";
|
||||
const CLOUD_REQUIREMENTS_CACHE_REFRESH_INTERVAL: Duration = Duration::from_secs(5 * 60);
|
||||
const CLOUD_REQUIREMENTS_CACHE_TTL: Duration = Duration::from_secs(30 * 60);
|
||||
const CLOUD_REQUIREMENTS_LOAD_FAILED_MESSAGE: &str = "failed to load your workspace-managed config";
|
||||
const CLOUD_REQUIREMENTS_CACHE_WRITE_HMAC_KEY: &[u8] =
|
||||
b"codex-cloud-requirements-cache-v3-064f8542-75b4-494c-a294-97d3ce597271";
|
||||
const CLOUD_REQUIREMENTS_CACHE_READ_HMAC_KEYS: &[&[u8]] =
|
||||
@@ -62,6 +64,12 @@ enum FetchCloudRequirementsStatus {
|
||||
Request,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
enum FetchCloudRequirementsError {
|
||||
Retryable(FetchCloudRequirementsStatus),
|
||||
Unauthorized(CloudRequirementsLoadError),
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, Error, PartialEq)]
|
||||
enum CacheLoadStatus {
|
||||
#[error("Skipping cloud requirements cache read because auth identity is incomplete.")]
|
||||
@@ -141,6 +149,16 @@ fn verify_cache_signature(payload_bytes: &[u8], signature: &str) -> bool {
|
||||
.any(|key| verify_cache_signature_with_key(payload_bytes, &signature_bytes, key))
|
||||
}
|
||||
|
||||
fn auth_identity(auth: &CodexAuth) -> (Option<String>, Option<String>) {
|
||||
let token_data = auth.get_token_data().ok();
|
||||
let chatgpt_user_id = token_data
|
||||
.as_ref()
|
||||
.and_then(|token_data| token_data.id_token.chatgpt_user_id.as_deref())
|
||||
.map(str::to_owned);
|
||||
let account_id = auth.get_account_id();
|
||||
(chatgpt_user_id, account_id)
|
||||
}
|
||||
|
||||
fn cache_payload_bytes(payload: &CloudRequirementsCacheSignedPayload) -> Option<Vec<u8>> {
|
||||
serde_json::to_vec(&payload).ok()
|
||||
}
|
||||
@@ -153,7 +171,7 @@ trait RequirementsFetcher: Send + Sync {
|
||||
async fn fetch_requirements(
|
||||
&self,
|
||||
auth: &CodexAuth,
|
||||
) -> Result<Option<String>, FetchCloudRequirementsStatus>;
|
||||
) -> Result<Option<String>, FetchCloudRequirementsError>;
|
||||
}
|
||||
|
||||
struct BackendRequirementsFetcher {
|
||||
@@ -171,7 +189,7 @@ impl RequirementsFetcher for BackendRequirementsFetcher {
|
||||
async fn fetch_requirements(
|
||||
&self,
|
||||
auth: &CodexAuth,
|
||||
) -> Result<Option<String>, FetchCloudRequirementsStatus> {
|
||||
) -> Result<Option<String>, FetchCloudRequirementsError> {
|
||||
let client = BackendClient::from_auth(self.base_url.clone(), auth)
|
||||
.inspect_err(|err| {
|
||||
tracing::warn!(
|
||||
@@ -179,13 +197,25 @@ impl RequirementsFetcher for BackendRequirementsFetcher {
|
||||
"Failed to construct backend client for cloud requirements"
|
||||
);
|
||||
})
|
||||
.map_err(|_| FetchCloudRequirementsStatus::BackendClientInit)?;
|
||||
.map_err(|_| {
|
||||
FetchCloudRequirementsError::Retryable(
|
||||
FetchCloudRequirementsStatus::BackendClientInit,
|
||||
)
|
||||
})?;
|
||||
|
||||
let response = client
|
||||
.get_config_requirements_file()
|
||||
.await
|
||||
.inspect_err(|err| tracing::warn!(error = %err, "Failed to fetch cloud requirements"))
|
||||
.map_err(|_| FetchCloudRequirementsStatus::Request)?;
|
||||
.map_err(|err| {
|
||||
if err.is_unauthorized() {
|
||||
FetchCloudRequirementsError::Unauthorized(CloudRequirementsLoadError::new(
|
||||
err.to_string(),
|
||||
))
|
||||
} else {
|
||||
FetchCloudRequirementsError::Retryable(FetchCloudRequirementsStatus::Request)
|
||||
}
|
||||
})?;
|
||||
|
||||
let Some(contents) = response.contents else {
|
||||
tracing::info!(
|
||||
@@ -281,14 +311,12 @@ impl CloudRequirementsService {
|
||||
{
|
||||
return Ok(None);
|
||||
}
|
||||
let token_data = auth.get_token_data().ok();
|
||||
let chatgpt_user_id = token_data
|
||||
.as_ref()
|
||||
.and_then(|token_data| token_data.id_token.chatgpt_user_id.as_deref());
|
||||
let account_id = auth.get_account_id();
|
||||
let account_id = account_id.as_deref();
|
||||
let (chatgpt_user_id, account_id) = auth_identity(&auth);
|
||||
|
||||
match self.load_cache(chatgpt_user_id, account_id).await {
|
||||
match self
|
||||
.load_cache(chatgpt_user_id.as_deref(), account_id.as_deref())
|
||||
.await
|
||||
{
|
||||
Ok(signed_payload) => {
|
||||
tracing::info!(
|
||||
path = %self.cache_path.display(),
|
||||
@@ -301,28 +329,20 @@ impl CloudRequirementsService {
|
||||
}
|
||||
}
|
||||
|
||||
self.fetch_with_retries(&auth, chatgpt_user_id, account_id)
|
||||
.await
|
||||
.ok_or_else(|| {
|
||||
let message = "failed to load your workspace-managed config";
|
||||
tracing::error!(
|
||||
path = %self.cache_path.display(),
|
||||
"{message}"
|
||||
);
|
||||
CloudRequirementsLoadError::new(message)
|
||||
})
|
||||
self.fetch_with_retries(auth).await
|
||||
}
|
||||
|
||||
async fn fetch_with_retries(
|
||||
&self,
|
||||
auth: &CodexAuth,
|
||||
chatgpt_user_id: Option<&str>,
|
||||
account_id: Option<&str>,
|
||||
) -> Option<Option<ConfigRequirementsToml>> {
|
||||
for attempt in 1..=CLOUD_REQUIREMENTS_MAX_ATTEMPTS {
|
||||
let contents = match self.fetcher.fetch_requirements(auth).await {
|
||||
mut auth: CodexAuth,
|
||||
) -> Result<Option<ConfigRequirementsToml>, CloudRequirementsLoadError> {
|
||||
let mut attempt = 1;
|
||||
let mut auth_recovery = self.auth_manager.unauthorized_recovery();
|
||||
|
||||
while attempt <= CLOUD_REQUIREMENTS_MAX_ATTEMPTS {
|
||||
let contents = match self.fetcher.fetch_requirements(&auth).await {
|
||||
Ok(contents) => contents,
|
||||
Err(status) => {
|
||||
Err(FetchCloudRequirementsError::Retryable(status)) => {
|
||||
if attempt < CLOUD_REQUIREMENTS_MAX_ATTEMPTS {
|
||||
tracing::warn!(
|
||||
status = ?status,
|
||||
@@ -332,8 +352,60 @@ impl CloudRequirementsService {
|
||||
);
|
||||
sleep(backoff(attempt as u64)).await;
|
||||
}
|
||||
attempt += 1;
|
||||
continue;
|
||||
}
|
||||
Err(FetchCloudRequirementsError::Unauthorized(err)) => {
|
||||
if auth_recovery.has_next() {
|
||||
tracing::warn!(
|
||||
attempt,
|
||||
max_attempts = CLOUD_REQUIREMENTS_MAX_ATTEMPTS,
|
||||
"Cloud requirements request was unauthorized; attempting auth recovery"
|
||||
);
|
||||
match auth_recovery.next().await {
|
||||
Ok(()) => {
|
||||
let Some(refreshed_auth) = self.auth_manager.auth().await else {
|
||||
tracing::error!(
|
||||
"Auth recovery succeeded but no auth is available for cloud requirements"
|
||||
);
|
||||
return Err(CloudRequirementsLoadError::new(
|
||||
CLOUD_REQUIREMENTS_LOAD_FAILED_MESSAGE,
|
||||
));
|
||||
};
|
||||
auth = refreshed_auth;
|
||||
continue;
|
||||
}
|
||||
Err(RefreshTokenError::Permanent(failed)) => {
|
||||
tracing::warn!(
|
||||
error = %failed,
|
||||
"Failed to recover from unauthorized cloud requirements request"
|
||||
);
|
||||
return Err(CloudRequirementsLoadError::new(failed.message));
|
||||
}
|
||||
Err(RefreshTokenError::Transient(recovery_err)) => {
|
||||
if attempt < CLOUD_REQUIREMENTS_MAX_ATTEMPTS {
|
||||
tracing::warn!(
|
||||
error = %recovery_err,
|
||||
attempt,
|
||||
max_attempts = CLOUD_REQUIREMENTS_MAX_ATTEMPTS,
|
||||
"Failed to recover from unauthorized cloud requirements request; retrying"
|
||||
);
|
||||
sleep(backoff(attempt as u64)).await;
|
||||
}
|
||||
attempt += 1;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tracing::warn!(
|
||||
error = %err,
|
||||
"Cloud requirements request was unauthorized and no auth recovery is available"
|
||||
);
|
||||
return Err(CloudRequirementsLoadError::new(
|
||||
CLOUD_REQUIREMENTS_LOAD_FAILED_MESSAGE,
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let requirements = match contents.as_deref() {
|
||||
@@ -341,27 +413,29 @@ impl CloudRequirementsService {
|
||||
Ok(requirements) => requirements,
|
||||
Err(err) => {
|
||||
tracing::error!(error = %err, "Failed to parse cloud requirements");
|
||||
return None;
|
||||
return Err(CloudRequirementsLoadError::new(
|
||||
CLOUD_REQUIREMENTS_LOAD_FAILED_MESSAGE,
|
||||
));
|
||||
}
|
||||
},
|
||||
None => None,
|
||||
};
|
||||
|
||||
if let Err(err) = self
|
||||
.save_cache(
|
||||
chatgpt_user_id.map(str::to_owned),
|
||||
account_id.map(str::to_owned),
|
||||
contents,
|
||||
)
|
||||
.await
|
||||
{
|
||||
let (chatgpt_user_id, account_id) = auth_identity(&auth);
|
||||
if let Err(err) = self.save_cache(chatgpt_user_id, account_id, contents).await {
|
||||
tracing::warn!(error = %err, "Failed to write cloud requirements cache");
|
||||
}
|
||||
|
||||
return Some(requirements);
|
||||
return Ok(requirements);
|
||||
}
|
||||
|
||||
None
|
||||
tracing::error!(
|
||||
path = %self.cache_path.display(),
|
||||
"{CLOUD_REQUIREMENTS_LOAD_FAILED_MESSAGE}"
|
||||
);
|
||||
Err(CloudRequirementsLoadError::new(
|
||||
CLOUD_REQUIREMENTS_LOAD_FAILED_MESSAGE,
|
||||
))
|
||||
}
|
||||
|
||||
async fn refresh_cache_in_background(&self) {
|
||||
@@ -392,20 +466,10 @@ impl CloudRequirementsService {
|
||||
return false;
|
||||
}
|
||||
|
||||
let token_data = auth.get_token_data().ok();
|
||||
let chatgpt_user_id = token_data
|
||||
.as_ref()
|
||||
.and_then(|token_data| token_data.id_token.chatgpt_user_id.as_deref());
|
||||
let account_id = auth.get_account_id();
|
||||
let account_id = account_id.as_deref();
|
||||
|
||||
if self
|
||||
.fetch_with_retries(&auth, chatgpt_user_id, account_id)
|
||||
.await
|
||||
.is_none()
|
||||
{
|
||||
if let Err(err) = self.fetch_with_retries(auth).await {
|
||||
tracing::error!(
|
||||
path = %self.cache_path.display(),
|
||||
error = %err,
|
||||
"Failed to refresh cloud requirements cache from remote"
|
||||
);
|
||||
if let Some(metrics) = codex_otel::metrics::global() {
|
||||
@@ -594,6 +658,7 @@ mod tests {
|
||||
use std::path::Path;
|
||||
use std::sync::atomic::AtomicUsize;
|
||||
use std::sync::atomic::Ordering;
|
||||
use tempfile::TempDir;
|
||||
use tempfile::tempdir;
|
||||
|
||||
fn write_auth_json(codex_home: &Path, value: serde_json::Value) -> std::io::Result<()> {
|
||||
@@ -622,6 +687,49 @@ mod tests {
|
||||
account_id: Option<&str>,
|
||||
) -> Arc<AuthManager> {
|
||||
let tmp = tempdir().expect("tempdir");
|
||||
write_auth_json(
|
||||
tmp.path(),
|
||||
chatgpt_auth_json(
|
||||
plan_type,
|
||||
chatgpt_user_id,
|
||||
account_id,
|
||||
"test-access-token",
|
||||
"test-refresh-token",
|
||||
),
|
||||
)
|
||||
.expect("write auth");
|
||||
Arc::new(AuthManager::new(
|
||||
tmp.path().to_path_buf(),
|
||||
false,
|
||||
AuthCredentialsStoreMode::File,
|
||||
))
|
||||
}
|
||||
|
||||
fn chatgpt_auth_json(
|
||||
plan_type: &str,
|
||||
chatgpt_user_id: Option<&str>,
|
||||
account_id: Option<&str>,
|
||||
access_token: &str,
|
||||
refresh_token: &str,
|
||||
) -> serde_json::Value {
|
||||
chatgpt_auth_json_with_mode(
|
||||
plan_type,
|
||||
chatgpt_user_id,
|
||||
account_id,
|
||||
access_token,
|
||||
refresh_token,
|
||||
None,
|
||||
)
|
||||
}
|
||||
|
||||
fn chatgpt_auth_json_with_mode(
|
||||
plan_type: &str,
|
||||
chatgpt_user_id: Option<&str>,
|
||||
account_id: Option<&str>,
|
||||
access_token: &str,
|
||||
refresh_token: &str,
|
||||
auth_mode: Option<&str>,
|
||||
) -> serde_json::Value {
|
||||
let header = json!({ "alg": "none", "typ": "JWT" });
|
||||
let auth_payload = json!({
|
||||
"chatgpt_plan_type": plan_type,
|
||||
@@ -637,22 +745,54 @@ mod tests {
|
||||
let signature_b64 = URL_SAFE_NO_PAD.encode(b"sig");
|
||||
let fake_jwt = format!("{header_b64}.{payload_b64}.{signature_b64}");
|
||||
|
||||
let auth_json = json!({
|
||||
let mut auth_json = json!({
|
||||
"OPENAI_API_KEY": null,
|
||||
"tokens": {
|
||||
"id_token": fake_jwt,
|
||||
"access_token": "test-access-token",
|
||||
"refresh_token": "test-refresh-token",
|
||||
"access_token": access_token,
|
||||
"refresh_token": refresh_token,
|
||||
"account_id": account_id,
|
||||
},
|
||||
"last_refresh": "2025-01-01T00:00:00Z",
|
||||
});
|
||||
write_auth_json(tmp.path(), auth_json).expect("write auth");
|
||||
Arc::new(AuthManager::new(
|
||||
tmp.path().to_path_buf(),
|
||||
false,
|
||||
AuthCredentialsStoreMode::File,
|
||||
))
|
||||
if let Some(auth_mode) = auth_mode {
|
||||
auth_json["auth_mode"] = serde_json::Value::String(auth_mode.to_string());
|
||||
}
|
||||
auth_json
|
||||
}
|
||||
|
||||
struct ManagedAuthContext {
|
||||
_home: TempDir,
|
||||
manager: Arc<AuthManager>,
|
||||
}
|
||||
|
||||
fn managed_auth_context(
|
||||
plan_type: &str,
|
||||
chatgpt_user_id: Option<&str>,
|
||||
account_id: Option<&str>,
|
||||
access_token: &str,
|
||||
refresh_token: &str,
|
||||
) -> ManagedAuthContext {
|
||||
let home = tempdir().expect("tempdir");
|
||||
write_auth_json(
|
||||
home.path(),
|
||||
chatgpt_auth_json(
|
||||
plan_type,
|
||||
chatgpt_user_id,
|
||||
account_id,
|
||||
access_token,
|
||||
refresh_token,
|
||||
),
|
||||
)
|
||||
.expect("write auth");
|
||||
ManagedAuthContext {
|
||||
manager: Arc::new(AuthManager::new(
|
||||
home.path().to_path_buf(),
|
||||
false,
|
||||
AuthCredentialsStoreMode::File,
|
||||
)),
|
||||
_home: home,
|
||||
}
|
||||
}
|
||||
|
||||
fn auth_manager_with_plan(plan_type: &str) -> Arc<AuthManager> {
|
||||
@@ -663,6 +803,10 @@ mod tests {
|
||||
contents.and_then(|contents| parse_cloud_requirements(contents).ok().flatten())
|
||||
}
|
||||
|
||||
fn request_error() -> FetchCloudRequirementsError {
|
||||
FetchCloudRequirementsError::Retryable(FetchCloudRequirementsStatus::Request)
|
||||
}
|
||||
|
||||
struct StaticFetcher {
|
||||
contents: Option<String>,
|
||||
}
|
||||
@@ -672,7 +816,7 @@ mod tests {
|
||||
async fn fetch_requirements(
|
||||
&self,
|
||||
_auth: &CodexAuth,
|
||||
) -> Result<Option<String>, FetchCloudRequirementsStatus> {
|
||||
) -> Result<Option<String>, FetchCloudRequirementsError> {
|
||||
Ok(self.contents.clone())
|
||||
}
|
||||
}
|
||||
@@ -684,7 +828,7 @@ mod tests {
|
||||
async fn fetch_requirements(
|
||||
&self,
|
||||
_auth: &CodexAuth,
|
||||
) -> Result<Option<String>, FetchCloudRequirementsStatus> {
|
||||
) -> Result<Option<String>, FetchCloudRequirementsError> {
|
||||
pending::<()>().await;
|
||||
Ok(None)
|
||||
}
|
||||
@@ -692,12 +836,12 @@ mod tests {
|
||||
|
||||
struct SequenceFetcher {
|
||||
responses:
|
||||
tokio::sync::Mutex<VecDeque<Result<Option<String>, FetchCloudRequirementsStatus>>>,
|
||||
tokio::sync::Mutex<VecDeque<Result<Option<String>, FetchCloudRequirementsError>>>,
|
||||
request_count: AtomicUsize,
|
||||
}
|
||||
|
||||
impl SequenceFetcher {
|
||||
fn new(responses: Vec<Result<Option<String>, FetchCloudRequirementsStatus>>) -> Self {
|
||||
fn new(responses: Vec<Result<Option<String>, FetchCloudRequirementsError>>) -> Self {
|
||||
Self {
|
||||
responses: tokio::sync::Mutex::new(VecDeque::from(responses)),
|
||||
request_count: AtomicUsize::new(0),
|
||||
@@ -710,13 +854,57 @@ mod tests {
|
||||
async fn fetch_requirements(
|
||||
&self,
|
||||
_auth: &CodexAuth,
|
||||
) -> Result<Option<String>, FetchCloudRequirementsStatus> {
|
||||
) -> Result<Option<String>, FetchCloudRequirementsError> {
|
||||
self.request_count.fetch_add(1, Ordering::SeqCst);
|
||||
let mut responses = self.responses.lock().await;
|
||||
responses.pop_front().unwrap_or(Ok(None))
|
||||
}
|
||||
}
|
||||
|
||||
struct TokenFetcher {
|
||||
expected_token: String,
|
||||
contents: String,
|
||||
request_count: AtomicUsize,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl RequirementsFetcher for TokenFetcher {
|
||||
async fn fetch_requirements(
|
||||
&self,
|
||||
auth: &CodexAuth,
|
||||
) -> Result<Option<String>, FetchCloudRequirementsError> {
|
||||
self.request_count.fetch_add(1, Ordering::SeqCst);
|
||||
if matches!(
|
||||
auth.get_token().as_deref(),
|
||||
Ok(token) if token == self.expected_token.as_str()
|
||||
) {
|
||||
Ok(Some(self.contents.clone()))
|
||||
} else {
|
||||
Err(FetchCloudRequirementsError::Unauthorized(
|
||||
CloudRequirementsLoadError::new("GET /config/requirements failed: 401"),
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct UnauthorizedFetcher {
|
||||
message: String,
|
||||
request_count: AtomicUsize,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl RequirementsFetcher for UnauthorizedFetcher {
|
||||
async fn fetch_requirements(
|
||||
&self,
|
||||
_auth: &CodexAuth,
|
||||
) -> Result<Option<String>, FetchCloudRequirementsError> {
|
||||
self.request_count.fetch_add(1, Ordering::SeqCst);
|
||||
Err(FetchCloudRequirementsError::Unauthorized(
|
||||
CloudRequirementsLoadError::new(self.message.clone()),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn fetch_cloud_requirements_skips_non_chatgpt_auth() {
|
||||
let auth_manager = auth_manager_with_api_key();
|
||||
@@ -837,7 +1025,7 @@ mod tests {
|
||||
#[tokio::test(start_paused = true)]
|
||||
async fn fetch_cloud_requirements_retries_until_success() {
|
||||
let fetcher = Arc::new(SequenceFetcher::new(vec![
|
||||
Err(FetchCloudRequirementsStatus::Request),
|
||||
Err(request_error()),
|
||||
Ok(Some("allowed_approval_policies = [\"never\"]".to_string())),
|
||||
]));
|
||||
let codex_home = tempdir().expect("tempdir");
|
||||
@@ -868,6 +1056,206 @@ mod tests {
|
||||
assert_eq!(fetcher.request_count.load(Ordering::SeqCst), 2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn fetch_cloud_requirements_recovers_after_unauthorized_reload() {
|
||||
let auth = managed_auth_context(
|
||||
"business",
|
||||
Some("user-12345"),
|
||||
Some("account-12345"),
|
||||
"stale-access-token",
|
||||
"test-refresh-token",
|
||||
);
|
||||
write_auth_json(
|
||||
auth._home.path(),
|
||||
chatgpt_auth_json(
|
||||
"business",
|
||||
Some("user-12345"),
|
||||
Some("account-12345"),
|
||||
"fresh-access-token",
|
||||
"test-refresh-token",
|
||||
),
|
||||
)
|
||||
.expect("write refreshed auth");
|
||||
|
||||
let fetcher = Arc::new(TokenFetcher {
|
||||
expected_token: "fresh-access-token".to_string(),
|
||||
contents: "allowed_approval_policies = [\"never\"]".to_string(),
|
||||
request_count: AtomicUsize::new(0),
|
||||
});
|
||||
let codex_home = tempdir().expect("tempdir");
|
||||
let service = CloudRequirementsService::new(
|
||||
Arc::clone(&auth.manager),
|
||||
fetcher.clone(),
|
||||
codex_home.path().to_path_buf(),
|
||||
CLOUD_REQUIREMENTS_TIMEOUT,
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
service.fetch().await,
|
||||
Ok(Some(ConfigRequirementsToml {
|
||||
allowed_approval_policies: Some(vec![AskForApproval::Never]),
|
||||
allowed_sandbox_modes: None,
|
||||
allowed_web_search_modes: None,
|
||||
feature_requirements: None,
|
||||
mcp_servers: None,
|
||||
rules: None,
|
||||
enforce_residency: None,
|
||||
network: None,
|
||||
}))
|
||||
);
|
||||
assert_eq!(fetcher.request_count.load(Ordering::SeqCst), 2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn fetch_cloud_requirements_recovers_after_unauthorized_reload_updates_cache_identity() {
|
||||
let auth = managed_auth_context(
|
||||
"business",
|
||||
Some("user-12345"),
|
||||
Some("account-12345"),
|
||||
"stale-access-token",
|
||||
"test-refresh-token",
|
||||
);
|
||||
write_auth_json(
|
||||
auth._home.path(),
|
||||
chatgpt_auth_json(
|
||||
"business",
|
||||
Some("user-99999"),
|
||||
Some("account-12345"),
|
||||
"fresh-access-token",
|
||||
"test-refresh-token",
|
||||
),
|
||||
)
|
||||
.expect("write refreshed auth");
|
||||
|
||||
let fetcher = Arc::new(TokenFetcher {
|
||||
expected_token: "fresh-access-token".to_string(),
|
||||
contents: "allowed_approval_policies = [\"never\"]".to_string(),
|
||||
request_count: AtomicUsize::new(0),
|
||||
});
|
||||
let codex_home = tempdir().expect("tempdir");
|
||||
let service = CloudRequirementsService::new(
|
||||
Arc::clone(&auth.manager),
|
||||
fetcher.clone(),
|
||||
codex_home.path().to_path_buf(),
|
||||
CLOUD_REQUIREMENTS_TIMEOUT,
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
service.fetch().await,
|
||||
Ok(Some(ConfigRequirementsToml {
|
||||
allowed_approval_policies: Some(vec![AskForApproval::Never]),
|
||||
allowed_sandbox_modes: None,
|
||||
allowed_web_search_modes: None,
|
||||
feature_requirements: None,
|
||||
mcp_servers: None,
|
||||
rules: None,
|
||||
enforce_residency: None,
|
||||
network: None,
|
||||
}))
|
||||
);
|
||||
|
||||
let path = codex_home.path().join(CLOUD_REQUIREMENTS_CACHE_FILENAME);
|
||||
let cache_file: CloudRequirementsCacheFile =
|
||||
serde_json::from_str(&std::fs::read_to_string(path).expect("read cache"))
|
||||
.expect("parse cache");
|
||||
assert_eq!(
|
||||
cache_file.signed_payload.chatgpt_user_id,
|
||||
Some("user-99999".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
cache_file.signed_payload.account_id,
|
||||
Some("account-12345".to_string())
|
||||
);
|
||||
assert_eq!(fetcher.request_count.load(Ordering::SeqCst), 2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn fetch_cloud_requirements_surfaces_auth_recovery_message() {
|
||||
let auth = managed_auth_context(
|
||||
"enterprise",
|
||||
Some("user-12345"),
|
||||
Some("account-12345"),
|
||||
"stale-access-token",
|
||||
"test-refresh-token",
|
||||
);
|
||||
write_auth_json(
|
||||
auth._home.path(),
|
||||
chatgpt_auth_json(
|
||||
"enterprise",
|
||||
Some("user-12345"),
|
||||
Some("account-99999"),
|
||||
"fresh-access-token",
|
||||
"test-refresh-token",
|
||||
),
|
||||
)
|
||||
.expect("write mismatched auth");
|
||||
|
||||
let fetcher = Arc::new(UnauthorizedFetcher {
|
||||
message: "GET /config/requirements failed: 401".to_string(),
|
||||
request_count: AtomicUsize::new(0),
|
||||
});
|
||||
let codex_home = tempdir().expect("tempdir");
|
||||
let service = CloudRequirementsService::new(
|
||||
Arc::clone(&auth.manager),
|
||||
fetcher.clone(),
|
||||
codex_home.path().to_path_buf(),
|
||||
CLOUD_REQUIREMENTS_TIMEOUT,
|
||||
);
|
||||
|
||||
let err = service
|
||||
.fetch()
|
||||
.await
|
||||
.expect_err("cloud requirements should surface auth recovery errors");
|
||||
assert_eq!(
|
||||
err.to_string(),
|
||||
"Your access token could not be refreshed because you have since logged out or signed in to another account. Please sign in again."
|
||||
);
|
||||
assert_eq!(fetcher.request_count.load(Ordering::SeqCst), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn fetch_cloud_requirements_unauthorized_without_recovery_uses_generic_message() {
|
||||
let auth_home = tempdir().expect("tempdir");
|
||||
write_auth_json(
|
||||
auth_home.path(),
|
||||
chatgpt_auth_json_with_mode(
|
||||
"enterprise",
|
||||
Some("user-12345"),
|
||||
Some("account-12345"),
|
||||
"test-access-token",
|
||||
"test-refresh-token",
|
||||
Some("chatgptAuthTokens"),
|
||||
),
|
||||
)
|
||||
.expect("write auth");
|
||||
let auth_manager = Arc::new(AuthManager::new(
|
||||
auth_home.path().to_path_buf(),
|
||||
false,
|
||||
AuthCredentialsStoreMode::File,
|
||||
));
|
||||
|
||||
let fetcher = Arc::new(UnauthorizedFetcher {
|
||||
message:
|
||||
"GET https://chatgpt.com/backend-api/wham/config/requirements failed: 401; content-type=text/html; body=<html>nope</html>"
|
||||
.to_string(),
|
||||
request_count: AtomicUsize::new(0),
|
||||
});
|
||||
let codex_home = tempdir().expect("tempdir");
|
||||
let service = CloudRequirementsService::new(
|
||||
auth_manager,
|
||||
fetcher.clone(),
|
||||
codex_home.path().to_path_buf(),
|
||||
CLOUD_REQUIREMENTS_TIMEOUT,
|
||||
);
|
||||
|
||||
let err = service
|
||||
.fetch()
|
||||
.await
|
||||
.expect_err("cloud requirements should fail closed");
|
||||
assert_eq!(err.to_string(), CLOUD_REQUIREMENTS_LOAD_FAILED_MESSAGE);
|
||||
assert_eq!(fetcher.request_count.load(Ordering::SeqCst), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn fetch_cloud_requirements_parse_error_does_not_retry() {
|
||||
let fetcher = Arc::new(SequenceFetcher::new(vec![
|
||||
@@ -899,9 +1287,7 @@ mod tests {
|
||||
);
|
||||
let _ = prime_service.fetch().await;
|
||||
|
||||
let fetcher = Arc::new(SequenceFetcher::new(vec![Err(
|
||||
FetchCloudRequirementsStatus::Request,
|
||||
)]));
|
||||
let fetcher = Arc::new(SequenceFetcher::new(vec![Err(request_error())]));
|
||||
let service = CloudRequirementsService::new(
|
||||
auth_manager_with_plan("business"),
|
||||
fetcher.clone(),
|
||||
@@ -1209,10 +1595,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn fetch_cloud_requirements_none_is_success_without_retry() {
|
||||
let fetcher = Arc::new(SequenceFetcher::new(vec![
|
||||
Ok(None),
|
||||
Err(FetchCloudRequirementsStatus::Request),
|
||||
]));
|
||||
let fetcher = Arc::new(SequenceFetcher::new(vec![Ok(None), Err(request_error())]));
|
||||
let codex_home = tempdir().expect("tempdir");
|
||||
let service = CloudRequirementsService::new(
|
||||
auth_manager_with_plan("enterprise"),
|
||||
@@ -1228,9 +1611,7 @@ mod tests {
|
||||
#[tokio::test(start_paused = true)]
|
||||
async fn fetch_cloud_requirements_stops_after_max_retries() {
|
||||
let fetcher = Arc::new(SequenceFetcher::new(vec![
|
||||
Err(
|
||||
FetchCloudRequirementsStatus::Request
|
||||
);
|
||||
Err(request_error());
|
||||
CLOUD_REQUIREMENTS_MAX_ATTEMPTS
|
||||
]));
|
||||
let codex_home = tempdir().expect("tempdir");
|
||||
|
||||
Reference in New Issue
Block a user