Add a background job to refresh the requirements local cache

This commit is contained in:
alexsong-oai
2026-02-26 11:46:21 -08:00
parent eb77db2957
commit f1c3c4adc5

View File

@@ -38,7 +38,8 @@ use tokio::time::timeout;
const CLOUD_REQUIREMENTS_TIMEOUT: Duration = Duration::from_secs(15);
const CLOUD_REQUIREMENTS_MAX_ATTEMPTS: usize = 5;
const CLOUD_REQUIREMENTS_CACHE_FILENAME: &str = "cloud-requirements-cache.json";
const CLOUD_REQUIREMENTS_CACHE_TTL: Duration = Duration::from_secs(60 * 60);
const CLOUD_REQUIREMENTS_CACHE_REFRESH_INTERVAL: Duration = Duration::from_secs(5 * 60);
const CLOUD_REQUIREMENTS_CACHE_TTL: Duration = Duration::from_secs(30 * 60);
const CLOUD_REQUIREMENTS_CACHE_WRITE_HMAC_KEY: &[u8] =
b"codex-cloud-requirements-cache-v3-064f8542-75b4-494c-a294-97d3ce597271";
const CLOUD_REQUIREMENTS_CACHE_READ_HMAC_KEYS: &[&[u8]] =
@@ -49,6 +50,7 @@ type HmacSha256 = Hmac<Sha256>;
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
enum FetchCloudRequirementsStatus {
BackendClientInit,
Parse,
Request,
}
@@ -188,6 +190,7 @@ impl RequirementsFetcher for BackendRequirementsFetcher {
}
}
#[derive(Clone)]
struct CloudRequirementsService {
auth_manager: Arc<AuthManager>,
fetcher: Arc<dyn RequirementsFetcher>,
@@ -271,7 +274,9 @@ impl CloudRequirementsService {
}
self.fetch_with_retries(&auth, chatgpt_user_id, account_id)
.await?
.await
.ok()
.flatten()
}
async fn fetch_with_retries(
@@ -279,11 +284,13 @@ impl CloudRequirementsService {
auth: &CodexAuth,
chatgpt_user_id: Option<&str>,
account_id: Option<&str>,
) -> Option<Option<ConfigRequirementsToml>> {
) -> Result<Option<ConfigRequirementsToml>, FetchCloudRequirementsStatus> {
let mut last_status = FetchCloudRequirementsStatus::Request;
for attempt in 1..=CLOUD_REQUIREMENTS_MAX_ATTEMPTS {
let contents = match self.fetcher.fetch_requirements(auth).await {
Ok(contents) => contents,
Err(status) => {
last_status = status;
if attempt < CLOUD_REQUIREMENTS_MAX_ATTEMPTS {
tracing::warn!(
status = ?status,
@@ -302,7 +309,7 @@ impl CloudRequirementsService {
Ok(requirements) => requirements,
Err(err) => {
tracing::warn!(error = %err, "Failed to parse cloud requirements");
return None;
return Err(FetchCloudRequirementsStatus::Parse);
}
},
None => None,
@@ -319,10 +326,55 @@ impl CloudRequirementsService {
tracing::warn!(error = %err, "Failed to write cloud requirements cache");
}
return Some(requirements);
return Ok(requirements);
}
None
Err(last_status)
}
async fn refresh_cache_in_background(&self) {
loop {
sleep(CLOUD_REQUIREMENTS_CACHE_REFRESH_INTERVAL).await;
let _ = timeout(self.timeout, self.refresh_cache())
.await
.inspect_err(|_| {
tracing::warn!(
"Timed out refreshing cloud requirements cache from remote; keeping existing cache"
);
});
}
}
async fn refresh_cache(&self) {
let Some(auth) = self.auth_manager.auth().await else {
return;
};
if !auth.is_chatgpt_auth()
|| !matches!(
auth.account_plan_type(),
Some(PlanType::Business | PlanType::Enterprise)
)
{
return;
}
let token_data = auth.get_token_data().ok();
let chatgpt_user_id = token_data
.as_ref()
.and_then(|token_data| token_data.id_token.chatgpt_user_id.as_deref());
let account_id = auth.get_account_id();
let account_id = account_id.as_deref();
if let Err(status) = self
.fetch_with_retries(&auth, chatgpt_user_id, account_id)
.await
{
tracing::warn!(
status = ?status,
path = %self.cache_path.display(),
"Failed to refresh cloud requirements cache from remote"
);
}
}
async fn load_cache(
@@ -452,7 +504,9 @@ pub fn cloud_requirements_loader(
codex_home,
CLOUD_REQUIREMENTS_TIMEOUT,
);
let refresh_service = service.clone();
let task = tokio::spawn(async move { service.fetch_with_timeout().await });
tokio::spawn(async move { refresh_service.refresh_cache_in_background().await });
CloudRequirementsLoader::new(async move {
task.await
.inspect_err(|err| tracing::warn!(error = %err, "Cloud requirements task failed"))
@@ -1052,7 +1106,11 @@ mod tests {
let cache_file: CloudRequirementsCacheFile =
serde_json::from_str(&std::fs::read_to_string(path).expect("read cache"))
.expect("parse cache");
assert!(cache_file.signed_payload.expires_at > Utc::now());
assert!(
cache_file.signed_payload.expires_at
<= cache_file.signed_payload.cached_at + ChronoDuration::minutes(30)
);
assert!(cache_file.signed_payload.expires_at > cache_file.signed_payload.cached_at);
assert!(cache_file.signed_payload.cached_at <= Utc::now());
assert_eq!(
cache_file.signed_payload.chatgpt_user_id,
@@ -1130,4 +1188,57 @@ mod tests {
CLOUD_REQUIREMENTS_MAX_ATTEMPTS
);
}
#[tokio::test]
async fn refresh_from_remote_updates_cached_cloud_requirements() {
let codex_home = tempdir().expect("tempdir");
let fetcher = Arc::new(SequenceFetcher::new(vec![
Ok(Some("allowed_approval_policies = [\"never\"]".to_string())),
Ok(Some(
"allowed_approval_policies = [\"on-request\"]".to_string(),
)),
]));
let service = CloudRequirementsService::new(
auth_manager_with_plan("business"),
fetcher,
codex_home.path().to_path_buf(),
CLOUD_REQUIREMENTS_TIMEOUT,
);
assert_eq!(
service.fetch().await,
Some(ConfigRequirementsToml {
allowed_approval_policies: Some(vec![AskForApproval::Never]),
allowed_sandbox_modes: None,
allowed_web_search_modes: None,
mcp_servers: None,
rules: None,
enforce_residency: None,
network: None,
})
);
service.refresh_cache().await;
let path = codex_home.path().join(CLOUD_REQUIREMENTS_CACHE_FILENAME);
let cache_file: CloudRequirementsCacheFile =
serde_json::from_str(&std::fs::read_to_string(path).expect("read cache"))
.expect("parse cache");
assert_eq!(
cache_file
.signed_payload
.contents
.as_deref()
.and_then(|contents| parse_cloud_requirements(contents).ok().flatten()),
Some(ConfigRequirementsToml {
allowed_approval_policies: Some(vec![AskForApproval::OnRequest]),
allowed_sandbox_modes: None,
allowed_web_search_modes: None,
mcp_servers: None,
rules: None,
enforce_residency: None,
network: None,
})
);
}
}