mirror of
https://github.com/openai/codex.git
synced 2026-04-24 14:45:27 +00:00
140 lines
4.4 KiB
Rust
140 lines
4.4 KiB
Rust
use chrono::DateTime;
|
|
use chrono::Utc;
|
|
use codex_protocol::openai_models::ModelInfo;
|
|
use serde::Deserialize;
|
|
use serde::Serialize;
|
|
use std::io;
|
|
use std::io::ErrorKind;
|
|
use std::path::PathBuf;
|
|
use std::time::Duration;
|
|
use tokio::fs;
|
|
use tracing::error;
|
|
|
|
/// Manages loading and saving of models cache to disk.
|
|
#[derive(Debug)]
|
|
pub(crate) struct ModelsCacheManager {
|
|
cache_path: PathBuf,
|
|
cache_ttl: Duration,
|
|
}
|
|
|
|
impl ModelsCacheManager {
|
|
/// Create a new cache manager with the given path and TTL.
|
|
pub(crate) fn new(cache_path: PathBuf, cache_ttl: Duration) -> Self {
|
|
Self {
|
|
cache_path,
|
|
cache_ttl,
|
|
}
|
|
}
|
|
|
|
/// Attempt to load a fresh cache entry. Returns `None` if the cache doesn't exist or is stale.
|
|
pub(crate) async fn load_fresh(&self) -> Option<CacheData> {
|
|
let cache = match self.load().await {
|
|
Ok(cache) => cache?,
|
|
Err(err) => {
|
|
error!("failed to load models cache: {err}");
|
|
return None;
|
|
}
|
|
};
|
|
if !cache.is_fresh(self.cache_ttl) {
|
|
return None;
|
|
}
|
|
Some(CacheData {
|
|
models: cache.models,
|
|
etag: cache.etag,
|
|
})
|
|
}
|
|
|
|
/// Persist the cache to disk, creating parent directories as needed.
|
|
pub(crate) async fn persist_cache(&self, models: &[ModelInfo], etag: Option<String>) {
|
|
let cache = ModelsCache {
|
|
fetched_at: Utc::now(),
|
|
etag,
|
|
models: models.to_vec(),
|
|
};
|
|
if let Err(err) = self.save_internal(&cache).await {
|
|
error!("failed to write models cache: {err}");
|
|
}
|
|
}
|
|
|
|
/// Renew the cache TTL by updating the fetched_at timestamp to the current time.
|
|
/// This rewrites the cache with a newer fetched_at value, effectively resetting the TTL.
|
|
pub(crate) async fn renew_cache_ttl(&self) -> io::Result<()> {
|
|
let mut cache = match self.load().await? {
|
|
Some(cache) => cache,
|
|
None => return Err(io::Error::new(ErrorKind::NotFound, "cache not found")),
|
|
};
|
|
cache.fetched_at = Utc::now();
|
|
self.save_internal(&cache).await
|
|
}
|
|
|
|
async fn load(&self) -> io::Result<Option<ModelsCache>> {
|
|
match fs::read(&self.cache_path).await {
|
|
Ok(contents) => {
|
|
let cache = serde_json::from_slice(&contents)
|
|
.map_err(|err| io::Error::new(ErrorKind::InvalidData, err.to_string()))?;
|
|
Ok(Some(cache))
|
|
}
|
|
Err(err) if err.kind() == ErrorKind::NotFound => Ok(None),
|
|
Err(err) => Err(err),
|
|
}
|
|
}
|
|
|
|
async fn save_internal(&self, cache: &ModelsCache) -> io::Result<()> {
|
|
if let Some(parent) = self.cache_path.parent() {
|
|
fs::create_dir_all(parent).await?;
|
|
}
|
|
let json = serde_json::to_vec_pretty(cache)
|
|
.map_err(|err| io::Error::new(ErrorKind::InvalidData, err.to_string()))?;
|
|
fs::write(&self.cache_path, json).await
|
|
}
|
|
|
|
#[cfg(test)]
|
|
/// Set the cache TTL.
|
|
pub(crate) fn set_ttl(&mut self, ttl: Duration) {
|
|
self.cache_ttl = ttl;
|
|
}
|
|
|
|
#[cfg(test)]
|
|
/// Manipulate cache file for testing. Allows setting a custom fetched_at timestamp.
|
|
pub(crate) async fn manipulate_cache_for_test<F>(&self, f: F) -> io::Result<()>
|
|
where
|
|
F: FnOnce(&mut DateTime<Utc>),
|
|
{
|
|
let mut cache = match self.load().await? {
|
|
Some(cache) => cache,
|
|
None => return Err(io::Error::new(ErrorKind::NotFound, "cache not found")),
|
|
};
|
|
f(&mut cache.fetched_at);
|
|
self.save_internal(&cache).await
|
|
}
|
|
}
|
|
|
|
/// Data returned from loading a fresh cache entry.
|
|
pub(crate) struct CacheData {
|
|
pub(crate) models: Vec<ModelInfo>,
|
|
pub(crate) etag: Option<String>,
|
|
}
|
|
|
|
/// Serialized snapshot of models and metadata cached on disk.
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
struct ModelsCache {
|
|
fetched_at: DateTime<Utc>,
|
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
|
etag: Option<String>,
|
|
models: Vec<ModelInfo>,
|
|
}
|
|
|
|
impl ModelsCache {
|
|
/// Returns `true` when the cache entry has not exceeded the configured TTL.
|
|
fn is_fresh(&self, ttl: Duration) -> bool {
|
|
if ttl.is_zero() {
|
|
return false;
|
|
}
|
|
let Ok(ttl_duration) = chrono::Duration::from_std(ttl) else {
|
|
return false;
|
|
};
|
|
let age = Utc::now().signed_duration_since(self.fetched_at);
|
|
age <= ttl_duration
|
|
}
|
|
}
|