feat: cache tokenizer (#6609)

This commit is contained in:
jif-oai
2025-11-14 17:05:00 +01:00
committed by GitHub
parent 63c8c01f40
commit f17b392470
8 changed files with 112 additions and 45 deletions

View File

@@ -1,7 +1,9 @@
use std::fmt;
use std::num::NonZeroUsize;
use std::sync::OnceLock;
use anyhow::Context;
use anyhow::Error as AnyhowError;
use codex_utils_cache::BlockingLruCache;
use thiserror::Error;
use tiktoken_rs::CoreBPE;
@@ -37,6 +39,26 @@ pub enum TokenizerError {
},
}
fn model_cache() -> &'static BlockingLruCache<String, CoreBPE> {
static MODEL_CACHE: OnceLock<BlockingLruCache<String, CoreBPE>> = OnceLock::new();
MODEL_CACHE
.get_or_init(|| BlockingLruCache::new(NonZeroUsize::new(64).unwrap_or(NonZeroUsize::MIN)))
}
/// Fire-and-forget function used to pre-warm model tokenizer loading. This is done
/// on a best-effort basis, without any guarantee about the state of the cache
/// before or after.
/// Only working in Tokio runtimes
pub fn warm_model_cache(model: &str) {
if tokio::runtime::Handle::try_current().is_err() {
return;
}
let model = model.to_string();
tokio::spawn(async move {
let _ = Tokenizer::for_model(&model);
});
}
/// Thin wrapper around a `tiktoken_rs::CoreBPE` tokenizer.
#[derive(Clone)]
pub struct Tokenizer {
@@ -63,20 +85,13 @@ impl Tokenizer {
/// Build a tokenizer using an `OpenAI` model name (maps to an encoding).
/// Falls back to the `O200kBase` encoding when the model is unknown.
pub fn for_model(model: &str) -> Result<Self, TokenizerError> {
match tiktoken_rs::get_bpe_from_model(model) {
Ok(inner) => Ok(Self { inner }),
Err(model_error) => {
let inner = tiktoken_rs::o200k_base()
.with_context(|| {
format!("fallback after model lookup failure for {model}: {model_error}")
})
.map_err(|source| TokenizerError::LoadEncoding {
kind: EncodingKind::O200kBase,
source,
})?;
Ok(Self { inner })
let inner = model_cache().get_or_try_insert_with(model.to_owned(), || {
match tiktoken_rs::get_bpe_from_model(model) {
Ok(inner) => Ok(inner),
Err(_model_error) => Tokenizer::new(EncodingKind::O200kBase).map(|e| e.inner),
}
}
})?;
Ok(Self { inner })
}
/// Encode text to token IDs. If `with_special_tokens` is true, special
@@ -158,4 +173,9 @@ mod tests {
assert_eq!(tok.encode(text, false), fallback.encode(text, false));
Ok(())
}
#[test]
fn warm_model_cache_without_runtime_is_noop() {
warm_model_cache("gpt-5");
}
}