Compare commits

...

1 Commits

Author SHA1 Message Date
Kit Langton
73ea130d16 fix(session): preserve tool metadata during pending transition 2026-03-31 15:07:31 -04:00
6 changed files with 218 additions and 38 deletions

View File

@@ -1,7 +1,6 @@
import { Provider } from "@/provider/provider"
import { Log } from "@/util/log"
import { Cause, Effect, Layer, Record, ServiceMap } from "effect"
import * as Queue from "effect/Queue"
import { Effect, Layer, Record, ServiceMap } from "effect"
import * as Stream from "effect/Stream"
import { streamText, wrapLanguageModel, type ModelMessage, type Tool, tool, jsonSchema } from "ai"
import { mergeDeep, pipe } from "remeda"
@@ -60,21 +59,8 @@ export namespace LLM {
Effect.sync(() => new AbortController()),
(ctrl) => Effect.sync(() => ctrl.abort()),
)
const queue = yield* Queue.unbounded<Event, unknown | Cause.Done>()
yield* Effect.promise(async () => {
const result = await LLM.stream({ ...input, abort: ctrl.signal })
for await (const event of result.fullStream) {
if (!Queue.offerUnsafe(queue, event)) break
}
Queue.endUnsafe(queue)
}).pipe(
Effect.catchCause((cause) => Effect.sync(() => void Queue.failCauseUnsafe(queue, cause))),
Effect.onInterrupt(() => Effect.sync(() => ctrl.abort())),
Effect.forkScoped,
)
return Stream.fromQueue(queue)
const result = yield* Effect.promise(() => LLM.stream({ ...input, abort: ctrl.signal }))
return Stream.fromAsyncIterable(result.fullStream, (err) => err)
}),
),
)

View File

@@ -30,6 +30,10 @@ export namespace SessionProcessor {
export interface Handle {
readonly message: MessageV2.Assistant
readonly partFromToolCall: (toolCallID: string) => MessageV2.ToolPart | undefined
readonly metadata: (
toolCallID: string,
input: { title?: string; metadata?: Record<string, any> },
) => Effect.Effect<void>
readonly abort: () => Effect.Effect<void>
readonly process: (streamInput: LLM.StreamInput) => Effect.Effect<Result>
}
@@ -46,6 +50,7 @@ export namespace SessionProcessor {
interface ProcessorContext extends Input {
toolcalls: Record<string, MessageV2.ToolPart>
toolmeta: Record<string, { title?: string; metadata?: Record<string, any> }>
shouldBreak: boolean
snapshot: string | undefined
blocked: boolean
@@ -89,6 +94,7 @@ export namespace SessionProcessor {
sessionID: input.sessionID,
model: input.model,
toolcalls: {},
toolmeta: {},
shouldBreak: false,
snapshot: undefined,
blocked: false,
@@ -172,13 +178,21 @@ export namespace SessionProcessor {
throw new Error(`Tool call not allowed while generating summary: ${value.toolName}`)
}
const match = ctx.toolcalls[value.toolCallId]
const meta = ctx.toolmeta[value.toolCallId]
if (!match) return
ctx.toolcalls[value.toolCallId] = yield* session.updatePart({
...match,
tool: value.toolName,
state: { status: "running", input: value.input, time: { start: Date.now() } },
state: {
status: "running",
input: value.input,
title: meta?.title,
metadata: meta?.metadata,
time: { start: Date.now() },
},
metadata: value.providerMetadata,
} satisfies MessageV2.ToolPart)
delete ctx.toolmeta[value.toolCallId]
const parts = yield* Effect.promise(() => MessageV2.parts(ctx.assistantMessage.id))
const recentParts = parts.slice(-DOOM_LOOP_THRESHOLD)
@@ -224,6 +238,7 @@ export namespace SessionProcessor {
},
})
delete ctx.toolcalls[value.toolCallId]
delete ctx.toolmeta[value.toolCallId]
return
}
@@ -243,6 +258,7 @@ export namespace SessionProcessor {
ctx.blocked = ctx.shouldBreak
}
delete ctx.toolcalls[value.toolCallId]
delete ctx.toolmeta[value.toolCallId]
return
}
@@ -494,6 +510,24 @@ export namespace SessionProcessor {
partFromToolCall(toolCallID: string) {
return ctx.toolcalls[toolCallID]
},
metadata: Effect.fn("SessionProcessor.metadata")(function* (toolCallID, input) {
const match = ctx.toolcalls[toolCallID]
if (!match || match.state.status !== "running") {
ctx.toolmeta[toolCallID] = {
...ctx.toolmeta[toolCallID],
...input,
}
return
}
ctx.toolcalls[toolCallID] = yield* session.updatePart({
...match,
state: {
...match.state,
title: input.title ?? match.state.title,
metadata: input.metadata ?? match.state.metadata,
},
})
}),
abort,
process,
} satisfies Handle

