mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-06-01 19:03:42 +00:00
Fix bulk of remaining issues with generalist profile (#26073)
This commit is contained in:
@@ -47,6 +47,7 @@ function resolveProcessorOptions<T>(
|
||||
}
|
||||
|
||||
export interface ContextProfile {
|
||||
name: string;
|
||||
config: ContextManagementConfig;
|
||||
buildPipelines: (
|
||||
env: ContextEnvironment,
|
||||
@@ -56,6 +57,10 @@ export interface ContextProfile {
|
||||
env: ContextEnvironment,
|
||||
config?: ContextManagementConfig,
|
||||
) => AsyncPipelineDef[];
|
||||
sentinels?: {
|
||||
continuation?: string;
|
||||
lostToolResponse?: string;
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -63,6 +68,12 @@ export interface ContextProfile {
|
||||
* Optimized for safety, precision, and reliable summarization.
|
||||
*/
|
||||
export const generalistProfile: ContextProfile = {
|
||||
name: 'Generalist (Default)',
|
||||
sentinels: {
|
||||
continuation: '[Continuing from previous AI thoughts...]',
|
||||
lostToolResponse:
|
||||
'The tool execution result was lost due to context management truncation.',
|
||||
},
|
||||
config: {
|
||||
budget: {
|
||||
retainedTokens: 65000,
|
||||
@@ -106,14 +117,14 @@ export const generalistProfile: ContextProfile = {
|
||||
'NodeDistillation',
|
||||
env,
|
||||
resolveProcessorOptions(config, 'NodeDistillation', {
|
||||
nodeThresholdTokens: 3000,
|
||||
nodeThresholdTokens: 1000,
|
||||
}),
|
||||
),
|
||||
createNodeTruncationProcessor(
|
||||
'NodeTruncation',
|
||||
env,
|
||||
resolveProcessorOptions(config, 'NodeTruncation', {
|
||||
maxTokensPerNode: 2000,
|
||||
maxTokensPerNode: 1200,
|
||||
}),
|
||||
),
|
||||
],
|
||||
@@ -158,6 +169,7 @@ export const generalistProfile: ContextProfile = {
|
||||
* within a few conversational turns.
|
||||
*/
|
||||
export const stressTestProfile: ContextProfile = {
|
||||
name: 'Stress Test',
|
||||
config: {
|
||||
budget: {
|
||||
retainedTokens: 4000,
|
||||
|
||||
@@ -14,9 +14,13 @@ vi.mock('node:fs/promises', () => ({
|
||||
writeFile: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock('node:fs', () => ({
|
||||
existsSync: vi.fn(),
|
||||
}));
|
||||
vi.mock('node:fs', async (importOriginal) => {
|
||||
const actual = await importOriginal<typeof import('node:fs')>();
|
||||
return {
|
||||
...actual,
|
||||
existsSync: vi.fn(),
|
||||
};
|
||||
});
|
||||
|
||||
describe('ContextCompressionService', () => {
|
||||
let mockConfig: Partial<Config>;
|
||||
|
||||
@@ -51,17 +51,18 @@ describe('ContextManager Sync Pressure Barrier Tests', () => {
|
||||
const rawHistoryLength = chatHistory.get().length;
|
||||
|
||||
// 5. Project History (Triggers Sync Barrier)
|
||||
const projection = await contextManager.renderHistory();
|
||||
const { history: projection } = await contextManager.renderHistory();
|
||||
|
||||
// 6. Assertions
|
||||
// The barrier should have dropped several older episodes to get under 150k.
|
||||
|
||||
expect(projection.length).toBeLessThan(rawHistoryLength);
|
||||
|
||||
// Verify Episode 0 (System) is perfectly preserved at the front
|
||||
|
||||
// Verify Episode 0 (System) was pruned, so we now start with a sentinel due to role alternation
|
||||
expect(projection[0].role).toBe('user');
|
||||
expect(projection[0].parts![0].text).toBe('System prompt');
|
||||
expect(projection[0].parts![0].text).toBe(
|
||||
'[Continuing from previous AI thoughts...]',
|
||||
);
|
||||
|
||||
// Filter out synthetic Yield nodes (they are model responses without actual tool/text bodies)
|
||||
const contentNodes = projection.filter(
|
||||
@@ -70,8 +71,14 @@ describe('ContextManager Sync Pressure Barrier Tests', () => {
|
||||
);
|
||||
|
||||
// Verify the latest turn is perfectly preserved at the back
|
||||
const lastUser = contentNodes[contentNodes.length - 2];
|
||||
const lastModel = contentNodes[contentNodes.length - 1];
|
||||
// Note: The HistoryHardener appends a "Please continue." user turn if we end on model,
|
||||
// so we look at the turns before the sentinel.
|
||||
const lastSentinel = contentNodes[contentNodes.length - 1];
|
||||
const lastModel = contentNodes[contentNodes.length - 2];
|
||||
const lastUser = contentNodes[contentNodes.length - 3];
|
||||
|
||||
expect(lastSentinel.role).toBe('user');
|
||||
expect(lastSentinel.parts![0].text).toBe('Please continue.');
|
||||
|
||||
expect(lastUser.role).toBe('user');
|
||||
expect(lastUser.parts![0].text).toBe('Final question.');
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
|
||||
import type { Content } from '@google/genai';
|
||||
import type { AgentChatHistory } from '../core/agentChatHistory.js';
|
||||
import type { ConcreteNode } from './graph/types.js';
|
||||
import { isToolExecution, type ConcreteNode } from './graph/types.js';
|
||||
import type { ContextEventBus } from './eventBus.js';
|
||||
import type { ContextTracer } from './tracer.js';
|
||||
import type { ContextEnvironment } from './pipeline/environment.js';
|
||||
@@ -15,6 +15,9 @@ import type { PipelineOrchestrator } from './pipeline/orchestrator.js';
|
||||
import { HistoryObserver } from './historyObserver.js';
|
||||
import { render } from './graph/render.js';
|
||||
import { ContextWorkingBufferImpl } from './pipeline/contextWorkingBuffer.js';
|
||||
import { debugLogger } from '../utils/debugLogger.js';
|
||||
import { hardenHistory } from '../utils/historyHardening.js';
|
||||
import { checkContextInvariants } from './utils/invariantChecker.js';
|
||||
|
||||
export class ContextManager {
|
||||
// The master state containing the pristine graph and current active graph.
|
||||
@@ -27,21 +30,30 @@ export class ContextManager {
|
||||
private readonly orchestrator: PipelineOrchestrator;
|
||||
private readonly historyObserver: HistoryObserver;
|
||||
|
||||
// Cache for Anomaly 3 (Redundant Renders)
|
||||
private lastRenderCache?: {
|
||||
nodesHash: string;
|
||||
result: { history: Content[]; didApplyManagement: boolean };
|
||||
};
|
||||
|
||||
constructor(
|
||||
private readonly sidecar: ContextProfile,
|
||||
private readonly env: ContextEnvironment,
|
||||
private readonly tracer: ContextTracer,
|
||||
orchestrator: PipelineOrchestrator,
|
||||
chatHistory: AgentChatHistory,
|
||||
private readonly headerProvider?: () => Promise<Content | undefined>,
|
||||
) {
|
||||
this.eventBus = env.eventBus;
|
||||
this.orchestrator = orchestrator;
|
||||
|
||||
// Provide the orchestrator with a way to fetch the latest nodes from the live buffer
|
||||
this.orchestrator.setNodeProvider(() => this.buffer.nodes);
|
||||
|
||||
this.historyObserver = new HistoryObserver(
|
||||
chatHistory,
|
||||
this.env.eventBus,
|
||||
this.tracer,
|
||||
this.env.tokenCalculator,
|
||||
this.env.graphMapper,
|
||||
);
|
||||
|
||||
@@ -69,6 +81,13 @@ export class ContextManager {
|
||||
this.historyObserver.start();
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a promise that resolves when all currently executing async pipelines have finished.
|
||||
*/
|
||||
async waitForPipelines(): Promise<void> {
|
||||
return this.orchestrator.waitForPipelines();
|
||||
}
|
||||
|
||||
/**
|
||||
* Safely stops background async pipelines and clears event listeners.
|
||||
*/
|
||||
@@ -98,6 +117,15 @@ export class ContextManager {
|
||||
if (currentTokens > this.sidecar.config.budget.retainedTokens) {
|
||||
const agedOutNodes = new Set<string>();
|
||||
let rollingTokens = 0;
|
||||
|
||||
// Identify active tool calls that must NEVER be truncated
|
||||
const protectedIds = this.getProtectedNodeIds(this.buffer.nodes);
|
||||
if (protectedIds.size > 0) {
|
||||
debugLogger.log(
|
||||
`[ContextManager] Pinning ${protectedIds.size} active tool call nodes to prevent truncation.`,
|
||||
);
|
||||
}
|
||||
|
||||
// Walk backwards finding nodes that fall out of the retained budget
|
||||
for (let i = this.buffer.nodes.length - 1; i >= 0; i--) {
|
||||
const node = this.buffer.nodes[i];
|
||||
@@ -105,7 +133,10 @@ export class ContextManager {
|
||||
node,
|
||||
]);
|
||||
if (rollingTokens > this.sidecar.config.budget.retainedTokens) {
|
||||
agedOutNodes.add(node.id);
|
||||
// Only age out if not protected
|
||||
if (!protectedIds.has(node.id)) {
|
||||
agedOutNodes.add(node.id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -123,6 +154,54 @@ export class ContextManager {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Identifies 'pinned' nodes that should not be truncated.
|
||||
* This includes:
|
||||
* 1. The entire last turn (Recent context).
|
||||
* 2. Active tool calls (calls without responses in the graph).
|
||||
*/
|
||||
private getProtectedNodeIds(
|
||||
nodes: readonly ConcreteNode[],
|
||||
extraProtectedIds: Set<string> = new Set(),
|
||||
): Map<string, string> {
|
||||
const protectionMap = new Map<string, string>();
|
||||
if (nodes.length === 0) return protectionMap;
|
||||
|
||||
// 1. Identify all nodes belonging to the last turn (Recent context)
|
||||
const lastNode = nodes[nodes.length - 1];
|
||||
const lastTurnId = lastNode.turnId;
|
||||
|
||||
for (const node of nodes) {
|
||||
if (node.turnId === lastTurnId) {
|
||||
protectionMap.set(node.id, 'recent_turn');
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Identify active tool calls that must NEVER be truncated
|
||||
const calls = nodes.filter((n) => isToolExecution(n) && n.role === 'model');
|
||||
const responses = new Set(
|
||||
nodes
|
||||
.filter((n) => isToolExecution(n) && n.role === 'user')
|
||||
.map((n) => n.payload.functionResponse?.id)
|
||||
.filter((id): id is string => !!id),
|
||||
);
|
||||
|
||||
for (const call of calls) {
|
||||
const id = call.payload.functionCall?.id;
|
||||
// If we have a call but no response in the current graph, it's 'in flight'
|
||||
if (id && !responses.has(id)) {
|
||||
protectionMap.set(call.id, 'in_flight_tool_call');
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Any externally requested protections
|
||||
for (const id of extraProtectedIds) {
|
||||
protectionMap.set(id, 'external_active_task');
|
||||
}
|
||||
|
||||
return protectionMap;
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves the raw, uncompressed Episodic Context Graph graph.
|
||||
* Useful for internal tool rendering (like the trace viewer).
|
||||
@@ -157,22 +236,78 @@ export class ContextManager {
|
||||
* This is the primary method called by the agent framework before sending a request.
|
||||
*/
|
||||
async renderHistory(
|
||||
pendingRequest?: Content,
|
||||
activeTaskIds: Set<string> = new Set(),
|
||||
): Promise<Content[]> {
|
||||
): Promise<{ history: Content[]; didApplyManagement: boolean }> {
|
||||
this.tracer.logEvent('ContextManager', 'Starting rendering of LLM context');
|
||||
|
||||
// 1. Synchronous Pressure Barrier: Wait for background management pipelines to finish.
|
||||
// This ensures that the render sees the results of recent pushes (Anomaly 2).
|
||||
await this.orchestrator.waitForPipelines();
|
||||
|
||||
let nodes = this.buffer.nodes;
|
||||
|
||||
// If we have a pending request, we need to build a 'preview' graph for this render.
|
||||
if (pendingRequest) {
|
||||
const previewNodes = this.env.graphMapper.applyEvent({
|
||||
type: 'PUSH',
|
||||
payload: [pendingRequest],
|
||||
});
|
||||
nodes = [...nodes, ...previewNodes];
|
||||
}
|
||||
|
||||
// 2. Fetch Header and calculate tokens
|
||||
const header = this.headerProvider
|
||||
? await this.headerProvider()
|
||||
: undefined;
|
||||
const headerTokens = header
|
||||
? this.env.tokenCalculator.calculateContentTokens(header)
|
||||
: 0;
|
||||
|
||||
// 3. Cache Check (Anomaly 3): If nodes haven't changed, return previous result.
|
||||
// We combine the graph hash with a hash of the header to ensure total freshness.
|
||||
const graphHash = nodes.map((n) => n.id).join('|');
|
||||
const headerHash = header ? JSON.stringify(header.parts) : 'no-header';
|
||||
const totalHash = `${graphHash}::${headerHash}`;
|
||||
|
||||
if (this.lastRenderCache?.nodesHash === totalHash) {
|
||||
debugLogger.log(
|
||||
'[ContextManager] Render cache hit. Skipping redundant render.',
|
||||
);
|
||||
return this.lastRenderCache.result;
|
||||
}
|
||||
|
||||
const protectionReasons = this.getProtectedNodeIds(nodes, activeTaskIds);
|
||||
|
||||
// Apply final GC Backstop pressure barrier synchronously before mapping
|
||||
const finalHistory = await render(
|
||||
this.buffer.nodes,
|
||||
const { history: renderedHistory, didApplyManagement } = await render(
|
||||
nodes,
|
||||
this.orchestrator,
|
||||
this.sidecar,
|
||||
this.tracer,
|
||||
this.env,
|
||||
activeTaskIds,
|
||||
protectionReasons,
|
||||
headerTokens,
|
||||
);
|
||||
|
||||
// Structural validation in debug mode
|
||||
checkContextInvariants(this.buffer.nodes, 'RenderHistory');
|
||||
|
||||
this.tracer.logEvent('ContextManager', 'Finished rendering');
|
||||
|
||||
return finalHistory;
|
||||
const combinedHistory = header
|
||||
? [header, ...renderedHistory]
|
||||
: renderedHistory;
|
||||
|
||||
const result = {
|
||||
history: hardenHistory(combinedHistory, {
|
||||
sentinels: this.sidecar.sentinels,
|
||||
}),
|
||||
didApplyManagement,
|
||||
};
|
||||
|
||||
// Update cache
|
||||
this.lastRenderCache = { nodesHash: totalHash, result };
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,21 +3,11 @@
|
||||
* Copyright 2026 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
import type { Content, Part } from '@google/genai';
|
||||
import type { ConcreteNode } from './types.js';
|
||||
|
||||
export interface NodeSerializationWriter {
|
||||
appendContent(content: Content): void;
|
||||
appendModelPart(part: Part): void;
|
||||
appendUserPart(part: Part): void;
|
||||
flushModelParts(): void;
|
||||
}
|
||||
import type { Part } from '@google/genai';
|
||||
import type { ConcreteNode, NodeType } from './types.js';
|
||||
|
||||
export interface NodeBehavior<T extends ConcreteNode = ConcreteNode> {
|
||||
readonly type: T['type'];
|
||||
|
||||
/** Serializes the node into the Gemini Content structure. */
|
||||
serialize(node: T, writer: NodeSerializationWriter): void;
|
||||
readonly type: NodeType;
|
||||
|
||||
/**
|
||||
* Generates a structural representation of the node for the purpose
|
||||
@@ -27,13 +17,13 @@ export interface NodeBehavior<T extends ConcreteNode = ConcreteNode> {
|
||||
}
|
||||
|
||||
export class NodeBehaviorRegistry {
|
||||
private readonly behaviors = new Map<string, NodeBehavior<ConcreteNode>>();
|
||||
private readonly behaviors = new Map<NodeType, NodeBehavior<ConcreteNode>>();
|
||||
|
||||
register<T extends ConcreteNode>(behavior: NodeBehavior<T>) {
|
||||
this.behaviors.set(behavior.type, behavior);
|
||||
}
|
||||
|
||||
get(type: string): NodeBehavior<ConcreteNode> {
|
||||
get(type: NodeType): NodeBehavior<ConcreteNode> {
|
||||
const behavior = this.behaviors.get(type);
|
||||
if (!behavior) {
|
||||
throw new Error(`Unregistered Node type: ${type}`);
|
||||
|
||||
@@ -3,160 +3,72 @@
|
||||
* Copyright 2026 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
import type { Part } from '@google/genai';
|
||||
import type { NodeBehavior, NodeBehaviorRegistry } from './behaviorRegistry.js';
|
||||
import type {
|
||||
UserPrompt,
|
||||
AgentThought,
|
||||
ToolExecution,
|
||||
MaskedTool,
|
||||
AgentYield,
|
||||
Snapshot,
|
||||
RollingSummary,
|
||||
SystemEvent,
|
||||
import {
|
||||
type UserPrompt,
|
||||
type AgentThought,
|
||||
type ToolExecution,
|
||||
type MaskedTool,
|
||||
type AgentYield,
|
||||
type Snapshot,
|
||||
type RollingSummary,
|
||||
type SystemEvent,
|
||||
NodeType,
|
||||
} from './types.js';
|
||||
|
||||
export const UserPromptBehavior: NodeBehavior<UserPrompt> = {
|
||||
type: 'USER_PROMPT',
|
||||
getEstimatableParts(prompt) {
|
||||
const parts: Part[] = [];
|
||||
for (const sp of prompt.semanticParts) {
|
||||
switch (sp.type) {
|
||||
case 'text':
|
||||
parts.push({ text: sp.text });
|
||||
break;
|
||||
case 'inline_data':
|
||||
parts.push({ inlineData: { mimeType: sp.mimeType, data: sp.data } });
|
||||
break;
|
||||
case 'file_data':
|
||||
parts.push({
|
||||
fileData: { mimeType: sp.mimeType, fileUri: sp.fileUri },
|
||||
});
|
||||
break;
|
||||
case 'raw_part':
|
||||
parts.push(sp.part);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
return parts;
|
||||
},
|
||||
serialize(prompt, writer) {
|
||||
const parts = this.getEstimatableParts(prompt);
|
||||
if (parts.length > 0) {
|
||||
writer.flushModelParts();
|
||||
writer.appendContent({ role: 'user', parts });
|
||||
}
|
||||
type: NodeType.USER_PROMPT,
|
||||
getEstimatableParts(node) {
|
||||
return [node.payload];
|
||||
},
|
||||
};
|
||||
|
||||
export const AgentThoughtBehavior: NodeBehavior<AgentThought> = {
|
||||
type: 'AGENT_THOUGHT',
|
||||
getEstimatableParts(thought) {
|
||||
return [{ text: thought.text }];
|
||||
},
|
||||
serialize(thought, writer) {
|
||||
writer.appendModelPart({ text: thought.text });
|
||||
type: NodeType.AGENT_THOUGHT,
|
||||
getEstimatableParts(node) {
|
||||
return [node.payload];
|
||||
},
|
||||
};
|
||||
|
||||
export const ToolExecutionBehavior: NodeBehavior<ToolExecution> = {
|
||||
type: 'TOOL_EXECUTION',
|
||||
getEstimatableParts(tool) {
|
||||
return [
|
||||
{ functionCall: { id: tool.id, name: tool.toolName, args: tool.intent } },
|
||||
{
|
||||
functionResponse: {
|
||||
id: tool.id,
|
||||
name: tool.toolName,
|
||||
response:
|
||||
typeof tool.observation === 'string'
|
||||
? { message: tool.observation }
|
||||
: tool.observation,
|
||||
},
|
||||
},
|
||||
];
|
||||
},
|
||||
serialize(tool, writer) {
|
||||
const parts = this.getEstimatableParts(tool);
|
||||
writer.appendModelPart(parts[0]);
|
||||
writer.flushModelParts();
|
||||
writer.appendUserPart(parts[1]);
|
||||
type: NodeType.TOOL_EXECUTION,
|
||||
getEstimatableParts(node) {
|
||||
return [node.payload];
|
||||
},
|
||||
};
|
||||
|
||||
export const MaskedToolBehavior: NodeBehavior<MaskedTool> = {
|
||||
type: 'MASKED_TOOL',
|
||||
getEstimatableParts(tool) {
|
||||
return [
|
||||
{
|
||||
functionCall: {
|
||||
id: tool.id,
|
||||
name: tool.toolName,
|
||||
args: tool.intent ?? {},
|
||||
},
|
||||
},
|
||||
{
|
||||
functionResponse: {
|
||||
id: tool.id,
|
||||
name: tool.toolName,
|
||||
response:
|
||||
typeof tool.observation === 'string'
|
||||
? { message: tool.observation }
|
||||
: (tool.observation ?? {}),
|
||||
},
|
||||
},
|
||||
];
|
||||
},
|
||||
serialize(tool, writer) {
|
||||
const parts = this.getEstimatableParts(tool);
|
||||
writer.appendModelPart(parts[0]);
|
||||
writer.flushModelParts();
|
||||
writer.appendUserPart(parts[1]);
|
||||
type: NodeType.MASKED_TOOL,
|
||||
getEstimatableParts(node) {
|
||||
return [node.payload];
|
||||
},
|
||||
};
|
||||
|
||||
export const AgentYieldBehavior: NodeBehavior<AgentYield> = {
|
||||
type: 'AGENT_YIELD',
|
||||
getEstimatableParts(yieldNode) {
|
||||
return [{ text: yieldNode.text }];
|
||||
},
|
||||
serialize() {
|
||||
// AGENT_YIELD is a synthetic marker node used for internal graph tracking.
|
||||
// We intentionally do NOT serialize it to the LLM to prevent prompt corruption.
|
||||
type: NodeType.AGENT_YIELD,
|
||||
getEstimatableParts() {
|
||||
return [];
|
||||
},
|
||||
};
|
||||
|
||||
export const SystemEventBehavior: NodeBehavior<SystemEvent> = {
|
||||
type: 'SYSTEM_EVENT',
|
||||
getEstimatableParts() {
|
||||
return [];
|
||||
},
|
||||
serialize(node, writer) {
|
||||
writer.flushModelParts();
|
||||
type: NodeType.SYSTEM_EVENT,
|
||||
getEstimatableParts(node) {
|
||||
return [node.payload];
|
||||
},
|
||||
};
|
||||
|
||||
export const SnapshotBehavior: NodeBehavior<Snapshot> = {
|
||||
type: 'SNAPSHOT',
|
||||
type: NodeType.SNAPSHOT,
|
||||
getEstimatableParts(node) {
|
||||
return [{ text: node.text }];
|
||||
},
|
||||
serialize(node, writer) {
|
||||
writer.flushModelParts();
|
||||
writer.appendUserPart({ text: node.text });
|
||||
return [node.payload];
|
||||
},
|
||||
};
|
||||
|
||||
export const RollingSummaryBehavior: NodeBehavior<RollingSummary> = {
|
||||
type: 'ROLLING_SUMMARY',
|
||||
type: NodeType.ROLLING_SUMMARY,
|
||||
getEstimatableParts(node) {
|
||||
return [{ text: node.text }];
|
||||
},
|
||||
serialize(node, writer) {
|
||||
writer.flushModelParts();
|
||||
writer.appendUserPart({ text: node.text });
|
||||
return [node.payload];
|
||||
},
|
||||
};
|
||||
|
||||
|
||||
@@ -3,52 +3,53 @@
|
||||
* Copyright 2026 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
import type { Content, Part } from '@google/genai';
|
||||
|
||||
import type { Content } from '@google/genai';
|
||||
import type { ConcreteNode } from './types.js';
|
||||
import type {
|
||||
NodeSerializationWriter,
|
||||
NodeBehaviorRegistry,
|
||||
} from './behaviorRegistry.js';
|
||||
import { debugLogger } from '../../utils/debugLogger.js';
|
||||
|
||||
class NodeSerializer implements NodeSerializationWriter {
|
||||
private history: Content[] = [];
|
||||
private currentModelParts: Part[] = [];
|
||||
/**
|
||||
* Reconstructs a valid Gemini Chat History from a list of Concrete Nodes.
|
||||
* This process is "role-alternation-aware" and uses turnId to
|
||||
* preserve original turn boundaries even if multiple turns have the same role.
|
||||
*/
|
||||
export function fromGraph(nodes: readonly ConcreteNode[]): Content[] {
|
||||
debugLogger.log(
|
||||
`[fromGraph] Reconstructing history from ${nodes.length} nodes`,
|
||||
);
|
||||
|
||||
appendContent(content: Content) {
|
||||
this.flushModelParts();
|
||||
this.history.push(content);
|
||||
}
|
||||
const history: Content[] = [];
|
||||
let currentTurn: (Content & { _turnId?: string }) | null = null;
|
||||
|
||||
appendModelPart(part: Part) {
|
||||
this.currentModelParts.push(part);
|
||||
}
|
||||
for (const node of nodes) {
|
||||
const turnId = node.turnId;
|
||||
|
||||
appendUserPart(part: Part) {
|
||||
this.flushModelParts();
|
||||
this.history.push({ role: 'user', parts: [part] });
|
||||
}
|
||||
|
||||
flushModelParts() {
|
||||
if (this.currentModelParts.length > 0) {
|
||||
this.history.push({ role: 'model', parts: [...this.currentModelParts] });
|
||||
this.currentModelParts = [];
|
||||
// We start a new turn if:
|
||||
// 1. We don't have a current turn.
|
||||
// 2. The role changes (Standard alternation).
|
||||
// 3. The turnId changes (Preserving distinct turns of the same role).
|
||||
if (
|
||||
!currentTurn ||
|
||||
currentTurn.role !== node.role ||
|
||||
currentTurn._turnId !== turnId
|
||||
) {
|
||||
currentTurn = {
|
||||
role: node.role,
|
||||
parts: [node.payload],
|
||||
_turnId: turnId,
|
||||
};
|
||||
history.push(currentTurn);
|
||||
} else {
|
||||
currentTurn.parts = [...(currentTurn.parts || []), node.payload];
|
||||
}
|
||||
}
|
||||
|
||||
getContents(): Content[] {
|
||||
this.flushModelParts();
|
||||
return this.history;
|
||||
// Final cleanup: remove our internal tracking field
|
||||
for (const turn of history) {
|
||||
const t = turn as Content & { _turnId?: string };
|
||||
delete t._turnId;
|
||||
}
|
||||
}
|
||||
|
||||
export function fromGraph(
|
||||
nodes: readonly ConcreteNode[],
|
||||
registry: NodeBehaviorRegistry,
|
||||
): Content[] {
|
||||
const writer = new NodeSerializer();
|
||||
for (const node of nodes) {
|
||||
const behavior = registry.get(node.type);
|
||||
behavior.serialize(node, writer);
|
||||
}
|
||||
return writer.getContents();
|
||||
debugLogger.log(`[fromGraph] Reconstructed ${history.length} turns`);
|
||||
return history;
|
||||
}
|
||||
|
||||
@@ -8,41 +8,20 @@ import { ContextGraphBuilder } from './toGraph.js';
|
||||
import type { Content } from '@google/genai';
|
||||
import type { HistoryEvent } from '../../core/agentChatHistory.js';
|
||||
import { fromGraph } from './fromGraph.js';
|
||||
import type { ContextTokenCalculator } from '../utils/contextTokenCalculator.js';
|
||||
import type { NodeBehaviorRegistry } from './behaviorRegistry.js';
|
||||
|
||||
export class ContextGraphMapper {
|
||||
private readonly nodeIdentityMap = new WeakMap<object, string>();
|
||||
private readonly builder: ContextGraphBuilder;
|
||||
|
||||
constructor(private readonly registry: NodeBehaviorRegistry) {}
|
||||
constructor() {
|
||||
this.builder = new ContextGraphBuilder(this.nodeIdentityMap);
|
||||
}
|
||||
|
||||
private builder?: ContextGraphBuilder;
|
||||
|
||||
applyEvent(
|
||||
event: HistoryEvent,
|
||||
tokenCalculator: ContextTokenCalculator,
|
||||
): ConcreteNode[] {
|
||||
if (!this.builder) {
|
||||
this.builder = new ContextGraphBuilder(
|
||||
tokenCalculator,
|
||||
this.nodeIdentityMap,
|
||||
);
|
||||
}
|
||||
|
||||
if (event.type === 'CLEAR') {
|
||||
this.builder.clear();
|
||||
return [];
|
||||
}
|
||||
|
||||
if (event.type === 'SYNC_FULL') {
|
||||
this.builder.clear();
|
||||
}
|
||||
|
||||
this.builder.processHistory(event.payload);
|
||||
return this.builder.getNodes();
|
||||
applyEvent(event: HistoryEvent): ConcreteNode[] {
|
||||
return this.builder.processHistory(event.payload);
|
||||
}
|
||||
|
||||
fromGraph(nodes: readonly ConcreteNode[]): Content[] {
|
||||
return fromGraph(nodes, this.registry);
|
||||
return fromGraph(nodes);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,17 +6,14 @@
|
||||
|
||||
import type { Content } from '@google/genai';
|
||||
import type { ConcreteNode } from './types.js';
|
||||
import { debugLogger } from '../../utils/debugLogger.js';
|
||||
import type {
|
||||
ContextEnvironment,
|
||||
ContextTracer,
|
||||
} from '../pipeline/environment.js';
|
||||
import type { PipelineOrchestrator } from '../pipeline/orchestrator.js';
|
||||
import type { ContextTracer } from '../tracer.js';
|
||||
import type { ContextProfile } from '../config/profiles.js';
|
||||
import type { PipelineOrchestrator } from '../pipeline/orchestrator.js';
|
||||
import type { ContextEnvironment } from '../pipeline/environment.js';
|
||||
|
||||
/**
|
||||
* Orchestrates the final render: takes a working buffer view (The Nodes),
|
||||
* applies the Immediate Sanitization pipeline, and enforces token boundaries.
|
||||
* Maps the Episodic Context Graph back into a raw Gemini Content[] array for transmission.
|
||||
* It applies synchronous context management (GC backstop) if the budget is exceeded.
|
||||
*/
|
||||
export async function render(
|
||||
nodes: readonly ConcreteNode[],
|
||||
@@ -24,28 +21,40 @@ export async function render(
|
||||
sidecar: ContextProfile,
|
||||
tracer: ContextTracer,
|
||||
env: ContextEnvironment,
|
||||
protectedIds: Set<string>,
|
||||
): Promise<Content[]> {
|
||||
protectionReasons: Map<string, string> = new Map(),
|
||||
headerTokens: number = 0,
|
||||
): Promise<{ history: Content[]; didApplyManagement: boolean }> {
|
||||
if (!sidecar.config.budget) {
|
||||
const contents = env.graphMapper.fromGraph(nodes);
|
||||
tracer.logEvent('Render', 'Render Context to LLM (No Budget)', {
|
||||
renderedContext: contents,
|
||||
});
|
||||
return contents;
|
||||
return { history: contents, didApplyManagement: false };
|
||||
}
|
||||
|
||||
const maxTokens = sidecar.config.budget.maxTokens;
|
||||
const currentTokens = env.tokenCalculator.calculateConcreteListTokens(nodes);
|
||||
const graphTokens = env.tokenCalculator.calculateConcreteListTokens(nodes);
|
||||
const currentTokens = graphTokens + headerTokens;
|
||||
|
||||
// V0: Always protect the first node (System Prompt) and the last turn
|
||||
if (nodes.length > 0) {
|
||||
protectedIds.add(nodes[0].id);
|
||||
if (nodes[0].logicalParentId) protectedIds.add(nodes[0].logicalParentId);
|
||||
const protectedIds = new Set(protectionReasons.keys());
|
||||
|
||||
const lastNode = nodes[nodes.length - 1];
|
||||
protectedIds.add(lastNode.id);
|
||||
if (lastNode.logicalParentId) protectedIds.add(lastNode.logicalParentId);
|
||||
}
|
||||
tracer.logEvent('Render', 'Budget Audit', {
|
||||
maxTokens,
|
||||
retainedTokens: sidecar.config.budget.retainedTokens,
|
||||
graphTokens,
|
||||
headerTokens,
|
||||
currentTokens,
|
||||
pressure: (currentTokens / maxTokens).toFixed(2),
|
||||
isOverBudget: currentTokens > maxTokens,
|
||||
});
|
||||
|
||||
tracer.logEvent('Render', 'Estimation Calibration', {
|
||||
breakdown: env.tokenCalculator.calculateTokenBreakdown(nodes),
|
||||
});
|
||||
|
||||
tracer.logEvent('Render', 'Protection Audit', {
|
||||
reasons: Object.fromEntries(protectionReasons),
|
||||
});
|
||||
|
||||
if (currentTokens <= maxTokens) {
|
||||
tracer.logEvent(
|
||||
@@ -56,15 +65,14 @@ export async function render(
|
||||
tracer.logEvent('Render', 'Render Context for LLM', {
|
||||
renderedContext: contents,
|
||||
});
|
||||
return contents;
|
||||
return { history: contents, didApplyManagement: false };
|
||||
}
|
||||
|
||||
const targetDelta = currentTokens - sidecar.config.budget.retainedTokens;
|
||||
tracer.logEvent(
|
||||
'Render',
|
||||
`View exceeds maxTokens (${currentTokens} > ${maxTokens}). Hitting Synchronous Pressure Barrier.`,
|
||||
);
|
||||
debugLogger.log(
|
||||
`Context Manager Synchronous Barrier triggered: View at ${currentTokens} tokens (limit: ${maxTokens}).`,
|
||||
{ targetDelta },
|
||||
);
|
||||
|
||||
// Calculate exactly which nodes aged out of the retainedTokens budget to form our target delta
|
||||
@@ -87,16 +95,6 @@ export async function render(
|
||||
protectedIds,
|
||||
);
|
||||
|
||||
const finalTokens =
|
||||
env.tokenCalculator.calculateConcreteListTokens(processedNodes);
|
||||
tracer.logEvent(
|
||||
'Render',
|
||||
`Finished rendering. Final token count: ${finalTokens}.`,
|
||||
);
|
||||
debugLogger.log(
|
||||
`Context Manager finished. Final actual token count: ${finalTokens}.`,
|
||||
);
|
||||
|
||||
// Apply skipList logic to abstract over summarized nodes
|
||||
const skipList = new Set<string>();
|
||||
for (const node of processedNodes) {
|
||||
@@ -111,5 +109,5 @@ export async function render(
|
||||
tracer.logEvent('Render', 'Render Sanitized Context for LLM', {
|
||||
renderedContextSanitized: contents,
|
||||
});
|
||||
return contents;
|
||||
return { history: contents, didApplyManagement: true };
|
||||
}
|
||||
|
||||
@@ -5,294 +5,227 @@
|
||||
*/
|
||||
|
||||
import type { Content, Part } from '@google/genai';
|
||||
import type {
|
||||
ConcreteNode,
|
||||
Episode,
|
||||
SemanticPart,
|
||||
ToolExecution,
|
||||
AgentThought,
|
||||
AgentYield,
|
||||
UserPrompt,
|
||||
} from './types.js';
|
||||
import type { ContextTokenCalculator } from '../utils/contextTokenCalculator.js';
|
||||
import { randomUUID } from 'node:crypto';
|
||||
import { isRecord } from '../../utils/markdownUtils.js';
|
||||
import { type ConcreteNode, NodeType } from './types.js';
|
||||
import { randomUUID, createHash } from 'node:crypto';
|
||||
import { debugLogger } from '../../utils/debugLogger.js';
|
||||
|
||||
// We remove the global nodeIdentityMap and instead rely on one passed from ContextGraphMapper
|
||||
export function getStableId(
|
||||
obj: object,
|
||||
nodeIdentityMap: WeakMap<object, string>,
|
||||
): string {
|
||||
let id = nodeIdentityMap.get(obj);
|
||||
if (!id) {
|
||||
id = randomUUID();
|
||||
nodeIdentityMap.set(obj, id);
|
||||
}
|
||||
return id;
|
||||
interface PartWithSynthId extends Part {
|
||||
_synthId?: string;
|
||||
}
|
||||
|
||||
function isCompleteEpisode(ep: Partial<Episode>): ep is Episode {
|
||||
// Global WeakMap to cache hashes for Part objects.
|
||||
// This optimizes getStableId by avoiding redundant stringify/hash operations
|
||||
// on the same object instances across multiple management passes.
|
||||
const PART_HASH_CACHE = new WeakMap<object, string>();
|
||||
|
||||
function isTextPart(part: Part): part is Part & { text: string } {
|
||||
return typeof part.text === 'string';
|
||||
}
|
||||
|
||||
function isInlineDataPart(
|
||||
part: Part,
|
||||
): part is Part & { inlineData: { data: string } } {
|
||||
return (
|
||||
typeof ep.id === 'string' &&
|
||||
Array.isArray(ep.concreteNodes) &&
|
||||
ep.concreteNodes.length > 0
|
||||
typeof part.inlineData === 'object' &&
|
||||
part.inlineData !== null &&
|
||||
typeof part.inlineData.data === 'string'
|
||||
);
|
||||
}
|
||||
|
||||
export class ContextGraphBuilder {
|
||||
private episodes: Episode[] = [];
|
||||
private currentEpisode: Partial<Episode> | null = null;
|
||||
private pendingCallParts: Map<string, Part> = new Map();
|
||||
private pendingCallPartsWithoutId: Part[] = [];
|
||||
function isFileDataPart(
|
||||
part: Part,
|
||||
): part is Part & { fileData: { fileUri: string } } {
|
||||
return (
|
||||
typeof part.fileData === 'object' &&
|
||||
part.fileData !== null &&
|
||||
typeof part.fileData.fileUri === 'string'
|
||||
);
|
||||
}
|
||||
|
||||
function isFunctionCallPart(
|
||||
part: Part,
|
||||
): part is Part & { functionCall: { id: string; name: string } } {
|
||||
return (
|
||||
typeof part.functionCall === 'object' &&
|
||||
part.functionCall !== null &&
|
||||
typeof part.functionCall.name === 'string'
|
||||
);
|
||||
}
|
||||
|
||||
function isFunctionResponsePart(
|
||||
part: Part,
|
||||
): part is Part & { functionResponse: { id: string; name: string } } {
|
||||
return (
|
||||
typeof part.functionResponse === 'object' &&
|
||||
part.functionResponse !== null &&
|
||||
typeof part.functionResponse.name === 'string'
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates a stable ID for an object reference using a WeakMap.
|
||||
* Falls back to content-based hashing for Part-like objects to ensure
|
||||
* stability across object re-creations (e.g. during history mapping).
|
||||
*/
|
||||
export function getStableId(
|
||||
obj: object,
|
||||
nodeIdentityMap: WeakMap<object, string>,
|
||||
turnSalt: string = '',
|
||||
partIdx: number = 0,
|
||||
): string {
|
||||
let id = nodeIdentityMap.get(obj);
|
||||
if (id) return id;
|
||||
|
||||
const cachedHash = PART_HASH_CACHE.get(obj);
|
||||
if (cachedHash) {
|
||||
id = `${cachedHash}_${turnSalt}_${partIdx}`;
|
||||
nodeIdentityMap.set(obj, id);
|
||||
return id;
|
||||
}
|
||||
|
||||
const part = obj as PartWithSynthId;
|
||||
let contentHash: string | undefined;
|
||||
|
||||
// If the object already has a synthetic ID property, use it.
|
||||
if (typeof part._synthId === 'string') {
|
||||
id = part._synthId;
|
||||
} else if (isTextPart(part)) {
|
||||
contentHash = createHash('sha256').update(part.text).digest('hex');
|
||||
id = `text_${contentHash}_${turnSalt}_${partIdx}`;
|
||||
} else if (isInlineDataPart(part)) {
|
||||
contentHash = createHash('sha256')
|
||||
.update(part.inlineData.data)
|
||||
.digest('hex');
|
||||
id = `media_${contentHash}_${turnSalt}_${partIdx}`;
|
||||
} else if (isFileDataPart(part)) {
|
||||
contentHash = createHash('sha256')
|
||||
.update(part.fileData.fileUri)
|
||||
.digest('hex');
|
||||
id = `file_${contentHash}_${turnSalt}_${partIdx}`;
|
||||
} else if (isFunctionCallPart(part)) {
|
||||
contentHash = createHash('sha256')
|
||||
.update(
|
||||
`call:${part.functionCall.name}:${JSON.stringify(part.functionCall.args)}`,
|
||||
)
|
||||
.digest('hex');
|
||||
id = `call_h_${contentHash}_${turnSalt}_${partIdx}`;
|
||||
} else if (isFunctionResponsePart(part)) {
|
||||
contentHash = createHash('sha256')
|
||||
.update(
|
||||
`resp:${part.functionResponse.name}:${JSON.stringify(part.functionResponse.response)}`,
|
||||
)
|
||||
.digest('hex');
|
||||
id = `resp_h_${contentHash}_${turnSalt}_${partIdx}`;
|
||||
}
|
||||
|
||||
if (contentHash) {
|
||||
PART_HASH_CACHE.set(obj, contentHash);
|
||||
}
|
||||
|
||||
if (!id) {
|
||||
id = randomUUID();
|
||||
}
|
||||
|
||||
nodeIdentityMap.set(obj, id);
|
||||
return id;
|
||||
}
|
||||
|
||||
/**
|
||||
* Builds a 1:1 Mirror Graph from Chat History.
|
||||
* Every Part in history is mapped to exactly one ConcreteNode.
|
||||
*/
|
||||
export class ContextGraphBuilder {
|
||||
constructor(
|
||||
private readonly tokenCalculator: ContextTokenCalculator,
|
||||
private readonly nodeIdentityMap: WeakMap<object, string> = new WeakMap(),
|
||||
) {}
|
||||
|
||||
clear() {
|
||||
this.episodes = [];
|
||||
this.currentEpisode = null;
|
||||
this.pendingCallParts.clear();
|
||||
this.pendingCallPartsWithoutId = [];
|
||||
}
|
||||
processHistory(history: readonly Content[]): ConcreteNode[] {
|
||||
const nodes: ConcreteNode[] = [];
|
||||
|
||||
processHistory(history: readonly Content[]) {
|
||||
const finalizeEpisode = () => {
|
||||
if (this.currentEpisode && isCompleteEpisode(this.currentEpisode)) {
|
||||
this.episodes.push(this.currentEpisode);
|
||||
}
|
||||
this.currentEpisode = null;
|
||||
};
|
||||
// Tracks occurrences of identical turn content to ensure unique stable IDs
|
||||
const seenHashes = new Map<string, number>();
|
||||
|
||||
for (const msg of history) {
|
||||
for (let turnIdx = 0; turnIdx < history.length; turnIdx++) {
|
||||
const msg = history[turnIdx];
|
||||
if (!msg.parts) continue;
|
||||
|
||||
if (msg.role === 'user') {
|
||||
const hasToolResponses = msg.parts.some((p) => !!p.functionResponse);
|
||||
const hasUserParts = msg.parts.some(
|
||||
(p) => !!p.text || !!p.inlineData || !!p.fileData,
|
||||
);
|
||||
|
||||
if (hasToolResponses) {
|
||||
this.currentEpisode = parseToolResponses(
|
||||
msg,
|
||||
this.currentEpisode,
|
||||
this.pendingCallParts,
|
||||
this.pendingCallPartsWithoutId,
|
||||
this.tokenCalculator,
|
||||
this.nodeIdentityMap,
|
||||
// Defensive: Skip legacy environment header if it's the first turn.
|
||||
// We now manage this as an orthogonal late-addition header.
|
||||
if (turnIdx === 0 && msg.role === 'user' && msg.parts.length === 1) {
|
||||
const text = msg.parts[0].text;
|
||||
if (
|
||||
text?.startsWith('<session_context>') &&
|
||||
text?.includes('This is the Gemini CLI.')
|
||||
) {
|
||||
debugLogger.log(
|
||||
'[ContextGraphBuilder] Skipping legacy environment header turn from graph.',
|
||||
);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if (hasUserParts) {
|
||||
finalizeEpisode();
|
||||
this.currentEpisode = parseUserParts(msg, this.nodeIdentityMap);
|
||||
// Generate a stable salt for this turn based on its role and content
|
||||
const turnContent = JSON.stringify(msg.parts);
|
||||
const h = createHash('md5')
|
||||
.update(`${msg.role}:${turnContent}`)
|
||||
.digest('hex');
|
||||
const occurrence = (seenHashes.get(h) || 0) + 1;
|
||||
seenHashes.set(h, occurrence);
|
||||
const turnSalt = `${h}_${occurrence}`;
|
||||
const turnId = getStableId(msg, this.nodeIdentityMap, turnSalt, -1);
|
||||
|
||||
if (msg.role === 'user') {
|
||||
for (let partIdx = 0; partIdx < msg.parts.length; partIdx++) {
|
||||
const part = msg.parts[partIdx];
|
||||
const apiId =
|
||||
isFunctionResponsePart(part) &&
|
||||
typeof part.functionResponse.id === 'string'
|
||||
? `resp_${part.functionResponse.id}_${turnSalt}_${partIdx}`
|
||||
: isFunctionCallPart(part) &&
|
||||
typeof part.functionCall.id === 'string'
|
||||
? `call_${part.functionCall.id}_${turnSalt}_${partIdx}`
|
||||
: undefined;
|
||||
const id =
|
||||
apiId || getStableId(part, this.nodeIdentityMap, turnSalt, partIdx);
|
||||
const node: ConcreteNode = {
|
||||
id,
|
||||
timestamp: Date.now(),
|
||||
type: isFunctionResponsePart(part)
|
||||
? NodeType.TOOL_EXECUTION
|
||||
: NodeType.USER_PROMPT,
|
||||
role: 'user',
|
||||
payload: part,
|
||||
turnId,
|
||||
};
|
||||
nodes.push(node);
|
||||
}
|
||||
} else if (msg.role === 'model') {
|
||||
this.currentEpisode = parseModelParts(
|
||||
msg,
|
||||
this.currentEpisode,
|
||||
this.pendingCallParts,
|
||||
this.pendingCallPartsWithoutId,
|
||||
this.nodeIdentityMap,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
getNodes(): ConcreteNode[] {
|
||||
const copy = [...this.episodes];
|
||||
if (this.currentEpisode) {
|
||||
const activeEp = {
|
||||
...this.currentEpisode,
|
||||
concreteNodes: [...(this.currentEpisode.concreteNodes || [])],
|
||||
};
|
||||
finalizeYield(activeEp);
|
||||
if (isCompleteEpisode(activeEp)) {
|
||||
copy.push(activeEp);
|
||||
}
|
||||
}
|
||||
|
||||
const nodes: ConcreteNode[] = [];
|
||||
for (const ep of copy) {
|
||||
if (ep.concreteNodes) {
|
||||
for (const child of ep.concreteNodes) {
|
||||
nodes.push(child);
|
||||
for (let partIdx = 0; partIdx < msg.parts.length; partIdx++) {
|
||||
const part = msg.parts[partIdx];
|
||||
const apiId =
|
||||
isFunctionCallPart(part) && typeof part.functionCall.id === 'string'
|
||||
? `call_${part.functionCall.id}_${turnSalt}_${partIdx}`
|
||||
: undefined;
|
||||
const id =
|
||||
apiId || getStableId(part, this.nodeIdentityMap, turnSalt, partIdx);
|
||||
const node: ConcreteNode = {
|
||||
id,
|
||||
timestamp: Date.now(),
|
||||
type: isFunctionCallPart(part)
|
||||
? NodeType.TOOL_EXECUTION
|
||||
: NodeType.AGENT_THOUGHT,
|
||||
role: 'model',
|
||||
payload: part,
|
||||
turnId,
|
||||
};
|
||||
nodes.push(node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
debugLogger.log(
|
||||
`[ContextGraphBuilder] Mirror Graph built with ${nodes.length} nodes.`,
|
||||
);
|
||||
return nodes;
|
||||
}
|
||||
}
|
||||
|
||||
function parseToolResponses(
|
||||
msg: Content,
|
||||
currentEpisode: Partial<Episode> | null,
|
||||
pendingCallParts: Map<string, Part>,
|
||||
pendingCallPartsWithoutId: Part[],
|
||||
tokenCalculator: ContextTokenCalculator,
|
||||
nodeIdentityMap: WeakMap<object, string>,
|
||||
): Partial<Episode> {
|
||||
if (!currentEpisode) {
|
||||
currentEpisode = {
|
||||
id: getStableId(msg, nodeIdentityMap),
|
||||
|
||||
concreteNodes: [],
|
||||
};
|
||||
}
|
||||
|
||||
const parts = msg.parts || [];
|
||||
for (const part of parts) {
|
||||
if (part.functionResponse) {
|
||||
const callId = part.functionResponse.id || '';
|
||||
let matchingCall = pendingCallParts.get(callId);
|
||||
|
||||
if (!matchingCall && pendingCallPartsWithoutId.length > 0) {
|
||||
const idx = pendingCallPartsWithoutId.findIndex(
|
||||
(p) => p.functionCall?.name === part.functionResponse!.name,
|
||||
);
|
||||
if (idx !== -1) {
|
||||
matchingCall = pendingCallPartsWithoutId[idx];
|
||||
pendingCallPartsWithoutId.splice(idx, 1);
|
||||
} else {
|
||||
matchingCall = pendingCallPartsWithoutId.shift();
|
||||
}
|
||||
}
|
||||
|
||||
const intentTokens = matchingCall
|
||||
? tokenCalculator.estimateTokensForParts([matchingCall])
|
||||
: 0;
|
||||
const obsTokens = tokenCalculator.estimateTokensForParts([part]);
|
||||
|
||||
const step: ToolExecution = {
|
||||
id: getStableId(part, nodeIdentityMap),
|
||||
timestamp: Date.now(),
|
||||
type: 'TOOL_EXECUTION',
|
||||
toolName: part.functionResponse.name || 'unknown',
|
||||
intent: isRecord(matchingCall?.functionCall?.args)
|
||||
? matchingCall.functionCall.args
|
||||
: {},
|
||||
observation: isRecord(part.functionResponse.response)
|
||||
? part.functionResponse.response
|
||||
: {},
|
||||
tokens: {
|
||||
intent: intentTokens,
|
||||
observation: obsTokens,
|
||||
},
|
||||
};
|
||||
|
||||
currentEpisode.concreteNodes = [
|
||||
...(currentEpisode.concreteNodes || []),
|
||||
step,
|
||||
];
|
||||
if (callId) pendingCallParts.delete(callId);
|
||||
}
|
||||
}
|
||||
return currentEpisode;
|
||||
}
|
||||
|
||||
function parseUserParts(
|
||||
msg: Content,
|
||||
nodeIdentityMap: WeakMap<object, string>,
|
||||
): Partial<Episode> {
|
||||
const semanticParts: SemanticPart[] = [];
|
||||
const parts = msg.parts || [];
|
||||
for (const p of parts) {
|
||||
if (p.text !== undefined)
|
||||
semanticParts.push({ type: 'text', text: p.text });
|
||||
else if (p.inlineData)
|
||||
semanticParts.push({
|
||||
type: 'inline_data',
|
||||
mimeType: p.inlineData.mimeType || '',
|
||||
data: p.inlineData.data || '',
|
||||
});
|
||||
else if (p.fileData)
|
||||
semanticParts.push({
|
||||
type: 'file_data',
|
||||
mimeType: p.fileData.mimeType || '',
|
||||
fileUri: p.fileData.fileUri || '',
|
||||
});
|
||||
else if (!p.functionResponse)
|
||||
semanticParts.push({ type: 'raw_part', part: p }); // Preserve unknowns
|
||||
}
|
||||
|
||||
const baseObj = parts.length > 0 ? parts[0] : msg;
|
||||
const trigger: UserPrompt = {
|
||||
id: getStableId(baseObj, nodeIdentityMap),
|
||||
timestamp: Date.now(),
|
||||
type: 'USER_PROMPT',
|
||||
semanticParts,
|
||||
};
|
||||
return {
|
||||
id: getStableId(msg, nodeIdentityMap),
|
||||
|
||||
concreteNodes: [trigger],
|
||||
};
|
||||
}
|
||||
|
||||
function parseModelParts(
|
||||
msg: Content,
|
||||
currentEpisode: Partial<Episode> | null,
|
||||
pendingCallParts: Map<string, Part>,
|
||||
pendingCallPartsWithoutId: Part[],
|
||||
nodeIdentityMap: WeakMap<object, string>,
|
||||
): Partial<Episode> {
|
||||
if (!currentEpisode) {
|
||||
currentEpisode = {
|
||||
id: getStableId(msg, nodeIdentityMap),
|
||||
|
||||
concreteNodes: [],
|
||||
};
|
||||
}
|
||||
|
||||
const parts = msg.parts || [];
|
||||
for (const part of parts) {
|
||||
if (part.functionCall) {
|
||||
const callId = part.functionCall.id || '';
|
||||
if (callId) {
|
||||
pendingCallParts.set(callId, part);
|
||||
} else {
|
||||
const lastIdx = pendingCallPartsWithoutId.length - 1;
|
||||
const lastPart = pendingCallPartsWithoutId[lastIdx];
|
||||
|
||||
if (
|
||||
lastPart &&
|
||||
lastPart.functionCall &&
|
||||
lastPart.functionCall.name === part.functionCall.name
|
||||
) {
|
||||
// Replace the previous chunk with the more complete one
|
||||
pendingCallPartsWithoutId[lastIdx] = part;
|
||||
} else {
|
||||
pendingCallPartsWithoutId.push(part);
|
||||
}
|
||||
}
|
||||
} else if (part.text) {
|
||||
const thought: AgentThought = {
|
||||
id: getStableId(part, nodeIdentityMap),
|
||||
timestamp: Date.now(),
|
||||
type: 'AGENT_THOUGHT',
|
||||
text: part.text,
|
||||
};
|
||||
|
||||
currentEpisode.concreteNodes = [
|
||||
...(currentEpisode.concreteNodes || []),
|
||||
thought,
|
||||
];
|
||||
}
|
||||
}
|
||||
return currentEpisode;
|
||||
}
|
||||
|
||||
function finalizeYield(currentEpisode: Partial<Episode>) {
|
||||
if (currentEpisode.concreteNodes && currentEpisode.concreteNodes.length > 0) {
|
||||
const yieldNode: AgentYield = {
|
||||
id: randomUUID(),
|
||||
timestamp: Date.now(),
|
||||
type: 'AGENT_YIELD',
|
||||
text: 'Yield', // Synthesized yield since we don't have the original concrete node
|
||||
};
|
||||
const existingNodes = currentEpisode.concreteNodes || [];
|
||||
currentEpisode.concreteNodes = [...existingNodes, yieldNode];
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,24 +6,22 @@
|
||||
|
||||
import type { Part } from '@google/genai';
|
||||
|
||||
export type NodeType =
|
||||
// Organic Concrete Nodes
|
||||
| 'USER_PROMPT'
|
||||
| 'SYSTEM_EVENT'
|
||||
| 'AGENT_THOUGHT'
|
||||
| 'TOOL_EXECUTION'
|
||||
| 'AGENT_YIELD'
|
||||
/**
|
||||
* Basic Node Interface
|
||||
* Every element in the Context Graph is a Node.
|
||||
*/
|
||||
|
||||
// Synthetic Concrete Nodes
|
||||
| 'SNAPSHOT'
|
||||
| 'ROLLING_SUMMARY'
|
||||
| 'MASKED_TOOL'
|
||||
export enum NodeType {
|
||||
USER_PROMPT = 'USER_PROMPT',
|
||||
SYSTEM_EVENT = 'SYSTEM_EVENT',
|
||||
AGENT_THOUGHT = 'AGENT_THOUGHT',
|
||||
TOOL_EXECUTION = 'TOOL_EXECUTION',
|
||||
MASKED_TOOL = 'MASKED_TOOL',
|
||||
AGENT_YIELD = 'AGENT_YIELD',
|
||||
SNAPSHOT = 'SNAPSHOT',
|
||||
ROLLING_SUMMARY = 'ROLLING_SUMMARY',
|
||||
}
|
||||
|
||||
// Logical Nodes
|
||||
| 'TASK'
|
||||
| 'EPISODE';
|
||||
|
||||
/** Base interface for all nodes in the Episodic Context Graph */
|
||||
export interface Node {
|
||||
readonly id: string;
|
||||
readonly type: NodeType;
|
||||
@@ -32,11 +30,20 @@ export interface Node {
|
||||
/**
|
||||
* Concrete Nodes: The atomic, renderable pieces of data.
|
||||
* These are the actual "planks" of the Nodes of Theseus.
|
||||
*
|
||||
* Each ConcreteNode is now a 1:1 wrapper around a Gemini Part,
|
||||
* ensuring 100% fidelity during reconstruction.
|
||||
*/
|
||||
export interface BaseConcreteNode extends Node {
|
||||
readonly type: NodeType;
|
||||
readonly timestamp: number;
|
||||
/** The ID of the Logical Node (e.g., Episode) that structurally owns this node */
|
||||
readonly logicalParentId?: string;
|
||||
/** The role of the turn this part belongs to */
|
||||
readonly role: 'user' | 'model';
|
||||
/** The original, high-fidelity Part object from the API */
|
||||
readonly payload: Part;
|
||||
|
||||
/** The ID of the specific turn in history this node belongs to. Unique per turn. */
|
||||
readonly turnId: string;
|
||||
|
||||
/** If this node replaced a single node 1:1 (e.g., masking), this points to the original */
|
||||
readonly replacesId?: string;
|
||||
@@ -45,50 +52,19 @@ export interface BaseConcreteNode extends Node {
|
||||
readonly abstractsIds?: readonly string[];
|
||||
}
|
||||
|
||||
/**
|
||||
* Semantic Parts for User Prompts
|
||||
*/
|
||||
export interface SemanticTextPart {
|
||||
readonly type: 'text';
|
||||
readonly text: string;
|
||||
}
|
||||
|
||||
export interface SemanticInlineDataPart {
|
||||
readonly type: 'inline_data';
|
||||
readonly mimeType: string;
|
||||
readonly data: string;
|
||||
}
|
||||
|
||||
export interface SemanticFileDataPart {
|
||||
readonly type: 'file_data';
|
||||
readonly mimeType: string;
|
||||
readonly fileUri: string;
|
||||
}
|
||||
|
||||
export interface SemanticRawPart {
|
||||
readonly type: 'raw_part';
|
||||
readonly part: Part;
|
||||
}
|
||||
|
||||
export type SemanticPart =
|
||||
| SemanticTextPart
|
||||
| SemanticInlineDataPart
|
||||
| SemanticFileDataPart
|
||||
| SemanticRawPart;
|
||||
|
||||
/**
|
||||
* Trigger Nodes
|
||||
* Events that wake the agent up and initiate an Episode.
|
||||
*/
|
||||
export interface UserPrompt extends BaseConcreteNode {
|
||||
readonly type: 'USER_PROMPT';
|
||||
readonly semanticParts: readonly SemanticPart[];
|
||||
readonly type: NodeType.USER_PROMPT;
|
||||
readonly role: 'user';
|
||||
}
|
||||
|
||||
export interface SystemEvent extends BaseConcreteNode {
|
||||
readonly type: 'SYSTEM_EVENT';
|
||||
readonly type: NodeType.SYSTEM_EVENT;
|
||||
readonly name: string;
|
||||
readonly payload: Record<string, unknown>;
|
||||
readonly payload: Part; // System events are usually injected as user text parts
|
||||
}
|
||||
|
||||
export type EpisodeTrigger = UserPrompt | SystemEvent;
|
||||
@@ -98,30 +74,16 @@ export type EpisodeTrigger = UserPrompt | SystemEvent;
|
||||
* The internal autonomous actions taken by the agent during its loop.
|
||||
*/
|
||||
export interface AgentThought extends BaseConcreteNode {
|
||||
readonly type: 'AGENT_THOUGHT';
|
||||
readonly text: string;
|
||||
readonly type: NodeType.AGENT_THOUGHT;
|
||||
readonly role: 'model';
|
||||
}
|
||||
|
||||
export interface ToolExecution extends BaseConcreteNode {
|
||||
readonly type: 'TOOL_EXECUTION';
|
||||
readonly toolName: string;
|
||||
readonly intent: Record<string, unknown>;
|
||||
readonly observation: string | Record<string, unknown>;
|
||||
readonly tokens: {
|
||||
readonly intent: number;
|
||||
readonly observation: number;
|
||||
};
|
||||
readonly type: NodeType.TOOL_EXECUTION;
|
||||
}
|
||||
|
||||
export interface MaskedTool extends BaseConcreteNode {
|
||||
readonly type: 'MASKED_TOOL';
|
||||
readonly toolName: string;
|
||||
readonly intent?: Record<string, unknown>;
|
||||
readonly observation?: string | Record<string, unknown>;
|
||||
readonly tokens: {
|
||||
readonly intent: number;
|
||||
readonly observation: number;
|
||||
};
|
||||
readonly type: NodeType.MASKED_TOOL;
|
||||
}
|
||||
|
||||
export type EpisodeStep = AgentThought | ToolExecution | MaskedTool;
|
||||
@@ -131,8 +93,8 @@ export type EpisodeStep = AgentThought | ToolExecution | MaskedTool;
|
||||
* The final message where the agent yields control back to the user.
|
||||
*/
|
||||
export interface AgentYield extends BaseConcreteNode {
|
||||
readonly type: 'AGENT_YIELD';
|
||||
readonly text: string;
|
||||
readonly type: NodeType.AGENT_YIELD;
|
||||
readonly role: 'model';
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -140,13 +102,11 @@ export interface AgentYield extends BaseConcreteNode {
|
||||
* Processors that generate summaries emit explicit synthetic nodes.
|
||||
*/
|
||||
export interface Snapshot extends BaseConcreteNode {
|
||||
readonly type: 'SNAPSHOT';
|
||||
readonly text: string;
|
||||
readonly type: NodeType.SNAPSHOT;
|
||||
}
|
||||
|
||||
export interface RollingSummary extends BaseConcreteNode {
|
||||
readonly type: 'ROLLING_SUMMARY';
|
||||
readonly text: string;
|
||||
readonly type: NodeType.ROLLING_SUMMARY;
|
||||
}
|
||||
|
||||
export type SyntheticLeaf = Snapshot | RollingSummary;
|
||||
@@ -161,62 +121,34 @@ export type ConcreteNode =
|
||||
| Snapshot
|
||||
| RollingSummary;
|
||||
|
||||
/**
|
||||
* Logical Nodes
|
||||
* These define hierarchy and grouping. They do not directly render to Gemini.
|
||||
*/
|
||||
export interface Episode extends Node {
|
||||
readonly type: 'EPISODE';
|
||||
/** References to the Concrete Node IDs that conceptually belong to this Episode. */
|
||||
concreteNodes: readonly ConcreteNode[];
|
||||
}
|
||||
|
||||
export interface Task extends Node {
|
||||
readonly type: 'TASK';
|
||||
readonly goal: string;
|
||||
readonly status: 'active' | 'completed' | 'failed';
|
||||
/** References to the Episode IDs that belong to this task */
|
||||
readonly episodeIds: readonly string[];
|
||||
}
|
||||
|
||||
export type LogicalNode = Task | Episode;
|
||||
|
||||
export function isEpisode(node: Node): node is Episode {
|
||||
return node.type === 'EPISODE';
|
||||
}
|
||||
|
||||
export function isTask(node: Node): node is Task {
|
||||
return node.type === 'TASK';
|
||||
}
|
||||
|
||||
export function isAgentThought(node: Node): node is AgentThought {
|
||||
return node.type === 'AGENT_THOUGHT';
|
||||
return node.type === NodeType.AGENT_THOUGHT;
|
||||
}
|
||||
|
||||
export function isAgentYield(node: Node): node is AgentYield {
|
||||
return node.type === 'AGENT_YIELD';
|
||||
return node.type === NodeType.AGENT_YIELD;
|
||||
}
|
||||
|
||||
export function isToolExecution(node: Node): node is ToolExecution {
|
||||
return node.type === 'TOOL_EXECUTION';
|
||||
return node.type === NodeType.TOOL_EXECUTION;
|
||||
}
|
||||
|
||||
export function isMaskedTool(node: Node): node is MaskedTool {
|
||||
return node.type === 'MASKED_TOOL';
|
||||
return node.type === NodeType.MASKED_TOOL;
|
||||
}
|
||||
|
||||
export function isUserPrompt(node: Node): node is UserPrompt {
|
||||
return node.type === 'USER_PROMPT';
|
||||
return node.type === NodeType.USER_PROMPT;
|
||||
}
|
||||
|
||||
export function isSystemEvent(node: Node): node is SystemEvent {
|
||||
return node.type === 'SYSTEM_EVENT';
|
||||
return node.type === NodeType.SYSTEM_EVENT;
|
||||
}
|
||||
|
||||
export function isSnapshot(node: Node): node is Snapshot {
|
||||
return node.type === 'SNAPSHOT';
|
||||
return node.type === NodeType.SNAPSHOT;
|
||||
}
|
||||
|
||||
export function isRollingSummary(node: Node): node is RollingSummary {
|
||||
return node.type === 'ROLLING_SUMMARY';
|
||||
return node.type === NodeType.ROLLING_SUMMARY;
|
||||
}
|
||||
|
||||
@@ -9,12 +9,9 @@ import type {
|
||||
HistoryEvent,
|
||||
} from '../core/agentChatHistory.js';
|
||||
import type { ContextGraphMapper } from './graph/mapper.js';
|
||||
import type { ContextTokenCalculator } from './utils/contextTokenCalculator.js';
|
||||
import type { ContextEventBus } from './eventBus.js';
|
||||
import type { ContextTracer } from './tracer.js';
|
||||
|
||||
import type { ConcreteNode } from './graph/types.js';
|
||||
|
||||
/**
|
||||
* Connects the raw AgentChatHistory to the ContextManager.
|
||||
* It maps raw messages into Episodic Intermediate Representation (Context Graph)
|
||||
@@ -29,18 +26,25 @@ export class HistoryObserver {
|
||||
private readonly chatHistory: AgentChatHistory,
|
||||
private readonly eventBus: ContextEventBus,
|
||||
private readonly tracer: ContextTracer,
|
||||
private readonly tokenCalculator: ContextTokenCalculator,
|
||||
private readonly graphMapper: ContextGraphMapper,
|
||||
) {}
|
||||
|
||||
private processEvent = (event: HistoryEvent) => {
|
||||
let nodes: ConcreteNode[] = [];
|
||||
|
||||
if (event.type === 'CLEAR') {
|
||||
this.seenNodeIds.clear();
|
||||
}
|
||||
|
||||
nodes = this.graphMapper.applyEvent(event, this.tokenCalculator);
|
||||
if (event.type === 'SILENT_SYNC') {
|
||||
return;
|
||||
}
|
||||
|
||||
// Always process the FULL history to provide a complete view to the ContextManager.
|
||||
// The ContextManager relies on the 'nodes' array to be the TOTAL set of valid pristine nodes.
|
||||
const fullHistory = this.chatHistory.get();
|
||||
const nodes = this.graphMapper.applyEvent({
|
||||
...event,
|
||||
payload: fullHistory,
|
||||
});
|
||||
|
||||
const newNodes = new Set<string>();
|
||||
for (const node of nodes) {
|
||||
|
||||
@@ -13,7 +13,7 @@ import { ContextEventBus } from './eventBus.js';
|
||||
import { ContextEnvironmentImpl } from './pipeline/environmentImpl.js';
|
||||
import { PipelineOrchestrator } from './pipeline/orchestrator.js';
|
||||
import { ContextManager } from './contextManager.js';
|
||||
import { debugLogger } from '../utils/debugLogger.js';
|
||||
// import { debugLogger } from '../utils/debugLogger.js';
|
||||
import { NodeTruncationProcessorOptionsSchema } from './processors/nodeTruncationProcessor.js';
|
||||
import { ToolMaskingProcessorOptionsSchema } from './processors/toolMaskingProcessor.js';
|
||||
import { HistoryTruncationProcessorOptionsSchema } from './processors/historyTruncationProcessor.js';
|
||||
@@ -22,6 +22,7 @@ import { NodeDistillationProcessorOptionsSchema } from './processors/nodeDistill
|
||||
import { StateSnapshotProcessorOptionsSchema } from './processors/stateSnapshotProcessor.js';
|
||||
import { StateSnapshotAsyncProcessorOptionsSchema } from './processors/stateSnapshotAsyncProcessor.js';
|
||||
import { RollingSummaryProcessorOptionsSchema } from './processors/rollingSummaryProcessor.js';
|
||||
import { getEnvironmentContext } from '../utils/environmentContext.js';
|
||||
|
||||
export async function initializeContextManager(
|
||||
config: Config,
|
||||
@@ -29,10 +30,6 @@ export async function initializeContextManager(
|
||||
lastPromptId: string,
|
||||
): Promise<ContextManager | undefined> {
|
||||
const isV1Enabled = config.getContextManagementConfig().enabled;
|
||||
debugLogger.log(
|
||||
`[initializer] called with enabled=${isV1Enabled}, GEMINI_CONTEXT_TRACE_DIR=${process.env['GEMINI_CONTEXT_TRACE_DIR']}`,
|
||||
);
|
||||
|
||||
if (!isV1Enabled) {
|
||||
return undefined;
|
||||
}
|
||||
@@ -113,5 +110,9 @@ export async function initializeContextManager(
|
||||
tracer,
|
||||
orchestrator,
|
||||
chat.agentHistory,
|
||||
async () => {
|
||||
const parts = await getEnvironmentContext(config);
|
||||
return { role: 'user', parts };
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
@@ -28,7 +28,6 @@ export interface GraphMutation {
|
||||
export interface ContextWorkingBuffer {
|
||||
readonly nodes: readonly ConcreteNode[];
|
||||
getPristineNodes(id: string): readonly ConcreteNode[];
|
||||
getLineage(id: string): readonly ConcreteNode[];
|
||||
getAuditLog(): readonly GraphMutation[];
|
||||
}
|
||||
|
||||
|
||||
@@ -7,19 +7,20 @@
|
||||
import { describe, it, expect } from 'vitest';
|
||||
import { ContextWorkingBufferImpl } from './contextWorkingBuffer.js';
|
||||
import { createDummyNode } from '../testing/contextTestUtils.js';
|
||||
import { NodeType } from '../graph/types.js';
|
||||
|
||||
describe('ContextWorkingBufferImpl', () => {
|
||||
it('should initialize with a pristine graph correctly', () => {
|
||||
const pristine1 = createDummyNode(
|
||||
'ep1',
|
||||
'USER_PROMPT',
|
||||
NodeType.USER_PROMPT,
|
||||
10,
|
||||
undefined,
|
||||
'p1',
|
||||
);
|
||||
const pristine2 = createDummyNode(
|
||||
'ep1',
|
||||
'AGENT_THOUGHT',
|
||||
NodeType.AGENT_THOUGHT,
|
||||
10,
|
||||
undefined,
|
||||
'p2',
|
||||
@@ -38,7 +39,7 @@ describe('ContextWorkingBufferImpl', () => {
|
||||
it('should track 1:1 replacements (e.g., masking) and append to audit log', () => {
|
||||
const pristine1 = createDummyNode(
|
||||
'ep1',
|
||||
'USER_PROMPT',
|
||||
NodeType.USER_PROMPT,
|
||||
10,
|
||||
undefined,
|
||||
'p1',
|
||||
@@ -47,7 +48,7 @@ describe('ContextWorkingBufferImpl', () => {
|
||||
|
||||
const maskedNode = createDummyNode(
|
||||
'ep1',
|
||||
'USER_PROMPT',
|
||||
NodeType.USER_PROMPT,
|
||||
5,
|
||||
undefined,
|
||||
'm1',
|
||||
@@ -76,15 +77,33 @@ describe('ContextWorkingBufferImpl', () => {
|
||||
});
|
||||
|
||||
it('should track N:1 abstractions (e.g., rolling summaries)', () => {
|
||||
const p1 = createDummyNode('ep1', 'USER_PROMPT', 10, undefined, 'p1');
|
||||
const p2 = createDummyNode('ep1', 'AGENT_THOUGHT', 10, undefined, 'p2');
|
||||
const p3 = createDummyNode('ep1', 'USER_PROMPT', 10, undefined, 'p3');
|
||||
const p1 = createDummyNode(
|
||||
'ep1',
|
||||
NodeType.USER_PROMPT,
|
||||
10,
|
||||
undefined,
|
||||
'p1',
|
||||
);
|
||||
const p2 = createDummyNode(
|
||||
'ep1',
|
||||
NodeType.AGENT_THOUGHT,
|
||||
10,
|
||||
undefined,
|
||||
'p2',
|
||||
);
|
||||
const p3 = createDummyNode(
|
||||
'ep1',
|
||||
NodeType.USER_PROMPT,
|
||||
10,
|
||||
undefined,
|
||||
'p3',
|
||||
);
|
||||
|
||||
let buffer = ContextWorkingBufferImpl.initialize([p1, p2, p3]);
|
||||
|
||||
const summaryNode = createDummyNode(
|
||||
'ep1',
|
||||
'ROLLING_SUMMARY',
|
||||
NodeType.ROLLING_SUMMARY,
|
||||
15,
|
||||
undefined,
|
||||
's1',
|
||||
@@ -105,11 +124,23 @@ describe('ContextWorkingBufferImpl', () => {
|
||||
});
|
||||
|
||||
it('should track multi-generation provenance correctly', () => {
|
||||
const p1 = createDummyNode('ep1', 'USER_PROMPT', 10, undefined, 'p1');
|
||||
const p1 = createDummyNode(
|
||||
'ep1',
|
||||
NodeType.USER_PROMPT,
|
||||
10,
|
||||
undefined,
|
||||
'p1',
|
||||
);
|
||||
let buffer = ContextWorkingBufferImpl.initialize([p1]);
|
||||
|
||||
// Gen 1: Masked
|
||||
const gen1 = createDummyNode('ep1', 'USER_PROMPT', 8, undefined, 'gen1');
|
||||
const gen1 = createDummyNode(
|
||||
'ep1',
|
||||
NodeType.USER_PROMPT,
|
||||
8,
|
||||
undefined,
|
||||
'gen1',
|
||||
);
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
(gen1 as any).replacesId = 'p1';
|
||||
buffer = buffer.applyProcessorResult('Masking', [p1], [gen1]);
|
||||
@@ -117,7 +148,7 @@ describe('ContextWorkingBufferImpl', () => {
|
||||
// Gen 2: Summarized
|
||||
const gen2 = createDummyNode(
|
||||
'ep1',
|
||||
'ROLLING_SUMMARY',
|
||||
NodeType.ROLLING_SUMMARY,
|
||||
5,
|
||||
undefined,
|
||||
'gen2',
|
||||
@@ -140,12 +171,18 @@ describe('ContextWorkingBufferImpl', () => {
|
||||
});
|
||||
|
||||
it('should handle net-new injected nodes without throwing', () => {
|
||||
const p1 = createDummyNode('ep1', 'USER_PROMPT', 10, undefined, 'p1');
|
||||
const p1 = createDummyNode(
|
||||
'ep1',
|
||||
NodeType.USER_PROMPT,
|
||||
10,
|
||||
undefined,
|
||||
'p1',
|
||||
);
|
||||
let buffer = ContextWorkingBufferImpl.initialize([p1]);
|
||||
|
||||
const injected = createDummyNode(
|
||||
'ep1',
|
||||
'SYSTEM_EVENT',
|
||||
NodeType.SYSTEM_EVENT,
|
||||
5,
|
||||
undefined,
|
||||
'injected1',
|
||||
|
||||
@@ -66,14 +66,23 @@ export class ContextWorkingBufferImpl implements ContextWorkingBuffer {
|
||||
|
||||
const newPristineMap = new Map<string, ConcreteNode>(this.pristineNodesMap);
|
||||
const newProvenanceMap = new Map(this.provenanceMap);
|
||||
const existingIds = new Set(this.nodes.map((n) => n.id));
|
||||
|
||||
const nodesToAdd: ConcreteNode[] = [];
|
||||
const batchIds = new Set<string>();
|
||||
for (const node of newNodes) {
|
||||
newPristineMap.set(node.id, node);
|
||||
newProvenanceMap.set(node.id, new Set([node.id]));
|
||||
if (!existingIds.has(node.id) && !batchIds.has(node.id)) {
|
||||
newPristineMap.set(node.id, node);
|
||||
newProvenanceMap.set(node.id, new Set([node.id]));
|
||||
nodesToAdd.push(node);
|
||||
batchIds.add(node.id);
|
||||
}
|
||||
}
|
||||
|
||||
if (nodesToAdd.length === 0) return this;
|
||||
|
||||
return new ContextWorkingBufferImpl(
|
||||
[...this.nodes, ...newNodes],
|
||||
[...this.nodes, ...nodesToAdd],
|
||||
newPristineMap,
|
||||
newProvenanceMap,
|
||||
[...this.history],
|
||||
@@ -257,20 +266,4 @@ export class ContextWorkingBufferImpl implements ContextWorkingBuffer {
|
||||
getAuditLog(): readonly GraphMutation[] {
|
||||
return this.history;
|
||||
}
|
||||
|
||||
getLineage(id: string): readonly ConcreteNode[] {
|
||||
const lineage: ConcreteNode[] = [];
|
||||
const currentNodesMap = new Map(this.nodes.map((n) => [n.id, n]));
|
||||
|
||||
let current = currentNodesMap.get(id);
|
||||
while (current) {
|
||||
lineage.push(current);
|
||||
if (current.logicalParentId && current.logicalParentId !== current.id) {
|
||||
current = currentNodesMap.get(current.logicalParentId);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
return lineage;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -37,7 +37,7 @@ export class ContextEnvironmentImpl implements ContextEnvironment {
|
||||
this.behaviorRegistry,
|
||||
);
|
||||
this.inbox = new LiveInbox();
|
||||
this.graphMapper = new ContextGraphMapper(this.behaviorRegistry);
|
||||
this.graphMapper = new ContextGraphMapper();
|
||||
}
|
||||
|
||||
get llmClient(): BaseLlmClient {
|
||||
|
||||
@@ -4,13 +4,13 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import assert from 'node:assert';
|
||||
import { describe, it, expect, beforeEach, vi, afterEach } from 'vitest';
|
||||
import { PipelineOrchestrator } from './orchestrator.js';
|
||||
import {
|
||||
createMockEnvironment,
|
||||
createDummyNode,
|
||||
} from '../testing/contextTestUtils.js';
|
||||
import { NodeType } from '../graph/types.js';
|
||||
import type { ContextEnvironment } from './environment.js';
|
||||
import type {
|
||||
ContextProcessor,
|
||||
@@ -28,21 +28,22 @@ function createModifyingProcessor(id: string): ContextProcessor {
|
||||
name: 'ModifyingProcessor',
|
||||
process: async (args: ProcessArgs) => {
|
||||
const newTargets = [...args.targets];
|
||||
if (newTargets.length > 0 && newTargets[0].type === 'USER_PROMPT') {
|
||||
if (
|
||||
newTargets.length > 0 &&
|
||||
newTargets[0].type === NodeType.USER_PROMPT
|
||||
) {
|
||||
const prompt = newTargets[0];
|
||||
const newParts = [...prompt.semanticParts];
|
||||
if (newParts.length > 0 && newParts[0].type === 'text') {
|
||||
newParts[0] = {
|
||||
...newParts[0],
|
||||
text: newParts[0].text + ' [modified]',
|
||||
if (prompt.payload.text) {
|
||||
newTargets[0] = {
|
||||
...prompt,
|
||||
id: prompt.id + '-modified',
|
||||
replacesId: prompt.id,
|
||||
payload: {
|
||||
...prompt.payload,
|
||||
text: prompt.payload.text + ' [modified]',
|
||||
},
|
||||
};
|
||||
}
|
||||
newTargets[0] = {
|
||||
...prompt,
|
||||
id: prompt.id + '-modified',
|
||||
replacesId: prompt.id,
|
||||
semanticParts: newParts,
|
||||
};
|
||||
}
|
||||
return newTargets;
|
||||
},
|
||||
@@ -112,8 +113,8 @@ describe('PipelineOrchestrator (Component)', () => {
|
||||
];
|
||||
|
||||
const orchestrator = setupOrchestrator(pipelines);
|
||||
const originalNode = createDummyNode('ep1', 'USER_PROMPT', 50, {
|
||||
semanticParts: [{ type: 'text', text: 'Original' }],
|
||||
const originalNode = createDummyNode('ep1', NodeType.USER_PROMPT, 50, {
|
||||
payload: { text: 'Original' },
|
||||
});
|
||||
|
||||
const processed = await orchestrator.executeTriggerSync(
|
||||
@@ -125,8 +126,7 @@ describe('PipelineOrchestrator (Component)', () => {
|
||||
|
||||
expect(processed.length).toBe(1);
|
||||
const resultingNode = processed[0] as UserPrompt;
|
||||
assert(resultingNode.semanticParts[0].type === 'text');
|
||||
expect(resultingNode.semanticParts[0].text).toBe('Original [modified]');
|
||||
expect(resultingNode.payload.text).toBe('Original [modified]');
|
||||
expect(resultingNode.replacesId).toBe(originalNode.id);
|
||||
});
|
||||
|
||||
@@ -140,8 +140,8 @@ describe('PipelineOrchestrator (Component)', () => {
|
||||
];
|
||||
|
||||
const orchestrator = setupOrchestrator(pipelines);
|
||||
const originalNode = createDummyNode('ep1', 'USER_PROMPT', 50, {
|
||||
semanticParts: [{ type: 'text', text: 'Original' }],
|
||||
const originalNode = createDummyNode('ep1', NodeType.USER_PROMPT, 50, {
|
||||
payload: { text: 'Original' },
|
||||
});
|
||||
|
||||
const processed = await orchestrator.executeTriggerSync(
|
||||
@@ -167,8 +167,8 @@ describe('PipelineOrchestrator (Component)', () => {
|
||||
];
|
||||
|
||||
const orchestrator = setupOrchestrator(pipelines);
|
||||
const originalNode = createDummyNode('ep1', 'USER_PROMPT', 50, {
|
||||
semanticParts: [{ type: 'text', text: 'Original' }],
|
||||
const originalNode = createDummyNode('ep1', NodeType.USER_PROMPT, 50, {
|
||||
payload: { text: 'Original' },
|
||||
});
|
||||
|
||||
// The throwing processor should be caught and logged, allowing Mod to still run.
|
||||
@@ -181,8 +181,7 @@ describe('PipelineOrchestrator (Component)', () => {
|
||||
|
||||
expect(processed.length).toBe(1);
|
||||
const resultingNode = processed[0] as UserPrompt;
|
||||
assert(resultingNode.semanticParts[0].type === 'text');
|
||||
expect(resultingNode.semanticParts[0].text).toBe('Original [modified]');
|
||||
expect(resultingNode.payload.text).toBe('Original [modified]');
|
||||
});
|
||||
});
|
||||
|
||||
@@ -205,8 +204,8 @@ describe('PipelineOrchestrator (Component)', () => {
|
||||
],
|
||||
);
|
||||
|
||||
const node1 = createDummyNode('ep1', 'USER_PROMPT', 10);
|
||||
const node2 = createDummyNode('ep1', 'AGENT_THOUGHT', 20);
|
||||
const node1 = createDummyNode('ep1', NodeType.USER_PROMPT, 10);
|
||||
const node2 = createDummyNode('ep1', NodeType.AGENT_THOUGHT, 20);
|
||||
|
||||
eventBus.emitChunkReceived({
|
||||
nodes: [node1, node2],
|
||||
|
||||
@@ -21,6 +21,9 @@ import { ContextWorkingBufferImpl } from './contextWorkingBuffer.js';
|
||||
|
||||
export class PipelineOrchestrator {
|
||||
private activeTimers: NodeJS.Timeout[] = [];
|
||||
private readonly pendingPipelines = new Map<string, Promise<void>>();
|
||||
private readonly pipelineMutex = new Map<string, Promise<void>>();
|
||||
private nodeProvider: (() => readonly ConcreteNode[]) | undefined;
|
||||
|
||||
constructor(
|
||||
private readonly pipelines: PipelineDef[],
|
||||
@@ -32,15 +35,37 @@ export class PipelineOrchestrator {
|
||||
this.setupTriggers();
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the provider for the latest live nodes.
|
||||
* This is used by sequential pipeline runs to ensure they operate on current state.
|
||||
*/
|
||||
setNodeProvider(provider: () => readonly ConcreteNode[]) {
|
||||
this.nodeProvider = provider;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a promise that resolves when all currently executing async pipelines have finished.
|
||||
* This acts as a 'Pressure Barrier' for the ContextManager.
|
||||
*/
|
||||
async waitForPipelines(): Promise<void> {
|
||||
const pending = Array.from(this.pendingPipelines.values());
|
||||
if (pending.length > 0) {
|
||||
debugLogger.log(
|
||||
`[PipelineOrchestrator] Waiting for ${pending.length} pending async pipelines to complete...`,
|
||||
);
|
||||
await Promise.allSettled(pending);
|
||||
}
|
||||
}
|
||||
|
||||
private isNodeAllowed(
|
||||
node: ConcreteNode,
|
||||
triggerTargets: ReadonlySet<string>,
|
||||
protectedLogicalIds: ReadonlySet<string> = new Set(),
|
||||
protectedTurnIds: ReadonlySet<string> = new Set(),
|
||||
): boolean {
|
||||
return (
|
||||
triggerTargets.has(node.id) &&
|
||||
!protectedLogicalIds.has(node.id) &&
|
||||
(!node.logicalParentId || !protectedLogicalIds.has(node.logicalParentId))
|
||||
!protectedTurnIds.has(node.id) &&
|
||||
!protectedTurnIds.has(node.turnId)
|
||||
);
|
||||
}
|
||||
|
||||
@@ -78,12 +103,42 @@ export class PipelineOrchestrator {
|
||||
};
|
||||
|
||||
bindTriggers(this.pipelines, (pipeline, nodes, targets, protectedIds) => {
|
||||
void this.executePipelineAsync(
|
||||
pipeline,
|
||||
nodes,
|
||||
new Set(targets),
|
||||
new Set(protectedIds),
|
||||
);
|
||||
// Fetch the tail of the current chain for this pipeline, or start a new one
|
||||
const existing =
|
||||
this.pipelineMutex.get(pipeline.name) || Promise.resolve();
|
||||
|
||||
const nextPromise = (async () => {
|
||||
try {
|
||||
// Wait for the previous run of THIS pipeline to complete
|
||||
await existing;
|
||||
|
||||
// We re-fetch the LATEST nodes from the environment's live buffer
|
||||
// to ensure this sequential run isn't operating on stale data from the trigger event.
|
||||
const latestNodes = this.nodeProvider!();
|
||||
|
||||
await this.executePipelineAsync(
|
||||
pipeline,
|
||||
latestNodes,
|
||||
new Set(targets),
|
||||
new Set(protectedIds),
|
||||
);
|
||||
} catch (e) {
|
||||
debugLogger.error(`Pipeline chain ${pipeline.name} failed:`, e);
|
||||
}
|
||||
})();
|
||||
|
||||
// Update the chain tail
|
||||
this.pipelineMutex.set(pipeline.name, nextPromise);
|
||||
|
||||
const pipelineId = `${pipeline.name}_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`;
|
||||
this.pendingPipelines.set(pipelineId, nextPromise);
|
||||
void nextPromise.finally(() => {
|
||||
this.pendingPipelines.delete(pipelineId);
|
||||
// Only clear the mutex if we are still the tail of the chain
|
||||
if (this.pipelineMutex.get(pipeline.name) === nextPromise) {
|
||||
this.pipelineMutex.delete(pipeline.name);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
bindTriggers(this.asyncPipelines, (pipeline, nodes, targetIds) => {
|
||||
@@ -115,8 +170,13 @@ export class PipelineOrchestrator {
|
||||
trigger: PipelineTrigger,
|
||||
nodes: readonly ConcreteNode[],
|
||||
triggerTargets: ReadonlySet<string>,
|
||||
protectedLogicalIds: ReadonlySet<string> = new Set(),
|
||||
protectedTurnIds: ReadonlySet<string> = new Set(),
|
||||
): Promise<readonly ConcreteNode[]> {
|
||||
this.tracer.logEvent('Orchestrator', 'Strategy Intent', {
|
||||
trigger,
|
||||
totalNodes: nodes.length,
|
||||
targetNodes: triggerTargets.size,
|
||||
});
|
||||
let currentBuffer = ContextWorkingBufferImpl.initialize(nodes);
|
||||
const triggerPipelines = this.pipelines.filter((p) =>
|
||||
p.triggers.includes(trigger),
|
||||
@@ -133,10 +193,11 @@ export class PipelineOrchestrator {
|
||||
this.tracer.logEvent(
|
||||
'Orchestrator',
|
||||
`Executing processor synchronously: ${processor.id}`,
|
||||
{ nodeCountBefore: currentBuffer.nodes.length },
|
||||
);
|
||||
|
||||
const allowedTargets = currentBuffer.nodes.filter((n) =>
|
||||
this.isNodeAllowed(n, triggerTargets, protectedLogicalIds),
|
||||
this.isNodeAllowed(n, triggerTargets, protectedTurnIds),
|
||||
);
|
||||
|
||||
const returnedNodes = await processor.process({
|
||||
@@ -150,6 +211,27 @@ export class PipelineOrchestrator {
|
||||
allowedTargets,
|
||||
returnedNodes,
|
||||
);
|
||||
|
||||
const addedNodes = returnedNodes.filter(
|
||||
(n) => !allowedTargets.some((at) => at.id === n.id),
|
||||
);
|
||||
const removedNodes = allowedTargets.filter(
|
||||
(at) => !returnedNodes.some((n) => n.id === at.id),
|
||||
);
|
||||
|
||||
this.tracer.logEvent('Orchestrator', 'Transformation Lineage', {
|
||||
processorId: processor.id,
|
||||
inputNodeCount: allowedTargets.length,
|
||||
outputNodeCount: returnedNodes.length,
|
||||
removedNodeIds: removedNodes.map((n) => n.id),
|
||||
addedNodes: addedNodes.map((n) => ({
|
||||
id: n.id,
|
||||
replacesId: n.replacesId,
|
||||
abstractsIds: n.abstractsIds,
|
||||
approxTokens:
|
||||
this.env.tokenCalculator.calculateConcreteListTokens([n]),
|
||||
})),
|
||||
});
|
||||
} catch (error) {
|
||||
debugLogger.error(
|
||||
`Synchronous processor ${processor.id} failed:`,
|
||||
@@ -169,11 +251,15 @@ export class PipelineOrchestrator {
|
||||
pipeline: PipelineDef,
|
||||
nodes: readonly ConcreteNode[],
|
||||
triggerTargets: Set<string>,
|
||||
protectedLogicalIds: ReadonlySet<string> = new Set(),
|
||||
protectedTurnIds: ReadonlySet<string> = new Set(),
|
||||
) {
|
||||
this.tracer.logEvent(
|
||||
'Orchestrator',
|
||||
`Triggering async pipeline: ${pipeline.name}`,
|
||||
{
|
||||
triggerTargets: triggerTargets.size,
|
||||
totalNodes: nodes.length,
|
||||
},
|
||||
);
|
||||
if (!nodes || nodes.length === 0) return;
|
||||
|
||||
@@ -187,10 +273,11 @@ export class PipelineOrchestrator {
|
||||
this.tracer.logEvent(
|
||||
'Orchestrator',
|
||||
`Executing processor: ${processor.id} (async)`,
|
||||
{ nodeCountBefore: currentBuffer.nodes.length },
|
||||
);
|
||||
|
||||
const allowedTargets = currentBuffer.nodes.filter((n) =>
|
||||
this.isNodeAllowed(n, triggerTargets, protectedLogicalIds),
|
||||
this.isNodeAllowed(n, triggerTargets, protectedTurnIds),
|
||||
);
|
||||
|
||||
const returnedNodes = await processor.process({
|
||||
@@ -204,6 +291,29 @@ export class PipelineOrchestrator {
|
||||
allowedTargets,
|
||||
returnedNodes,
|
||||
);
|
||||
|
||||
const addedNodes = returnedNodes.filter(
|
||||
(n) => !allowedTargets.some((at) => at.id === n.id),
|
||||
);
|
||||
const removedNodes = allowedTargets.filter(
|
||||
(at) => !returnedNodes.some((n) => n.id === at.id),
|
||||
);
|
||||
|
||||
this.tracer.logEvent('Orchestrator', 'Transformation Lineage (Async)', {
|
||||
processorId: processor.id,
|
||||
inputNodeCount: allowedTargets.length,
|
||||
outputNodeCount: returnedNodes.length,
|
||||
removedNodeIds: removedNodes.map((n) => n.id),
|
||||
addedNodes: addedNodes.map((n) => ({
|
||||
id: n.id,
|
||||
replacesId: n.replacesId,
|
||||
abstractsIds: n.abstractsIds,
|
||||
approxTokens: this.env.tokenCalculator.calculateConcreteListTokens([
|
||||
n,
|
||||
]),
|
||||
})),
|
||||
});
|
||||
|
||||
this.eventBus.emitProcessorResult({
|
||||
processorId: processor.id,
|
||||
targets: allowedTargets,
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import assert from 'node:assert';
|
||||
import { describe, it, expect } from 'vitest';
|
||||
import { createBlobDegradationProcessor } from './blobDegradationProcessor.js';
|
||||
import {
|
||||
@@ -12,7 +11,7 @@ import {
|
||||
createMockEnvironment,
|
||||
createDummyNode,
|
||||
} from '../testing/contextTestUtils.js';
|
||||
import type { UserPrompt, SemanticPart, ConcreteNode } from '../graph/types.js';
|
||||
import { type ConcreteNode, NodeType } from '../graph/types.js';
|
||||
|
||||
describe('BlobDegradationProcessor', () => {
|
||||
it('should ignore text parts and only target inline_data and file_data', async () => {
|
||||
@@ -28,35 +27,31 @@ describe('BlobDegradationProcessor', () => {
|
||||
env,
|
||||
);
|
||||
|
||||
const parts: SemanticPart[] = [
|
||||
{ type: 'text', text: 'Hello' },
|
||||
{ type: 'inline_data', mimeType: 'image/png', data: fakeData },
|
||||
{ type: 'text', text: 'World' },
|
||||
];
|
||||
const node1 = createDummyNode('ep1', NodeType.USER_PROMPT, 10, {
|
||||
payload: { text: 'Hello' },
|
||||
});
|
||||
const node2 = createDummyNode('ep1', NodeType.USER_PROMPT, 100, {
|
||||
payload: { inlineData: { mimeType: 'image/png', data: fakeData } },
|
||||
});
|
||||
const node3 = createDummyNode('ep1', NodeType.USER_PROMPT, 10, {
|
||||
payload: { text: 'World' },
|
||||
});
|
||||
|
||||
const prompt = createDummyNode('ep1', 'USER_PROMPT', 100, {
|
||||
semanticParts: parts,
|
||||
}) as UserPrompt;
|
||||
|
||||
const targets = [prompt];
|
||||
const targets = [node1, node2, node3];
|
||||
|
||||
const result = await processor.process(createMockProcessArgs(targets));
|
||||
|
||||
expect(result.length).toBe(1);
|
||||
const modifiedPrompt = result[0] as UserPrompt;
|
||||
expect(result.length).toBe(3);
|
||||
|
||||
expect(modifiedPrompt.id).not.toBe(prompt.id);
|
||||
expect(modifiedPrompt.semanticParts.length).toBe(3);
|
||||
// Text nodes should be untouched
|
||||
expect(result[0]).toBe(node1);
|
||||
expect(result[2]).toBe(node3);
|
||||
|
||||
// Text parts should be untouched
|
||||
expect(modifiedPrompt.semanticParts[0]).toEqual(parts[0]);
|
||||
expect(modifiedPrompt.semanticParts[2]).toEqual(parts[2]);
|
||||
|
||||
// The inline_data part should be replaced with text
|
||||
const degradedPart = modifiedPrompt.semanticParts[1];
|
||||
expect(degradedPart.type).toBe('text');
|
||||
assert(degradedPart.type === 'text');
|
||||
expect(degradedPart.text).toContain(
|
||||
// The inline_data node should be replaced with text
|
||||
const degradedNode = result[1];
|
||||
expect(degradedNode.id).not.toBe(node2.id);
|
||||
expect(degradedNode.replacesId).toBe(node2.id);
|
||||
expect(degradedNode.payload.text).toContain(
|
||||
'[Multi-Modal Blob (image/png, 0.00MB) degraded to text',
|
||||
);
|
||||
});
|
||||
@@ -69,29 +64,26 @@ describe('BlobDegradationProcessor', () => {
|
||||
env,
|
||||
);
|
||||
|
||||
// Tokens for fileData = 258.
|
||||
// Degraded text = "[File Reference (video/mp4) degraded to text to preserve context window. Original URI: gs://test1]"
|
||||
// Degraded text length ~100 characters.
|
||||
// Since charsPerToken=1, degraded text = 100 tokens.
|
||||
// Tokens saved = 258 - 100 = 158. This is > 0, so it WILL degrade it!
|
||||
const node1 = createDummyNode('ep1', NodeType.USER_PROMPT, 100, {
|
||||
payload: {
|
||||
fileData: { mimeType: 'video/mp4', fileUri: 'gs://test1' },
|
||||
},
|
||||
});
|
||||
const node2 = createDummyNode('ep1', NodeType.USER_PROMPT, 100, {
|
||||
payload: {
|
||||
fileData: { mimeType: 'video/mp4', fileUri: 'gs://test2' },
|
||||
},
|
||||
});
|
||||
|
||||
const prompt = createDummyNode('ep1', 'USER_PROMPT', 100, {
|
||||
semanticParts: [
|
||||
{ type: 'file_data', mimeType: 'video/mp4', fileUri: 'gs://test1' },
|
||||
{ type: 'file_data', mimeType: 'video/mp4', fileUri: 'gs://test2' },
|
||||
],
|
||||
}) as UserPrompt;
|
||||
|
||||
const targets = [prompt];
|
||||
const targets = [node1, node2];
|
||||
|
||||
const result = await processor.process(createMockProcessArgs(targets));
|
||||
|
||||
const modifiedPrompt = result[0] as UserPrompt;
|
||||
expect(modifiedPrompt.semanticParts.length).toBe(2);
|
||||
expect(result.length).toBe(2);
|
||||
|
||||
// Both parts should be degraded
|
||||
expect(modifiedPrompt.semanticParts[0].type).toBe('text');
|
||||
expect(modifiedPrompt.semanticParts[1].type).toBe('text');
|
||||
// Both nodes should be degraded
|
||||
expect(result[0].payload.text).toContain('degraded to text');
|
||||
expect(result[1].payload.text).toContain('degraded to text');
|
||||
});
|
||||
|
||||
it('should return exactly the targets array if targets are empty', async () => {
|
||||
|
||||
@@ -8,7 +8,7 @@ import type { JSONSchemaType } from 'ajv';
|
||||
import type { ProcessArgs, ContextProcessor } from '../pipeline.js';
|
||||
import * as fs from 'node:fs/promises';
|
||||
import * as path from 'node:path';
|
||||
import type { ConcreteNode, UserPrompt } from '../graph/types.js';
|
||||
import type { ConcreteNode } from '../graph/types.js';
|
||||
import type { ContextEnvironment } from '../pipeline/environment.js';
|
||||
import { sanitizeFilenamePart } from '../../utils/fileUtils.js';
|
||||
|
||||
@@ -55,95 +55,50 @@ export function createBlobDegradationProcessor(
|
||||
|
||||
// Forward scan, looking for bloated non-text parts to degrade
|
||||
for (const node of targets) {
|
||||
switch (node.type) {
|
||||
case 'USER_PROMPT': {
|
||||
let modified = false;
|
||||
const newParts = [...node.semanticParts];
|
||||
const payload = node.payload;
|
||||
let newText = '';
|
||||
let tokensSaved = 0;
|
||||
|
||||
for (let j = 0; j < node.semanticParts.length; j++) {
|
||||
const part = node.semanticParts[j];
|
||||
if (part.type === 'text') continue;
|
||||
if (payload.inlineData?.data && payload.inlineData?.mimeType) {
|
||||
await ensureDir();
|
||||
const ext = payload.inlineData.mimeType.split('/')[1] || 'bin';
|
||||
const fileName = `blob_${Date.now()}_${randomUUID()}.${ext}`;
|
||||
const filePath = path.join(blobOutputsDir, fileName);
|
||||
|
||||
let newText = '';
|
||||
let tokensSaved = 0;
|
||||
const buffer = Buffer.from(payload.inlineData.data, 'base64');
|
||||
await fs.writeFile(filePath, buffer);
|
||||
|
||||
switch (part.type) {
|
||||
case 'inline_data': {
|
||||
await ensureDir();
|
||||
const ext = part.mimeType.split('/')[1] || 'bin';
|
||||
const fileName = `blob_${Date.now()}_${randomUUID()}.${ext}`;
|
||||
const filePath = path.join(blobOutputsDir, fileName);
|
||||
const mb = (buffer.byteLength / 1024 / 1024).toFixed(2);
|
||||
newText = `[Multi-Modal Blob (${payload.inlineData.mimeType}, ${mb}MB) degraded to text to preserve context window. Saved to: ${filePath}]`;
|
||||
|
||||
const buffer = Buffer.from(part.data, 'base64');
|
||||
await fs.writeFile(filePath, buffer);
|
||||
const oldTokens = env.tokenCalculator.estimateTokensForParts([
|
||||
payload,
|
||||
]);
|
||||
const newTokens = env.tokenCalculator.estimateTokensForParts([
|
||||
{ text: newText },
|
||||
]);
|
||||
tokensSaved = oldTokens - newTokens;
|
||||
} else if (payload.fileData?.mimeType && payload.fileData?.fileUri) {
|
||||
newText = `[File Reference (${payload.fileData.mimeType}) degraded to text to preserve context window. Original URI: ${payload.fileData.fileUri}]`;
|
||||
const oldTokens = env.tokenCalculator.estimateTokensForParts([
|
||||
payload,
|
||||
]);
|
||||
const newTokens = env.tokenCalculator.estimateTokensForParts([
|
||||
{ text: newText },
|
||||
]);
|
||||
tokensSaved = oldTokens - newTokens;
|
||||
}
|
||||
|
||||
const mb = (buffer.byteLength / 1024 / 1024).toFixed(2);
|
||||
newText = `[Multi-Modal Blob (${part.mimeType}, ${mb}MB) degraded to text to preserve context window. Saved to: ${filePath}]`;
|
||||
|
||||
const oldTokens = env.tokenCalculator.estimateTokensForParts([
|
||||
{
|
||||
inlineData: { mimeType: part.mimeType, data: part.data },
|
||||
},
|
||||
]);
|
||||
const newTokens = env.tokenCalculator.estimateTokensForParts([
|
||||
{ text: newText },
|
||||
]);
|
||||
tokensSaved = oldTokens - newTokens;
|
||||
break;
|
||||
}
|
||||
case 'file_data': {
|
||||
newText = `[File Reference (${part.mimeType}) degraded to text to preserve context window. Original URI: ${part.fileUri}]`;
|
||||
const oldTokens = env.tokenCalculator.estimateTokensForParts([
|
||||
{
|
||||
fileData: {
|
||||
mimeType: part.mimeType,
|
||||
fileUri: part.fileUri,
|
||||
},
|
||||
},
|
||||
]);
|
||||
const newTokens = env.tokenCalculator.estimateTokensForParts([
|
||||
{ text: newText },
|
||||
]);
|
||||
tokensSaved = oldTokens - newTokens;
|
||||
break;
|
||||
}
|
||||
case 'raw_part': {
|
||||
newText = `[Unknown Part degraded to text to preserve context window.]`;
|
||||
const oldTokens = env.tokenCalculator.estimateTokensForParts([
|
||||
part.part,
|
||||
]);
|
||||
const newTokens = env.tokenCalculator.estimateTokensForParts([
|
||||
{ text: newText },
|
||||
]);
|
||||
tokensSaved = oldTokens - newTokens;
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
if (newText && tokensSaved > 0) {
|
||||
newParts[j] = { type: 'text', text: newText };
|
||||
modified = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (modified) {
|
||||
const degradedNode: UserPrompt = {
|
||||
...node,
|
||||
id: randomUUID(),
|
||||
semanticParts: newParts,
|
||||
replacesId: node.id,
|
||||
};
|
||||
returnedNodes.push(degradedNode);
|
||||
} else {
|
||||
returnedNodes.push(node);
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
returnedNodes.push(node);
|
||||
break;
|
||||
if (newText && tokensSaved > 0) {
|
||||
returnedNodes.push({
|
||||
...node,
|
||||
id: randomUUID(),
|
||||
payload: { text: newText },
|
||||
replacesId: node.id,
|
||||
turnId: node.turnId,
|
||||
});
|
||||
} else {
|
||||
returnedNodes.push(node);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import assert from 'node:assert';
|
||||
import { describe, it, expect } from 'vitest';
|
||||
import { createNodeDistillationProcessor } from './nodeDistillationProcessor.js';
|
||||
import {
|
||||
@@ -14,6 +13,7 @@ import {
|
||||
createDummyToolNode,
|
||||
createMockLlmClient,
|
||||
} from '../testing/contextTestUtils.js';
|
||||
import { NodeType } from '../graph/types.js';
|
||||
import type {
|
||||
UserPrompt,
|
||||
AgentThought,
|
||||
@@ -41,20 +41,20 @@ describe('NodeDistillationProcessor', () => {
|
||||
|
||||
const prompt = createDummyNode(
|
||||
'ep1',
|
||||
'USER_PROMPT',
|
||||
NodeType.USER_PROMPT,
|
||||
50,
|
||||
{
|
||||
semanticParts: [{ type: 'text', text: longText }],
|
||||
payload: { text: longText },
|
||||
},
|
||||
'prompt-id',
|
||||
) as UserPrompt;
|
||||
|
||||
const thought = createDummyNode(
|
||||
'ep1',
|
||||
'AGENT_THOUGHT',
|
||||
NodeType.AGENT_THOUGHT,
|
||||
50,
|
||||
{
|
||||
text: longText,
|
||||
payload: { text: longText },
|
||||
},
|
||||
'thought-id',
|
||||
) as AgentThought;
|
||||
@@ -64,7 +64,13 @@ describe('NodeDistillationProcessor', () => {
|
||||
5,
|
||||
500,
|
||||
{
|
||||
observation: { result: 'A'.repeat(500) },
|
||||
role: 'user',
|
||||
payload: {
|
||||
functionResponse: {
|
||||
name: 'dummy_tool',
|
||||
response: { result: 'A'.repeat(500) },
|
||||
},
|
||||
},
|
||||
},
|
||||
'tool-id',
|
||||
);
|
||||
@@ -78,19 +84,19 @@ describe('NodeDistillationProcessor', () => {
|
||||
// 1. User Prompt
|
||||
const compressedPrompt = result[0] as UserPrompt;
|
||||
expect(compressedPrompt.id).not.toBe(prompt.id);
|
||||
expect(compressedPrompt.semanticParts[0].type).toBe('text');
|
||||
assert(compressedPrompt.semanticParts[0].type === 'text');
|
||||
expect(compressedPrompt.semanticParts[0].text).toBe('Mocked Summary!');
|
||||
expect(compressedPrompt.payload.text).toBe('Mocked Summary!');
|
||||
|
||||
// 2. Agent Thought
|
||||
const compressedThought = result[1] as AgentThought;
|
||||
expect(compressedThought.id).not.toBe(thought.id);
|
||||
expect(compressedThought.text).toBe('Mocked Summary!');
|
||||
expect(compressedThought.payload.text).toBe('Mocked Summary!');
|
||||
|
||||
// 3. Tool Execution
|
||||
const compressedTool = result[2] as ToolExecution;
|
||||
expect(compressedTool.id).not.toBe(tool.id);
|
||||
expect(compressedTool.observation).toEqual({ summary: 'Mocked Summary!' });
|
||||
expect(compressedTool.payload.functionResponse?.response).toEqual({
|
||||
summary: 'Mocked Summary!',
|
||||
});
|
||||
|
||||
expect(mockLlmClient.generateContent).toHaveBeenCalledTimes(3);
|
||||
});
|
||||
@@ -114,20 +120,20 @@ describe('NodeDistillationProcessor', () => {
|
||||
|
||||
const prompt = createDummyNode(
|
||||
'ep1',
|
||||
'USER_PROMPT',
|
||||
NodeType.USER_PROMPT,
|
||||
10,
|
||||
{
|
||||
semanticParts: [{ type: 'text', text: shortText }],
|
||||
payload: { text: shortText },
|
||||
},
|
||||
'prompt-id',
|
||||
) as UserPrompt;
|
||||
|
||||
const thought = createDummyNode(
|
||||
'ep1',
|
||||
'AGENT_THOUGHT',
|
||||
NodeType.AGENT_THOUGHT,
|
||||
13,
|
||||
{
|
||||
text: 'Short thought',
|
||||
payload: { text: 'Short thought' },
|
||||
},
|
||||
'thought-id',
|
||||
) as AgentThought;
|
||||
|
||||
@@ -6,10 +6,14 @@
|
||||
import { randomUUID } from 'node:crypto';
|
||||
import type { JSONSchemaType } from 'ajv';
|
||||
import type { ContextProcessor, ProcessArgs } from '../pipeline.js';
|
||||
import type { ConcreteNode } from '../graph/types.js';
|
||||
import { type ConcreteNode, NodeType } from '../graph/types.js';
|
||||
import type { ContextEnvironment } from '../pipeline/environment.js';
|
||||
import { debugLogger } from '../../utils/debugLogger.js';
|
||||
import { getResponseText } from '../../utils/partUtils.js';
|
||||
import {
|
||||
getResponseText,
|
||||
updatePart,
|
||||
cloneFunctionResponse,
|
||||
} from '../../utils/partUtils.js';
|
||||
import { LlmRole } from '../../telemetry/llmRole.js';
|
||||
|
||||
export interface NodeDistillationProcessorOptions {
|
||||
@@ -56,7 +60,7 @@ export function createNodeDistillationProcessor(
|
||||
},
|
||||
});
|
||||
return getResponseText(response) || text;
|
||||
} catch (e) {
|
||||
} catch (e: unknown) {
|
||||
debugLogger.warn(
|
||||
`NodeDistillationProcessor failed to summarize ${contextInfo}`,
|
||||
e,
|
||||
@@ -77,58 +81,31 @@ export function createNodeDistillationProcessor(
|
||||
|
||||
// Scan the target working buffer and unconditionally apply the configured hyperparameter threshold
|
||||
for (const node of targets) {
|
||||
const payload = node.payload;
|
||||
|
||||
switch (node.type) {
|
||||
case 'USER_PROMPT': {
|
||||
let modified = false;
|
||||
const newParts = [...node.semanticParts];
|
||||
|
||||
for (let j = 0; j < node.semanticParts.length; j++) {
|
||||
const part = node.semanticParts[j];
|
||||
if (part.type !== 'text') continue;
|
||||
|
||||
if (part.text.length > thresholdChars) {
|
||||
const summary = await generateSummary(part.text, 'User Prompt');
|
||||
const newTokens = env.tokenCalculator.estimateTokensForParts([
|
||||
{ text: summary },
|
||||
]);
|
||||
const oldTokens = env.tokenCalculator.estimateTokensForParts([
|
||||
{ text: part.text },
|
||||
]);
|
||||
|
||||
if (newTokens < oldTokens) {
|
||||
newParts[j] = { type: 'text', text: summary };
|
||||
modified = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (modified) {
|
||||
returnedNodes.push({
|
||||
...node,
|
||||
id: randomUUID(),
|
||||
semanticParts: newParts,
|
||||
replacesId: node.id,
|
||||
});
|
||||
} else {
|
||||
returnedNodes.push(node);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case 'AGENT_THOUGHT': {
|
||||
if (node.text.length > thresholdChars) {
|
||||
const summary = await generateSummary(node.text, 'Agent Thought');
|
||||
case NodeType.USER_PROMPT:
|
||||
case NodeType.AGENT_THOUGHT: {
|
||||
const text = payload.text;
|
||||
if (text && text.length > thresholdChars) {
|
||||
const summary = await generateSummary(text, node.type);
|
||||
const newTokens = env.tokenCalculator.estimateTokensForParts([
|
||||
{ text: summary },
|
||||
]);
|
||||
const oldTokens = env.tokenCalculator.getTokenCost(node);
|
||||
const oldTokens = env.tokenCalculator.estimateTokensForParts([
|
||||
{ text },
|
||||
]);
|
||||
|
||||
if (newTokens < oldTokens) {
|
||||
const distilledPayload = updatePart(payload, { text: summary });
|
||||
|
||||
returnedNodes.push({
|
||||
...node,
|
||||
id: randomUUID(),
|
||||
text: summary,
|
||||
payload: distilledPayload,
|
||||
replacesId: node.id,
|
||||
timestamp: node.timestamp,
|
||||
turnId: node.turnId,
|
||||
});
|
||||
break;
|
||||
}
|
||||
@@ -137,54 +114,60 @@ export function createNodeDistillationProcessor(
|
||||
break;
|
||||
}
|
||||
|
||||
case 'TOOL_EXECUTION': {
|
||||
const rawObs = node.observation;
|
||||
|
||||
let stringifiedObs = '';
|
||||
if (typeof rawObs === 'string') {
|
||||
stringifiedObs = rawObs;
|
||||
} else {
|
||||
try {
|
||||
stringifiedObs = JSON.stringify(rawObs);
|
||||
} catch {
|
||||
stringifiedObs = String(rawObs);
|
||||
case NodeType.TOOL_EXECUTION: {
|
||||
if (payload.functionResponse) {
|
||||
const rawObs = payload.functionResponse.response;
|
||||
let stringifiedObs = '';
|
||||
if (typeof rawObs === 'string') {
|
||||
stringifiedObs = rawObs;
|
||||
} else {
|
||||
try {
|
||||
stringifiedObs = JSON.stringify(rawObs);
|
||||
} catch {
|
||||
stringifiedObs = String(rawObs);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (stringifiedObs.length > thresholdChars) {
|
||||
const summary = await generateSummary(
|
||||
stringifiedObs,
|
||||
node.toolName || 'unknown',
|
||||
);
|
||||
const newObsObject = { summary };
|
||||
if (stringifiedObs.length > thresholdChars) {
|
||||
const summary = await generateSummary(
|
||||
stringifiedObs,
|
||||
payload.functionResponse.name || 'unknown',
|
||||
);
|
||||
const newObsObject = { summary };
|
||||
|
||||
const newObsTokens = env.tokenCalculator.estimateTokensForParts([
|
||||
{
|
||||
functionResponse: {
|
||||
name: node.toolName || 'unknown',
|
||||
response: newObsObject,
|
||||
id: node.id,
|
||||
},
|
||||
},
|
||||
]);
|
||||
const newFR = cloneFunctionResponse(payload.functionResponse);
|
||||
newFR.response = newObsObject;
|
||||
|
||||
const oldObsTokens =
|
||||
node.tokens?.observation ??
|
||||
env.tokenCalculator.getTokenCost(node);
|
||||
const intentTokens = node.tokens?.intent ?? 0;
|
||||
const newObsTokens = env.tokenCalculator.estimateTokensForParts(
|
||||
[
|
||||
{
|
||||
functionResponse: newFR,
|
||||
},
|
||||
],
|
||||
);
|
||||
|
||||
if (newObsTokens < oldObsTokens) {
|
||||
returnedNodes.push({
|
||||
...node,
|
||||
id: randomUUID(),
|
||||
observation: newObsObject,
|
||||
tokens: {
|
||||
intent: intentTokens,
|
||||
observation: newObsTokens,
|
||||
},
|
||||
replacesId: node.id,
|
||||
});
|
||||
break;
|
||||
const oldObsTokens = env.tokenCalculator.estimateTokensForParts(
|
||||
[payload],
|
||||
);
|
||||
|
||||
if (newObsTokens < oldObsTokens) {
|
||||
const newFR = cloneFunctionResponse(payload.functionResponse);
|
||||
newFR.response = newObsObject;
|
||||
|
||||
const distilledPayload = updatePart(payload, {
|
||||
functionResponse: newFR,
|
||||
});
|
||||
|
||||
returnedNodes.push({
|
||||
...node,
|
||||
id: randomUUID(),
|
||||
payload: distilledPayload,
|
||||
replacesId: node.id,
|
||||
timestamp: node.timestamp,
|
||||
turnId: node.turnId,
|
||||
});
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
returnedNodes.push(node);
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import assert from 'node:assert';
|
||||
import { describe, it, expect } from 'vitest';
|
||||
import { createNodeTruncationProcessor } from './nodeTruncationProcessor.js';
|
||||
import {
|
||||
@@ -12,7 +11,12 @@ import {
|
||||
createMockEnvironment,
|
||||
createDummyNode,
|
||||
} from '../testing/contextTestUtils.js';
|
||||
import type { UserPrompt, AgentThought, AgentYield } from '../graph/types.js';
|
||||
import {
|
||||
NodeType,
|
||||
type UserPrompt,
|
||||
type AgentThought,
|
||||
type AgentYield,
|
||||
} from '../graph/types.js';
|
||||
|
||||
describe('NodeTruncationProcessor', () => {
|
||||
it('should truncate nodes that exceed maxTokensPerNode', async () => {
|
||||
@@ -31,30 +35,30 @@ describe('NodeTruncationProcessor', () => {
|
||||
|
||||
const prompt = createDummyNode(
|
||||
'ep1',
|
||||
'USER_PROMPT',
|
||||
NodeType.USER_PROMPT,
|
||||
50,
|
||||
{
|
||||
semanticParts: [{ type: 'text', text: longText }],
|
||||
payload: { text: longText },
|
||||
},
|
||||
'prompt-id',
|
||||
) as UserPrompt;
|
||||
|
||||
const thought = createDummyNode(
|
||||
'ep1',
|
||||
'AGENT_THOUGHT',
|
||||
NodeType.AGENT_THOUGHT,
|
||||
50,
|
||||
{
|
||||
text: longText,
|
||||
payload: { text: longText },
|
||||
},
|
||||
'thought-id',
|
||||
) as AgentThought;
|
||||
|
||||
const yieldNode = createDummyNode(
|
||||
'ep1',
|
||||
'AGENT_YIELD',
|
||||
NodeType.AGENT_YIELD,
|
||||
50,
|
||||
{
|
||||
text: longText,
|
||||
payload: { text: longText },
|
||||
},
|
||||
'yield-id',
|
||||
) as AgentYield;
|
||||
@@ -68,19 +72,17 @@ describe('NodeTruncationProcessor', () => {
|
||||
// 1. User Prompt
|
||||
const squashedPrompt = result[0] as UserPrompt;
|
||||
expect(squashedPrompt.id).not.toBe(prompt.id);
|
||||
expect(squashedPrompt.semanticParts[0].type).toBe('text');
|
||||
assert(squashedPrompt.semanticParts[0].type === 'text');
|
||||
expect(squashedPrompt.semanticParts[0].text).toContain('[... OMITTED');
|
||||
expect(squashedPrompt.payload.text).toContain('[... OMITTED');
|
||||
|
||||
// 2. Agent Thought
|
||||
const squashedThought = result[1] as AgentThought;
|
||||
expect(squashedThought.id).not.toBe(thought.id);
|
||||
expect(squashedThought.text).toContain('[... OMITTED');
|
||||
expect(squashedThought.payload.text).toContain('[... OMITTED');
|
||||
|
||||
// 3. Agent Yield
|
||||
const squashedYield = result[2] as AgentYield;
|
||||
expect(squashedYield.id).not.toBe(yieldNode.id);
|
||||
expect(squashedYield.text).toContain('[... OMITTED');
|
||||
expect(squashedYield.payload.text).toContain('[... OMITTED');
|
||||
});
|
||||
|
||||
it('should ignore nodes that are below maxTokensPerNode', async () => {
|
||||
@@ -98,20 +100,20 @@ describe('NodeTruncationProcessor', () => {
|
||||
|
||||
const prompt = createDummyNode(
|
||||
'ep1',
|
||||
'USER_PROMPT',
|
||||
NodeType.USER_PROMPT,
|
||||
10,
|
||||
{
|
||||
semanticParts: [{ type: 'text', text: shortText }],
|
||||
payload: { text: shortText },
|
||||
},
|
||||
'prompt-id',
|
||||
) as UserPrompt;
|
||||
|
||||
const thought = createDummyNode(
|
||||
'ep1',
|
||||
'AGENT_THOUGHT',
|
||||
NodeType.AGENT_THOUGHT,
|
||||
13,
|
||||
{
|
||||
text: 'Short thought', // 13 chars
|
||||
payload: { text: 'Short thought' }, // 13 chars
|
||||
},
|
||||
'thought-id',
|
||||
) as AgentThought;
|
||||
@@ -125,12 +127,11 @@ describe('NodeTruncationProcessor', () => {
|
||||
// 1. User Prompt (untouched)
|
||||
const squashedPrompt = result[0] as UserPrompt;
|
||||
expect(squashedPrompt.id).toBe(prompt.id);
|
||||
assert(squashedPrompt.semanticParts[0].type === 'text');
|
||||
expect(squashedPrompt.semanticParts[0].text).not.toContain('[... OMITTED');
|
||||
expect(squashedPrompt.payload.text).not.toContain('[... OMITTED');
|
||||
|
||||
// 2. Agent Thought (untouched)
|
||||
const untouchedThought = result[1] as AgentThought;
|
||||
expect(untouchedThought.id).toBe(thought.id);
|
||||
expect(untouchedThought.text).not.toContain('[... OMITTED');
|
||||
expect(untouchedThought.payload.text).not.toContain('[... OMITTED');
|
||||
});
|
||||
});
|
||||
|
||||
@@ -73,69 +73,24 @@ export function createNodeTruncationProcessor(
|
||||
const returnedNodes: ConcreteNode[] = [];
|
||||
|
||||
for (const node of targets) {
|
||||
switch (node.type) {
|
||||
case 'USER_PROMPT': {
|
||||
let modified = false;
|
||||
const newParts = [...node.semanticParts];
|
||||
const payload = node.payload;
|
||||
const text = payload.text;
|
||||
|
||||
for (let j = 0; j < node.semanticParts.length; j++) {
|
||||
const part = node.semanticParts[j];
|
||||
if (part.type === 'text') {
|
||||
const squashResult = tryApplySquash(part.text, limitChars);
|
||||
if (squashResult) {
|
||||
newParts[j] = { type: 'text', text: squashResult.text };
|
||||
modified = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (modified) {
|
||||
returnedNodes.push({
|
||||
...node,
|
||||
id: randomUUID(),
|
||||
semanticParts: newParts,
|
||||
replacesId: node.id,
|
||||
});
|
||||
} else {
|
||||
returnedNodes.push(node);
|
||||
}
|
||||
break;
|
||||
if (text) {
|
||||
const squashResult = tryApplySquash(text, limitChars);
|
||||
if (squashResult) {
|
||||
returnedNodes.push({
|
||||
...node,
|
||||
id: randomUUID(),
|
||||
payload: { ...payload, text: squashResult.text },
|
||||
replacesId: node.id,
|
||||
turnId: node.turnId,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
case 'AGENT_THOUGHT': {
|
||||
const squashResult = tryApplySquash(node.text, limitChars);
|
||||
if (squashResult) {
|
||||
returnedNodes.push({
|
||||
...node,
|
||||
id: randomUUID(),
|
||||
text: squashResult.text,
|
||||
replacesId: node.id,
|
||||
});
|
||||
} else {
|
||||
returnedNodes.push(node);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case 'AGENT_YIELD': {
|
||||
const squashResult = tryApplySquash(node.text, limitChars);
|
||||
if (squashResult) {
|
||||
returnedNodes.push({
|
||||
...node,
|
||||
id: randomUUID(),
|
||||
text: squashResult.text,
|
||||
replacesId: node.id,
|
||||
});
|
||||
} else {
|
||||
returnedNodes.push(node);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
default:
|
||||
returnedNodes.push(node);
|
||||
break;
|
||||
}
|
||||
|
||||
returnedNodes.push(node);
|
||||
}
|
||||
|
||||
return returnedNodes;
|
||||
|
||||
@@ -10,6 +10,7 @@ import {
|
||||
createMockEnvironment,
|
||||
createDummyNode,
|
||||
} from '../testing/contextTestUtils.js';
|
||||
import { NodeType } from '../graph/types.js';
|
||||
|
||||
describe('RollingSummaryProcessor', () => {
|
||||
it('should initialize with correct default options', () => {
|
||||
@@ -43,13 +44,25 @@ describe('RollingSummaryProcessor', () => {
|
||||
const targets = [
|
||||
createDummyNode(
|
||||
'ep1',
|
||||
'USER_PROMPT',
|
||||
NodeType.USER_PROMPT,
|
||||
50,
|
||||
{ semanticParts: [{ type: 'text', text: text50 }] },
|
||||
{ payload: { text: text50 } },
|
||||
'id1',
|
||||
),
|
||||
createDummyNode('ep1', 'AGENT_THOUGHT', 50, { text: text50 }, 'id2'),
|
||||
createDummyNode('ep1', 'AGENT_YIELD', 50, { text: text50 }, 'id3'),
|
||||
createDummyNode(
|
||||
'ep1',
|
||||
NodeType.AGENT_THOUGHT,
|
||||
50,
|
||||
{ payload: { text: text50 } },
|
||||
'id2',
|
||||
),
|
||||
createDummyNode(
|
||||
'ep1',
|
||||
NodeType.AGENT_YIELD,
|
||||
50,
|
||||
{ payload: { text: text50 } },
|
||||
'id3',
|
||||
),
|
||||
];
|
||||
|
||||
const result = await processor.process(createMockProcessArgs(targets));
|
||||
@@ -59,8 +72,8 @@ describe('RollingSummaryProcessor', () => {
|
||||
// Node id2 adds 50 deficit. Node id3 adds 50 deficit. Total = 100 deficit, which hits the target break point.
|
||||
// Thus, id2 and id3 are summarized into a new ROLLING_SUMMARY node.
|
||||
expect(result.length).toBe(2);
|
||||
expect(result[0].type).toBe('USER_PROMPT');
|
||||
expect(result[1].type).toBe('ROLLING_SUMMARY');
|
||||
expect(result[0].type).toBe(NodeType.USER_PROMPT);
|
||||
expect(result[1].type).toBe(NodeType.ROLLING_SUMMARY);
|
||||
});
|
||||
|
||||
it('should preserve targets if deficit does not trigger summary', async () => {
|
||||
@@ -80,19 +93,25 @@ describe('RollingSummaryProcessor', () => {
|
||||
const targets = [
|
||||
createDummyNode(
|
||||
'ep1',
|
||||
'USER_PROMPT',
|
||||
NodeType.USER_PROMPT,
|
||||
10,
|
||||
{ semanticParts: [{ type: 'text', text: text10 }] },
|
||||
{ payload: { text: text10 } },
|
||||
'id1',
|
||||
),
|
||||
createDummyNode('ep1', 'AGENT_THOUGHT', 10, { text: text10 }, 'id2'),
|
||||
createDummyNode(
|
||||
'ep1',
|
||||
NodeType.AGENT_THOUGHT,
|
||||
10,
|
||||
{ payload: { text: text10 } },
|
||||
'id2',
|
||||
),
|
||||
];
|
||||
|
||||
const result = await processor.process(createMockProcessArgs(targets));
|
||||
|
||||
// Deficit accumulator reaches 10. This is < 100 limit, and total summarizable nodes < 2 anyway.
|
||||
expect(result.length).toBe(2);
|
||||
expect(result[0].type).toBe('USER_PROMPT');
|
||||
expect(result[1].type).toBe('AGENT_THOUGHT');
|
||||
expect(result[0].type).toBe(NodeType.USER_PROMPT);
|
||||
expect(result[1].type).toBe(NodeType.AGENT_THOUGHT);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -11,7 +11,11 @@ import type {
|
||||
BackstopTargetOptions,
|
||||
} from '../pipeline.js';
|
||||
import type { ContextEnvironment } from '../pipeline/environment.js';
|
||||
import type { ConcreteNode, RollingSummary } from '../graph/types.js';
|
||||
import {
|
||||
type ConcreteNode,
|
||||
type RollingSummary,
|
||||
NodeType,
|
||||
} from '../graph/types.js';
|
||||
import { debugLogger } from '../../utils/debugLogger.js';
|
||||
import { LlmRole } from '../../telemetry/llmRole.js';
|
||||
|
||||
@@ -45,16 +49,14 @@ export function createRollingSummaryProcessor(
|
||||
): Promise<string> => {
|
||||
let transcript = '';
|
||||
for (const node of nodes) {
|
||||
const payload = node.payload;
|
||||
let nodeContent = '';
|
||||
if ('text' in node && typeof node.text === 'string') {
|
||||
nodeContent = node.text;
|
||||
} else if ('semanticParts' in node) {
|
||||
nodeContent = JSON.stringify(node.semanticParts);
|
||||
} else if ('observation' in node) {
|
||||
nodeContent =
|
||||
typeof node.observation === 'string'
|
||||
? node.observation
|
||||
: JSON.stringify(node.observation);
|
||||
if (payload.text) {
|
||||
nodeContent = payload.text;
|
||||
} else if (payload.functionCall) {
|
||||
nodeContent = `CALL: ${payload.functionCall.name}(${JSON.stringify(payload.functionCall.args)})`;
|
||||
} else if (payload.functionResponse) {
|
||||
nodeContent = `RESPONSE: ${JSON.stringify(payload.functionResponse.response)}`;
|
||||
}
|
||||
transcript += `[${node.type}]: ${nodeContent}\n`;
|
||||
}
|
||||
@@ -125,10 +127,11 @@ export function createRollingSummaryProcessor(
|
||||
|
||||
const summaryNode: RollingSummary = {
|
||||
id: newId,
|
||||
logicalParentId: newId,
|
||||
type: 'ROLLING_SUMMARY',
|
||||
timestamp: Date.now(),
|
||||
text: snapshotText,
|
||||
turnId: newId,
|
||||
type: NodeType.ROLLING_SUMMARY,
|
||||
timestamp: nodesToSummarize[nodesToSummarize.length - 1].timestamp,
|
||||
role: 'user',
|
||||
payload: { text: snapshotText },
|
||||
abstractsIds: nodesToSummarize.map((n) => n.id),
|
||||
};
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ import {
|
||||
createDummyNode,
|
||||
createMockProcessArgs,
|
||||
} from '../testing/contextTestUtils.js';
|
||||
import { NodeType } from '../graph/types.js';
|
||||
import type { InboxMessage } from '../pipeline.js';
|
||||
import type { InboxSnapshotImpl } from '../pipeline/inbox.js';
|
||||
|
||||
@@ -25,8 +26,20 @@ describe('StateSnapshotAsyncProcessor', () => {
|
||||
{ type: 'point-in-time' },
|
||||
);
|
||||
|
||||
const nodeA = createDummyNode('ep1', 'USER_PROMPT', 50, {}, 'node-A');
|
||||
const nodeB = createDummyNode('ep1', 'AGENT_THOUGHT', 60, {}, 'node-B');
|
||||
const nodeA = createDummyNode(
|
||||
'ep1',
|
||||
NodeType.USER_PROMPT,
|
||||
50,
|
||||
{},
|
||||
'node-A',
|
||||
);
|
||||
const nodeB = createDummyNode(
|
||||
'ep1',
|
||||
NodeType.AGENT_THOUGHT,
|
||||
60,
|
||||
{},
|
||||
'node-B',
|
||||
);
|
||||
|
||||
const targets = [nodeA, nodeB];
|
||||
await worker.process(createMockProcessArgs(targets, targets, []));
|
||||
@@ -56,7 +69,13 @@ describe('StateSnapshotAsyncProcessor', () => {
|
||||
{ type: 'accumulate' },
|
||||
);
|
||||
|
||||
const nodeC = createDummyNode('ep2', 'USER_PROMPT', 50, {}, 'node-C');
|
||||
const nodeC = createDummyNode(
|
||||
'ep2',
|
||||
NodeType.USER_PROMPT,
|
||||
50,
|
||||
{},
|
||||
'node-C',
|
||||
);
|
||||
const targets = [nodeC];
|
||||
|
||||
const inboxMessages: InboxMessage[] = [
|
||||
|
||||
@@ -7,7 +7,7 @@ import { randomUUID } from 'node:crypto';
|
||||
import type { JSONSchemaType } from 'ajv';
|
||||
import type { AsyncContextProcessor, ProcessArgs } from '../pipeline.js';
|
||||
import type { ContextEnvironment } from '../pipeline/environment.js';
|
||||
import type { ConcreteNode } from '../graph/types.js';
|
||||
import { type ConcreteNode, NodeType } from '../graph/types.js';
|
||||
import { SnapshotGenerator } from '../utils/snapshotGenerator.js';
|
||||
import { debugLogger } from '../../utils/debugLogger.js';
|
||||
|
||||
@@ -73,13 +73,14 @@ export function createStateSnapshotAsyncProcessor(
|
||||
|
||||
previousConsumedIds = latest.payload.consumedIds;
|
||||
|
||||
// Prepend a synthetic node representing the previous rolling state
|
||||
const snapshotId = randomUUID();
|
||||
const previousStateNode: ConcreteNode = {
|
||||
id: randomUUID(),
|
||||
logicalParentId: '',
|
||||
type: 'SNAPSHOT',
|
||||
id: snapshotId,
|
||||
turnId: snapshotId,
|
||||
type: NodeType.SNAPSHOT,
|
||||
timestamp: latest.timestamp,
|
||||
text: latest.payload.newText,
|
||||
role: 'user',
|
||||
payload: { text: latest.payload.newText },
|
||||
};
|
||||
|
||||
nodesToSummarize = [previousStateNode, ...targets];
|
||||
@@ -101,6 +102,7 @@ export function createStateSnapshotAsyncProcessor(
|
||||
newText: snapshotText,
|
||||
consumedIds: newConsumedIds,
|
||||
type: processorType,
|
||||
timestamp: targets[targets.length - 1].timestamp,
|
||||
});
|
||||
} catch (e) {
|
||||
debugLogger.error(
|
||||
|
||||
@@ -10,6 +10,7 @@ import {
|
||||
createDummyNode,
|
||||
createMockProcessArgs,
|
||||
} from '../testing/contextTestUtils.js';
|
||||
import { NodeType } from '../graph/types.js';
|
||||
import type { InboxSnapshotImpl } from '../pipeline/inbox.js';
|
||||
|
||||
describe('StateSnapshotProcessor', () => {
|
||||
@@ -22,7 +23,7 @@ describe('StateSnapshotProcessor', () => {
|
||||
target: 'incremental',
|
||||
},
|
||||
);
|
||||
const targets = [createDummyNode('ep1', 'USER_PROMPT')];
|
||||
const targets = [createDummyNode('ep1', NodeType.USER_PROMPT)];
|
||||
const result = await processor.process(createMockProcessArgs(targets));
|
||||
expect(result).toBe(targets); // Strict equality
|
||||
});
|
||||
@@ -37,9 +38,27 @@ describe('StateSnapshotProcessor', () => {
|
||||
},
|
||||
);
|
||||
|
||||
const nodeA = createDummyNode('ep1', 'USER_PROMPT', 50, {}, 'node-A');
|
||||
const nodeB = createDummyNode('ep1', 'AGENT_THOUGHT', 60, {}, 'node-B');
|
||||
const nodeC = createDummyNode('ep2', 'USER_PROMPT', 50, {}, 'node-C');
|
||||
const nodeA = createDummyNode(
|
||||
'ep1',
|
||||
NodeType.USER_PROMPT,
|
||||
50,
|
||||
{},
|
||||
'node-A',
|
||||
);
|
||||
const nodeB = createDummyNode(
|
||||
'ep1',
|
||||
NodeType.AGENT_THOUGHT,
|
||||
60,
|
||||
{},
|
||||
'node-B',
|
||||
);
|
||||
const nodeC = createDummyNode(
|
||||
'ep2',
|
||||
NodeType.USER_PROMPT,
|
||||
50,
|
||||
{},
|
||||
'node-C',
|
||||
);
|
||||
|
||||
const targets = [nodeA, nodeB, nodeC];
|
||||
|
||||
@@ -62,7 +81,7 @@ describe('StateSnapshotProcessor', () => {
|
||||
|
||||
// Should remove A and B, insert Snapshot, keep C
|
||||
expect(result.length).toBe(2);
|
||||
expect(result[0].type).toBe('SNAPSHOT');
|
||||
expect(result[0].type).toBe(NodeType.SNAPSHOT);
|
||||
expect(result[1].id).toBe('node-C');
|
||||
|
||||
// Should consume the message
|
||||
@@ -83,7 +102,13 @@ describe('StateSnapshotProcessor', () => {
|
||||
// Make deficit 0 so we don't fall through to the sync backstop and fail the test that way
|
||||
|
||||
// node-A is MISSING (user deleted it)
|
||||
const nodeB = createDummyNode('ep1', 'AGENT_THOUGHT', 60, {}, 'node-B');
|
||||
const nodeB = createDummyNode(
|
||||
'ep1',
|
||||
NodeType.AGENT_THOUGHT,
|
||||
60,
|
||||
{},
|
||||
'node-B',
|
||||
);
|
||||
const targets = [nodeB];
|
||||
|
||||
const messages = [
|
||||
@@ -117,15 +142,33 @@ describe('StateSnapshotProcessor', () => {
|
||||
{ target: 'max' },
|
||||
); // Summarize all
|
||||
|
||||
const nodeA = createDummyNode('ep1', 'USER_PROMPT', 50, {}, 'node-A');
|
||||
const nodeB = createDummyNode('ep1', 'AGENT_THOUGHT', 60, {}, 'node-B');
|
||||
const nodeC = createDummyNode('ep2', 'USER_PROMPT', 50, {}, 'node-C');
|
||||
const nodeA = createDummyNode(
|
||||
'ep1',
|
||||
NodeType.USER_PROMPT,
|
||||
50,
|
||||
{},
|
||||
'node-A',
|
||||
);
|
||||
const nodeB = createDummyNode(
|
||||
'ep1',
|
||||
NodeType.AGENT_THOUGHT,
|
||||
60,
|
||||
{},
|
||||
'node-B',
|
||||
);
|
||||
const nodeC = createDummyNode(
|
||||
'ep2',
|
||||
NodeType.USER_PROMPT,
|
||||
50,
|
||||
{},
|
||||
'node-C',
|
||||
);
|
||||
const targets = [nodeA, nodeB, nodeC];
|
||||
const result = await processor.process(createMockProcessArgs(targets));
|
||||
|
||||
// Should synthesize a new snapshot synchronously
|
||||
expect(env.llmClient.generateContent).toHaveBeenCalled();
|
||||
expect(result.length).toBe(2); // nodeA is skipped as "system prompt", snapshot + nodeA
|
||||
expect(result[1].type).toBe('SNAPSHOT');
|
||||
expect(result.length).toBe(1); // nodeA is no longer protected, so everything is snapshotted
|
||||
expect(result[0].type).toBe(NodeType.SNAPSHOT);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -11,7 +11,7 @@ import type {
|
||||
BackstopTargetOptions,
|
||||
} from '../pipeline.js';
|
||||
import type { ContextEnvironment } from '../pipeline/environment.js';
|
||||
import type { ConcreteNode, Snapshot } from '../graph/types.js';
|
||||
import { type ConcreteNode, type Snapshot, NodeType } from '../graph/types.js';
|
||||
import { SnapshotGenerator } from '../utils/snapshotGenerator.js';
|
||||
import { debugLogger } from '../../utils/debugLogger.js';
|
||||
|
||||
@@ -61,6 +61,7 @@ export function createStateSnapshotProcessor(
|
||||
newText: string;
|
||||
consumedIds: string[];
|
||||
type: string;
|
||||
timestamp: number;
|
||||
}>('PROPOSED_SNAPSHOT');
|
||||
|
||||
if (proposedSnapshots.length > 0) {
|
||||
@@ -75,7 +76,7 @@ export function createStateSnapshotProcessor(
|
||||
);
|
||||
|
||||
for (const proposed of sorted) {
|
||||
const { consumedIds, newText } = proposed.payload;
|
||||
const { consumedIds, newText, timestamp } = proposed.payload;
|
||||
|
||||
// Verify all consumed IDs still exist sequentially in targets
|
||||
const targetIds = new Set(targets.map((t) => t.id));
|
||||
@@ -87,10 +88,11 @@ export function createStateSnapshotProcessor(
|
||||
|
||||
const snapshotNode: Snapshot = {
|
||||
id: newId,
|
||||
logicalParentId: newId,
|
||||
type: 'SNAPSHOT',
|
||||
timestamp: Date.now(),
|
||||
text: newText,
|
||||
turnId: newId,
|
||||
type: NodeType.SNAPSHOT,
|
||||
timestamp: timestamp ?? Date.now(),
|
||||
role: 'user',
|
||||
payload: { text: newText },
|
||||
abstractsIds: consumedIds,
|
||||
};
|
||||
|
||||
@@ -131,12 +133,6 @@ export function createStateSnapshotProcessor(
|
||||
|
||||
// Scan oldest to newest
|
||||
for (const node of targets) {
|
||||
if (node.id === targets[0].id && node.type === 'USER_PROMPT') {
|
||||
// Keep system prompt if it's the very first node
|
||||
// In a real system, system prompt is protected, but we double check
|
||||
continue;
|
||||
}
|
||||
|
||||
nodesToSummarize.push(node);
|
||||
deficitAccumulator += env.tokenCalculator.getTokenCost(node);
|
||||
|
||||
@@ -153,10 +149,11 @@ export function createStateSnapshotProcessor(
|
||||
const newId = randomUUID();
|
||||
const snapshotNode: Snapshot = {
|
||||
id: newId,
|
||||
logicalParentId: newId,
|
||||
type: 'SNAPSHOT',
|
||||
timestamp: Date.now(),
|
||||
text: snapshotText,
|
||||
turnId: newId,
|
||||
type: NodeType.SNAPSHOT,
|
||||
timestamp: nodesToSummarize[nodesToSummarize.length - 1].timestamp,
|
||||
role: 'user',
|
||||
payload: { text: snapshotText },
|
||||
abstractsIds: nodesToSummarize.map((n) => n.id),
|
||||
};
|
||||
|
||||
|
||||
@@ -25,9 +25,16 @@ describe('ToolMaskingProcessor', () => {
|
||||
const longString = 'A'.repeat(500); // 500 chars
|
||||
|
||||
const toolStep = createDummyToolNode('ep1', 50, 500, {
|
||||
observation: {
|
||||
result: longString,
|
||||
metadata: 'short', // 5 chars, will not be masked
|
||||
role: 'model',
|
||||
payload: {
|
||||
functionResponse: {
|
||||
name: 'dummy_tool',
|
||||
id: 'dummy_id',
|
||||
response: {
|
||||
result: longString,
|
||||
metadata: 'short', // 5 chars, will not be masked
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
@@ -40,7 +47,10 @@ describe('ToolMaskingProcessor', () => {
|
||||
expect(masked.id).not.toBe(toolStep.id);
|
||||
|
||||
// It should have masked the observation
|
||||
const obs = masked.observation as { result: string; metadata: string };
|
||||
const obs = masked.payload.functionResponse?.response as {
|
||||
result: string;
|
||||
metadata: string;
|
||||
};
|
||||
expect(obs.result).toContain('<tool_output_masked>');
|
||||
expect(obs.metadata).toBe('short'); // Untouched
|
||||
});
|
||||
@@ -53,10 +63,15 @@ describe('ToolMaskingProcessor', () => {
|
||||
});
|
||||
|
||||
const toolStep = createDummyToolNode('ep1', 10, 10, {
|
||||
toolName: 'activate_skill',
|
||||
observation: {
|
||||
result:
|
||||
'this is a really long string that normally would get masked but wont because of the tool name',
|
||||
payload: {
|
||||
functionCall: {
|
||||
name: 'activate_skill',
|
||||
id: 'dummy_id',
|
||||
args: {
|
||||
result:
|
||||
'this is a really long string that normally would get masked but wont because of the tool name',
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
@@ -76,23 +91,49 @@ describe('ToolMaskingProcessor', () => {
|
||||
const longString = 'A'.repeat(500);
|
||||
|
||||
const toolStep = createDummyToolNode('ep1', 50, 500, {
|
||||
intent: originalIntent,
|
||||
observation: {
|
||||
result: longString,
|
||||
payload: {
|
||||
functionCall: {
|
||||
name: 'ls',
|
||||
id: 'call_123',
|
||||
args: originalIntent,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const result = await processor.process(createMockProcessArgs([toolStep]));
|
||||
// We also need a response node if we want to test "observation is masked"
|
||||
// Wait, the test says "strictly preserve the original intent args when only the observation is masked"
|
||||
// But ToolMaskingProcessor processes nodes individually now.
|
||||
// If we have a ToolExecution node with a functionCall, it masks the args.
|
||||
// If we have a ToolExecution node with a functionResponse, it masks the response.
|
||||
|
||||
expect(result.length).toBe(1);
|
||||
const masked = result[0] as ToolExecution;
|
||||
const responseStep = createDummyToolNode('ep1', 50, 500, {
|
||||
payload: {
|
||||
functionResponse: {
|
||||
name: 'ls',
|
||||
id: 'call_123',
|
||||
response: {
|
||||
result: longString,
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
expect(masked.id).not.toBe(toolStep.id);
|
||||
const result = await processor.process(
|
||||
createMockProcessArgs([toolStep, responseStep]),
|
||||
);
|
||||
|
||||
const obs = masked.observation as { result: string };
|
||||
expect(result.length).toBe(2);
|
||||
const maskedCall = result[0] as ToolExecution;
|
||||
const maskedObs = result[1] as ToolExecution;
|
||||
|
||||
// Intent was short, so it should be the same node (or at least same content)
|
||||
expect(maskedCall.payload.functionCall?.args).toEqual(originalIntent);
|
||||
|
||||
// Observation was long, so it should be masked
|
||||
expect(maskedObs.id).not.toBe(responseStep.id);
|
||||
const obs = maskedObs.payload.functionResponse?.response as {
|
||||
result: string;
|
||||
};
|
||||
expect(obs.result).toContain('<tool_output_masked>');
|
||||
|
||||
// The intent MUST be perfectly preserved and not fall back to {} or undefined incorrectly
|
||||
expect(masked.intent).toEqual(originalIntent);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -8,7 +8,7 @@ import type { JSONSchemaType } from 'ajv';
|
||||
import type { ContextProcessor, ProcessArgs } from '../pipeline.js';
|
||||
import * as fs from 'node:fs/promises';
|
||||
import * as path from 'node:path';
|
||||
import type { ConcreteNode, ToolExecution } from '../graph/types.js';
|
||||
import type { ConcreteNode } from '../graph/types.js';
|
||||
import type { ContextEnvironment } from '../pipeline/environment.js';
|
||||
import { sanitizeFilenamePart } from '../../utils/fileUtils.js';
|
||||
import {
|
||||
@@ -18,7 +18,11 @@ import {
|
||||
ENTER_PLAN_MODE_TOOL_NAME,
|
||||
EXIT_PLAN_MODE_TOOL_NAME,
|
||||
} from '../../tools/tool-names.js';
|
||||
import type { Part } from '@google/genai';
|
||||
import {
|
||||
updatePart,
|
||||
cloneFunctionCall,
|
||||
cloneFunctionResponse,
|
||||
} from '../../utils/partUtils.js';
|
||||
|
||||
export interface ToolMaskingProcessorOptions {
|
||||
stringLengthThresholdTokens: number;
|
||||
@@ -138,149 +142,121 @@ export function createToolMaskingProcessor(
|
||||
const returnedNodes: ConcreteNode[] = [];
|
||||
|
||||
for (const node of targets) {
|
||||
switch (node.type) {
|
||||
case 'TOOL_EXECUTION': {
|
||||
const toolName = node.toolName;
|
||||
if (toolName && UNMASKABLE_TOOLS.has(toolName)) {
|
||||
returnedNodes.push(node);
|
||||
break;
|
||||
}
|
||||
|
||||
const callId = node.id || Date.now().toString();
|
||||
|
||||
const maskAsync = async (
|
||||
obj: MaskableValue,
|
||||
nodeType: string,
|
||||
): Promise<{ masked: MaskableValue; changed: boolean }> => {
|
||||
if (typeof obj === 'string') {
|
||||
if (obj.length > limitChars && !isAlreadyMasked(obj)) {
|
||||
const newString = await handleMasking(
|
||||
obj,
|
||||
toolName || 'unknown',
|
||||
callId,
|
||||
nodeType,
|
||||
);
|
||||
return { masked: newString, changed: true };
|
||||
}
|
||||
return { masked: obj, changed: false };
|
||||
}
|
||||
if (Array.isArray(obj)) {
|
||||
let changed = false;
|
||||
const masked: MaskableValue[] = [];
|
||||
for (const item of obj) {
|
||||
const res = await maskAsync(item, nodeType);
|
||||
if (res.changed) changed = true;
|
||||
masked.push(res.masked);
|
||||
}
|
||||
return { masked, changed };
|
||||
}
|
||||
if (typeof obj === 'object' && obj !== null) {
|
||||
let changed = false;
|
||||
const masked: Record<string, MaskableValue> = {};
|
||||
for (const [key, value] of Object.entries(obj)) {
|
||||
const res = await maskAsync(value, nodeType);
|
||||
if (res.changed) changed = true;
|
||||
masked[key] = res.masked;
|
||||
}
|
||||
return { masked, changed };
|
||||
}
|
||||
return { masked: obj, changed: false };
|
||||
};
|
||||
|
||||
const rawIntent = node.intent;
|
||||
const rawObs = node.observation;
|
||||
|
||||
if (!isMaskableRecord(rawIntent) || !isMaskableValue(rawObs)) {
|
||||
returnedNodes.push(node);
|
||||
break;
|
||||
}
|
||||
|
||||
const intentRes = await maskAsync(rawIntent, 'intent');
|
||||
const obsRes = await maskAsync(rawObs, 'observation');
|
||||
|
||||
if (intentRes.changed || obsRes.changed) {
|
||||
const maskedIntent = isMaskableRecord(intentRes.masked)
|
||||
? (intentRes.masked as Record<string, unknown>)
|
||||
: undefined;
|
||||
// Ensure we strictly preserve the original intent if it was unchanged and is a record
|
||||
const finalIntent = intentRes.changed
|
||||
? maskedIntent
|
||||
: isMaskableRecord(rawIntent)
|
||||
? (rawIntent as Record<string, unknown>)
|
||||
: undefined;
|
||||
|
||||
// Handle observation explicitly as string vs object
|
||||
const maskedObs =
|
||||
typeof obsRes.masked === 'string'
|
||||
? ({ message: obsRes.masked } as Record<string, unknown>)
|
||||
: isMaskableRecord(obsRes.masked)
|
||||
? (obsRes.masked as Record<string, unknown>)
|
||||
: undefined;
|
||||
// Ensure we strictly preserve the original observation if it was unchanged
|
||||
const finalObs = obsRes.changed
|
||||
? maskedObs
|
||||
: typeof rawObs === 'string'
|
||||
? ({ message: rawObs } as Record<string, unknown>)
|
||||
: isMaskableRecord(rawObs)
|
||||
? (rawObs as Record<string, unknown>)
|
||||
: undefined;
|
||||
|
||||
const newIntentTokens =
|
||||
env.tokenCalculator.estimateTokensForParts([
|
||||
{
|
||||
functionCall: {
|
||||
name: toolName || 'unknown',
|
||||
args: finalIntent,
|
||||
id: callId,
|
||||
},
|
||||
},
|
||||
]);
|
||||
|
||||
let obsPart: Record<string, unknown> = {};
|
||||
if (maskedObs) {
|
||||
obsPart = {
|
||||
functionResponse: {
|
||||
name: toolName || 'unknown',
|
||||
response: finalObs,
|
||||
id: callId,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
const newObsTokens = env.tokenCalculator.estimateTokensForParts([
|
||||
obsPart as Part,
|
||||
]);
|
||||
|
||||
const tokensSaved =
|
||||
env.tokenCalculator.getTokenCost(node) -
|
||||
(newIntentTokens + newObsTokens);
|
||||
|
||||
if (tokensSaved > 0) {
|
||||
const maskedNode: ToolExecution = {
|
||||
...node,
|
||||
id: randomUUID(), // Modified, so generate new ID
|
||||
intent: finalIntent ?? node.intent,
|
||||
observation: finalObs ?? node.observation,
|
||||
tokens: {
|
||||
intent: newIntentTokens,
|
||||
observation: newObsTokens,
|
||||
},
|
||||
replacesId: node.id,
|
||||
};
|
||||
|
||||
returnedNodes.push(maskedNode);
|
||||
} else {
|
||||
returnedNodes.push(node);
|
||||
}
|
||||
} else {
|
||||
returnedNodes.push(node);
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
returnedNodes.push(node);
|
||||
break;
|
||||
if (node.type !== 'TOOL_EXECUTION') {
|
||||
returnedNodes.push(node);
|
||||
continue;
|
||||
}
|
||||
|
||||
const payload = node.payload;
|
||||
const toolName =
|
||||
payload.functionCall?.name || payload.functionResponse?.name;
|
||||
|
||||
if (toolName && UNMASKABLE_TOOLS.has(toolName)) {
|
||||
returnedNodes.push(node);
|
||||
continue;
|
||||
}
|
||||
|
||||
const callId =
|
||||
payload.functionCall?.id || payload.functionResponse?.id || 'unknown';
|
||||
|
||||
const maskAsync = async (
|
||||
obj: MaskableValue,
|
||||
nodeType: string,
|
||||
): Promise<{ masked: MaskableValue; changed: boolean }> => {
|
||||
if (typeof obj === 'string') {
|
||||
if (obj.length > limitChars && !isAlreadyMasked(obj)) {
|
||||
const newString = await handleMasking(
|
||||
obj,
|
||||
toolName || 'unknown',
|
||||
callId,
|
||||
nodeType,
|
||||
);
|
||||
return { masked: newString, changed: true };
|
||||
}
|
||||
return { masked: obj, changed: false };
|
||||
}
|
||||
if (Array.isArray(obj)) {
|
||||
let changed = false;
|
||||
const masked: MaskableValue[] = [];
|
||||
for (const item of obj) {
|
||||
const res = await maskAsync(item, nodeType);
|
||||
if (res.changed) changed = true;
|
||||
masked.push(res.masked);
|
||||
}
|
||||
return { masked, changed };
|
||||
}
|
||||
if (typeof obj === 'object' && obj !== null) {
|
||||
let changed = false;
|
||||
const masked: Record<string, MaskableValue> = {};
|
||||
for (const [key, value] of Object.entries(obj)) {
|
||||
const res = await maskAsync(value, nodeType);
|
||||
if (res.changed) changed = true;
|
||||
masked[key] = res.masked;
|
||||
}
|
||||
return { masked, changed };
|
||||
}
|
||||
return { masked: obj, changed: false };
|
||||
};
|
||||
|
||||
if (payload.functionCall) {
|
||||
const rawIntent = payload.functionCall.args;
|
||||
if (isMaskableRecord(rawIntent)) {
|
||||
const res = await maskAsync(rawIntent, 'intent');
|
||||
if (res.changed) {
|
||||
const newFC = cloneFunctionCall(payload.functionCall);
|
||||
let maskedRecord: Record<string, unknown>;
|
||||
if (isMaskableRecord(res.masked)) {
|
||||
maskedRecord = res.masked;
|
||||
} else {
|
||||
maskedRecord = { message: String(res.masked) };
|
||||
}
|
||||
newFC.args = maskedRecord;
|
||||
|
||||
const maskedPart = updatePart(payload, {
|
||||
functionCall: newFC,
|
||||
});
|
||||
|
||||
returnedNodes.push({
|
||||
...node,
|
||||
id: randomUUID(),
|
||||
payload: maskedPart,
|
||||
replacesId: node.id,
|
||||
turnId: node.turnId,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
}
|
||||
} else if (payload.functionResponse) {
|
||||
const rawObs = payload.functionResponse.response;
|
||||
if (isMaskableValue(rawObs)) {
|
||||
const res = await maskAsync(rawObs, 'observation');
|
||||
if (res.changed) {
|
||||
const newFR = cloneFunctionResponse(payload.functionResponse);
|
||||
let maskedRecord: Record<string, unknown>;
|
||||
if (isMaskableRecord(res.masked)) {
|
||||
maskedRecord = res.masked;
|
||||
} else {
|
||||
maskedRecord = { message: String(res.masked) };
|
||||
}
|
||||
newFR.response = maskedRecord;
|
||||
|
||||
const maskedPart = updatePart(payload, {
|
||||
functionResponse: newFR,
|
||||
});
|
||||
|
||||
returnedNodes.push({
|
||||
...node,
|
||||
id: randomUUID(),
|
||||
payload: maskedPart,
|
||||
replacesId: node.id,
|
||||
turnId: node.turnId,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
returnedNodes.push(node);
|
||||
}
|
||||
|
||||
return returnedNodes;
|
||||
|
||||
@@ -3,6 +3,14 @@
|
||||
exports[`System Lifecycle Golden Tests > Scenario 1: Organic Growth with Huge Tool Output & Images 1`] = `
|
||||
{
|
||||
"finalProjection": [
|
||||
{
|
||||
"parts": [
|
||||
{
|
||||
"text": "[Continuing from previous AI thoughts...]",
|
||||
},
|
||||
],
|
||||
"role": "user",
|
||||
},
|
||||
{
|
||||
"parts": [
|
||||
{
|
||||
@@ -27,31 +35,39 @@ exports[`System Lifecycle Golden Tests > Scenario 1: Organic Growth with Huge To
|
||||
],
|
||||
"role": "model",
|
||||
},
|
||||
{
|
||||
"parts": [
|
||||
{
|
||||
"text": "Please continue.",
|
||||
},
|
||||
],
|
||||
"role": "user",
|
||||
},
|
||||
],
|
||||
"tokenTrajectory": [
|
||||
{
|
||||
"tokensAfterBackground": 6,
|
||||
"tokensBeforeBackground": 6,
|
||||
"tokensAfterBackground": 17,
|
||||
"tokensBeforeBackground": 17,
|
||||
"turnIndex": 0,
|
||||
},
|
||||
{
|
||||
"tokensAfterBackground": 11,
|
||||
"tokensBeforeBackground": 11,
|
||||
"tokensAfterBackground": 34,
|
||||
"tokensBeforeBackground": 34,
|
||||
"turnIndex": 1,
|
||||
},
|
||||
{
|
||||
"tokensAfterBackground": 458,
|
||||
"tokensBeforeBackground": 20170,
|
||||
"tokensAfterBackground": 327,
|
||||
"tokensBeforeBackground": 20172,
|
||||
"turnIndex": 2,
|
||||
},
|
||||
{
|
||||
"tokensAfterBackground": 61,
|
||||
"tokensBeforeBackground": 3017,
|
||||
"tokensAfterBackground": 93,
|
||||
"tokensBeforeBackground": 3037,
|
||||
"turnIndex": 3,
|
||||
},
|
||||
{
|
||||
"tokensAfterBackground": 10,
|
||||
"tokensBeforeBackground": 10,
|
||||
"tokensAfterBackground": 27,
|
||||
"tokensBeforeBackground": 27,
|
||||
"turnIndex": 4,
|
||||
},
|
||||
],
|
||||
@@ -93,16 +109,24 @@ exports[`System Lifecycle Golden Tests > Scenario 2: Under Budget (No Modificati
|
||||
],
|
||||
"role": "model",
|
||||
},
|
||||
{
|
||||
"parts": [
|
||||
{
|
||||
"text": "Please continue.",
|
||||
},
|
||||
],
|
||||
"role": "user",
|
||||
},
|
||||
],
|
||||
"tokenTrajectory": [
|
||||
{
|
||||
"tokensAfterBackground": 6,
|
||||
"tokensBeforeBackground": 6,
|
||||
"tokensAfterBackground": 17,
|
||||
"tokensBeforeBackground": 17,
|
||||
"turnIndex": 0,
|
||||
},
|
||||
{
|
||||
"tokensAfterBackground": 11,
|
||||
"tokensBeforeBackground": 11,
|
||||
"tokensAfterBackground": 34,
|
||||
"tokensBeforeBackground": 34,
|
||||
"turnIndex": 1,
|
||||
},
|
||||
],
|
||||
@@ -160,21 +184,29 @@ exports[`System Lifecycle Golden Tests > Scenario 3: Async-Driven Background GC
|
||||
],
|
||||
"role": "model",
|
||||
},
|
||||
{
|
||||
"parts": [
|
||||
{
|
||||
"text": "Please continue.",
|
||||
},
|
||||
],
|
||||
"role": "user",
|
||||
},
|
||||
],
|
||||
"tokenTrajectory": [
|
||||
{
|
||||
"tokensAfterBackground": 25,
|
||||
"tokensBeforeBackground": 25,
|
||||
"tokensAfterBackground": 42,
|
||||
"tokensBeforeBackground": 42,
|
||||
"turnIndex": 0,
|
||||
},
|
||||
{
|
||||
"tokensAfterBackground": 49,
|
||||
"tokensBeforeBackground": 49,
|
||||
"tokensAfterBackground": 84,
|
||||
"tokensBeforeBackground": 84,
|
||||
"turnIndex": 1,
|
||||
},
|
||||
{
|
||||
"tokensAfterBackground": 73,
|
||||
"tokensBeforeBackground": 73,
|
||||
"tokensAfterBackground": 126,
|
||||
"tokensBeforeBackground": 126,
|
||||
"turnIndex": 2,
|
||||
},
|
||||
],
|
||||
|
||||
@@ -18,14 +18,24 @@ import { createStateSnapshotAsyncProcessor } from '../processors/stateSnapshotAs
|
||||
expect.addSnapshotSerializer({
|
||||
test: (val) =>
|
||||
typeof val === 'string' &&
|
||||
(/^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$/i.test(
|
||||
(/[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}/i.test(
|
||||
val,
|
||||
) ||
|
||||
/^\/tmp\/sim/.test(val)), // Mask temp directories and UUIDs
|
||||
print: (val) =>
|
||||
typeof val === 'string' && /^\/tmp\/sim/.test(val)
|
||||
? '"<MOCKED_DIR>"'
|
||||
: '"<UUID>"',
|
||||
/[\\/]tmp[\\/]sim/.test(val)),
|
||||
print: (val) => {
|
||||
if (typeof val !== 'string') return `"${val}"`;
|
||||
let scrubbed = val
|
||||
.replace(
|
||||
/[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}/gi,
|
||||
'<UUID>',
|
||||
)
|
||||
.replace(/[\\/]tmp[\\/]sim[^\s"'\]]*/g, '<MOCKED_DIR>');
|
||||
|
||||
// Also scrub timestamps in filenames like blob_1234567890_...
|
||||
scrubbed = scrubbed.replace(/blob_\d+_/g, 'blob_<TIMESTAMP>_');
|
||||
|
||||
return `"${scrubbed}"`;
|
||||
},
|
||||
});
|
||||
|
||||
describe('System Lifecycle Golden Tests', () => {
|
||||
@@ -43,6 +53,7 @@ describe('System Lifecycle Golden Tests', () => {
|
||||
});
|
||||
|
||||
const getAggressiveConfig = (): ContextProfile => ({
|
||||
name: 'Aggressive Test',
|
||||
config: {
|
||||
budget: { maxTokens: 1000, retainedTokens: 500 }, // Extremely tight limits
|
||||
},
|
||||
@@ -170,6 +181,7 @@ describe('System Lifecycle Golden Tests', () => {
|
||||
|
||||
it('Scenario 2: Under Budget (No Modifications)', async () => {
|
||||
const generousConfig: ContextProfile = {
|
||||
name: 'Generous Config',
|
||||
config: {
|
||||
budget: { maxTokens: 100000, retainedTokens: 50000 },
|
||||
},
|
||||
@@ -202,6 +214,7 @@ describe('System Lifecycle Golden Tests', () => {
|
||||
|
||||
it('Scenario 3: Async-Driven Background GC', async () => {
|
||||
const gcConfig: ContextProfile = {
|
||||
name: 'GC Test Config',
|
||||
config: {
|
||||
budget: { maxTokens: 200, retainedTokens: 100 },
|
||||
},
|
||||
|
||||
@@ -148,7 +148,8 @@ export class SimulationHarness {
|
||||
}
|
||||
|
||||
async getGoldenState() {
|
||||
const finalProjection = await this.contextManager.renderHistory();
|
||||
const { history: finalProjection } =
|
||||
await this.contextManager.renderHistory();
|
||||
return {
|
||||
tokenTrajectory: this.tokenTrajectory,
|
||||
finalProjection,
|
||||
|
||||
@@ -12,7 +12,11 @@ import { ContextTracer } from '../tracer.js';
|
||||
import { ContextEnvironmentImpl } from '../pipeline/environmentImpl.js';
|
||||
import { ContextEventBus } from '../eventBus.js';
|
||||
import { PipelineOrchestrator } from '../pipeline/orchestrator.js';
|
||||
import type { ConcreteNode, ToolExecution } from '../graph/types.js';
|
||||
import {
|
||||
type ConcreteNode,
|
||||
type ToolExecution,
|
||||
NodeType,
|
||||
} from '../graph/types.js';
|
||||
import type { ContextEnvironment } from '../pipeline/environment.js';
|
||||
import type { Config } from '../../config/config.js';
|
||||
import type { BaseLlmClient } from '../../core/baseLlmClient.js';
|
||||
@@ -37,57 +41,56 @@ export const createMockGenerateContentResponse = (
|
||||
}) as GenerateContentResponse;
|
||||
|
||||
export function createDummyNode(
|
||||
logicalParentId: string,
|
||||
type: ConcreteNode['type'],
|
||||
tokens = 100,
|
||||
turnId: string,
|
||||
type: NodeType,
|
||||
_tokens = 100,
|
||||
overrides?: Partial<ConcreteNode>,
|
||||
id?: string,
|
||||
): ConcreteNode {
|
||||
const role =
|
||||
type === NodeType.USER_PROMPT ||
|
||||
type === NodeType.SYSTEM_EVENT ||
|
||||
type === NodeType.SNAPSHOT ||
|
||||
type === NodeType.ROLLING_SUMMARY
|
||||
? 'user'
|
||||
: 'model';
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
|
||||
return {
|
||||
id: id || randomUUID(),
|
||||
episodeId: logicalParentId,
|
||||
logicalParentId,
|
||||
turnId,
|
||||
type,
|
||||
timestamp: Date.now(),
|
||||
text: `Dummy ${type}`,
|
||||
name: type === 'SYSTEM_EVENT' ? 'dummy_event' : undefined,
|
||||
payload: type === 'SYSTEM_EVENT' ? {} : undefined,
|
||||
semanticParts: [],
|
||||
metadata: {
|
||||
originalTokens: tokens,
|
||||
currentTokens: tokens,
|
||||
transformations: [],
|
||||
},
|
||||
role,
|
||||
payload: { text: `Dummy ${type}` },
|
||||
...overrides,
|
||||
} as unknown as ConcreteNode;
|
||||
}
|
||||
|
||||
export function createDummyToolNode(
|
||||
logicalParentId: string,
|
||||
intentTokens = 100,
|
||||
obsTokens = 200,
|
||||
turnId: string,
|
||||
_intentTokens = 100,
|
||||
_obsTokens = 200,
|
||||
overrides?: Partial<ToolExecution>,
|
||||
id?: string,
|
||||
): ToolExecution {
|
||||
// We don't distinguish between call and response here, but ToolExecution nodes in 1:1 map to ONE part.
|
||||
// Tests using this usually want to simulate a tool interaction.
|
||||
// For simplicity, we'll make this a 'model' tool call by default.
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
|
||||
return {
|
||||
id: id || randomUUID(),
|
||||
episodeId: logicalParentId,
|
||||
logicalParentId,
|
||||
type: 'TOOL_EXECUTION',
|
||||
turnId,
|
||||
type: NodeType.TOOL_EXECUTION,
|
||||
timestamp: Date.now(),
|
||||
toolName: 'dummy_tool',
|
||||
intent: { action: 'test' },
|
||||
observation: { result: 'ok' },
|
||||
tokens: {
|
||||
intent: intentTokens,
|
||||
observation: obsTokens,
|
||||
},
|
||||
metadata: {
|
||||
originalTokens: intentTokens + obsTokens,
|
||||
currentTokens: intentTokens + obsTokens,
|
||||
transformations: [],
|
||||
role: 'model',
|
||||
payload: {
|
||||
functionCall: {
|
||||
name: 'dummy_tool',
|
||||
args: { action: 'test' },
|
||||
id: id || 'dummy_id',
|
||||
},
|
||||
},
|
||||
...overrides,
|
||||
} as unknown as ToolExecution;
|
||||
|
||||
@@ -9,6 +9,7 @@ import type { ContextEnvironment } from '../pipeline/environment.js';
|
||||
import { createHistoryTruncationProcessor } from '../processors/historyTruncationProcessor.js';
|
||||
|
||||
export const testTruncateProfile: ContextProfile = {
|
||||
name: 'Test Truncate',
|
||||
config: {
|
||||
budget: {
|
||||
retainedTokens: 65000,
|
||||
|
||||
@@ -0,0 +1,69 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2026 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect } from 'vitest';
|
||||
import { ContextTokenCalculator } from './contextTokenCalculator.js';
|
||||
import { NodeBehaviorRegistry } from '../graph/behaviorRegistry.js';
|
||||
import { registerBuiltInBehaviors } from '../graph/builtinBehaviors.js';
|
||||
import { createDummyNode } from '../testing/contextTestUtils.js';
|
||||
import { MSG_OVERHEAD_TOKENS } from '../../utils/tokenCalculation.js';
|
||||
import { NodeType } from '../graph/types.js';
|
||||
|
||||
describe('ContextTokenCalculator', () => {
|
||||
const registry = new NodeBehaviorRegistry();
|
||||
registerBuiltInBehaviors(registry);
|
||||
const charsPerToken = 1; // Simplifies math for text nodes in tests
|
||||
const calculator = new ContextTokenCalculator(charsPerToken, registry);
|
||||
|
||||
it('should include structural overhead for each unique turn', () => {
|
||||
const turn1Id = 'turn-1';
|
||||
const turn2Id = 'turn-2';
|
||||
|
||||
const node1 = createDummyNode(turn1Id, NodeType.USER_PROMPT);
|
||||
const node2 = createDummyNode(turn1Id, NodeType.USER_PROMPT); // Same turn
|
||||
const node3 = createDummyNode(turn2Id, NodeType.AGENT_THOUGHT); // Different turn
|
||||
|
||||
const nodes = [node1, node2, node3];
|
||||
|
||||
// Estimated tokens (using 0.33 per ASCII char heuristic):
|
||||
// node1: floor(17 chars * 0.33) = 5 tokens
|
||||
// node2: floor(17 chars * 0.33) = 5 tokens
|
||||
// node3: floor(19 chars * 0.33) = 6 tokens
|
||||
// Turn 1 overhead: 5 tokens
|
||||
// Turn 2 overhead: 5 tokens
|
||||
// Total: 5 + 5 + 6 + 5 + 5 = 26
|
||||
|
||||
const total = calculator.calculateConcreteListTokens(nodes);
|
||||
expect(total).toBe(26);
|
||||
});
|
||||
|
||||
it('should handle categorical breakdown with overhead', () => {
|
||||
const turn1Id = 'turn-1';
|
||||
const node = createDummyNode(turn1Id, NodeType.USER_PROMPT);
|
||||
|
||||
const breakdown = calculator.calculateTokenBreakdown([node]);
|
||||
|
||||
expect(breakdown.overhead).toBe(MSG_OVERHEAD_TOKENS);
|
||||
expect(breakdown.total).toBe(
|
||||
calculator.getTokenCost(node) + MSG_OVERHEAD_TOKENS,
|
||||
);
|
||||
});
|
||||
|
||||
it('should not double-count overhead for duplicate turn IDs in separate nodes', () => {
|
||||
const turn1Id = 'turn-1';
|
||||
const node1 = createDummyNode(turn1Id, NodeType.USER_PROMPT);
|
||||
const node2 = createDummyNode(turn1Id, NodeType.USER_PROMPT);
|
||||
|
||||
const total = calculator.calculateConcreteListTokens([node1, node2]);
|
||||
|
||||
// cost(node1) + cost(node2) + 1 * overhead
|
||||
const expected =
|
||||
calculator.getTokenCost(node1) +
|
||||
calculator.getTokenCost(node2) +
|
||||
MSG_OVERHEAD_TOKENS;
|
||||
expect(total).toBe(expected);
|
||||
});
|
||||
});
|
||||
@@ -4,8 +4,11 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { Part } from '@google/genai';
|
||||
import { estimateTokenCountSync } from '../../utils/tokenCalculation.js';
|
||||
import type { Part, Content } from '@google/genai';
|
||||
import {
|
||||
estimateTokenCountSync,
|
||||
MSG_OVERHEAD_TOKENS,
|
||||
} from '../../utils/tokenCalculation.js';
|
||||
import type { ConcreteNode } from '../graph/types.js';
|
||||
import type { NodeBehaviorRegistry } from '../graph/behaviorRegistry.js';
|
||||
|
||||
@@ -73,18 +76,107 @@ export class ContextTokenCalculator {
|
||||
return this.cacheNodeTokens(node);
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculates a detailed breakdown of tokens by category for a list of nodes.
|
||||
* Useful for calibration tracing and debugging overestimation.
|
||||
*/
|
||||
calculateTokenBreakdown(nodes: readonly ConcreteNode[]): {
|
||||
total: number;
|
||||
text: number;
|
||||
media: number;
|
||||
tool: number;
|
||||
overhead: number;
|
||||
} {
|
||||
const breakdown = { total: 0, text: 0, media: 0, tool: 0, overhead: 0 };
|
||||
const seenIds = new Set<string>();
|
||||
const seenTurnIds = new Set<string>();
|
||||
|
||||
for (const node of nodes) {
|
||||
if (seenIds.has(node.id)) continue;
|
||||
seenIds.add(node.id);
|
||||
|
||||
if (node.turnId) {
|
||||
if (!seenTurnIds.has(node.turnId)) {
|
||||
seenTurnIds.add(node.turnId);
|
||||
breakdown.overhead += MSG_OVERHEAD_TOKENS;
|
||||
breakdown.total += MSG_OVERHEAD_TOKENS;
|
||||
}
|
||||
}
|
||||
|
||||
const cost = this.getTokenCost(node);
|
||||
breakdown.total += cost;
|
||||
|
||||
const behavior = this.registry.get(node.type);
|
||||
const parts = behavior.getEstimatableParts(node);
|
||||
|
||||
for (const part of parts) {
|
||||
if (typeof part.text === 'string') {
|
||||
breakdown.text += estimateTokenCountSync(
|
||||
[part],
|
||||
0,
|
||||
this.charsPerToken,
|
||||
);
|
||||
} else if (
|
||||
part.inlineData?.mimeType?.startsWith('image/') ||
|
||||
part.fileData?.mimeType?.startsWith('image/')
|
||||
) {
|
||||
breakdown.media += estimateTokenCountSync(
|
||||
[part],
|
||||
0,
|
||||
this.charsPerToken,
|
||||
);
|
||||
} else if (part.functionCall || part.functionResponse) {
|
||||
breakdown.tool += estimateTokenCountSync(
|
||||
[part],
|
||||
0,
|
||||
this.charsPerToken,
|
||||
);
|
||||
} else {
|
||||
breakdown.overhead += estimateTokenCountSync(
|
||||
[part],
|
||||
0,
|
||||
this.charsPerToken,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
return breakdown;
|
||||
}
|
||||
|
||||
/**
|
||||
* Fast calculation for a flat array of ConcreteNodes (The Nodes).
|
||||
* It relies entirely on the O(1) sidecar token cache.
|
||||
*/
|
||||
calculateConcreteListTokens(nodes: readonly ConcreteNode[]): number {
|
||||
let tokens = 0;
|
||||
const seenIds = new Set<string>();
|
||||
const seenTurnIds = new Set<string>();
|
||||
|
||||
for (const node of nodes) {
|
||||
tokens += this.getTokenCost(node);
|
||||
if (!seenIds.has(node.id)) {
|
||||
seenIds.add(node.id);
|
||||
tokens += this.getTokenCost(node);
|
||||
|
||||
if (node.turnId) {
|
||||
if (!seenTurnIds.has(node.turnId)) {
|
||||
seenTurnIds.add(node.turnId);
|
||||
tokens += MSG_OVERHEAD_TOKENS;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return tokens;
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculates the token cost for a single Gemini Content object.
|
||||
*/
|
||||
calculateContentTokens(content: Content): number {
|
||||
return (
|
||||
this.estimateTokensForParts(content.parts || []) + MSG_OVERHEAD_TOKENS
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Slower, precise estimation for a Gemini Content/Part graph.
|
||||
* Deeply inspects the nested structure and uses the base tokenization math.
|
||||
|
||||
51
packages/core/src/context/utils/invariantChecker.ts
Normal file
51
packages/core/src/context/utils/invariantChecker.ts
Normal file
@@ -0,0 +1,51 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2026 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { ConcreteNode } from '../graph/types.js';
|
||||
import { debugLogger } from '../../utils/debugLogger.js';
|
||||
|
||||
/**
|
||||
* Validates structural and logical invariants of the Episodic Context Graph.
|
||||
* Primarily used in debug mode to identify "smelly" states before they reach the LLM.
|
||||
*/
|
||||
export function checkContextInvariants(
|
||||
nodes: readonly ConcreteNode[],
|
||||
context: string,
|
||||
): void {
|
||||
const seenIds = new Set<string>();
|
||||
const duplicates = new Set<string>();
|
||||
|
||||
for (const node of nodes) {
|
||||
if (seenIds.has(node.id)) {
|
||||
duplicates.add(node.id);
|
||||
}
|
||||
seenIds.add(node.id);
|
||||
}
|
||||
|
||||
if (duplicates.size > 0) {
|
||||
debugLogger.warn(
|
||||
`[InvariantCheck][${context}] Detected ${duplicates.size} duplicate nodes by ID: ${Array.from(duplicates).join(', ')}`,
|
||||
);
|
||||
}
|
||||
|
||||
// Check for orphan logic (nodes without turn association)
|
||||
const orphans = nodes.filter((n) => !n.turnId);
|
||||
if (orphans.length > 0) {
|
||||
debugLogger.warn(
|
||||
`[InvariantCheck][${context}] Detected ${orphans.length} nodes without turnId.`,
|
||||
);
|
||||
}
|
||||
|
||||
// Check for timestamp linearity
|
||||
for (let i = 1; i < nodes.length; i++) {
|
||||
if (nodes[i].timestamp < nodes[i - 1].timestamp) {
|
||||
debugLogger.warn(
|
||||
`[InvariantCheck][${context}] Non-linear timestamps detected at index ${i}.`,
|
||||
);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -23,16 +23,14 @@ Output ONLY the raw factual snapshot, formatted compactly. Do not include markdo
|
||||
|
||||
let userPromptText = 'TRANSCRIPT TO SNAPSHOT:\n\n';
|
||||
for (const node of nodes) {
|
||||
const payload = node.payload;
|
||||
let nodeContent = '';
|
||||
if ('text' in node && typeof node.text === 'string') {
|
||||
nodeContent = node.text;
|
||||
} else if ('semanticParts' in node) {
|
||||
nodeContent = JSON.stringify(node.semanticParts);
|
||||
} else if ('observation' in node) {
|
||||
nodeContent =
|
||||
typeof node.observation === 'string'
|
||||
? node.observation
|
||||
: JSON.stringify(node.observation);
|
||||
if (payload.text) {
|
||||
nodeContent = payload.text;
|
||||
} else if (payload.functionCall) {
|
||||
nodeContent = `CALL: ${payload.functionCall.name}(${JSON.stringify(payload.functionCall.args)})`;
|
||||
} else if (payload.functionResponse) {
|
||||
nodeContent = `RESPONSE: ${JSON.stringify(payload.functionResponse.response)}`;
|
||||
}
|
||||
|
||||
userPromptText += `[${node.type}]: ${nodeContent}\n`;
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
|
||||
import type { Content } from '@google/genai';
|
||||
|
||||
export type HistoryEventType = 'PUSH' | 'SYNC_FULL' | 'CLEAR';
|
||||
export type HistoryEventType = 'PUSH' | 'SYNC_FULL' | 'CLEAR' | 'SILENT_SYNC';
|
||||
|
||||
export interface HistoryEvent {
|
||||
type: HistoryEventType;
|
||||
@@ -42,9 +42,9 @@ export class AgentChatHistory {
|
||||
this.notify('PUSH', [content]);
|
||||
}
|
||||
|
||||
set(history: readonly Content[]) {
|
||||
set(history: readonly Content[], options: { silent?: boolean } = {}) {
|
||||
this.history = [...history];
|
||||
this.notify('SYNC_FULL', this.history);
|
||||
this.notify(options.silent ? 'SILENT_SYNC' : 'SYNC_FULL', this.history);
|
||||
}
|
||||
|
||||
clear() {
|
||||
|
||||
@@ -1517,7 +1517,7 @@ ${JSON.stringify(
|
||||
const longText = 'a'.repeat(404);
|
||||
const request: Part[] = [{ text: longText }];
|
||||
// estimateTextOnlyLength counts only text content (400 chars), not JSON structure
|
||||
const estimatedRequestTokenCount = Math.floor(longText.length / 4);
|
||||
const estimatedRequestTokenCount = Math.floor(longText.length * 0.33);
|
||||
const remainingTokenCount = MOCKED_TOKEN_LIMIT - lastPromptTokenCount;
|
||||
|
||||
// Mock tryCompressChat to not compress
|
||||
@@ -1577,7 +1577,7 @@ ${JSON.stringify(
|
||||
const longText = 'a'.repeat(404);
|
||||
const request: Part[] = [{ text: longText }];
|
||||
// estimateTextOnlyLength counts only text content (400 chars), not JSON structure
|
||||
const estimatedRequestTokenCount = Math.floor(longText.length / 4);
|
||||
const estimatedRequestTokenCount = Math.floor(longText.length * 0.33);
|
||||
const remainingTokenCount = STICKY_MODEL_LIMIT - lastPromptTokenCount;
|
||||
|
||||
vi.spyOn(client, 'tryCompressChat').mockResolvedValue({
|
||||
|
||||
@@ -369,7 +369,9 @@ export class GeminiClient {
|
||||
const toolDeclarations = toolRegistry.getFunctionDeclarations();
|
||||
const tools: Tool[] = [{ functionDeclarations: toolDeclarations }];
|
||||
|
||||
const history = await getInitialChatHistory(this.config, extraHistory);
|
||||
const history = this.config.getContextManagementConfig().enabled
|
||||
? (extraHistory ?? [])
|
||||
: await getInitialChatHistory(this.config, extraHistory);
|
||||
|
||||
try {
|
||||
const systemMemory = this.config.getSystemInstructionMemory();
|
||||
@@ -618,14 +620,25 @@ export class GeminiClient {
|
||||
const modelForLimitCheck = this._getActiveModelForCurrentTurn();
|
||||
|
||||
if (this.config.getContextManagementConfig().enabled) {
|
||||
const newHistory = this.contextManager
|
||||
? await this.contextManager.renderHistory()
|
||||
: await this.agentHistoryProvider.manageHistory(
|
||||
this.getHistory(),
|
||||
signal,
|
||||
);
|
||||
if (newHistory.length !== this.getHistory().length) {
|
||||
this.getChat().setHistory(newHistory);
|
||||
if (this.contextManager) {
|
||||
const pendingRequest = createUserContent(request);
|
||||
const { history: newHistory, didApplyManagement } =
|
||||
await this.contextManager.renderHistory(pendingRequest);
|
||||
|
||||
if (didApplyManagement) {
|
||||
// If the manager pruned history, we update the chat before continuing.
|
||||
// Note: we don't include the pendingRequest in this setHistory,
|
||||
// because Turn.run will add it normally.
|
||||
this.getChat().setHistory(newHistory, { silent: true });
|
||||
}
|
||||
} else {
|
||||
const newHistory = await this.agentHistoryProvider.manageHistory(
|
||||
this.getHistory(),
|
||||
signal,
|
||||
);
|
||||
if (newHistory.length !== this.getHistory().length) {
|
||||
this.getChat().setHistory(newHistory);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
const compressed = await this.tryCompressChat(prompt_id, false, signal);
|
||||
|
||||
@@ -240,7 +240,7 @@ describe('GeminiChat', () => {
|
||||
// 'Hello': 5 chars * 0.25 = 1.25
|
||||
// 'Hi there': 8 chars * 0.25 = 2.0
|
||||
// Total: 3.25 -> floor(3.25) = 3
|
||||
expect(chatWithHistory.getLastPromptTokenCount()).toBe(3);
|
||||
expect(chatWithHistory.getLastPromptTokenCount()).toBe(4);
|
||||
});
|
||||
|
||||
it('should initialize lastPromptTokenCount for empty history', () => {
|
||||
|
||||
@@ -48,6 +48,7 @@ import {
|
||||
} from '../telemetry/types.js';
|
||||
import { handleFallback } from '../fallback/handler.js';
|
||||
import { isFunctionResponse } from '../utils/messageInspectors.js';
|
||||
import { scrubHistory } from '../utils/historyHardening.js';
|
||||
import { partListUnionToString } from './geminiRequest.js';
|
||||
import type { ModelConfigKey } from '../services/modelConfigService.js';
|
||||
import { estimateTokenCountSync } from '../utils/tokenCalculation.js';
|
||||
@@ -57,6 +58,7 @@ import {
|
||||
} from '../availability/policyHelpers.js';
|
||||
import { coreEvents } from '../utils/events.js';
|
||||
import type { AgentLoopContext } from '../config/agent-loop-context.js';
|
||||
import { debugLogger } from '../utils/debugLogger.js';
|
||||
|
||||
export enum StreamEventType {
|
||||
/** A regular content chunk from the API. */
|
||||
@@ -96,6 +98,18 @@ const MID_STREAM_RETRY_OPTIONS: MidStreamRetryOptions = {
|
||||
|
||||
export const SYNTHETIC_THOUGHT_SIGNATURE = 'skip_thought_signature_validator';
|
||||
|
||||
/**
|
||||
* Internal interface for parts that carry the magic 'callIndex' property
|
||||
* used during model response consolidation.
|
||||
*/
|
||||
interface IndexedPart extends Part {
|
||||
callIndex?: number;
|
||||
}
|
||||
|
||||
function isIndexedPart(part: Part): part is IndexedPart {
|
||||
return 'callIndex' in part;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns true if the response is valid, false otherwise.
|
||||
*/
|
||||
@@ -250,10 +264,11 @@ export class GeminiChat {
|
||||
private sendPromise: Promise<void> = Promise.resolve();
|
||||
private readonly chatRecordingService: ChatRecordingService;
|
||||
private lastPromptTokenCount: number;
|
||||
private callCounter = 0;
|
||||
agentHistory: AgentChatHistory;
|
||||
|
||||
constructor(
|
||||
private readonly context: AgentLoopContext,
|
||||
readonly context: AgentLoopContext,
|
||||
private systemInstruction: string = '',
|
||||
private tools: Tool[] = [],
|
||||
history: Content[] = [],
|
||||
@@ -502,8 +517,14 @@ export class GeminiChat {
|
||||
abortSignal: AbortSignal,
|
||||
role: LlmRole,
|
||||
): Promise<AsyncGenerator<GenerateContentResponse>> {
|
||||
// Last mile scrubbing to remove internal tracking properties (e.g. callIndex)
|
||||
// before sending to the Gemini API. This whitelists only standard Gemini fields.
|
||||
const scrubbedContents = this.context.config.isContextManagementEnabled()
|
||||
? scrubHistory([...requestContents])
|
||||
: [...requestContents];
|
||||
|
||||
const contentsForPreviewModel =
|
||||
this.ensureActiveLoopHasThoughtSignatures(requestContents);
|
||||
this.ensureActiveLoopHasThoughtSignatures(scrubbedContents);
|
||||
|
||||
// Track final request parameters for AfterModel hooks
|
||||
const {
|
||||
@@ -772,8 +793,11 @@ export class GeminiChat {
|
||||
this.agentHistory.push(content);
|
||||
}
|
||||
|
||||
setHistory(history: readonly Content[]): void {
|
||||
this.agentHistory.set(history);
|
||||
setHistory(
|
||||
history: readonly Content[],
|
||||
options: { silent?: boolean } = {},
|
||||
): void {
|
||||
this.agentHistory.set(history, options);
|
||||
this.lastPromptTokenCount = estimateTokenCountSync(
|
||||
this.agentHistory.flatMap((c) => c.parts || []),
|
||||
);
|
||||
@@ -892,7 +916,12 @@ export class GeminiChat {
|
||||
let finishReason: FinishReason | undefined;
|
||||
|
||||
// The SDK provides fully assembled FunctionCall objects in chunk.functionCalls
|
||||
const finalFunctionCalls: FunctionCall[] = [];
|
||||
// We use a Map to ensure we only keep the latest version of each call (by ID)
|
||||
const finalFunctionCallsMap = new Map<string, FunctionCall>();
|
||||
const legacyFunctionCalls: FunctionCall[] = [];
|
||||
|
||||
// Map to track synthetic IDs assigned to each call index across chunks
|
||||
const callIndexToId = new Map<number, string>();
|
||||
|
||||
for await (const chunk of streamResponse) {
|
||||
const candidateWithReason = chunk?.candidates?.find(
|
||||
@@ -904,9 +933,26 @@ export class GeminiChat {
|
||||
}
|
||||
|
||||
if (chunk.functionCalls && chunk.functionCalls.length > 0) {
|
||||
finalFunctionCalls.push(...chunk.functionCalls);
|
||||
if (this.context.config.isContextManagementEnabled()) {
|
||||
for (let i = 0; i < chunk.functionCalls.length; i++) {
|
||||
const fnCall = chunk.functionCalls[i];
|
||||
if (!fnCall.id) {
|
||||
let id = callIndexToId.get(i);
|
||||
if (!id) {
|
||||
id = `synth_${this.context.promptId}_${Date.now()}_${this.callCounter++}`;
|
||||
callIndexToId.set(i, id);
|
||||
debugLogger.log(
|
||||
`[GeminiChat] Assigned synthetic ID: ${id} to tool at index ${i}: ${fnCall.name}`,
|
||||
);
|
||||
}
|
||||
fnCall.id = id;
|
||||
}
|
||||
finalFunctionCallsMap.set(fnCall.id, fnCall);
|
||||
}
|
||||
} else {
|
||||
legacyFunctionCalls.push(...chunk.functionCalls);
|
||||
}
|
||||
}
|
||||
|
||||
if (isValidResponse(chunk)) {
|
||||
const content = chunk.candidates?.[0]?.content;
|
||||
if (content?.parts) {
|
||||
@@ -920,7 +966,19 @@ export class GeminiChat {
|
||||
}
|
||||
|
||||
modelResponseParts.push(
|
||||
...content.parts.filter((part) => !part.thought),
|
||||
...content.parts
|
||||
.filter((part) => !part.thought)
|
||||
.map((part) => {
|
||||
if (!this.context.config.isContextManagementEnabled()) {
|
||||
return part;
|
||||
}
|
||||
return {
|
||||
...part,
|
||||
callIndex: chunk.functionCalls?.findIndex(
|
||||
(fc) => fc.name === part.functionCall?.name,
|
||||
),
|
||||
};
|
||||
}),
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -961,27 +1019,23 @@ export class GeminiChat {
|
||||
|
||||
// String thoughts and consolidate text parts.
|
||||
const consolidatedParts: Part[] = [];
|
||||
const finalFunctionCalls = this.context.config.isContextManagementEnabled()
|
||||
? Array.from(finalFunctionCallsMap.values())
|
||||
: legacyFunctionCalls;
|
||||
|
||||
let currentCallSourceIndex = -1;
|
||||
if (this.context.config.isContextManagementEnabled()) {
|
||||
debugLogger.log(
|
||||
`[GeminiChat] Starting consolidation for ${modelResponseParts.length} raw parts and ${finalFunctionCalls.length} assembled function calls.`,
|
||||
);
|
||||
for (const part of modelResponseParts) {
|
||||
if (part.functionCall) {
|
||||
// Skip partial functionCall stream chunks! We will replace them
|
||||
// entirely with the pristine, fully assembled objects from the SDK
|
||||
// (finalFunctionCalls) immediately below. We only push the very first
|
||||
// partial chunk of a sequence as a placeholder so we know *where*
|
||||
// in the sequence of parts the tool call happened.
|
||||
const lastPart = consolidatedParts[consolidatedParts.length - 1];
|
||||
const currentId = part.functionCall.id;
|
||||
const lastId = lastPart?.functionCall?.id;
|
||||
|
||||
const partIndex = isIndexedPart(part) ? part.callIndex : undefined;
|
||||
const isNewCall =
|
||||
!lastPart?.functionCall ||
|
||||
(currentId !== undefined &&
|
||||
lastId !== undefined &&
|
||||
currentId !== lastId) ||
|
||||
lastPart.functionCall.name !== part.functionCall.name;
|
||||
partIndex !== undefined && partIndex > currentCallSourceIndex;
|
||||
|
||||
if (isNewCall) {
|
||||
currentCallSourceIndex = partIndex;
|
||||
consolidatedParts.push({ ...part }); // Push placeholder
|
||||
}
|
||||
} else {
|
||||
|
||||
@@ -48,6 +48,7 @@ describe('Turn', () => {
|
||||
sendMessageStream: typeof mockSendMessageStream;
|
||||
getHistory: typeof mockGetHistory;
|
||||
maybeIncludeSchemaDepthContext: typeof mockMaybeIncludeSchemaDepthContext;
|
||||
context: { config: { isContextManagementEnabled: () => boolean } };
|
||||
};
|
||||
let mockChatInstance: MockedChatInstance;
|
||||
|
||||
@@ -57,6 +58,11 @@ describe('Turn', () => {
|
||||
sendMessageStream: mockSendMessageStream,
|
||||
getHistory: mockGetHistory,
|
||||
maybeIncludeSchemaDepthContext: mockMaybeIncludeSchemaDepthContext,
|
||||
context: {
|
||||
config: {
|
||||
isContextManagementEnabled: () => false,
|
||||
},
|
||||
},
|
||||
};
|
||||
turn = new Turn(mockChatInstance as unknown as GeminiChat, 'prompt-id-1');
|
||||
mockGetHistory.mockReturnValue([]);
|
||||
|
||||
@@ -409,7 +409,11 @@ export class Turn {
|
||||
): ServerGeminiStreamEvent | null {
|
||||
const name = fnCall.name || 'undefined_tool_name';
|
||||
const args = fnCall.args || {};
|
||||
const callId = fnCall.id ?? `${name}_${Date.now()}_${this.callCounter++}`;
|
||||
const callId =
|
||||
fnCall.id ??
|
||||
(this.chat.context.config.isContextManagementEnabled()
|
||||
? `synth_${this.prompt_id}_${Date.now()}_${this.callCounter++}`
|
||||
: `${name}_${Date.now()}_${this.callCounter++}`);
|
||||
|
||||
const toolCallRequest: ToolCallRequestInfo = {
|
||||
callId,
|
||||
|
||||
355
packages/core/src/utils/historyHardening.ts
Normal file
355
packages/core/src/utils/historyHardening.ts
Normal file
@@ -0,0 +1,355 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2026 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { Content, Part } from '@google/genai';
|
||||
import { debugLogger } from './debugLogger.js';
|
||||
|
||||
export const SYNTHETIC_THOUGHT_SIGNATURE = 'skip_thought_signature_validator';
|
||||
|
||||
export interface HardeningOptions {
|
||||
sentinels?: {
|
||||
continuation?: string;
|
||||
lostToolResponse?: string;
|
||||
};
|
||||
}
|
||||
|
||||
const DEFAULT_SENTINELS = {
|
||||
continuation: '[Continuing from previous AI thoughts...]',
|
||||
lostToolResponse:
|
||||
'The tool execution result was lost due to context management truncation.',
|
||||
};
|
||||
|
||||
/**
|
||||
* Hardens a chat history to ensure it strictly adheres to Gemini API invariants.
|
||||
* This is a defensive post-processing pass that patches violations using
|
||||
* sentinel messages rather than failing.
|
||||
*
|
||||
* Invariants enforced:
|
||||
* 1. Role Alternation: user -> model -> user -> model
|
||||
* 2. Start Constraint: Must start with a 'user' turn.
|
||||
* 3. End Constraint: Must end with a 'user' turn (usually for follow-up prompts).
|
||||
* 4. Tool Pairing: Every model functionCall must be followed by a user functionResponse.
|
||||
* 5. Signatures: The first functionCall in a model turn must have a thoughtSignature.
|
||||
*/
|
||||
export function hardenHistory(
|
||||
history: Content[],
|
||||
options: HardeningOptions = {},
|
||||
): Content[] {
|
||||
if (history.length === 0) return history;
|
||||
|
||||
const sentinels = { ...DEFAULT_SENTINELS, ...options.sentinels };
|
||||
|
||||
// Pass 1: Initial Coalesce & Empty Turn Removal
|
||||
let coalesced = coalesce(history);
|
||||
|
||||
// Pass 2: Tool Pairing & Signatures (The semantic layer)
|
||||
coalesced = pairToolsAndEnforceSignatures(coalesced, sentinels);
|
||||
|
||||
// Pass 3: Structural Refinement (Hoisting & Re-ordering of tool responses)
|
||||
coalesced = refineToolResponses(coalesced);
|
||||
|
||||
// Pass 4: Enforce Structural Invariants (Start/End/Alternation)
|
||||
let final = enforceRoleConstraints(coalesced, sentinels);
|
||||
|
||||
// Pass 5: Final Scrubbing (Remove custom/non-standard properties for API compatibility)
|
||||
final = scrubHistory(final);
|
||||
|
||||
return final;
|
||||
}
|
||||
|
||||
/**
|
||||
* Combines adjacent turns with the same role and removes empty turns.
|
||||
*/
|
||||
function coalesce(history: Content[]): Content[] {
|
||||
const result: Content[] = [];
|
||||
for (const turn of history) {
|
||||
if (!turn.parts || turn.parts.length === 0) continue;
|
||||
|
||||
const last = result[result.length - 1];
|
||||
if (last && last.role === turn.role) {
|
||||
last.parts = [...(last.parts || []), ...(turn.parts || [])];
|
||||
} else {
|
||||
// Shallow clone the turn so we don't mutate the original history array structure
|
||||
result.push({ ...turn });
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Ensures tool calls have matching responses and model turns have required signatures.
|
||||
*/
|
||||
function pairToolsAndEnforceSignatures(
|
||||
history: Content[],
|
||||
sentinels: Required<NonNullable<HardeningOptions['sentinels']>>,
|
||||
): Content[] {
|
||||
const result: Content[] = [];
|
||||
|
||||
// We work on a copy to allow splicing in sentinel turns
|
||||
const work = [...history];
|
||||
|
||||
for (let i = 0; i < work.length; i++) {
|
||||
const turn = work[i];
|
||||
|
||||
if (turn.role === 'model') {
|
||||
const parts = turn.parts || [];
|
||||
|
||||
// A. Signatures
|
||||
let foundCall = false;
|
||||
for (let j = 0; j < parts.length; j++) {
|
||||
const p = parts[j];
|
||||
if (p.functionCall) {
|
||||
if (!foundCall && !p.thoughtSignature) {
|
||||
debugLogger.warn(
|
||||
`[HistoryHardener] Missing thought signature on first function call in model turn. Injecting synthetic signature.`,
|
||||
);
|
||||
parts[j] = { ...p, thoughtSignature: SYNTHETIC_THOUGHT_SIGNATURE };
|
||||
}
|
||||
foundCall = true;
|
||||
}
|
||||
}
|
||||
|
||||
// B. Pairing
|
||||
const callParts = parts.filter((p) => !!p.functionCall);
|
||||
if (callParts.length > 0) {
|
||||
const nextTurn = work[i + 1];
|
||||
const missing: Array<{ id: string; name: string }> = [];
|
||||
|
||||
for (const call of callParts) {
|
||||
const id = call.functionCall!.id || 'undefined';
|
||||
const name = call.functionCall!.name || 'unknown';
|
||||
|
||||
const hasResponse =
|
||||
nextTurn?.role === 'user' &&
|
||||
nextTurn.parts?.some(
|
||||
(p) =>
|
||||
p.functionResponse?.id === id &&
|
||||
p.functionResponse?.name === name,
|
||||
);
|
||||
|
||||
if (!hasResponse) {
|
||||
debugLogger.log(
|
||||
`[HistoryHardener] Call id='${id}' (name='${name}') has no matching response in next turn.`,
|
||||
);
|
||||
missing.push({ id, name });
|
||||
}
|
||||
}
|
||||
|
||||
if (missing.length > 0) {
|
||||
debugLogger.log(
|
||||
`[HistoryHardener] Detected ${missing.length} tool calls without responses. Injecting sentinel responses.`,
|
||||
);
|
||||
|
||||
let targetUserTurn: Content;
|
||||
if (nextTurn?.role === 'user') {
|
||||
targetUserTurn = nextTurn;
|
||||
} else {
|
||||
targetUserTurn = { role: 'user', parts: [] };
|
||||
work.splice(i + 1, 0, targetUserTurn);
|
||||
}
|
||||
|
||||
for (const m of missing) {
|
||||
targetUserTurn.parts = targetUserTurn.parts || [];
|
||||
targetUserTurn.parts.push({
|
||||
functionResponse: {
|
||||
name: m.name,
|
||||
id: m.id,
|
||||
response: {
|
||||
error: sentinels.lostToolResponse,
|
||||
},
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (turn.role === 'user') {
|
||||
// C. Orphaned Responses
|
||||
// A user response MUST follow a model call.
|
||||
const prevTurn = result[result.length - 1];
|
||||
const parts = turn.parts || [];
|
||||
const validParts: Part[] = [];
|
||||
|
||||
for (const p of parts) {
|
||||
if (p.functionResponse) {
|
||||
const id = p.functionResponse.id;
|
||||
const name = p.functionResponse.name;
|
||||
const hasCall =
|
||||
prevTurn?.role === 'model' &&
|
||||
prevTurn.parts?.some(
|
||||
(cp) =>
|
||||
cp.functionCall?.id === id && cp.functionCall?.name === name,
|
||||
);
|
||||
|
||||
if (hasCall) {
|
||||
validParts.push(p);
|
||||
} else {
|
||||
debugLogger.log(
|
||||
`[HistoryHardener] Dropping orphaned functionResponse id='${id}' (name='${name}')`,
|
||||
);
|
||||
}
|
||||
} else {
|
||||
validParts.push(p);
|
||||
}
|
||||
}
|
||||
turn.parts = validParts;
|
||||
}
|
||||
|
||||
if (turn.parts && turn.parts.length > 0) {
|
||||
result.push(turn);
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Hoists and re-orders tool responses within user turns to match preceding model turns.
|
||||
*/
|
||||
function refineToolResponses(history: Content[]): Content[] {
|
||||
for (let i = 1; i < history.length; i++) {
|
||||
const turn = history[i];
|
||||
const prev = history[i - 1];
|
||||
|
||||
if (turn.role === 'user' && prev.role === 'model') {
|
||||
const callOrder =
|
||||
prev.parts
|
||||
?.filter((p) => !!p.functionCall)
|
||||
.map((p) => p.functionCall!.id) || [];
|
||||
|
||||
if (callOrder.length > 0) {
|
||||
const responseParts =
|
||||
turn.parts?.filter((p) => !!p.functionResponse) || [];
|
||||
const otherParts = turn.parts?.filter((p) => !p.functionResponse) || [];
|
||||
|
||||
if (responseParts.length > 0) {
|
||||
// 1. Re-order: Sort responses to match the model's call order
|
||||
responseParts.sort((a, b) => {
|
||||
const idA = a.functionResponse!.id;
|
||||
const idB = b.functionResponse!.id;
|
||||
const idxA = callOrder.indexOf(idA);
|
||||
const idxB = callOrder.indexOf(idB);
|
||||
|
||||
// If an ID isn't found in the preceding turn (should be rare after pairing),
|
||||
// move it to the end.
|
||||
if (idxA === -1) return 1;
|
||||
if (idxB === -1) return -1;
|
||||
return idxA - idxB;
|
||||
});
|
||||
|
||||
// 2. Hoisting: Place all sorted responses BEFORE text or other parts
|
||||
turn.parts = [...responseParts, ...otherParts];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return history;
|
||||
}
|
||||
|
||||
/**
|
||||
* Final pass to ensure start/end roles and alternation are correct.
|
||||
*/
|
||||
function enforceRoleConstraints(
|
||||
history: Content[],
|
||||
sentinels: Required<NonNullable<HardeningOptions['sentinels']>>,
|
||||
): Content[] {
|
||||
if (history.length === 0) return [];
|
||||
|
||||
// Re-coalesce first to catch any empty turns or adjacent roles introduced by pairing
|
||||
const base = coalesce(history);
|
||||
if (base.length === 0) return [];
|
||||
|
||||
const result: Content[] = [...base];
|
||||
|
||||
// 1. Ensure starts with user
|
||||
if (result[0].role === 'model') {
|
||||
debugLogger.log(
|
||||
'[HistoryHardener] Final history starts with model role. Prepending sentinel user turn.',
|
||||
);
|
||||
result.unshift({
|
||||
role: 'user',
|
||||
parts: [{ text: sentinels.continuation }],
|
||||
});
|
||||
}
|
||||
|
||||
// 2. Ensure ends with user
|
||||
if (result[result.length - 1].role === 'model') {
|
||||
debugLogger.log(
|
||||
'[HistoryHardener] Final history ends with model role. Appending sentinel user turn.',
|
||||
);
|
||||
result.push({
|
||||
role: 'user',
|
||||
parts: [{ text: 'Please continue.' }],
|
||||
});
|
||||
}
|
||||
|
||||
// 3. Final Alternation Check (redundant if coalesce works, but safe)
|
||||
return coalesce(result);
|
||||
}
|
||||
|
||||
/**
|
||||
* Deep-scrubs the history to remove any non-standard properties from Content and Part objects.
|
||||
* This ensures compatibility with strict APIs (like Vertex AI) that reject unknown fields.
|
||||
*/
|
||||
export function scrubHistory(history: Content[]): Content[] {
|
||||
return history.map((content) => ({
|
||||
role: content.role,
|
||||
parts: (content.parts || []).map(scrubPart),
|
||||
}));
|
||||
}
|
||||
|
||||
interface ThoughtPart extends Part {
|
||||
thoughtSignature?: string;
|
||||
}
|
||||
|
||||
function isThoughtPart(part: Part): part is ThoughtPart {
|
||||
return 'thoughtSignature' in part;
|
||||
}
|
||||
|
||||
function scrubPart(part: Part): Part {
|
||||
const scrubbed: Record<string, unknown> = {};
|
||||
|
||||
if ('text' in part && typeof part.text === 'string') {
|
||||
scrubbed['text'] = part.text;
|
||||
}
|
||||
if ('inlineData' in part) {
|
||||
scrubbed['inlineData'] = part.inlineData;
|
||||
}
|
||||
if ('functionCall' in part && part.functionCall) {
|
||||
const scrubbedCall: Record<string, unknown> = {
|
||||
name: part.functionCall.name,
|
||||
args: part.functionCall.args,
|
||||
};
|
||||
if (part.functionCall.id) {
|
||||
scrubbedCall['id'] = part.functionCall.id;
|
||||
}
|
||||
scrubbed['functionCall'] = scrubbedCall;
|
||||
}
|
||||
if (isThoughtPart(part)) {
|
||||
scrubbed['thoughtSignature'] = part.thoughtSignature;
|
||||
}
|
||||
if ('functionResponse' in part && part.functionResponse) {
|
||||
const scrubbedResp: Record<string, unknown> = {
|
||||
name: part.functionResponse.name,
|
||||
response: part.functionResponse.response,
|
||||
};
|
||||
if (part.functionResponse.id) {
|
||||
scrubbedResp['id'] = part.functionResponse.id;
|
||||
}
|
||||
scrubbed['functionResponse'] = scrubbedResp;
|
||||
}
|
||||
if ('fileData' in part) {
|
||||
scrubbed['fileData'] = part.fileData;
|
||||
}
|
||||
if ('executableCode' in part) {
|
||||
scrubbed['executableCode'] = part.executableCode;
|
||||
}
|
||||
if ('codeExecutionResult' in part) {
|
||||
scrubbed['codeExecutionResult'] = part.codeExecutionResult;
|
||||
}
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
|
||||
return scrubbed as unknown as Part;
|
||||
}
|
||||
@@ -81,6 +81,42 @@ export function partToString(
|
||||
return part.text ?? '';
|
||||
}
|
||||
|
||||
/**
|
||||
* Safely clones a Part object.
|
||||
* We use a local eslint-disable because the linter incorrectly identifies Part
|
||||
* as a class instance and warns about losing the prototype during spread.
|
||||
* In reality, Parts in the GenAI SDK are plain data objects.
|
||||
*/
|
||||
export function clonePart(part: Part): Part {
|
||||
return { ...part };
|
||||
}
|
||||
|
||||
/**
|
||||
* Safely updates a Part object with new fields.
|
||||
*/
|
||||
export function updatePart(part: Part, updates: Partial<Part>): Part {
|
||||
return { ...part, ...updates };
|
||||
}
|
||||
|
||||
/**
|
||||
* Safely clones a FunctionResponse object.
|
||||
*/
|
||||
export function cloneFunctionResponse(
|
||||
resp: NonNullable<Part['functionResponse']>,
|
||||
): NonNullable<Part['functionResponse']> {
|
||||
// eslint-disable-next-line @typescript-eslint/no-misused-spread
|
||||
return { ...resp };
|
||||
}
|
||||
|
||||
/**
|
||||
* Safely clones a FunctionCall object.
|
||||
*/
|
||||
export function cloneFunctionCall(
|
||||
call: NonNullable<Part['functionCall']>,
|
||||
): NonNullable<Part['functionCall']> {
|
||||
return { ...call };
|
||||
}
|
||||
|
||||
export function getResponseText(
|
||||
response: GenerateContentResponse,
|
||||
): string | null {
|
||||
|
||||
@@ -9,11 +9,14 @@ import type { ContentGenerator } from '../core/contentGenerator.js';
|
||||
import { debugLogger } from './debugLogger.js';
|
||||
|
||||
// Token estimation constants
|
||||
// ASCII characters (0-127) are roughly 4 chars per token
|
||||
export const ASCII_TOKENS_PER_CHAR = 0.25;
|
||||
// ASCII characters (0-127) are roughly 3-4 chars per token.
|
||||
// We use 0.33 (~3 chars/token) as a conservative baseline for mixed text and code.
|
||||
export const ASCII_TOKENS_PER_CHAR = 0.33;
|
||||
// Non-ASCII characters (including CJK) are often 1-2 tokens per char.
|
||||
// We use 1.3 as a conservative estimate to avoid underestimation.
|
||||
export const NON_ASCII_TOKENS_PER_CHAR = 1.3;
|
||||
// We use 1.5 as a conservative estimate to avoid underestimation.
|
||||
export const NON_ASCII_TOKENS_PER_CHAR = 1.5;
|
||||
// Structural overhead per Content turn (role prefixes, separators).
|
||||
export const MSG_OVERHEAD_TOKENS = 5;
|
||||
// Fixed token estimate for images
|
||||
const IMAGE_TOKEN_ESTIMATE = 3000;
|
||||
// Fixed token estimate for PDFs (~100 pages at 258 tokens/page)
|
||||
|
||||
Reference in New Issue
Block a user