refactor(llm): normalize Usage as inclusive total + non-overlapping breakdown (#26735)

This commit is contained in:
Kit Langton
2026-05-10 21:52:50 -04:00
committed by GitHub
parent 38e4540119
commit 2703eff2e2
12 changed files with 275 additions and 68 deletions

View File

@@ -404,34 +404,56 @@ const mapFinishReason = (reason: string | null | undefined): FinishReason => {
return "unknown"
}
// Anthropic reports the non-overlapping breakdown natively — its
// `input_tokens` is the *non-cached* count per the Messages API docs, with
// cache reads and writes as separate fields. We sum them to derive the
// inclusive `inputTokens` the rest of the contract expects. Extended
// thinking tokens are *not* broken out by Anthropic — they're billed as
// part of `output_tokens`, so `reasoningTokens` stays `undefined` and
// `outputTokens` carries the combined total.
const mapUsage = (usage: AnthropicUsage | undefined): Usage | undefined => {
if (!usage) return undefined
const nonCached = usage.input_tokens
const cacheRead = usage.cache_read_input_tokens ?? undefined
const cacheWrite = usage.cache_creation_input_tokens ?? undefined
const inputTokens = ProviderShared.sumTokens(nonCached, cacheRead, cacheWrite)
return new Usage({
inputTokens: usage.input_tokens,
inputTokens,
outputTokens: usage.output_tokens,
cacheReadInputTokens: usage.cache_read_input_tokens ?? undefined,
cacheWriteInputTokens: usage.cache_creation_input_tokens ?? undefined,
totalTokens: ProviderShared.totalTokens(usage.input_tokens, usage.output_tokens, undefined),
native: usage,
nonCachedInputTokens: nonCached,
cacheReadInputTokens: cacheRead,
cacheWriteInputTokens: cacheWrite,
totalTokens: ProviderShared.totalTokens(inputTokens, usage.output_tokens, undefined),
providerMetadata: { anthropic: usage },
})
}
// Anthropic emits usage on `message_start` and again on `message_delta` — the
// final delta carries the authoritative totals. Right-biased merge: each
// field prefers `right` when defined, falls back to `left`. `totalTokens` is
// recomputed from the merged input/output to stay consistent.
// field prefers `right` when defined, falls back to `left`. `inputTokens` is
// recomputed from the merged breakdown so the inclusive total stays
// consistent with `nonCached + cacheRead + cacheWrite`.
const mergeUsage = (left: Usage | undefined, right: Usage | undefined) => {
if (!left) return right
if (!right) return left
const inputTokens = right.inputTokens ?? left.inputTokens
const nonCachedInputTokens = right.nonCachedInputTokens ?? left.nonCachedInputTokens
const cacheReadInputTokens = right.cacheReadInputTokens ?? left.cacheReadInputTokens
const cacheWriteInputTokens = right.cacheWriteInputTokens ?? left.cacheWriteInputTokens
const inputTokens = ProviderShared.sumTokens(nonCachedInputTokens, cacheReadInputTokens, cacheWriteInputTokens)
const outputTokens = right.outputTokens ?? left.outputTokens
return new Usage({
inputTokens,
outputTokens,
cacheReadInputTokens: right.cacheReadInputTokens ?? left.cacheReadInputTokens,
cacheWriteInputTokens: right.cacheWriteInputTokens ?? left.cacheWriteInputTokens,
nonCachedInputTokens,
cacheReadInputTokens,
cacheWriteInputTokens,
totalTokens: ProviderShared.totalTokens(inputTokens, outputTokens, undefined),
native: { ...left.native, ...right.native },
providerMetadata: {
anthropic: {
...(left.providerMetadata?.["anthropic"] ?? {}),
...(right.providerMetadata?.["anthropic"] ?? {}),
},
},
})
}

View File

@@ -395,15 +395,22 @@ const mapFinishReason = (reason: string): FinishReason => {
return "unknown"
}
// AWS Bedrock Converse reports `inputTokens` (inclusive total) with
// `cacheReadInputTokens` and `cacheWriteInputTokens` as subsets. Pass
// the total through and derive the non-cached breakdown. Bedrock does
// not break reasoning out of `outputTokens` for any current model.
const mapUsage = (usage: BedrockUsageSchema | undefined): Usage | undefined => {
if (!usage) return undefined
const cacheTotal = (usage.cacheReadInputTokens ?? 0) + (usage.cacheWriteInputTokens ?? 0)
const nonCached = ProviderShared.subtractTokens(usage.inputTokens, cacheTotal)
return new Usage({
inputTokens: usage.inputTokens,
outputTokens: usage.outputTokens,
totalTokens: ProviderShared.totalTokens(usage.inputTokens, usage.outputTokens, usage.totalTokens),
nonCachedInputTokens: nonCached,
cacheReadInputTokens: usage.cacheReadInputTokens,
cacheWriteInputTokens: usage.cacheWriteInputTokens,
native: usage,
totalTokens: ProviderShared.totalTokens(usage.inputTokens, usage.outputTokens, usage.totalTokens),
providerMetadata: { bedrock: usage },
})
}

View File

@@ -281,15 +281,30 @@ const fromRequest = Effect.fn("Gemini.fromRequest")(function* (request: LLMReque
// =============================================================================
// Stream Parsing
// =============================================================================
// Gemini reports `promptTokenCount` (inclusive total) with a
// `cachedContentTokenCount` subset. `candidatesTokenCount` is *exclusive*
// of `thoughtsTokenCount` — visible-only, not a total — so we sum the two
// to produce the inclusive `outputTokens` the rest of the contract expects.
const mapUsage = (usage: GeminiUsage | undefined) => {
if (!usage) return undefined
const cached = usage.cachedContentTokenCount
const nonCached = ProviderShared.subtractTokens(usage.promptTokenCount, cached)
// `candidatesTokenCount` is visible-only; sum with thoughts to produce the
// inclusive `outputTokens` the contract expects. Only compute the total
// when the visible component is reported — otherwise we'd fabricate an
// inclusive number from a partial breakdown.
const outputTokens =
usage.candidatesTokenCount !== undefined
? usage.candidatesTokenCount + (usage.thoughtsTokenCount ?? 0)
: undefined
return new Usage({
inputTokens: usage.promptTokenCount,
outputTokens: usage.candidatesTokenCount,
outputTokens,
nonCachedInputTokens: nonCached,
cacheReadInputTokens: cached,
reasoningTokens: usage.thoughtsTokenCount,
cacheReadInputTokens: usage.cachedContentTokenCount,
totalTokens: ProviderShared.totalTokens(usage.promptTokenCount, usage.candidatesTokenCount, usage.totalTokenCount),
native: usage,
totalTokens: ProviderShared.totalTokens(usage.promptTokenCount, outputTokens, usage.totalTokenCount),
providerMetadata: { google: usage },
})
}

View File

@@ -290,15 +290,24 @@ const mapFinishReason = (reason: string | null | undefined): FinishReason => {
return "unknown"
}
// OpenAI Chat reports `prompt_tokens` (inclusive total) with a
// `cached_tokens` subset, and `completion_tokens` (inclusive total) with
// a `reasoning_tokens` subset. We pass the inclusive totals through and
// derive the non-cached breakdown so the `LLM.Usage` contract is
// satisfied on both sides.
const mapUsage = (usage: OpenAIChatEvent["usage"]): Usage | undefined => {
if (!usage) return undefined
const cached = usage.prompt_tokens_details?.cached_tokens
const reasoning = usage.completion_tokens_details?.reasoning_tokens
const nonCached = ProviderShared.subtractTokens(usage.prompt_tokens, cached)
return new Usage({
inputTokens: usage.prompt_tokens,
outputTokens: usage.completion_tokens,
reasoningTokens: usage.completion_tokens_details?.reasoning_tokens,
cacheReadInputTokens: usage.prompt_tokens_details?.cached_tokens,
nonCachedInputTokens: nonCached,
cacheReadInputTokens: cached,
reasoningTokens: reasoning,
totalTokens: ProviderShared.totalTokens(usage.prompt_tokens, usage.completion_tokens, usage.total_tokens),
native: usage,
providerMetadata: { openai: usage },
})
}

View File

@@ -276,15 +276,23 @@ const fromRequest = Effect.fn("OpenAIResponses.fromRequest")(function* (request:
// =============================================================================
// Stream Parsing
// =============================================================================
// OpenAI Responses reports `input_tokens` (inclusive total) with a
// `cached_tokens` subset, and `output_tokens` (inclusive total) with a
// `reasoning_tokens` subset. Pass the totals through and derive the
// non-cached breakdown.
const mapUsage = (usage: OpenAIResponsesUsage | null | undefined) => {
if (!usage) return undefined
const cached = usage.input_tokens_details?.cached_tokens
const reasoning = usage.output_tokens_details?.reasoning_tokens
const nonCached = ProviderShared.subtractTokens(usage.input_tokens, cached)
return new Usage({
inputTokens: usage.input_tokens,
outputTokens: usage.output_tokens,
reasoningTokens: usage.output_tokens_details?.reasoning_tokens,
cacheReadInputTokens: usage.input_tokens_details?.cached_tokens,
nonCachedInputTokens: nonCached,
cacheReadInputTokens: cached,
reasoningTokens: reasoning,
totalTokens: ProviderShared.totalTokens(usage.input_tokens, usage.output_tokens, usage.total_tokens),
native: usage,
providerMetadata: { openai: usage },
})
}

View File

@@ -42,6 +42,13 @@ export interface ToolAccumulator {
* supplied total; otherwise falls back to `inputTokens + outputTokens` only
* when at least one is defined. Returns `undefined` when neither input nor
* output is known so routes don't publish a misleading `0`.
*
* Under the additive `LLM.Usage` contract, `inputTokens` and `outputTokens`
* are the non-cached input and visible output only. The provider-supplied
* `total` is the source of truth when present; the computed fallback
* under-counts cache and reasoning by design and exists mainly so
* Anthropic-style providers (which don't surface a total) still get a
* sensible aggregate on the input + output axes.
*/
export const totalTokens = (
inputTokens: number | undefined,
@@ -53,6 +60,38 @@ export const totalTokens = (
return (inputTokens ?? 0) + (outputTokens ?? 0)
}
/**
* Subtract `subtrahend` from `total`, clamping to zero if the provider
* reports a non-sensical breakdown (e.g. `cached_tokens > prompt_tokens`).
* Used by protocol mappers when deriving a non-overlapping breakdown field
* from a provider's inclusive total — `nonCachedInputTokens` from
* `inputTokens - cacheReadInputTokens - cacheWriteInputTokens`.
*
* If `total` is `undefined`, returns `undefined` (we don't fabricate
* counts). If `subtrahend` is `undefined`, returns `total` unchanged. The
* provider-native breakdown stays available on `Usage.native` for debugging.
*/
export const subtractTokens = (
total: number | undefined,
subtrahend: number | undefined,
): number | undefined => {
if (total === undefined) return undefined
if (subtrahend === undefined) return total
return Math.max(0, total - subtrahend)
}
/**
* Sum a list of optional token counts, returning `undefined` only when
* every value is `undefined` (so we don't fabricate a `0`). Used by
* protocol mappers to derive the inclusive `inputTokens` total from a
* provider that natively reports a non-overlapping breakdown
* (e.g. Anthropic, whose `input_tokens` is already non-cached only).
*/
export const sumTokens = (...values: ReadonlyArray<number | undefined>): number | undefined => {
if (values.every((value) => value === undefined)) return undefined
return values.reduce<number>((acc, value) => acc + (value ?? 0), 0)
}
export const eventError = (route: string, message: string, raw?: string) =>
new LLMError({
module: "ProviderShared",

View File

@@ -3,15 +3,70 @@ import { ContentBlockID, FinishReason, ProtocolID, ProviderMetadata, ResponseID,
import { ModelRef } from "./options"
import { ToolResultValue } from "./messages"
/**
* Token usage reported by an LLM provider.
*
* **Inclusive totals** (match AI SDK / OpenAI / LangChain convention — a
* reader from any of those ecosystems sees the number they expect):
*
* - `inputTokens` — total prompt tokens, *including* cached reads/writes.
* - `outputTokens` — total output tokens, *including* reasoning.
* - `totalTokens` — provider-supplied total, or `inputTokens + outputTokens`.
*
* **Non-overlapping breakdown** (every field is independently meaningful;
* consumers never have to subtract):
*
* - `nonCachedInputTokens` — the "fresh" portion of the prompt.
* - `cacheReadInputTokens` — input tokens served from cache.
* - `cacheWriteInputTokens` — input tokens written to cache.
* - `reasoningTokens` — subset of `outputTokens` spent on hidden reasoning.
*
* **Invariant**: `nonCachedInputTokens + cacheReadInputTokens +
* cacheWriteInputTokens = inputTokens`, and `reasoningTokens ≤ outputTokens`.
* Each protocol mapper computes whichever side it doesn't get natively,
* with `Math.max(0, …)` clamping for defense against provider bugs. Because
* every breakdown field is stored independently, downstream consumers can
* read whatever they need (cost-by-category, context-pressure, AI-SDK-style
* inclusive total) without ever subtracting — eliminating the underflow
* class of bug where a clamped difference would silently store the wrong
* value.
*
* **Semantics by provider**:
*
* - OpenAI Chat / Responses / Gemini / Bedrock: provider reports inclusive
* `inputTokens` and an inclusive `outputTokens`; mapper subtracts to
* derive the breakdown.
* - Anthropic: provider reports the breakdown natively (`input_tokens` is
* non-cached only); mapper sums to derive the inclusive `inputTokens`.
* Anthropic does *not* break extended-thinking out of `output_tokens`, so
* `reasoningTokens` is `undefined` and `outputTokens` carries the
* combined total — a documented limitation of the Anthropic API.
*
* `providerMetadata` always carries the provider's raw usage payload —
* keyed by provider name (`{ openai: ... }`, `{ anthropic: ... }`, etc.)
* — for fields we don't normalize and for billing-level audit trails.
* Matches the same escape-hatch field on `LLMEvent`.
*/
export class Usage extends Schema.Class<Usage>("LLM.Usage")({
inputTokens: Schema.optional(Schema.Number),
outputTokens: Schema.optional(Schema.Number),
reasoningTokens: Schema.optional(Schema.Number),
nonCachedInputTokens: Schema.optional(Schema.Number),
cacheReadInputTokens: Schema.optional(Schema.Number),
cacheWriteInputTokens: Schema.optional(Schema.Number),
reasoningTokens: Schema.optional(Schema.Number),
totalTokens: Schema.optional(Schema.Number),
native: Schema.optional(Schema.Record(Schema.String, Schema.Unknown)),
}) {}
providerMetadata: Schema.optional(ProviderMetadata),
}) {
/**
* Visible output tokens — `outputTokens` minus `reasoningTokens`, clamped
* to zero. The one place subtraction happens in this contract; the clamp
* means a provider reporting `reasoningTokens > outputTokens` produces a
* harmless zero rather than a negative that crashes downstream schemas.
*/
get visibleOutputTokens() {
return Math.max(0, (this.outputTokens ?? 0) - (this.reasoningTokens ?? 0))
}
}
export const RequestStart = Schema.Struct({
type: Schema.tag("request-start"),

View File

@@ -1,6 +1,6 @@
import { describe, expect } from "bun:test"
import { Effect } from "effect"
import { CacheHint, LLM, LLMError } from "../../src"
import { CacheHint, LLM, LLMError, Usage } from "../../src"
import { LLMClient } from "../../src/route"
import * as AnthropicMessages from "../../src/protocols/anthropic-messages"
import { it } from "../lib/effect"
@@ -110,10 +110,11 @@ describe("Anthropic Messages route", () => {
expect(response.text).toBe("Hello!")
expect(response.reasoning).toBe("thinking")
expect(response.usage).toMatchObject({
inputTokens: 5,
inputTokens: 6,
outputTokens: 2,
nonCachedInputTokens: 5,
cacheReadInputTokens: 1,
totalTokens: 7,
totalTokens: 8,
})
expect(response.events.find((event) => event.type === "reasoning-end")).toMatchObject({
providerMetadata: { anthropic: { signature: "sig_1" } },
@@ -152,7 +153,13 @@ describe("Anthropic Messages route", () => {
{
type: "request-finish",
reason: "tool-calls",
usage: { inputTokens: 5, outputTokens: 1, totalTokens: 6, native: { input_tokens: 5, output_tokens: 1 } },
usage: new Usage({
inputTokens: 5,
outputTokens: 1,
nonCachedInputTokens: 5,
totalTokens: 6,
providerMetadata: { anthropic: { input_tokens: 5, output_tokens: 1 } },
}),
},
])
}),

View File

@@ -1,6 +1,6 @@
import { describe, expect } from "bun:test"
import { Effect } from "effect"
import { LLM, LLMError } from "../../src"
import { LLM, LLMError, Usage } from "../../src"
import { LLMClient } from "../../src/route"
import * as Gemini from "../../src/protocols/gemini"
import { it } from "../lib/effect"
@@ -198,9 +198,10 @@ describe("Gemini route", () => {
expect(response.reasoning).toBe("thinking")
expect(response.usage).toMatchObject({
inputTokens: 5,
outputTokens: 2,
reasoningTokens: 1,
outputTokens: 3,
nonCachedInputTokens: 4,
cacheReadInputTokens: 1,
reasoningTokens: 1,
totalTokens: 7,
})
expect(response.events).toEqual([
@@ -210,20 +211,23 @@ describe("Gemini route", () => {
{
type: "request-finish",
reason: "stop",
usage: {
usage: new Usage({
inputTokens: 5,
outputTokens: 2,
reasoningTokens: 1,
outputTokens: 3,
nonCachedInputTokens: 4,
cacheReadInputTokens: 1,
reasoningTokens: 1,
totalTokens: 7,
native: {
promptTokenCount: 5,
candidatesTokenCount: 2,
totalTokenCount: 7,
thoughtsTokenCount: 1,
cachedContentTokenCount: 1,
providerMetadata: {
google: {
promptTokenCount: 5,
candidatesTokenCount: 2,
totalTokenCount: 7,
thoughtsTokenCount: 1,
cachedContentTokenCount: 1,
},
},
},
}),
},
])
}),
@@ -257,12 +261,13 @@ describe("Gemini route", () => {
{
type: "request-finish",
reason: "tool-calls",
usage: {
usage: new Usage({
inputTokens: 5,
outputTokens: 1,
nonCachedInputTokens: 5,
totalTokens: 6,
native: { promptTokenCount: 5, candidatesTokenCount: 1 },
},
providerMetadata: { google: { promptTokenCount: 5, candidatesTokenCount: 1 } },
}),
},
])
}),

View File

@@ -1,7 +1,7 @@
import { describe, expect } from "bun:test"
import { Effect, Schema, Stream } from "effect"
import { HttpClientRequest } from "effect/unstable/http"
import { LLM, LLMError } from "../../src"
import { LLM, LLMError, Usage } from "../../src"
import * as Azure from "../../src/providers/azure"
import * as OpenAI from "../../src/providers/openai"
import * as OpenAIChat from "../../src/protocols/openai-chat"
@@ -230,20 +230,23 @@ describe("OpenAI Chat route", () => {
{
type: "request-finish",
reason: "stop",
usage: {
usage: new Usage({
inputTokens: 5,
outputTokens: 2,
reasoningTokens: 0,
nonCachedInputTokens: 4,
cacheReadInputTokens: 1,
reasoningTokens: 0,
totalTokens: 7,
native: {
prompt_tokens: 5,
completion_tokens: 2,
total_tokens: 7,
prompt_tokens_details: { cached_tokens: 1 },
completion_tokens_details: { reasoning_tokens: 0 },
providerMetadata: {
openai: {
prompt_tokens: 5,
completion_tokens: 2,
total_tokens: 7,
prompt_tokens_details: { cached_tokens: 1 },
completion_tokens_details: { reasoning_tokens: 0 },
},
},
},
}),
},
])
}),

View File

@@ -1,7 +1,7 @@
import { describe, expect } from "bun:test"
import { ConfigProvider, Effect, Layer, Stream } from "effect"
import { Headers, HttpClientRequest } from "effect/unstable/http"
import { LLM, LLMError } from "../../src"
import { LLM, LLMError, Usage } from "../../src"
import { Auth, LLMClient, RequestExecutor, WebSocketExecutor } from "../../src/route"
import * as Azure from "../../src/providers/azure"
import * as OpenAI from "../../src/providers/openai"
@@ -342,20 +342,23 @@ describe("OpenAI Responses route", () => {
type: "request-finish",
reason: "stop",
providerMetadata: { openai: { responseId: "resp_1", serviceTier: "default" } },
usage: {
usage: new Usage({
inputTokens: 5,
outputTokens: 2,
reasoningTokens: 0,
nonCachedInputTokens: 4,
cacheReadInputTokens: 1,
reasoningTokens: 0,
totalTokens: 7,
native: {
input_tokens: 5,
output_tokens: 2,
total_tokens: 7,
input_tokens_details: { cached_tokens: 1 },
output_tokens_details: { reasoning_tokens: 0 },
providerMetadata: {
openai: {
input_tokens: 5,
output_tokens: 2,
total_tokens: 7,
input_tokens_details: { cached_tokens: 1 },
output_tokens_details: { reasoning_tokens: 0 },
},
},
},
}),
},
])
}),
@@ -411,7 +414,13 @@ describe("OpenAI Responses route", () => {
{
type: "request-finish",
reason: "tool-calls",
usage: { inputTokens: 5, outputTokens: 1, totalTokens: 6, native: { input_tokens: 5, output_tokens: 1 } },
usage: new Usage({
inputTokens: 5,
outputTokens: 1,
nonCachedInputTokens: 5,
totalTokens: 6,
providerMetadata: { openai: { input_tokens: 5, output_tokens: 1 } },
}),
},
])
}),

View File

@@ -1,6 +1,7 @@
import { describe, expect, test } from "bun:test"
import { Schema } from "effect"
import { ContentPart, LLMEvent, LLMRequest, ModelID, ModelLimits, ModelRef, ProviderID } from "../src/schema"
import { ContentPart, LLMEvent, LLMRequest, ModelID, ModelLimits, ModelRef, ProviderID, Usage } from "../src/schema"
import { ProviderShared } from "../src/protocols/shared"
const model = new ModelRef({
id: ModelID.make("fake-model"),
@@ -48,3 +49,30 @@ describe("llm schema", () => {
expect(ContentPart.guards.media({ type: "text", text: "hi" })).toBe(false)
})
})
describe("LLM.Usage", () => {
test("subtractTokens clamps non-sensical breakdowns to zero", () => {
// Defense against a provider reporting cached_tokens > prompt_tokens or
// reasoning_tokens > completion_tokens — the negative would otherwise
// round-trip through the pipeline and crash strict downstream schemas.
expect(ProviderShared.subtractTokens(5, 3)).toBe(2)
expect(ProviderShared.subtractTokens(5, 10)).toBe(0)
expect(ProviderShared.subtractTokens(5, undefined)).toBe(5)
expect(ProviderShared.subtractTokens(undefined, 3)).toBeUndefined()
expect(ProviderShared.subtractTokens(undefined, undefined)).toBeUndefined()
})
test("sumTokens returns undefined only when every input is undefined", () => {
expect(ProviderShared.sumTokens(1, 2, 3)).toBe(6)
expect(ProviderShared.sumTokens(1, undefined, 3)).toBe(4)
expect(ProviderShared.sumTokens(undefined, undefined, undefined)).toBeUndefined()
expect(ProviderShared.sumTokens()).toBeUndefined()
})
test("visibleOutputTokens clamps reasoning > output to zero", () => {
expect(new Usage({ outputTokens: 10, reasoningTokens: 4 }).visibleOutputTokens).toBe(6)
expect(new Usage({ outputTokens: 10 }).visibleOutputTokens).toBe(10)
expect(new Usage({ outputTokens: 4, reasoningTokens: 10 }).visibleOutputTokens).toBe(0)
expect(new Usage({}).visibleOutputTokens).toBe(0)
})
})