From f605628624343ea6230892bf39ec6664f6993271 Mon Sep 17 00:00:00 2001 From: Abhi <43648792+abhipatel12@users.noreply.github.com> Date: Fri, 30 Jan 2026 10:15:00 -0500 Subject: [PATCH] perf(core): optimize token calculation and add support for multimodal tool responses (#17835) --- .../core/src/utils/tokenCalculation.test.ts | 425 +++++++++++------- packages/core/src/utils/tokenCalculation.ts | 116 +++-- 2 files changed, 353 insertions(+), 188 deletions(-) diff --git a/packages/core/src/utils/tokenCalculation.test.ts b/packages/core/src/utils/tokenCalculation.test.ts index 126ef7bac2..e642669708 100644 --- a/packages/core/src/utils/tokenCalculation.test.ts +++ b/packages/core/src/utils/tokenCalculation.test.ts @@ -5,180 +5,285 @@ */ import { describe, it, expect, vi } from 'vitest'; -import { calculateRequestTokenCount } from './tokenCalculation.js'; +import { + calculateRequestTokenCount, + estimateTokenCountSync, +} from './tokenCalculation.js'; import type { ContentGenerator } from '../core/contentGenerator.js'; +import type { Part } from '@google/genai'; -describe('calculateRequestTokenCount', () => { - const mockContentGenerator = { - countTokens: vi.fn(), - } as unknown as ContentGenerator; +describe('tokenCalculation', () => { + describe('calculateRequestTokenCount', () => { + const mockContentGenerator = { + countTokens: vi.fn(), + } as unknown as ContentGenerator; - const model = 'gemini-pro'; + const model = 'gemini-pro'; - it('should use countTokens API for media requests (images/files)', async () => { - vi.mocked(mockContentGenerator.countTokens).mockResolvedValue({ - totalTokens: 100, + it('should use countTokens API for media requests (images/files)', async () => { + vi.mocked(mockContentGenerator.countTokens).mockResolvedValue({ + totalTokens: 100, + }); + const request = [{ inlineData: { mimeType: 'image/png', data: 'data' } }]; + + const count = await calculateRequestTokenCount( + request, + mockContentGenerator, + model, + ); + + expect(count).toBe(100); + expect(mockContentGenerator.countTokens).toHaveBeenCalled(); }); - const request = [{ inlineData: { mimeType: 'image/png', data: 'data' } }]; - const count = await calculateRequestTokenCount( - request, - mockContentGenerator, - model, - ); + it('should estimate tokens locally for tool calls', async () => { + vi.mocked(mockContentGenerator.countTokens).mockClear(); + const request = [{ functionCall: { name: 'foo', args: { bar: 'baz' } } }]; - expect(count).toBe(100); - expect(mockContentGenerator.countTokens).toHaveBeenCalled(); - }); + const count = await calculateRequestTokenCount( + request, + mockContentGenerator, + model, + ); - it('should estimate tokens locally for tool calls', async () => { - vi.mocked(mockContentGenerator.countTokens).mockClear(); - const request = [{ functionCall: { name: 'foo', args: { bar: 'baz' } } }]; - - const count = await calculateRequestTokenCount( - request, - mockContentGenerator, - model, - ); - - // Estimation logic: JSON.stringify(part).length / 4 - // JSON: {"functionCall":{"name":"foo","args":{"bar":"baz"}}} - // Length: ~53 chars. 53 / 4 = 13.25 -> 13. - expect(count).toBeGreaterThan(0); - expect(mockContentGenerator.countTokens).not.toHaveBeenCalled(); - }); - - it('should estimate tokens locally for simple ASCII text', async () => { - vi.mocked(mockContentGenerator.countTokens).mockClear(); - // 12 chars. 12 * 0.25 = 3 tokens. - const request = 'Hello world!'; - - const count = await calculateRequestTokenCount( - request, - mockContentGenerator, - model, - ); - - expect(count).toBe(3); - expect(mockContentGenerator.countTokens).not.toHaveBeenCalled(); - }); - - it('should estimate tokens locally for CJK text with higher weight', async () => { - vi.mocked(mockContentGenerator.countTokens).mockClear(); - // 2 chars. 2 * 1.3 = 2.6 -> floor(2.6) = 2. - // Old logic would be 2/4 = 0.5 -> 0. - const request = '你好'; - - const count = await calculateRequestTokenCount( - request, - mockContentGenerator, - model, - ); - - expect(count).toBeGreaterThanOrEqual(2); - expect(mockContentGenerator.countTokens).not.toHaveBeenCalled(); - }); - - it('should handle mixed content', async () => { - vi.mocked(mockContentGenerator.countTokens).mockClear(); - // 'Hi': 2 * 0.25 = 0.5 - // '你好': 2 * 1.3 = 2.6 - // Total: 3.1 -> 3 - const request = 'Hi你好'; - - const count = await calculateRequestTokenCount( - request, - mockContentGenerator, - model, - ); - - expect(count).toBe(3); - expect(mockContentGenerator.countTokens).not.toHaveBeenCalled(); - }); - - it('should handle empty text', async () => { - const request = ''; - const count = await calculateRequestTokenCount( - request, - mockContentGenerator, - model, - ); - expect(count).toBe(0); - }); - - it('should fallback to local estimation when countTokens API fails', async () => { - vi.mocked(mockContentGenerator.countTokens).mockRejectedValue( - new Error('API error'), - ); - const request = [ - { text: 'Hello' }, - { inlineData: { mimeType: 'image/png', data: 'data' } }, - ]; - - const count = await calculateRequestTokenCount( - request, - mockContentGenerator, - model, - ); - - // Should fallback to estimation: - // 'Hello': 5 chars * 0.25 = 1.25 - // inlineData: 3000 - // Total: 3001.25 -> 3001 - expect(count).toBe(3001); - expect(mockContentGenerator.countTokens).toHaveBeenCalled(); - }); - - it('should use fixed estimate for images in fallback', async () => { - vi.mocked(mockContentGenerator.countTokens).mockRejectedValue( - new Error('API error'), - ); - const request = [ - { inlineData: { mimeType: 'image/png', data: 'large_data' } }, - ]; - - const count = await calculateRequestTokenCount( - request, - mockContentGenerator, - model, - ); - - expect(count).toBe(3000); - }); - - it('should use countTokens API for PDF requests', async () => { - vi.mocked(mockContentGenerator.countTokens).mockResolvedValue({ - totalTokens: 5160, + expect(count).toBeGreaterThan(0); + expect(mockContentGenerator.countTokens).not.toHaveBeenCalled(); }); - const request = [ - { inlineData: { mimeType: 'application/pdf', data: 'pdf_data' } }, - ]; - const count = await calculateRequestTokenCount( - request, - mockContentGenerator, - model, - ); + it('should estimate tokens locally for simple ASCII text', async () => { + vi.mocked(mockContentGenerator.countTokens).mockClear(); + // 12 chars. 12 * 0.25 = 3 tokens. + const request = 'Hello world!'; - expect(count).toBe(5160); - expect(mockContentGenerator.countTokens).toHaveBeenCalled(); + const count = await calculateRequestTokenCount( + request, + mockContentGenerator, + model, + ); + + expect(count).toBe(3); + expect(mockContentGenerator.countTokens).not.toHaveBeenCalled(); + }); + + it('should estimate tokens locally for CJK text with higher weight', async () => { + vi.mocked(mockContentGenerator.countTokens).mockClear(); + // 2 chars. 2 * 1.3 = 2.6 -> floor(2.6) = 2. + const request = '你好'; + + const count = await calculateRequestTokenCount( + request, + mockContentGenerator, + model, + ); + + expect(count).toBeGreaterThanOrEqual(2); + expect(mockContentGenerator.countTokens).not.toHaveBeenCalled(); + }); + + it('should handle mixed content', async () => { + vi.mocked(mockContentGenerator.countTokens).mockClear(); + // 'Hi': 2 * 0.25 = 0.5 + // '你好': 2 * 1.3 = 2.6 + // Total: 3.1 -> 3 + const request = 'Hi你好'; + + const count = await calculateRequestTokenCount( + request, + mockContentGenerator, + model, + ); + + expect(count).toBe(3); + expect(mockContentGenerator.countTokens).not.toHaveBeenCalled(); + }); + + it('should handle empty text', async () => { + const request = ''; + const count = await calculateRequestTokenCount( + request, + mockContentGenerator, + model, + ); + expect(count).toBe(0); + }); + + it('should fallback to local estimation when countTokens API fails', async () => { + vi.mocked(mockContentGenerator.countTokens).mockRejectedValue( + new Error('API error'), + ); + const request = [ + { text: 'Hello' }, + { inlineData: { mimeType: 'image/png', data: 'data' } }, + ]; + + const count = await calculateRequestTokenCount( + request, + mockContentGenerator, + model, + ); + + expect(count).toBe(3001); + expect(mockContentGenerator.countTokens).toHaveBeenCalled(); + }); + + it('should use fixed estimate for images in fallback', async () => { + vi.mocked(mockContentGenerator.countTokens).mockRejectedValue( + new Error('API error'), + ); + const request = [ + { inlineData: { mimeType: 'image/png', data: 'large_data' } }, + ]; + + const count = await calculateRequestTokenCount( + request, + mockContentGenerator, + model, + ); + + expect(count).toBe(3000); + }); + + it('should use countTokens API for PDF requests', async () => { + vi.mocked(mockContentGenerator.countTokens).mockResolvedValue({ + totalTokens: 5160, + }); + const request = [ + { inlineData: { mimeType: 'application/pdf', data: 'pdf_data' } }, + ]; + + const count = await calculateRequestTokenCount( + request, + mockContentGenerator, + model, + ); + + expect(count).toBe(5160); + expect(mockContentGenerator.countTokens).toHaveBeenCalled(); + }); + + it('should use fixed estimate for PDFs in fallback', async () => { + vi.mocked(mockContentGenerator.countTokens).mockRejectedValue( + new Error('API error'), + ); + const request = [ + { inlineData: { mimeType: 'application/pdf', data: 'large_pdf_data' } }, + ]; + + const count = await calculateRequestTokenCount( + request, + mockContentGenerator, + model, + ); + + // PDF estimate: 25800 tokens (~100 pages at 258 tokens/page) + expect(count).toBe(25800); + }); }); - it('should use fixed estimate for PDFs in fallback', async () => { - vi.mocked(mockContentGenerator.countTokens).mockRejectedValue( - new Error('API error'), - ); - const request = [ - { inlineData: { mimeType: 'application/pdf', data: 'large_pdf_data' } }, - ]; + describe('estimateTokenCountSync', () => { + it('should use fast heuristic for massive strings', () => { + const massiveText = 'a'.repeat(200_000); + // 200,000 / 4 = 50,000 tokens + const parts: Part[] = [{ text: massiveText }]; + expect(estimateTokenCountSync(parts)).toBe(50000); + }); - const count = await calculateRequestTokenCount( - request, - mockContentGenerator, - model, - ); + it('should estimate functionResponse without full stringification', () => { + const toolResult = 'result'.repeat(1000); // 6000 chars + const parts: Part[] = [ + { + functionResponse: { + name: 'my_tool', + id: '123', + response: { output: toolResult }, + }, + }, + ]; - // PDF estimate: 25800 tokens (~100 pages at 258 tokens/page) - expect(count).toBe(25800); + const tokens = estimateTokenCountSync(parts); + // payload ~6013 chars / 4 = 1503.25 + // name 7 / 4 = 1.75 + // total ~1505 + expect(tokens).toBeGreaterThan(1500); + expect(tokens).toBeLessThan(1600); + }); + + it('should handle Gemini 3 multimodal nested parts in functionResponse', () => { + const parts: Part[] = [ + { + functionResponse: { + name: 'multimodal_tool', + id: '456', + response: { status: 'success' }, + // Gemini 3 nested parts + parts: [ + { inlineData: { mimeType: 'image/png', data: 'base64...' } }, + { text: 'Look at this image' }, + ] as Part[], + }, + }, + ]; + + const tokens = estimateTokenCountSync(parts); + // image 3000 + text 4.5 + response 5 = ~3009.5 + expect(tokens).toBeGreaterThan(3000); + expect(tokens).toBeLessThan(3100); + }); + + it('should respect the maximum recursion depth limit', () => { + // Create a structure nested to depth 5 (exceeding limit of 3) + const parts: Part[] = [ + { + functionResponse: { + name: 'd0', + response: { val: 'a' }, // ~12 chars -> 3 tokens + parts: [ + { + functionResponse: { + name: 'd1', + response: { val: 'a' }, // ~12 chars -> 3 tokens + parts: [ + { + functionResponse: { + name: 'd2', + response: { val: 'a' }, // ~12 chars -> 3 tokens + parts: [ + { + functionResponse: { + name: 'd3', + response: { val: 'a' }, // ~12 chars -> 3 tokens + parts: [ + { + functionResponse: { + name: 'd4', + response: { val: 'a' }, + }, + }, + ] as Part[], + }, + }, + ] as Part[], + }, + }, + ] as Part[], + }, + }, + ] as Part[], + }, + }, + ]; + + const tokens = estimateTokenCountSync(parts); + // It should count d0, d1, d2, d3 (depth 0, 1, 2, 3) but NOT d4 (depth 4) + // d0..d3: 4 * ~4 tokens = ~16 + expect(tokens).toBeGreaterThan(10); + expect(tokens).toBeLessThan(30); + }); + + it('should handle empty or nullish inputs gracefully', () => { + expect(estimateTokenCountSync([])).toBe(0); + expect(estimateTokenCountSync([{ text: '' }])).toBe(0); + expect(estimateTokenCountSync([{} as Part])).toBe(0); + }); }); }); diff --git a/packages/core/src/utils/tokenCalculation.ts b/packages/core/src/utils/tokenCalculation.ts index ba32a80a9e..447424531e 100644 --- a/packages/core/src/utils/tokenCalculation.ts +++ b/packages/core/src/utils/tokenCalculation.ts @@ -24,44 +24,104 @@ const PDF_TOKEN_ESTIMATE = 25800; // Above this, we use a faster approximation to avoid performance bottlenecks. const MAX_CHARS_FOR_FULL_HEURISTIC = 100_000; +// Maximum depth for recursive token estimation to prevent stack overflow from +// malicious or buggy nested structures. A depth of 3 is sufficient given +// standard multimodal responses are typically depth 1. +const MAX_RECURSION_DEPTH = 3; + +/** + * Heuristic estimation of tokens for a text string. + */ +function estimateTextTokens(text: string): number { + if (text.length > MAX_CHARS_FOR_FULL_HEURISTIC) { + return text.length / 4; + } + + let tokens = 0; + // Optimized loop: charCodeAt is faster than for...of on large strings + for (let i = 0; i < text.length; i++) { + if (text.charCodeAt(i) <= 127) { + tokens += ASCII_TOKENS_PER_CHAR; + } else { + tokens += NON_ASCII_TOKENS_PER_CHAR; + } + } + return tokens; +} + +/** + * Heuristic estimation for media parts (images, PDFs) using fixed safe estimates. + */ +function estimateMediaTokens(part: Part): number | undefined { + const inlineData = 'inlineData' in part ? part.inlineData : undefined; + const fileData = 'fileData' in part ? part.fileData : undefined; + const mimeType = inlineData?.mimeType || fileData?.mimeType; + + if (mimeType?.startsWith('image/')) { + // Images: 3,000 tokens (covers up to 4K resolution on Gemini 3) + // See: https://ai.google.dev/gemini-api/docs/vision#token_counting + return IMAGE_TOKEN_ESTIMATE; + } else if (mimeType?.startsWith('application/pdf')) { + // PDFs: 25,800 tokens (~100 pages at 258 tokens/page) + // See: https://ai.google.dev/gemini-api/docs/document-processing + return PDF_TOKEN_ESTIMATE; + } + return undefined; +} + +/** + * Heuristic estimation for tool responses, avoiding massive string copies + * and accounting for nested Gemini 3 multimodal parts. + */ +function estimateFunctionResponseTokens(part: Part, depth: number): number { + const fr = part.functionResponse; + if (!fr) return 0; + + let totalTokens = (fr.name?.length ?? 0) / 4; + const response = fr.response as unknown; + + if (typeof response === 'string') { + totalTokens += response.length / 4; + } else if (response !== undefined && response !== null) { + // For objects, stringify only the payload, not the whole Part object. + totalTokens += JSON.stringify(response).length / 4; + } + + // Gemini 3: Handle nested multimodal parts recursively. + const nestedParts = (fr as unknown as { parts?: Part[] }).parts; + if (nestedParts && nestedParts.length > 0) { + totalTokens += estimateTokenCountSync(nestedParts, depth + 1); + } + + return totalTokens; +} + /** * Estimates token count for parts synchronously using a heuristic. * - Text: character-based heuristic (ASCII vs CJK) for small strings, length/4 for massive ones. * - Non-text (Tools, etc): JSON string length / 4. */ -export function estimateTokenCountSync(parts: Part[]): number { +export function estimateTokenCountSync( + parts: Part[], + depth: number = 0, +): number { + if (depth > MAX_RECURSION_DEPTH) { + return 0; + } + let totalTokens = 0; for (const part of parts) { if (typeof part.text === 'string') { - if (part.text.length > MAX_CHARS_FOR_FULL_HEURISTIC) { - totalTokens += part.text.length / 4; - } else { - for (const char of part.text) { - if (char.codePointAt(0)! <= 127) { - totalTokens += ASCII_TOKENS_PER_CHAR; - } else { - totalTokens += NON_ASCII_TOKENS_PER_CHAR; - } - } - } + totalTokens += estimateTextTokens(part.text); + } else if (part.functionResponse) { + totalTokens += estimateFunctionResponseTokens(part, depth); } else { - // For images and PDFs, we use fixed safe estimates: - // - Images: 3,000 tokens (covers up to 4K resolution on Gemini 3) - // - PDFs: 25,800 tokens (~100 pages at 258 tokens/page) - // See: https://ai.google.dev/gemini-api/docs/vision#token_counting - // See: https://ai.google.dev/gemini-api/docs/document-processing - const inlineData = 'inlineData' in part ? part.inlineData : undefined; - const fileData = 'fileData' in part ? part.fileData : undefined; - const mimeType = inlineData?.mimeType || fileData?.mimeType; - - if (mimeType?.startsWith('image/')) { - totalTokens += IMAGE_TOKEN_ESTIMATE; - } else if (mimeType?.startsWith('application/pdf')) { - totalTokens += PDF_TOKEN_ESTIMATE; + const mediaEstimate = estimateMediaTokens(part); + if (mediaEstimate !== undefined) { + totalTokens += mediaEstimate; } else { - // For other non-text parts (functionCall, functionResponse, etc.), - // we fallback to the JSON string length heuristic. - // Note: This is an approximation. + // Fallback for other non-text parts (e.g., functionCall). + // Note: JSON.stringify(part) here is safe as these parts are typically small. totalTokens += JSON.stringify(part).length / 4; } }