Compare commits

...

8 Commits

Author SHA1 Message Date
Ahmed Ibrahim
787774c63d Merge branch 'main' into start-tokenizer-on-start 2025-11-13 22:56:10 -08:00
Ahmed Ibrahim
285c825c93 Merge branch 'main' into start-tokenizer-on-start 2025-11-13 17:43:13 -08:00
Ahmed Ibrahim
df640801f2 move 2025-11-13 17:39:07 -08:00
Ahmed Ibrahim
ab1287cdb0 move 2025-11-13 17:31:18 -08:00
Ahmed Ibrahim
3b98b6e98b move 2025-11-13 17:31:02 -08:00
Ahmed Ibrahim
01ba9446ac move 2025-11-13 17:30:51 -08:00
Ahmed Ibrahim
090ffa4b0f move 2025-11-13 17:30:20 -08:00
Ahmed Ibrahim
109f00da0a start 2025-11-13 16:59:00 -08:00
6 changed files with 51 additions and 5 deletions

2
codex-rs/Cargo.lock generated
View File

@@ -1553,6 +1553,8 @@ dependencies = [
"pretty_assertions",
"thiserror 2.0.17",
"tiktoken-rs",
"tokio",
"tracing",
]
[[package]]

View File

@@ -7,7 +7,7 @@ use codex_protocol::models::FunctionCallOutputPayload;
use codex_protocol::models::ResponseItem;
use codex_protocol::protocol::TokenUsage;
use codex_protocol::protocol::TokenUsageInfo;
use codex_utils_tokenizer::Tokenizer;
use codex_utils_tokenizer::shared_default_tokenizer;
use std::ops::Deref;
const CONTEXT_WINDOW_HARD_LIMIT_FACTOR: f64 = 1.1;
@@ -85,8 +85,7 @@ impl ContextManager {
// /!\ The value is a lower bound estimate and does not represent the exact
// context length.
pub(crate) fn estimate_token_count(&self, turn_context: &TurnContext) -> Option<i64> {
let model = turn_context.client.get_model();
let tokenizer = Tokenizer::for_model(model.as_str()).ok()?;
let tokenizer = shared_default_tokenizer()?;
let model_family = turn_context.client.get_model_family();
Some(

View File

@@ -17,6 +17,7 @@ use codex_protocol::models::ResponseItem;
use codex_protocol::protocol::InitialHistory;
use codex_protocol::protocol::RolloutItem;
use codex_protocol::protocol::SessionSource;
use codex_utils_tokenizer::warm_up_default_tokenizer;
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
@@ -40,6 +41,7 @@ pub struct ConversationManager {
impl ConversationManager {
pub fn new(auth_manager: Arc<AuthManager>, session_source: SessionSource) -> Self {
warm_up_default_tokenizer();
Self {
conversations: Arc::new(RwLock::new(HashMap::new())),
auth_manager,

View File

@@ -1,7 +1,7 @@
//! Utilities for truncating large chunks of output while preserving a prefix
//! and suffix on UTF-8 boundaries.
use codex_utils_tokenizer::Tokenizer;
use codex_utils_tokenizer::shared_default_tokenizer;
/// Truncate the middle of a UTF-8 string to at most `max_bytes` bytes,
/// preserving the beginning and the end. Returns the possibly truncated
@@ -15,7 +15,7 @@ pub(crate) fn truncate_middle(s: &str, max_bytes: usize) -> (String, Option<u64>
// Build a tokenizer for counting (default to o200k_base; fall back to cl100k_base).
// If both fail, fall back to a 4-bytes-per-token estimate.
let tok = Tokenizer::try_default().ok();
let tok = shared_default_tokenizer();
let token_count = |text: &str| -> u64 {
if let Some(ref t) = tok {
t.count(text) as u64

View File

@@ -10,6 +10,8 @@ workspace = true
anyhow = { workspace = true }
thiserror = { workspace = true }
tiktoken-rs = "0.7"
tokio.workspace = true
tracing = { workspace = true }
[dev-dependencies]
pretty_assertions = { workspace = true }

View File

@@ -1,9 +1,15 @@
use std::fmt;
use std::sync::Arc;
use std::sync::OnceLock;
use std::time::Duration;
use anyhow::Context;
use anyhow::Error as AnyhowError;
use thiserror::Error;
use tiktoken_rs::CoreBPE;
use tracing::error;
static DEFAULT_TOKENIZER: OnceLock<Result<Arc<Tokenizer>, TokenizerError>> = OnceLock::new();
/// Supported local encodings.
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
@@ -107,6 +113,31 @@ impl Tokenizer {
}
}
pub fn warm_up_default_tokenizer() {
tokio::spawn(tokio::time::timeout(Duration::from_secs(5), async {
let _ = shared_default_tokenizer();
}));
}
/// Return a shared default tokenizer (`O200kBase`), loading it once per process.
/// Returns `None` if initialization fails.
#[must_use]
pub fn shared_default_tokenizer() -> Option<Arc<Tokenizer>> {
DEFAULT_TOKENIZER
.get_or_init(init_default_tokenizer)
.as_ref()
.ok()
.cloned()
}
fn init_default_tokenizer() -> Result<Arc<Tokenizer>, TokenizerError> {
let result = Tokenizer::try_default().map(Arc::new);
if let Err(ref error) = result {
error!("failed to initialize default tokenizer: {error}");
}
result
}
#[cfg(test)]
mod tests {
use super::*;
@@ -158,4 +189,14 @@ mod tests {
assert_eq!(tok.encode(text, false), fallback.encode(text, false));
Ok(())
}
#[test]
fn shared_default_tokenizer_is_cached() {
let first = shared_default_tokenizer().expect("default tokenizer");
let second = shared_default_tokenizer().expect("default tokenizer reused");
let ptr1 = Arc::as_ptr(&first);
let ptr2 = Arc::as_ptr(&second);
assert_eq!(ptr1, ptr2);
}
}