Compare commits

...

1 Commits

Author SHA1 Message Date
gt-oai
e04cfc3601 Make cloud requirements load fail-closed 2026-01-31 00:28:03 +00:00
6 changed files with 221 additions and 108 deletions

2
codex-rs/Cargo.lock generated
View File

@@ -1297,6 +1297,7 @@ dependencies = [
name = "codex-cloud-requirements"
version = "0.0.0"
dependencies = [
"anyhow",
"async-trait",
"base64",
"codex-backend-client",
@@ -1306,6 +1307,7 @@ dependencies = [
"pretty_assertions",
"serde_json",
"tempfile",
"thiserror 2.0.17",
"tokio",
"toml 0.9.5",
"tracing",

View File

@@ -14,10 +14,12 @@ codex-core = { workspace = true }
codex-otel = { workspace = true }
codex-protocol = { workspace = true }
tokio = { workspace = true, features = ["sync", "time"] }
thiserror = { workspace = true }
toml = { workspace = true }
tracing = { workspace = true }
[dev-dependencies]
anyhow = { workspace = true }
base64 = { workspace = true }
pretty_assertions = { workspace = true }
serde_json = { workspace = true }

View File

@@ -3,9 +3,7 @@
//! This crate fetches `requirements.toml` data from the backend as an alternative to loading it
//! from the local filesystem. It only applies to Enterprise ChatGPT customers.
//!
//! Today, fetching is best-effort: on error or timeout, Codex continues without cloud requirements.
//! We expect to tighten this so that Enterprise ChatGPT customers must successfully fetch these
//! requirements before Codex will run.
//! Enterprise ChatGPT customers must successfully fetch these requirements before Codex will run.
use async_trait::async_trait;
use codex_backend_client::Client as BackendClient;
@@ -14,21 +12,73 @@ use codex_core::auth::CodexAuth;
use codex_core::config_loader::CloudRequirementsLoader;
use codex_core::config_loader::ConfigRequirementsToml;
use codex_protocol::account::PlanType;
use std::io;
use std::sync::Arc;
use std::time::Duration;
use std::time::Instant;
use thiserror::Error;
use tokio::time::timeout;
/// This blocks codecs startup, so must be short.
const CLOUD_REQUIREMENTS_TIMEOUT: Duration = Duration::from_secs(5);
#[derive(Debug, Error, Clone, PartialEq, Eq)]
enum CloudRequirementsError {
#[error("cloud requirements user error: {0}")]
User(CloudRequirementsUserError),
#[error("cloud requirements network error: {0}")]
Network(CloudRequirementsNetworkError),
}
impl From<CloudRequirementsUserError> for CloudRequirementsError {
fn from(err: CloudRequirementsUserError) -> Self {
CloudRequirementsError::User(err)
}
}
impl From<CloudRequirementsNetworkError> for CloudRequirementsError {
fn from(err: CloudRequirementsNetworkError) -> Self {
CloudRequirementsError::Network(err)
}
}
impl From<CloudRequirementsError> for io::Error {
fn from(err: CloudRequirementsError) -> Self {
let kind = match &err {
CloudRequirementsError::User(_) => io::ErrorKind::InvalidData,
CloudRequirementsError::Network(CloudRequirementsNetworkError::Timeout { .. }) => {
io::ErrorKind::TimedOut
}
CloudRequirementsError::Network(_) => io::ErrorKind::Other,
};
io::Error::new(kind, err)
}
}
#[derive(Debug, Error, Clone, PartialEq, Eq)]
enum CloudRequirementsUserError {
#[error("failed to parse requirements TOML: {message}")]
InvalidToml { message: String },
}
#[derive(Debug, Error, Clone, PartialEq, Eq)]
enum CloudRequirementsNetworkError {
#[error("backend client initialization failed: {message}")]
BackendClient { message: String },
#[error("request failed: {message}")]
Request { message: String },
#[error("cloud requirements response missing contents")]
MissingContents,
#[error("timed out after {timeout_ms}ms")]
Timeout { timeout_ms: u64 },
#[error("cloud requirements task failed: {message}")]
Task { message: String },
}
#[async_trait]
trait RequirementsFetcher: Send + Sync {
/// Returns requirements as a TOML string.
///
/// TODO(gt): For now, returns an Option. But when we want to make this fail-closed, return a
/// Result.
async fn fetch_requirements(&self, auth: &CodexAuth) -> Option<String>;
async fn fetch_requirements(&self, auth: &CodexAuth) -> Result<String, CloudRequirementsError>;
}
struct BackendRequirementsFetcher {
@@ -43,7 +93,7 @@ impl BackendRequirementsFetcher {
#[async_trait]
impl RequirementsFetcher for BackendRequirementsFetcher {
async fn fetch_requirements(&self, auth: &CodexAuth) -> Option<String> {
async fn fetch_requirements(&self, auth: &CodexAuth) -> Result<String, CloudRequirementsError> {
let client = BackendClient::from_auth(self.base_url.clone(), auth)
.inspect_err(|err| {
tracing::warn!(
@@ -51,20 +101,28 @@ impl RequirementsFetcher for BackendRequirementsFetcher {
"Failed to construct backend client for cloud requirements"
);
})
.ok()?;
.map_err(|err| CloudRequirementsNetworkError::BackendClient {
message: err.to_string(),
})
.map_err(CloudRequirementsError::from)?;
let response = client
.get_config_requirements_file()
.await
.inspect_err(|err| tracing::warn!(error = %err, "Failed to fetch cloud requirements"))
.ok()?;
.map_err(|err| CloudRequirementsNetworkError::Request {
message: err.to_string(),
})
.map_err(CloudRequirementsError::from)?;
let Some(contents) = response.contents else {
tracing::warn!("Cloud requirements response missing contents");
return None;
return Err(CloudRequirementsError::from(
CloudRequirementsNetworkError::MissingContents,
));
};
Some(contents)
Ok(contents)
}
}
@@ -87,47 +145,53 @@ impl CloudRequirementsService {
}
}
async fn fetch_with_timeout(&self) -> Option<ConfigRequirementsToml> {
async fn fetch_with_timeout(
&self,
) -> Result<Option<ConfigRequirementsToml>, CloudRequirementsError> {
let _timer =
codex_otel::start_global_timer("codex.cloud_requirements.fetch.duration_ms", &[]);
let started_at = Instant::now();
let result = timeout(self.timeout, self.fetch())
.await
.inspect_err(|_| {
tracing::warn!("Timed out waiting for cloud requirements; continuing without them");
})
.ok()?;
let result = timeout(self.timeout, self.fetch()).await.map_err(|_| {
CloudRequirementsNetworkError::Timeout {
timeout_ms: self.timeout.as_millis() as u64,
}
})?;
match result.as_ref() {
Some(requirements) => {
Ok(Some(requirements)) => {
tracing::info!(
elapsed_ms = started_at.elapsed().as_millis(),
requirements = ?requirements,
"Cloud requirements load completed"
);
}
None => {
Ok(None) => {
tracing::info!(
elapsed_ms = started_at.elapsed().as_millis(),
"Cloud requirements load completed (none)"
);
}
Err(err) => {
tracing::warn!(error = %err, "Cloud requirements load failed");
}
}
result
}
async fn fetch(&self) -> Option<ConfigRequirementsToml> {
let auth = self.auth_manager.auth().await?;
async fn fetch(&self) -> Result<Option<ConfigRequirementsToml>, CloudRequirementsError> {
let auth = match self.auth_manager.auth().await {
Some(auth) => auth,
None => return Ok(None),
};
if !(auth.is_chatgpt_auth() && auth.account_plan_type() == Some(PlanType::Enterprise)) {
return None;
return Ok(None);
}
let contents = self.fetcher.fetch_requirements(&auth).await?;
parse_cloud_requirements(&contents)
.inspect_err(|err| tracing::warn!(error = %err, "Failed to parse cloud requirements"))
.ok()
.flatten()
.map_err(CloudRequirementsError::from)
}
}
@@ -143,20 +207,28 @@ pub fn cloud_requirements_loader(
let task = tokio::spawn(async move { service.fetch_with_timeout().await });
CloudRequirementsLoader::new(async move {
task.await
.map_err(|err| {
CloudRequirementsError::from(CloudRequirementsNetworkError::Task {
message: err.to_string(),
})
})
.and_then(std::convert::identity)
.map_err(io::Error::from)
.inspect_err(|err| tracing::warn!(error = %err, "Cloud requirements task failed"))
.ok()
.flatten()
})
}
fn parse_cloud_requirements(
contents: &str,
) -> Result<Option<ConfigRequirementsToml>, toml::de::Error> {
) -> Result<Option<ConfigRequirementsToml>, CloudRequirementsUserError> {
if contents.trim().is_empty() {
return Ok(None);
}
let requirements: ConfigRequirementsToml = toml::from_str(contents)?;
let requirements: ConfigRequirementsToml =
toml::from_str(contents).map_err(|err| CloudRequirementsUserError::InvalidToml {
message: err.to_string(),
})?;
if requirements.is_empty() {
Ok(None)
} else {
@@ -167,6 +239,7 @@ fn parse_cloud_requirements(
#[cfg(test)]
mod tests {
use super::*;
use anyhow::Result;
use base64::Engine;
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use codex_core::auth::AuthCredentialsStoreMode;
@@ -177,28 +250,28 @@ mod tests {
use std::path::Path;
use tempfile::tempdir;
fn write_auth_json(codex_home: &Path, value: serde_json::Value) -> std::io::Result<()> {
fn write_auth_json(codex_home: &Path, value: serde_json::Value) -> Result<()> {
std::fs::write(codex_home.join("auth.json"), serde_json::to_string(&value)?)?;
Ok(())
}
fn auth_manager_with_api_key() -> Arc<AuthManager> {
let tmp = tempdir().expect("tempdir");
fn auth_manager_with_api_key() -> Result<Arc<AuthManager>> {
let tmp = tempdir()?;
let auth_json = json!({
"OPENAI_API_KEY": "sk-test-key",
"tokens": null,
"last_refresh": null,
});
write_auth_json(tmp.path(), auth_json).expect("write auth");
Arc::new(AuthManager::new(
write_auth_json(tmp.path(), auth_json)?;
Ok(Arc::new(AuthManager::new(
tmp.path().to_path_buf(),
false,
AuthCredentialsStoreMode::File,
))
)))
}
fn auth_manager_with_plan(plan_type: &str) -> Arc<AuthManager> {
let tmp = tempdir().expect("tempdir");
fn auth_manager_with_plan(plan_type: &str) -> Result<Arc<AuthManager>> {
let tmp = tempdir()?;
let header = json!({ "alg": "none", "typ": "JWT" });
let auth_payload = json!({
"chatgpt_plan_type": plan_type,
@@ -209,8 +282,8 @@ mod tests {
"email": "user@example.com",
"https://api.openai.com/auth": auth_payload,
});
let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_vec(&header).expect("header"));
let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_vec(&payload).expect("payload"));
let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_vec(&header)?);
let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_vec(&payload)?);
let signature_b64 = URL_SAFE_NO_PAD.encode(b"sig");
let fake_jwt = format!("{header_b64}.{payload_b64}.{signature_b64}");
@@ -223,26 +296,31 @@ mod tests {
},
"last_refresh": null,
});
write_auth_json(tmp.path(), auth_json).expect("write auth");
Arc::new(AuthManager::new(
write_auth_json(tmp.path(), auth_json)?;
Ok(Arc::new(AuthManager::new(
tmp.path().to_path_buf(),
false,
AuthCredentialsStoreMode::File,
))
)))
}
fn parse_for_fetch(contents: Option<&str>) -> Option<ConfigRequirementsToml> {
contents.and_then(|contents| parse_cloud_requirements(contents).ok().flatten())
fn parse_for_fetch(
contents: Option<&str>,
) -> Result<Option<ConfigRequirementsToml>, CloudRequirementsUserError> {
contents.map(parse_cloud_requirements).unwrap_or(Ok(None))
}
struct StaticFetcher {
contents: Option<String>,
result: Result<String, CloudRequirementsError>,
}
#[async_trait::async_trait]
impl RequirementsFetcher for StaticFetcher {
async fn fetch_requirements(&self, _auth: &CodexAuth) -> Option<String> {
self.contents.clone()
async fn fetch_requirements(
&self,
_auth: &CodexAuth,
) -> Result<String, CloudRequirementsError> {
self.result.clone()
}
}
@@ -250,88 +328,115 @@ mod tests {
#[async_trait::async_trait]
impl RequirementsFetcher for PendingFetcher {
async fn fetch_requirements(&self, _auth: &CodexAuth) -> Option<String> {
async fn fetch_requirements(
&self,
_auth: &CodexAuth,
) -> Result<String, CloudRequirementsError> {
pending::<()>().await;
None
Ok(String::new())
}
}
#[tokio::test]
async fn fetch_cloud_requirements_skips_non_chatgpt_auth() {
let auth_manager = auth_manager_with_api_key();
async fn fetch_cloud_requirements_skips_non_chatgpt_auth() -> Result<()> {
let service = CloudRequirementsService::new(
auth_manager,
Arc::new(StaticFetcher { contents: None }),
auth_manager_with_api_key()?,
Arc::new(StaticFetcher {
result: Ok(String::new()),
}),
CLOUD_REQUIREMENTS_TIMEOUT,
);
let result = service.fetch().await;
assert!(result.is_none());
assert_eq!(service.fetch().await, Ok(None));
Ok(())
}
#[tokio::test]
async fn fetch_cloud_requirements_skips_non_enterprise_plan() {
let auth_manager = auth_manager_with_plan("pro");
async fn fetch_cloud_requirements_skips_non_enterprise_plan() -> Result<()> {
let service = CloudRequirementsService::new(
auth_manager,
Arc::new(StaticFetcher { contents: None }),
auth_manager_with_plan("pro")?,
Arc::new(StaticFetcher {
result: Ok(String::new()),
}),
CLOUD_REQUIREMENTS_TIMEOUT,
);
let result = service.fetch().await;
assert!(result.is_none());
assert_eq!(service.fetch().await, Ok(None));
Ok(())
}
#[tokio::test]
async fn fetch_cloud_requirements_handles_missing_contents() {
let result = parse_for_fetch(None);
assert!(result.is_none());
}
#[tokio::test]
async fn fetch_cloud_requirements_handles_empty_contents() {
let result = parse_for_fetch(Some(" "));
assert!(result.is_none());
}
#[tokio::test]
async fn fetch_cloud_requirements_handles_invalid_toml() {
let result = parse_for_fetch(Some("not = ["));
assert!(result.is_none());
}
#[tokio::test]
async fn fetch_cloud_requirements_ignores_empty_requirements() {
let result = parse_for_fetch(Some("# comment"));
assert!(result.is_none());
}
#[tokio::test]
async fn fetch_cloud_requirements_parses_valid_toml() {
let result = parse_for_fetch(Some("allowed_approval_policies = [\"never\"]"));
async fn fetch_cloud_requirements_returns_missing_contents_error() -> Result<()> {
let service = CloudRequirementsService::new(
auth_manager_with_plan("enterprise")?,
Arc::new(StaticFetcher {
result: Err(CloudRequirementsError::Network(
CloudRequirementsNetworkError::MissingContents,
)),
}),
CLOUD_REQUIREMENTS_TIMEOUT,
);
assert_eq!(
result,
Some(ConfigRequirementsToml {
service.fetch().await,
Err(CloudRequirementsError::Network(
CloudRequirementsNetworkError::MissingContents
))
);
Ok(())
}
#[tokio::test]
async fn fetch_cloud_requirements_handles_empty_contents() -> Result<()> {
assert_eq!(parse_for_fetch(Some(" ")), Ok(None));
Ok(())
}
#[tokio::test]
async fn fetch_cloud_requirements_handles_invalid_toml() -> Result<()> {
assert!(matches!(
parse_for_fetch(Some("not = [")),
Err(CloudRequirementsUserError::InvalidToml { .. })
));
Ok(())
}
#[tokio::test]
async fn fetch_cloud_requirements_ignores_empty_requirements() -> Result<()> {
assert_eq!(parse_for_fetch(Some("# comment")), Ok(None));
Ok(())
}
#[tokio::test]
async fn fetch_cloud_requirements_parses_valid_toml() -> Result<()> {
assert_eq!(
parse_for_fetch(Some("allowed_approval_policies = [\"never\"]")),
Ok(Some(ConfigRequirementsToml {
allowed_approval_policies: Some(vec![AskForApproval::Never]),
allowed_sandbox_modes: None,
mcp_servers: None,
rules: None,
enforce_residency: None,
})
}))
);
Ok(())
}
#[tokio::test(start_paused = true)]
async fn fetch_cloud_requirements_times_out() {
let auth_manager = auth_manager_with_plan("enterprise");
async fn fetch_cloud_requirements_times_out() -> Result<()> {
let service = CloudRequirementsService::new(
auth_manager,
auth_manager_with_plan("enterprise")?,
Arc::new(PendingFetcher),
CLOUD_REQUIREMENTS_TIMEOUT,
);
let handle = tokio::spawn(async move { service.fetch_with_timeout().await });
tokio::time::advance(CLOUD_REQUIREMENTS_TIMEOUT + Duration::from_millis(1)).await;
let result = handle.await.expect("cloud requirements task");
assert!(result.is_none());
assert_eq!(
handle.await?,
Err(CloudRequirementsError::Network(
CloudRequirementsNetworkError::Timeout {
timeout_ms: CLOUD_REQUIREMENTS_TIMEOUT.as_millis() as u64,
}
))
);
Ok(())
}
}

View File

@@ -4,25 +4,29 @@ use futures::future::FutureExt;
use futures::future::Shared;
use std::fmt;
use std::future::Future;
use std::io;
use std::sync::Arc;
#[derive(Clone)]
pub struct CloudRequirementsLoader {
// TODO(gt): This should return a Result once we can fail-closed.
fut: Shared<BoxFuture<'static, Option<ConfigRequirementsToml>>>,
fut: Shared<BoxFuture<'static, Arc<io::Result<Option<ConfigRequirementsToml>>>>>,
}
impl CloudRequirementsLoader {
pub fn new<F>(fut: F) -> Self
where
F: Future<Output = Option<ConfigRequirementsToml>> + Send + 'static,
F: Future<Output = io::Result<Option<ConfigRequirementsToml>>> + Send + 'static,
{
Self {
fut: fut.boxed().shared(),
fut: fut.map(Arc::new).boxed().shared(),
}
}
pub async fn get(&self) -> Option<ConfigRequirementsToml> {
self.fut.clone().await
pub async fn get(&self) -> io::Result<Option<ConfigRequirementsToml>> {
match self.fut.clone().await.as_ref() {
Ok(requirements) => Ok(requirements.clone()),
Err(err) => Err(io::Error::new(err.kind(), err.to_string())),
}
}
}
@@ -34,7 +38,7 @@ impl fmt::Debug for CloudRequirementsLoader {
impl Default for CloudRequirementsLoader {
fn default() -> Self {
Self::new(async { None })
Self::new(async { Ok(None) })
}
}
@@ -52,11 +56,11 @@ mod tests {
let counter_clone = Arc::clone(&counter);
let loader = CloudRequirementsLoader::new(async move {
counter_clone.fetch_add(1, Ordering::SeqCst);
Some(ConfigRequirementsToml::default())
Ok(Some(ConfigRequirementsToml::default()))
});
let (first, second) = tokio::join!(loader.get(), loader.get());
assert_eq!(first, second);
assert_eq!(first.as_ref().ok(), second.as_ref().ok());
assert_eq!(counter.load(Ordering::SeqCst), 1);
}
}

View File

@@ -115,7 +115,7 @@ pub async fn load_config_layers_state(
)
.await?;
if let Some(requirements) = cloud_requirements.get().await {
if let Some(requirements) = cloud_requirements.get().await? {
config_requirements_toml
.merge_unset_fields(RequirementSource::CloudRequirements, requirements);
}

View File

@@ -545,7 +545,7 @@ async fn load_config_layers_includes_cloud_requirements() -> anyhow::Result<()>
enforce_residency: None,
};
let expected = requirements.clone();
let cloud_requirements = CloudRequirementsLoader::new(async move { Some(requirements) });
let cloud_requirements = CloudRequirementsLoader::new(async move { Ok(Some(requirements)) });
let layers = load_config_layers_state(
&codex_home,