Compare commits

...

4 Commits

Author SHA1 Message Date
Ryan Vogel
c92351ec52 Delete packages/console/app/src/routes/zen/util/handler.ts 2026-02-12 09:48:09 -05:00
Ryan Vogel
4bd25e0e33 feat(tool): return image attachments from webfetch 2026-02-12 09:45:37 -05:00
Frank
798f866d4c wip: zen 2026-02-12 09:28:53 -05:00
opencode-agent[bot]
838c9968ec chore: generate 2026-02-12 14:28:17 +00:00
5 changed files with 129 additions and 849 deletions

View File

@@ -1,839 +0,0 @@
import type { APIEvent } from "@solidjs/start/server"
import { and, Database, eq, isNull, lt, or, sql } from "@opencode-ai/console-core/drizzle/index.js"
import { KeyTable } from "@opencode-ai/console-core/schema/key.sql.js"
import { BillingTable, SubscriptionTable, UsageTable } from "@opencode-ai/console-core/schema/billing.sql.js"
import { centsToMicroCents } from "@opencode-ai/console-core/util/price.js"
import { getWeekBounds } from "@opencode-ai/console-core/util/date.js"
import { Identifier } from "@opencode-ai/console-core/identifier.js"
import { Billing } from "@opencode-ai/console-core/billing.js"
import { Actor } from "@opencode-ai/console-core/actor.js"
import { WorkspaceTable } from "@opencode-ai/console-core/schema/workspace.sql.js"
import { ZenData } from "@opencode-ai/console-core/model.js"
import { Black, BlackData } from "@opencode-ai/console-core/black.js"
import { UserTable } from "@opencode-ai/console-core/schema/user.sql.js"
import { ModelTable } from "@opencode-ai/console-core/schema/model.sql.js"
import { ProviderTable } from "@opencode-ai/console-core/schema/provider.sql.js"
import { logger } from "./logger"
import {
AuthError,
CreditsError,
MonthlyLimitError,
UserLimitError,
ModelError,
FreeUsageLimitError,
SubscriptionUsageLimitError,
} from "./error"
import { createBodyConverter, createStreamPartConverter, createResponseConverter, UsageInfo } from "./provider/provider"
import { anthropicHelper } from "./provider/anthropic"
import { googleHelper } from "./provider/google"
import { openaiHelper } from "./provider/openai"
import { oaCompatHelper } from "./provider/openai-compatible"
import { createRateLimiter } from "./rateLimiter"
import { createDataDumper } from "./dataDumper"
import { createTrialLimiter } from "./trialLimiter"
import { createStickyTracker } from "./stickyProviderTracker"
type ZenData = Awaited<ReturnType<typeof ZenData.list>>
type RetryOptions = {
excludeProviders: string[]
retryCount: number
}
type BillingSource = "anonymous" | "free" | "byok" | "subscription" | "balance"
export async function handler(
input: APIEvent,
opts: {
format: ZenData.Format
parseApiKey: (headers: Headers) => string | undefined
parseModel: (url: string, body: any) => string
parseIsStream: (url: string, body: any) => boolean
},
) {
type AuthInfo = Awaited<ReturnType<typeof authenticate>>
type ModelInfo = Awaited<ReturnType<typeof validateModel>>
type ProviderInfo = Awaited<ReturnType<typeof selectProvider>>
type CostInfo = ReturnType<typeof calculateCost>
const MAX_FAILOVER_RETRIES = 3
const MAX_429_RETRIES = 3
const FREE_WORKSPACES = [
"wrk_01K46JDFR0E75SG2Q8K172KF3Y", // frank
"wrk_01K6W1A3VE0KMNVSCQT43BG2SX", // opencode bench
]
try {
const url = input.request.url
const body = await input.request.json()
const model = opts.parseModel(url, body)
const isStream = opts.parseIsStream(url, body)
const ip = input.request.headers.get("x-real-ip") ?? ""
const sessionId = input.request.headers.get("x-opencode-session") ?? ""
const requestId = input.request.headers.get("x-opencode-request") ?? ""
const projectId = input.request.headers.get("x-opencode-project") ?? ""
const ocClient = input.request.headers.get("x-opencode-client") ?? ""
logger.metric({
is_tream: isStream,
session: sessionId,
request: requestId,
client: ocClient,
})
const zenData = ZenData.list()
const modelInfo = validateModel(zenData, model)
const dataDumper = createDataDumper(sessionId, requestId, projectId)
const trialLimiter = createTrialLimiter(modelInfo.trial, ip, ocClient)
const isTrial = await trialLimiter?.isTrial()
const rateLimiter = createRateLimiter(modelInfo.rateLimit, ip, input.request.headers)
await rateLimiter?.check()
const stickyTracker = createStickyTracker(modelInfo.stickyProvider, sessionId)
const stickyProvider = await stickyTracker?.get()
const authInfo = await authenticate(modelInfo)
const billingSource = validateBilling(authInfo, modelInfo)
const retriableRequest = async (retry: RetryOptions = { excludeProviders: [], retryCount: 0 }) => {
const providerInfo = selectProvider(
model,
zenData,
authInfo,
modelInfo,
sessionId,
isTrial ?? false,
retry,
stickyProvider,
)
validateModelSettings(authInfo)
updateProviderKey(authInfo, providerInfo)
logger.metric({ provider: providerInfo.id })
const startTimestamp = Date.now()
const reqUrl = providerInfo.modifyUrl(providerInfo.api, isStream)
const reqBody = JSON.stringify(
providerInfo.modifyBody({
...createBodyConverter(opts.format, providerInfo.format)(body),
model: providerInfo.model,
}),
)
logger.debug("REQUEST URL: " + reqUrl)
logger.debug("REQUEST: " + reqBody.substring(0, 300) + "...")
const res = await fetchWith429Retry(reqUrl, {
method: "POST",
headers: (() => {
const headers = new Headers(input.request.headers)
providerInfo.modifyHeaders(headers, body, providerInfo.apiKey)
Object.entries(providerInfo.headerMappings ?? {}).forEach(([k, v]) => {
headers.set(k, headers.get(v)!)
})
Object.entries(providerInfo.headers ?? {}).forEach(([k, v]) => {
headers.set(k, v)
})
headers.delete("host")
headers.delete("content-length")
headers.delete("x-opencode-request")
headers.delete("x-opencode-session")
headers.delete("x-opencode-project")
headers.delete("x-opencode-client")
return headers
})(),
body: reqBody,
})
if (res.status !== 200) {
logger.metric({
"llm.error.code": res.status,
"llm.error.message": res.statusText,
})
}
// Try another provider => stop retrying if using fallback provider
if (
res.status !== 200 &&
// ie. openai 404 error: Item with id 'msg_0ead8b004a3b165d0069436a6b6834819896da85b63b196a3f' not found.
res.status !== 404 &&
// ie. cannot change codex model providers mid-session
modelInfo.stickyProvider !== "strict" &&
modelInfo.fallbackProvider &&
providerInfo.id !== modelInfo.fallbackProvider
) {
return retriableRequest({
excludeProviders: [...retry.excludeProviders, providerInfo.id],
retryCount: retry.retryCount + 1,
})
}
return { providerInfo, reqBody, res, startTimestamp }
}
const { providerInfo, reqBody, res, startTimestamp } = await retriableRequest()
// Store model request
dataDumper?.provideModel(providerInfo.storeModel)
dataDumper?.provideRequest(reqBody)
// Store sticky provider
await stickyTracker?.set(providerInfo.id)
// Temporarily change 404 to 400 status code b/c solid start automatically override 404 response
const resStatus = res.status === 404 ? 400 : res.status
// Scrub response headers
const resHeaders = new Headers()
const keepHeaders = ["content-type", "cache-control"]
for (const [k, v] of res.headers.entries()) {
if (keepHeaders.includes(k.toLowerCase())) {
resHeaders.set(k, v)
}
}
logger.debug("STATUS: " + res.status + " " + res.statusText)
// Handle non-streaming response
if (!isStream) {
const json = await res.json()
const usageInfo = providerInfo.normalizeUsage(json.usage)
const costInfo = calculateCost(modelInfo, usageInfo)
await trialLimiter?.track(usageInfo)
await rateLimiter?.track()
await trackUsage(billingSource, authInfo, modelInfo, providerInfo, usageInfo, costInfo)
await reload(billingSource, authInfo, costInfo)
const responseConverter = createResponseConverter(providerInfo.format, opts.format)
const body = JSON.stringify(
responseConverter({
...json,
cost: calculateOccuredCost(billingSource, costInfo),
}),
)
logger.metric({ response_length: body.length })
logger.debug("RESPONSE: " + body)
dataDumper?.provideResponse(body)
dataDumper?.flush()
return new Response(body, {
status: resStatus,
statusText: res.statusText,
headers: resHeaders,
})
}
// Handle streaming response
const streamConverter = createStreamPartConverter(providerInfo.format, opts.format)
const usageParser = providerInfo.createUsageParser()
const binaryDecoder = providerInfo.createBinaryStreamDecoder()
const stream = new ReadableStream({
start(c) {
const reader = res.body?.getReader()
const decoder = new TextDecoder()
const encoder = new TextEncoder()
let buffer = ""
let responseLength = 0
function pump(): Promise<void> {
return (
reader?.read().then(async ({ done, value: rawValue }) => {
if (done) {
logger.metric({
response_length: responseLength,
"timestamp.last_byte": Date.now(),
})
dataDumper?.flush()
await rateLimiter?.track()
const usage = usageParser.retrieve()
let cost = "0"
if (usage) {
const usageInfo = providerInfo.normalizeUsage(usage)
const costInfo = calculateCost(modelInfo, usageInfo)
await trialLimiter?.track(usageInfo)
await trackUsage(billingSource, authInfo, modelInfo, providerInfo, usageInfo, costInfo)
await reload(billingSource, authInfo, costInfo)
cost = calculateOccuredCost(billingSource, costInfo)
}
c.enqueue(encoder.encode(usageParser.buidlCostChunk(cost)))
c.close()
return
}
if (responseLength === 0) {
const now = Date.now()
logger.metric({
time_to_first_byte: now - startTimestamp,
"timestamp.first_byte": now,
})
}
const value = binaryDecoder ? binaryDecoder(rawValue) : rawValue
if (!value) return
responseLength += value.length
buffer += decoder.decode(value, { stream: true })
dataDumper?.provideStream(buffer)
const parts = buffer.split(providerInfo.streamSeparator)
buffer = parts.pop() ?? ""
for (let part of parts) {
logger.debug("PART: " + part)
part = part.trim()
usageParser.parse(part)
if (providerInfo.bodyModifier) {
for (const [k, v] of Object.entries(providerInfo.bodyModifier)) {
part = part.replace(k, v)
}
c.enqueue(encoder.encode(part + "\n\n"))
} else if (providerInfo.format !== opts.format) {
part = streamConverter(part)
c.enqueue(encoder.encode(part + "\n\n"))
}
}
if (!providerInfo.bodyModifier && providerInfo.format === opts.format) {
c.enqueue(value)
}
return pump()
}) || Promise.resolve()
)
}
return pump()
},
})
return new Response(stream, {
status: resStatus,
statusText: res.statusText,
headers: resHeaders,
})
} catch (error: any) {
logger.metric({
"error.type": error.constructor.name,
"error.message": error.message,
})
// Note: both top level "type" and "error.type" fields are used by the @ai-sdk/anthropic client to render the error message.
if (
error instanceof AuthError ||
error instanceof CreditsError ||
error instanceof MonthlyLimitError ||
error instanceof UserLimitError ||
error instanceof ModelError
)
return new Response(
JSON.stringify({
type: "error",
error: { type: error.constructor.name, message: error.message },
}),
{ status: 401 },
)
if (error instanceof FreeUsageLimitError || error instanceof SubscriptionUsageLimitError) {
const headers = new Headers()
if (error.retryAfter) {
headers.set("retry-after", String(error.retryAfter))
}
return new Response(
JSON.stringify({
type: "error",
error: { type: error.constructor.name, message: error.message },
}),
{ status: 429, headers },
)
}
return new Response(
JSON.stringify({
type: "error",
error: {
type: "error",
message: error.message,
},
}),
{ status: 500 },
)
}
function validateModel(zenData: ZenData, reqModel: string) {
if (!(reqModel in zenData.models)) throw new ModelError(`Model ${reqModel} not supported`)
const modelId = reqModel as keyof typeof zenData.models
const modelData = Array.isArray(zenData.models[modelId])
? zenData.models[modelId].find((model) => opts.format === model.formatFilter)
: zenData.models[modelId]
if (!modelData) throw new ModelError(`Model ${reqModel} not supported for format ${opts.format}`)
logger.metric({ model: modelId })
return { id: modelId, ...modelData }
}
function selectProvider(
reqModel: string,
zenData: ZenData,
authInfo: AuthInfo,
modelInfo: ModelInfo,
sessionId: string,
isTrial: boolean,
retry: RetryOptions,
stickyProvider: string | undefined,
) {
const modelProvider = (() => {
if (authInfo?.provider?.credentials) {
return modelInfo.providers.find((provider) => provider.id === modelInfo.byokProvider)
}
if (isTrial) {
return modelInfo.providers.find((provider) => provider.id === modelInfo.trial!.provider)
}
if (stickyProvider) {
const provider = modelInfo.providers.find((provider) => provider.id === stickyProvider)
if (provider) return provider
}
if (retry.retryCount === MAX_FAILOVER_RETRIES) {
const provider = modelInfo.providers.find((provider) => provider.id === modelInfo.fallbackProvider)
if (provider) return provider
}
const providers = modelInfo.providers
.filter((provider) => !provider.disabled)
.filter((provider) => !retry.excludeProviders.includes(provider.id))
.flatMap((provider) => Array<typeof provider>(provider.weight ?? 1).fill(provider))
// Use the last 4 characters of session ID to select a provider
let h = 0
const l = sessionId.length
for (let i = l - 4; i < l; i++) {
h = (h * 31 + sessionId.charCodeAt(i)) | 0 // 32-bit int
}
const index = (h >>> 0) % providers.length // make unsigned + range 0..length-1
return providers[index || 0]
})()
if (!modelProvider) throw new ModelError("No provider available")
if (!(modelProvider.id in zenData.providers)) throw new ModelError(`Provider ${modelProvider.id} not supported`)
return {
...modelProvider,
...zenData.providers[modelProvider.id],
...(() => {
const format = zenData.providers[modelProvider.id].format
const providerModel = modelProvider.model
if (format === "anthropic") return anthropicHelper({ reqModel, providerModel })
if (format === "google") return googleHelper({ reqModel, providerModel })
if (format === "openai") return openaiHelper({ reqModel, providerModel })
return oaCompatHelper({ reqModel, providerModel })
})(),
}
}
async function authenticate(modelInfo: ModelInfo) {
const apiKey = opts.parseApiKey(input.request.headers)
if (!apiKey || apiKey === "public") {
if (modelInfo.allowAnonymous) return
throw new AuthError("Missing API key.")
}
const data = await Database.use((tx) =>
tx
.select({
apiKey: KeyTable.id,
workspaceID: KeyTable.workspaceID,
billing: {
balance: BillingTable.balance,
paymentMethodID: BillingTable.paymentMethodID,
monthlyLimit: BillingTable.monthlyLimit,
monthlyUsage: BillingTable.monthlyUsage,
timeMonthlyUsageUpdated: BillingTable.timeMonthlyUsageUpdated,
reloadTrigger: BillingTable.reloadTrigger,
timeReloadLockedTill: BillingTable.timeReloadLockedTill,
subscription: BillingTable.subscription,
},
user: {
id: UserTable.id,
monthlyLimit: UserTable.monthlyLimit,
monthlyUsage: UserTable.monthlyUsage,
timeMonthlyUsageUpdated: UserTable.timeMonthlyUsageUpdated,
},
subscription: {
id: SubscriptionTable.id,
rollingUsage: SubscriptionTable.rollingUsage,
fixedUsage: SubscriptionTable.fixedUsage,
timeRollingUpdated: SubscriptionTable.timeRollingUpdated,
timeFixedUpdated: SubscriptionTable.timeFixedUpdated,
},
provider: {
credentials: ProviderTable.credentials,
},
timeDisabled: ModelTable.timeCreated,
})
.from(KeyTable)
.innerJoin(WorkspaceTable, eq(WorkspaceTable.id, KeyTable.workspaceID))
.innerJoin(BillingTable, eq(BillingTable.workspaceID, KeyTable.workspaceID))
.innerJoin(UserTable, and(eq(UserTable.workspaceID, KeyTable.workspaceID), eq(UserTable.id, KeyTable.userID)))
.leftJoin(ModelTable, and(eq(ModelTable.workspaceID, KeyTable.workspaceID), eq(ModelTable.model, modelInfo.id)))
.leftJoin(
ProviderTable,
modelInfo.byokProvider
? and(
eq(ProviderTable.workspaceID, KeyTable.workspaceID),
eq(ProviderTable.provider, modelInfo.byokProvider),
)
: sql`false`,
)
.leftJoin(
SubscriptionTable,
and(
eq(SubscriptionTable.workspaceID, KeyTable.workspaceID),
eq(SubscriptionTable.userID, KeyTable.userID),
isNull(SubscriptionTable.timeDeleted),
),
)
.where(and(eq(KeyTable.key, apiKey), isNull(KeyTable.timeDeleted)))
.then((rows) => rows[0]),
)
if (!data) throw new AuthError("Invalid API key.")
logger.metric({
api_key: data.apiKey,
workspace: data.workspaceID,
isSubscription: data.subscription ? true : false,
subscription: data.billing.subscription?.plan,
})
return {
apiKeyId: data.apiKey,
workspaceID: data.workspaceID,
billing: data.billing,
user: data.user,
subscription: data.subscription,
provider: data.provider,
isFree: FREE_WORKSPACES.includes(data.workspaceID),
isDisabled: !!data.timeDisabled,
}
}
function validateBilling(authInfo: AuthInfo, modelInfo: ModelInfo): BillingSource {
if (!authInfo) return "anonymous"
if (authInfo.provider?.credentials) return "byok"
if (authInfo.isFree) return "free"
if (modelInfo.allowAnonymous) return "free"
// Validate subscription billing
if (authInfo.billing.subscription && authInfo.subscription) {
try {
const sub = authInfo.subscription
const plan = authInfo.billing.subscription.plan
const formatRetryTime = (seconds: number) => {
const days = Math.floor(seconds / 86400)
if (days >= 1) return `${days} day${days > 1 ? "s" : ""}`
const hours = Math.floor(seconds / 3600)
const minutes = Math.ceil((seconds % 3600) / 60)
if (hours >= 1) return `${hours}hr ${minutes}min`
return `${minutes}min`
}
// Check weekly limit
if (sub.fixedUsage && sub.timeFixedUpdated) {
const result = Black.analyzeWeeklyUsage({
plan,
usage: sub.fixedUsage,
timeUpdated: sub.timeFixedUpdated,
})
if (result.status === "rate-limited")
throw new SubscriptionUsageLimitError(
`Subscription quota exceeded. Retry in ${formatRetryTime(result.resetInSec)}.`,
result.resetInSec,
)
}
// Check rolling limit
if (sub.rollingUsage && sub.timeRollingUpdated) {
const result = Black.analyzeRollingUsage({
plan,
usage: sub.rollingUsage,
timeUpdated: sub.timeRollingUpdated,
})
if (result.status === "rate-limited")
throw new SubscriptionUsageLimitError(
`Subscription quota exceeded. Retry in ${formatRetryTime(result.resetInSec)}.`,
result.resetInSec,
)
}
return "subscription"
} catch (e) {
if (!authInfo.billing.subscription.useBalance) throw e
}
}
// Validate pay as you go billing
const billing = authInfo.billing
if (!billing.paymentMethodID)
throw new CreditsError(
`No payment method. Add a payment method here: https://opencode.ai/workspace/${authInfo.workspaceID}/billing`,
)
if (billing.balance <= 0)
throw new CreditsError(
`Insufficient balance. Manage your billing here: https://opencode.ai/workspace/${authInfo.workspaceID}/billing`,
)
const now = new Date()
const currentYear = now.getUTCFullYear()
const currentMonth = now.getUTCMonth()
if (
billing.monthlyLimit &&
billing.monthlyUsage &&
billing.timeMonthlyUsageUpdated &&
billing.monthlyUsage >= centsToMicroCents(billing.monthlyLimit * 100) &&
currentYear === billing.timeMonthlyUsageUpdated.getUTCFullYear() &&
currentMonth === billing.timeMonthlyUsageUpdated.getUTCMonth()
)
throw new MonthlyLimitError(
`Your workspace has reached its monthly spending limit of $${billing.monthlyLimit}. Manage your limits here: https://opencode.ai/workspace/${authInfo.workspaceID}/billing`,
)
if (
authInfo.user.monthlyLimit &&
authInfo.user.monthlyUsage &&
authInfo.user.timeMonthlyUsageUpdated &&
authInfo.user.monthlyUsage >= centsToMicroCents(authInfo.user.monthlyLimit * 100) &&
currentYear === authInfo.user.timeMonthlyUsageUpdated.getUTCFullYear() &&
currentMonth === authInfo.user.timeMonthlyUsageUpdated.getUTCMonth()
)
throw new UserLimitError(
`You have reached your monthly spending limit of $${authInfo.user.monthlyLimit}. Manage your limits here: https://opencode.ai/workspace/${authInfo.workspaceID}/members`,
)
return "balance"
}
function validateModelSettings(authInfo: AuthInfo) {
if (!authInfo) return
if (authInfo.isDisabled) throw new ModelError("Model is disabled")
}
function updateProviderKey(authInfo: AuthInfo, providerInfo: ProviderInfo) {
if (!authInfo?.provider?.credentials) return
providerInfo.apiKey = authInfo.provider.credentials
}
async function fetchWith429Retry(url: string, options: RequestInit, retry = { count: 0 }) {
const res = await fetch(url, options)
if (res.status === 429 && retry.count < MAX_429_RETRIES) {
await new Promise((resolve) => setTimeout(resolve, Math.pow(2, retry.count) * 500))
return fetchWith429Retry(url, options, { count: retry.count + 1 })
}
return res
}
function calculateCost(modelInfo: ModelInfo, usageInfo: UsageInfo) {
const { inputTokens, outputTokens, reasoningTokens, cacheReadTokens, cacheWrite5mTokens, cacheWrite1hTokens } =
usageInfo
const modelCost =
modelInfo.cost200K &&
inputTokens + (cacheReadTokens ?? 0) + (cacheWrite5mTokens ?? 0) + (cacheWrite1hTokens ?? 0) > 200_000
? modelInfo.cost200K
: modelInfo.cost
const inputCost = modelCost.input * inputTokens * 100
const outputCost = modelCost.output * outputTokens * 100
const reasoningCost = (() => {
if (!reasoningTokens) return undefined
return modelCost.output * reasoningTokens * 100
})()
const cacheReadCost = (() => {
if (!cacheReadTokens) return undefined
if (!modelCost.cacheRead) return undefined
return modelCost.cacheRead * cacheReadTokens * 100
})()
const cacheWrite5mCost = (() => {
if (!cacheWrite5mTokens) return undefined
if (!modelCost.cacheWrite5m) return undefined
return modelCost.cacheWrite5m * cacheWrite5mTokens * 100
})()
const cacheWrite1hCost = (() => {
if (!cacheWrite1hTokens) return undefined
if (!modelCost.cacheWrite1h) return undefined
return modelCost.cacheWrite1h * cacheWrite1hTokens * 100
})()
const totalCostInCent =
inputCost +
outputCost +
(reasoningCost ?? 0) +
(cacheReadCost ?? 0) +
(cacheWrite5mCost ?? 0) +
(cacheWrite1hCost ?? 0)
return {
totalCostInCent,
inputCost,
outputCost,
reasoningCost,
cacheReadCost,
cacheWrite5mCost,
cacheWrite1hCost,
}
}
function calculateOccuredCost(billingSource: BillingSource, costInfo: CostInfo) {
return billingSource === "balance" ? (costInfo.totalCostInCent / 100).toFixed(8) : "0"
}
async function trackUsage(
billingSource: BillingSource,
authInfo: AuthInfo,
modelInfo: ModelInfo,
providerInfo: ProviderInfo,
usageInfo: UsageInfo,
costInfo: CostInfo,
) {
const { inputTokens, outputTokens, reasoningTokens, cacheReadTokens, cacheWrite5mTokens, cacheWrite1hTokens } =
usageInfo
const { totalCostInCent, inputCost, outputCost, reasoningCost, cacheReadCost, cacheWrite5mCost, cacheWrite1hCost } =
costInfo
logger.metric({
"tokens.input": inputTokens,
"tokens.output": outputTokens,
"tokens.reasoning": reasoningTokens,
"tokens.cache_read": cacheReadTokens,
"tokens.cache_write_5m": cacheWrite5mTokens,
"tokens.cache_write_1h": cacheWrite1hTokens,
"cost.input": Math.round(inputCost),
"cost.output": Math.round(outputCost),
"cost.reasoning": reasoningCost ? Math.round(reasoningCost) : undefined,
"cost.cache_read": cacheReadCost ? Math.round(cacheReadCost) : undefined,
"cost.cache_write_5m": cacheWrite5mCost ? Math.round(cacheWrite5mCost) : undefined,
"cost.cache_write_1h": cacheWrite1hCost ? Math.round(cacheWrite1hCost) : undefined,
"cost.total": Math.round(totalCostInCent),
})
if (billingSource === "anonymous") return
authInfo = authInfo!
const cost = centsToMicroCents(totalCostInCent)
await Database.use((db) =>
Promise.all([
db.insert(UsageTable).values({
workspaceID: authInfo.workspaceID,
id: Identifier.create("usage"),
model: modelInfo.id,
provider: providerInfo.id,
inputTokens,
outputTokens,
reasoningTokens,
cacheReadTokens,
cacheWrite5mTokens,
cacheWrite1hTokens,
cost,
keyID: authInfo.apiKeyId,
enrichment: billingSource === "subscription" ? { plan: "sub" } : undefined,
}),
db
.update(KeyTable)
.set({ timeUsed: sql`now()` })
.where(and(eq(KeyTable.workspaceID, authInfo.workspaceID), eq(KeyTable.id, authInfo.apiKeyId))),
...(billingSource === "subscription"
? (() => {
const plan = authInfo.billing.subscription!.plan
const black = BlackData.getLimits({ plan })
const week = getWeekBounds(new Date())
const rollingWindowSeconds = black.rollingWindow * 3600
return [
db
.update(SubscriptionTable)
.set({
fixedUsage: sql`
CASE
WHEN ${SubscriptionTable.timeFixedUpdated} >= ${week.start} THEN ${SubscriptionTable.fixedUsage} + ${cost}
ELSE ${cost}
END
`,
timeFixedUpdated: sql`now()`,
rollingUsage: sql`
CASE
WHEN UNIX_TIMESTAMP(${SubscriptionTable.timeRollingUpdated}) >= UNIX_TIMESTAMP(now()) - ${rollingWindowSeconds} THEN ${SubscriptionTable.rollingUsage} + ${cost}
ELSE ${cost}
END
`,
timeRollingUpdated: sql`
CASE
WHEN UNIX_TIMESTAMP(${SubscriptionTable.timeRollingUpdated}) >= UNIX_TIMESTAMP(now()) - ${rollingWindowSeconds} THEN ${SubscriptionTable.timeRollingUpdated}
ELSE now()
END
`,
})
.where(
and(
eq(SubscriptionTable.workspaceID, authInfo.workspaceID),
eq(SubscriptionTable.userID, authInfo.user.id),
),
),
]
})()
: [
db
.update(BillingTable)
.set({
balance: authInfo.isFree
? sql`${BillingTable.balance} - ${0}`
: sql`${BillingTable.balance} - ${cost}`,
monthlyUsage: sql`
CASE
WHEN MONTH(${BillingTable.timeMonthlyUsageUpdated}) = MONTH(now()) AND YEAR(${BillingTable.timeMonthlyUsageUpdated}) = YEAR(now()) THEN ${BillingTable.monthlyUsage} + ${cost}
ELSE ${cost}
END
`,
timeMonthlyUsageUpdated: sql`now()`,
})
.where(eq(BillingTable.workspaceID, authInfo.workspaceID)),
db
.update(UserTable)
.set({
monthlyUsage: sql`
CASE
WHEN MONTH(${UserTable.timeMonthlyUsageUpdated}) = MONTH(now()) AND YEAR(${UserTable.timeMonthlyUsageUpdated}) = YEAR(now()) THEN ${UserTable.monthlyUsage} + ${cost}
ELSE ${cost}
END
`,
timeMonthlyUsageUpdated: sql`now()`,
})
.where(and(eq(UserTable.workspaceID, authInfo.workspaceID), eq(UserTable.id, authInfo.user.id))),
]),
]),
)
return { costInMicroCents: cost }
}
async function reload(billingSource: BillingSource, authInfo: AuthInfo, costInfo: CostInfo) {
if (billingSource !== "balance") return
authInfo = authInfo!
const reloadTrigger = centsToMicroCents((authInfo.billing.reloadTrigger ?? Billing.RELOAD_TRIGGER) * 100)
if (authInfo.billing.balance - costInfo.totalCostInCent >= reloadTrigger) return
if (authInfo.billing.timeReloadLockedTill && authInfo.billing.timeReloadLockedTill > new Date()) return
const lock = await Database.use((tx) =>
tx
.update(BillingTable)
.set({
timeReloadLockedTill: sql`now() + interval 1 minute`,
})
.where(
and(
eq(BillingTable.workspaceID, authInfo.workspaceID),
eq(BillingTable.reload, true),
lt(BillingTable.balance, reloadTrigger),
or(isNull(BillingTable.timeReloadLockedTill), lt(BillingTable.timeReloadLockedTill, sql`now()`)),
),
),
)
if (lock.rowsAffected === 0) return
await Actor.provide("system", { workspaceID: authInfo.workspaceID }, async () => {
await Billing.reload()
})
}
}

