perf(core): optimize token calculation and add support for multimodal tool responses (#17835)

This commit is contained in:
Abhi
2026-01-30 10:15:00 -05:00
committed by GitHub
parent 2238802e97
commit f605628624
2 changed files with 353 additions and 188 deletions

View File

@@ -5,180 +5,285 @@
*/
import { describe, it, expect, vi } from 'vitest';
import { calculateRequestTokenCount } from './tokenCalculation.js';
import {
calculateRequestTokenCount,
estimateTokenCountSync,
} from './tokenCalculation.js';
import type { ContentGenerator } from '../core/contentGenerator.js';
import type { Part } from '@google/genai';
describe('calculateRequestTokenCount', () => {
const mockContentGenerator = {
countTokens: vi.fn(),
} as unknown as ContentGenerator;
describe('tokenCalculation', () => {
describe('calculateRequestTokenCount', () => {
const mockContentGenerator = {
countTokens: vi.fn(),
} as unknown as ContentGenerator;
const model = 'gemini-pro';
const model = 'gemini-pro';
it('should use countTokens API for media requests (images/files)', async () => {
vi.mocked(mockContentGenerator.countTokens).mockResolvedValue({
totalTokens: 100,
it('should use countTokens API for media requests (images/files)', async () => {
vi.mocked(mockContentGenerator.countTokens).mockResolvedValue({
totalTokens: 100,
});
const request = [{ inlineData: { mimeType: 'image/png', data: 'data' } }];
const count = await calculateRequestTokenCount(
request,
mockContentGenerator,
model,
);
expect(count).toBe(100);
expect(mockContentGenerator.countTokens).toHaveBeenCalled();
});
const request = [{ inlineData: { mimeType: 'image/png', data: 'data' } }];
const count = await calculateRequestTokenCount(
request,
mockContentGenerator,
model,
);
it('should estimate tokens locally for tool calls', async () => {
vi.mocked(mockContentGenerator.countTokens).mockClear();
const request = [{ functionCall: { name: 'foo', args: { bar: 'baz' } } }];
expect(count).toBe(100);
expect(mockContentGenerator.countTokens).toHaveBeenCalled();
});
const count = await calculateRequestTokenCount(
request,
mockContentGenerator,
model,
);
it('should estimate tokens locally for tool calls', async () => {
vi.mocked(mockContentGenerator.countTokens).mockClear();
const request = [{ functionCall: { name: 'foo', args: { bar: 'baz' } } }];
const count = await calculateRequestTokenCount(
request,
mockContentGenerator,
model,
);
// Estimation logic: JSON.stringify(part).length / 4
// JSON: {"functionCall":{"name":"foo","args":{"bar":"baz"}}}
// Length: ~53 chars. 53 / 4 = 13.25 -> 13.
expect(count).toBeGreaterThan(0);
expect(mockContentGenerator.countTokens).not.toHaveBeenCalled();
});
it('should estimate tokens locally for simple ASCII text', async () => {
vi.mocked(mockContentGenerator.countTokens).mockClear();
// 12 chars. 12 * 0.25 = 3 tokens.
const request = 'Hello world!';
const count = await calculateRequestTokenCount(
request,
mockContentGenerator,
model,
);
expect(count).toBe(3);
expect(mockContentGenerator.countTokens).not.toHaveBeenCalled();
});
it('should estimate tokens locally for CJK text with higher weight', async () => {
vi.mocked(mockContentGenerator.countTokens).mockClear();
// 2 chars. 2 * 1.3 = 2.6 -> floor(2.6) = 2.
// Old logic would be 2/4 = 0.5 -> 0.
const request = '你好';
const count = await calculateRequestTokenCount(
request,
mockContentGenerator,
model,
);
expect(count).toBeGreaterThanOrEqual(2);
expect(mockContentGenerator.countTokens).not.toHaveBeenCalled();
});
it('should handle mixed content', async () => {
vi.mocked(mockContentGenerator.countTokens).mockClear();
// 'Hi': 2 * 0.25 = 0.5
// '你好': 2 * 1.3 = 2.6
// Total: 3.1 -> 3
const request = 'Hi你好';
const count = await calculateRequestTokenCount(
request,
mockContentGenerator,
model,
);
expect(count).toBe(3);
expect(mockContentGenerator.countTokens).not.toHaveBeenCalled();
});
it('should handle empty text', async () => {
const request = '';
const count = await calculateRequestTokenCount(
request,
mockContentGenerator,
model,
);
expect(count).toBe(0);
});
it('should fallback to local estimation when countTokens API fails', async () => {
vi.mocked(mockContentGenerator.countTokens).mockRejectedValue(
new Error('API error'),
);
const request = [
{ text: 'Hello' },
{ inlineData: { mimeType: 'image/png', data: 'data' } },
];
const count = await calculateRequestTokenCount(
request,
mockContentGenerator,
model,
);
// Should fallback to estimation:
// 'Hello': 5 chars * 0.25 = 1.25
// inlineData: 3000
// Total: 3001.25 -> 3001
expect(count).toBe(3001);
expect(mockContentGenerator.countTokens).toHaveBeenCalled();
});
it('should use fixed estimate for images in fallback', async () => {
vi.mocked(mockContentGenerator.countTokens).mockRejectedValue(
new Error('API error'),
);
const request = [
{ inlineData: { mimeType: 'image/png', data: 'large_data' } },
];
const count = await calculateRequestTokenCount(
request,
mockContentGenerator,
model,
);
expect(count).toBe(3000);
});
it('should use countTokens API for PDF requests', async () => {
vi.mocked(mockContentGenerator.countTokens).mockResolvedValue({
totalTokens: 5160,
expect(count).toBeGreaterThan(0);
expect(mockContentGenerator.countTokens).not.toHaveBeenCalled();
});
const request = [
{ inlineData: { mimeType: 'application/pdf', data: 'pdf_data' } },
];
const count = await calculateRequestTokenCount(
request,
mockContentGenerator,
model,
);
it('should estimate tokens locally for simple ASCII text', async () => {
vi.mocked(mockContentGenerator.countTokens).mockClear();
// 12 chars. 12 * 0.25 = 3 tokens.
const request = 'Hello world!';
expect(count).toBe(5160);
expect(mockContentGenerator.countTokens).toHaveBeenCalled();
const count = await calculateRequestTokenCount(
request,
mockContentGenerator,
model,
);
expect(count).toBe(3);
expect(mockContentGenerator.countTokens).not.toHaveBeenCalled();
});
it('should estimate tokens locally for CJK text with higher weight', async () => {
vi.mocked(mockContentGenerator.countTokens).mockClear();
// 2 chars. 2 * 1.3 = 2.6 -> floor(2.6) = 2.
const request = '你好';
const count = await calculateRequestTokenCount(
request,
mockContentGenerator,
model,
);
expect(count).toBeGreaterThanOrEqual(2);
expect(mockContentGenerator.countTokens).not.toHaveBeenCalled();
});
it('should handle mixed content', async () => {
vi.mocked(mockContentGenerator.countTokens).mockClear();
// 'Hi': 2 * 0.25 = 0.5
// '你好': 2 * 1.3 = 2.6
// Total: 3.1 -> 3
const request = 'Hi你好';
const count = await calculateRequestTokenCount(
request,
mockContentGenerator,
model,
);
expect(count).toBe(3);
expect(mockContentGenerator.countTokens).not.toHaveBeenCalled();
});
it('should handle empty text', async () => {
const request = '';
const count = await calculateRequestTokenCount(
request,
mockContentGenerator,
model,
);
expect(count).toBe(0);
});
it('should fallback to local estimation when countTokens API fails', async () => {
vi.mocked(mockContentGenerator.countTokens).mockRejectedValue(
new Error('API error'),
);
const request = [
{ text: 'Hello' },
{ inlineData: { mimeType: 'image/png', data: 'data' } },
];
const count = await calculateRequestTokenCount(
request,
mockContentGenerator,
model,
);
expect(count).toBe(3001);
expect(mockContentGenerator.countTokens).toHaveBeenCalled();
});
it('should use fixed estimate for images in fallback', async () => {
vi.mocked(mockContentGenerator.countTokens).mockRejectedValue(
new Error('API error'),
);
const request = [
{ inlineData: { mimeType: 'image/png', data: 'large_data' } },
];
const count = await calculateRequestTokenCount(
request,
mockContentGenerator,
model,
);
expect(count).toBe(3000);
});
it('should use countTokens API for PDF requests', async () => {
vi.mocked(mockContentGenerator.countTokens).mockResolvedValue({
totalTokens: 5160,
});
const request = [
{ inlineData: { mimeType: 'application/pdf', data: 'pdf_data' } },
];
const count = await calculateRequestTokenCount(
request,
mockContentGenerator,
model,
);
expect(count).toBe(5160);
expect(mockContentGenerator.countTokens).toHaveBeenCalled();
});
it('should use fixed estimate for PDFs in fallback', async () => {
vi.mocked(mockContentGenerator.countTokens).mockRejectedValue(
new Error('API error'),
);
const request = [
{ inlineData: { mimeType: 'application/pdf', data: 'large_pdf_data' } },
];
const count = await calculateRequestTokenCount(
request,
mockContentGenerator,
model,
);
// PDF estimate: 25800 tokens (~100 pages at 258 tokens/page)
expect(count).toBe(25800);
});
});
it('should use fixed estimate for PDFs in fallback', async () => {
vi.mocked(mockContentGenerator.countTokens).mockRejectedValue(
new Error('API error'),
);
const request = [
{ inlineData: { mimeType: 'application/pdf', data: 'large_pdf_data' } },
];
describe('estimateTokenCountSync', () => {
it('should use fast heuristic for massive strings', () => {
const massiveText = 'a'.repeat(200_000);
// 200,000 / 4 = 50,000 tokens
const parts: Part[] = [{ text: massiveText }];
expect(estimateTokenCountSync(parts)).toBe(50000);
});
const count = await calculateRequestTokenCount(
request,
mockContentGenerator,
model,
);
it('should estimate functionResponse without full stringification', () => {
const toolResult = 'result'.repeat(1000); // 6000 chars
const parts: Part[] = [
{
functionResponse: {
name: 'my_tool',
id: '123',
response: { output: toolResult },
},
},
];
// PDF estimate: 25800 tokens (~100 pages at 258 tokens/page)
expect(count).toBe(25800);
const tokens = estimateTokenCountSync(parts);
// payload ~6013 chars / 4 = 1503.25
// name 7 / 4 = 1.75
// total ~1505
expect(tokens).toBeGreaterThan(1500);
expect(tokens).toBeLessThan(1600);
});
it('should handle Gemini 3 multimodal nested parts in functionResponse', () => {
const parts: Part[] = [
{
functionResponse: {
name: 'multimodal_tool',
id: '456',
response: { status: 'success' },
// Gemini 3 nested parts
parts: [
{ inlineData: { mimeType: 'image/png', data: 'base64...' } },
{ text: 'Look at this image' },
] as Part[],
},
},
];
const tokens = estimateTokenCountSync(parts);
// image 3000 + text 4.5 + response 5 = ~3009.5
expect(tokens).toBeGreaterThan(3000);
expect(tokens).toBeLessThan(3100);
});
it('should respect the maximum recursion depth limit', () => {
// Create a structure nested to depth 5 (exceeding limit of 3)
const parts: Part[] = [
{
functionResponse: {
name: 'd0',
response: { val: 'a' }, // ~12 chars -> 3 tokens
parts: [
{
functionResponse: {
name: 'd1',
response: { val: 'a' }, // ~12 chars -> 3 tokens
parts: [
{
functionResponse: {
name: 'd2',
response: { val: 'a' }, // ~12 chars -> 3 tokens
parts: [
{
functionResponse: {
name: 'd3',
response: { val: 'a' }, // ~12 chars -> 3 tokens
parts: [
{
functionResponse: {
name: 'd4',
response: { val: 'a' },
},
},
] as Part[],
},
},
] as Part[],
},
},
] as Part[],
},
},
] as Part[],
},
},
];
const tokens = estimateTokenCountSync(parts);
// It should count d0, d1, d2, d3 (depth 0, 1, 2, 3) but NOT d4 (depth 4)
// d0..d3: 4 * ~4 tokens = ~16
expect(tokens).toBeGreaterThan(10);
expect(tokens).toBeLessThan(30);
});
it('should handle empty or nullish inputs gracefully', () => {
expect(estimateTokenCountSync([])).toBe(0);
expect(estimateTokenCountSync([{ text: '' }])).toBe(0);
expect(estimateTokenCountSync([{} as Part])).toBe(0);
});
});
});

View File

@@ -24,44 +24,104 @@ const PDF_TOKEN_ESTIMATE = 25800;
// Above this, we use a faster approximation to avoid performance bottlenecks.
const MAX_CHARS_FOR_FULL_HEURISTIC = 100_000;
// Maximum depth for recursive token estimation to prevent stack overflow from
// malicious or buggy nested structures. A depth of 3 is sufficient given
// standard multimodal responses are typically depth 1.
const MAX_RECURSION_DEPTH = 3;
/**
* Heuristic estimation of tokens for a text string.
*/
function estimateTextTokens(text: string): number {
if (text.length > MAX_CHARS_FOR_FULL_HEURISTIC) {
return text.length / 4;
}
let tokens = 0;
// Optimized loop: charCodeAt is faster than for...of on large strings
for (let i = 0; i < text.length; i++) {
if (text.charCodeAt(i) <= 127) {
tokens += ASCII_TOKENS_PER_CHAR;
} else {
tokens += NON_ASCII_TOKENS_PER_CHAR;
}
}
return tokens;
}
/**
* Heuristic estimation for media parts (images, PDFs) using fixed safe estimates.
*/
function estimateMediaTokens(part: Part): number | undefined {
const inlineData = 'inlineData' in part ? part.inlineData : undefined;
const fileData = 'fileData' in part ? part.fileData : undefined;
const mimeType = inlineData?.mimeType || fileData?.mimeType;
if (mimeType?.startsWith('image/')) {
// Images: 3,000 tokens (covers up to 4K resolution on Gemini 3)
// See: https://ai.google.dev/gemini-api/docs/vision#token_counting
return IMAGE_TOKEN_ESTIMATE;
} else if (mimeType?.startsWith('application/pdf')) {
// PDFs: 25,800 tokens (~100 pages at 258 tokens/page)
// See: https://ai.google.dev/gemini-api/docs/document-processing
return PDF_TOKEN_ESTIMATE;
}
return undefined;
}
/**
* Heuristic estimation for tool responses, avoiding massive string copies
* and accounting for nested Gemini 3 multimodal parts.
*/
function estimateFunctionResponseTokens(part: Part, depth: number): number {
const fr = part.functionResponse;
if (!fr) return 0;
let totalTokens = (fr.name?.length ?? 0) / 4;
const response = fr.response as unknown;
if (typeof response === 'string') {
totalTokens += response.length / 4;
} else if (response !== undefined && response !== null) {
// For objects, stringify only the payload, not the whole Part object.
totalTokens += JSON.stringify(response).length / 4;
}
// Gemini 3: Handle nested multimodal parts recursively.
const nestedParts = (fr as unknown as { parts?: Part[] }).parts;
if (nestedParts && nestedParts.length > 0) {
totalTokens += estimateTokenCountSync(nestedParts, depth + 1);
}
return totalTokens;
}
/**
* Estimates token count for parts synchronously using a heuristic.
* - Text: character-based heuristic (ASCII vs CJK) for small strings, length/4 for massive ones.
* - Non-text (Tools, etc): JSON string length / 4.
*/
export function estimateTokenCountSync(parts: Part[]): number {
export function estimateTokenCountSync(
parts: Part[],
depth: number = 0,
): number {
if (depth > MAX_RECURSION_DEPTH) {
return 0;
}
let totalTokens = 0;
for (const part of parts) {
if (typeof part.text === 'string') {
if (part.text.length > MAX_CHARS_FOR_FULL_HEURISTIC) {
totalTokens += part.text.length / 4;
} else {
for (const char of part.text) {
if (char.codePointAt(0)! <= 127) {
totalTokens += ASCII_TOKENS_PER_CHAR;
} else {
totalTokens += NON_ASCII_TOKENS_PER_CHAR;
}
}
}
totalTokens += estimateTextTokens(part.text);
} else if (part.functionResponse) {
totalTokens += estimateFunctionResponseTokens(part, depth);
} else {
// For images and PDFs, we use fixed safe estimates:
// - Images: 3,000 tokens (covers up to 4K resolution on Gemini 3)
// - PDFs: 25,800 tokens (~100 pages at 258 tokens/page)
// See: https://ai.google.dev/gemini-api/docs/vision#token_counting
// See: https://ai.google.dev/gemini-api/docs/document-processing
const inlineData = 'inlineData' in part ? part.inlineData : undefined;
const fileData = 'fileData' in part ? part.fileData : undefined;
const mimeType = inlineData?.mimeType || fileData?.mimeType;
if (mimeType?.startsWith('image/')) {
totalTokens += IMAGE_TOKEN_ESTIMATE;
} else if (mimeType?.startsWith('application/pdf')) {
totalTokens += PDF_TOKEN_ESTIMATE;
const mediaEstimate = estimateMediaTokens(part);
if (mediaEstimate !== undefined) {
totalTokens += mediaEstimate;
} else {
// For other non-text parts (functionCall, functionResponse, etc.),
// we fallback to the JSON string length heuristic.
// Note: This is an approximation.
// Fallback for other non-text parts (e.g., functionCall).
// Note: JSON.stringify(part) here is safe as these parts are typically small.
totalTokens += JSON.stringify(part).length / 4;
}
}