mirror of
https://github.com/openai/codex.git
synced 2026-02-02 06:57:03 +00:00
Compare commits
3 Commits
dev/cc/rel
...
token-usag
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
eb77fc79ed | ||
|
|
8835b955fb | ||
|
|
470b13c26f |
@@ -1,3 +1,4 @@
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::ModelProviderInfo;
|
||||
@@ -12,6 +13,7 @@ use crate::error::Result;
|
||||
use crate::error::RetryLimitReachedError;
|
||||
use crate::error::UnexpectedResponseError;
|
||||
use crate::model_family::ModelFamily;
|
||||
use crate::protocol::TokenUsage;
|
||||
use crate::tools::spec::create_tools_json_for_chat_completions_api;
|
||||
use crate::util::backoff;
|
||||
use bytes::Bytes;
|
||||
@@ -20,6 +22,7 @@ use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::FunctionCallOutputContentItem;
|
||||
use codex_protocol::models::ReasoningItemContent;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_utils_tokenizer::Tokenizer;
|
||||
use eventsource_stream::Eventsource;
|
||||
use futures::Stream;
|
||||
use futures::StreamExt;
|
||||
@@ -34,6 +37,102 @@ use tokio::time::timeout;
|
||||
use tracing::debug;
|
||||
use tracing::trace;
|
||||
|
||||
struct ChatUsageHeuristic {
|
||||
tokenizer: Arc<Tokenizer>,
|
||||
input_tokens: i64,
|
||||
output_tokens: i64,
|
||||
reasoning_tokens: i64,
|
||||
}
|
||||
|
||||
impl ChatUsageHeuristic {
|
||||
fn new(model: &str, messages: &[serde_json::Value]) -> Option<Self> {
|
||||
let tokenizer = match Tokenizer::for_model(model) {
|
||||
Ok(tok) => tok,
|
||||
Err(err) => {
|
||||
debug!(
|
||||
"failed to build tokenizer for model {model}; falling back to default: {err:?}"
|
||||
);
|
||||
match Tokenizer::try_default() {
|
||||
Ok(tok) => tok,
|
||||
Err(fallback_err) => {
|
||||
debug!(
|
||||
"failed to fall back to default tokenizer for model {model}: {fallback_err:?}"
|
||||
);
|
||||
return None;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let tokenizer = Arc::new(tokenizer);
|
||||
let mut input_tokens =
|
||||
4_i64.saturating_mul(i64::try_from(messages.len()).unwrap_or(i64::MAX));
|
||||
|
||||
for message in messages {
|
||||
input_tokens =
|
||||
input_tokens.saturating_add(Self::count_value_tokens(tokenizer.as_ref(), message));
|
||||
|
||||
if let Some(tool_calls) = message.get("tool_calls").and_then(|v| v.as_array()) {
|
||||
input_tokens = input_tokens.saturating_add(
|
||||
8_i64.saturating_mul(i64::try_from(tool_calls.len()).unwrap_or(i64::MAX)),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Some(Self {
|
||||
tokenizer,
|
||||
input_tokens,
|
||||
output_tokens: 0,
|
||||
reasoning_tokens: 0,
|
||||
})
|
||||
}
|
||||
|
||||
fn record_output(&mut self, text: &str) {
|
||||
if text.is_empty() {
|
||||
return;
|
||||
}
|
||||
self.output_tokens = self
|
||||
.output_tokens
|
||||
.saturating_add(self.tokenizer.count(text));
|
||||
}
|
||||
|
||||
fn record_reasoning(&mut self, text: &str) {
|
||||
if text.is_empty() {
|
||||
return;
|
||||
}
|
||||
self.reasoning_tokens = self
|
||||
.reasoning_tokens
|
||||
.saturating_add(self.tokenizer.count(text));
|
||||
}
|
||||
|
||||
fn to_usage(&self) -> TokenUsage {
|
||||
let total = self
|
||||
.input_tokens
|
||||
.saturating_add(self.output_tokens)
|
||||
.saturating_add(self.reasoning_tokens);
|
||||
TokenUsage {
|
||||
input_tokens: self.input_tokens,
|
||||
cached_input_tokens: 0,
|
||||
output_tokens: self.output_tokens,
|
||||
reasoning_output_tokens: self.reasoning_tokens,
|
||||
total_tokens: total,
|
||||
}
|
||||
}
|
||||
|
||||
fn count_value_tokens(tokenizer: &Tokenizer, value: &serde_json::Value) -> i64 {
|
||||
match value {
|
||||
serde_json::Value::String(s) => tokenizer.count(s),
|
||||
serde_json::Value::Array(items) => items.iter().fold(0_i64, |acc, item| {
|
||||
acc.saturating_add(Self::count_value_tokens(tokenizer, item))
|
||||
}),
|
||||
serde_json::Value::Object(map) => map.values().fold(0_i64, |acc, item| {
|
||||
acc.saturating_add(Self::count_value_tokens(tokenizer, item))
|
||||
}),
|
||||
_ => 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Implementation for the classic Chat Completions API.
|
||||
pub(crate) async fn stream_chat_completions(
|
||||
prompt: &Prompt,
|
||||
@@ -325,6 +424,8 @@ pub(crate) async fn stream_chat_completions(
|
||||
}
|
||||
|
||||
let tools_json = create_tools_json_for_chat_completions_api(&prompt.tools)?;
|
||||
let usage_heuristic = ChatUsageHeuristic::new(model_family.slug.as_str(), &messages);
|
||||
|
||||
let payload = json!({
|
||||
"model": model_family.slug,
|
||||
"messages": messages,
|
||||
@@ -368,6 +469,7 @@ pub(crate) async fn stream_chat_completions(
|
||||
tx_event,
|
||||
provider.stream_idle_timeout(),
|
||||
otel_event_manager.clone(),
|
||||
usage_heuristic,
|
||||
));
|
||||
return Ok(ResponseStream { rx_event });
|
||||
}
|
||||
@@ -421,6 +523,7 @@ async fn process_chat_sse<S>(
|
||||
tx_event: mpsc::Sender<Result<ResponseEvent>>,
|
||||
idle_timeout: Duration,
|
||||
otel_event_manager: OtelEventManager,
|
||||
mut usage_heuristic: Option<ChatUsageHeuristic>,
|
||||
) where
|
||||
S: Stream<Item = Result<Bytes>> + Unpin,
|
||||
{
|
||||
@@ -459,10 +562,11 @@ async fn process_chat_sse<S>(
|
||||
}
|
||||
Ok(None) => {
|
||||
// Stream closed gracefully – emit Completed with dummy id.
|
||||
let token_usage = usage_heuristic.as_ref().map(ChatUsageHeuristic::to_usage);
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::Completed {
|
||||
response_id: String::new(),
|
||||
token_usage: None,
|
||||
token_usage,
|
||||
}))
|
||||
.await;
|
||||
return;
|
||||
@@ -505,10 +609,11 @@ async fn process_chat_sse<S>(
|
||||
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
|
||||
}
|
||||
|
||||
let token_usage = usage_heuristic.as_ref().map(ChatUsageHeuristic::to_usage);
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::Completed {
|
||||
response_id: String::new(),
|
||||
token_usage: None,
|
||||
token_usage,
|
||||
}))
|
||||
.await;
|
||||
return;
|
||||
@@ -532,6 +637,9 @@ async fn process_chat_sse<S>(
|
||||
&& !content.is_empty()
|
||||
{
|
||||
assistant_text.push_str(content);
|
||||
if let Some(usage) = usage_heuristic.as_mut() {
|
||||
usage.record_output(content);
|
||||
}
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::OutputTextDelta(content.to_string())))
|
||||
.await;
|
||||
@@ -565,6 +673,9 @@ async fn process_chat_sse<S>(
|
||||
if let Some(reasoning) = maybe_text {
|
||||
// Accumulate so we can emit a terminal Reasoning item at the end.
|
||||
reasoning_text.push_str(&reasoning);
|
||||
if let Some(usage) = usage_heuristic.as_mut() {
|
||||
usage.record_reasoning(&reasoning);
|
||||
}
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::ReasoningContentDelta(reasoning)))
|
||||
.await;
|
||||
@@ -578,6 +689,9 @@ async fn process_chat_sse<S>(
|
||||
if let Some(s) = message_reasoning.as_str() {
|
||||
if !s.is_empty() {
|
||||
reasoning_text.push_str(s);
|
||||
if let Some(usage) = usage_heuristic.as_mut() {
|
||||
usage.record_reasoning(s);
|
||||
}
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::ReasoningContentDelta(s.to_string())))
|
||||
.await;
|
||||
@@ -590,6 +704,9 @@ async fn process_chat_sse<S>(
|
||||
&& !s.is_empty()
|
||||
{
|
||||
reasoning_text.push_str(s);
|
||||
if let Some(usage) = usage_heuristic.as_mut() {
|
||||
usage.record_reasoning(s);
|
||||
}
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::ReasoningContentDelta(s.to_string())))
|
||||
.await;
|
||||
@@ -608,18 +725,31 @@ async fn process_chat_sse<S>(
|
||||
|
||||
// Extract call_id if present.
|
||||
if let Some(id) = tool_call.get("id").and_then(|v| v.as_str()) {
|
||||
fn_call_state.call_id.get_or_insert_with(|| id.to_string());
|
||||
if fn_call_state.call_id.is_none() {
|
||||
if let Some(usage) = usage_heuristic.as_mut() {
|
||||
usage.record_output(id);
|
||||
}
|
||||
fn_call_state.call_id = Some(id.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
// Extract function details if present.
|
||||
if let Some(function) = tool_call.get("function") {
|
||||
if let Some(name) = function.get("name").and_then(|n| n.as_str()) {
|
||||
fn_call_state.name.get_or_insert_with(|| name.to_string());
|
||||
if fn_call_state.name.is_none() {
|
||||
if let Some(usage) = usage_heuristic.as_mut() {
|
||||
usage.record_output(name);
|
||||
}
|
||||
fn_call_state.name = Some(name.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(args_fragment) = function.get("arguments").and_then(|a| a.as_str())
|
||||
{
|
||||
fn_call_state.arguments.push_str(args_fragment);
|
||||
if let Some(usage) = usage_heuristic.as_mut() {
|
||||
usage.record_output(args_fragment);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -682,10 +812,11 @@ async fn process_chat_sse<S>(
|
||||
}
|
||||
|
||||
// Emit Completed regardless of reason so the agent can advance.
|
||||
let token_usage = usage_heuristic.as_ref().map(ChatUsageHeuristic::to_usage);
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::Completed {
|
||||
response_id: String::new(),
|
||||
token_usage: None,
|
||||
token_usage,
|
||||
}))
|
||||
.await;
|
||||
|
||||
|
||||
@@ -37,8 +37,10 @@ impl ModelInfo {
|
||||
}
|
||||
|
||||
pub(crate) fn get_model_info(model_family: &ModelFamily) -> Option<ModelInfo> {
|
||||
let slug = model_family.slug.as_str();
|
||||
match slug {
|
||||
let raw_slug = model_family.slug.as_str();
|
||||
let slug = raw_slug.strip_prefix("openai/").unwrap_or(raw_slug);
|
||||
let normalized_slug = slug.replace(':', "-");
|
||||
match normalized_slug.as_str() {
|
||||
// OSS models have a 128k shared token pool.
|
||||
// Arbitrarily splitting it: 3/4 input context, 1/4 output.
|
||||
// https://openai.com/index/gpt-oss-model-card/
|
||||
|
||||
@@ -185,6 +185,49 @@ async fn streams_text_without_reasoning() {
|
||||
assert_matches!(events[2], ResponseEvent::Completed { .. });
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn completed_event_includes_usage_estimate() {
|
||||
if network_disabled() {
|
||||
println!(
|
||||
"Skipping test because it cannot execute when network is disabled in a Codex sandbox."
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
let sse = concat!(
|
||||
"data: {\"choices\":[{\"delta\":{\"content\":\"hi\"}}]}\n\n",
|
||||
"data: {\"choices\":[{\"delta\":{}}]}\n\n",
|
||||
"data: [DONE]\n\n",
|
||||
);
|
||||
|
||||
let events = run_stream(sse).await;
|
||||
assert_eq!(events.len(), 3, "unexpected events: {events:?}");
|
||||
|
||||
let usage = events
|
||||
.iter()
|
||||
.find_map(|event| match event {
|
||||
ResponseEvent::Completed {
|
||||
token_usage: Some(usage),
|
||||
..
|
||||
} => Some(usage.clone()),
|
||||
_ => None,
|
||||
})
|
||||
.expect("missing usage estimate on Completed event");
|
||||
|
||||
assert!(
|
||||
usage.input_tokens > 0,
|
||||
"expected input tokens > 0, got {usage:?}"
|
||||
);
|
||||
assert!(
|
||||
usage.output_tokens > 0,
|
||||
"expected output tokens > 0, got {usage:?}"
|
||||
);
|
||||
assert!(
|
||||
usage.total_tokens >= usage.input_tokens + usage.output_tokens,
|
||||
"expected total tokens to cover input + output, got {usage:?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn streams_reasoning_from_string_delta() {
|
||||
if network_disabled() {
|
||||
|
||||
Reference in New Issue
Block a user