feat(core): sort session diffs with small model

This commit is contained in:
Adam
2026-01-05 15:44:16 -06:00
parent 48d14d4cac
commit 7139e4d075

View File

@@ -17,8 +17,27 @@ import { Bus } from "@/bus"
import { LLM } from "./llm"
import { Agent } from "@/agent/agent"
const DIFF_ORDER_PROMPT = `You help order file diffs for code review.
Given a list of changed files, return the best order to review them so a developer can understand the changes quickly.
Guidelines:
- Prefer dependency order (types/config before usage)
- Group related files together
- Put core logic and shared utilities before UI
- Put tests and docs after code
Rules:
- Output ONLY file paths, one per line
- Use ONLY the file paths provided in the input
- Include every provided file path exactly once
- No numbering, bullets, headings, or commentary
`
export namespace SessionSummary {
const log = Log.create({ service: "session.summary" })
const seq = new Map<string, number>()
const aborts = new Map<string, AbortController>()
export const summarize = fn(
z.object({
@@ -35,6 +54,14 @@ export namespace SessionSummary {
)
async function summarizeSession(input: { sessionID: string; messages: MessageV2.WithParts[] }) {
const id = input.sessionID
const version = (seq.get(id) ?? 0) + 1
seq.set(id, version)
const ctrl = aborts.get(id)
if (ctrl) ctrl.abort()
aborts.delete(id)
const files = new Set(
input.messages
.flatMap((x) => x.parts)
@@ -42,25 +69,171 @@ export namespace SessionSummary {
.flatMap((x) => x.files)
.map((x) => path.relative(Instance.worktree, x)),
)
const diffs = await computeDiff({ messages: input.messages }).then((x) =>
x.filter((x) => {
return files.has(x.file)
}),
)
await Session.update(input.sessionID, (draft) => {
const diffs = await computeDiff({ messages: input.messages }).then((x) => x.filter((x) => files.has(x.file)))
if (seq.get(id) !== version) return
const ordered = await orderDiffs({ sessionID: id, diffs, messages: input.messages })
if (seq.get(id) !== version) return
const safe = (n: number) => (Number.isFinite(n) ? n : 0)
await Session.update(id, (draft) => {
draft.summary = {
additions: diffs.reduce((sum, x) => sum + x.additions, 0),
deletions: diffs.reduce((sum, x) => sum + x.deletions, 0),
files: diffs.length,
additions: ordered.reduce((sum, x) => sum + safe(x.additions), 0),
deletions: ordered.reduce((sum, x) => sum + safe(x.deletions), 0),
files: ordered.length,
}
})
await Storage.write(["session_diff", input.sessionID], diffs)
await Storage.write(["session_diff", id], ordered)
Bus.publish(Session.Event.Diff, {
sessionID: input.sessionID,
diff: diffs,
sessionID: id,
diff: ordered,
})
}
async function orderDiffs(input: { sessionID: string; diffs: Snapshot.FileDiff[]; messages: MessageV2.WithParts[] }) {
if (input.diffs.length <= 1) return input.diffs
const safe = (n: number) => (Number.isFinite(n) ? n : 0)
const sig = (diffs: Snapshot.FileDiff[]) =>
diffs
.map((d) => `${d.file}:${safe(d.additions)}:${safe(d.deletions)}:${d.before.length}:${d.after.length}`)
.sort()
.join("\n")
const prev = await Storage.read<Snapshot.FileDiff[]>(["session_diff", input.sessionID]).catch(() => [])
const map = new Map(input.diffs.map((d) => [d.file, d]))
const cached = prev.map((d) => map.get(d.file)).filter((d): d is Snapshot.FileDiff => !!d)
const stable = cached.length === input.diffs.length && cached.length === prev.length ? cached : undefined
if (stable && sig(prev) === sig(input.diffs)) return stable
const fallback = stable ?? input.diffs
const user = input.messages
.slice()
.reverse()
.find((m) => m.info.role === "user")?.info as MessageV2.User | undefined
if (!user) return fallback
const model = await sortModel(input.messages)
if (!model) return fallback
const base = await Agent.get("summary")
const agent: Agent.Info = {
...base,
name: "diff-order",
prompt: DIFF_ORDER_PROMPT,
temperature: 0.2,
}
const items = input.diffs
.map((d) => {
const additions = safe(d.additions)
const deletions = safe(d.deletions)
const ext = path.extname(d.file) || "none"
const kind =
d.before === "" && d.after === ""
? "binary"
: d.before === ""
? "added"
: d.after === ""
? "deleted"
: "modified"
return `${d.file}\t${kind}\t+${additions}\t-${deletions}\text:${ext}`
})
.join("\n")
const abort = new AbortController()
aborts.set(input.sessionID, abort)
const timer = setTimeout(() => abort.abort(), 8000)
const clean = () => {
clearTimeout(timer)
if (aborts.get(input.sessionID) === abort) aborts.delete(input.sessionID)
}
const stream = await LLM.stream({
agent,
user,
tools: {},
model,
small: true,
messages: [
{
role: "user" as const,
content: `Order these files for review.\n\nFiles (tab-separated: path\tkind\t+adds\t-dels\text:ext):\n${items}`,
},
],
abort: abort.signal,
sessionID: user.sessionID,
system: [],
retries: 1,
}).catch(() => undefined)
if (!stream) {
clean()
return fallback
}
const text = await stream.text.catch(() => "").finally(clean)
const files = new Set(input.diffs.map((d) => d.file))
const seen = new Set<string>()
const order = text
.split("\n")
.map((line) => line.trim())
.filter(Boolean)
.map((line) =>
line
.replace(/^[-*]\s+/, "")
.replace(/^\d+[.)]\s+/, "")
.replace(/^['"`]/, "")
.replace(/['"`,]$/, "")
.trim(),
)
.map((line) => (line.includes("\t") ? (line.split("\t")[0] ?? "") : line).trim())
.filter((line) => files.has(line))
.filter((line) => {
if (seen.has(line)) return false
seen.add(line)
return true
})
if (order.length === 0) return fallback
const sorted = order.map((file) => map.get(file)).filter((d): d is Snapshot.FileDiff => !!d)
const rest = input.diffs.filter((d) => !seen.has(d.file))
const result = [...sorted, ...rest]
if (result.length !== input.diffs.length) return fallback
log.debug("diff order", {
sessionID: input.sessionID,
ordered: result.map((d) => d.file),
})
return result
}
async function sortModel(messages: MessageV2.WithParts[]) {
const assistant = messages
.slice()
.reverse()
.find((m) => m.info.role === "assistant")?.info as MessageV2.Assistant | undefined
if (assistant) {
const small = await Provider.getSmallModel(assistant.providerID).catch(() => undefined)
if (small) return small
return Provider.getModel(assistant.providerID, assistant.modelID).catch(() => undefined)
}
const defaultModel = await Provider.defaultModel().catch(() => undefined)
if (!defaultModel) return undefined
const small = await Provider.getSmallModel(defaultModel.providerID).catch(() => undefined)
if (small) return small
return Provider.getModel(defaultModel.providerID, defaultModel.modelID).catch(() => undefined)
}
async function summarizeMessage(input: { messageID: string; messages: MessageV2.WithParts[] }) {
const messages = input.messages.filter(
(m) => m.info.id === input.messageID || (m.info.role === "assistant" && m.info.parentID === input.messageID),