View File

@@ -3,6 +3,7 @@ import { Tool } from "./tool"
import TurndownService from "turndown"
import DESCRIPTION from "./webfetch.txt"
import { abortAfterAny } from "../util/abort"
import { Identifier } from "../id/id"
const MAX_RESPONSE_SIZE = 5 * 1024 * 1024 // 5MB
const DEFAULT_TIMEOUT = 30 * 1000 // 30 seconds
@@ -87,11 +88,34 @@ export const WebFetchTool = Tool.define("webfetch", {
throw new Error("Response too large (exceeds 5MB limit)")
}
const content = new TextDecoder().decode(arrayBuffer)
const contentType = response.headers.get("content-type") || ""
const mime = contentType.split(";")[0]?.trim().toLowerCase() || ""
const title = `${params.url} (${contentType})`
// Check if response is an image
const isImage = mime.startsWith("image/") && mime !== "image/svg+xml" && mime !== "image/vnd.fastbidsheet"
if (isImage) {
const base64Content = Buffer.from(arrayBuffer).toString("base64")
return {
title,
output: "Image fetched successfully",
metadata: {},
attachments: [
{
id: Identifier.ascending("part"),
sessionID: ctx.sessionID,
messageID: ctx.messageID,
type: "file",
mime,
url: `data:${mime};base64,${base64Content}`,
},
],
}
}
const content = new TextDecoder().decode(arrayBuffer)
// Handle content based on requested format and actual content type
switch (params.format) {
case "markdown":

View File

@@ -0,0 +1,97 @@
import { describe, expect, test } from "bun:test"
import path from "path"
import { Instance } from "../../src/project/instance"
import { WebFetchTool } from "../../src/tool/webfetch"
const projectRoot = path.join(import.meta.dir, "../..")
const ctx = {
sessionID: "test",
messageID: "message",
callID: "",
agent: "build",
abort: AbortSignal.any([]),
messages: [],
metadata: () => {},
ask: async () => {},
}
async function withFetch(
mockFetch: (input: string | URL | Request, init?: RequestInit) => Promise<Response>,
fn: () => Promise<void>,
) {
const originalFetch = globalThis.fetch
globalThis.fetch = mockFetch as unknown as typeof fetch
try {
await fn()
} finally {
globalThis.fetch = originalFetch
}
}
describe("tool.webfetch", () => {
test("returns image responses as file attachments", async () => {
const bytes = new Uint8Array([137, 80, 78, 71, 13, 10, 26, 10])
await withFetch(
async () => new Response(bytes, { status: 200, headers: { "content-type": "IMAGE/PNG; charset=binary" } }),
async () => {
await Instance.provide({
directory: projectRoot,
fn: async () => {
const webfetch = await WebFetchTool.init()
const result = await webfetch.execute({ url: "https://example.com/image.png", format: "markdown" }, ctx)
expect(result.output).toBe("Image fetched successfully")
expect(result.attachments).toBeDefined()
expect(result.attachments?.length).toBe(1)
expect(result.attachments?.[0].type).toBe("file")
expect(result.attachments?.[0].mime).toBe("image/png")
expect(result.attachments?.[0].url.startsWith("data:image/png;base64,")).toBe(true)
},
})
},
)
})
test("keeps svg as text output", async () => {
const svg = '<svg xmlns="http://www.w3.org/2000/svg"><text>hello</text></svg>'
await withFetch(
async () =>
new Response(svg, {
status: 200,
headers: { "content-type": "image/svg+xml; charset=UTF-8" },
}),
async () => {
await Instance.provide({
directory: projectRoot,
fn: async () => {
const webfetch = await WebFetchTool.init()
const result = await webfetch.execute({ url: "https://example.com/image.svg", format: "html" }, ctx)
expect(result.output).toContain("<svg")
expect(result.attachments).toBeUndefined()
},
})
},
)
})
test("keeps text responses as text output", async () => {
await withFetch(
async () =>
new Response("hello from webfetch", {
status: 200,
headers: { "content-type": "text/plain; charset=utf-8" },
}),
async () => {
await Instance.provide({
directory: projectRoot,
fn: async () => {
const webfetch = await WebFetchTool.init()
const result = await webfetch.execute({ url: "https://example.com/file.txt", format: "text" }, ctx)
expect(result.output).toBe("hello from webfetch")
expect(result.attachments).toBeUndefined()
},
})
},
)
})
})

View File

@@ -1548,8 +1548,8 @@ export type ProviderConfig = {
[key: string]: string
}
provider?: {
npm: string
api: string
npm?: string
api?: string
}
/**
* Variant-specific configuration
@@ -4068,8 +4068,8 @@ export type ProviderListResponses = {
[key: string]: string
}
provider?: {
npm: string
api: string
npm?: string
api?: string
}
variants?: {
[key: string]: {

View File

@@ -3800,8 +3800,7 @@
"api": {
"type": "string"
}
},
"required": ["npm", "api"]
}
},
"variants": {
"type": "object",
@@ -9405,8 +9404,7 @@
"api": {
"type": "string"
}
},
"required": ["npm", "api"]
}
},
"variants": {
"description": "Variant-specific configuration",