mirror of
https://github.com/anomalyco/opencode.git
synced 2026-03-10 08:34:10 +00:00
Compare commits
8 Commits
refactor/n
...
add-api-sh
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b82787e270 | ||
|
|
bb88363a76 | ||
|
|
d63f4bcd5f | ||
|
|
065c2a1e6e | ||
|
|
4f982ddb94 | ||
|
|
ff3bb7424d | ||
|
|
e6a49ed85c | ||
|
|
77cdfcdb64 |
@@ -81,8 +81,6 @@
|
||||
"@gitlab/gitlab-ai-provider": "3.6.0",
|
||||
"@gitlab/opencode-gitlab-auth": "1.3.3",
|
||||
"@hono/standard-validator": "0.1.5",
|
||||
"@hono/node-server": "1.19.11",
|
||||
"@hono/node-ws": "1.3.0",
|
||||
"@hono/zod-validator": "catalog:",
|
||||
"@modelcontextprotocol/sdk": "1.25.2",
|
||||
"@octokit/graphql": "9.0.2",
|
||||
|
||||
@@ -23,7 +23,7 @@ export const AcpCommand = cmd({
|
||||
process.env.OPENCODE_CLIENT = "acp"
|
||||
await bootstrap(process.cwd(), async () => {
|
||||
const opts = await resolveNetworkOptions(args)
|
||||
const server = await Server.listen(opts)
|
||||
const server = Server.listen(opts)
|
||||
|
||||
const sdk = createOpencodeClient({
|
||||
baseUrl: `http://${server.hostname}:${server.port}`,
|
||||
|
||||
@@ -15,7 +15,7 @@ export const ServeCommand = cmd({
|
||||
console.log("Warning: OPENCODE_SERVER_PASSWORD is not set; server is unsecured.")
|
||||
}
|
||||
const opts = await resolveNetworkOptions(args)
|
||||
const server = await Server.listen(opts)
|
||||
const server = Server.listen(opts)
|
||||
console.log(`opencode server listening on http://${server.hostname}:${server.port}`)
|
||||
|
||||
await new Promise(() => {})
|
||||
|
||||
@@ -8,6 +8,7 @@ import { upgrade } from "@/cli/upgrade"
|
||||
import { Config } from "@/config/config"
|
||||
import { GlobalBus } from "@/bus/global"
|
||||
import { createOpencodeClient, type Event } from "@opencode-ai/sdk/v2"
|
||||
import type { BunWebSocketData } from "hono/bun"
|
||||
import { Flag } from "@/flag/flag"
|
||||
import { setTimeout as sleep } from "node:timers/promises"
|
||||
|
||||
@@ -37,7 +38,7 @@ GlobalBus.on("event", (event) => {
|
||||
Rpc.emit("global.event", event)
|
||||
})
|
||||
|
||||
let server: Awaited<ReturnType<typeof Server.listen>> | undefined
|
||||
let server: Bun.Server<BunWebSocketData> | undefined
|
||||
|
||||
const eventStream = {
|
||||
abort: undefined as AbortController | undefined,
|
||||
@@ -119,7 +120,7 @@ export const rpc = {
|
||||
},
|
||||
async server(input: { port: number; hostname: string; mdns?: boolean; cors?: string[] }) {
|
||||
if (server) await server.stop(true)
|
||||
server = await Server.listen(input)
|
||||
server = Server.listen(input)
|
||||
return { url: server.url.toString() }
|
||||
},
|
||||
async checkUpgrade(input: { directory: string }) {
|
||||
@@ -142,7 +143,7 @@ export const rpc = {
|
||||
Log.Default.info("worker shutting down")
|
||||
if (eventStream.abort) eventStream.abort.abort()
|
||||
await Instance.disposeAll()
|
||||
if (server) await server.stop(true)
|
||||
if (server) server.stop(true)
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ export const WebCommand = cmd({
|
||||
UI.println(UI.Style.TEXT_WARNING_BOLD + "! " + "OPENCODE_SERVER_PASSWORD is not set; server is unsecured.")
|
||||
}
|
||||
const opts = await resolveNetworkOptions(args)
|
||||
const server = await Server.listen(opts)
|
||||
const server = Server.listen(opts)
|
||||
UI.empty()
|
||||
UI.println(UI.logo(" "))
|
||||
UI.empty()
|
||||
|
||||
@@ -396,8 +396,14 @@ export namespace MCP {
|
||||
} catch (error) {
|
||||
lastError = error instanceof Error ? error : new Error(String(error))
|
||||
|
||||
// Handle OAuth-specific errors
|
||||
if (error instanceof UnauthorizedError) {
|
||||
// Handle OAuth-specific errors.
|
||||
// The SDK throws UnauthorizedError when auth() returns 'REDIRECT',
|
||||
// but may also throw plain Errors when auth() fails internally
|
||||
// (e.g. during discovery, registration, or state generation).
|
||||
// When an authProvider is attached, treat both cases as auth-related.
|
||||
const isAuthError =
|
||||
error instanceof UnauthorizedError || (authProvider && lastError.message.includes("OAuth"))
|
||||
if (isAuthError) {
|
||||
log.info("mcp server requires authentication", { key, transport: name })
|
||||
|
||||
// Check if this is a "needs registration" error
|
||||
|
||||
@@ -144,10 +144,19 @@ export class McpOAuthProvider implements OAuthClientProvider {
|
||||
|
||||
async state(): Promise<string> {
|
||||
const entry = await McpAuth.get(this.mcpName)
|
||||
if (!entry?.oauthState) {
|
||||
throw new Error(`No OAuth state saved for MCP server: ${this.mcpName}`)
|
||||
if (entry?.oauthState) {
|
||||
return entry.oauthState
|
||||
}
|
||||
return entry.oauthState
|
||||
|
||||
// Generate a new state if none exists — the SDK calls state() as a
|
||||
// generator, not just a reader, so we need to produce a value even when
|
||||
// startAuth() hasn't pre-saved one (e.g. during automatic auth on first
|
||||
// connect).
|
||||
const newState = Array.from(crypto.getRandomValues(new Uint8Array(32)))
|
||||
.map((b) => b.toString(16).padStart(2, "0"))
|
||||
.join("")
|
||||
await McpAuth.updateOAuthState(this.mcpName, newState)
|
||||
return newState
|
||||
}
|
||||
|
||||
async invalidateCredentials(type: "all" | "client" | "tokens"): Promise<void> {
|
||||
|
||||
@@ -65,7 +65,13 @@ export namespace ModelsDev {
|
||||
status: z.enum(["alpha", "beta", "deprecated"]).optional(),
|
||||
options: z.record(z.string(), z.any()),
|
||||
headers: z.record(z.string(), z.string()).optional(),
|
||||
provider: z.object({ npm: z.string().optional(), api: z.string().optional() }).optional(),
|
||||
provider: z
|
||||
.object({
|
||||
npm: z.string().optional(),
|
||||
api: z.string().optional(),
|
||||
shape: z.string().optional(),
|
||||
})
|
||||
.optional(),
|
||||
variants: z.record(z.string(), z.record(z.string(), z.any())).optional(),
|
||||
})
|
||||
export type Model = z.infer<typeof Model>
|
||||
|
||||
@@ -110,7 +110,7 @@ export namespace Provider {
|
||||
"@ai-sdk/github-copilot": createGitHubCopilotOpenAICompatible,
|
||||
}
|
||||
|
||||
type CustomModelLoader = (sdk: any, modelID: string, options?: Record<string, any>) => Promise<any>
|
||||
type CustomModelLoader = (sdk: any, model: Model, options?: Record<string, any>) => Promise<any>
|
||||
type CustomLoader = (provider: Info) => Promise<{
|
||||
autoload: boolean
|
||||
getModel?: CustomModelLoader
|
||||
@@ -154,8 +154,8 @@ export namespace Provider {
|
||||
openai: async () => {
|
||||
return {
|
||||
autoload: false,
|
||||
async getModel(sdk: any, modelID: string, _options?: Record<string, any>) {
|
||||
return sdk.responses(modelID)
|
||||
async getModel(sdk: any, model: Model, _options?: Record<string, any>) {
|
||||
return sdk.responses(model.api.id)
|
||||
},
|
||||
options: {},
|
||||
}
|
||||
@@ -163,9 +163,8 @@ export namespace Provider {
|
||||
"github-copilot": async () => {
|
||||
return {
|
||||
autoload: false,
|
||||
async getModel(sdk: any, modelID: string, _options?: Record<string, any>) {
|
||||
if (sdk.responses === undefined && sdk.chat === undefined) return sdk.languageModel(modelID)
|
||||
return shouldUseCopilotResponsesApi(modelID) ? sdk.responses(modelID) : sdk.chat(modelID)
|
||||
async getModel(sdk: any, model: Model, _options?: Record<string, any>) {
|
||||
return shouldUseCopilotResponsesApi(model.api.id) ? sdk.responses(model.api.id) : sdk.chat(model.api.id)
|
||||
},
|
||||
options: {},
|
||||
}
|
||||
@@ -173,9 +172,8 @@ export namespace Provider {
|
||||
"github-copilot-enterprise": async () => {
|
||||
return {
|
||||
autoload: false,
|
||||
async getModel(sdk: any, modelID: string, _options?: Record<string, any>) {
|
||||
if (sdk.responses === undefined && sdk.chat === undefined) return sdk.languageModel(modelID)
|
||||
return shouldUseCopilotResponsesApi(modelID) ? sdk.responses(modelID) : sdk.chat(modelID)
|
||||
async getModel(sdk: any, model: Model, _options?: Record<string, any>) {
|
||||
return shouldUseCopilotResponsesApi(model.api.id) ? sdk.responses(model.api.id) : sdk.chat(model.api.id)
|
||||
},
|
||||
options: {},
|
||||
}
|
||||
@@ -183,12 +181,9 @@ export namespace Provider {
|
||||
azure: async () => {
|
||||
return {
|
||||
autoload: false,
|
||||
async getModel(sdk: any, modelID: string, options?: Record<string, any>) {
|
||||
if (options?.["useCompletionUrls"]) {
|
||||
return sdk.chat(modelID)
|
||||
} else {
|
||||
return sdk.responses(modelID)
|
||||
}
|
||||
async getModel(sdk: any, model: Model, options?: Record<string, any>) {
|
||||
if (options?.["useCompletionUrls"]) return sdk.chat(model.api.id)
|
||||
return sdk.responses(model.api.id)
|
||||
},
|
||||
options: {},
|
||||
}
|
||||
@@ -197,12 +192,9 @@ export namespace Provider {
|
||||
const resourceName = Env.get("AZURE_COGNITIVE_SERVICES_RESOURCE_NAME")
|
||||
return {
|
||||
autoload: false,
|
||||
async getModel(sdk: any, modelID: string, options?: Record<string, any>) {
|
||||
if (options?.["useCompletionUrls"]) {
|
||||
return sdk.chat(modelID)
|
||||
} else {
|
||||
return sdk.responses(modelID)
|
||||
}
|
||||
async getModel(sdk: any, model: Model, options?: Record<string, any>) {
|
||||
if (options?.["useCompletionUrls"]) return sdk.chat(model.api.id)
|
||||
return sdk.responses(model.api.id)
|
||||
},
|
||||
options: {
|
||||
baseURL: resourceName ? `https://${resourceName}.cognitiveservices.azure.com/openai` : undefined,
|
||||
@@ -270,7 +262,8 @@ export namespace Provider {
|
||||
return {
|
||||
autoload: true,
|
||||
options: providerOptions,
|
||||
async getModel(sdk: any, modelID: string, options?: Record<string, any>) {
|
||||
async getModel(sdk: any, model: Model, options?: Record<string, any>) {
|
||||
let modelID = model.api.id
|
||||
// Skip region prefixing if model already has a cross-region inference profile prefix
|
||||
// Models from models.dev may already include prefixes like us., eu., global., etc.
|
||||
const crossRegionPrefixes = ["global.", "us.", "eu.", "jp.", "apac.", "au."]
|
||||
@@ -407,8 +400,8 @@ export namespace Provider {
|
||||
return fetch(input, { ...init, headers })
|
||||
},
|
||||
},
|
||||
async getModel(sdk: any, modelID: string) {
|
||||
const id = String(modelID).trim()
|
||||
async getModel(sdk: any, model: Model) {
|
||||
const id = String(model.api.id).trim()
|
||||
return sdk.languageModel(id)
|
||||
},
|
||||
}
|
||||
@@ -424,8 +417,8 @@ export namespace Provider {
|
||||
project,
|
||||
location,
|
||||
},
|
||||
async getModel(sdk: any, modelID) {
|
||||
const id = String(modelID).trim()
|
||||
async getModel(sdk: any, model: Model) {
|
||||
const id = String(model.api.id).trim()
|
||||
return sdk.languageModel(id)
|
||||
},
|
||||
}
|
||||
@@ -449,8 +442,8 @@ export namespace Provider {
|
||||
return {
|
||||
autoload: !!envServiceKey,
|
||||
options: envServiceKey ? { deploymentId, resourceGroup } : {},
|
||||
async getModel(sdk: any, modelID: string) {
|
||||
return sdk(modelID)
|
||||
async getModel(sdk: any, model: Model) {
|
||||
return sdk(model.api.id)
|
||||
},
|
||||
}
|
||||
},
|
||||
@@ -496,8 +489,8 @@ export namespace Provider {
|
||||
...(providerConfig?.options?.featureFlags || {}),
|
||||
},
|
||||
},
|
||||
async getModel(sdk: ReturnType<typeof createGitLab>, modelID: string) {
|
||||
return sdk.agenticChat(modelID, {
|
||||
async getModel(sdk: ReturnType<typeof createGitLab>, model: Model) {
|
||||
return sdk.agenticChat(model.api.id, {
|
||||
aiGatewayHeaders,
|
||||
featureFlags: {
|
||||
duo_agent_platform_agentic_chat: true,
|
||||
@@ -526,8 +519,8 @@ export namespace Provider {
|
||||
apiKey,
|
||||
baseURL: `https://api.cloudflare.com/client/v4/accounts/${accountId}/ai/v1`,
|
||||
},
|
||||
async getModel(sdk: any, modelID: string) {
|
||||
return sdk.languageModel(modelID)
|
||||
async getModel(sdk: any, model: Model) {
|
||||
return sdk.languageModel(model.api.id)
|
||||
},
|
||||
}
|
||||
},
|
||||
@@ -583,9 +576,9 @@ export namespace Provider {
|
||||
|
||||
return {
|
||||
autoload: true,
|
||||
async getModel(_sdk: any, modelID: string, _options?: Record<string, any>) {
|
||||
async getModel(_sdk: any, model: Model, _options?: Record<string, any>) {
|
||||
// Model IDs use Unified API format: provider/model (e.g., "anthropic/claude-sonnet-4-5")
|
||||
return aigateway(unified(modelID))
|
||||
return aigateway(unified(model.api.id))
|
||||
},
|
||||
options: {},
|
||||
}
|
||||
@@ -621,6 +614,7 @@ export namespace Provider {
|
||||
id: z.string(),
|
||||
url: z.string(),
|
||||
npm: z.string(),
|
||||
shape: z.string().optional(),
|
||||
}),
|
||||
name: z.string(),
|
||||
family: z.string().optional(),
|
||||
@@ -709,6 +703,7 @@ export namespace Provider {
|
||||
id: model.id,
|
||||
url: model.provider?.api ?? provider.api!,
|
||||
npm: model.provider?.npm ?? provider.npm ?? "@ai-sdk/openai-compatible",
|
||||
shape: model.provider?.shape,
|
||||
},
|
||||
status: model.status ?? "active",
|
||||
headers: model.headers ?? {},
|
||||
@@ -859,6 +854,7 @@ export namespace Provider {
|
||||
existingModel?.api.npm ??
|
||||
modelsDev[providerID]?.npm ??
|
||||
"@ai-sdk/openai-compatible",
|
||||
shape: model.provider?.shape ?? existingModel?.api.shape,
|
||||
url: model.provider?.api ?? provider?.api ?? existingModel?.api.url ?? modelsDev[providerID]?.api,
|
||||
},
|
||||
status: model.status ?? existingModel?.status ?? "active",
|
||||
@@ -1199,9 +1195,20 @@ export namespace Provider {
|
||||
const sdk = await getSDK(model)
|
||||
|
||||
try {
|
||||
const language = s.modelLoaders[model.providerID]
|
||||
? await s.modelLoaders[model.providerID](sdk, model.api.id, provider.options)
|
||||
: sdk.languageModel(model.api.id)
|
||||
const language = await iife(async () => {
|
||||
// SDK the type doesn't have .responses or .chat on it, but it CAN when using certain npm packages
|
||||
const looselyTyped = sdk as any
|
||||
if (looselyTyped.responses !== undefined || looselyTyped.chat !== undefined) {
|
||||
if (model.api.shape === "responses") return looselyTyped.responses(model.api.id)
|
||||
if (model.api.shape === "completions") return looselyTyped.chat(model.api.id)
|
||||
}
|
||||
|
||||
const ml = s.modelLoaders[model.providerID]
|
||||
if (ml) {
|
||||
return await ml(sdk, model, provider.options)
|
||||
}
|
||||
return sdk.languageModel(model.api.id)
|
||||
})
|
||||
s.models.set(key, language)
|
||||
return language
|
||||
} catch (e) {
|
||||
|
||||
@@ -23,8 +23,6 @@ export namespace Pty {
|
||||
close: (code?: number, reason?: string) => void
|
||||
}
|
||||
|
||||
const key = (ws: Socket) => (ws.data && typeof ws.data === "object" ? ws.data : ws)
|
||||
|
||||
// WebSocket control frame: 0x00 + UTF-8 JSON.
|
||||
const meta = (cursor: number) => {
|
||||
const json = JSON.stringify({ cursor })
|
||||
@@ -99,9 +97,9 @@ export namespace Pty {
|
||||
try {
|
||||
session.process.kill()
|
||||
} catch {}
|
||||
for (const [id, ws] of session.subscribers.entries()) {
|
||||
for (const [key, ws] of session.subscribers.entries()) {
|
||||
try {
|
||||
if (key(ws) === id) ws.close()
|
||||
if (ws.data === key) ws.close()
|
||||
} catch {
|
||||
// ignore
|
||||
}
|
||||
@@ -172,21 +170,21 @@ export namespace Pty {
|
||||
ptyProcess.onData((chunk) => {
|
||||
session.cursor += chunk.length
|
||||
|
||||
for (const [id, ws] of session.subscribers.entries()) {
|
||||
for (const [key, ws] of session.subscribers.entries()) {
|
||||
if (ws.readyState !== 1) {
|
||||
session.subscribers.delete(id)
|
||||
session.subscribers.delete(key)
|
||||
continue
|
||||
}
|
||||
|
||||
if (key(ws) !== id) {
|
||||
session.subscribers.delete(id)
|
||||
if (ws.data !== key) {
|
||||
session.subscribers.delete(key)
|
||||
continue
|
||||
}
|
||||
|
||||
try {
|
||||
ws.send(chunk)
|
||||
} catch {
|
||||
session.subscribers.delete(id)
|
||||
session.subscribers.delete(key)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -228,9 +226,9 @@ export namespace Pty {
|
||||
try {
|
||||
session.process.kill()
|
||||
} catch {}
|
||||
for (const [id, ws] of session.subscribers.entries()) {
|
||||
for (const [key, ws] of session.subscribers.entries()) {
|
||||
try {
|
||||
if (key(ws) === id) ws.close()
|
||||
if (ws.data === key) ws.close()
|
||||
} catch {
|
||||
// ignore
|
||||
}
|
||||
@@ -261,13 +259,16 @@ export namespace Pty {
|
||||
}
|
||||
log.info("client connected to session", { id })
|
||||
|
||||
const sub = key(ws)
|
||||
// Use ws.data as the unique key for this connection lifecycle.
|
||||
// If ws.data is undefined, fallback to ws object.
|
||||
const connectionKey = ws.data && typeof ws.data === "object" ? ws.data : ws
|
||||
|
||||
session.subscribers.delete(sub)
|
||||
session.subscribers.set(sub, ws)
|
||||
// Optionally cleanup if the key somehow exists
|
||||
session.subscribers.delete(connectionKey)
|
||||
session.subscribers.set(connectionKey, ws)
|
||||
|
||||
const cleanup = () => {
|
||||
session.subscribers.delete(sub)
|
||||
session.subscribers.delete(connectionKey)
|
||||
}
|
||||
|
||||
const start = session.bufferCursor
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
import { Hono } from "hono"
|
||||
import { describeRoute, validator, resolver } from "hono-openapi"
|
||||
import type { UpgradeWebSocket } from "hono/ws"
|
||||
import { upgradeWebSocket } from "hono/bun"
|
||||
import z from "zod"
|
||||
import { Pty } from "@/pty"
|
||||
import { NotFoundError } from "../../storage/db"
|
||||
import { errors } from "../error"
|
||||
import { lazy } from "../../util/lazy"
|
||||
|
||||
export function PtyRoutes(upgradeWebSocket: UpgradeWebSocket) {
|
||||
return new Hono()
|
||||
export const PtyRoutes = lazy(() =>
|
||||
new Hono()
|
||||
.get(
|
||||
"/",
|
||||
describeRoute({
|
||||
@@ -195,5 +196,5 @@ export function PtyRoutes(upgradeWebSocket: UpgradeWebSocket) {
|
||||
},
|
||||
}
|
||||
}),
|
||||
)
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
@@ -34,8 +34,7 @@ import { ProviderRoutes } from "./routes/provider"
|
||||
import { InstanceBootstrap } from "../project/bootstrap"
|
||||
import { NotFoundError } from "../storage/db"
|
||||
import type { ContentfulStatusCode } from "hono/utils/http-status"
|
||||
import { createAdaptorServer, type ServerType } from "@hono/node-server"
|
||||
import { createNodeWebSocket } from "@hono/node-ws"
|
||||
import { websocket } from "hono/bun"
|
||||
import { HTTPException } from "hono/http-exception"
|
||||
import { errors } from "./error"
|
||||
import { Filesystem } from "@/util/filesystem"
|
||||
@@ -49,20 +48,13 @@ import { lazy } from "@/util/lazy"
|
||||
globalThis.AI_SDK_LOG_WARNINGS = false
|
||||
|
||||
export namespace Server {
|
||||
export type Listener = {
|
||||
hostname: string
|
||||
port: number
|
||||
url: URL
|
||||
stop: (close?: boolean) => Promise<void>
|
||||
}
|
||||
const log = Log.create({ service: "server" })
|
||||
|
||||
export const Default = lazy(() => create({}).app)
|
||||
export const Default = lazy(() => createApp({}))
|
||||
|
||||
function create(opts: { cors?: string[] }) {
|
||||
const log = Log.create({ service: "server" })
|
||||
export const createApp = (opts: { cors?: string[] }): Hono => {
|
||||
const app = new Hono()
|
||||
const ws = createNodeWebSocket({ app })
|
||||
const route = app
|
||||
return app
|
||||
.onError((err, c) => {
|
||||
log.error("failed", {
|
||||
error: err,
|
||||
@@ -247,6 +239,7 @@ export namespace Server {
|
||||
),
|
||||
)
|
||||
.route("/project", ProjectRoutes())
|
||||
.route("/pty", PtyRoutes())
|
||||
.route("/config", ConfigRoutes())
|
||||
.route("/experimental", ExperimentalRoutes())
|
||||
.route("/session", SessionRoutes())
|
||||
@@ -559,7 +552,6 @@ export namespace Server {
|
||||
})
|
||||
},
|
||||
)
|
||||
.route("/pty", PtyRoutes(ws.upgradeWebSocket))
|
||||
.all("/*", async (c) => {
|
||||
const path = c.req.path
|
||||
|
||||
@@ -576,11 +568,6 @@ export namespace Server {
|
||||
)
|
||||
return response
|
||||
})
|
||||
|
||||
return {
|
||||
app: route as Hono,
|
||||
ws,
|
||||
}
|
||||
}
|
||||
|
||||
export async function openapi() {
|
||||
@@ -598,86 +585,48 @@ export namespace Server {
|
||||
return result
|
||||
}
|
||||
|
||||
export async function listen(opts: {
|
||||
export function listen(opts: {
|
||||
port: number
|
||||
hostname: string
|
||||
mdns?: boolean
|
||||
mdnsDomain?: string
|
||||
cors?: string[]
|
||||
}): Promise<Listener> {
|
||||
const log = Log.create({ service: "server" })
|
||||
const built = create({
|
||||
...opts,
|
||||
})
|
||||
const start = (port: number) =>
|
||||
new Promise<ServerType>((resolve, reject) => {
|
||||
const server = createAdaptorServer({ fetch: built.app.fetch })
|
||||
built.ws.injectWebSocket(server)
|
||||
const fail = (err: Error) => {
|
||||
cleanup()
|
||||
reject(err)
|
||||
}
|
||||
const ready = () => {
|
||||
cleanup()
|
||||
resolve(server)
|
||||
}
|
||||
const cleanup = () => {
|
||||
server.off("error", fail)
|
||||
server.off("listening", ready)
|
||||
}
|
||||
server.once("error", fail)
|
||||
server.once("listening", ready)
|
||||
server.listen(port, opts.hostname)
|
||||
})
|
||||
|
||||
const server = opts.port === 0 ? await start(4096).catch(() => start(0)) : await start(opts.port)
|
||||
const addr = server.address()
|
||||
if (!addr || typeof addr === "string") {
|
||||
throw new Error(`Failed to resolve server address for port ${opts.port}`)
|
||||
}) {
|
||||
const app = createApp(opts)
|
||||
const args = {
|
||||
hostname: opts.hostname,
|
||||
idleTimeout: 0,
|
||||
fetch: app.fetch,
|
||||
websocket: websocket,
|
||||
} as const
|
||||
const tryServe = (port: number) => {
|
||||
try {
|
||||
return Bun.serve({ ...args, port })
|
||||
} catch {
|
||||
return undefined
|
||||
}
|
||||
}
|
||||
|
||||
const url = new URL("http://localhost")
|
||||
url.hostname = opts.hostname
|
||||
url.port = String(addr.port)
|
||||
const server = opts.port === 0 ? (tryServe(4096) ?? tryServe(0)) : tryServe(opts.port)
|
||||
if (!server) throw new Error(`Failed to start server on port ${opts.port}`)
|
||||
|
||||
const shouldPublishMDNS =
|
||||
opts.mdns &&
|
||||
addr.port &&
|
||||
server.port &&
|
||||
opts.hostname !== "127.0.0.1" &&
|
||||
opts.hostname !== "localhost" &&
|
||||
opts.hostname !== "::1"
|
||||
if (shouldPublishMDNS) {
|
||||
MDNS.publish(addr.port, opts.mdnsDomain)
|
||||
MDNS.publish(server.port!, opts.mdnsDomain)
|
||||
} else if (opts.mdns) {
|
||||
log.warn("mDNS enabled but hostname is loopback; skipping mDNS publish")
|
||||
}
|
||||
|
||||
let closing: Promise<void> | undefined
|
||||
return {
|
||||
hostname: opts.hostname,
|
||||
port: addr.port,
|
||||
url,
|
||||
stop(close?: boolean) {
|
||||
closing ??= new Promise((resolve, reject) => {
|
||||
if (shouldPublishMDNS) MDNS.unpublish()
|
||||
server.close((err) => {
|
||||
if (err) {
|
||||
reject(err)
|
||||
return
|
||||
}
|
||||
resolve()
|
||||
})
|
||||
if (close) {
|
||||
if ("closeAllConnections" in server && typeof server.closeAllConnections === "function") {
|
||||
server.closeAllConnections()
|
||||
}
|
||||
if ("closeIdleConnections" in server && typeof server.closeIdleConnections === "function") {
|
||||
server.closeIdleConnections()
|
||||
}
|
||||
}
|
||||
})
|
||||
return closing
|
||||
},
|
||||
const originalStop = server.stop.bind(server)
|
||||
server.stop = async (closeActiveConnections?: boolean) => {
|
||||
if (shouldPublishMDNS) MDNS.unpublish()
|
||||
return originalStop(closeActiveConnections)
|
||||
}
|
||||
|
||||
return server
|
||||
}
|
||||
}
|
||||
|
||||
199
packages/opencode/test/mcp/oauth-auto-connect.test.ts
Normal file
199
packages/opencode/test/mcp/oauth-auto-connect.test.ts
Normal file
@@ -0,0 +1,199 @@
|
||||
import { test, expect, mock, beforeEach } from "bun:test"
|
||||
|
||||
// Mock UnauthorizedError to match the SDK's class
|
||||
class MockUnauthorizedError extends Error {
|
||||
constructor(message?: string) {
|
||||
super(message ?? "Unauthorized")
|
||||
this.name = "UnauthorizedError"
|
||||
}
|
||||
}
|
||||
|
||||
// Track what options were passed to each transport constructor
|
||||
const transportCalls: Array<{
|
||||
type: "streamable" | "sse"
|
||||
url: string
|
||||
options: { authProvider?: unknown }
|
||||
}> = []
|
||||
|
||||
// Controls whether the mock transport simulates a 401 that triggers the SDK
|
||||
// auth flow (which calls provider.state()) or a simple UnauthorizedError.
|
||||
let simulateAuthFlow = true
|
||||
|
||||
// Mock the transport constructors to simulate OAuth auto-auth on 401
|
||||
mock.module("@modelcontextprotocol/sdk/client/streamableHttp.js", () => ({
|
||||
StreamableHTTPClientTransport: class MockStreamableHTTP {
|
||||
authProvider:
|
||||
| {
|
||||
state?: () => Promise<string>
|
||||
redirectToAuthorization?: (url: URL) => Promise<void>
|
||||
saveCodeVerifier?: (v: string) => Promise<void>
|
||||
}
|
||||
| undefined
|
||||
constructor(url: URL, options?: { authProvider?: unknown }) {
|
||||
this.authProvider = options?.authProvider as typeof this.authProvider
|
||||
transportCalls.push({
|
||||
type: "streamable",
|
||||
url: url.toString(),
|
||||
options: options ?? {},
|
||||
})
|
||||
}
|
||||
async start() {
|
||||
// Simulate what the real SDK transport does on 401:
|
||||
// It calls auth() which eventually calls provider.state(), then
|
||||
// provider.redirectToAuthorization(), then throws UnauthorizedError.
|
||||
if (simulateAuthFlow && this.authProvider) {
|
||||
// The SDK calls provider.state() to get the OAuth state parameter
|
||||
if (this.authProvider.state) {
|
||||
await this.authProvider.state()
|
||||
}
|
||||
// The SDK calls saveCodeVerifier before redirecting
|
||||
if (this.authProvider.saveCodeVerifier) {
|
||||
await this.authProvider.saveCodeVerifier("test-verifier")
|
||||
}
|
||||
// The SDK calls redirectToAuthorization to redirect the user
|
||||
if (this.authProvider.redirectToAuthorization) {
|
||||
await this.authProvider.redirectToAuthorization(new URL("https://auth.example.com/authorize?state=test"))
|
||||
}
|
||||
throw new MockUnauthorizedError()
|
||||
}
|
||||
throw new MockUnauthorizedError()
|
||||
}
|
||||
async finishAuth(_code: string) {}
|
||||
},
|
||||
}))
|
||||
|
||||
mock.module("@modelcontextprotocol/sdk/client/sse.js", () => ({
|
||||
SSEClientTransport: class MockSSE {
|
||||
constructor(url: URL, options?: { authProvider?: unknown }) {
|
||||
transportCalls.push({
|
||||
type: "sse",
|
||||
url: url.toString(),
|
||||
options: options ?? {},
|
||||
})
|
||||
}
|
||||
async start() {
|
||||
throw new Error("Mock SSE transport cannot connect")
|
||||
}
|
||||
},
|
||||
}))
|
||||
|
||||
// Mock the MCP SDK Client
|
||||
mock.module("@modelcontextprotocol/sdk/client/index.js", () => ({
|
||||
Client: class MockClient {
|
||||
async connect(transport: { start: () => Promise<void> }) {
|
||||
await transport.start()
|
||||
}
|
||||
},
|
||||
}))
|
||||
|
||||
// Mock UnauthorizedError in the auth module so instanceof checks work
|
||||
mock.module("@modelcontextprotocol/sdk/client/auth.js", () => ({
|
||||
UnauthorizedError: MockUnauthorizedError,
|
||||
}))
|
||||
|
||||
beforeEach(() => {
|
||||
transportCalls.length = 0
|
||||
simulateAuthFlow = true
|
||||
})
|
||||
|
||||
// Import modules after mocking
|
||||
const { MCP } = await import("../../src/mcp/index")
|
||||
const { Instance } = await import("../../src/project/instance")
|
||||
const { tmpdir } = await import("../fixture/fixture")
|
||||
|
||||
test("first connect to OAuth server shows needs_auth instead of failed", async () => {
|
||||
await using tmp = await tmpdir({
|
||||
init: async (dir) => {
|
||||
await Bun.write(
|
||||
`${dir}/opencode.json`,
|
||||
JSON.stringify({
|
||||
$schema: "https://opencode.ai/config.json",
|
||||
mcp: {
|
||||
"test-oauth": {
|
||||
type: "remote",
|
||||
url: "https://example.com/mcp",
|
||||
},
|
||||
},
|
||||
}),
|
||||
)
|
||||
},
|
||||
})
|
||||
|
||||
await Instance.provide({
|
||||
directory: tmp.path,
|
||||
fn: async () => {
|
||||
const result = await MCP.add("test-oauth", {
|
||||
type: "remote",
|
||||
url: "https://example.com/mcp",
|
||||
})
|
||||
|
||||
const serverStatus = result.status as Record<string, { status: string; error?: string }>
|
||||
|
||||
// The server should be detected as needing auth, NOT as failed.
|
||||
// Before the fix, provider.state() would throw a plain Error
|
||||
// ("No OAuth state saved for MCP server: test-oauth") which was
|
||||
// not caught as UnauthorizedError, causing status to be "failed".
|
||||
expect(serverStatus["test-oauth"]).toBeDefined()
|
||||
expect(serverStatus["test-oauth"].status).toBe("needs_auth")
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
test("state() generates a new state when none is saved", async () => {
|
||||
const { McpOAuthProvider } = await import("../../src/mcp/oauth-provider")
|
||||
const { McpAuth } = await import("../../src/mcp/auth")
|
||||
|
||||
await using tmp = await tmpdir()
|
||||
|
||||
await Instance.provide({
|
||||
directory: tmp.path,
|
||||
fn: async () => {
|
||||
const provider = new McpOAuthProvider(
|
||||
"test-state-gen",
|
||||
"https://example.com/mcp",
|
||||
{},
|
||||
{ onRedirect: async () => {} },
|
||||
)
|
||||
|
||||
// Ensure no state exists
|
||||
const entryBefore = await McpAuth.get("test-state-gen")
|
||||
expect(entryBefore?.oauthState).toBeUndefined()
|
||||
|
||||
// state() should generate and return a new state, not throw
|
||||
const state = await provider.state()
|
||||
expect(typeof state).toBe("string")
|
||||
expect(state.length).toBe(64) // 32 bytes as hex
|
||||
|
||||
// The generated state should be persisted
|
||||
const entryAfter = await McpAuth.get("test-state-gen")
|
||||
expect(entryAfter?.oauthState).toBe(state)
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
test("state() returns existing state when one is saved", async () => {
|
||||
const { McpOAuthProvider } = await import("../../src/mcp/oauth-provider")
|
||||
const { McpAuth } = await import("../../src/mcp/auth")
|
||||
|
||||
await using tmp = await tmpdir()
|
||||
|
||||
await Instance.provide({
|
||||
directory: tmp.path,
|
||||
fn: async () => {
|
||||
const provider = new McpOAuthProvider(
|
||||
"test-state-existing",
|
||||
"https://example.com/mcp",
|
||||
{},
|
||||
{ onRedirect: async () => {} },
|
||||
)
|
||||
|
||||
// Pre-save a state
|
||||
const existingState = "pre-saved-state-value"
|
||||
await McpAuth.updateOAuthState("test-state-existing", existingState)
|
||||
|
||||
// state() should return the existing state
|
||||
const state = await provider.state()
|
||||
expect(state).toBe(existingState)
|
||||
},
|
||||
})
|
||||
})
|
||||
Reference in New Issue
Block a user