From dd14413a642759dfdecb2adee42f56b4b314d08f Mon Sep 17 00:00:00 2001 From: Kit Langton Date: Tue, 12 May 2026 16:16:58 -0400 Subject: [PATCH] Preserve native LLM tool context (#27116) --- packages/llm/example/tutorial.ts | 4 +- packages/llm/src/index.ts | 1 + .../llm/src/protocols/openai-responses.ts | 2 +- packages/llm/src/protocols/utils/lifecycle.ts | 2 +- packages/llm/src/schema/events.ts | 44 +++++---- packages/llm/src/tool-runtime.ts | 96 ++++++++++++++++--- packages/llm/src/tool.ts | 11 ++- packages/llm/test/adapter.test.ts | 6 +- packages/llm/test/llm.test.ts | 2 +- .../test/provider/anthropic-messages.test.ts | 6 +- .../test/provider/bedrock-converse.test.ts | 8 +- packages/llm/test/provider/gemini.test.ts | 14 +-- .../llm/test/provider/openai-chat.test.ts | 4 +- .../provider/openai-compatible-chat.test.ts | 2 +- .../test/provider/openai-responses.test.ts | 4 +- packages/llm/test/recorded-scenarios.ts | 16 ++-- packages/llm/test/schema.test.ts | 5 + packages/llm/test/tool-runtime.test.ts | 92 ++++++++++++++++-- 18 files changed, 244 insertions(+), 75 deletions(-) diff --git a/packages/llm/example/tutorial.ts b/packages/llm/example/tutorial.ts index a9adecf369..429ac4824b 100644 --- a/packages/llm/example/tutorial.ts +++ b/packages/llm/example/tutorial.ts @@ -78,7 +78,7 @@ const streamText = LLM.stream(request).pipe( Stream.tap((event) => Effect.sync(() => { if (event.type === "text-delta") process.stdout.write(`\ntext: ${event.text}`) - if (event.type === "request-finish") process.stdout.write(`\nfinish: ${event.reason}\n`) + if (event.type === "finish") process.stdout.write(`\nfinish: ${event.reason}\n`) }), ), Stream.runDrain, @@ -185,7 +185,7 @@ const FakeProtocol = Protocol.make({ event: Schema.String, initial: () => undefined, step: (_, frame) => Effect.succeed([undefined, [{ type: "text-delta", id: "text-0", text: frame }]] as const), - onHalt: () => [{ type: "request-finish", reason: "stop" }], + onHalt: () => [{ type: "finish", reason: "stop" }], }, }) diff --git a/packages/llm/src/index.ts b/packages/llm/src/index.ts index f4adf4859a..acf73b360e 100644 --- a/packages/llm/src/index.ts +++ b/packages/llm/src/index.ts @@ -17,6 +17,7 @@ export type { ExecutableTools, Tool as ToolShape, ToolExecute, + ToolExecuteContext, Tools, ToolSchema, } from "./tool" diff --git a/packages/llm/src/protocols/openai-responses.ts b/packages/llm/src/protocols/openai-responses.ts index e31a42cd5a..7cf734f027 100644 --- a/packages/llm/src/protocols/openai-responses.ts +++ b/packages/llm/src/protocols/openai-responses.ts @@ -380,7 +380,7 @@ type StepResult = readonly [ParserState, ReadonlyArray] const NO_EVENTS: StepResult["1"] = [] // `response.completed` / `response.incomplete` are clean finishes that emit a -// `request-finish` event; `response.failed` is a hard failure that emits a +// `finish` event; `response.failed` is a hard failure that emits a // `provider-error`. All three end the stream — kept in one set so `step` and // the protocol's `terminal` predicate stay in sync. const TERMINAL_TYPES = new Set(["response.completed", "response.incomplete", "response.failed"]) diff --git a/packages/llm/src/protocols/utils/lifecycle.ts b/packages/llm/src/protocols/utils/lifecycle.ts index 67039b137a..c249d75cee 100644 --- a/packages/llm/src/protocols/utils/lifecycle.ts +++ b/packages/llm/src/protocols/utils/lifecycle.ts @@ -80,7 +80,7 @@ export const finish = ( usage: input.usage, providerMetadata: input.providerMetadata, }), - LLMEvent.requestFinish(input), + LLMEvent.finish(input), ) return { ...stepped, stepStarted: false } } diff --git a/packages/llm/src/schema/events.ts b/packages/llm/src/schema/events.ts index 6e6bb1541b..6a088dc873 100644 --- a/packages/llm/src/schema/events.ts +++ b/packages/llm/src/schema/events.ts @@ -1,5 +1,5 @@ import { Schema } from "effect" -import { ContentBlockID, FinishReason, ProtocolID, ProviderMetadata, ResponseID, RouteID, ToolCallID } from "./ids" +import { ContentBlockID, FinishReason, ProtocolID, ProviderMetadata, RouteID, ToolCallID } from "./ids" import { ModelRef } from "./options" import { ToolResultValue } from "./messages" @@ -66,14 +66,13 @@ export class Usage extends Schema.Class("LLM.Usage")({ get visibleOutputTokens() { return Math.max(0, (this.outputTokens ?? 0) - (this.reasoningTokens ?? 0)) } + + static from(input: UsageInput) { + return input instanceof Usage ? input : new Usage(input) + } } -export const RequestStart = Schema.Struct({ - type: Schema.tag("request-start"), - id: ResponseID, - model: ModelRef, -}).annotate({ identifier: "LLM.Event.RequestStart" }) -export type RequestStart = Schema.Schema.Type +export type UsageInput = Usage | ConstructorParameters[0] export const StepStart = Schema.Struct({ type: Schema.tag("step-start"), @@ -185,13 +184,13 @@ export const StepFinish = Schema.Struct({ }).annotate({ identifier: "LLM.Event.StepFinish" }) export type StepFinish = Schema.Schema.Type -export const RequestFinish = Schema.Struct({ - type: Schema.tag("request-finish"), +export const Finish = Schema.Struct({ + type: Schema.tag("finish"), reason: FinishReason, usage: Schema.optional(Usage), providerMetadata: Schema.optional(ProviderMetadata), -}).annotate({ identifier: "LLM.Event.RequestFinish" }) -export type RequestFinish = Schema.Schema.Type +}).annotate({ identifier: "LLM.Event.Finish" }) +export type Finish = Schema.Schema.Type export const ProviderErrorEvent = Schema.Struct({ type: Schema.tag("provider-error"), @@ -202,7 +201,6 @@ export const ProviderErrorEvent = Schema.Struct({ export type ProviderErrorEvent = Schema.Schema.Type const llmEventTagged = Schema.Union([ - RequestStart, StepStart, TextStart, TextDelta, @@ -217,13 +215,15 @@ const llmEventTagged = Schema.Union([ ToolResult, ToolError, StepFinish, - RequestFinish, + Finish, ProviderErrorEvent, ]).pipe(Schema.toTaggedUnion("type")) type WithID = Omit & { readonly id: ID | string } +type WithUsage = Omit & { + readonly usage?: UsageInput +} -const responseID = (value: ResponseID | string) => ResponseID.make(value) const contentBlockID = (value: ContentBlockID | string) => ContentBlockID.make(value) const toolCallID = (value: ToolCallID | string) => ToolCallID.make(value) @@ -233,7 +233,6 @@ const toolCallID = (value: ToolCallID | string) => ToolCallID.make(value) * `events.filter(LLMEvent.guards["tool-call"])`. */ export const LLMEvent = Object.assign(llmEventTagged, { - requestStart: (input: WithID) => RequestStart.make({ ...input, id: responseID(input.id) }), stepStart: StepStart.make, textStart: (input: WithID) => TextStart.make({ ...input, id: contentBlockID(input.id) }), textDelta: (input: WithID) => TextDelta.make({ ...input, id: contentBlockID(input.id) }), @@ -252,11 +251,18 @@ export const LLMEvent = Object.assign(llmEventTagged, { toolCall: (input: WithID) => ToolCall.make({ ...input, id: toolCallID(input.id) }), toolResult: (input: WithID) => ToolResult.make({ ...input, id: toolCallID(input.id) }), toolError: (input: WithID) => ToolError.make({ ...input, id: toolCallID(input.id) }), - stepFinish: StepFinish.make, - requestFinish: RequestFinish.make, + stepFinish: (input: WithUsage) => + StepFinish.make({ + ...input, + usage: input.usage === undefined ? undefined : Usage.from(input.usage), + }), + finish: (input: WithUsage) => + Finish.make({ + ...input, + usage: input.usage === undefined ? undefined : Usage.from(input.usage), + }), providerError: ProviderErrorEvent.make, is: { - requestStart: llmEventTagged.guards["request-start"], stepStart: llmEventTagged.guards["step-start"], textStart: llmEventTagged.guards["text-start"], textDelta: llmEventTagged.guards["text-delta"], @@ -271,7 +277,7 @@ export const LLMEvent = Object.assign(llmEventTagged, { toolResult: llmEventTagged.guards["tool-result"], toolError: llmEventTagged.guards["tool-error"], stepFinish: llmEventTagged.guards["step-finish"], - requestFinish: llmEventTagged.guards["request-finish"], + finish: llmEventTagged.guards.finish, providerError: llmEventTagged.guards["provider-error"], }, }) diff --git a/packages/llm/src/tool-runtime.ts b/packages/llm/src/tool-runtime.ts index f464525827..d83dcc67ad 100644 --- a/packages/llm/src/tool-runtime.ts +++ b/packages/llm/src/tool-runtime.ts @@ -12,6 +12,7 @@ import { ToolFailure, ToolResultPart, type ToolResultValue, + Usage, } from "./schema" import { type AnyTool, type ExecutableTools, type Tools, toDefinitions } from "./tool" @@ -72,19 +73,42 @@ export const stream = (options: StreamOptions): Stream.Strea tools: [...options.request.tools.filter((tool) => !runtimeToolNames.has(tool.name)), ...runtimeTools], }) - const loop = (request: LLMRequest, step: number): Stream.Stream => + const loop = ( + request: LLMRequest, + step: number, + usage: Usage | undefined, + providerMetadata: ProviderMetadata | undefined, + ): Stream.Stream => Stream.unwrap( Effect.gen(function* () { - const state: StepState = { assistantContent: [], toolCalls: [], finishReason: undefined } + const state: StepState = { + assistantContent: [], + toolCalls: [], + finishReason: undefined, + usage: undefined, + providerMetadata: undefined, + } const modelStream = options .stream(request) + .pipe(Stream.map((event) => indexStep(event, step))) .pipe(Stream.tap((event) => Effect.sync(() => accumulate(state, event)))) + .pipe(Stream.filter((event) => event.type !== "finish")) const continuation = Stream.unwrap( Effect.gen(function* () { - if (state.finishReason !== "tool-calls" || state.toolCalls.length === 0) return Stream.empty - if (options.toolExecution === "none") return Stream.empty + const totalUsage = addUsage(usage, state.usage) + const totalProviderMetadata = mergeProviderMetadata(providerMetadata, state.providerMetadata) + const finishStream = Stream.fromIterable([ + LLMEvent.finish({ + reason: state.finishReason ?? "unknown", + usage: totalUsage, + providerMetadata: totalProviderMetadata, + }), + ]) + + if (state.finishReason !== "tool-calls" || state.toolCalls.length === 0) return finishStream + if (options.toolExecution === "none") return finishStream const dispatched = yield* Effect.forEach( state.toolCalls, @@ -93,10 +117,14 @@ export const stream = (options: StreamOptions): Stream.Strea ) const resultStream = Stream.fromIterable(dispatched.flatMap(([call, result]) => emitEvents(call, result))) - if (!options.stopWhen) return resultStream - if (options.stopWhen({ step, request })) return resultStream + if (!options.stopWhen) return resultStream.pipe(Stream.concat(finishStream)) + if (options.stopWhen({ step, request })) return resultStream.pipe(Stream.concat(finishStream)) - return resultStream.pipe(Stream.concat(loop(followUpRequest(request, state, dispatched), step + 1))) + return resultStream.pipe( + Stream.concat( + loop(followUpRequest(request, state, dispatched), step + 1, totalUsage, totalProviderMetadata), + ), + ) }), ) @@ -104,13 +132,21 @@ export const stream = (options: StreamOptions): Stream.Strea }), ) - return loop(initialRequest, 0) + return loop(initialRequest, 0, undefined, undefined) +} + +const indexStep = (event: LLMEvent, index: number): LLMEvent => { + if (event.type === "step-start") return LLMEvent.stepStart({ index }) + if (event.type === "step-finish") return LLMEvent.stepFinish({ ...event, index }) + return event } interface StepState { assistantContent: ContentPart[] toolCalls: ToolCallPart[] finishReason: FinishReason | undefined + usage: Usage | undefined + providerMetadata: ProviderMetadata | undefined } const accumulate = (state: StepState, event: LLMEvent) => { @@ -154,9 +190,43 @@ const accumulate = (state: StepState, event: LLMEvent) => { ) return } - if (event.type === "step-finish" || event.type === "request-finish") { + if (event.type === "step-finish") { state.finishReason = event.reason === "stop" && state.toolCalls.length > 0 ? "tool-calls" : event.reason + state.usage = addUsage(state.usage, event.usage) + state.providerMetadata = mergeProviderMetadata(state.providerMetadata, event.providerMetadata) + return } + if (event.type === "finish") { + state.finishReason ??= event.reason + state.usage ??= event.usage + state.providerMetadata = mergeProviderMetadata(state.providerMetadata, event.providerMetadata) + } +} + +const addUsage = (left: Usage | undefined, right: Usage | undefined) => { + if (!left) return right + if (!right) return left + type UsageKey = + | "inputTokens" + | "outputTokens" + | "nonCachedInputTokens" + | "cacheReadInputTokens" + | "cacheWriteInputTokens" + | "reasoningTokens" + | "totalTokens" + const sum = (key: UsageKey) => + left[key] === undefined && right[key] === undefined ? undefined : Number(left[key] ?? 0) + Number(right[key] ?? 0) + + return new Usage({ + inputTokens: sum("inputTokens"), + outputTokens: sum("outputTokens"), + nonCachedInputTokens: sum("nonCachedInputTokens"), + cacheReadInputTokens: sum("cacheReadInputTokens"), + cacheWriteInputTokens: sum("cacheWriteInputTokens"), + reasoningTokens: sum("reasoningTokens"), + totalTokens: sum("totalTokens"), + providerMetadata: mergeProviderMetadata(left.providerMetadata, right.providerMetadata), + }) } const sameProviderMetadata = (left: ProviderMetadata | undefined, right: ProviderMetadata | undefined) => @@ -200,17 +270,17 @@ const dispatch = (tools: Tools, call: ToolCallPart): Effect.Effect Effect.succeed({ type: "error" as const, value: failure.message } satisfies ToolResultValue), ), ) } -const decodeAndExecute = (tool: AnyTool, input: unknown): Effect.Effect => - tool._decode(input).pipe( +const decodeAndExecute = (tool: AnyTool, call: ToolCallPart): Effect.Effect => + tool._decode(call.input).pipe( Effect.mapError((error) => new ToolFailure({ message: `Invalid tool input: ${error.message}` })), - Effect.flatMap((decoded) => tool.execute!(decoded)), + Effect.flatMap((decoded) => tool.execute!(decoded, { id: call.id, name: call.name })), Effect.flatMap((value) => tool._encode(value).pipe( Effect.mapError( diff --git a/packages/llm/src/tool.ts b/packages/llm/src/tool.ts index 311c8798b6..df0a1cd3d3 100644 --- a/packages/llm/src/tool.ts +++ b/packages/llm/src/tool.ts @@ -1,5 +1,5 @@ import { Effect, JsonSchema, Schema } from "effect" -import type { ToolDefinition as ToolDefinitionClass } from "./schema" +import type { ToolCallPart, ToolDefinition as ToolDefinitionClass } from "./schema" import { ToolDefinition, ToolFailure } from "./schema" /** @@ -8,9 +8,14 @@ import { ToolDefinition, ToolFailure } from "./schema" * beyond pure data conversion belongs in the handler closure. */ export type ToolSchema = Schema.Codec +export interface ToolExecuteContext { + readonly id: ToolCallPart["id"] + readonly name: ToolCallPart["name"] +} export type ToolExecute, Success extends ToolSchema> = ( params: Schema.Schema.Type, + context?: ToolExecuteContext, ) => Effect.Effect, ToolFailure> /** @@ -61,7 +66,7 @@ type TypedToolConfig = { type DynamicToolConfig = { readonly description: string readonly jsonSchema: JsonSchema.JsonSchema - readonly execute?: (params: unknown) => Effect.Effect + readonly execute?: (params: unknown, context?: ToolExecuteContext) => Effect.Effect } /** @@ -110,7 +115,7 @@ export function make, Success extends ToolSch export function make(config: { readonly description: string readonly jsonSchema: JsonSchema.JsonSchema - readonly execute: (params: unknown) => Effect.Effect + readonly execute: (params: unknown, context?: ToolExecuteContext) => Effect.Effect }): AnyExecutableTool export function make(config: { readonly description: string diff --git a/packages/llm/test/adapter.test.ts b/packages/llm/test/adapter.test.ts index 5ac8b9d818..80349a5ae5 100644 --- a/packages/llm/test/adapter.test.ts +++ b/packages/llm/test/adapter.test.ts @@ -51,7 +51,7 @@ const request = LLM.request({ const raiseEvent = (event: FakeEvent): import("../src/schema").LLMEvent => event.type === "finish" - ? { type: "request-finish", reason: event.reason } + ? { type: "finish", reason: event.reason } : { type: "text-delta", id: "text-0", text: event.text } const fakeProtocol = Protocol.make({ @@ -112,8 +112,8 @@ describe("llm route", () => { const events = Array.from(yield* llm.stream(request).pipe(Stream.runCollect)) const response = yield* llm.generate(request) - expect(events.map((event) => event.type)).toEqual(["text-delta", "request-finish"]) - expect(response.events.map((event) => event.type)).toEqual(["text-delta", "request-finish"]) + expect(events.map((event) => event.type)).toEqual(["text-delta", "finish"]) + expect(response.events.map((event) => event.type)).toEqual(["text-delta", "finish"]) }), ) diff --git a/packages/llm/test/llm.test.ts b/packages/llm/test/llm.test.ts index c01fe33b29..a20c48411e 100644 --- a/packages/llm/test/llm.test.ts +++ b/packages/llm/test/llm.test.ts @@ -127,7 +127,7 @@ describe("llm constructors", () => { LLMResponse.text({ events: [ { type: "text-delta", id: "text-0", text: "hi" }, - { type: "request-finish", reason: "stop" }, + { type: "finish", reason: "stop" }, ], }), ).toBe("hi") diff --git a/packages/llm/test/provider/anthropic-messages.test.ts b/packages/llm/test/provider/anthropic-messages.test.ts index 6417f73c2b..71204bcd63 100644 --- a/packages/llm/test/provider/anthropic-messages.test.ts +++ b/packages/llm/test/provider/anthropic-messages.test.ts @@ -124,7 +124,7 @@ describe("Anthropic Messages route", () => { providerMetadata: { anthropic: { signature: "sig_1" } }, }) expect(response.events.at(-1)).toMatchObject({ - type: "request-finish", + type: "finish", reason: "stop", providerMetadata: { anthropic: { stopSequence: "\n\nHuman:" } }, }) @@ -182,7 +182,7 @@ describe("Anthropic Messages route", () => { }, { type: "step-finish", index: 0, reason: "tool-calls", usage, providerMetadata: undefined }, { - type: "request-finish", + type: "finish", reason: "tool-calls", providerMetadata: undefined, usage, @@ -275,7 +275,7 @@ describe("Anthropic Messages route", () => { providerMetadata: { anthropic: { blockType: "web_search_tool_result" } }, }) expect(response.text).toBe("Found it.") - expect(response.events.at(-1)).toMatchObject({ type: "request-finish", reason: "stop" }) + expect(response.events.at(-1)).toMatchObject({ type: "finish", reason: "stop" }) }), ) diff --git a/packages/llm/test/provider/bedrock-converse.test.ts b/packages/llm/test/provider/bedrock-converse.test.ts index 7d1ad3f309..ffdd6e8008 100644 --- a/packages/llm/test/provider/bedrock-converse.test.ts +++ b/packages/llm/test/provider/bedrock-converse.test.ts @@ -169,12 +169,12 @@ describe("Bedrock Converse route", () => { const response = yield* LLMClient.generate(baseRequest).pipe(Effect.provide(fixedBytes(body))) expect(response.text).toBe("Hello!") - const finishes = response.events.filter((event) => event.type === "request-finish") + const finishes = response.events.filter((event) => event.type === "finish") // Bedrock splits the finish across `messageStop` (carries reason) and // `metadata` (carries usage). We consolidate them into a single - // terminal `request-finish` event with both. + // terminal `finish` event with both. expect(finishes).toHaveLength(1) - expect(finishes[0]).toMatchObject({ type: "request-finish", reason: "stop" }) + expect(finishes[0]).toMatchObject({ type: "finish", reason: "stop" }) expect(response.usage).toMatchObject({ inputTokens: 5, outputTokens: 2, @@ -213,7 +213,7 @@ describe("Bedrock Converse route", () => { { type: "tool-input-delta", id: "tool_1", name: "lookup", text: '{"query"' }, { type: "tool-input-delta", id: "tool_1", name: "lookup", text: ':"weather"}' }, ]) - expect(response.events.at(-1)).toMatchObject({ type: "request-finish", reason: "tool-calls" }) + expect(response.events.at(-1)).toMatchObject({ type: "finish", reason: "tool-calls" }) }), ) diff --git a/packages/llm/test/provider/gemini.test.ts b/packages/llm/test/provider/gemini.test.ts index 80c32c58b3..7e6bbc8466 100644 --- a/packages/llm/test/provider/gemini.test.ts +++ b/packages/llm/test/provider/gemini.test.ts @@ -232,7 +232,7 @@ describe("Gemini route", () => { { type: "text-end", id: "text-0" }, { type: "step-finish", index: 0, reason: "stop", usage, providerMetadata: undefined }, { - type: "request-finish", + type: "finish", reason: "stop", usage, }, @@ -291,7 +291,7 @@ describe("Gemini route", () => { }, { type: "step-finish", index: 0, reason: "tool-calls", usage, providerMetadata: undefined }, { - type: "request-finish", + type: "finish", reason: "tool-calls", usage, }, @@ -325,7 +325,7 @@ describe("Gemini route", () => { { type: "tool-call", id: "tool_0", name: "lookup", input: { query: "weather" } }, { type: "tool-call", id: "tool_1", name: "lookup", input: { query: "news" } }, ]) - expect(response.events.at(-1)).toMatchObject({ type: "request-finish", reason: "tool-calls" }) + expect(response.events.at(-1)).toMatchObject({ type: "finish", reason: "tool-calls" }) }), ) @@ -344,10 +344,10 @@ describe("Gemini route", () => { ), ) - expect(length.events.map((event) => event.type)).toEqual(["step-start", "step-finish", "request-finish"]) - expect(length.events.at(-1)).toMatchObject({ type: "request-finish", reason: "length" }) - expect(filtered.events.map((event) => event.type)).toEqual(["step-start", "step-finish", "request-finish"]) - expect(filtered.events.at(-1)).toMatchObject({ type: "request-finish", reason: "content-filter" }) + expect(length.events.map((event) => event.type)).toEqual(["step-start", "step-finish", "finish"]) + expect(length.events.at(-1)).toMatchObject({ type: "finish", reason: "length" }) + expect(filtered.events.map((event) => event.type)).toEqual(["step-start", "step-finish", "finish"]) + expect(filtered.events.at(-1)).toMatchObject({ type: "finish", reason: "content-filter" }) }), ) diff --git a/packages/llm/test/provider/openai-chat.test.ts b/packages/llm/test/provider/openai-chat.test.ts index 115c58849c..4303a69ffa 100644 --- a/packages/llm/test/provider/openai-chat.test.ts +++ b/packages/llm/test/provider/openai-chat.test.ts @@ -249,7 +249,7 @@ describe("OpenAI Chat route", () => { { type: "text-end", id: "text-0" }, { type: "step-finish", index: 0, reason: "stop", usage, providerMetadata: undefined }, { - type: "request-finish", + type: "finish", reason: "stop", usage, }, @@ -288,7 +288,7 @@ describe("OpenAI Chat route", () => { providerMetadata: undefined, }, { type: "step-finish", index: 0, reason: "tool-calls", usage: undefined, providerMetadata: undefined }, - { type: "request-finish", reason: "tool-calls", usage: undefined }, + { type: "finish", reason: "tool-calls", usage: undefined }, ]) }), ) diff --git a/packages/llm/test/provider/openai-compatible-chat.test.ts b/packages/llm/test/provider/openai-compatible-chat.test.ts index 7759ff7202..50aac41091 100644 --- a/packages/llm/test/provider/openai-compatible-chat.test.ts +++ b/packages/llm/test/provider/openai-compatible-chat.test.ts @@ -231,7 +231,7 @@ describe("OpenAI-compatible Chat route", () => { expect(response.text).toBe("Hello!") expect(response.usage).toMatchObject({ inputTokens: 5, outputTokens: 2, totalTokens: 7 }) - expect(response.events.at(-1)).toMatchObject({ type: "request-finish", reason: "stop" }) + expect(response.events.at(-1)).toMatchObject({ type: "finish", reason: "stop" }) }), ) }) diff --git a/packages/llm/test/provider/openai-responses.test.ts b/packages/llm/test/provider/openai-responses.test.ts index 8b4469f4ed..63452f61b0 100644 --- a/packages/llm/test/provider/openai-responses.test.ts +++ b/packages/llm/test/provider/openai-responses.test.ts @@ -366,7 +366,7 @@ describe("OpenAI Responses route", () => { usage, }, { - type: "request-finish", + type: "finish", reason: "stop", providerMetadata: { openai: { responseId: "resp_1", serviceTier: "default" } }, usage, @@ -447,7 +447,7 @@ describe("OpenAI Responses route", () => { }, { type: "step-finish", index: 0, reason: "tool-calls", usage, providerMetadata: undefined }, { - type: "request-finish", + type: "finish", reason: "tool-calls", providerMetadata: undefined, usage, diff --git a/packages/llm/test/recorded-scenarios.ts b/packages/llm/test/recorded-scenarios.ts index bdba8580fd..3af7a77608 100644 --- a/packages/llm/test/recorded-scenarios.ts +++ b/packages/llm/test/recorded-scenarios.ts @@ -120,8 +120,8 @@ export const runWeatherToolLoop = (request: LLMRequest) => export const expectFinish = ( events: ReadonlyArray, - reason: Extract["reason"], -) => expect(events.at(-1)).toMatchObject({ type: "request-finish", reason }) + reason: Extract["reason"], +) => expect(events.at(-1)).toMatchObject({ type: "finish", reason }) export const expectWeatherToolCall = (response: LLMResponse) => expect(response.toolCalls).toMatchObject([ @@ -129,10 +129,12 @@ export const expectWeatherToolCall = (response: LLMResponse) => ]) export const expectWeatherToolLoop = (events: ReadonlyArray) => { - const finishes = events.filter(LLMEvent.is.requestFinish) - expect(finishes).toHaveLength(2) - expect(finishes[0]?.reason).toBe("tool-calls") - expect(finishes.at(-1)?.reason).toBe("stop") + const finishes = events.filter(LLMEvent.is.finish) + expect(finishes).toHaveLength(1) + expect(finishes[0]?.reason).toBe("stop") + + const stepFinishes = events.filter(LLMEvent.is.stepFinish) + expect(stepFinishes.map((event) => event.reason)).toEqual(["tool-calls", "stop"]) const toolCalls = events.filter(LLMEvent.is.toolCall) expect(toolCalls).toHaveLength(1) @@ -272,7 +274,7 @@ export const eventSummary = (events: ReadonlyArray) => { summary.push({ type: "tool-error", name: event.name, message: event.message }) continue } - if (event.type === "request-finish") { + if (event.type === "finish") { summary.push({ type: "finish", reason: event.reason, usage: usageSummary(event.usage) }) } } diff --git a/packages/llm/test/schema.test.ts b/packages/llm/test/schema.test.ts index 23bd9fd9bb..01d6fadd9f 100644 --- a/packages/llm/test/schema.test.ts +++ b/packages/llm/test/schema.test.ts @@ -44,6 +44,11 @@ describe("llm schema", () => { expect(() => Schema.decodeUnknownSync(LLMEvent)({ type: "bogus" })).toThrow() }) + test("finish constructors accept usage input", () => { + expect(LLMEvent.stepFinish({ index: 0, reason: "stop", usage: { inputTokens: 1 } }).usage).toBeInstanceOf(Usage) + expect(LLMEvent.finish({ reason: "stop", usage: { outputTokens: 2 } }).usage).toBeInstanceOf(Usage) + }) + test("content part tagged union exposes guards", () => { expect(ContentPart.guards.text({ type: "text", text: "hi" })).toBe(true) expect(ContentPart.guards.media({ type: "text", text: "hi" })).toBe(false) diff --git a/packages/llm/test/tool-runtime.test.ts b/packages/llm/test/tool-runtime.test.ts index 040a11fb68..573021c4c2 100644 --- a/packages/llm/test/tool-runtime.test.ts +++ b/packages/llm/test/tool-runtime.test.ts @@ -4,7 +4,8 @@ import { GenerationOptions, LLM, LLMEvent, LLMRequest, LLMResponse, ToolChoice } import { LLMClient } from "../src/route" import * as AnthropicMessages from "../src/protocols/anthropic-messages" import * as OpenAIChat from "../src/protocols/openai-chat" -import { tool, ToolFailure } from "../src/tool" +import { tool, ToolFailure, type ToolExecuteContext } from "../src/tool" +import { ToolRuntime } from "../src/tool-runtime" import { it } from "./lib/effect" import * as TestToolRuntime from "./lib/tool-runtime" import { dynamicResponse, scriptedResponses } from "./lib/http" @@ -129,7 +130,7 @@ describe("LLMClient tools", () => { name: "get_weather", result: { type: "json", value: { temperature: 22, condition: "sunny" } }, }) - expect(events.at(-1)?.type).toBe("request-finish") + expect(events.at(-1)?.type).toBe("finish") expect(LLMResponse.text({ events })).toBe("It's sunny in Paris.") }), ) @@ -148,11 +149,40 @@ describe("LLMClient tools", () => { ), ) - expect(events.filter(LLMEvent.is.requestFinish)).toHaveLength(1) + expect(events.filter(LLMEvent.is.finish)).toHaveLength(1) expect(events.find(LLMEvent.is.toolResult)).toMatchObject({ type: "tool-result", id: "call_1" }) }), ) + it.effect("passes tool call context to execute", () => + Effect.gen(function* () { + let context: ToolExecuteContext | undefined + const contextual = tool({ + description: "Capture tool context.", + parameters: Schema.Struct({ value: Schema.String }), + success: Schema.Struct({ ok: Schema.Boolean }), + execute: (_params, ctx) => + Effect.sync(() => { + context = ctx + return { ok: true } + }), + }) + const events = Array.from( + yield* TestToolRuntime.runTools({ request: baseRequest, tools: { contextual } }).pipe( + Stream.runCollect, + Effect.provide( + scriptedResponses([ + sseEvents(toolCallChunk("call_ctx", "contextual", '{"value":"x"}'), finishChunk("tool_calls")), + ]), + ), + ), + ) + + expect(events.some(LLMEvent.is.toolResult)).toBe(true) + expect(context).toEqual({ id: "call_ctx", name: "contextual" }) + }), + ) + it.effect("can expose tool schemas without executing tool calls", () => Effect.gen(function* () { const layer = scriptedResponses([ @@ -319,7 +349,7 @@ describe("LLMClient tools", () => { "text-delta", "text-end", "step-finish", - "request-finish", + "finish", ]) expect(LLMResponse.text({ events })).toBe("Done.") }), @@ -343,7 +373,57 @@ describe("LLMClient tools", () => { ), ) - expect(events.filter(LLMEvent.is.requestFinish)).toHaveLength(2) + expect(events.filter(LLMEvent.is.finish)).toHaveLength(1) + expect(events.filter(LLMEvent.is.stepStart).map((event) => event.index)).toEqual([0, 1]) + expect(events.filter(LLMEvent.is.stepFinish).map((event) => event.index)).toEqual([0, 1]) + }), + ) + + it.effect("emits one final finish with aggregate usage", () => + Effect.gen(function* () { + let calls = 0 + const events = Array.from( + yield* ToolRuntime.stream({ + request: baseRequest, + tools: { get_weather }, + stopWhen: ToolRuntime.stepCountIs(2), + stream: () => + Stream.fromIterable( + calls++ === 0 + ? [ + LLMEvent.stepStart({ index: 0 }), + LLMEvent.toolCall({ id: "call_1", name: "get_weather", input: { city: "Paris" } }), + LLMEvent.stepFinish({ + index: 0, + reason: "tool-calls", + usage: { inputTokens: 1, outputTokens: 2, totalTokens: 3 }, + }), + LLMEvent.finish({ + reason: "tool-calls", + usage: { inputTokens: 1, outputTokens: 2, totalTokens: 3 }, + }), + ] + : [ + LLMEvent.stepStart({ index: 0 }), + LLMEvent.textDelta({ id: "text_1", text: "Done." }), + LLMEvent.stepFinish({ + index: 0, + reason: "stop", + usage: { inputTokens: 4, outputTokens: 5, totalTokens: 9 }, + }), + LLMEvent.finish({ reason: "stop", usage: { inputTokens: 4, outputTokens: 5, totalTokens: 9 } }), + ], + ), + }).pipe(Stream.runCollect), + ) + + expect(events.filter(LLMEvent.is.stepFinish).map((event) => event.index)).toEqual([0, 1]) + expect(events.filter(LLMEvent.is.finish)).toHaveLength(1) + expect(events.find(LLMEvent.is.finish)?.usage).toMatchObject({ + inputTokens: 5, + outputTokens: 7, + totalTokens: 12, + }) }), ) @@ -362,7 +442,7 @@ describe("LLMClient tools", () => { }).pipe(Stream.runCollect, Effect.provide(layer)), ) - expect(events.filter(LLMEvent.is.requestFinish)).toHaveLength(1) + expect(events.filter(LLMEvent.is.finish)).toHaveLength(1) expect(events.find(LLMEvent.is.toolResult)).toMatchObject({ type: "tool-result", id: "call_1" }) }), )