mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-02-01 22:48:03 +00:00
perf(core): optimize token calculation and add support for multimodal tool responses (#17835)
This commit is contained in:
@@ -5,180 +5,285 @@
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi } from 'vitest';
|
||||
import { calculateRequestTokenCount } from './tokenCalculation.js';
|
||||
import {
|
||||
calculateRequestTokenCount,
|
||||
estimateTokenCountSync,
|
||||
} from './tokenCalculation.js';
|
||||
import type { ContentGenerator } from '../core/contentGenerator.js';
|
||||
import type { Part } from '@google/genai';
|
||||
|
||||
describe('calculateRequestTokenCount', () => {
|
||||
const mockContentGenerator = {
|
||||
countTokens: vi.fn(),
|
||||
} as unknown as ContentGenerator;
|
||||
describe('tokenCalculation', () => {
|
||||
describe('calculateRequestTokenCount', () => {
|
||||
const mockContentGenerator = {
|
||||
countTokens: vi.fn(),
|
||||
} as unknown as ContentGenerator;
|
||||
|
||||
const model = 'gemini-pro';
|
||||
const model = 'gemini-pro';
|
||||
|
||||
it('should use countTokens API for media requests (images/files)', async () => {
|
||||
vi.mocked(mockContentGenerator.countTokens).mockResolvedValue({
|
||||
totalTokens: 100,
|
||||
it('should use countTokens API for media requests (images/files)', async () => {
|
||||
vi.mocked(mockContentGenerator.countTokens).mockResolvedValue({
|
||||
totalTokens: 100,
|
||||
});
|
||||
const request = [{ inlineData: { mimeType: 'image/png', data: 'data' } }];
|
||||
|
||||
const count = await calculateRequestTokenCount(
|
||||
request,
|
||||
mockContentGenerator,
|
||||
model,
|
||||
);
|
||||
|
||||
expect(count).toBe(100);
|
||||
expect(mockContentGenerator.countTokens).toHaveBeenCalled();
|
||||
});
|
||||
const request = [{ inlineData: { mimeType: 'image/png', data: 'data' } }];
|
||||
|
||||
const count = await calculateRequestTokenCount(
|
||||
request,
|
||||
mockContentGenerator,
|
||||
model,
|
||||
);
|
||||
it('should estimate tokens locally for tool calls', async () => {
|
||||
vi.mocked(mockContentGenerator.countTokens).mockClear();
|
||||
const request = [{ functionCall: { name: 'foo', args: { bar: 'baz' } } }];
|
||||
|
||||
expect(count).toBe(100);
|
||||
expect(mockContentGenerator.countTokens).toHaveBeenCalled();
|
||||
});
|
||||
const count = await calculateRequestTokenCount(
|
||||
request,
|
||||
mockContentGenerator,
|
||||
model,
|
||||
);
|
||||
|
||||
it('should estimate tokens locally for tool calls', async () => {
|
||||
vi.mocked(mockContentGenerator.countTokens).mockClear();
|
||||
const request = [{ functionCall: { name: 'foo', args: { bar: 'baz' } } }];
|
||||
|
||||
const count = await calculateRequestTokenCount(
|
||||
request,
|
||||
mockContentGenerator,
|
||||
model,
|
||||
);
|
||||
|
||||
// Estimation logic: JSON.stringify(part).length / 4
|
||||
// JSON: {"functionCall":{"name":"foo","args":{"bar":"baz"}}}
|
||||
// Length: ~53 chars. 53 / 4 = 13.25 -> 13.
|
||||
expect(count).toBeGreaterThan(0);
|
||||
expect(mockContentGenerator.countTokens).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should estimate tokens locally for simple ASCII text', async () => {
|
||||
vi.mocked(mockContentGenerator.countTokens).mockClear();
|
||||
// 12 chars. 12 * 0.25 = 3 tokens.
|
||||
const request = 'Hello world!';
|
||||
|
||||
const count = await calculateRequestTokenCount(
|
||||
request,
|
||||
mockContentGenerator,
|
||||
model,
|
||||
);
|
||||
|
||||
expect(count).toBe(3);
|
||||
expect(mockContentGenerator.countTokens).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should estimate tokens locally for CJK text with higher weight', async () => {
|
||||
vi.mocked(mockContentGenerator.countTokens).mockClear();
|
||||
// 2 chars. 2 * 1.3 = 2.6 -> floor(2.6) = 2.
|
||||
// Old logic would be 2/4 = 0.5 -> 0.
|
||||
const request = '你好';
|
||||
|
||||
const count = await calculateRequestTokenCount(
|
||||
request,
|
||||
mockContentGenerator,
|
||||
model,
|
||||
);
|
||||
|
||||
expect(count).toBeGreaterThanOrEqual(2);
|
||||
expect(mockContentGenerator.countTokens).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle mixed content', async () => {
|
||||
vi.mocked(mockContentGenerator.countTokens).mockClear();
|
||||
// 'Hi': 2 * 0.25 = 0.5
|
||||
// '你好': 2 * 1.3 = 2.6
|
||||
// Total: 3.1 -> 3
|
||||
const request = 'Hi你好';
|
||||
|
||||
const count = await calculateRequestTokenCount(
|
||||
request,
|
||||
mockContentGenerator,
|
||||
model,
|
||||
);
|
||||
|
||||
expect(count).toBe(3);
|
||||
expect(mockContentGenerator.countTokens).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle empty text', async () => {
|
||||
const request = '';
|
||||
const count = await calculateRequestTokenCount(
|
||||
request,
|
||||
mockContentGenerator,
|
||||
model,
|
||||
);
|
||||
expect(count).toBe(0);
|
||||
});
|
||||
|
||||
it('should fallback to local estimation when countTokens API fails', async () => {
|
||||
vi.mocked(mockContentGenerator.countTokens).mockRejectedValue(
|
||||
new Error('API error'),
|
||||
);
|
||||
const request = [
|
||||
{ text: 'Hello' },
|
||||
{ inlineData: { mimeType: 'image/png', data: 'data' } },
|
||||
];
|
||||
|
||||
const count = await calculateRequestTokenCount(
|
||||
request,
|
||||
mockContentGenerator,
|
||||
model,
|
||||
);
|
||||
|
||||
// Should fallback to estimation:
|
||||
// 'Hello': 5 chars * 0.25 = 1.25
|
||||
// inlineData: 3000
|
||||
// Total: 3001.25 -> 3001
|
||||
expect(count).toBe(3001);
|
||||
expect(mockContentGenerator.countTokens).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should use fixed estimate for images in fallback', async () => {
|
||||
vi.mocked(mockContentGenerator.countTokens).mockRejectedValue(
|
||||
new Error('API error'),
|
||||
);
|
||||
const request = [
|
||||
{ inlineData: { mimeType: 'image/png', data: 'large_data' } },
|
||||
];
|
||||
|
||||
const count = await calculateRequestTokenCount(
|
||||
request,
|
||||
mockContentGenerator,
|
||||
model,
|
||||
);
|
||||
|
||||
expect(count).toBe(3000);
|
||||
});
|
||||
|
||||
it('should use countTokens API for PDF requests', async () => {
|
||||
vi.mocked(mockContentGenerator.countTokens).mockResolvedValue({
|
||||
totalTokens: 5160,
|
||||
expect(count).toBeGreaterThan(0);
|
||||
expect(mockContentGenerator.countTokens).not.toHaveBeenCalled();
|
||||
});
|
||||
const request = [
|
||||
{ inlineData: { mimeType: 'application/pdf', data: 'pdf_data' } },
|
||||
];
|
||||
|
||||
const count = await calculateRequestTokenCount(
|
||||
request,
|
||||
mockContentGenerator,
|
||||
model,
|
||||
);
|
||||
it('should estimate tokens locally for simple ASCII text', async () => {
|
||||
vi.mocked(mockContentGenerator.countTokens).mockClear();
|
||||
// 12 chars. 12 * 0.25 = 3 tokens.
|
||||
const request = 'Hello world!';
|
||||
|
||||
expect(count).toBe(5160);
|
||||
expect(mockContentGenerator.countTokens).toHaveBeenCalled();
|
||||
const count = await calculateRequestTokenCount(
|
||||
request,
|
||||
mockContentGenerator,
|
||||
model,
|
||||
);
|
||||
|
||||
expect(count).toBe(3);
|
||||
expect(mockContentGenerator.countTokens).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should estimate tokens locally for CJK text with higher weight', async () => {
|
||||
vi.mocked(mockContentGenerator.countTokens).mockClear();
|
||||
// 2 chars. 2 * 1.3 = 2.6 -> floor(2.6) = 2.
|
||||
const request = '你好';
|
||||
|
||||
const count = await calculateRequestTokenCount(
|
||||
request,
|
||||
mockContentGenerator,
|
||||
model,
|
||||
);
|
||||
|
||||
expect(count).toBeGreaterThanOrEqual(2);
|
||||
expect(mockContentGenerator.countTokens).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle mixed content', async () => {
|
||||
vi.mocked(mockContentGenerator.countTokens).mockClear();
|
||||
// 'Hi': 2 * 0.25 = 0.5
|
||||
// '你好': 2 * 1.3 = 2.6
|
||||
// Total: 3.1 -> 3
|
||||
const request = 'Hi你好';
|
||||
|
||||
const count = await calculateRequestTokenCount(
|
||||
request,
|
||||
mockContentGenerator,
|
||||
model,
|
||||
);
|
||||
|
||||
expect(count).toBe(3);
|
||||
expect(mockContentGenerator.countTokens).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle empty text', async () => {
|
||||
const request = '';
|
||||
const count = await calculateRequestTokenCount(
|
||||
request,
|
||||
mockContentGenerator,
|
||||
model,
|
||||
);
|
||||
expect(count).toBe(0);
|
||||
});
|
||||
|
||||
it('should fallback to local estimation when countTokens API fails', async () => {
|
||||
vi.mocked(mockContentGenerator.countTokens).mockRejectedValue(
|
||||
new Error('API error'),
|
||||
);
|
||||
const request = [
|
||||
{ text: 'Hello' },
|
||||
{ inlineData: { mimeType: 'image/png', data: 'data' } },
|
||||
];
|
||||
|
||||
const count = await calculateRequestTokenCount(
|
||||
request,
|
||||
mockContentGenerator,
|
||||
model,
|
||||
);
|
||||
|
||||
expect(count).toBe(3001);
|
||||
expect(mockContentGenerator.countTokens).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should use fixed estimate for images in fallback', async () => {
|
||||
vi.mocked(mockContentGenerator.countTokens).mockRejectedValue(
|
||||
new Error('API error'),
|
||||
);
|
||||
const request = [
|
||||
{ inlineData: { mimeType: 'image/png', data: 'large_data' } },
|
||||
];
|
||||
|
||||
const count = await calculateRequestTokenCount(
|
||||
request,
|
||||
mockContentGenerator,
|
||||
model,
|
||||
);
|
||||
|
||||
expect(count).toBe(3000);
|
||||
});
|
||||
|
||||
it('should use countTokens API for PDF requests', async () => {
|
||||
vi.mocked(mockContentGenerator.countTokens).mockResolvedValue({
|
||||
totalTokens: 5160,
|
||||
});
|
||||
const request = [
|
||||
{ inlineData: { mimeType: 'application/pdf', data: 'pdf_data' } },
|
||||
];
|
||||
|
||||
const count = await calculateRequestTokenCount(
|
||||
request,
|
||||
mockContentGenerator,
|
||||
model,
|
||||
);
|
||||
|
||||
expect(count).toBe(5160);
|
||||
expect(mockContentGenerator.countTokens).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should use fixed estimate for PDFs in fallback', async () => {
|
||||
vi.mocked(mockContentGenerator.countTokens).mockRejectedValue(
|
||||
new Error('API error'),
|
||||
);
|
||||
const request = [
|
||||
{ inlineData: { mimeType: 'application/pdf', data: 'large_pdf_data' } },
|
||||
];
|
||||
|
||||
const count = await calculateRequestTokenCount(
|
||||
request,
|
||||
mockContentGenerator,
|
||||
model,
|
||||
);
|
||||
|
||||
// PDF estimate: 25800 tokens (~100 pages at 258 tokens/page)
|
||||
expect(count).toBe(25800);
|
||||
});
|
||||
});
|
||||
|
||||
it('should use fixed estimate for PDFs in fallback', async () => {
|
||||
vi.mocked(mockContentGenerator.countTokens).mockRejectedValue(
|
||||
new Error('API error'),
|
||||
);
|
||||
const request = [
|
||||
{ inlineData: { mimeType: 'application/pdf', data: 'large_pdf_data' } },
|
||||
];
|
||||
describe('estimateTokenCountSync', () => {
|
||||
it('should use fast heuristic for massive strings', () => {
|
||||
const massiveText = 'a'.repeat(200_000);
|
||||
// 200,000 / 4 = 50,000 tokens
|
||||
const parts: Part[] = [{ text: massiveText }];
|
||||
expect(estimateTokenCountSync(parts)).toBe(50000);
|
||||
});
|
||||
|
||||
const count = await calculateRequestTokenCount(
|
||||
request,
|
||||
mockContentGenerator,
|
||||
model,
|
||||
);
|
||||
it('should estimate functionResponse without full stringification', () => {
|
||||
const toolResult = 'result'.repeat(1000); // 6000 chars
|
||||
const parts: Part[] = [
|
||||
{
|
||||
functionResponse: {
|
||||
name: 'my_tool',
|
||||
id: '123',
|
||||
response: { output: toolResult },
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
// PDF estimate: 25800 tokens (~100 pages at 258 tokens/page)
|
||||
expect(count).toBe(25800);
|
||||
const tokens = estimateTokenCountSync(parts);
|
||||
// payload ~6013 chars / 4 = 1503.25
|
||||
// name 7 / 4 = 1.75
|
||||
// total ~1505
|
||||
expect(tokens).toBeGreaterThan(1500);
|
||||
expect(tokens).toBeLessThan(1600);
|
||||
});
|
||||
|
||||
it('should handle Gemini 3 multimodal nested parts in functionResponse', () => {
|
||||
const parts: Part[] = [
|
||||
{
|
||||
functionResponse: {
|
||||
name: 'multimodal_tool',
|
||||
id: '456',
|
||||
response: { status: 'success' },
|
||||
// Gemini 3 nested parts
|
||||
parts: [
|
||||
{ inlineData: { mimeType: 'image/png', data: 'base64...' } },
|
||||
{ text: 'Look at this image' },
|
||||
] as Part[],
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
const tokens = estimateTokenCountSync(parts);
|
||||
// image 3000 + text 4.5 + response 5 = ~3009.5
|
||||
expect(tokens).toBeGreaterThan(3000);
|
||||
expect(tokens).toBeLessThan(3100);
|
||||
});
|
||||
|
||||
it('should respect the maximum recursion depth limit', () => {
|
||||
// Create a structure nested to depth 5 (exceeding limit of 3)
|
||||
const parts: Part[] = [
|
||||
{
|
||||
functionResponse: {
|
||||
name: 'd0',
|
||||
response: { val: 'a' }, // ~12 chars -> 3 tokens
|
||||
parts: [
|
||||
{
|
||||
functionResponse: {
|
||||
name: 'd1',
|
||||
response: { val: 'a' }, // ~12 chars -> 3 tokens
|
||||
parts: [
|
||||
{
|
||||
functionResponse: {
|
||||
name: 'd2',
|
||||
response: { val: 'a' }, // ~12 chars -> 3 tokens
|
||||
parts: [
|
||||
{
|
||||
functionResponse: {
|
||||
name: 'd3',
|
||||
response: { val: 'a' }, // ~12 chars -> 3 tokens
|
||||
parts: [
|
||||
{
|
||||
functionResponse: {
|
||||
name: 'd4',
|
||||
response: { val: 'a' },
|
||||
},
|
||||
},
|
||||
] as Part[],
|
||||
},
|
||||
},
|
||||
] as Part[],
|
||||
},
|
||||
},
|
||||
] as Part[],
|
||||
},
|
||||
},
|
||||
] as Part[],
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
const tokens = estimateTokenCountSync(parts);
|
||||
// It should count d0, d1, d2, d3 (depth 0, 1, 2, 3) but NOT d4 (depth 4)
|
||||
// d0..d3: 4 * ~4 tokens = ~16
|
||||
expect(tokens).toBeGreaterThan(10);
|
||||
expect(tokens).toBeLessThan(30);
|
||||
});
|
||||
|
||||
it('should handle empty or nullish inputs gracefully', () => {
|
||||
expect(estimateTokenCountSync([])).toBe(0);
|
||||
expect(estimateTokenCountSync([{ text: '' }])).toBe(0);
|
||||
expect(estimateTokenCountSync([{} as Part])).toBe(0);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -24,44 +24,104 @@ const PDF_TOKEN_ESTIMATE = 25800;
|
||||
// Above this, we use a faster approximation to avoid performance bottlenecks.
|
||||
const MAX_CHARS_FOR_FULL_HEURISTIC = 100_000;
|
||||
|
||||
// Maximum depth for recursive token estimation to prevent stack overflow from
|
||||
// malicious or buggy nested structures. A depth of 3 is sufficient given
|
||||
// standard multimodal responses are typically depth 1.
|
||||
const MAX_RECURSION_DEPTH = 3;
|
||||
|
||||
/**
|
||||
* Heuristic estimation of tokens for a text string.
|
||||
*/
|
||||
function estimateTextTokens(text: string): number {
|
||||
if (text.length > MAX_CHARS_FOR_FULL_HEURISTIC) {
|
||||
return text.length / 4;
|
||||
}
|
||||
|
||||
let tokens = 0;
|
||||
// Optimized loop: charCodeAt is faster than for...of on large strings
|
||||
for (let i = 0; i < text.length; i++) {
|
||||
if (text.charCodeAt(i) <= 127) {
|
||||
tokens += ASCII_TOKENS_PER_CHAR;
|
||||
} else {
|
||||
tokens += NON_ASCII_TOKENS_PER_CHAR;
|
||||
}
|
||||
}
|
||||
return tokens;
|
||||
}
|
||||
|
||||
/**
|
||||
* Heuristic estimation for media parts (images, PDFs) using fixed safe estimates.
|
||||
*/
|
||||
function estimateMediaTokens(part: Part): number | undefined {
|
||||
const inlineData = 'inlineData' in part ? part.inlineData : undefined;
|
||||
const fileData = 'fileData' in part ? part.fileData : undefined;
|
||||
const mimeType = inlineData?.mimeType || fileData?.mimeType;
|
||||
|
||||
if (mimeType?.startsWith('image/')) {
|
||||
// Images: 3,000 tokens (covers up to 4K resolution on Gemini 3)
|
||||
// See: https://ai.google.dev/gemini-api/docs/vision#token_counting
|
||||
return IMAGE_TOKEN_ESTIMATE;
|
||||
} else if (mimeType?.startsWith('application/pdf')) {
|
||||
// PDFs: 25,800 tokens (~100 pages at 258 tokens/page)
|
||||
// See: https://ai.google.dev/gemini-api/docs/document-processing
|
||||
return PDF_TOKEN_ESTIMATE;
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Heuristic estimation for tool responses, avoiding massive string copies
|
||||
* and accounting for nested Gemini 3 multimodal parts.
|
||||
*/
|
||||
function estimateFunctionResponseTokens(part: Part, depth: number): number {
|
||||
const fr = part.functionResponse;
|
||||
if (!fr) return 0;
|
||||
|
||||
let totalTokens = (fr.name?.length ?? 0) / 4;
|
||||
const response = fr.response as unknown;
|
||||
|
||||
if (typeof response === 'string') {
|
||||
totalTokens += response.length / 4;
|
||||
} else if (response !== undefined && response !== null) {
|
||||
// For objects, stringify only the payload, not the whole Part object.
|
||||
totalTokens += JSON.stringify(response).length / 4;
|
||||
}
|
||||
|
||||
// Gemini 3: Handle nested multimodal parts recursively.
|
||||
const nestedParts = (fr as unknown as { parts?: Part[] }).parts;
|
||||
if (nestedParts && nestedParts.length > 0) {
|
||||
totalTokens += estimateTokenCountSync(nestedParts, depth + 1);
|
||||
}
|
||||
|
||||
return totalTokens;
|
||||
}
|
||||
|
||||
/**
|
||||
* Estimates token count for parts synchronously using a heuristic.
|
||||
* - Text: character-based heuristic (ASCII vs CJK) for small strings, length/4 for massive ones.
|
||||
* - Non-text (Tools, etc): JSON string length / 4.
|
||||
*/
|
||||
export function estimateTokenCountSync(parts: Part[]): number {
|
||||
export function estimateTokenCountSync(
|
||||
parts: Part[],
|
||||
depth: number = 0,
|
||||
): number {
|
||||
if (depth > MAX_RECURSION_DEPTH) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
let totalTokens = 0;
|
||||
for (const part of parts) {
|
||||
if (typeof part.text === 'string') {
|
||||
if (part.text.length > MAX_CHARS_FOR_FULL_HEURISTIC) {
|
||||
totalTokens += part.text.length / 4;
|
||||
} else {
|
||||
for (const char of part.text) {
|
||||
if (char.codePointAt(0)! <= 127) {
|
||||
totalTokens += ASCII_TOKENS_PER_CHAR;
|
||||
} else {
|
||||
totalTokens += NON_ASCII_TOKENS_PER_CHAR;
|
||||
}
|
||||
}
|
||||
}
|
||||
totalTokens += estimateTextTokens(part.text);
|
||||
} else if (part.functionResponse) {
|
||||
totalTokens += estimateFunctionResponseTokens(part, depth);
|
||||
} else {
|
||||
// For images and PDFs, we use fixed safe estimates:
|
||||
// - Images: 3,000 tokens (covers up to 4K resolution on Gemini 3)
|
||||
// - PDFs: 25,800 tokens (~100 pages at 258 tokens/page)
|
||||
// See: https://ai.google.dev/gemini-api/docs/vision#token_counting
|
||||
// See: https://ai.google.dev/gemini-api/docs/document-processing
|
||||
const inlineData = 'inlineData' in part ? part.inlineData : undefined;
|
||||
const fileData = 'fileData' in part ? part.fileData : undefined;
|
||||
const mimeType = inlineData?.mimeType || fileData?.mimeType;
|
||||
|
||||
if (mimeType?.startsWith('image/')) {
|
||||
totalTokens += IMAGE_TOKEN_ESTIMATE;
|
||||
} else if (mimeType?.startsWith('application/pdf')) {
|
||||
totalTokens += PDF_TOKEN_ESTIMATE;
|
||||
const mediaEstimate = estimateMediaTokens(part);
|
||||
if (mediaEstimate !== undefined) {
|
||||
totalTokens += mediaEstimate;
|
||||
} else {
|
||||
// For other non-text parts (functionCall, functionResponse, etc.),
|
||||
// we fallback to the JSON string length heuristic.
|
||||
// Note: This is an approximation.
|
||||
// Fallback for other non-text parts (e.g., functionCall).
|
||||
// Note: JSON.stringify(part) here is safe as these parts are typically small.
|
||||
totalTokens += JSON.stringify(part).length / 4;
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user