View File

@@ -384,7 +384,7 @@ NOTE: At any point in time through this workflow you should feel free to ask the
model: Provider.Model
session: Session.Info
tools?: Record<string, boolean>
processor: Pick<SessionProcessor.Handle, "message" | "partFromToolCall">
processor: Pick<SessionProcessor.Handle, "message" | "partFromToolCall" | "metadata">
bypassAgentCheck: boolean
messages: MessageV2.WithParts[]
}) {
@@ -399,23 +399,7 @@ NOTE: At any point in time through this workflow you should feel free to ask the
extra: { model: input.model, bypassAgentCheck: input.bypassAgentCheck },
agent: input.agent.name,
messages: input.messages,
metadata: (val) =>
Effect.runPromise(
Effect.gen(function* () {
const match = input.processor.partFromToolCall(options.toolCallId)
if (!match || !["running", "pending"].includes(match.state.status)) return
yield* sessions.updatePart({
...match,
state: {
title: val.title,
metadata: val.metadata,
status: "running",
input: args,
time: { start: Date.now() },
},
})
}),
),
metadata: (val) => Effect.runPromise(input.processor.metadata(options.toolCallId, val)),
ask: (req) =>
Effect.runPromise(
permission.ask({

View File

@@ -149,6 +149,7 @@ function fake(
state: { status: "pending", input: {}, raw: "" },
}
},
metadata: Effect.fn("TestSessionProcessor.metadata")(() => Effect.void),
process: Effect.fn("TestSessionProcessor.process")(() => Effect.succeed(result)),
} satisfies SessionProcessorModule.SessionProcessor.Handle
}

View File

@@ -1,7 +1,7 @@
import { afterAll, beforeAll, beforeEach, describe, expect, test } from "bun:test"
import { afterAll, beforeAll, beforeEach, describe, expect, spyOn, test } from "bun:test"
import path from "path"
import { tool, type ModelMessage } from "ai"
import { Cause, Exit, Stream } from "effect"
import { Cause, Effect, Exit, Stream } from "effect"
import z from "zod"
import { makeRuntime } from "../../src/effect/run-service"
import { LLM } from "../../src/session/llm"
@@ -541,6 +541,94 @@ describe("session.llm.stream", () => {
})
})
test("service stream preserves fullStream backpressure", async () => {
const release = deferred<void>()
let pulled = false
const mock = spyOn(LLM, "stream").mockResolvedValue({
fullStream: {
[Symbol.asyncIterator]() {
let i = 0
return {
next: async () => {
if (i === 0) {
i += 1
return { done: false, value: { type: "start" } as LLM.Event }
}
if (i === 1) {
pulled = true
await release.promise
i += 1
return {
done: false,
value: {
type: "finish",
finishReason: "stop",
rawFinishReason: "stop",
totalUsage: {
inputTokens: 0,
outputTokens: 0,
totalTokens: 0,
},
} as LLM.Event,
}
}
return { done: true, value: undefined }
},
return: async () => ({ done: true, value: undefined }),
}
},
},
} as Awaited<ReturnType<typeof LLM.stream>>)
await using tmp = await tmpdir()
try {
await Instance.provide({
directory: tmp.path,
fn: async () => {
const sessionID = SessionID.make("session-test-service-backpressure")
const { runPromise } = makeRuntime(LLM.Service, LLM.defaultLayer)
await runPromise((svc) =>
svc
.stream({
user: {
id: MessageID.make("user-service-backpressure"),
sessionID,
role: "user",
time: { created: Date.now() },
agent: "test",
model: { providerID: ProviderID.make("test"), modelID: ModelID.make("test") },
} satisfies MessageV2.User,
sessionID,
model: {} as Provider.Model,
agent: {
name: "test",
mode: "primary",
options: {},
permission: [{ permission: "*", pattern: "*", action: "allow" }],
} satisfies Agent.Info,
system: [],
messages: [],
tools: {},
})
.pipe(
Stream.tap((event) =>
event.type === "start"
? Effect.sync(() => {
expect(pulled).toBe(false)
release.resolve()
})
: Effect.void,
),
Stream.runDrain,
),
)
},
})
} finally {
mock.mockRestore()
}
})
test("keeps tools enabled by prompt permissions", async () => {
const server = state.server
if (!server) {

View File

@@ -532,6 +532,93 @@ it.effect("failed subtask preserves metadata on error tool state", () =>
),
)
it.effect(
"task tool preserves session metadata while still running",
() =>
provideTmpdirInstance(
(dir) =>
Effect.gen(function* () {
const child = SessionID.make("task-child")
const init = spyOn(TaskTool, "init").mockResolvedValue({
description: "task",
parameters: z.object({
description: z.string(),
prompt: z.string(),
subagent_type: z.string(),
task_id: z.string().optional(),
command: z.string().optional(),
}),
execute: async (_args, ctx) => {
ctx.metadata({
title: "inspect bug",
metadata: {
sessionId: child,
model: ref,
},
})
return {
title: "inspect bug",
metadata: {
sessionId: child,
model: ref,
},
output: "",
}
},
})
yield* Effect.addFinalizer(() => Effect.sync(() => init.mockRestore()))
const { test, prompt, chat } = yield* boot({ title: "Pinned" })
yield* test.push((input) => {
const args = {
description: "inspect bug",
prompt: "look into the cache key path",
subagent_type: "general",
}
const exec = input.tools.task?.execute
if (!exec) throw new Error("task tool missing execute")
return stream(start(), toolInputStart("task-1", "task")).pipe(
Stream.concat(
Stream.fromEffect(
Effect.promise(async () => {
void exec(args, {
toolCallId: "task-1",
abortSignal: new AbortController().signal,
messages: input.messages,
})
return toolCall("task-1", "task", args)
}),
),
),
Stream.concat(Stream.fromEffect(Effect.never)),
)
})
yield* user(chat.id, "launch a subagent")
const fiber = yield* prompt.loop({ sessionID: chat.id }).pipe(Effect.forkChild)
const tool = yield* Effect.promise(async () => {
const end = Date.now() + 2_000
for (;;) {
const msgs = await MessageV2.filterCompacted(MessageV2.stream(chat.id))
const msg = msgs.findLast((item) => item.info.role === "assistant")
const part = msg?.parts.find((item): item is MessageV2.ToolPart => item.type === "tool")
if (part?.state.status === "running") return part
if (Date.now() > end) throw new Error("timed out waiting for running task tool")
await Bun.sleep(10)
}
})
if (tool.state.status !== "running") throw new Error("expected running task tool")
expect(tool.state.metadata?.sessionId).toBe(child)
yield* Fiber.interrupt(fiber)
}),
{ git: true, config: cfg },
),
30_000,
)
it.effect("loop sets status to busy then idle", () =>
provideTmpdirInstance(
(dir) =>