refactor: Pass and handle a dedicated timeout signal for streaming content generation and update mock to return an AsyncGenerator.

This commit is contained in:
kevin-ramdass
2026-01-31 18:23:15 -08:00
parent d03b9b95b3
commit b8ad178bee
2 changed files with 118 additions and 87 deletions

View File

@@ -1022,18 +1022,27 @@ describe('GeminiChat', () => {
// 2. Mock generateContentStream to hang UNTIL aborted // 2. Mock generateContentStream to hang UNTIL aborted
vi.mocked(mockContentGenerator.generateContentStream).mockImplementation( vi.mocked(mockContentGenerator.generateContentStream).mockImplementation(
(request) => new Promise((resolve, reject) => { async (request) => {
const config = request?.config; const signal = request.config?.abortSignal;
if (config?.abortSignal) { return {
if (config.abortSignal.aborted) { async *[Symbol.asyncIterator]() {
reject(new Error('Aborted')); if (signal) {
return; await new Promise((resolve, reject) => {
if (signal.aborted) {
reject(new Error('Aborted'));
return;
}
signal.addEventListener('abort', () => {
reject(new Error('Aborted'));
});
});
} else {
await new Promise(() => {}); // Hang indefinitely
} }
config.abortSignal.addEventListener('abort', () => { yield {} as GenerateContentResponse; // Dummy yield to satisfy require-yield lint rule
reject(new Error('Aborted')); },
}); } as AsyncGenerator<GenerateContentResponse>;
} },
}),
); );
// 3. Start the request // 3. Start the request

View File

@@ -592,14 +592,20 @@ export class GeminiChat {
lastContentsToUse = contentsToUse; lastContentsToUse = contentsToUse;
try { try {
return await this.config.getContentGenerator().generateContentStream( const stream = await this.config
{ .getContentGenerator()
model: modelToUse, .generateContentStream(
contents: contentsToUse, {
config, model: modelToUse,
}, contents: contentsToUse,
prompt_id, config,
); },
prompt_id,
);
return {
stream,
timeoutSignal,
};
} catch (error) { } catch (error) {
if (timeoutSignal.aborted) { if (timeoutSignal.aborted) {
const timeoutError = new Error( const timeoutError = new Error(
@@ -632,24 +638,27 @@ export class GeminiChat {
); );
}; };
const streamResponse = await retryWithBackoff(apiCall, { const { stream: streamResponse, timeoutSignal } = await retryWithBackoff(
onPersistent429: onPersistent429Callback, apiCall,
onValidationRequired: onValidationRequiredCallback, {
authType: this.config.getContentGeneratorConfig()?.authType, onPersistent429: onPersistent429Callback,
retryFetchErrors: this.config.getRetryFetchErrors(), onValidationRequired: onValidationRequiredCallback,
signal: abortSignal, authType: this.config.getContentGeneratorConfig()?.authType,
maxAttempts: availabilityMaxAttempts, retryFetchErrors: this.config.getRetryFetchErrors(),
getAvailabilityContext, signal: abortSignal,
onRetry: (attempt, error, delayMs) => { maxAttempts: availabilityMaxAttempts,
coreEvents.emitRetryAttempt({ getAvailabilityContext,
attempt, onRetry: (attempt, error, delayMs) => {
maxAttempts: availabilityMaxAttempts ?? 10, coreEvents.emitRetryAttempt({
delayMs, attempt,
error: error instanceof Error ? error.message : String(error), maxAttempts: availabilityMaxAttempts ?? 10,
model: lastModelToUse, delayMs,
}); error: error instanceof Error ? error.message : String(error),
model: lastModelToUse,
});
},
}, },
}); );
// Store the original request for AfterModel hooks // Store the original request for AfterModel hooks
const originalRequest: GenerateContentParameters = { const originalRequest: GenerateContentParameters = {
@@ -662,6 +671,7 @@ export class GeminiChat {
lastModelToUse, lastModelToUse,
streamResponse, streamResponse,
originalRequest, originalRequest,
timeoutSignal,
); );
} }
@@ -820,69 +830,81 @@ export class GeminiChat {
model: string, model: string,
streamResponse: AsyncGenerator<GenerateContentResponse>, streamResponse: AsyncGenerator<GenerateContentResponse>,
originalRequest: GenerateContentParameters, originalRequest: GenerateContentParameters,
timeoutSignal: AbortSignal,
): AsyncGenerator<GenerateContentResponse> { ): AsyncGenerator<GenerateContentResponse> {
const modelResponseParts: Part[] = []; const modelResponseParts: Part[] = [];
let hasToolCall = false; let hasToolCall = false;
let finishReason: FinishReason | undefined; let finishReason: FinishReason | undefined;
for await (const chunk of streamResponse) { try {
const candidateWithReason = chunk?.candidates?.find( for await (const chunk of streamResponse) {
(candidate) => candidate.finishReason, const candidateWithReason = chunk?.candidates?.find(
); (candidate) => candidate.finishReason,
if (candidateWithReason) {
finishReason = candidateWithReason.finishReason as FinishReason;
}
if (isValidResponse(chunk)) {
const content = chunk.candidates?.[0]?.content;
if (content?.parts) {
if (content.parts.some((part) => part.thought)) {
// Record thoughts
this.recordThoughtFromContent(content);
}
if (content.parts.some((part) => part.functionCall)) {
hasToolCall = true;
}
modelResponseParts.push(
...content.parts.filter((part) => !part.thought),
);
}
}
// Record token usage if this chunk has usageMetadata
if (chunk.usageMetadata) {
this.chatRecordingService.recordMessageTokens(chunk.usageMetadata);
if (chunk.usageMetadata.promptTokenCount !== undefined) {
this.lastPromptTokenCount = chunk.usageMetadata.promptTokenCount;
}
}
const hookSystem = this.config.getHookSystem();
if (originalRequest && chunk && hookSystem) {
const hookResult = await hookSystem.fireAfterModelEvent(
originalRequest,
chunk,
); );
if (candidateWithReason) {
if (hookResult.stopped) { finishReason = candidateWithReason.finishReason as FinishReason;
throw new AgentExecutionStoppedError(
hookResult.reason || 'Agent execution stopped by hook',
);
} }
if (hookResult.blocked) { if (isValidResponse(chunk)) {
throw new AgentExecutionBlockedError( const content = chunk.candidates?.[0]?.content;
hookResult.reason || 'Agent execution blocked by hook', if (content?.parts) {
hookResult.response, if (content.parts.some((part) => part.thought)) {
); // Record thoughts
this.recordThoughtFromContent(content);
}
if (content.parts.some((part) => part.functionCall)) {
hasToolCall = true;
}
modelResponseParts.push(
...content.parts.filter((part) => !part.thought),
);
}
} }
yield hookResult.response; // Record token usage if this chunk has usageMetadata
} else { if (chunk.usageMetadata) {
yield chunk; this.chatRecordingService.recordMessageTokens(chunk.usageMetadata);
if (chunk.usageMetadata.promptTokenCount !== undefined) {
this.lastPromptTokenCount = chunk.usageMetadata.promptTokenCount;
}
}
const hookSystem = this.config.getHookSystem();
if (originalRequest && chunk && hookSystem) {
const hookResult = await hookSystem.fireAfterModelEvent(
originalRequest,
chunk,
);
if (hookResult.stopped) {
throw new AgentExecutionStoppedError(
hookResult.reason || 'Agent execution stopped by hook',
);
}
if (hookResult.blocked) {
throw new AgentExecutionBlockedError(
hookResult.reason || 'Agent execution blocked by hook',
hookResult.response,
);
}
yield hookResult.response;
} else {
yield chunk;
}
} }
} catch (error) {
if (timeoutSignal.aborted) {
const timeoutError = new Error(
`Request timed out after ${TIMEOUT_MS}ms`,
);
(timeoutError as unknown as { code: string }).code = 'ETIMEDOUT';
throw timeoutError;
}
throw error;
} }
// String thoughts and consolidate text parts. // String thoughts and consolidate text parts.