Compare commits

...

103 Commits

Author SHA1 Message Date
adamdottv
c8f8d67a88 feat: codeAction tool 2025-05-14 14:57:47 -05:00
Dax Raad
182e32e4f7 add screenshot 2025-05-14 15:50:25 -04:00
adamdottv
5ea989fb74 feat: docSymbols and workspaceSymbols tools 2025-05-14 14:40:45 -05:00
adamdottv
45c778b90d feat: definition and references tools 2025-05-14 14:25:15 -05:00
Dax Raad
47cbb650a0 ci 2025-05-14 14:30:47 -04:00
Dax Raad
e91371c6a5 ci 2025-05-14 14:27:18 -04:00
adamdottv
9d17314309 fix: tweak 2025-05-14 13:15:39 -05:00
adamdottv
3982be4310 feat: session specific logs 2025-05-14 13:06:09 -05:00
adamdottv
4c998d4f4f chore: remove sourcegraph tool 2025-05-14 10:42:16 -05:00
adamdottv
f7849c2d59 fix: log display 2025-05-14 07:47:15 -05:00
adamdottv
463002185b fix: log perf 2025-05-14 07:38:21 -05:00
adamdottv
53a80eac1e chore: cleanup 2025-05-14 06:00:34 -05:00
adamdottv
01b6bf5bb7 chore: refactor db 2025-05-13 13:08:43 -05:00
adamdottv
d8f3b60625 chore: refactoring 2025-05-13 11:34:48 -05:00
adamdottv
cf8e16018d chore: refactoring 2025-05-13 11:15:14 -05:00
adamdottv
674797bd48 chore: refactoring 2025-05-13 11:07:34 -05:00
adamdottv
1f9610e266 chore: refactoring 2025-05-13 10:45:58 -05:00
adamdottv
ae86ef519c chore: refactoring 2025-05-13 10:27:09 -05:00
adamdottv
2391e338b4 chore: rename 2025-05-13 10:02:39 -05:00
adamdottv
1e9399fbee chore: cleanup 2025-05-13 09:26:54 -05:00
adamdottv
e9f74b867f chore: cleanup 2025-05-13 07:15:58 -05:00
adamdottv
5079556896 chore: cleanup 2025-05-13 07:04:27 -05:00
adamdottv
7f0e68b933 chore: cleanup 2025-05-13 06:55:18 -05:00
adamdottv
0c21ca5318 chore: cleanup 2025-05-13 06:51:28 -05:00
adamdottv
0117c72a2c chore: refactor diff 2025-05-12 15:17:50 -05:00
adamdottv
e3eb9e5435 chore: refactor diff 2025-05-12 14:48:13 -05:00
adamdottv
d941be3f1f chore: refactor agent.go 2025-05-12 14:43:12 -05:00
adamdottv
36e5ae804e chore: rename coder -> primary 2025-05-12 14:32:27 -05:00
adamdottv
c9b90dd184 fix: show context % 2025-05-12 14:08:57 -05:00
adamdottv
8270a1e4b1 chore: cleanup 2025-05-12 13:01:59 -05:00
adamdottv
7f9c992993 fix: log status messages 2025-05-12 12:53:20 -05:00
adamdottv
b6524c0982 chore: cleanup 2025-05-12 11:38:07 -05:00
adamdottv
425c0f1bab fix: timestamp formatting 2025-05-12 11:33:52 -05:00
adamdottv
d20d0c5a95 chore: cleanup 2025-05-12 11:22:19 -05:00
adamdottv
5af3c05d41 chore: cleanup 2025-05-12 10:53:13 -05:00
adamdottv
df4a9295c0 chore: cleanup 2025-05-12 10:46:14 -05:00
adamdottv
8cbfc581b5 chore: cleanup 2025-05-12 10:45:03 -05:00
Nicholas Hamilton
4bb350a09b Fix filepicker manual input (#146)
* fix: allows to type i while manual inputting filepath

* fix: file selection in filepicker focus mode

* remove duplicate code
2025-05-12 10:01:56 -05:00
adamdottv
17c5b9c12c fix: build 2025-05-12 09:59:57 -05:00
Ed Zynda
1f8580553c feat: custom commands (#133)
* Implement custom commands

* Add User: prefix

* Reuse var

* Check if the agent is busy and if so report a warning

* Update README

* fix typo

* Implement user and project scoped custom commands

* Allow for $ARGUMENTS

* UI tweaks

* Update internal/tui/components/dialog/arguments.go

Co-authored-by: Kujtim Hoxha <kujtimii.h@gmail.com>

* Also search in $HOME/.opencode/commands

---------

Co-authored-by: Kujtim Hoxha <kujtimii.h@gmail.com>
2025-05-12 09:58:59 -05:00
mineo
f92b2b76dc replace github.com/google/generative-ai-go with github.com/googleapis/go-genai (#138)
* replace to github.com/googleapis/go-genai

* fix history logic

* small fixes

---------

Co-authored-by: Kujtim Hoxha <kujtimii.h@gmail.com>
2025-05-12 09:56:30 -05:00
adamdottv
1d1a1ddcbf fix: visual tweaks 2025-05-12 09:54:20 -05:00
adamdottv
dfe5fd8d97 wip: refactoring 2025-05-12 09:44:56 -05:00
adamdottv
ed9fba99c9 wip: refactoring 2025-05-12 08:43:34 -05:00
adamdottv
f100777199 wip: logging improvements 2025-05-09 13:37:13 -05:00
adamdottv
f41b7bbd0a chore: refactoring status updates 2025-05-08 12:03:59 -05:00
adamdottv
e35ea2d448 fix: log page nav 2025-05-08 07:59:15 -05:00
adamdottv
bab17d7520 feat: session manager 2025-05-08 07:58:37 -05:00
adamdottv
051d7d7936 chore: logging token usage 2025-05-06 14:40:00 -05:00
adamdottv
b638dafe5f feat: better logs page 2025-05-06 14:22:37 -05:00
adamdottv
e387b1f16c fix: openrouter require_parameters 2025-05-06 11:17:32 -05:00
adamdottv
71a68dd56d feat: add qwen3 models 2025-05-06 10:19:07 -05:00
adamdottv
3ee8ebd3d3 fix: auto-compact logic 2025-05-06 10:03:21 -05:00
adamdottv
ef298b2f18 fix: hide empty messages 2025-05-05 14:38:33 -05:00
adamdottv
3cc08494a5 fix: pubsub leak and shutdown seq 2025-05-05 14:23:29 -05:00
adamdottv
afcdabd095 fix: anthropic non-empty blocks 2025-05-05 12:00:09 -05:00
adamdottv
efaba6c5b8 feat: hide tool calls 2025-05-05 11:25:34 -05:00
adamdottv
874715838a feat: show sender name and timestamp 2025-05-05 11:02:02 -05:00
adamdottv
167eb9ddfa fix: diagnostics visual in status bar 2025-05-05 07:04:33 -05:00
Joshua LaMorey-Salzmann
fba344718f Config fix correcting loose viper string check, default model now set correctly (#147) 2025-05-05 06:56:10 -05:00
adamdottv
cdd906e32e fix: bedrock supports attachments 2025-05-02 15:35:24 -05:00
phantomreactor
ff0ef3bb43 feat: add support for images 2025-05-02 15:29:46 -05:00
adamdottv
0095832be3 chore: cleanup and logging 2025-05-02 15:24:47 -05:00
adamdottv
406ccf9b87 fix: diagnostics tool init 2025-05-02 15:24:47 -05:00
adamdottv
f90d6238ed fix: bedrock context window 2025-05-02 15:24:47 -05:00
adamdottv
f004a0b8c3 fix: anthropic non-empty blocks 2025-05-02 15:24:47 -05:00
adamdottv
49423da081 feat: compact command with auto-compact 2025-05-02 15:24:47 -05:00
adamdottv
364cf5b429 feat: write to context.md by default 2025-05-02 15:24:47 -05:00
adamdottv
b2f24e38ed feat: better diagnostic visuals in status bar 2025-05-02 15:24:47 -05:00
adamdottv
49037e7b28 feat: better logs page 2025-05-02 15:24:47 -05:00
adamdottv
c66832d299 fix: wording 2025-05-01 11:54:35 -05:00
adamdottv
7398b4ce70 fix: don't truncate task tool output 2025-05-01 11:35:54 -05:00
Kujtim Hoxha
a61b2026eb add xai support (#135) 2025-05-01 11:08:26 -05:00
Aiden Cline
69ade34c2c fix: tweak the logic in config to ensure that env vs file configurations merge properly (#115) 2025-05-01 11:08:17 -05:00
Garrett Ladley
fbca5441f6 feat: test for getContextFromPaths (#105)
* feat: test for getContextFromPaths

* fix: use testify
2025-05-01 11:08:06 -05:00
Kujtim Hoxha
e4680caebb some small fixes 2025-05-01 11:07:47 -05:00
adamdottv
e760d28c5a feat: show hunk headers 2025-05-01 09:02:14 -05:00
adamdottv
7d5f0f9d18 fix: pass input to EDITOR 2025-05-01 07:58:27 -05:00
adamdottv
515f4e8642 fix: visual tweaks 2025-05-01 07:32:04 -05:00
adamdottv
f2b36b9234 fix: remove lsp tool 2025-05-01 07:28:37 -05:00
adamdottv
f224978bbc fix: remove manual lsp definition 2025-05-01 07:24:53 -05:00
adamdottv
8819a37a05 fix: logo 2025-05-01 06:43:59 -05:00
adamdottv
769dff00ba fix: don't mark as init 2025-05-01 06:43:30 -05:00
adamdottv
d1be7a984e fix: logo 2025-05-01 06:36:30 -05:00
adamdottv
3e30607a6d fix: minor prompt fix 2025-05-01 06:35:47 -05:00
adamdottv
d08e58279d feat: lsp discovery 2025-05-01 06:26:20 -05:00
adamdottv
7bc542abff fix: better diagnostics visual 2025-04-30 15:23:19 -05:00
adamdottv
ed50c36789 fix: lsp issues with tmp and deleted files 2025-04-30 12:20:51 -05:00
adamdottv
98cf65b425 fix: more intuitive keybinds 2025-04-30 11:34:21 -05:00
adamdottv
5406083850 fix: minor icon tweak 2025-04-30 11:33:30 -05:00
adamdottv
91ae9b33d3 feat: custom themes 2025-04-30 11:05:59 -05:00
adamdottv
a42175c067 fix: info and hint icons 2025-04-30 07:47:14 -05:00
adamdottv
8497145db2 fix: status sizing 2025-04-30 07:47:10 -05:00
adamdottv
89544fad61 feat: tron theme 2025-04-30 07:46:35 -05:00
adamdottv
1151accf4b feat: tokyonight theme 2025-04-30 07:46:35 -05:00
adamdottv
1ae3f1830b feat: dracula theme 2025-04-30 07:46:35 -05:00
adamdottv
1e958b62ad feat: opencode theme (default) 2025-04-30 07:46:35 -05:00
adamdottv
fdf5367f4f feat: monokai pro theme 2025-04-30 07:46:34 -05:00
adamdottv
0e8842a007 feat: onedark theme 2025-04-30 07:46:34 -05:00
adamdottv
060994f393 feat: flexoki theme 2025-04-30 07:46:34 -05:00
adamdottv
61b605e724 feat: themes 2025-04-30 07:46:34 -05:00
Adam
61d9dc9511 fix: allow text selection (#127) 2025-04-30 12:52:30 +02:00
Hunter Casten
76275e533e fix(openrouter): set api key from env (#129) 2025-04-30 12:50:57 +02:00
144 changed files with 13138 additions and 5018 deletions

View File

@@ -4,7 +4,7 @@ on:
workflow_dispatch:
push:
branches:
- main
- dev
concurrency: ${{ github.workflow }}-${{ github.ref }}

View File

@@ -36,5 +36,5 @@ jobs:
version: latest
args: release --clean
env:
GITHUB_TOKEN: ${{ secrets.HOMEBREW_GITHUB_TOKEN }}
GITHUB_TOKEN: ${{ secrets.SST_GITHUB_TOKEN }}
AUR_KEY: ${{ secrets.AUR_KEY }}

View File

@@ -12,7 +12,7 @@ builds:
- amd64
- arm64
ldflags:
- -s -w -X github.com/opencode-ai/opencode/internal/version.Version={{.Version}}
- -s -w -X github.com/sst/opencode/internal/version.Version={{.Version}}
main: ./main.go
archives:
@@ -35,10 +35,11 @@ snapshot:
name_template: "0.0.0-{{ .Timestamp }}"
aurs:
- name: opencode
homepage: "https://github.com/opencode-ai/opencode"
homepage: "https://github.com/sst/opencode"
description: "terminal based agent that can build anything"
maintainers:
- "kujtimiihoxha <kujtimii.h@gmail.com>"
- "dax"
- "adam"
license: "MIT"
private_key: "{{ .Env.AUR_KEY }}"
git_url: "ssh://aur@aur.archlinux.org/opencode-bin.git"
@@ -50,7 +51,7 @@ aurs:
install -Dm755 ./opencode "${pkgdir}/usr/bin/opencode"
brews:
- repository:
owner: opencode-ai
owner: sst
name: homebrew-tap
nfpms:
- maintainer: kujtimiihoxha

View File

@@ -1,8 +1,3 @@
{
"$schema": "./opencode-schema.json",
"lsp": {
"gopls": {
"command": "gopls"
}
}
"$schema": "./opencode-schema.json"
}

24
CONTEXT.md Normal file
View File

@@ -0,0 +1,24 @@
# OpenCode Development Context
## Build Commands
- Build: `go build`
- Run: `go run main.go`
- Test: `go test ./...`
- Test single package: `go test ./internal/package/...`
- Test single test: `go test ./internal/package -run TestName`
- Verbose test: `go test -v ./...`
- Coverage: `go test -cover ./...`
- Lint: `go vet ./...`
- Format: `go fmt ./...`
- Build snapshot: `./scripts/snapshot`
## Code Style
- Use Go 1.24+ features
- Follow standard Go formatting (gofmt)
- Use table-driven tests with t.Parallel() when possible
- Error handling: check errors immediately, return early
- Naming: CamelCase for exported, camelCase for unexported
- Imports: standard library first, then external, then internal
- Use context.Context for cancellation and timeouts
- Prefer interfaces for dependencies to enable testing
- Use testify for assertions in tests

160
README.md
View File

@@ -1,5 +1,7 @@
# ⌬ OpenCode
![OpenCode Terminal UI](screenshot.png)
> **⚠️ Early Development Notice:** This project is in early development and is not yet ready for production use. Features may change, break, or be incomplete. Use at your own risk.
A powerful terminal-based AI assistant for developers, providing intelligent coding assistance directly in your terminal.
@@ -35,7 +37,7 @@ curl -fsSL https://opencode.ai/install | VERSION=0.1.0 bash
### Using Homebrew (macOS and Linux)
```bash
brew install opencode-ai/tap/opencode
brew install sst/tap/opencode
```
### Using AUR (Arch Linux)
@@ -51,7 +53,7 @@ paru -S opencode-bin
### Using Go
```bash
go install github.com/opencode-ai/opencode@latest
go install github.com/sst/opencode@latest
```
## Configuration
@@ -67,7 +69,7 @@ OpenCode looks for configuration in the following locations:
You can configure OpenCode using environment variables:
| Environment Variable | Purpose |
|----------------------------|--------------------------------------------------------|
| -------------------------- | ------------------------------------------------------ |
| `ANTHROPIC_API_KEY` | For Claude models |
| `OPENAI_API_KEY` | For OpenAI models |
| `GEMINI_API_KEY` | For Google Gemini models |
@@ -79,7 +81,6 @@ You can configure OpenCode using environment variables:
| `AZURE_OPENAI_API_KEY` | For Azure OpenAI models (optional when using Entra ID) |
| `AZURE_OPENAI_API_VERSION` | For Azure OpenAI models |
### Configuration File Structure
```json
@@ -106,7 +107,7 @@ You can configure OpenCode using environment variables:
}
},
"agents": {
"coder": {
"primary": {
"model": "claude-3.7-sonnet",
"maxTokens": 5000
},
@@ -296,12 +297,81 @@ OpenCode's AI assistant has access to various tools to help with coding tasks:
### Other Tools
| Tool | Description | Parameters |
| ------------- | -------------------------------------- | ----------------------------------------------------------------------------------------- |
| `bash` | Execute shell commands | `command` (required), `timeout` (optional) |
| `fetch` | Fetch data from URLs | `url` (required), `format` (required), `timeout` (optional) |
| `sourcegraph` | Search code across public repositories | `query` (required), `count` (optional), `context_window` (optional), `timeout` (optional) |
| `agent` | Run sub-tasks with the AI agent | `prompt` (required) |
| Tool | Description | Parameters |
| ------- | ------------------------------- | ----------------------------------------------------------- |
| `bash` | Execute shell commands | `command` (required), `timeout` (optional) |
| `fetch` | Fetch data from URLs | `url` (required), `format` (required), `timeout` (optional) |
| `agent` | Run sub-tasks with the AI agent | `prompt` (required) |
## Theming
OpenCode supports multiple themes for customizing the appearance of the terminal interface.
### Available Themes
The following predefined themes are available:
- `opencode` (default)
- `catppuccin`
- `dracula`
- `flexoki`
- `gruvbox`
- `monokai`
- `onedark`
- `tokyonight`
- `tron`
- `custom` (user-defined)
### Setting a Theme
You can set a theme in your `.opencode.json` configuration file:
```json
{
"tui": {
"theme": "monokai"
}
}
```
### Custom Themes
You can define your own custom theme by setting the `theme` to `"custom"` and providing color definitions in the `customTheme` map:
```json
{
"tui": {
"theme": "custom",
"customTheme": {
"primary": "#ffcc00",
"secondary": "#00ccff",
"accent": { "dark": "#aa00ff", "light": "#ddccff" },
"error": "#ff0000"
}
}
}
```
#### Color Definition Formats
Custom theme colors support two formats:
1. **Simple Hex String**: A single hex color string (e.g., `"#aabbcc"`) that will be used for both light and dark terminal backgrounds.
2. **Adaptive Object**: An object with `dark` and `light` keys, each holding a hex color string. This allows for adaptive colors based on the terminal's background.
#### Available Color Keys
You can define any of the following color keys in your `customTheme`:
- Base colors: `primary`, `secondary`, `accent`
- Status colors: `error`, `warning`, `success`, `info`
- Text colors: `text`, `textMuted`, `textEmphasized`
- Background colors: `background`, `backgroundSecondary`, `backgroundDarker`
- Border colors: `borderNormal`, `borderFocused`, `borderDim`
- Diff view colors: `diffAdded`, `diffRemoved`, `diffContext`, etc.
You don't need to define all colors. Any undefined colors will fall back to the default "opencode" theme colors.
## Architecture
@@ -318,6 +388,72 @@ OpenCode is built with a modular architecture:
- **internal/session**: Session management
- **internal/lsp**: Language Server Protocol integration
## Custom Commands
OpenCode supports custom commands that can be created by users to quickly send predefined prompts to the AI assistant.
### Creating Custom Commands
Custom commands are predefined prompts stored as Markdown files in one of three locations:
1. **User Commands** (prefixed with `user:`):
```
$XDG_CONFIG_HOME/opencode/commands/
```
(typically `~/.config/opencode/commands/` on Linux/macOS)
or
```
$HOME/.opencode/commands/
```
2. **Project Commands** (prefixed with `project:`):
```
<PROJECT DIR>/.opencode/commands/
```
Each `.md` file in these directories becomes a custom command. The file name (without extension) becomes the command ID.
For example, creating a file at `~/.config/opencode/commands/prime-context.md` with content:
```markdown
RUN git ls-files
READ README.md
```
This creates a command called `user:prime-context`.
### Command Arguments
You can create commands that accept arguments by including the `$ARGUMENTS` placeholder in your command file:
```markdown
RUN git show $ARGUMENTS
```
When you run this command, OpenCode will prompt you to enter the text that should replace `$ARGUMENTS`.
### Organizing Commands
You can organize commands in subdirectories:
```
~/.config/opencode/commands/git/commit.md
```
This creates a command with ID `user:git:commit`.
### Using Custom Commands
1. Press `Ctrl+K` to open the command dialog
2. Select your custom command (prefixed with either `user:` or `project:`)
3. Press Enter to execute the command
The content of the command file will be sent as a message to the AI assistant.
## MCP (Model Context Protocol)
OpenCode implements the Model Context Protocol (MCP) to extend its capabilities through external tools. MCP provides a standardized way for the AI assistant to interact with external services and tools.
@@ -408,7 +544,7 @@ While the LSP client implementation supports the full LSP protocol (including co
```bash
# Clone the repository
git clone https://github.com/opencode-ai/opencode.git
git clone https://github.com/sst/opencode.git
cd opencode
# Build

View File

@@ -7,19 +7,42 @@ import (
"sync"
"time"
"log/slog"
tea "github.com/charmbracelet/bubbletea"
"github.com/opencode-ai/opencode/internal/app"
"github.com/opencode-ai/opencode/internal/config"
"github.com/opencode-ai/opencode/internal/db"
"github.com/opencode-ai/opencode/internal/llm/agent"
"github.com/opencode-ai/opencode/internal/logging"
"github.com/opencode-ai/opencode/internal/pubsub"
"github.com/opencode-ai/opencode/internal/tui"
"github.com/opencode-ai/opencode/internal/version"
zone "github.com/lrstanley/bubblezone"
"github.com/spf13/cobra"
"github.com/sst/opencode/internal/app"
"github.com/sst/opencode/internal/config"
"github.com/sst/opencode/internal/db"
"github.com/sst/opencode/internal/llm/agent"
"github.com/sst/opencode/internal/logging"
"github.com/sst/opencode/internal/lsp/discovery"
"github.com/sst/opencode/internal/pubsub"
"github.com/sst/opencode/internal/tui"
"github.com/sst/opencode/internal/version"
)
type SessionIDHandler struct {
slog.Handler
app *app.App
}
func (h *SessionIDHandler) Handle(ctx context.Context, r slog.Record) error {
if h.app != nil {
sessionID := h.app.CurrentSession.ID
if sessionID != "" {
r.AddAttrs(slog.String("session_id", sessionID))
}
}
return h.Handler.Handle(ctx, r)
}
func (h *SessionIDHandler) WithApp(app *app.App) *SessionIDHandler {
h.app = app
return h
}
var rootCmd = &cobra.Command{
Use: "OpenCode",
Short: "A terminal AI assistant for software development",
@@ -37,6 +60,13 @@ to assist developers in writing, debugging, and understanding code directly from
return nil
}
// Setup logging
lvl := new(slog.LevelVar)
textHandler := slog.NewTextHandler(logging.NewSlogWriter(), &slog.HandlerOptions{Level: lvl})
sessionAwareHandler := &SessionIDHandler{Handler: textHandler}
logger := slog.New(sessionAwareHandler)
slog.SetDefault(logger)
// Load the config
debug, _ := cmd.Flags().GetBool("debug")
cwd, _ := cmd.Flags().GetString("cwd")
@@ -53,11 +83,17 @@ to assist developers in writing, debugging, and understanding code directly from
}
cwd = c
}
_, err := config.Load(cwd, debug)
_, err := config.Load(cwd, debug, lvl)
if err != nil {
return err
}
// Run LSP auto-discovery
if err := discovery.IntegrateLSPServers(cwd); err != nil {
slog.Warn("Failed to auto-discover LSP servers", "error", err)
// Continue anyway, this is not a fatal error
}
// Connect DB, this will also run migrations
conn, err := db.Connect()
if err != nil {
@@ -70,16 +106,16 @@ to assist developers in writing, debugging, and understanding code directly from
app, err := app.New(ctx, conn)
if err != nil {
logging.Error("Failed to create app: %v", err)
slog.Error("Failed to create app", "error", err)
return err
}
sessionAwareHandler.WithApp(app)
// Set up the TUI
zone.NewGlobal()
program := tea.NewProgram(
tui.New(app),
tea.WithAltScreen(),
tea.WithMouseCellMotion(),
)
// Initialize MCP tools in the background
@@ -103,11 +139,11 @@ to assist developers in writing, debugging, and understanding code directly from
for {
select {
case <-tuiCtx.Done():
logging.Info("TUI message handler shutting down")
slog.Info("TUI message handler shutting down")
return
case msg, ok := <-ch:
if !ok {
logging.Info("TUI message channel closed")
slog.Info("TUI message channel closed")
return
}
program.Send(msg)
@@ -117,19 +153,19 @@ to assist developers in writing, debugging, and understanding code directly from
// Cleanup function for when the program exits
cleanup := func() {
// Shutdown the app
app.Shutdown()
// Cancel subscriptions first
cancelSubs()
// Then shutdown the app
app.Shutdown()
// Then cancel TUI message handler
tuiCancel()
// Wait for TUI message handler to finish
tuiWg.Wait()
logging.Info("All goroutines cleaned up")
slog.Info("All goroutines cleaned up")
}
// Run the TUI
@@ -137,18 +173,18 @@ to assist developers in writing, debugging, and understanding code directly from
cleanup()
if err != nil {
logging.Error("TUI error: %v", err)
slog.Error("TUI error", "error", err)
return fmt.Errorf("TUI error: %v", err)
}
logging.Info("TUI exited with result: %v", result)
slog.Info("TUI exited", "result", result)
return nil
},
}
// attemptTUIRecovery tries to recover the TUI after a panic
func attemptTUIRecovery(program *tea.Program) {
logging.Info("Attempting to recover TUI after panic")
slog.Info("Attempting to recover TUI after panic")
// We could try to restart the TUI or gracefully exit
// For now, we'll just quit the program to avoid further issues
@@ -165,7 +201,7 @@ func initMCPTools(ctx context.Context, app *app.App) {
// Set this up once with proper error handling
agent.GetMcpTools(ctxWithTimeout, app.Permissions)
logging.Info("MCP message handling goroutine exiting")
slog.Info("MCP message handling goroutine exiting")
}()
}
@@ -182,12 +218,16 @@ func setupSubscriber[T any](
defer logging.RecoverPanic(fmt.Sprintf("subscription-%s", name), nil)
subCh := subscriber(ctx)
if subCh == nil {
slog.Warn("subscription channel is nil", "name", name)
return
}
for {
select {
case event, ok := <-subCh:
if !ok {
logging.Info("subscription channel closed", "name", name)
slog.Info("subscription channel closed", "name", name)
return
}
@@ -196,13 +236,13 @@ func setupSubscriber[T any](
select {
case outputCh <- msg:
case <-time.After(2 * time.Second):
logging.Warn("message dropped due to slow consumer", "name", name)
slog.Warn("message dropped due to slow consumer", "name", name)
case <-ctx.Done():
logging.Info("subscription cancelled", "name", name)
slog.Info("subscription cancelled", "name", name)
return
}
case <-ctx.Done():
logging.Info("subscription cancelled", "name", name)
slog.Info("subscription cancelled", "name", name)
return
}
}
@@ -215,13 +255,14 @@ func setupSubscriptions(app *app.App, parentCtx context.Context) (chan tea.Msg,
wg := sync.WaitGroup{}
ctx, cancel := context.WithCancel(parentCtx) // Inherit from parent context
setupSubscriber(ctx, &wg, "logging", logging.Subscribe, ch)
setupSubscriber(ctx, &wg, "logging", app.Logs.Subscribe, ch)
setupSubscriber(ctx, &wg, "sessions", app.Sessions.Subscribe, ch)
setupSubscriber(ctx, &wg, "messages", app.Messages.Subscribe, ch)
setupSubscriber(ctx, &wg, "permissions", app.Permissions.Subscribe, ch)
setupSubscriber(ctx, &wg, "status", app.Status.Subscribe, ch)
cleanupFunc := func() {
logging.Info("Cancelling all subscriptions")
slog.Info("Cancelling all subscriptions")
cancel() // Signal all goroutines to stop
waitCh := make(chan struct{})
@@ -233,10 +274,10 @@ func setupSubscriptions(app *app.App, parentCtx context.Context) (chan tea.Msg,
select {
case <-waitCh:
logging.Info("All subscription goroutines completed successfully")
slog.Info("All subscription goroutines completed successfully")
close(ch) // Only close after all writers are confirmed done
case <-time.After(5 * time.Second):
logging.Warn("Timed out waiting for some subscription goroutines to complete")
slog.Warn("Timed out waiting for some subscription goroutines to complete")
close(ch)
}
}

View File

@@ -46,7 +46,7 @@ Here's an example configuration that conforms to the schema:
}
},
"agents": {
"coder": {
"primary": {
"model": "claude-3.7-sonnet",
"maxTokens": 5000,
"reasoningEffort": "medium"
@@ -61,4 +61,5 @@ Here's an example configuration that conforms to the schema:
}
}
}
```
```

View File

@@ -5,8 +5,8 @@ import (
"fmt"
"os"
"github.com/opencode-ai/opencode/internal/config"
"github.com/opencode-ai/opencode/internal/llm/models"
"github.com/sst/opencode/internal/config"
"github.com/sst/opencode/internal/llm/models"
)
// JSONSchemaType represents a JSON Schema type
@@ -98,6 +98,57 @@ func generateSchema() map[string]any {
},
}
schema["properties"].(map[string]any)["tui"] = map[string]any{
"type": "object",
"description": "Terminal User Interface configuration",
"properties": map[string]any{
"theme": map[string]any{
"type": "string",
"description": "TUI theme name",
"default": "opencode",
"enum": []string{
"opencode",
"catppuccin",
"dracula",
"flexoki",
"gruvbox",
"monokai",
"onedark",
"tokyonight",
"tron",
"custom",
},
},
"customTheme": map[string]any{
"type": "object",
"description": "Custom theme color definitions",
"additionalProperties": map[string]any{
"oneOf": []map[string]any{
{
"type": "string",
"pattern": "^#[0-9a-fA-F]{6}$",
},
{
"type": "object",
"properties": map[string]any{
"dark": map[string]any{
"type": "string",
"pattern": "^#[0-9a-fA-F]{6}$",
},
"light": map[string]any{
"type": "string",
"pattern": "^#[0-9a-fA-F]{6}$",
},
},
"required": []string{"dark", "light"},
"additionalProperties": false,
},
},
},
},
},
}
// Add MCP servers
schema["properties"].(map[string]any)["mcpServers"] = map[string]any{
"type": "object",
@@ -223,7 +274,7 @@ func generateSchema() map[string]any {
// Add specific agent properties
agentProperties := map[string]any{}
knownAgents := []string{
string(config.AgentCoder),
string(config.AgentPrimary),
string(config.AgentTask),
string(config.AgentTitle),
}

24
go.mod
View File

@@ -1,9 +1,7 @@
module github.com/opencode-ai/opencode
module github.com/sst/opencode
go 1.24.0
toolchain go1.24.2
require (
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.7.0
github.com/JohannesKaufmann/html-to-markdown v1.6.0
@@ -16,12 +14,10 @@ require (
github.com/charmbracelet/bubbles v0.20.0
github.com/charmbracelet/bubbletea v1.3.4
github.com/charmbracelet/glamour v0.9.1
github.com/charmbracelet/huh v0.6.0
github.com/charmbracelet/lipgloss v1.1.0
github.com/charmbracelet/x/ansi v0.8.0
github.com/fsnotify/fsnotify v1.8.0
github.com/go-logfmt/logfmt v0.6.0
github.com/google/generative-ai-go v0.19.0
github.com/google/uuid v1.6.0
github.com/lrstanley/bubblezone v0.0.0-20250315020633-c249a3fe1231
github.com/mark3labs/mcp-go v0.17.0
@@ -35,16 +31,12 @@ require (
github.com/spf13/cobra v1.9.1
github.com/spf13/viper v1.20.0
github.com/stretchr/testify v1.10.0
google.golang.org/api v0.215.0
)
require (
cloud.google.com/go v0.116.0 // indirect
cloud.google.com/go/ai v0.8.0 // indirect
cloud.google.com/go/auth v0.13.0 // indirect
cloud.google.com/go/auth/oauth2adapt v0.2.6 // indirect
cloud.google.com/go/compute/metadata v0.6.0 // indirect
cloud.google.com/go/longrunning v0.5.7 // indirect
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.17.0 // indirect
github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0 // indirect
github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2 // indirect
@@ -68,30 +60,30 @@ require (
github.com/aymerick/douceur v0.2.0 // indirect
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd // indirect
github.com/charmbracelet/x/exp/strings v0.0.0-20240722160745-212f7b056ed0 // indirect
github.com/charmbracelet/x/term v0.2.1 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/disintegration/imaging v1.6.2
github.com/dlclark/regexp2 v1.11.4 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/go-logr/logr v1.4.2 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-viper/mapstructure/v2 v2.2.1 // indirect
github.com/golang-jwt/jwt/v5 v5.2.2 // indirect
github.com/google/go-cmp v0.7.0 // indirect
github.com/google/s2a-go v0.1.8 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect
github.com/googleapis/gax-go/v2 v2.14.1 // indirect
github.com/gorilla/css v1.0.1 // indirect
github.com/gorilla/websocket v1.5.3 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/kylelemons/godebug v1.1.0 // indirect
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
github.com/lucasb-eyer/go-colorful v1.2.0
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-localereader v0.0.1 // indirect
github.com/mattn/go-runewidth v0.0.16 // indirect
github.com/mfridman/interpolate v0.0.2 // indirect
github.com/microcosm-cc/bluemonday v1.0.27 // indirect
github.com/mitchellh/hashstructure/v2 v2.0.2 // indirect
github.com/muesli/cancelreader v0.2.2 // indirect
github.com/ncruces/julianday v1.0.0 // indirect
github.com/pelletier/go-toml/v2 v2.2.3 // indirect
@@ -116,21 +108,19 @@ require (
github.com/yuin/goldmark v1.7.8 // indirect
github.com/yuin/goldmark-emoji v1.0.5 // indirect
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.54.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0 // indirect
go.opentelemetry.io/otel v1.35.0 // indirect
go.opentelemetry.io/otel/metric v1.35.0 // indirect
go.opentelemetry.io/otel/trace v1.35.0 // indirect
go.uber.org/multierr v1.11.0 // indirect
golang.org/x/crypto v0.37.0 // indirect
golang.org/x/image v0.26.0 // indirect
golang.org/x/net v0.39.0 // indirect
golang.org/x/oauth2 v0.25.0 // indirect
golang.org/x/sync v0.13.0 // indirect
golang.org/x/sys v0.32.0 // indirect
golang.org/x/term v0.31.0 // indirect
golang.org/x/text v0.24.0 // indirect
golang.org/x/time v0.8.0 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20250106144421-5f5ef82da422 // indirect
google.golang.org/genai v1.3.0
google.golang.org/genproto/googleapis/rpc v0.0.0-20250324211829-b45e905df463 // indirect
google.golang.org/grpc v1.71.0 // indirect
google.golang.org/protobuf v1.36.6 // indirect

33
go.sum
View File

@@ -1,15 +1,9 @@
cloud.google.com/go v0.116.0 h1:B3fRrSDkLRt5qSHWe40ERJvhvnQwdZiHu0bJOpldweE=
cloud.google.com/go v0.116.0/go.mod h1:cEPSRWPzZEswwdr9BxE6ChEn01dWlTaF05LiC2Xs70U=
cloud.google.com/go/ai v0.8.0 h1:rXUEz8Wp2OlrM8r1bfmpF2+VKqc1VJpafE3HgzRnD/w=
cloud.google.com/go/ai v0.8.0/go.mod h1:t3Dfk4cM61sytiggo2UyGsDVW3RF1qGZaUKDrZFyqkE=
cloud.google.com/go/auth v0.13.0 h1:8Fu8TZy167JkW8Tj3q7dIkr2v4cndv41ouecJx0PAHs=
cloud.google.com/go/auth v0.13.0/go.mod h1:COOjD9gwfKNKz+IIduatIhYJQIc0mG3H102r/EMxX6Q=
cloud.google.com/go/auth/oauth2adapt v0.2.6 h1:V6a6XDu2lTwPZWOawrAa9HUK+DB2zfJyTuciBG5hFkU=
cloud.google.com/go/auth/oauth2adapt v0.2.6/go.mod h1:AlmsELtlEBnaNTL7jCj8VQFLy6mbZv0s4Q7NGBeQ5E8=
cloud.google.com/go/compute/metadata v0.6.0 h1:A6hENjEsCDtC1k8byVsgwvVcioamEHvZ4j01OwKxG9I=
cloud.google.com/go/compute/metadata v0.6.0/go.mod h1:FjyFAW1MW0C203CEOMDTu3Dk1FlqW3Rga40jzHL4hfg=
cloud.google.com/go/longrunning v0.5.7 h1:WLbHekDbjK1fVFD3ibpFFVoyizlLRl73I7YKuAKilhU=
cloud.google.com/go/longrunning v0.5.7/go.mod h1:8GClkudohy1Fxm3owmBGid8W0pSgodEMwEAztp38Xng=
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.17.0 h1:g0EZJwz7xkXQiZAI5xi9f3WWFYBlX1CPTrR+NDToRkQ=
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.17.0/go.mod h1:XCW7KnZet0Opnr7HccfUw1PLc4CjHqpcaxW8DHklNkQ=
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.7.0 h1:tfLQ34V6F7tVSwoTf/4lH5sE0o6eCJuNDTmH09nDpbc=
@@ -82,8 +76,6 @@ github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc h1:4p
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc/go.mod h1:X4/0JoqgTIPSFcRA/P6INZzIuyqdFY5rm8tb41s9okk=
github.com/charmbracelet/glamour v0.9.1 h1:11dEfiGP8q1BEqvGoIjivuc2rBk+5qEXdPtaQ2WoiCM=
github.com/charmbracelet/glamour v0.9.1/go.mod h1:+SHvIS8qnwhgTpVMiXwn7OfGomSqff1cHBCI8jLOetk=
github.com/charmbracelet/huh v0.6.0 h1:mZM8VvZGuE0hoDXq6XLxRtgfWyTI3b2jZNKh0xWmax8=
github.com/charmbracelet/huh v0.6.0/go.mod h1:GGNKeWCeNzKpEOh/OJD8WBwTQjV3prFAtQPpLv+AVwU=
github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY=
github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30=
github.com/charmbracelet/x/ansi v0.8.0 h1:9GTq3xq9caJW8ZrBTe0LIe2fvfLR/bYXKTx2llXn7xE=
@@ -92,14 +84,14 @@ github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd h1:vy0G
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd/go.mod h1:xe0nKWGd3eJgtqZRaN9RjMtK7xUYchjzPr7q6kcvCCs=
github.com/charmbracelet/x/exp/golden v0.0.0-20240815200342-61de596daa2b h1:MnAMdlwSltxJyULnrYbkZpp4k58Co7Tah3ciKhSNo0Q=
github.com/charmbracelet/x/exp/golden v0.0.0-20240815200342-61de596daa2b/go.mod h1:wDlXFlCrmJ8J+swcL/MnGUuYnqgQdW9rhSD61oNMb6U=
github.com/charmbracelet/x/exp/strings v0.0.0-20240722160745-212f7b056ed0 h1:qko3AQ4gK1MTS/de7F5hPGx6/k1u0w4TeYmBFwzYVP4=
github.com/charmbracelet/x/exp/strings v0.0.0-20240722160745-212f7b056ed0/go.mod h1:pBhA0ybfXv6hDjQUZ7hk1lVxBiUbupdw5R31yPUViVQ=
github.com/charmbracelet/x/term v0.2.1 h1:AQeHeLZ1OqSXhrAWpYUtZyX1T3zVxfpZuEQMIQaGIAQ=
github.com/charmbracelet/x/term v0.2.1/go.mod h1:oQ4enTYFV7QN4m0i9mzHrViD7TQKvNEEkHUMCmsxdUg=
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/disintegration/imaging v1.6.2 h1:w1LecBlG2Lnp8B3jk5zSuNqd7b4DXhcjwek1ei82L+c=
github.com/disintegration/imaging v1.6.2/go.mod h1:44/5580QXChDfwIclfc/PCwrr44amcmDAg8hxG0Ewe4=
github.com/dlclark/regexp2 v1.11.4 h1:rPYF9/LECdNymJufQKmri9gV604RvvABwgOA8un7yAo=
github.com/dlclark/regexp2 v1.11.4/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
@@ -125,8 +117,6 @@ github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeD
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
github.com/google/generative-ai-go v0.19.0 h1:R71szggh8wHMCUlEMsW2A/3T+5LdEIkiaHSYgSpUgdg=
github.com/google/generative-ai-go v0.19.0/go.mod h1:JYolL13VG7j79kM5BtHz4qwONHkeJQzOCkKXnpqtS/E=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/s2a-go v0.1.8 h1:zZDs9gcbt9ZPLV0ndSyQk6Kacx2g/X+SKYovpnz3SMM=
@@ -139,6 +129,8 @@ github.com/googleapis/gax-go/v2 v2.14.1 h1:hb0FFeiPaQskmvakKu5EbCbpntQn48jyHuvrk
github.com/googleapis/gax-go/v2 v2.14.1/go.mod h1:Hb/NubMaVM88SrNkvl8X/o8XWwDJEPqouaLeN2IUxoA=
github.com/gorilla/css v1.0.1 h1:ntNaBIghp6JmvWnxbZKANoLyuXTPZ4cAMlo6RyhlbO8=
github.com/gorilla/css v1.0.1/go.mod h1:BvnYkspnSzMmwRK+b8/xgNPLiIuNZr6vbZBTPQ2A3b0=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM=
github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
@@ -169,8 +161,6 @@ github.com/mfridman/interpolate v0.0.2 h1:pnuTK7MQIxxFz1Gr+rjSIx9u7qVjf5VOoM/u6B
github.com/mfridman/interpolate v0.0.2/go.mod h1:p+7uk6oE07mpE/Ik1b8EckO0O4ZXiGAfshKBWLUM9Xg=
github.com/microcosm-cc/bluemonday v1.0.27 h1:MpEUotklkwCSLeH+Qdx1VJgNqLlpY2KXwXFM08ygZfk=
github.com/microcosm-cc/bluemonday v1.0.27/go.mod h1:jFi9vgW+H7c3V0lb6nR74Ib/DIB5OBs92Dimizgw2cA=
github.com/mitchellh/hashstructure/v2 v2.0.2 h1:vGKWl0YJqUNxE8d+h8f6NJLcCJrgbhC4NcD46KavDd4=
github.com/mitchellh/hashstructure/v2 v2.0.2/go.mod h1:MG3aRVU/N29oo/V/IhBX8GR/zz4kQkprJgF2EVszyDE=
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI=
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo=
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
@@ -258,8 +248,6 @@ github.com/yuin/goldmark-emoji v1.0.5 h1:EMVWyCGPlXJfUXBXpuMu+ii3TIaxbVBnEX9uaDC
github.com/yuin/goldmark-emoji v1.0.5/go.mod h1:tTkZEbwu5wkPmgTcitqddVxY9osFZiavD+r4AzQrh1U=
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.54.0 h1:r6I7RJCN86bpD/FQwedZ0vSixDpwuWREjW9oRMsmqDc=
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.54.0/go.mod h1:B9yO6b04uB80CzjedvewuqDhxJxi11s7/GtiGa8bAjI=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0 h1:TT4fX+nBOA/+LUkobKGW1ydGcn+G3vRw9+g5HwCphpk=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0/go.mod h1:L7UH0GbB0p47T4Rri3uHjbpCFYrVrwc1I25QhNPiGK8=
go.opentelemetry.io/otel v1.35.0 h1:xKWKPxrxB6OtMCbmMY021CqC45J+3Onta9MqjhnusiQ=
@@ -283,6 +271,9 @@ golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE=
golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc=
golang.org/x/exp v0.0.0-20250305212735-054e65f0b394 h1:nDVHiLt8aIbd/VzvPWN6kSOPE7+F/fNFDSXLVYkE/Iw=
golang.org/x/exp v0.0.0-20250305212735-054e65f0b394/go.mod h1:sIifuuw/Yco/y6yb6+bDNfyeQ/MdPUy/hKEMYQV17cM=
golang.org/x/image v0.0.0-20191009234506-e7c1f5e7dbb8/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
golang.org/x/image v0.26.0 h1:4XjIFEZWQmCZi6Wv8BoxsDhRU3RVnLX04dToTDAEPlY=
golang.org/x/image v0.26.0/go.mod h1:lcxbMFAovzpnJxzXS3nyL83K27tmqtKzIJpctK8YO5c=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
@@ -296,8 +287,6 @@ golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY=
golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E=
golang.org/x/oauth2 v0.25.0 h1:CY4y7XT9v0cRI9oupztF8AgiIu99L/ksR/Xp/6jrZ70=
golang.org/x/oauth2 v0.25.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -338,17 +327,13 @@ golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0=
golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU=
golang.org/x/time v0.8.0 h1:9i3RxcPv3PZnitoVGMPDKZSq1xW1gK1Xy3ArNOGZfEg=
golang.org/x/time v0.8.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/api v0.215.0 h1:jdYF4qnyczlEz2ReWIsosNLDuzXyvFHJtI5gcr0J7t0=
google.golang.org/api v0.215.0/go.mod h1:fta3CVtuJYOEdugLNWm6WodzOS8KdFckABwN4I40hzY=
google.golang.org/genproto/googleapis/api v0.0.0-20250106144421-5f5ef82da422 h1:GVIKPyP/kLIyVOgOnTwFOrvQaQUzOzGMCxgFUOEmm24=
google.golang.org/genproto/googleapis/api v0.0.0-20250106144421-5f5ef82da422/go.mod h1:b6h1vNKhxaSoEI+5jc3PJUCustfli/mRab7295pY7rw=
google.golang.org/genai v1.3.0 h1:tXhPJF30skOjnnDY7ZnjK3q7IKy4PuAlEA0fk7uEaEI=
google.golang.org/genai v1.3.0/go.mod h1:TyfOKRz/QyCaj6f/ZDt505x+YreXnY40l2I6k8TvgqY=
google.golang.org/genproto/googleapis/rpc v0.0.0-20250324211829-b45e905df463 h1:e0AIkUUhxyBKh6ssZNrAMeqhA7RKUj42346d1y02i2g=
google.golang.org/genproto/googleapis/rpc v0.0.0-20250324211829-b45e905df463/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A=
google.golang.org/grpc v1.71.0 h1:kF77BGdPTQ4/JZWMlb9VpJ5pa25aqvVqogsxNHHdeBg=

View File

@@ -40,15 +40,15 @@ INSTALL_DIR=$HOME/.opencode/bin
mkdir -p "$INSTALL_DIR"
if [ -z "$requested_version" ]; then
url="https://github.com/opencode-ai/opencode/releases/latest/download/$filename"
specific_version=$(curl -s https://api.github.com/repos/opencode-ai/opencode/releases/latest | awk -F'"' '/"tag_name": "/ {gsub(/^v/, "", $4); print $4}')
url="https://github.com/sst/opencode/releases/latest/download/$filename"
specific_version=$(curl -s https://api.github.com/repos/sst/opencode/releases/latest | awk -F'"' '/"tag_name": "/ {gsub(/^v/, "", $4); print $4}')
if [[ $? -ne 0 ]]; then
echo "${RED}Failed to fetch version information${NC}"
exit 1
fi
else
url="https://github.com/opencode-ai/opencode/releases/download/v${requested_version}/$filename"
url="https://github.com/sst/opencode/releases/download/v${requested_version}/$filename"
specific_version=$requested_version
fi

View File

@@ -7,24 +7,30 @@ import (
"sync"
"time"
"github.com/opencode-ai/opencode/internal/config"
"github.com/opencode-ai/opencode/internal/db"
"github.com/opencode-ai/opencode/internal/history"
"github.com/opencode-ai/opencode/internal/llm/agent"
"github.com/opencode-ai/opencode/internal/logging"
"github.com/opencode-ai/opencode/internal/lsp"
"github.com/opencode-ai/opencode/internal/message"
"github.com/opencode-ai/opencode/internal/permission"
"github.com/opencode-ai/opencode/internal/session"
"log/slog"
"github.com/sst/opencode/internal/config"
"github.com/sst/opencode/internal/history"
"github.com/sst/opencode/internal/llm/agent"
"github.com/sst/opencode/internal/logging"
"github.com/sst/opencode/internal/lsp"
"github.com/sst/opencode/internal/message"
"github.com/sst/opencode/internal/permission"
"github.com/sst/opencode/internal/session"
"github.com/sst/opencode/internal/status"
"github.com/sst/opencode/internal/tui/theme"
)
type App struct {
Sessions session.Service
Messages message.Service
History history.Service
Permissions permission.Service
CurrentSession *session.Session
Logs logging.Service
Sessions session.Service
Messages message.Service
History history.Service
Permissions permission.Service
Status status.Service
CoderAgent agent.Service
PrimaryAgent agent.Service
LSPClients map[string]*lsp.Client
@@ -36,28 +42,59 @@ type App struct {
}
func New(ctx context.Context, conn *sql.DB) (*App, error) {
q := db.New(conn)
sessions := session.NewService(q)
messages := message.NewService(q)
files := history.NewService(q, conn)
err := logging.InitService(conn)
if err != nil {
slog.Error("Failed to initialize logging service", "error", err)
return nil, err
}
err = session.InitService(conn)
if err != nil {
slog.Error("Failed to initialize session service", "error", err)
return nil, err
}
err = message.InitService(conn)
if err != nil {
slog.Error("Failed to initialize message service", "error", err)
return nil, err
}
err = history.InitService(conn)
if err != nil {
slog.Error("Failed to initialize history service", "error", err)
return nil, err
}
err = permission.InitService()
if err != nil {
slog.Error("Failed to initialize permission service", "error", err)
return nil, err
}
err = status.InitService()
if err != nil {
slog.Error("Failed to initialize status service", "error", err)
return nil, err
}
app := &App{
Sessions: sessions,
Messages: messages,
History: files,
Permissions: permission.NewPermissionService(),
LSPClients: make(map[string]*lsp.Client),
CurrentSession: &session.Session{},
Logs: logging.GetService(),
Sessions: session.GetService(),
Messages: message.GetService(),
History: history.GetService(),
Permissions: permission.GetService(),
Status: status.GetService(),
LSPClients: make(map[string]*lsp.Client),
}
// Initialize theme based on configuration
app.initTheme()
// Initialize LSP clients in the background
go app.initLSPClients(ctx)
var err error
app.CoderAgent, err = agent.NewAgent(
config.AgentCoder,
app.PrimaryAgent, err = agent.NewAgent(
config.AgentPrimary,
app.Sessions,
app.Messages,
agent.CoderAgentTools(
agent.PrimaryAgentTools(
app.Permissions,
app.Sessions,
app.Messages,
@@ -66,13 +103,28 @@ func New(ctx context.Context, conn *sql.DB) (*App, error) {
),
)
if err != nil {
logging.Error("Failed to create coder agent", err)
slog.Error("Failed to create primary agent", "error", err)
return nil, err
}
return app, nil
}
// initTheme sets the application theme based on the configuration
func (app *App) initTheme() {
cfg := config.Get()
if cfg == nil || cfg.TUI.Theme == "" {
return // Use default theme
}
// Try to set the theme from config
err := theme.SetTheme(cfg.TUI.Theme)
if err != nil {
slog.Warn("Failed to set theme from config, using default theme", "theme", cfg.TUI.Theme, "error", err)
} else {
slog.Debug("Set theme from config", "theme", cfg.TUI.Theme)
}
}
// Shutdown performs a clean shutdown of the application
func (app *App) Shutdown() {
@@ -93,7 +145,7 @@ func (app *App) Shutdown() {
for name, client := range clients {
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
if err := client.Shutdown(shutdownCtx); err != nil {
logging.Error("Failed to shutdown LSP client", "name", name, "error", err)
slog.Error("Failed to shutdown LSP client", "name", name, "error", err)
}
cancel()
}

View File

@@ -4,10 +4,12 @@ import (
"context"
"time"
"github.com/opencode-ai/opencode/internal/config"
"github.com/opencode-ai/opencode/internal/logging"
"github.com/opencode-ai/opencode/internal/lsp"
"github.com/opencode-ai/opencode/internal/lsp/watcher"
"log/slog"
"github.com/sst/opencode/internal/config"
"github.com/sst/opencode/internal/logging"
"github.com/sst/opencode/internal/lsp"
"github.com/sst/opencode/internal/lsp/watcher"
)
func (app *App) initLSPClients(ctx context.Context) {
@@ -18,29 +20,29 @@ func (app *App) initLSPClients(ctx context.Context) {
// Start each client initialization in its own goroutine
go app.createAndStartLSPClient(ctx, name, clientConfig.Command, clientConfig.Args...)
}
logging.Info("LSP clients initialization started in background")
slog.Info("LSP clients initialization started in background")
}
// createAndStartLSPClient creates a new LSP client, initializes it, and starts its workspace watcher
func (app *App) createAndStartLSPClient(ctx context.Context, name string, command string, args ...string) {
// Create a specific context for initialization with a timeout
logging.Info("Creating LSP client", "name", name, "command", command, "args", args)
slog.Info("Creating LSP client", "name", name, "command", command, "args", args)
// Create the LSP client
lspClient, err := lsp.NewClient(ctx, command, args...)
if err != nil {
logging.Error("Failed to create LSP client for", name, err)
slog.Error("Failed to create LSP client for", name, err)
return
}
// Create a longer timeout for initialization (some servers take time to start)
initCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
// Initialize with the initialization context
_, err = lspClient.InitializeLSPClient(initCtx, config.WorkingDirectory())
if err != nil {
logging.Error("Initialize failed", "name", name, "error", err)
slog.Error("Initialize failed", "name", name, "error", err)
// Clean up the client to prevent resource leaks
lspClient.Close()
return
@@ -48,22 +50,22 @@ func (app *App) createAndStartLSPClient(ctx context.Context, name string, comman
// Wait for the server to be ready
if err := lspClient.WaitForServerReady(initCtx); err != nil {
logging.Error("Server failed to become ready", "name", name, "error", err)
slog.Error("Server failed to become ready", "name", name, "error", err)
// We'll continue anyway, as some functionality might still work
lspClient.SetServerState(lsp.StateError)
} else {
logging.Info("LSP server is ready", "name", name)
slog.Info("LSP server is ready", "name", name)
lspClient.SetServerState(lsp.StateReady)
}
logging.Info("LSP client initialized", "name", name)
slog.Info("LSP client initialized", "name", name)
// Create a child context that can be canceled when the app is shutting down
watchCtx, cancelFunc := context.WithCancel(ctx)
// Create a context with the server name for better identification
watchCtx = context.WithValue(watchCtx, "serverName", name)
// Create the workspace watcher
workspaceWatcher := watcher.NewWorkspaceWatcher(lspClient)
@@ -92,7 +94,7 @@ func (app *App) runWorkspaceWatcher(ctx context.Context, name string, workspaceW
})
workspaceWatcher.WatchWorkspace(ctx, config.WorkingDirectory())
logging.Info("Workspace watcher stopped", "client", name)
slog.Info("Workspace watcher stopped", "client", name)
}
// restartLSPClient attempts to restart a crashed or failed LSP client
@@ -101,7 +103,7 @@ func (app *App) restartLSPClient(ctx context.Context, name string) {
cfg := config.Get()
clientConfig, exists := cfg.LSP[name]
if !exists {
logging.Error("Cannot restart client, configuration not found", "client", name)
slog.Error("Cannot restart client, configuration not found", "client", name)
return
}
@@ -118,9 +120,15 @@ func (app *App) restartLSPClient(ctx context.Context, name string) {
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
_ = oldClient.Shutdown(shutdownCtx)
cancel()
// Ensure we close the client to free resources
_ = oldClient.Close()
}
// Wait a moment before restarting to avoid rapid restart cycles
time.Sleep(1 * time.Second)
// Create a new client using the shared function
app.createAndStartLSPClient(ctx, name, clientConfig.Command, clientConfig.Args...)
logging.Info("Successfully restarted LSP client", "client", name)
slog.Info("Successfully restarted LSP client", "client", name)
}

View File

@@ -2,14 +2,16 @@
package config
import (
"encoding/json"
"fmt"
"log/slog"
"os"
"os/user"
"path/filepath"
"strings"
"github.com/opencode-ai/opencode/internal/llm/models"
"github.com/opencode-ai/opencode/internal/logging"
"github.com/spf13/viper"
"github.com/sst/opencode/internal/llm/models"
)
// MCPType defines the type of MCP (Model Control Protocol) server.
@@ -34,9 +36,9 @@ type MCPServer struct {
type AgentName string
const (
AgentCoder AgentName = "coder"
AgentTask AgentName = "task"
AgentTitle AgentName = "title"
AgentPrimary AgentName = "primary"
AgentTask AgentName = "task"
AgentTitle AgentName = "title"
)
// Agent defines configuration for different LLM models and their token limits.
@@ -65,6 +67,12 @@ type LSPConfig struct {
Options any `json:"options"`
}
// TUIConfig defines the configuration for the Terminal User Interface.
type TUIConfig struct {
Theme string `json:"theme,omitempty"`
CustomTheme map[string]any `json:"customTheme,omitempty"`
}
// Config is the main configuration structure for the application.
type Config struct {
Data Data `json:"data"`
@@ -76,6 +84,7 @@ type Config struct {
Debug bool `json:"debug,omitempty"`
DebugLSP bool `json:"debugLSP,omitempty"`
ContextPaths []string `json:"contextPaths,omitempty"`
TUI TUIConfig `json:"tui"`
}
// Application constants
@@ -93,6 +102,8 @@ var defaultContextPaths = []string{
".cursor/rules/",
"CLAUDE.md",
"CLAUDE.local.md",
"CONTEXT.md",
"CONTEXT.local.md",
"opencode.md",
"opencode.local.md",
"OpenCode.md",
@@ -107,7 +118,7 @@ var cfg *Config
// Load initializes the configuration from environment variables and config files.
// If debug is true, debug mode is enabled and log level is set to debug.
// It returns an error if configuration loading fails.
func Load(workingDir string, debug bool) (*Config, error) {
func Load(workingDir string, debug bool, lvl *slog.LevelVar) (*Config, error) {
if cfg != nil {
return cfg, nil
}
@@ -121,7 +132,6 @@ func Load(workingDir string, debug bool) (*Config, error) {
configureViper()
setDefaults(debug)
setProviderDefaults()
// Read global config
if err := readConfig(viper.ReadInConfig()); err != nil {
@@ -131,45 +141,21 @@ func Load(workingDir string, debug bool) (*Config, error) {
// Load and merge local config
mergeLocalConfig(workingDir)
setProviderDefaults()
// Apply configuration to the struct
if err := viper.Unmarshal(cfg); err != nil {
return cfg, fmt.Errorf("failed to unmarshal config: %w", err)
}
applyDefaultValues()
defaultLevel := slog.LevelInfo
if cfg.Debug {
defaultLevel = slog.LevelDebug
}
if os.Getenv("OPENCODE_DEV_DEBUG") == "true" {
loggingFile := fmt.Sprintf("%s/%s", cfg.Data.Directory, "debug.log")
// if file does not exist create it
if _, err := os.Stat(loggingFile); os.IsNotExist(err) {
if err := os.MkdirAll(cfg.Data.Directory, 0o755); err != nil {
return cfg, fmt.Errorf("failed to create directory: %w", err)
}
if _, err := os.Create(loggingFile); err != nil {
return cfg, fmt.Errorf("failed to create log file: %w", err)
}
}
sloggingFileWriter, err := os.OpenFile(loggingFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o666)
if err != nil {
return cfg, fmt.Errorf("failed to open log file: %w", err)
}
// Configure logger
logger := slog.New(slog.NewTextHandler(sloggingFileWriter, &slog.HandlerOptions{
Level: defaultLevel,
}))
slog.SetDefault(logger)
} else {
// Configure logger
logger := slog.New(slog.NewTextHandler(logging.NewWriter(), &slog.HandlerOptions{
Level: defaultLevel,
}))
slog.SetDefault(logger)
}
lvl.Set(defaultLevel)
slog.SetLogLoggerLevel(defaultLevel)
// Validate configuration
if err := Validate(); err != nil {
@@ -203,6 +189,7 @@ func configureViper() {
func setDefaults(debug bool) {
viper.SetDefault("data.directory", defaultDataDirectory)
viper.SetDefault("contextPaths", defaultContextPaths)
viper.SetDefault("tui.theme", "opencode")
if debug {
viper.SetDefault("debug", true)
@@ -213,7 +200,8 @@ func setDefaults(debug bool) {
}
}
// setProviderDefaults configures LLM provider defaults based on environment variables.
// setProviderDefaults configures LLM provider defaults based on provider provided by
// environment variables and configuration file.
func setProviderDefaults() {
// Set all API keys we can find in the environment
if apiKey := os.Getenv("ANTHROPIC_API_KEY"); apiKey != "" {
@@ -228,66 +216,85 @@ func setProviderDefaults() {
if apiKey := os.Getenv("GROQ_API_KEY"); apiKey != "" {
viper.SetDefault("providers.groq.apiKey", apiKey)
}
if apiKey := os.Getenv("OPENROUTER_API_KEY"); apiKey != "" {
viper.SetDefault("providers.openrouter.apiKey", apiKey)
}
if apiKey := os.Getenv("XAI_API_KEY"); apiKey != "" {
viper.SetDefault("providers.xai.apiKey", apiKey)
}
if apiKey := os.Getenv("AZURE_OPENAI_ENDPOINT"); apiKey != "" {
// api-key may be empty when using Entra ID credentials that's okay
viper.SetDefault("providers.azure.apiKey", os.Getenv("AZURE_OPENAI_API_KEY"))
}
// Use this order to set the default models
// 1. Anthropic
// 2. OpenAI
// 3. Google Gemini
// 4. Groq
// 5. AWS Bedrock
// 5. OpenRouter
// 6. AWS Bedrock
// 7. Azure
// Anthropic configuration
if apiKey := os.Getenv("ANTHROPIC_API_KEY"); apiKey != "" {
viper.SetDefault("agents.coder.model", models.Claude37Sonnet)
if key := viper.GetString("providers.anthropic.apiKey"); strings.TrimSpace(key) != "" {
viper.SetDefault("agents.primary.model", models.Claude37Sonnet)
viper.SetDefault("agents.task.model", models.Claude37Sonnet)
viper.SetDefault("agents.title.model", models.Claude37Sonnet)
return
}
// OpenAI configuration
if apiKey := os.Getenv("OPENAI_API_KEY"); apiKey != "" {
viper.SetDefault("agents.coder.model", models.GPT41)
if key := viper.GetString("providers.openai.apiKey"); strings.TrimSpace(key) != "" {
viper.SetDefault("agents.primary.model", models.GPT41)
viper.SetDefault("agents.task.model", models.GPT41Mini)
viper.SetDefault("agents.title.model", models.GPT41Mini)
return
}
// Google Gemini configuration
if apiKey := os.Getenv("GEMINI_API_KEY"); apiKey != "" {
viper.SetDefault("agents.coder.model", models.Gemini25)
if key := viper.GetString("providers.gemini.apiKey"); strings.TrimSpace(key) != "" {
viper.SetDefault("agents.primary.model", models.Gemini25)
viper.SetDefault("agents.task.model", models.Gemini25Flash)
viper.SetDefault("agents.title.model", models.Gemini25Flash)
return
}
// Groq configuration
if apiKey := os.Getenv("GROQ_API_KEY"); apiKey != "" {
viper.SetDefault("agents.coder.model", models.QWENQwq)
if key := viper.GetString("providers.groq.apiKey"); strings.TrimSpace(key) != "" {
viper.SetDefault("agents.primary.model", models.QWENQwq)
viper.SetDefault("agents.task.model", models.QWENQwq)
viper.SetDefault("agents.title.model", models.QWENQwq)
return
}
// OpenRouter configuration
if apiKey := os.Getenv("OPENROUTER_API_KEY"); apiKey != "" {
viper.SetDefault("providers.openrouter.apiKey", apiKey)
viper.SetDefault("agents.coder.model", models.OpenRouterClaude37Sonnet)
if key := viper.GetString("providers.openrouter.apiKey"); strings.TrimSpace(key) != "" {
viper.SetDefault("agents.primary.model", models.OpenRouterClaude37Sonnet)
viper.SetDefault("agents.task.model", models.OpenRouterClaude37Sonnet)
viper.SetDefault("agents.title.model", models.OpenRouterClaude35Haiku)
return
}
// XAI configuration
if key := viper.GetString("providers.xai.apiKey"); strings.TrimSpace(key) != "" {
viper.SetDefault("agents.primary.model", models.XAIGrok3Beta)
viper.SetDefault("agents.task.model", models.XAIGrok3Beta)
viper.SetDefault("agents.title.model", models.XAiGrok3MiniFastBeta)
return
}
// AWS Bedrock configuration
if hasAWSCredentials() {
viper.SetDefault("agents.coder.model", models.BedrockClaude37Sonnet)
viper.SetDefault("agents.primary.model", models.BedrockClaude37Sonnet)
viper.SetDefault("agents.task.model", models.BedrockClaude37Sonnet)
viper.SetDefault("agents.title.model", models.BedrockClaude37Sonnet)
return
}
// Azure OpenAI configuration
if os.Getenv("AZURE_OPENAI_ENDPOINT") != "" {
// api-key may be empty when using Entra ID credentials that's okay
viper.SetDefault("providers.azure.apiKey", os.Getenv("AZURE_OPENAI_API_KEY"))
viper.SetDefault("agents.coder.model", models.AzureGPT41)
viper.SetDefault("agents.primary.model", models.AzureGPT41)
viper.SetDefault("agents.task.model", models.AzureGPT41Mini)
viper.SetDefault("agents.title.model", models.AzureGPT41Mini)
return
@@ -363,13 +370,13 @@ func validateAgent(cfg *Config, name AgentName, agent Agent) error {
// Check if model exists
model, modelExists := models.SupportedModels[agent.Model]
if !modelExists {
logging.Warn("unsupported model configured, reverting to default",
slog.Warn("unsupported model configured, reverting to default",
"agent", name,
"configured_model", agent.Model)
// Set default model based on available providers
if setDefaultModelForAgent(name) {
logging.Info("set default model for agent", "agent", name, "model", cfg.Agents[name].Model)
slog.Info("set default model for agent", "agent", name, "model", cfg.Agents[name].Model)
} else {
return fmt.Errorf("no valid provider available for agent %s", name)
}
@@ -384,14 +391,14 @@ func validateAgent(cfg *Config, name AgentName, agent Agent) error {
// Provider not configured, check if we have environment variables
apiKey := getProviderAPIKey(provider)
if apiKey == "" {
logging.Warn("provider not configured for model, reverting to default",
slog.Warn("provider not configured for model, reverting to default",
"agent", name,
"model", agent.Model,
"provider", provider)
// Set default model based on available providers
if setDefaultModelForAgent(name) {
logging.Info("set default model for agent", "agent", name, "model", cfg.Agents[name].Model)
slog.Info("set default model for agent", "agent", name, "model", cfg.Agents[name].Model)
} else {
return fmt.Errorf("no valid provider available for agent %s", name)
}
@@ -400,18 +407,18 @@ func validateAgent(cfg *Config, name AgentName, agent Agent) error {
cfg.Providers[provider] = Provider{
APIKey: apiKey,
}
logging.Info("added provider from environment", "provider", provider)
slog.Info("added provider from environment", "provider", provider)
}
} else if providerCfg.Disabled || providerCfg.APIKey == "" {
// Provider is disabled or has no API key
logging.Warn("provider is disabled or has no API key, reverting to default",
slog.Warn("provider is disabled or has no API key, reverting to default",
"agent", name,
"model", agent.Model,
"provider", provider)
// Set default model based on available providers
if setDefaultModelForAgent(name) {
logging.Info("set default model for agent", "agent", name, "model", cfg.Agents[name].Model)
slog.Info("set default model for agent", "agent", name, "model", cfg.Agents[name].Model)
} else {
return fmt.Errorf("no valid provider available for agent %s", name)
}
@@ -419,7 +426,7 @@ func validateAgent(cfg *Config, name AgentName, agent Agent) error {
// Validate max tokens
if agent.MaxTokens <= 0 {
logging.Warn("invalid max tokens, setting to default",
slog.Warn("invalid max tokens, setting to default",
"agent", name,
"model", agent.Model,
"max_tokens", agent.MaxTokens)
@@ -434,7 +441,7 @@ func validateAgent(cfg *Config, name AgentName, agent Agent) error {
cfg.Agents[name] = updatedAgent
} else if model.ContextWindow > 0 && agent.MaxTokens > model.ContextWindow/2 {
// Ensure max tokens doesn't exceed half the context window (reasonable limit)
logging.Warn("max tokens exceeds half the context window, adjusting",
slog.Warn("max tokens exceeds half the context window, adjusting",
"agent", name,
"model", agent.Model,
"max_tokens", agent.MaxTokens,
@@ -450,7 +457,7 @@ func validateAgent(cfg *Config, name AgentName, agent Agent) error {
if model.CanReason && provider == models.ProviderOpenAI {
if agent.ReasoningEffort == "" {
// Set default reasoning effort for models that support it
logging.Info("setting default reasoning effort for model that supports reasoning",
slog.Info("setting default reasoning effort for model that supports reasoning",
"agent", name,
"model", agent.Model)
@@ -462,7 +469,7 @@ func validateAgent(cfg *Config, name AgentName, agent Agent) error {
// Check if reasoning effort is valid (low, medium, high)
effort := strings.ToLower(agent.ReasoningEffort)
if effort != "low" && effort != "medium" && effort != "high" {
logging.Warn("invalid reasoning effort, setting to medium",
slog.Warn("invalid reasoning effort, setting to medium",
"agent", name,
"model", agent.Model,
"reasoning_effort", agent.ReasoningEffort)
@@ -475,7 +482,7 @@ func validateAgent(cfg *Config, name AgentName, agent Agent) error {
}
} else if !model.CanReason && agent.ReasoningEffort != "" {
// Model doesn't support reasoning but reasoning effort is set
logging.Warn("model doesn't support reasoning but reasoning effort is set, ignoring",
slog.Warn("model doesn't support reasoning but reasoning effort is set, ignoring",
"agent", name,
"model", agent.Model,
"reasoning_effort", agent.ReasoningEffort)
@@ -505,7 +512,7 @@ func Validate() error {
// Validate providers
for provider, providerCfg := range cfg.Providers {
if providerCfg.APIKey == "" && !providerCfg.Disabled {
logging.Warn("provider has no API key, marking as disabled", "provider", provider)
slog.Warn("provider has no API key, marking as disabled", "provider", provider)
providerCfg.Disabled = true
cfg.Providers[provider] = providerCfg
}
@@ -514,7 +521,7 @@ func Validate() error {
// Validate LSP configurations
for language, lspConfig := range cfg.LSP {
if lspConfig.Command == "" && !lspConfig.Disabled {
logging.Warn("LSP configuration has no command, marking as disabled", "language", language)
slog.Warn("LSP configuration has no command, marking as disabled", "language", language)
lspConfig.Disabled = true
cfg.LSP[language] = lspConfig
}
@@ -679,6 +686,24 @@ func WorkingDirectory() string {
return cfg.WorkingDir
}
// GetHostname returns the system hostname or "User" if it can't be determined
func GetHostname() (string, error) {
hostname, err := os.Hostname()
if err != nil {
return "User", err
}
return hostname, nil
}
// GetUsername returns the current user's username
func GetUsername() (string, error) {
currentUser, err := user.Current()
if err != nil {
return "User", err
}
return currentUser.Username, nil
}
func UpdateAgentModel(agentName AgentName, modelID models.ModelID) error {
if cfg == nil {
panic("config not loaded")
@@ -711,3 +736,62 @@ func UpdateAgentModel(agentName AgentName, modelID models.ModelID) error {
return nil
}
// UpdateTheme updates the theme in the configuration and writes it to the config file.
func UpdateTheme(themeName string) error {
if cfg == nil {
return fmt.Errorf("config not loaded")
}
// Update the in-memory config
cfg.TUI.Theme = themeName
// Get the config file path
configFile := viper.ConfigFileUsed()
var configData []byte
if configFile == "" {
homeDir, err := os.UserHomeDir()
if err != nil {
return fmt.Errorf("failed to get home directory: %w", err)
}
configFile = filepath.Join(homeDir, fmt.Sprintf(".%s.json", appName))
slog.Info("config file not found, creating new one", "path", configFile)
configData = []byte(`{}`)
} else {
// Read the existing config file
data, err := os.ReadFile(configFile)
if err != nil {
return fmt.Errorf("failed to read config file: %w", err)
}
configData = data
}
// Parse the JSON
var configMap map[string]any
if err := json.Unmarshal(configData, &configMap); err != nil {
return fmt.Errorf("failed to parse config file: %w", err)
}
// Update just the theme value
tuiConfig, ok := configMap["tui"].(map[string]any)
if !ok {
// TUI config doesn't exist yet, create it
configMap["tui"] = map[string]any{"theme": themeName}
} else {
// Update existing TUI config
tuiConfig["theme"] = themeName
configMap["tui"] = tuiConfig
}
// Write the updated config back to file
updatedData, err := json.MarshalIndent(configMap, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal config: %w", err)
}
if err := os.WriteFile(configFile, updatedData, 0o644); err != nil {
return fmt.Errorf("failed to write config file: %w", err)
}
return nil
}

View File

@@ -58,4 +58,3 @@ func MarkProjectInitialized() error {
return nil
}

View File

@@ -9,8 +9,8 @@ import (
_ "github.com/ncruces/go-sqlite3/driver"
_ "github.com/ncruces/go-sqlite3/embed"
"github.com/opencode-ai/opencode/internal/config"
"github.com/opencode-ai/opencode/internal/logging"
"github.com/sst/opencode/internal/config"
"log/slog"
"github.com/pressly/goose/v3"
)
@@ -47,21 +47,21 @@ func Connect() (*sql.DB, error) {
for _, pragma := range pragmas {
if _, err = db.Exec(pragma); err != nil {
logging.Error("Failed to set pragma", pragma, err)
slog.Error("Failed to set pragma", pragma, err)
} else {
logging.Debug("Set pragma", "pragma", pragma)
slog.Debug("Set pragma", "pragma", pragma)
}
}
goose.SetBaseFS(FS)
if err := goose.SetDialect("sqlite3"); err != nil {
logging.Error("Failed to set dialect", "error", err)
slog.Error("Failed to set dialect", "error", err)
return nil, fmt.Errorf("failed to set dialect: %w", err)
}
if err := goose.Up(db, "migrations"); err != nil {
logging.Error("Failed to apply migrations", "error", err)
slog.Error("Failed to apply migrations", "error", err)
return nil, fmt.Errorf("failed to apply migrations: %w", err)
}
return db, nil

View File

@@ -1,6 +1,6 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.27.0
// sqlc v1.29.0
package db
@@ -27,6 +27,9 @@ func Prepare(ctx context.Context, db DBTX) (*Queries, error) {
if q.createFileStmt, err = db.PrepareContext(ctx, createFile); err != nil {
return nil, fmt.Errorf("error preparing query CreateFile: %w", err)
}
if q.createLogStmt, err = db.PrepareContext(ctx, createLog); err != nil {
return nil, fmt.Errorf("error preparing query CreateLog: %w", err)
}
if q.createMessageStmt, err = db.PrepareContext(ctx, createMessage); err != nil {
return nil, fmt.Errorf("error preparing query CreateMessage: %w", err)
}
@@ -60,6 +63,9 @@ func Prepare(ctx context.Context, db DBTX) (*Queries, error) {
if q.getSessionByIDStmt, err = db.PrepareContext(ctx, getSessionByID); err != nil {
return nil, fmt.Errorf("error preparing query GetSessionByID: %w", err)
}
if q.listAllLogsStmt, err = db.PrepareContext(ctx, listAllLogs); err != nil {
return nil, fmt.Errorf("error preparing query ListAllLogs: %w", err)
}
if q.listFilesByPathStmt, err = db.PrepareContext(ctx, listFilesByPath); err != nil {
return nil, fmt.Errorf("error preparing query ListFilesByPath: %w", err)
}
@@ -69,9 +75,15 @@ func Prepare(ctx context.Context, db DBTX) (*Queries, error) {
if q.listLatestSessionFilesStmt, err = db.PrepareContext(ctx, listLatestSessionFiles); err != nil {
return nil, fmt.Errorf("error preparing query ListLatestSessionFiles: %w", err)
}
if q.listLogsBySessionStmt, err = db.PrepareContext(ctx, listLogsBySession); err != nil {
return nil, fmt.Errorf("error preparing query ListLogsBySession: %w", err)
}
if q.listMessagesBySessionStmt, err = db.PrepareContext(ctx, listMessagesBySession); err != nil {
return nil, fmt.Errorf("error preparing query ListMessagesBySession: %w", err)
}
if q.listMessagesBySessionAfterStmt, err = db.PrepareContext(ctx, listMessagesBySessionAfter); err != nil {
return nil, fmt.Errorf("error preparing query ListMessagesBySessionAfter: %w", err)
}
if q.listNewFilesStmt, err = db.PrepareContext(ctx, listNewFiles); err != nil {
return nil, fmt.Errorf("error preparing query ListNewFiles: %w", err)
}
@@ -97,6 +109,11 @@ func (q *Queries) Close() error {
err = fmt.Errorf("error closing createFileStmt: %w", cerr)
}
}
if q.createLogStmt != nil {
if cerr := q.createLogStmt.Close(); cerr != nil {
err = fmt.Errorf("error closing createLogStmt: %w", cerr)
}
}
if q.createMessageStmt != nil {
if cerr := q.createMessageStmt.Close(); cerr != nil {
err = fmt.Errorf("error closing createMessageStmt: %w", cerr)
@@ -152,6 +169,11 @@ func (q *Queries) Close() error {
err = fmt.Errorf("error closing getSessionByIDStmt: %w", cerr)
}
}
if q.listAllLogsStmt != nil {
if cerr := q.listAllLogsStmt.Close(); cerr != nil {
err = fmt.Errorf("error closing listAllLogsStmt: %w", cerr)
}
}
if q.listFilesByPathStmt != nil {
if cerr := q.listFilesByPathStmt.Close(); cerr != nil {
err = fmt.Errorf("error closing listFilesByPathStmt: %w", cerr)
@@ -167,11 +189,21 @@ func (q *Queries) Close() error {
err = fmt.Errorf("error closing listLatestSessionFilesStmt: %w", cerr)
}
}
if q.listLogsBySessionStmt != nil {
if cerr := q.listLogsBySessionStmt.Close(); cerr != nil {
err = fmt.Errorf("error closing listLogsBySessionStmt: %w", cerr)
}
}
if q.listMessagesBySessionStmt != nil {
if cerr := q.listMessagesBySessionStmt.Close(); cerr != nil {
err = fmt.Errorf("error closing listMessagesBySessionStmt: %w", cerr)
}
}
if q.listMessagesBySessionAfterStmt != nil {
if cerr := q.listMessagesBySessionAfterStmt.Close(); cerr != nil {
err = fmt.Errorf("error closing listMessagesBySessionAfterStmt: %w", cerr)
}
}
if q.listNewFilesStmt != nil {
if cerr := q.listNewFilesStmt.Close(); cerr != nil {
err = fmt.Errorf("error closing listNewFilesStmt: %w", cerr)
@@ -234,55 +266,63 @@ func (q *Queries) queryRow(ctx context.Context, stmt *sql.Stmt, query string, ar
}
type Queries struct {
db DBTX
tx *sql.Tx
createFileStmt *sql.Stmt
createMessageStmt *sql.Stmt
createSessionStmt *sql.Stmt
deleteFileStmt *sql.Stmt
deleteMessageStmt *sql.Stmt
deleteSessionStmt *sql.Stmt
deleteSessionFilesStmt *sql.Stmt
deleteSessionMessagesStmt *sql.Stmt
getFileStmt *sql.Stmt
getFileByPathAndSessionStmt *sql.Stmt
getMessageStmt *sql.Stmt
getSessionByIDStmt *sql.Stmt
listFilesByPathStmt *sql.Stmt
listFilesBySessionStmt *sql.Stmt
listLatestSessionFilesStmt *sql.Stmt
listMessagesBySessionStmt *sql.Stmt
listNewFilesStmt *sql.Stmt
listSessionsStmt *sql.Stmt
updateFileStmt *sql.Stmt
updateMessageStmt *sql.Stmt
updateSessionStmt *sql.Stmt
db DBTX
tx *sql.Tx
createFileStmt *sql.Stmt
createLogStmt *sql.Stmt
createMessageStmt *sql.Stmt
createSessionStmt *sql.Stmt
deleteFileStmt *sql.Stmt
deleteMessageStmt *sql.Stmt
deleteSessionStmt *sql.Stmt
deleteSessionFilesStmt *sql.Stmt
deleteSessionMessagesStmt *sql.Stmt
getFileStmt *sql.Stmt
getFileByPathAndSessionStmt *sql.Stmt
getMessageStmt *sql.Stmt
getSessionByIDStmt *sql.Stmt
listAllLogsStmt *sql.Stmt
listFilesByPathStmt *sql.Stmt
listFilesBySessionStmt *sql.Stmt
listLatestSessionFilesStmt *sql.Stmt
listLogsBySessionStmt *sql.Stmt
listMessagesBySessionStmt *sql.Stmt
listMessagesBySessionAfterStmt *sql.Stmt
listNewFilesStmt *sql.Stmt
listSessionsStmt *sql.Stmt
updateFileStmt *sql.Stmt
updateMessageStmt *sql.Stmt
updateSessionStmt *sql.Stmt
}
func (q *Queries) WithTx(tx *sql.Tx) *Queries {
return &Queries{
db: tx,
tx: tx,
createFileStmt: q.createFileStmt,
createMessageStmt: q.createMessageStmt,
createSessionStmt: q.createSessionStmt,
deleteFileStmt: q.deleteFileStmt,
deleteMessageStmt: q.deleteMessageStmt,
deleteSessionStmt: q.deleteSessionStmt,
deleteSessionFilesStmt: q.deleteSessionFilesStmt,
deleteSessionMessagesStmt: q.deleteSessionMessagesStmt,
getFileStmt: q.getFileStmt,
getFileByPathAndSessionStmt: q.getFileByPathAndSessionStmt,
getMessageStmt: q.getMessageStmt,
getSessionByIDStmt: q.getSessionByIDStmt,
listFilesByPathStmt: q.listFilesByPathStmt,
listFilesBySessionStmt: q.listFilesBySessionStmt,
listLatestSessionFilesStmt: q.listLatestSessionFilesStmt,
listMessagesBySessionStmt: q.listMessagesBySessionStmt,
listNewFilesStmt: q.listNewFilesStmt,
listSessionsStmt: q.listSessionsStmt,
updateFileStmt: q.updateFileStmt,
updateMessageStmt: q.updateMessageStmt,
updateSessionStmt: q.updateSessionStmt,
db: tx,
tx: tx,
createFileStmt: q.createFileStmt,
createLogStmt: q.createLogStmt,
createMessageStmt: q.createMessageStmt,
createSessionStmt: q.createSessionStmt,
deleteFileStmt: q.deleteFileStmt,
deleteMessageStmt: q.deleteMessageStmt,
deleteSessionStmt: q.deleteSessionStmt,
deleteSessionFilesStmt: q.deleteSessionFilesStmt,
deleteSessionMessagesStmt: q.deleteSessionMessagesStmt,
getFileStmt: q.getFileStmt,
getFileByPathAndSessionStmt: q.getFileByPathAndSessionStmt,
getMessageStmt: q.getMessageStmt,
getSessionByIDStmt: q.getSessionByIDStmt,
listAllLogsStmt: q.listAllLogsStmt,
listFilesByPathStmt: q.listFilesByPathStmt,
listFilesBySessionStmt: q.listFilesBySessionStmt,
listLatestSessionFilesStmt: q.listLatestSessionFilesStmt,
listLogsBySessionStmt: q.listLogsBySessionStmt,
listMessagesBySessionStmt: q.listMessagesBySessionStmt,
listMessagesBySessionAfterStmt: q.listMessagesBySessionAfterStmt,
listNewFilesStmt: q.listNewFilesStmt,
listSessionsStmt: q.listSessionsStmt,
updateFileStmt: q.updateFileStmt,
updateMessageStmt: q.updateMessageStmt,
updateSessionStmt: q.updateSessionStmt,
}
}

View File

@@ -1,6 +1,6 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.27.0
// sqlc v1.29.0
// source: files.sql
package db
@@ -15,13 +15,11 @@ INSERT INTO files (
session_id,
path,
content,
version,
created_at,
updated_at
version
) VALUES (
?, ?, ?, ?, ?, strftime('%s', 'now'), strftime('%s', 'now')
?, ?, ?, ?, ?
)
RETURNING id, session_id, path, content, version, created_at, updated_at
RETURNING id, session_id, path, content, version, is_new, created_at, updated_at
`
type CreateFileParams struct {
@@ -47,6 +45,7 @@ func (q *Queries) CreateFile(ctx context.Context, arg CreateFileParams) (File, e
&i.Path,
&i.Content,
&i.Version,
&i.IsNew,
&i.CreatedAt,
&i.UpdatedAt,
)
@@ -74,7 +73,7 @@ func (q *Queries) DeleteSessionFiles(ctx context.Context, sessionID string) erro
}
const getFile = `-- name: GetFile :one
SELECT id, session_id, path, content, version, created_at, updated_at
SELECT id, session_id, path, content, version, is_new, created_at, updated_at
FROM files
WHERE id = ? LIMIT 1
`
@@ -88,6 +87,7 @@ func (q *Queries) GetFile(ctx context.Context, id string) (File, error) {
&i.Path,
&i.Content,
&i.Version,
&i.IsNew,
&i.CreatedAt,
&i.UpdatedAt,
)
@@ -95,7 +95,7 @@ func (q *Queries) GetFile(ctx context.Context, id string) (File, error) {
}
const getFileByPathAndSession = `-- name: GetFileByPathAndSession :one
SELECT id, session_id, path, content, version, created_at, updated_at
SELECT id, session_id, path, content, version, is_new, created_at, updated_at
FROM files
WHERE path = ? AND session_id = ?
ORDER BY created_at DESC
@@ -116,6 +116,7 @@ func (q *Queries) GetFileByPathAndSession(ctx context.Context, arg GetFileByPath
&i.Path,
&i.Content,
&i.Version,
&i.IsNew,
&i.CreatedAt,
&i.UpdatedAt,
)
@@ -123,7 +124,7 @@ func (q *Queries) GetFileByPathAndSession(ctx context.Context, arg GetFileByPath
}
const listFilesByPath = `-- name: ListFilesByPath :many
SELECT id, session_id, path, content, version, created_at, updated_at
SELECT id, session_id, path, content, version, is_new, created_at, updated_at
FROM files
WHERE path = ?
ORDER BY created_at DESC
@@ -144,6 +145,7 @@ func (q *Queries) ListFilesByPath(ctx context.Context, path string) ([]File, err
&i.Path,
&i.Content,
&i.Version,
&i.IsNew,
&i.CreatedAt,
&i.UpdatedAt,
); err != nil {
@@ -161,7 +163,7 @@ func (q *Queries) ListFilesByPath(ctx context.Context, path string) ([]File, err
}
const listFilesBySession = `-- name: ListFilesBySession :many
SELECT id, session_id, path, content, version, created_at, updated_at
SELECT id, session_id, path, content, version, is_new, created_at, updated_at
FROM files
WHERE session_id = ?
ORDER BY created_at ASC
@@ -182,6 +184,7 @@ func (q *Queries) ListFilesBySession(ctx context.Context, sessionID string) ([]F
&i.Path,
&i.Content,
&i.Version,
&i.IsNew,
&i.CreatedAt,
&i.UpdatedAt,
); err != nil {
@@ -199,7 +202,7 @@ func (q *Queries) ListFilesBySession(ctx context.Context, sessionID string) ([]F
}
const listLatestSessionFiles = `-- name: ListLatestSessionFiles :many
SELECT f.id, f.session_id, f.path, f.content, f.version, f.created_at, f.updated_at
SELECT f.id, f.session_id, f.path, f.content, f.version, f.is_new, f.created_at, f.updated_at
FROM files f
INNER JOIN (
SELECT path, MAX(created_at) as max_created_at
@@ -225,6 +228,7 @@ func (q *Queries) ListLatestSessionFiles(ctx context.Context, sessionID string)
&i.Path,
&i.Content,
&i.Version,
&i.IsNew,
&i.CreatedAt,
&i.UpdatedAt,
); err != nil {
@@ -242,7 +246,7 @@ func (q *Queries) ListLatestSessionFiles(ctx context.Context, sessionID string)
}
const listNewFiles = `-- name: ListNewFiles :many
SELECT id, session_id, path, content, version, created_at, updated_at
SELECT id, session_id, path, content, version, is_new, created_at, updated_at
FROM files
WHERE is_new = 1
ORDER BY created_at DESC
@@ -263,6 +267,7 @@ func (q *Queries) ListNewFiles(ctx context.Context) ([]File, error) {
&i.Path,
&i.Content,
&i.Version,
&i.IsNew,
&i.CreatedAt,
&i.UpdatedAt,
); err != nil {
@@ -284,9 +289,9 @@ UPDATE files
SET
content = ?,
version = ?,
updated_at = strftime('%s', 'now')
updated_at = strftime('%Y-%m-%dT%H:%M:%f000Z', 'now')
WHERE id = ?
RETURNING id, session_id, path, content, version, created_at, updated_at
RETURNING id, session_id, path, content, version, is_new, created_at, updated_at
`
type UpdateFileParams struct {
@@ -304,6 +309,7 @@ func (q *Queries) UpdateFile(ctx context.Context, arg UpdateFileParams) (File, e
&i.Path,
&i.Content,
&i.Version,
&i.IsNew,
&i.CreatedAt,
&i.UpdatedAt,
)

137
internal/db/logs.sql.go Normal file
View File

@@ -0,0 +1,137 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.29.0
// source: logs.sql
package db
import (
"context"
"database/sql"
)
const createLog = `-- name: CreateLog :one
INSERT INTO logs (
id,
session_id,
timestamp,
level,
message,
attributes
) VALUES (
?,
?,
?,
?,
?,
?
) RETURNING id, session_id, timestamp, level, message, attributes, created_at, updated_at
`
type CreateLogParams struct {
ID string `json:"id"`
SessionID sql.NullString `json:"session_id"`
Timestamp string `json:"timestamp"`
Level string `json:"level"`
Message string `json:"message"`
Attributes sql.NullString `json:"attributes"`
}
func (q *Queries) CreateLog(ctx context.Context, arg CreateLogParams) (Log, error) {
row := q.queryRow(ctx, q.createLogStmt, createLog,
arg.ID,
arg.SessionID,
arg.Timestamp,
arg.Level,
arg.Message,
arg.Attributes,
)
var i Log
err := row.Scan(
&i.ID,
&i.SessionID,
&i.Timestamp,
&i.Level,
&i.Message,
&i.Attributes,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const listAllLogs = `-- name: ListAllLogs :many
SELECT id, session_id, timestamp, level, message, attributes, created_at, updated_at FROM logs
ORDER BY timestamp DESC
LIMIT ?
`
func (q *Queries) ListAllLogs(ctx context.Context, limit int64) ([]Log, error) {
rows, err := q.query(ctx, q.listAllLogsStmt, listAllLogs, limit)
if err != nil {
return nil, err
}
defer rows.Close()
items := []Log{}
for rows.Next() {
var i Log
if err := rows.Scan(
&i.ID,
&i.SessionID,
&i.Timestamp,
&i.Level,
&i.Message,
&i.Attributes,
&i.CreatedAt,
&i.UpdatedAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const listLogsBySession = `-- name: ListLogsBySession :many
SELECT id, session_id, timestamp, level, message, attributes, created_at, updated_at FROM logs
WHERE session_id = ?
ORDER BY timestamp DESC
`
func (q *Queries) ListLogsBySession(ctx context.Context, sessionID sql.NullString) ([]Log, error) {
rows, err := q.query(ctx, q.listLogsBySessionStmt, listLogsBySession, sessionID)
if err != nil {
return nil, err
}
defer rows.Close()
items := []Log{}
for rows.Next() {
var i Log
if err := rows.Scan(
&i.ID,
&i.SessionID,
&i.Timestamp,
&i.Level,
&i.Message,
&i.Attributes,
&i.CreatedAt,
&i.UpdatedAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}

View File

@@ -1,6 +1,6 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.27.0
// sqlc v1.29.0
// source: messages.sql
package db
@@ -16,11 +16,9 @@ INSERT INTO messages (
session_id,
role,
parts,
model,
created_at,
updated_at
model
) VALUES (
?, ?, ?, ?, ?, strftime('%s', 'now'), strftime('%s', 'now')
?, ?, ?, ?, ?
)
RETURNING id, session_id, role, parts, model, created_at, updated_at, finished_at
`
@@ -136,19 +134,63 @@ func (q *Queries) ListMessagesBySession(ctx context.Context, sessionID string) (
return items, nil
}
const listMessagesBySessionAfter = `-- name: ListMessagesBySessionAfter :many
SELECT id, session_id, role, parts, model, created_at, updated_at, finished_at
FROM messages
WHERE session_id = ? AND created_at > ?
ORDER BY created_at ASC
`
type ListMessagesBySessionAfterParams struct {
SessionID string `json:"session_id"`
CreatedAt string `json:"created_at"`
}
func (q *Queries) ListMessagesBySessionAfter(ctx context.Context, arg ListMessagesBySessionAfterParams) ([]Message, error) {
rows, err := q.query(ctx, q.listMessagesBySessionAfterStmt, listMessagesBySessionAfter, arg.SessionID, arg.CreatedAt)
if err != nil {
return nil, err
}
defer rows.Close()
items := []Message{}
for rows.Next() {
var i Message
if err := rows.Scan(
&i.ID,
&i.SessionID,
&i.Role,
&i.Parts,
&i.Model,
&i.CreatedAt,
&i.UpdatedAt,
&i.FinishedAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const updateMessage = `-- name: UpdateMessage :exec
UPDATE messages
SET
parts = ?,
finished_at = ?,
updated_at = strftime('%s', 'now')
updated_at = strftime('%Y-%m-%dT%H:%M:%f000Z', 'now')
WHERE id = ?
`
type UpdateMessageParams struct {
Parts string `json:"parts"`
FinishedAt sql.NullInt64 `json:"finished_at"`
ID string `json:"id"`
Parts string `json:"parts"`
FinishedAt sql.NullString `json:"finished_at"`
ID string `json:"id"`
}
func (q *Queries) UpdateMessage(ctx context.Context, arg UpdateMessageParams) error {

View File

@@ -6,17 +6,19 @@ CREATE TABLE IF NOT EXISTS sessions (
parent_session_id TEXT,
title TEXT NOT NULL,
message_count INTEGER NOT NULL DEFAULT 0 CHECK (message_count >= 0),
prompt_tokens INTEGER NOT NULL DEFAULT 0 CHECK (prompt_tokens >= 0),
completion_tokens INTEGER NOT NULL DEFAULT 0 CHECK (completion_tokens>= 0),
prompt_tokens INTEGER NOT NULL DEFAULT 0 CHECK (prompt_tokens >= 0),
completion_tokens INTEGER NOT NULL DEFAULT 0 CHECK (completion_tokens >= 0),
cost REAL NOT NULL DEFAULT 0.0 CHECK (cost >= 0.0),
updated_at INTEGER NOT NULL, -- Unix timestamp in milliseconds
created_at INTEGER NOT NULL -- Unix timestamp in milliseconds
summary TEXT,
summarized_at TEXT,
updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%f000Z', 'now')),
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%f000Z', 'now'))
);
CREATE TRIGGER IF NOT EXISTS update_sessions_updated_at
AFTER UPDATE ON sessions
BEGIN
UPDATE sessions SET updated_at = strftime('%s', 'now')
UPDATE sessions SET updated_at = strftime('%Y-%m-%dT%H:%M:%f000Z', 'now')
WHERE id = new.id;
END;
@@ -27,8 +29,9 @@ CREATE TABLE IF NOT EXISTS files (
path TEXT NOT NULL,
content TEXT NOT NULL,
version TEXT NOT NULL,
created_at INTEGER NOT NULL, -- Unix timestamp in milliseconds
updated_at INTEGER NOT NULL, -- Unix timestamp in milliseconds
is_new INTEGER DEFAULT 0,
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%f000Z', 'now')),
updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%f000Z', 'now')),
FOREIGN KEY (session_id) REFERENCES sessions (id) ON DELETE CASCADE,
UNIQUE(path, session_id, version)
);
@@ -39,7 +42,7 @@ CREATE INDEX IF NOT EXISTS idx_files_path ON files (path);
CREATE TRIGGER IF NOT EXISTS update_files_updated_at
AFTER UPDATE ON files
BEGIN
UPDATE files SET updated_at = strftime('%s', 'now')
UPDATE files SET updated_at = strftime('%Y-%m-%dT%H:%M:%f000Z', 'now')
WHERE id = new.id;
END;
@@ -50,9 +53,9 @@ CREATE TABLE IF NOT EXISTS messages (
role TEXT NOT NULL,
parts TEXT NOT NULL default '[]',
model TEXT,
created_at INTEGER NOT NULL, -- Unix timestamp in milliseconds
updated_at INTEGER NOT NULL, -- Unix timestamp in milliseconds
finished_at INTEGER, -- Unix timestamp in milliseconds
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%f000Z', 'now')),
updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%f000Z', 'now')),
finished_at TEXT,
FOREIGN KEY (session_id) REFERENCES sessions (id) ON DELETE CASCADE
);
@@ -61,7 +64,7 @@ CREATE INDEX IF NOT EXISTS idx_messages_session_id ON messages (session_id);
CREATE TRIGGER IF NOT EXISTS update_messages_updated_at
AFTER UPDATE ON messages
BEGIN
UPDATE messages SET updated_at = strftime('%s', 'now')
UPDATE messages SET updated_at = strftime('%Y-%m-%dT%H:%M:%f000Z', 'now')
WHERE id = new.id;
END;
@@ -81,6 +84,28 @@ UPDATE sessions SET
WHERE id = old.session_id;
END;
-- Logs
CREATE TABLE IF NOT EXISTS logs (
id TEXT PRIMARY KEY,
session_id TEXT REFERENCES sessions(id) ON DELETE CASCADE,
timestamp TEXT NOT NULL,
level TEXT NOT NULL,
message TEXT NOT NULL,
attributes TEXT,
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%f000Z', 'now')),
updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%f000Z', 'now'))
);
CREATE INDEX logs_session_id_idx ON logs(session_id);
CREATE INDEX logs_timestamp_idx ON logs(timestamp);
CREATE TRIGGER IF NOT EXISTS update_logs_updated_at
AFTER UPDATE ON logs
BEGIN
UPDATE logs SET updated_at = strftime('%Y-%m-%dT%H:%M:%f000Z', 'now')
WHERE id = new.id;
END;
-- +goose StatementEnd
-- +goose Down
@@ -88,11 +113,13 @@ END;
DROP TRIGGER IF EXISTS update_sessions_updated_at;
DROP TRIGGER IF EXISTS update_messages_updated_at;
DROP TRIGGER IF EXISTS update_files_updated_at;
DROP TRIGGER IF EXISTS update_logs_updated_at;
DROP TRIGGER IF EXISTS update_session_message_count_on_delete;
DROP TRIGGER IF EXISTS update_session_message_count_on_insert;
DROP TABLE IF EXISTS sessions;
DROP TABLE IF EXISTS logs;
DROP TABLE IF EXISTS messages;
DROP TABLE IF EXISTS files;
DROP TABLE IF EXISTS sessions;
-- +goose StatementEnd

View File

@@ -1,6 +1,6 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.27.0
// sqlc v1.29.0
package db
@@ -9,13 +9,25 @@ import (
)
type File struct {
ID string `json:"id"`
SessionID string `json:"session_id"`
Path string `json:"path"`
Content string `json:"content"`
Version string `json:"version"`
CreatedAt int64 `json:"created_at"`
UpdatedAt int64 `json:"updated_at"`
ID string `json:"id"`
SessionID string `json:"session_id"`
Path string `json:"path"`
Content string `json:"content"`
Version string `json:"version"`
IsNew sql.NullInt64 `json:"is_new"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
}
type Log struct {
ID string `json:"id"`
SessionID sql.NullString `json:"session_id"`
Timestamp string `json:"timestamp"`
Level string `json:"level"`
Message string `json:"message"`
Attributes sql.NullString `json:"attributes"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
}
type Message struct {
@@ -24,9 +36,9 @@ type Message struct {
Role string `json:"role"`
Parts string `json:"parts"`
Model sql.NullString `json:"model"`
CreatedAt int64 `json:"created_at"`
UpdatedAt int64 `json:"updated_at"`
FinishedAt sql.NullInt64 `json:"finished_at"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
FinishedAt sql.NullString `json:"finished_at"`
}
type Session struct {
@@ -37,6 +49,8 @@ type Session struct {
PromptTokens int64 `json:"prompt_tokens"`
CompletionTokens int64 `json:"completion_tokens"`
Cost float64 `json:"cost"`
UpdatedAt int64 `json:"updated_at"`
CreatedAt int64 `json:"created_at"`
Summary sql.NullString `json:"summary"`
SummarizedAt sql.NullString `json:"summarized_at"`
UpdatedAt string `json:"updated_at"`
CreatedAt string `json:"created_at"`
}

View File

@@ -1,15 +1,17 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.27.0
// sqlc v1.29.0
package db
import (
"context"
"database/sql"
)
type Querier interface {
CreateFile(ctx context.Context, arg CreateFileParams) (File, error)
CreateLog(ctx context.Context, arg CreateLogParams) (Log, error)
CreateMessage(ctx context.Context, arg CreateMessageParams) (Message, error)
CreateSession(ctx context.Context, arg CreateSessionParams) (Session, error)
DeleteFile(ctx context.Context, id string) error
@@ -21,10 +23,13 @@ type Querier interface {
GetFileByPathAndSession(ctx context.Context, arg GetFileByPathAndSessionParams) (File, error)
GetMessage(ctx context.Context, id string) (Message, error)
GetSessionByID(ctx context.Context, id string) (Session, error)
ListAllLogs(ctx context.Context, limit int64) ([]Log, error)
ListFilesByPath(ctx context.Context, path string) ([]File, error)
ListFilesBySession(ctx context.Context, sessionID string) ([]File, error)
ListLatestSessionFiles(ctx context.Context, sessionID string) ([]File, error)
ListLogsBySession(ctx context.Context, sessionID sql.NullString) ([]Log, error)
ListMessagesBySession(ctx context.Context, sessionID string) ([]Message, error)
ListMessagesBySessionAfter(ctx context.Context, arg ListMessagesBySessionAfterParams) ([]Message, error)
ListNewFiles(ctx context.Context) ([]File, error)
ListSessions(ctx context.Context) ([]Session, error)
UpdateFile(ctx context.Context, arg UpdateFileParams) (File, error)

View File

@@ -1,6 +1,6 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.27.0
// sqlc v1.29.0
// source: sessions.sql
package db
@@ -19,8 +19,8 @@ INSERT INTO sessions (
prompt_tokens,
completion_tokens,
cost,
updated_at,
created_at
summary,
summarized_at
) VALUES (
?,
?,
@@ -29,9 +29,9 @@ INSERT INTO sessions (
?,
?,
?,
strftime('%s', 'now'),
strftime('%s', 'now')
) RETURNING id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at
?,
?
) RETURNING id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, summary, summarized_at, updated_at, created_at
`
type CreateSessionParams struct {
@@ -42,6 +42,8 @@ type CreateSessionParams struct {
PromptTokens int64 `json:"prompt_tokens"`
CompletionTokens int64 `json:"completion_tokens"`
Cost float64 `json:"cost"`
Summary sql.NullString `json:"summary"`
SummarizedAt sql.NullString `json:"summarized_at"`
}
func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) (Session, error) {
@@ -53,6 +55,8 @@ func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) (S
arg.PromptTokens,
arg.CompletionTokens,
arg.Cost,
arg.Summary,
arg.SummarizedAt,
)
var i Session
err := row.Scan(
@@ -63,6 +67,8 @@ func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) (S
&i.PromptTokens,
&i.CompletionTokens,
&i.Cost,
&i.Summary,
&i.SummarizedAt,
&i.UpdatedAt,
&i.CreatedAt,
)
@@ -80,7 +86,7 @@ func (q *Queries) DeleteSession(ctx context.Context, id string) error {
}
const getSessionByID = `-- name: GetSessionByID :one
SELECT id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at
SELECT id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, summary, summarized_at, updated_at, created_at
FROM sessions
WHERE id = ? LIMIT 1
`
@@ -96,6 +102,8 @@ func (q *Queries) GetSessionByID(ctx context.Context, id string) (Session, error
&i.PromptTokens,
&i.CompletionTokens,
&i.Cost,
&i.Summary,
&i.SummarizedAt,
&i.UpdatedAt,
&i.CreatedAt,
)
@@ -103,7 +111,7 @@ func (q *Queries) GetSessionByID(ctx context.Context, id string) (Session, error
}
const listSessions = `-- name: ListSessions :many
SELECT id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at
SELECT id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, summary, summarized_at, updated_at, created_at
FROM sessions
WHERE parent_session_id is NULL
ORDER BY created_at DESC
@@ -126,6 +134,8 @@ func (q *Queries) ListSessions(ctx context.Context) ([]Session, error) {
&i.PromptTokens,
&i.CompletionTokens,
&i.Cost,
&i.Summary,
&i.SummarizedAt,
&i.UpdatedAt,
&i.CreatedAt,
); err != nil {
@@ -148,17 +158,21 @@ SET
title = ?,
prompt_tokens = ?,
completion_tokens = ?,
cost = ?
cost = ?,
summary = ?,
summarized_at = ?
WHERE id = ?
RETURNING id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at
RETURNING id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, summary, summarized_at, updated_at, created_at
`
type UpdateSessionParams struct {
Title string `json:"title"`
PromptTokens int64 `json:"prompt_tokens"`
CompletionTokens int64 `json:"completion_tokens"`
Cost float64 `json:"cost"`
ID string `json:"id"`
Title string `json:"title"`
PromptTokens int64 `json:"prompt_tokens"`
CompletionTokens int64 `json:"completion_tokens"`
Cost float64 `json:"cost"`
Summary sql.NullString `json:"summary"`
SummarizedAt sql.NullString `json:"summarized_at"`
ID string `json:"id"`
}
func (q *Queries) UpdateSession(ctx context.Context, arg UpdateSessionParams) (Session, error) {
@@ -167,6 +181,8 @@ func (q *Queries) UpdateSession(ctx context.Context, arg UpdateSessionParams) (S
arg.PromptTokens,
arg.CompletionTokens,
arg.Cost,
arg.Summary,
arg.SummarizedAt,
arg.ID,
)
var i Session
@@ -178,6 +194,8 @@ func (q *Queries) UpdateSession(ctx context.Context, arg UpdateSessionParams) (S
&i.PromptTokens,
&i.CompletionTokens,
&i.Cost,
&i.Summary,
&i.SummarizedAt,
&i.UpdatedAt,
&i.CreatedAt,
)

View File

@@ -28,11 +28,9 @@ INSERT INTO files (
session_id,
path,
content,
version,
created_at,
updated_at
version
) VALUES (
?, ?, ?, ?, ?, strftime('%s', 'now'), strftime('%s', 'now')
?, ?, ?, ?, ?
)
RETURNING *;
@@ -41,7 +39,7 @@ UPDATE files
SET
content = ?,
version = ?,
updated_at = strftime('%s', 'now')
updated_at = strftime('%Y-%m-%dT%H:%M:%f000Z', 'now')
WHERE id = ?
RETURNING *;

26
internal/db/sql/logs.sql Normal file
View File

@@ -0,0 +1,26 @@
-- name: CreateLog :one
INSERT INTO logs (
id,
session_id,
timestamp,
level,
message,
attributes
) VALUES (
?,
?,
?,
?,
?,
?
) RETURNING *;
-- name: ListLogsBySession :many
SELECT * FROM logs
WHERE session_id = ?
ORDER BY timestamp DESC;
-- name: ListAllLogs :many
SELECT * FROM logs
ORDER BY timestamp DESC
LIMIT ?;

View File

@@ -9,17 +9,21 @@ FROM messages
WHERE session_id = ?
ORDER BY created_at ASC;
-- name: ListMessagesBySessionAfter :many
SELECT *
FROM messages
WHERE session_id = ? AND created_at > ?
ORDER BY created_at ASC;
-- name: CreateMessage :one
INSERT INTO messages (
id,
session_id,
role,
parts,
model,
created_at,
updated_at
model
) VALUES (
?, ?, ?, ?, ?, strftime('%s', 'now'), strftime('%s', 'now')
?, ?, ?, ?, ?
)
RETURNING *;
@@ -28,7 +32,7 @@ UPDATE messages
SET
parts = ?,
finished_at = ?,
updated_at = strftime('%s', 'now')
updated_at = strftime('%Y-%m-%dT%H:%M:%f000Z', 'now')
WHERE id = ?;

View File

@@ -7,8 +7,8 @@ INSERT INTO sessions (
prompt_tokens,
completion_tokens,
cost,
updated_at,
created_at
summary,
summarized_at
) VALUES (
?,
?,
@@ -17,8 +17,8 @@ INSERT INTO sessions (
?,
?,
?,
strftime('%s', 'now'),
strftime('%s', 'now')
?,
?
) RETURNING *;
-- name: GetSessionByID :one
@@ -38,7 +38,9 @@ SET
title = ?,
prompt_tokens = ?,
completion_tokens = ?,
cost = ?
cost = ?,
summary = ?,
summarized_at = ?
WHERE id = ?
RETURNING *;

View File

@@ -15,8 +15,9 @@ import (
"github.com/aymanbagabas/go-udiff"
"github.com/charmbracelet/lipgloss"
"github.com/charmbracelet/x/ansi"
"github.com/opencode-ai/opencode/internal/config"
"github.com/sergi/go-diff/diffmatchpatch"
"github.com/sst/opencode/internal/config"
"github.com/sst/opencode/internal/tui/theme"
)
// -------------------------------------------------------------------------
@@ -68,143 +69,6 @@ type linePair struct {
right *DiffLine
}
// -------------------------------------------------------------------------
// Style Configuration
// -------------------------------------------------------------------------
// StyleConfig defines styling for diff rendering
type StyleConfig struct {
ShowHeader bool
ShowHunkHeader bool
FileNameFg lipgloss.Color
// Background colors
RemovedLineBg lipgloss.Color
AddedLineBg lipgloss.Color
ContextLineBg lipgloss.Color
HunkLineBg lipgloss.Color
RemovedLineNumberBg lipgloss.Color
AddedLineNamerBg lipgloss.Color
// Foreground colors
HunkLineFg lipgloss.Color
RemovedFg lipgloss.Color
AddedFg lipgloss.Color
LineNumberFg lipgloss.Color
RemovedHighlightFg lipgloss.Color
AddedHighlightFg lipgloss.Color
// Highlight settings
HighlightStyle string
RemovedHighlightBg lipgloss.Color
AddedHighlightBg lipgloss.Color
}
// StyleOption is a function that modifies a StyleConfig
type StyleOption func(*StyleConfig)
// NewStyleConfig creates a StyleConfig with default values
func NewStyleConfig(opts ...StyleOption) StyleConfig {
// Default color scheme
config := StyleConfig{
ShowHeader: true,
ShowHunkHeader: true,
FileNameFg: lipgloss.Color("#a0a0a0"),
RemovedLineBg: lipgloss.Color("#3A3030"),
AddedLineBg: lipgloss.Color("#303A30"),
ContextLineBg: lipgloss.Color("#212121"),
HunkLineBg: lipgloss.Color("#212121"),
HunkLineFg: lipgloss.Color("#a0a0a0"),
RemovedFg: lipgloss.Color("#7C4444"),
AddedFg: lipgloss.Color("#478247"),
LineNumberFg: lipgloss.Color("#888888"),
HighlightStyle: "dracula",
RemovedHighlightBg: lipgloss.Color("#612726"),
AddedHighlightBg: lipgloss.Color("#256125"),
RemovedLineNumberBg: lipgloss.Color("#332929"),
AddedLineNamerBg: lipgloss.Color("#293229"),
RemovedHighlightFg: lipgloss.Color("#FADADD"),
AddedHighlightFg: lipgloss.Color("#DAFADA"),
}
// Apply all provided options
for _, opt := range opts {
opt(&config)
}
return config
}
// Style option functions
func WithFileNameFg(color lipgloss.Color) StyleOption {
return func(s *StyleConfig) { s.FileNameFg = color }
}
func WithRemovedLineBg(color lipgloss.Color) StyleOption {
return func(s *StyleConfig) { s.RemovedLineBg = color }
}
func WithAddedLineBg(color lipgloss.Color) StyleOption {
return func(s *StyleConfig) { s.AddedLineBg = color }
}
func WithContextLineBg(color lipgloss.Color) StyleOption {
return func(s *StyleConfig) { s.ContextLineBg = color }
}
func WithRemovedFg(color lipgloss.Color) StyleOption {
return func(s *StyleConfig) { s.RemovedFg = color }
}
func WithAddedFg(color lipgloss.Color) StyleOption {
return func(s *StyleConfig) { s.AddedFg = color }
}
func WithLineNumberFg(color lipgloss.Color) StyleOption {
return func(s *StyleConfig) { s.LineNumberFg = color }
}
func WithHighlightStyle(style string) StyleOption {
return func(s *StyleConfig) { s.HighlightStyle = style }
}
func WithRemovedHighlightColors(bg, fg lipgloss.Color) StyleOption {
return func(s *StyleConfig) {
s.RemovedHighlightBg = bg
s.RemovedHighlightFg = fg
}
}
func WithAddedHighlightColors(bg, fg lipgloss.Color) StyleOption {
return func(s *StyleConfig) {
s.AddedHighlightBg = bg
s.AddedHighlightFg = fg
}
}
func WithRemovedLineNumberBg(color lipgloss.Color) StyleOption {
return func(s *StyleConfig) { s.RemovedLineNumberBg = color }
}
func WithAddedLineNumberBg(color lipgloss.Color) StyleOption {
return func(s *StyleConfig) { s.AddedLineNamerBg = color }
}
func WithHunkLineBg(color lipgloss.Color) StyleOption {
return func(s *StyleConfig) { s.HunkLineBg = color }
}
func WithHunkLineFg(color lipgloss.Color) StyleOption {
return func(s *StyleConfig) { s.HunkLineFg = color }
}
func WithShowHeader(show bool) StyleOption {
return func(s *StyleConfig) { s.ShowHeader = show }
}
func WithShowHunkHeader(show bool) StyleOption {
return func(s *StyleConfig) { s.ShowHunkHeader = show }
}
// -------------------------------------------------------------------------
// Parse Configuration
// -------------------------------------------------------------------------
@@ -233,7 +97,6 @@ func WithContextSize(size int) ParseOption {
// SideBySideConfig configures the rendering of side-by-side diffs
type SideBySideConfig struct {
TotalWidth int
Style StyleConfig
}
// SideBySideOption modifies a SideBySideConfig
@@ -243,7 +106,6 @@ type SideBySideOption func(*SideBySideConfig)
func NewSideBySideConfig(opts ...SideBySideOption) SideBySideConfig {
config := SideBySideConfig{
TotalWidth: 160, // Default width for side-by-side view
Style: NewStyleConfig(),
}
for _, opt := range opts {
@@ -262,20 +124,6 @@ func WithTotalWidth(width int) SideBySideOption {
}
}
// WithStyle sets the styling configuration
func WithStyle(style StyleConfig) SideBySideOption {
return func(s *SideBySideConfig) {
s.Style = style
}
}
// WithStyleOptions applies the specified style options
func WithStyleOptions(opts ...StyleOption) SideBySideOption {
return func(s *SideBySideConfig) {
s.Style = NewStyleConfig(opts...)
}
}
// -------------------------------------------------------------------------
// Diff Parsing
// -------------------------------------------------------------------------
@@ -382,7 +230,7 @@ func ParseUnifiedDiff(diff string) (DiffResult, error) {
}
// HighlightIntralineChanges updates lines in a hunk to show character-level differences
func HighlightIntralineChanges(h *Hunk, style StyleConfig) {
func HighlightIntralineChanges(h *Hunk) {
var updated []DiffLine
dmp := diffmatchpatch.New()
@@ -476,6 +324,8 @@ func pairLines(lines []DiffLine) []linePair {
// SyntaxHighlight applies syntax highlighting to text based on file extension
func SyntaxHighlight(w io.Writer, source, fileName, formatter string, bg lipgloss.TerminalColor) error {
t := theme.CurrentTheme()
// Determine the language lexer to use
l := lexers.Match(fileName)
if l == nil {
@@ -491,93 +341,175 @@ func SyntaxHighlight(w io.Writer, source, fileName, formatter string, bg lipglos
if f == nil {
f = formatters.Fallback
}
theme := `
<style name="vscode-dark-plus">
<!-- Base colors -->
<entry type="Background" style="bg:#1E1E1E"/>
<entry type="Text" style="#D4D4D4"/>
<entry type="Other" style="#D4D4D4"/>
<entry type="Error" style="#F44747"/>
<!-- Keywords - using the Control flow / Special keywords color -->
<entry type="Keyword" style="#C586C0"/>
<entry type="KeywordConstant" style="#4FC1FF"/>
<entry type="KeywordDeclaration" style="#C586C0"/>
<entry type="KeywordNamespace" style="#C586C0"/>
<entry type="KeywordPseudo" style="#C586C0"/>
<entry type="KeywordReserved" style="#C586C0"/>
<entry type="KeywordType" style="#4EC9B0"/>
<!-- Names -->
<entry type="Name" style="#D4D4D4"/>
<entry type="NameAttribute" style="#9CDCFE"/>
<entry type="NameBuiltin" style="#4EC9B0"/>
<entry type="NameBuiltinPseudo" style="#9CDCFE"/>
<entry type="NameClass" style="#4EC9B0"/>
<entry type="NameConstant" style="#4FC1FF"/>
<entry type="NameDecorator" style="#DCDCAA"/>
<entry type="NameEntity" style="#9CDCFE"/>
<entry type="NameException" style="#4EC9B0"/>
<entry type="NameFunction" style="#DCDCAA"/>
<entry type="NameLabel" style="#C8C8C8"/>
<entry type="NameNamespace" style="#4EC9B0"/>
<entry type="NameOther" style="#9CDCFE"/>
<entry type="NameTag" style="#569CD6"/>
<entry type="NameVariable" style="#9CDCFE"/>
<entry type="NameVariableClass" style="#9CDCFE"/>
<entry type="NameVariableGlobal" style="#9CDCFE"/>
<entry type="NameVariableInstance" style="#9CDCFE"/>
<!-- Literals -->
<entry type="Literal" style="#CE9178"/>
<entry type="LiteralDate" style="#CE9178"/>
<entry type="LiteralString" style="#CE9178"/>
<entry type="LiteralStringBacktick" style="#CE9178"/>
<entry type="LiteralStringChar" style="#CE9178"/>
<entry type="LiteralStringDoc" style="#CE9178"/>
<entry type="LiteralStringDouble" style="#CE9178"/>
<entry type="LiteralStringEscape" style="#d7ba7d"/>
<entry type="LiteralStringHeredoc" style="#CE9178"/>
<entry type="LiteralStringInterpol" style="#CE9178"/>
<entry type="LiteralStringOther" style="#CE9178"/>
<entry type="LiteralStringRegex" style="#d16969"/>
<entry type="LiteralStringSingle" style="#CE9178"/>
<entry type="LiteralStringSymbol" style="#CE9178"/>
<!-- Numbers - using the numberLiteral color -->
<entry type="LiteralNumber" style="#b5cea8"/>
<entry type="LiteralNumberBin" style="#b5cea8"/>
<entry type="LiteralNumberFloat" style="#b5cea8"/>
<entry type="LiteralNumberHex" style="#b5cea8"/>
<entry type="LiteralNumberInteger" style="#b5cea8"/>
<entry type="LiteralNumberIntegerLong" style="#b5cea8"/>
<entry type="LiteralNumberOct" style="#b5cea8"/>
<!-- Operators -->
<entry type="Operator" style="#D4D4D4"/>
<entry type="OperatorWord" style="#C586C0"/>
<entry type="Punctuation" style="#D4D4D4"/>
<!-- Comments - standard VSCode Dark+ comment color -->
<entry type="Comment" style="#6A9955"/>
<entry type="CommentHashbang" style="#6A9955"/>
<entry type="CommentMultiline" style="#6A9955"/>
<entry type="CommentSingle" style="#6A9955"/>
<entry type="CommentSpecial" style="#6A9955"/>
<entry type="CommentPreproc" style="#C586C0"/>
<!-- Generic styles -->
<entry type="Generic" style="#D4D4D4"/>
<entry type="GenericDeleted" style="#F44747"/>
<entry type="GenericEmph" style="italic #D4D4D4"/>
<entry type="GenericError" style="#F44747"/>
<entry type="GenericHeading" style="bold #D4D4D4"/>
<entry type="GenericInserted" style="#b5cea8"/>
<entry type="GenericOutput" style="#808080"/>
<entry type="GenericPrompt" style="#D4D4D4"/>
<entry type="GenericStrong" style="bold #D4D4D4"/>
<entry type="GenericSubheading" style="bold #D4D4D4"/>
<entry type="GenericTraceback" style="#F44747"/>
<entry type="GenericUnderline" style="underline"/>
<entry type="TextWhitespace" style="#D4D4D4"/>
</style>
`
r := strings.NewReader(theme)
// Dynamic theme based on current theme values
syntaxThemeXml := fmt.Sprintf(`
<style name="opencode-theme">
<!-- Base colors -->
<entry type="Background" style="bg:%s"/>
<entry type="Text" style="%s"/>
<entry type="Other" style="%s"/>
<entry type="Error" style="%s"/>
<!-- Keywords -->
<entry type="Keyword" style="%s"/>
<entry type="KeywordConstant" style="%s"/>
<entry type="KeywordDeclaration" style="%s"/>
<entry type="KeywordNamespace" style="%s"/>
<entry type="KeywordPseudo" style="%s"/>
<entry type="KeywordReserved" style="%s"/>
<entry type="KeywordType" style="%s"/>
<!-- Names -->
<entry type="Name" style="%s"/>
<entry type="NameAttribute" style="%s"/>
<entry type="NameBuiltin" style="%s"/>
<entry type="NameBuiltinPseudo" style="%s"/>
<entry type="NameClass" style="%s"/>
<entry type="NameConstant" style="%s"/>
<entry type="NameDecorator" style="%s"/>
<entry type="NameEntity" style="%s"/>
<entry type="NameException" style="%s"/>
<entry type="NameFunction" style="%s"/>
<entry type="NameLabel" style="%s"/>
<entry type="NameNamespace" style="%s"/>
<entry type="NameOther" style="%s"/>
<entry type="NameTag" style="%s"/>
<entry type="NameVariable" style="%s"/>
<entry type="NameVariableClass" style="%s"/>
<entry type="NameVariableGlobal" style="%s"/>
<entry type="NameVariableInstance" style="%s"/>
<!-- Literals -->
<entry type="Literal" style="%s"/>
<entry type="LiteralDate" style="%s"/>
<entry type="LiteralString" style="%s"/>
<entry type="LiteralStringBacktick" style="%s"/>
<entry type="LiteralStringChar" style="%s"/>
<entry type="LiteralStringDoc" style="%s"/>
<entry type="LiteralStringDouble" style="%s"/>
<entry type="LiteralStringEscape" style="%s"/>
<entry type="LiteralStringHeredoc" style="%s"/>
<entry type="LiteralStringInterpol" style="%s"/>
<entry type="LiteralStringOther" style="%s"/>
<entry type="LiteralStringRegex" style="%s"/>
<entry type="LiteralStringSingle" style="%s"/>
<entry type="LiteralStringSymbol" style="%s"/>
<!-- Numbers -->
<entry type="LiteralNumber" style="%s"/>
<entry type="LiteralNumberBin" style="%s"/>
<entry type="LiteralNumberFloat" style="%s"/>
<entry type="LiteralNumberHex" style="%s"/>
<entry type="LiteralNumberInteger" style="%s"/>
<entry type="LiteralNumberIntegerLong" style="%s"/>
<entry type="LiteralNumberOct" style="%s"/>
<!-- Operators -->
<entry type="Operator" style="%s"/>
<entry type="OperatorWord" style="%s"/>
<entry type="Punctuation" style="%s"/>
<!-- Comments -->
<entry type="Comment" style="%s"/>
<entry type="CommentHashbang" style="%s"/>
<entry type="CommentMultiline" style="%s"/>
<entry type="CommentSingle" style="%s"/>
<entry type="CommentSpecial" style="%s"/>
<entry type="CommentPreproc" style="%s"/>
<!-- Generic styles -->
<entry type="Generic" style="%s"/>
<entry type="GenericDeleted" style="%s"/>
<entry type="GenericEmph" style="italic %s"/>
<entry type="GenericError" style="%s"/>
<entry type="GenericHeading" style="bold %s"/>
<entry type="GenericInserted" style="%s"/>
<entry type="GenericOutput" style="%s"/>
<entry type="GenericPrompt" style="%s"/>
<entry type="GenericStrong" style="bold %s"/>
<entry type="GenericSubheading" style="bold %s"/>
<entry type="GenericTraceback" style="%s"/>
<entry type="GenericUnderline" style="underline"/>
<entry type="TextWhitespace" style="%s"/>
</style>
`,
getColor(t.Background()), // Background
getColor(t.Text()), // Text
getColor(t.Text()), // Other
getColor(t.Error()), // Error
getColor(t.SyntaxKeyword()), // Keyword
getColor(t.SyntaxKeyword()), // KeywordConstant
getColor(t.SyntaxKeyword()), // KeywordDeclaration
getColor(t.SyntaxKeyword()), // KeywordNamespace
getColor(t.SyntaxKeyword()), // KeywordPseudo
getColor(t.SyntaxKeyword()), // KeywordReserved
getColor(t.SyntaxType()), // KeywordType
getColor(t.Text()), // Name
getColor(t.SyntaxVariable()), // NameAttribute
getColor(t.SyntaxType()), // NameBuiltin
getColor(t.SyntaxVariable()), // NameBuiltinPseudo
getColor(t.SyntaxType()), // NameClass
getColor(t.SyntaxVariable()), // NameConstant
getColor(t.SyntaxFunction()), // NameDecorator
getColor(t.SyntaxVariable()), // NameEntity
getColor(t.SyntaxType()), // NameException
getColor(t.SyntaxFunction()), // NameFunction
getColor(t.Text()), // NameLabel
getColor(t.SyntaxType()), // NameNamespace
getColor(t.SyntaxVariable()), // NameOther
getColor(t.SyntaxKeyword()), // NameTag
getColor(t.SyntaxVariable()), // NameVariable
getColor(t.SyntaxVariable()), // NameVariableClass
getColor(t.SyntaxVariable()), // NameVariableGlobal
getColor(t.SyntaxVariable()), // NameVariableInstance
getColor(t.SyntaxString()), // Literal
getColor(t.SyntaxString()), // LiteralDate
getColor(t.SyntaxString()), // LiteralString
getColor(t.SyntaxString()), // LiteralStringBacktick
getColor(t.SyntaxString()), // LiteralStringChar
getColor(t.SyntaxString()), // LiteralStringDoc
getColor(t.SyntaxString()), // LiteralStringDouble
getColor(t.SyntaxString()), // LiteralStringEscape
getColor(t.SyntaxString()), // LiteralStringHeredoc
getColor(t.SyntaxString()), // LiteralStringInterpol
getColor(t.SyntaxString()), // LiteralStringOther
getColor(t.SyntaxString()), // LiteralStringRegex
getColor(t.SyntaxString()), // LiteralStringSingle
getColor(t.SyntaxString()), // LiteralStringSymbol
getColor(t.SyntaxNumber()), // LiteralNumber
getColor(t.SyntaxNumber()), // LiteralNumberBin
getColor(t.SyntaxNumber()), // LiteralNumberFloat
getColor(t.SyntaxNumber()), // LiteralNumberHex
getColor(t.SyntaxNumber()), // LiteralNumberInteger
getColor(t.SyntaxNumber()), // LiteralNumberIntegerLong
getColor(t.SyntaxNumber()), // LiteralNumberOct
getColor(t.SyntaxOperator()), // Operator
getColor(t.SyntaxKeyword()), // OperatorWord
getColor(t.SyntaxPunctuation()), // Punctuation
getColor(t.SyntaxComment()), // Comment
getColor(t.SyntaxComment()), // CommentHashbang
getColor(t.SyntaxComment()), // CommentMultiline
getColor(t.SyntaxComment()), // CommentSingle
getColor(t.SyntaxComment()), // CommentSpecial
getColor(t.SyntaxKeyword()), // CommentPreproc
getColor(t.Text()), // Generic
getColor(t.Error()), // GenericDeleted
getColor(t.Text()), // GenericEmph
getColor(t.Error()), // GenericError
getColor(t.Text()), // GenericHeading
getColor(t.Success()), // GenericInserted
getColor(t.TextMuted()), // GenericOutput
getColor(t.Text()), // GenericPrompt
getColor(t.Text()), // GenericStrong
getColor(t.Text()), // GenericSubheading
getColor(t.Error()), // GenericTraceback
getColor(t.Text()), // TextWhitespace
)
r := strings.NewReader(syntaxThemeXml)
style := chroma.MustNewXMLStyle(r)
// Modify the style to use the provided background
s, err := style.Builder().Transform(
func(t chroma.StyleEntry) chroma.StyleEntry {
@@ -599,6 +531,14 @@ func SyntaxHighlight(w io.Writer, source, fileName, formatter string, bg lipglos
return f.Format(w, s, it)
}
// getColor returns the appropriate hex color string based on terminal background
func getColor(adaptiveColor lipgloss.AdaptiveColor) string {
if lipgloss.HasDarkBackground() {
return adaptiveColor.Dark
}
return adaptiveColor.Light
}
// highlightLine applies syntax highlighting to a single line
func highlightLine(fileName string, line string, bg lipgloss.TerminalColor) string {
var buf bytes.Buffer
@@ -610,11 +550,11 @@ func highlightLine(fileName string, line string, bg lipgloss.TerminalColor) stri
}
// createStyles generates the lipgloss styles needed for rendering diffs
func createStyles(config StyleConfig) (removedLineStyle, addedLineStyle, contextLineStyle, lineNumberStyle lipgloss.Style) {
removedLineStyle = lipgloss.NewStyle().Background(config.RemovedLineBg)
addedLineStyle = lipgloss.NewStyle().Background(config.AddedLineBg)
contextLineStyle = lipgloss.NewStyle().Background(config.ContextLineBg)
lineNumberStyle = lipgloss.NewStyle().Foreground(config.LineNumberFg)
func createStyles(t theme.Theme) (removedLineStyle, addedLineStyle, contextLineStyle, lineNumberStyle lipgloss.Style) {
removedLineStyle = lipgloss.NewStyle().Background(t.DiffRemovedBg())
addedLineStyle = lipgloss.NewStyle().Background(t.DiffAddedBg())
contextLineStyle = lipgloss.NewStyle().Background(t.DiffContextBg())
lineNumberStyle = lipgloss.NewStyle().Foreground(t.DiffLineNumber())
return
}
@@ -624,8 +564,7 @@ func createStyles(config StyleConfig) (removedLineStyle, addedLineStyle, context
// -------------------------------------------------------------------------
// applyHighlighting applies intra-line highlighting to a piece of text
func applyHighlighting(content string, segments []Segment, segmentType LineType, highlightBg lipgloss.Color,
) string {
func applyHighlighting(content string, segments []Segment, segmentType LineType, highlightBg lipgloss.AdaptiveColor) string {
// Find all ANSI sequences in the content
ansiRegex := regexp.MustCompile(`\x1b(?:[@-Z\\-_]|\[[0-9?]*(?:;[0-9?]*)*[@-~])`)
ansiMatches := ansiRegex.FindAllStringIndex(content, -1)
@@ -663,6 +602,10 @@ func applyHighlighting(content string, segments []Segment, segmentType LineType,
inSelection := false
currentPos := 0
// Get the appropriate color based on terminal background
bgColor := lipgloss.Color(getColor(highlightBg))
fgColor := lipgloss.Color(getColor(theme.CurrentTheme().Background()))
for i := 0; i < len(content); {
// Check if we're at an ANSI sequence
isAnsi := false
@@ -697,12 +640,17 @@ func applyHighlighting(content string, segments []Segment, segmentType LineType,
// Get the current styling
currentStyle := ansiSequences[currentPos]
// Apply background highlight
// Apply foreground and background highlight
sb.WriteString("\x1b[38;2;")
r, g, b, _ := fgColor.RGBA()
sb.WriteString(fmt.Sprintf("%d;%d;%dm", r>>8, g>>8, b>>8))
sb.WriteString("\x1b[48;2;")
r, g, b, _ := highlightBg.RGBA()
r, g, b, _ = bgColor.RGBA()
sb.WriteString(fmt.Sprintf("%d;%d;%dm", r>>8, g>>8, b>>8))
sb.WriteString(char)
sb.WriteString("\x1b[49m") // Reset only background
// Full reset of all attributes to ensure clean state
sb.WriteString("\x1b[0m")
// Reapply the original ANSI sequence
sb.WriteString(currentStyle)
@@ -718,50 +666,98 @@ func applyHighlighting(content string, segments []Segment, segmentType LineType,
return sb.String()
}
// renderLeftColumn formats the left side of a side-by-side diff
func renderLeftColumn(fileName string, dl *DiffLine, colWidth int, styles StyleConfig) string {
// renderDiffColumnLine is a helper function that handles the common logic for rendering diff columns
func renderDiffColumnLine(
fileName string,
dl *DiffLine,
colWidth int,
isLeftColumn bool,
t theme.Theme,
) string {
if dl == nil {
contextLineStyle := lipgloss.NewStyle().Background(styles.ContextLineBg)
contextLineStyle := lipgloss.NewStyle().Background(t.DiffContextBg())
return contextLineStyle.Width(colWidth).Render("")
}
removedLineStyle, _, contextLineStyle, lineNumberStyle := createStyles(styles)
removedLineStyle, addedLineStyle, contextLineStyle, lineNumberStyle := createStyles(t)
// Determine line style based on line type
// Determine line style based on line type and column
var marker string
var bgStyle lipgloss.Style
switch dl.Kind {
case LineRemoved:
marker = removedLineStyle.Foreground(styles.RemovedFg).Render("-")
bgStyle = removedLineStyle
lineNumberStyle = lineNumberStyle.Foreground(styles.RemovedFg).Background(styles.RemovedLineNumberBg)
case LineAdded:
marker = "?"
bgStyle = contextLineStyle
case LineContext:
marker = contextLineStyle.Render(" ")
bgStyle = contextLineStyle
var lineNum string
var highlightType LineType
var highlightColor lipgloss.AdaptiveColor
if isLeftColumn {
// Left column logic
switch dl.Kind {
case LineRemoved:
marker = "-"
bgStyle = removedLineStyle
lineNumberStyle = lineNumberStyle.Foreground(t.DiffRemoved()).Background(t.DiffRemovedLineNumberBg())
highlightType = LineRemoved
highlightColor = t.DiffHighlightRemoved()
case LineAdded:
marker = "?"
bgStyle = contextLineStyle
case LineContext:
marker = " "
bgStyle = contextLineStyle
}
// Format line number for left column
if dl.OldLineNo > 0 {
lineNum = fmt.Sprintf("%6d", dl.OldLineNo)
}
} else {
// Right column logic
switch dl.Kind {
case LineAdded:
marker = "+"
bgStyle = addedLineStyle
lineNumberStyle = lineNumberStyle.Foreground(t.DiffAdded()).Background(t.DiffAddedLineNumberBg())
highlightType = LineAdded
highlightColor = t.DiffHighlightAdded()
case LineRemoved:
marker = "?"
bgStyle = contextLineStyle
case LineContext:
marker = " "
bgStyle = contextLineStyle
}
// Format line number for right column
if dl.NewLineNo > 0 {
lineNum = fmt.Sprintf("%6d", dl.NewLineNo)
}
}
// Format line number
lineNum := ""
if dl.OldLineNo > 0 {
lineNum = fmt.Sprintf("%6d", dl.OldLineNo)
// Style the marker based on line type
var styledMarker string
switch dl.Kind {
case LineRemoved:
styledMarker = removedLineStyle.Foreground(t.DiffRemoved()).Render(marker)
case LineAdded:
styledMarker = addedLineStyle.Foreground(t.DiffAdded()).Render(marker)
case LineContext:
styledMarker = contextLineStyle.Foreground(t.TextMuted()).Render(marker)
default:
styledMarker = marker
}
// Create the line prefix
prefix := lineNumberStyle.Render(lineNum + " " + marker)
prefix := lineNumberStyle.Render(lineNum + " " + styledMarker)
// Apply syntax highlighting
content := highlightLine(fileName, dl.Content, bgStyle.GetBackground())
// Apply intra-line highlighting for removed lines
if dl.Kind == LineRemoved && len(dl.Segments) > 0 {
content = applyHighlighting(content, dl.Segments, LineRemoved, styles.RemovedHighlightBg)
// Apply intra-line highlighting if needed
if (dl.Kind == LineRemoved && isLeftColumn || dl.Kind == LineAdded && !isLeftColumn) && len(dl.Segments) > 0 {
content = applyHighlighting(content, dl.Segments, highlightType, highlightColor)
}
// Add a padding space for removed lines
if dl.Kind == LineRemoved {
// Add a padding space for added/removed lines
if (dl.Kind == LineRemoved && isLeftColumn) || (dl.Kind == LineAdded && !isLeftColumn) {
content = bgStyle.Render(" ") + content
}
@@ -771,67 +767,19 @@ func renderLeftColumn(fileName string, dl *DiffLine, colWidth int, styles StyleC
ansi.Truncate(
lineText,
colWidth,
lipgloss.NewStyle().Background(styles.HunkLineBg).Foreground(styles.HunkLineFg).Render("..."),
lipgloss.NewStyle().Background(bgStyle.GetBackground()).Foreground(t.TextMuted()).Render("..."),
),
)
}
// renderLeftColumn formats the left side of a side-by-side diff
func renderLeftColumn(fileName string, dl *DiffLine, colWidth int) string {
return renderDiffColumnLine(fileName, dl, colWidth, true, theme.CurrentTheme())
}
// renderRightColumn formats the right side of a side-by-side diff
func renderRightColumn(fileName string, dl *DiffLine, colWidth int, styles StyleConfig) string {
if dl == nil {
contextLineStyle := lipgloss.NewStyle().Background(styles.ContextLineBg)
return contextLineStyle.Width(colWidth).Render("")
}
_, addedLineStyle, contextLineStyle, lineNumberStyle := createStyles(styles)
// Determine line style based on line type
var marker string
var bgStyle lipgloss.Style
switch dl.Kind {
case LineAdded:
marker = addedLineStyle.Foreground(styles.AddedFg).Render("+")
bgStyle = addedLineStyle
lineNumberStyle = lineNumberStyle.Foreground(styles.AddedFg).Background(styles.AddedLineNamerBg)
case LineRemoved:
marker = "?"
bgStyle = contextLineStyle
case LineContext:
marker = contextLineStyle.Render(" ")
bgStyle = contextLineStyle
}
// Format line number
lineNum := ""
if dl.NewLineNo > 0 {
lineNum = fmt.Sprintf("%6d", dl.NewLineNo)
}
// Create the line prefix
prefix := lineNumberStyle.Render(lineNum + " " + marker)
// Apply syntax highlighting
content := highlightLine(fileName, dl.Content, bgStyle.GetBackground())
// Apply intra-line highlighting for added lines
if dl.Kind == LineAdded && len(dl.Segments) > 0 {
content = applyHighlighting(content, dl.Segments, LineAdded, styles.AddedHighlightBg)
}
// Add a padding space for added lines
if dl.Kind == LineAdded {
content = bgStyle.Render(" ") + content
}
// Create the final line and truncate if needed
lineText := prefix + content
return bgStyle.MaxHeight(1).Width(colWidth).Render(
ansi.Truncate(
lineText,
colWidth,
lipgloss.NewStyle().Background(styles.HunkLineBg).Foreground(styles.HunkLineFg).Render("..."),
),
)
func renderRightColumn(fileName string, dl *DiffLine, colWidth int) string {
return renderDiffColumnLine(fileName, dl, colWidth, false, theme.CurrentTheme())
}
// -------------------------------------------------------------------------
@@ -848,7 +796,7 @@ func RenderSideBySideHunk(fileName string, h Hunk, opts ...SideBySideOption) str
copy(hunkCopy.Lines, h.Lines)
// Highlight changes within lines
HighlightIntralineChanges(&hunkCopy, config.Style)
HighlightIntralineChanges(&hunkCopy)
// Pair lines for side-by-side display
pairs := pairLines(hunkCopy.Lines)
@@ -860,8 +808,8 @@ func RenderSideBySideHunk(fileName string, h Hunk, opts ...SideBySideOption) str
rightWidth := config.TotalWidth - colWidth
var sb strings.Builder
for _, p := range pairs {
leftStr := renderLeftColumn(fileName, p.left, leftWidth, config.Style)
rightStr := renderRightColumn(fileName, p.right, rightWidth, config.Style)
leftStr := renderLeftColumn(fileName, p.left, leftWidth)
rightStr := renderRightColumn(fileName, p.right, rightWidth)
sb.WriteString(leftStr + rightStr + "\n")
}
@@ -870,6 +818,7 @@ func RenderSideBySideHunk(fileName string, h Hunk, opts ...SideBySideOption) str
// FormatDiff creates a side-by-side formatted view of a diff
func FormatDiff(diffText string, opts ...SideBySideOption) (string, error) {
t := theme.CurrentTheme()
diffResult, err := ParseUnifiedDiff(diffText)
if err != nil {
return "", err
@@ -877,53 +826,14 @@ func FormatDiff(diffText string, opts ...SideBySideOption) (string, error) {
var sb strings.Builder
config := NewSideBySideConfig(opts...)
if config.Style.ShowHeader {
removeIcon := lipgloss.NewStyle().
Background(config.Style.RemovedLineBg).
Foreground(config.Style.RemovedFg).
Render("⏹")
addIcon := lipgloss.NewStyle().
Background(config.Style.AddedLineBg).
Foreground(config.Style.AddedFg).
Render("⏹")
fileName := lipgloss.NewStyle().
Background(config.Style.ContextLineBg).
Foreground(config.Style.FileNameFg).
Render(" " + diffResult.OldFile)
for _, h := range diffResult.Hunks {
sb.WriteString(
lipgloss.NewStyle().
Background(config.Style.ContextLineBg).
Padding(0, 1, 0, 1).
Foreground(config.Style.FileNameFg).
BorderStyle(lipgloss.NormalBorder()).
BorderTop(true).
BorderBottom(true).
BorderForeground(config.Style.FileNameFg).
BorderBackground(config.Style.ContextLineBg).
Background(t.DiffHunkHeader()).
Foreground(t.Background()).
Width(config.TotalWidth).
Render(
lipgloss.JoinHorizontal(lipgloss.Top,
removeIcon,
addIcon,
fileName,
),
) + "\n",
Render(h.Header) + "\n",
)
}
for _, h := range diffResult.Hunks {
// Render hunk header
if config.Style.ShowHunkHeader {
sb.WriteString(
lipgloss.NewStyle().
Background(config.Style.HunkLineBg).
Foreground(config.Style.HunkLineFg).
Width(config.TotalWidth).
Render(h.Header) + "\n",
)
}
sb.WriteString(RenderSideBySideHunk(diffResult.OldFile, h, opts...))
}
@@ -938,14 +848,16 @@ func GenerateDiff(beforeContent, afterContent, fileName string) (string, int, in
fileName = strings.TrimPrefix(fileName, cwd)
fileName = strings.TrimPrefix(fileName, "/")
edits := udiff.Strings(beforeContent, afterContent)
unified, _ := udiff.ToUnified("a/"+fileName, "b/"+fileName, beforeContent, edits, 8)
var (
unified = udiff.Unified("a/"+fileName, "b/"+fileName, beforeContent, afterContent)
additions = 0
removals = 0
)
lines := strings.Split(unified, "\n")
for _, line := range lines {
lines := strings.SplitSeq(unified, "\n")
for line := range lines {
if strings.HasPrefix(line, "+") && !strings.HasPrefix(line, "+++") {
additions++
} else if strings.HasPrefix(line, "-") && !strings.HasPrefix(line, "---") {

103
internal/diff/diff_test.go Normal file
View File

@@ -0,0 +1,103 @@
package diff
import (
"fmt"
"testing"
"github.com/charmbracelet/lipgloss"
"github.com/stretchr/testify/assert"
)
// TestApplyHighlighting tests the applyHighlighting function with various ANSI sequences
func TestApplyHighlighting(t *testing.T) {
t.Parallel()
// Mock theme colors for testing
mockHighlightBg := lipgloss.AdaptiveColor{
Dark: "#FF0000", // Red background for highlighting
Light: "#FF0000",
}
// Test cases
tests := []struct {
name string
content string
segments []Segment
segmentType LineType
expectContains string
}{
{
name: "Simple text with no ANSI",
content: "This is a test",
segments: []Segment{{Start: 0, End: 4, Type: LineAdded}},
segmentType: LineAdded,
// Should contain full reset sequence after highlighting
expectContains: "\x1b[0m",
},
{
name: "Text with existing ANSI foreground",
content: "This \x1b[32mis\x1b[0m a test", // "is" in green
segments: []Segment{{Start: 5, End: 7, Type: LineAdded}},
segmentType: LineAdded,
// Should contain full reset sequence after highlighting
expectContains: "\x1b[0m",
},
{
name: "Text with existing ANSI background",
content: "This \x1b[42mis\x1b[0m a test", // "is" with green background
segments: []Segment{{Start: 5, End: 7, Type: LineAdded}},
segmentType: LineAdded,
// Should contain full reset sequence after highlighting
expectContains: "\x1b[0m",
},
{
name: "Text with complex ANSI styling",
content: "This \x1b[1;32;45mis\x1b[0m a test", // "is" bold green on magenta
segments: []Segment{{Start: 5, End: 7, Type: LineAdded}},
segmentType: LineAdded,
// Should contain full reset sequence after highlighting
expectContains: "\x1b[0m",
},
}
for _, tc := range tests {
tc := tc // Capture range variable for parallel testing
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
result := applyHighlighting(tc.content, tc.segments, tc.segmentType, mockHighlightBg)
// Verify the result contains the expected sequence
assert.Contains(t, result, tc.expectContains,
"Result should contain full reset sequence")
// Print the result for manual inspection if needed
if t.Failed() {
fmt.Printf("Original: %q\nResult: %q\n", tc.content, result)
}
})
}
}
// TestApplyHighlightingWithMultipleSegments tests highlighting multiple segments
func TestApplyHighlightingWithMultipleSegments(t *testing.T) {
t.Parallel()
// Mock theme colors for testing
mockHighlightBg := lipgloss.AdaptiveColor{
Dark: "#FF0000", // Red background for highlighting
Light: "#FF0000",
}
content := "This is a test with multiple segments to highlight"
segments := []Segment{
{Start: 0, End: 4, Type: LineAdded}, // "This"
{Start: 8, End: 9, Type: LineAdded}, // "a"
{Start: 15, End: 23, Type: LineAdded}, // "multiple"
}
result := applyHighlighting(content, segments, LineAdded, mockHighlightBg)
// Verify the result contains the full reset sequence
assert.Contains(t, result, "\x1b[0m",
"Result should contain full reset sequence")
}

View File

@@ -1,252 +0,0 @@
package history
import (
"context"
"database/sql"
"fmt"
"strconv"
"strings"
"time"
"github.com/google/uuid"
"github.com/opencode-ai/opencode/internal/db"
"github.com/opencode-ai/opencode/internal/pubsub"
)
const (
InitialVersion = "initial"
)
type File struct {
ID string
SessionID string
Path string
Content string
Version string
CreatedAt int64
UpdatedAt int64
}
type Service interface {
pubsub.Suscriber[File]
Create(ctx context.Context, sessionID, path, content string) (File, error)
CreateVersion(ctx context.Context, sessionID, path, content string) (File, error)
Get(ctx context.Context, id string) (File, error)
GetByPathAndSession(ctx context.Context, path, sessionID string) (File, error)
ListBySession(ctx context.Context, sessionID string) ([]File, error)
ListLatestSessionFiles(ctx context.Context, sessionID string) ([]File, error)
Update(ctx context.Context, file File) (File, error)
Delete(ctx context.Context, id string) error
DeleteSessionFiles(ctx context.Context, sessionID string) error
}
type service struct {
*pubsub.Broker[File]
db *sql.DB
q *db.Queries
}
func NewService(q *db.Queries, db *sql.DB) Service {
return &service{
Broker: pubsub.NewBroker[File](),
q: q,
db: db,
}
}
func (s *service) Create(ctx context.Context, sessionID, path, content string) (File, error) {
return s.createWithVersion(ctx, sessionID, path, content, InitialVersion)
}
func (s *service) CreateVersion(ctx context.Context, sessionID, path, content string) (File, error) {
// Get the latest version for this path
files, err := s.q.ListFilesByPath(ctx, path)
if err != nil {
return File{}, err
}
if len(files) == 0 {
// No previous versions, create initial
return s.Create(ctx, sessionID, path, content)
}
// Get the latest version
latestFile := files[0] // Files are ordered by created_at DESC
latestVersion := latestFile.Version
// Generate the next version
var nextVersion string
if latestVersion == InitialVersion {
nextVersion = "v1"
} else if strings.HasPrefix(latestVersion, "v") {
versionNum, err := strconv.Atoi(latestVersion[1:])
if err != nil {
// If we can't parse the version, just use a timestamp-based version
nextVersion = fmt.Sprintf("v%d", latestFile.CreatedAt)
} else {
nextVersion = fmt.Sprintf("v%d", versionNum+1)
}
} else {
// If the version format is unexpected, use a timestamp-based version
nextVersion = fmt.Sprintf("v%d", latestFile.CreatedAt)
}
return s.createWithVersion(ctx, sessionID, path, content, nextVersion)
}
func (s *service) createWithVersion(ctx context.Context, sessionID, path, content, version string) (File, error) {
// Maximum number of retries for transaction conflicts
const maxRetries = 3
var file File
var err error
// Retry loop for transaction conflicts
for attempt := range maxRetries {
// Start a transaction
tx, txErr := s.db.Begin()
if txErr != nil {
return File{}, fmt.Errorf("failed to begin transaction: %w", txErr)
}
// Create a new queries instance with the transaction
qtx := s.q.WithTx(tx)
// Try to create the file within the transaction
dbFile, txErr := qtx.CreateFile(ctx, db.CreateFileParams{
ID: uuid.New().String(),
SessionID: sessionID,
Path: path,
Content: content,
Version: version,
})
if txErr != nil {
// Rollback the transaction
tx.Rollback()
// Check if this is a uniqueness constraint violation
if strings.Contains(txErr.Error(), "UNIQUE constraint failed") {
if attempt < maxRetries-1 {
// If we have retries left, generate a new version and try again
if strings.HasPrefix(version, "v") {
versionNum, parseErr := strconv.Atoi(version[1:])
if parseErr == nil {
version = fmt.Sprintf("v%d", versionNum+1)
continue
}
}
// If we can't parse the version, use a timestamp-based version
version = fmt.Sprintf("v%d", time.Now().Unix())
continue
}
}
return File{}, txErr
}
// Commit the transaction
if txErr = tx.Commit(); txErr != nil {
return File{}, fmt.Errorf("failed to commit transaction: %w", txErr)
}
file = s.fromDBItem(dbFile)
s.Publish(pubsub.CreatedEvent, file)
return file, nil
}
return file, err
}
func (s *service) Get(ctx context.Context, id string) (File, error) {
dbFile, err := s.q.GetFile(ctx, id)
if err != nil {
return File{}, err
}
return s.fromDBItem(dbFile), nil
}
func (s *service) GetByPathAndSession(ctx context.Context, path, sessionID string) (File, error) {
dbFile, err := s.q.GetFileByPathAndSession(ctx, db.GetFileByPathAndSessionParams{
Path: path,
SessionID: sessionID,
})
if err != nil {
return File{}, err
}
return s.fromDBItem(dbFile), nil
}
func (s *service) ListBySession(ctx context.Context, sessionID string) ([]File, error) {
dbFiles, err := s.q.ListFilesBySession(ctx, sessionID)
if err != nil {
return nil, err
}
files := make([]File, len(dbFiles))
for i, dbFile := range dbFiles {
files[i] = s.fromDBItem(dbFile)
}
return files, nil
}
func (s *service) ListLatestSessionFiles(ctx context.Context, sessionID string) ([]File, error) {
dbFiles, err := s.q.ListLatestSessionFiles(ctx, sessionID)
if err != nil {
return nil, err
}
files := make([]File, len(dbFiles))
for i, dbFile := range dbFiles {
files[i] = s.fromDBItem(dbFile)
}
return files, nil
}
func (s *service) Update(ctx context.Context, file File) (File, error) {
dbFile, err := s.q.UpdateFile(ctx, db.UpdateFileParams{
ID: file.ID,
Content: file.Content,
Version: file.Version,
})
if err != nil {
return File{}, err
}
updatedFile := s.fromDBItem(dbFile)
s.Publish(pubsub.UpdatedEvent, updatedFile)
return updatedFile, nil
}
func (s *service) Delete(ctx context.Context, id string) error {
file, err := s.Get(ctx, id)
if err != nil {
return err
}
err = s.q.DeleteFile(ctx, id)
if err != nil {
return err
}
s.Publish(pubsub.DeletedEvent, file)
return nil
}
func (s *service) DeleteSessionFiles(ctx context.Context, sessionID string) error {
files, err := s.ListBySession(ctx, sessionID)
if err != nil {
return err
}
for _, file := range files {
err = s.Delete(ctx, file.ID)
if err != nil {
return err
}
}
return nil
}
func (s *service) fromDBItem(item db.File) File {
return File{
ID: item.ID,
SessionID: item.SessionID,
Path: item.Path,
Content: item.Content,
Version: item.Version,
CreatedAt: item.CreatedAt,
UpdatedAt: item.UpdatedAt,
}
}

441
internal/history/history.go Normal file
View File

@@ -0,0 +1,441 @@
package history
import (
"context"
"database/sql"
"fmt"
"log/slog"
"slices"
"strconv"
"strings"
"sync"
"time"
"github.com/google/uuid"
"github.com/sst/opencode/internal/db"
"github.com/sst/opencode/internal/pubsub"
)
const (
InitialVersion = "initial"
)
type File struct {
ID string
SessionID string
Path string
Content string
Version string
CreatedAt time.Time
UpdatedAt time.Time
}
const (
EventFileCreated pubsub.EventType = "history_file_created"
EventFileVersionCreated pubsub.EventType = "history_file_version_created"
EventFileUpdated pubsub.EventType = "history_file_updated"
EventFileDeleted pubsub.EventType = "history_file_deleted"
EventSessionFilesDeleted pubsub.EventType = "history_session_files_deleted"
)
type Service interface {
pubsub.Subscriber[File]
Create(ctx context.Context, sessionID, path, content string) (File, error)
CreateVersion(ctx context.Context, sessionID, path, content string) (File, error)
Get(ctx context.Context, id string) (File, error)
GetByPathAndVersion(ctx context.Context, sessionID, path, version string) (File, error)
GetLatestByPathAndSession(ctx context.Context, path, sessionID string) (File, error)
ListBySession(ctx context.Context, sessionID string) ([]File, error)
ListLatestSessionFiles(ctx context.Context, sessionID string) ([]File, error)
ListVersionsByPath(ctx context.Context, path string) ([]File, error)
Update(ctx context.Context, file File) (File, error)
Delete(ctx context.Context, id string) error
DeleteSessionFiles(ctx context.Context, sessionID string) error
}
type service struct {
db *db.Queries
sqlDB *sql.DB
broker *pubsub.Broker[File]
mu sync.RWMutex
}
var globalHistoryService *service
func InitService(sqlDatabase *sql.DB) error {
if globalHistoryService != nil {
return fmt.Errorf("history service already initialized")
}
queries := db.New(sqlDatabase)
broker := pubsub.NewBroker[File]()
globalHistoryService = &service{
db: queries,
sqlDB: sqlDatabase,
broker: broker,
}
return nil
}
func GetService() Service {
if globalHistoryService == nil {
panic("history service not initialized. Call history.InitService() first.")
}
return globalHistoryService
}
func (s *service) Create(ctx context.Context, sessionID, path, content string) (File, error) {
return s.createWithVersion(ctx, sessionID, path, content, InitialVersion, EventFileCreated)
}
func (s *service) CreateVersion(ctx context.Context, sessionID, path, content string) (File, error) {
s.mu.RLock()
files, err := s.db.ListFilesByPath(ctx, path)
s.mu.RUnlock()
if err != nil && err != sql.ErrNoRows {
return File{}, fmt.Errorf("db.ListFilesByPath for next version: %w", err)
}
latestVersionNumber := 0
if len(files) > 0 {
// Sort to be absolutely sure about the latest version globally for this path
slices.SortFunc(files, func(a, b db.File) int {
if strings.HasPrefix(a.Version, "v") && strings.HasPrefix(b.Version, "v") {
vA, _ := strconv.Atoi(a.Version[1:])
vB, _ := strconv.Atoi(b.Version[1:])
return vB - vA // Descending to get latest first
}
if a.Version == InitialVersion && b.Version != InitialVersion {
return 1 // initial comes after vX
}
if b.Version == InitialVersion && a.Version != InitialVersion {
return -1
}
// Compare timestamps as strings (ISO format sorts correctly)
if b.CreatedAt > a.CreatedAt {
return 1
} else if a.CreatedAt > b.CreatedAt {
return -1
}
return 0 // Equal timestamps
})
latestFile := files[0]
if strings.HasPrefix(latestFile.Version, "v") {
vNum, parseErr := strconv.Atoi(latestFile.Version[1:])
if parseErr == nil {
latestVersionNumber = vNum
}
}
}
nextVersionStr := fmt.Sprintf("v%d", latestVersionNumber+1)
return s.createWithVersion(ctx, sessionID, path, content, nextVersionStr, EventFileVersionCreated)
}
func (s *service) createWithVersion(ctx context.Context, sessionID, path, content, version string, eventType pubsub.EventType) (File, error) {
s.mu.Lock()
defer s.mu.Unlock()
const maxRetries = 3
var file File
var err error
for attempt := range maxRetries {
tx, txErr := s.sqlDB.BeginTx(ctx, nil)
if txErr != nil {
return File{}, fmt.Errorf("failed to begin transaction: %w", txErr)
}
qtx := s.db.WithTx(tx)
dbFile, createErr := qtx.CreateFile(ctx, db.CreateFileParams{
ID: uuid.New().String(),
SessionID: sessionID,
Path: path,
Content: content,
Version: version,
})
if createErr != nil {
if rbErr := tx.Rollback(); rbErr != nil {
slog.Error("Failed to rollback transaction on create error", "error", rbErr)
}
if strings.Contains(createErr.Error(), "UNIQUE constraint failed: files.path, files.session_id, files.version") {
if attempt < maxRetries-1 {
slog.Warn("Unique constraint violation for file version, retrying with incremented version", "path", path, "session", sessionID, "attempted_version", version, "attempt", attempt+1)
// Increment version string like v1, v2, v3...
if strings.HasPrefix(version, "v") {
numPart := version[1:]
num, parseErr := strconv.Atoi(numPart)
if parseErr == nil {
version = fmt.Sprintf("v%d", num+1)
continue // Retry with new version
}
}
// Fallback if version is not "vX" or parsing failed
version = fmt.Sprintf("%s-retry%d", version, attempt+1)
continue
}
}
return File{}, fmt.Errorf("db.CreateFile within transaction: %w", createErr)
}
if commitErr := tx.Commit(); commitErr != nil {
return File{}, fmt.Errorf("failed to commit transaction: %w", commitErr)
}
file = s.fromDBItem(dbFile)
s.broker.Publish(eventType, file)
return file, nil // Success
}
return File{}, fmt.Errorf("failed to create file after %d retries due to version conflicts: %w", maxRetries, err)
}
func (s *service) Get(ctx context.Context, id string) (File, error) {
s.mu.RLock()
defer s.mu.RUnlock()
dbFile, err := s.db.GetFile(ctx, id)
if err != nil {
if err == sql.ErrNoRows {
return File{}, fmt.Errorf("file with ID '%s' not found", id)
}
return File{}, fmt.Errorf("db.GetFile: %w", err)
}
return s.fromDBItem(dbFile), nil
}
func (s *service) GetByPathAndVersion(ctx context.Context, sessionID, path, version string) (File, error) {
s.mu.RLock()
defer s.mu.RUnlock()
// sqlc doesn't directly support GetyByPathAndVersionAndSession
// We list and filter. This could be optimized with a custom query if performance is an issue.
allFilesForPath, err := s.db.ListFilesByPath(ctx, path)
if err != nil {
return File{}, fmt.Errorf("db.ListFilesByPath for GetByPathAndVersion: %w", err)
}
for _, dbFile := range allFilesForPath {
if dbFile.SessionID == sessionID && dbFile.Version == version {
return s.fromDBItem(dbFile), nil
}
}
return File{}, fmt.Errorf("file not found for session '%s', path '%s', version '%s'", sessionID, path, version)
}
func (s *service) GetLatestByPathAndSession(ctx context.Context, path, sessionID string) (File, error) {
s.mu.RLock()
defer s.mu.RUnlock()
// GetFileByPathAndSession in sqlc already orders by created_at DESC and takes LIMIT 1
dbFile, err := s.db.GetFileByPathAndSession(ctx, db.GetFileByPathAndSessionParams{
Path: path,
SessionID: sessionID,
})
if err != nil {
if err == sql.ErrNoRows {
return File{}, fmt.Errorf("no file found for path '%s' in session '%s'", path, sessionID)
}
return File{}, fmt.Errorf("db.GetFileByPathAndSession: %w", err)
}
return s.fromDBItem(dbFile), nil
}
func (s *service) ListBySession(ctx context.Context, sessionID string) ([]File, error) {
s.mu.RLock()
defer s.mu.RUnlock()
dbFiles, err := s.db.ListFilesBySession(ctx, sessionID) // Assumes this orders by created_at ASC
if err != nil {
return nil, fmt.Errorf("db.ListFilesBySession: %w", err)
}
files := make([]File, len(dbFiles))
for i, dbF := range dbFiles {
files[i] = s.fromDBItem(dbF)
}
return files, nil
}
func (s *service) ListLatestSessionFiles(ctx context.Context, sessionID string) ([]File, error) {
s.mu.RLock()
defer s.mu.RUnlock()
dbFiles, err := s.db.ListLatestSessionFiles(ctx, sessionID) // Uses the specific sqlc query
if err != nil {
return nil, fmt.Errorf("db.ListLatestSessionFiles: %w", err)
}
files := make([]File, len(dbFiles))
for i, dbF := range dbFiles {
files[i] = s.fromDBItem(dbF)
}
return files, nil
}
func (s *service) ListVersionsByPath(ctx context.Context, path string) ([]File, error) {
s.mu.RLock()
defer s.mu.RUnlock()
dbFiles, err := s.db.ListFilesByPath(ctx, path) // sqlc query orders by created_at DESC
if err != nil {
return nil, fmt.Errorf("db.ListFilesByPath: %w", err)
}
files := make([]File, len(dbFiles))
for i, dbF := range dbFiles {
files[i] = s.fromDBItem(dbF)
}
return files, nil
}
func (s *service) Update(ctx context.Context, file File) (File, error) {
s.mu.Lock()
defer s.mu.Unlock()
if file.ID == "" {
return File{}, fmt.Errorf("cannot update file with empty ID")
}
// UpdatedAt is handled by DB trigger
dbFile, err := s.db.UpdateFile(ctx, db.UpdateFileParams{
ID: file.ID,
Content: file.Content,
Version: file.Version,
})
if err != nil {
return File{}, fmt.Errorf("db.UpdateFile: %w", err)
}
updatedFile := s.fromDBItem(dbFile)
s.broker.Publish(EventFileUpdated, updatedFile)
return updatedFile, nil
}
func (s *service) Delete(ctx context.Context, id string) error {
s.mu.Lock()
fileToPublish, err := s.getServiceForPublish(ctx, id) // Use internal method with appropriate locking
s.mu.Unlock()
if err != nil {
if strings.Contains(err.Error(), "not found") {
slog.Warn("Attempted to delete non-existent file history", "id", id)
return nil // Or return specific error if needed
}
return err
}
s.mu.Lock()
defer s.mu.Unlock()
err = s.db.DeleteFile(ctx, id)
if err != nil {
return fmt.Errorf("db.DeleteFile: %w", err)
}
if fileToPublish != nil {
s.broker.Publish(EventFileDeleted, *fileToPublish)
}
return nil
}
func (s *service) getServiceForPublish(ctx context.Context, id string) (*File, error) {
// Assumes outer lock is NOT held or caller manages it.
// For GetFile, it has its own RLock.
dbFile, err := s.db.GetFile(ctx, id)
if err != nil {
return nil, err
}
file := s.fromDBItem(dbFile)
return &file, nil
}
func (s *service) DeleteSessionFiles(ctx context.Context, sessionID string) error {
s.mu.Lock() // Lock for the entire operation
defer s.mu.Unlock()
// Get files first for publishing events
filesToDelete, err := s.db.ListFilesBySession(ctx, sessionID)
if err != nil {
return fmt.Errorf("db.ListFilesBySession for deletion: %w", err)
}
err = s.db.DeleteSessionFiles(ctx, sessionID)
if err != nil {
return fmt.Errorf("db.DeleteSessionFiles: %w", err)
}
for _, dbFile := range filesToDelete {
file := s.fromDBItem(dbFile)
s.broker.Publish(EventFileDeleted, file) // Individual delete events
}
return nil
}
func (s *service) Subscribe(ctx context.Context) <-chan pubsub.Event[File] {
return s.broker.Subscribe(ctx)
}
func (s *service) fromDBItem(item db.File) File {
// Parse timestamps from ISO strings
createdAt, err := time.Parse(time.RFC3339Nano, item.CreatedAt)
if err != nil {
slog.Error("Failed to parse created_at", "value", item.CreatedAt, "error", err)
createdAt = time.Now() // Fallback
}
updatedAt, err := time.Parse(time.RFC3339Nano, item.UpdatedAt)
if err != nil {
slog.Error("Failed to parse created_at", "value", item.CreatedAt, "error", err)
updatedAt = time.Now() // Fallback
}
return File{
ID: item.ID,
SessionID: item.SessionID,
Path: item.Path,
Content: item.Content,
Version: item.Version,
CreatedAt: createdAt,
UpdatedAt: updatedAt,
}
}
func Create(ctx context.Context, sessionID, path, content string) (File, error) {
return GetService().Create(ctx, sessionID, path, content)
}
func CreateVersion(ctx context.Context, sessionID, path, content string) (File, error) {
return GetService().CreateVersion(ctx, sessionID, path, content)
}
func Get(ctx context.Context, id string) (File, error) {
return GetService().Get(ctx, id)
}
func GetByPathAndVersion(ctx context.Context, sessionID, path, version string) (File, error) {
return GetService().GetByPathAndVersion(ctx, sessionID, path, version)
}
func GetLatestByPathAndSession(ctx context.Context, path, sessionID string) (File, error) {
return GetService().GetLatestByPathAndSession(ctx, path, sessionID)
}
func ListBySession(ctx context.Context, sessionID string) ([]File, error) {
return GetService().ListBySession(ctx, sessionID)
}
func ListLatestSessionFiles(ctx context.Context, sessionID string) ([]File, error) {
return GetService().ListLatestSessionFiles(ctx, sessionID)
}
func ListVersionsByPath(ctx context.Context, path string) ([]File, error) {
return GetService().ListVersionsByPath(ctx, path)
}
func Update(ctx context.Context, file File) (File, error) {
return GetService().Update(ctx, file)
}
func Delete(ctx context.Context, id string) error {
return GetService().Delete(ctx, id)
}
func DeleteSessionFiles(ctx context.Context, sessionID string) error {
return GetService().DeleteSessionFiles(ctx, sessionID)
}
func Subscribe(ctx context.Context) <-chan pubsub.Event[File] {
return GetService().Subscribe(ctx)
}

View File

@@ -5,11 +5,11 @@ import (
"encoding/json"
"fmt"
"github.com/opencode-ai/opencode/internal/config"
"github.com/opencode-ai/opencode/internal/llm/tools"
"github.com/opencode-ai/opencode/internal/lsp"
"github.com/opencode-ai/opencode/internal/message"
"github.com/opencode-ai/opencode/internal/session"
"github.com/sst/opencode/internal/config"
"github.com/sst/opencode/internal/llm/tools"
"github.com/sst/opencode/internal/lsp"
"github.com/sst/opencode/internal/message"
"github.com/sst/opencode/internal/session"
)
type agentTool struct {
@@ -88,10 +88,8 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes
}
parentSession.Cost += updatedSession.Cost
parentSession.PromptTokens += updatedSession.PromptTokens
parentSession.CompletionTokens += updatedSession.CompletionTokens
_, err = b.sessions.Save(ctx, parentSession)
_, err = b.sessions.Update(ctx, parentSession)
if err != nil {
return tools.ToolResponse{}, fmt.Errorf("error saving parent session: %s", err)
}

View File

@@ -6,16 +6,20 @@ import (
"fmt"
"strings"
"sync"
"time"
"github.com/opencode-ai/opencode/internal/config"
"github.com/opencode-ai/opencode/internal/llm/models"
"github.com/opencode-ai/opencode/internal/llm/prompt"
"github.com/opencode-ai/opencode/internal/llm/provider"
"github.com/opencode-ai/opencode/internal/llm/tools"
"github.com/opencode-ai/opencode/internal/logging"
"github.com/opencode-ai/opencode/internal/message"
"github.com/opencode-ai/opencode/internal/permission"
"github.com/opencode-ai/opencode/internal/session"
"log/slog"
"github.com/sst/opencode/internal/config"
"github.com/sst/opencode/internal/llm/models"
"github.com/sst/opencode/internal/llm/prompt"
"github.com/sst/opencode/internal/llm/provider"
"github.com/sst/opencode/internal/llm/tools"
"github.com/sst/opencode/internal/logging"
"github.com/sst/opencode/internal/message"
"github.com/sst/opencode/internal/permission"
"github.com/sst/opencode/internal/session"
"github.com/sst/opencode/internal/status"
)
// Common errors
@@ -38,11 +42,14 @@ func (e *AgentEvent) Response() message.Message {
}
type Service interface {
Run(ctx context.Context, sessionID string, content string) (<-chan AgentEvent, error)
Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
Cancel(sessionID string)
IsSessionBusy(sessionID string) bool
IsBusy() bool
Update(agentName config.AgentName, modelID models.ModelID) (models.Model, error)
CompactSession(ctx context.Context, sessionID string, force bool) error
GetUsage(ctx context.Context, sessionID string) (*int64, error)
EstimateContextWindowUsage(ctx context.Context, sessionID string) (float64, bool, error)
}
type agent struct {
@@ -68,8 +75,8 @@ func NewAgent(
return nil, err
}
var titleProvider provider.Provider
// Only generate titles for the coder agent
if agentName == config.AgentCoder {
// Only generate titles for the primary agent
if agentName == config.AgentPrimary {
titleProvider, err = createAgentProvider(config.AgentTitle)
if err != nil {
return nil, err
@@ -91,7 +98,7 @@ func NewAgent(
func (a *agent) Cancel(sessionID string) {
if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
if cancel, ok := cancelFunc.(context.CancelFunc); ok {
logging.InfoPersist(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
status.Info(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
cancel()
}
}
@@ -117,6 +124,9 @@ func (a *agent) IsSessionBusy(sessionID string) bool {
}
func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
if content == "" {
return nil
}
if a.titleProvider == nil {
return nil
}
@@ -124,16 +134,13 @@ func (a *agent) generateTitle(ctx context.Context, sessionID string, content str
if err != nil {
return err
}
parts := []message.ContentPart{message.TextContent{Text: content}}
response, err := a.titleProvider.SendMessages(
ctx,
[]message.Message{
{
Role: message.User,
Parts: []message.ContentPart{
message.TextContent{
Text: content,
},
},
Role: message.User,
Parts: parts,
},
},
make([]tools.BaseTool, 0),
@@ -148,7 +155,7 @@ func (a *agent) generateTitle(ctx context.Context, sessionID string, content str
}
session.Title = title
_, err = a.sessions.Save(ctx, session)
_, err = a.sessions.Update(ctx, session)
return err
}
@@ -158,7 +165,10 @@ func (a *agent) err(err error) AgentEvent {
}
}
func (a *agent) Run(ctx context.Context, sessionID string, content string) (<-chan AgentEvent, error) {
func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
if !a.provider.Model().SupportsAttachments && attachments != nil {
attachments = nil
}
events := make(chan AgentEvent)
if a.IsSessionBusy(sessionID) {
return nil, ErrSessionBusy
@@ -168,49 +178,98 @@ func (a *agent) Run(ctx context.Context, sessionID string, content string) (<-ch
a.activeRequests.Store(sessionID, cancel)
go func() {
logging.Debug("Request started", "sessionID", sessionID)
slog.Debug("Request started", "sessionID", sessionID)
defer logging.RecoverPanic("agent.Run", func() {
events <- a.err(fmt.Errorf("panic while running the agent"))
})
result := a.processGeneration(genCtx, sessionID, content)
if result.Err() != nil && !errors.Is(result.Err(), ErrRequestCancelled) && !errors.Is(result.Err(), context.Canceled) {
logging.ErrorPersist(fmt.Sprintf("Generation error for session %s: %v", sessionID, result))
var attachmentParts []message.ContentPart
for _, attachment := range attachments {
attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
}
logging.Debug("Request completed", "sessionID", sessionID)
result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
if result.Err() != nil && !errors.Is(result.Err(), ErrRequestCancelled) && !errors.Is(result.Err(), context.Canceled) {
status.Error(result.Err().Error())
}
slog.Debug("Request completed", "sessionID", sessionID)
a.activeRequests.Delete(sessionID)
cancel()
events <- result
close(events)
}()
return events, nil
}
func (a *agent) processGeneration(ctx context.Context, sessionID, content string) AgentEvent {
// List existing messages; if none, start title generation asynchronously.
msgs, err := a.messages.List(ctx, sessionID)
func (a *agent) prepareMessageHistory(ctx context.Context, sessionID string) (session.Session, []message.Message, error) {
currentSession, err := a.sessions.Get(ctx, sessionID)
if err != nil {
return a.err(fmt.Errorf("failed to list messages: %w", err))
}
if len(msgs) == 0 {
go func() {
defer logging.RecoverPanic("agent.Run", func() {
logging.ErrorPersist("panic while generating title")
})
titleErr := a.generateTitle(context.Background(), sessionID, content)
if titleErr != nil {
logging.ErrorPersist(fmt.Sprintf("failed to generate title: %v", titleErr))
}
}()
return currentSession, nil, fmt.Errorf("failed to get session: %w", err)
}
userMsg, err := a.createUserMessage(ctx, sessionID, content)
var sessionMessages []message.Message
if currentSession.Summary != "" && !currentSession.SummarizedAt.IsZero() {
// If summary exists, only fetch messages after the summarization timestamp
sessionMessages, err = a.messages.ListAfter(ctx, sessionID, currentSession.SummarizedAt)
if err != nil {
return currentSession, nil, fmt.Errorf("failed to list messages after summary: %w", err)
}
} else {
// If no summary, fetch all messages
sessionMessages, err = a.messages.List(ctx, sessionID)
if err != nil {
return currentSession, nil, fmt.Errorf("failed to list messages: %w", err)
}
}
var messages []message.Message
if currentSession.Summary != "" && !currentSession.SummarizedAt.IsZero() {
// If summary exists, create a temporary message for the summary
summaryMessage := message.Message{
Role: message.Assistant,
Parts: []message.ContentPart{
message.TextContent{Text: currentSession.Summary},
},
}
// Start with the summary, then add messages after the summary timestamp
messages = append([]message.Message{summaryMessage}, sessionMessages...)
} else {
// If no summary, just use all messages
messages = sessionMessages
}
return currentSession, messages, nil
}
func (a *agent) triggerTitleGeneration(sessionID string, content string) {
go func() {
defer logging.RecoverPanic("agent.Run", func() {
status.Error("panic while generating title")
})
titleErr := a.generateTitle(context.Background(), sessionID, content)
if titleErr != nil {
status.Error(fmt.Sprintf("failed to generate title: %v", titleErr))
}
}()
}
func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
currentSession, sessionMessages, err := a.prepareMessageHistory(ctx, sessionID)
if err != nil {
return a.err(err)
}
// If this is a new session, start title generation asynchronously
if len(sessionMessages) == 0 && currentSession.Summary == "" {
a.triggerTitleGeneration(sessionID, content)
}
userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
if err != nil {
return a.err(fmt.Errorf("failed to create user message: %w", err))
}
// Append the new user message to the conversation history.
msgHistory := append(msgs, userMsg)
messages := append(sessionMessages, userMsg)
for {
// Check for cancellation before each iteration
select {
@@ -219,7 +278,42 @@ func (a *agent) processGeneration(ctx context.Context, sessionID, content string
default:
// Continue processing
}
agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
// Check if auto-compaction is needed before calling the provider
usagePercentage, needsCompaction, errEstimate := a.EstimateContextWindowUsage(ctx, sessionID)
if errEstimate != nil {
slog.Warn("Failed to estimate context window usage for auto-compaction", "error", errEstimate, "sessionID", sessionID)
} else if needsCompaction {
status.Info(fmt.Sprintf("Context window usage is at %.2f%%. Auto-compacting conversation...", usagePercentage))
// Run compaction synchronously
compactCtx, cancelCompact := context.WithTimeout(ctx, 30*time.Second) // Use appropriate context
errCompact := a.CompactSession(compactCtx, sessionID, true)
cancelCompact()
if errCompact != nil {
status.Warn(fmt.Sprintf("Auto-compaction failed: %v. Context window usage may continue to grow.", errCompact))
} else {
status.Info("Auto-compaction completed successfully.")
// After compaction, message history needs to be re-prepared.
// The 'messages' slice needs to be updated with the new summary and subsequent messages,
// ensuring the latest user message is correctly appended.
_, sessionMessagesFromCompact, errPrepare := a.prepareMessageHistory(ctx, sessionID)
if errPrepare != nil {
return a.err(fmt.Errorf("failed to re-prepare message history after compaction: %w", errPrepare))
}
messages = sessionMessagesFromCompact
// Ensure the user message that triggered this cycle is the last one.
// 'userMsg' was created before this loop using a.createUserMessage.
// It should be appended to the 'messages' slice if it's not already the last element.
if len(messages) == 0 || (len(messages) > 0 && messages[len(messages)-1].ID != userMsg.ID) {
messages = append(messages, userMsg)
}
}
}
agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, messages)
if err != nil {
if errors.Is(err, context.Canceled) {
agentMessage.AddFinish(message.FinishReasonCanceled)
@@ -228,10 +322,10 @@ func (a *agent) processGeneration(ctx context.Context, sessionID, content string
}
return a.err(fmt.Errorf("failed to process events: %w", err))
}
logging.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
// We are not done, we need to respond with the tool response
msgHistory = append(msgHistory, agentMessage, *toolResults)
messages = append(messages, agentMessage, *toolResults)
continue
}
return AgentEvent{
@@ -240,15 +334,36 @@ func (a *agent) processGeneration(ctx context.Context, sessionID, content string
}
}
func (a *agent) createUserMessage(ctx context.Context, sessionID, content string) (message.Message, error) {
func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
parts := []message.ContentPart{message.TextContent{Text: content}}
parts = append(parts, attachmentParts...)
return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
Role: message.User,
Parts: []message.ContentPart{
message.TextContent{Text: content},
},
Role: message.User,
Parts: parts,
})
}
func (a *agent) createToolResponseMessage(ctx context.Context, sessionID string, toolResults []message.ToolResult) (*message.Message, error) {
if len(toolResults) == 0 {
return nil, nil
}
parts := make([]message.ContentPart, 0, len(toolResults))
for _, tr := range toolResults {
parts = append(parts, tr)
}
msg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
Role: message.Tool,
Parts: parts,
})
if err != nil {
return nil, fmt.Errorf("failed to create tool response message: %w", err)
}
return &msg, nil
}
func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
@@ -277,12 +392,37 @@ func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msg
}
}
toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
toolCalls := assistantMsg.ToolCalls()
// If the assistant wants to use tools, execute them
if assistantMsg.FinishReason() == message.FinishReasonToolUse {
toolCalls := assistantMsg.ToolCalls()
if len(toolCalls) > 0 {
toolResults, err := a.executeToolCalls(ctx, toolCalls)
if err != nil {
if errors.Is(err, context.Canceled) {
a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
}
return assistantMsg, nil, err
}
// Create a message with the tool results
toolResponseMsg, err := a.createToolResponseMessage(ctx, sessionID, toolResults)
if err != nil {
return assistantMsg, nil, err
}
return assistantMsg, toolResponseMsg, nil
}
}
return assistantMsg, nil, nil
}
func (a *agent) executeToolCalls(ctx context.Context, toolCalls []message.ToolCall) ([]message.ToolResult, error) {
toolResults := make([]message.ToolResult, len(toolCalls))
for i, toolCall := range toolCalls {
select {
case <-ctx.Done():
a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
// Make all future tool calls cancelled
for j := i; j < len(toolCalls); j++ {
toolResults[j] = message.ToolResult{
@@ -291,7 +431,7 @@ func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msg
IsError: true,
}
}
goto out
return toolResults, ctx.Err()
default:
// Continue processing
var tool tools.BaseTool
@@ -316,6 +456,7 @@ func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msg
Name: toolCall.Name,
Input: toolCall.Input,
})
if toolErr != nil {
if errors.Is(toolErr, permission.ErrorPermissionDenied) {
toolResults[i] = message.ToolResult{
@@ -323,6 +464,7 @@ func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msg
Content: "Permission denied",
IsError: true,
}
// Cancel all remaining tool calls if permission is denied
for j := i + 1; j < len(toolCalls); j++ {
toolResults[j] = message.ToolResult{
ToolCallID: toolCalls[j].ID,
@@ -330,10 +472,18 @@ func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msg
IsError: true,
}
}
a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied)
break
return toolResults, nil
}
// Handle other errors
toolResults[i] = message.ToolResult{
ToolCallID: toolCall.ID,
Content: toolErr.Error(),
IsError: true,
}
continue
}
toolResults[i] = message.ToolResult{
ToolCallID: toolCall.ID,
Content: toolResult.Content,
@@ -342,28 +492,13 @@ func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msg
}
}
}
out:
if len(toolResults) == 0 {
return assistantMsg, nil, nil
}
parts := make([]message.ContentPart, 0)
for _, tr := range toolResults {
parts = append(parts, tr)
}
msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
Role: message.Tool,
Parts: parts,
})
if err != nil {
return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
}
return assistantMsg, &msg, err
return toolResults, nil
}
func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReson message.FinishReason) {
msg.AddFinish(finishReson)
_ = a.messages.Update(ctx, *msg)
_, _ = a.messages.Update(ctx, *msg)
}
func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
@@ -371,19 +506,22 @@ func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg
case <-ctx.Done():
return ctx.Err()
default:
// Continue processing.
// Continue processing
}
switch event.Type {
case provider.EventThinkingDelta:
assistantMsg.AppendReasoningContent(event.Content)
return a.messages.Update(ctx, *assistantMsg)
_, err := a.messages.Update(ctx, *assistantMsg)
return err
case provider.EventContentDelta:
assistantMsg.AppendContent(event.Content)
return a.messages.Update(ctx, *assistantMsg)
_, err := a.messages.Update(ctx, *assistantMsg)
return err
case provider.EventToolUseStart:
assistantMsg.AddToolCall(*event.ToolCall)
return a.messages.Update(ctx, *assistantMsg)
_, err := a.messages.Update(ctx, *assistantMsg)
return err
// TODO: see how to handle this
// case provider.EventToolUseDelta:
// tm := time.Unix(assistantMsg.UpdatedAt, 0)
@@ -395,18 +533,19 @@ func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg
// }
case provider.EventToolUseStop:
assistantMsg.FinishToolCall(event.ToolCall.ID)
return a.messages.Update(ctx, *assistantMsg)
_, err := a.messages.Update(ctx, *assistantMsg)
return err
case provider.EventError:
if errors.Is(event.Error, context.Canceled) {
logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID))
status.Info(fmt.Sprintf("Event processing canceled for session: %s", sessionID))
return context.Canceled
}
logging.ErrorPersist(event.Error.Error())
status.Error(event.Error.Error())
return event.Error
case provider.EventComplete:
assistantMsg.SetToolCalls(event.Response.ToolCalls)
assistantMsg.AddFinish(event.Response.FinishReason)
if err := a.messages.Update(ctx, *assistantMsg); err != nil {
if _, err := a.messages.Update(ctx, *assistantMsg); err != nil {
return fmt.Errorf("failed to update message: %w", err)
}
return a.TrackUsage(ctx, sessionID, a.provider.Model(), event.Response.Usage)
@@ -415,6 +554,49 @@ func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg
return nil
}
func (a *agent) GetUsage(ctx context.Context, sessionID string) (*int64, error) {
session, err := a.sessions.Get(ctx, sessionID)
if err != nil {
return nil, fmt.Errorf("failed to get session: %w", err)
}
usage := session.PromptTokens + session.CompletionTokens
return &usage, nil
}
func (a *agent) EstimateContextWindowUsage(ctx context.Context, sessionID string) (float64, bool, error) {
session, err := a.sessions.Get(ctx, sessionID)
if err != nil {
return 0, false, fmt.Errorf("failed to get session: %w", err)
}
// Get the model's context window size
model := a.provider.Model()
contextWindow := model.ContextWindow
if contextWindow <= 0 {
// Default to a reasonable size if not specified
contextWindow = 100000
}
// Calculate current token usage
currentTokens := session.PromptTokens + session.CompletionTokens
// Get the max tokens setting for the agent
maxTokens := a.provider.MaxTokens()
// Calculate percentage of context window used
usagePercentage := float64(currentTokens) / float64(contextWindow)
// Check if we need to auto-compact
// Auto-compact when:
// 1. Usage exceeds 90% of context window, OR
// 2. Current usage + maxTokens would exceed 100% of context window
needsCompaction := usagePercentage >= 0.9 ||
float64(currentTokens+maxTokens) > float64(contextWindow)
return usagePercentage * 100, needsCompaction, nil
}
func (a *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error {
sess, err := a.sessions.Get(ctx, sessionID)
if err != nil {
@@ -427,10 +609,10 @@ func (a *agent) TrackUsage(ctx context.Context, sessionID string, model models.M
model.CostPer1MOut/1e6*float64(usage.OutputTokens)
sess.Cost += cost
sess.CompletionTokens += usage.OutputTokens
sess.PromptTokens += usage.InputTokens
sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
_, err = a.sessions.Save(ctx, sess)
_, err = a.sessions.Update(ctx, sess)
if err != nil {
return fmt.Errorf("failed to save session: %w", err)
}
@@ -456,6 +638,127 @@ func (a *agent) Update(agentName config.AgentName, modelID models.ModelID) (mode
return a.provider.Model(), nil
}
func (a *agent) CompactSession(ctx context.Context, sessionID string, force bool) error {
// Check if the session is busy
if a.IsSessionBusy(sessionID) && !force {
return ErrSessionBusy
}
// Create a cancellable context
ctx, cancel := context.WithCancel(ctx)
defer cancel()
// Mark the session as busy during compaction
compactionCancelFunc := func() {}
a.activeRequests.Store(sessionID+"-compact", compactionCancelFunc)
defer a.activeRequests.Delete(sessionID + "-compact")
// Fetch the session
session, err := a.sessions.Get(ctx, sessionID)
if err != nil {
return fmt.Errorf("failed to get session: %w", err)
}
// Fetch all messages for the session
sessionMessages, err := a.messages.List(ctx, sessionID)
if err != nil {
return fmt.Errorf("failed to list messages: %w", err)
}
var existingSummary string
if session.Summary != "" && !session.SummarizedAt.IsZero() {
// Filter messages that were created after the last summarization
var newMessages []message.Message
for _, msg := range sessionMessages {
if msg.CreatedAt.After(session.SummarizedAt) {
newMessages = append(newMessages, msg)
}
}
sessionMessages = newMessages
existingSummary = session.Summary
}
// If there are no messages to summarize and no existing summary, return early
if len(sessionMessages) == 0 && existingSummary == "" {
return nil
}
messages := []message.Message{
message.Message{
Role: message.System,
Parts: []message.ContentPart{
message.TextContent{
Text: `You are a helpful AI assistant tasked with summarizing conversations.
When asked to summarize, provide a detailed but concise summary of the conversation.
Focus on information that would be helpful for continuing the conversation, including:
- What was done
- What is currently being worked on
- Which files are being modified
- What needs to be done next
Your summary should be comprehensive enough to provide context but concise enough to be quickly understood.`,
},
},
},
}
// If there's an existing summary, include it
if existingSummary != "" {
messages = append(messages, message.Message{
Role: message.Assistant,
Parts: []message.ContentPart{
message.TextContent{
Text: existingSummary,
},
},
})
}
// Add all messages since the last summarized message
messages = append(messages, sessionMessages...)
// Add a final user message requesting the summary
messages = append(messages, message.Message{
Role: message.User,
Parts: []message.ContentPart{
message.TextContent{
Text: "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next.",
},
},
})
// Call provider to get the summary
response, err := a.provider.SendMessages(ctx, messages, a.tools)
if err != nil {
return fmt.Errorf("failed to get summary from the assistant: %w", err)
}
// Extract the summary text
summaryText := strings.TrimSpace(response.Content)
if summaryText == "" {
return fmt.Errorf("received empty summary from the assistant")
}
// Update the session with the new summary
session.Summary = summaryText
session.SummarizedAt = time.Now()
// Save the updated session
_, err = a.sessions.Update(ctx, session)
if err != nil {
return fmt.Errorf("failed to save session with summary: %w", err)
}
// Track token usage
err = a.TrackUsage(ctx, sessionID, a.provider.Model(), response.Usage)
if err != nil {
return fmt.Errorf("failed to track usage: %w", err)
}
return nil
}
func createAgentProvider(agentName config.AgentName) (provider.Provider, error) {
cfg := config.Get()
agentConfig, ok := cfg.Agents[agentName]
@@ -491,7 +794,7 @@ func createAgentProvider(agentName config.AgentName) (provider.Provider, error)
provider.WithReasoningEffort(agentConfig.ReasoningEffort),
),
)
} else if model.Provider == models.ProviderAnthropic && model.CanReason && agentName == config.AgentCoder {
} else if model.Provider == models.ProviderAnthropic && model.CanReason && agentName == config.AgentPrimary {
opts = append(
opts,
provider.WithAnthropicOptions(

View File

@@ -5,11 +5,11 @@ import (
"encoding/json"
"fmt"
"github.com/opencode-ai/opencode/internal/config"
"github.com/opencode-ai/opencode/internal/llm/tools"
"github.com/opencode-ai/opencode/internal/logging"
"github.com/opencode-ai/opencode/internal/permission"
"github.com/opencode-ai/opencode/internal/version"
"github.com/sst/opencode/internal/config"
"github.com/sst/opencode/internal/llm/tools"
"github.com/sst/opencode/internal/permission"
"github.com/sst/opencode/internal/version"
"log/slog"
"github.com/mark3labs/mcp-go/client"
"github.com/mark3labs/mcp-go/mcp"
@@ -86,6 +86,7 @@ func (b *mcpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolRes
}
permissionDescription := fmt.Sprintf("execute %s with the following parameters: %s", b.Info().Name, params.Input)
p := b.permissions.Request(
ctx,
permission.CreatePermissionRequest{
SessionID: sessionID,
Path: config.WorkingDirectory(),
@@ -146,13 +147,13 @@ func getTools(ctx context.Context, name string, m config.MCPServer, permissions
_, err := c.Initialize(ctx, initRequest)
if err != nil {
logging.Error("error initializing mcp client", "error", err)
slog.Error("error initializing mcp client", "error", err)
return stdioTools
}
toolsRequest := mcp.ListToolsRequest{}
tools, err := c.ListTools(ctx, toolsRequest)
if err != nil {
logging.Error("error listing tools", "error", err)
slog.Error("error listing tools", "error", err)
return stdioTools
}
for _, t := range tools.Tools {
@@ -175,7 +176,7 @@ func GetMcpTools(ctx context.Context, permissions permission.Service) []tools.Ba
m.Args...,
)
if err != nil {
logging.Error("error creating mcp client", "error", err)
slog.Error("error creating mcp client", "error", err)
continue
}
@@ -186,7 +187,7 @@ func GetMcpTools(ctx context.Context, permissions permission.Service) []tools.Ba
client.WithHeaders(m.Headers),
)
if err != nil {
logging.Error("error creating mcp client", "error", err)
slog.Error("error creating mcp client", "error", err)
continue
}
mcpTools = append(mcpTools, getTools(ctx, name, m, permissions, c)...)

View File

@@ -3,15 +3,15 @@ package agent
import (
"context"
"github.com/opencode-ai/opencode/internal/history"
"github.com/opencode-ai/opencode/internal/llm/tools"
"github.com/opencode-ai/opencode/internal/lsp"
"github.com/opencode-ai/opencode/internal/message"
"github.com/opencode-ai/opencode/internal/permission"
"github.com/opencode-ai/opencode/internal/session"
"github.com/sst/opencode/internal/history"
"github.com/sst/opencode/internal/llm/tools"
"github.com/sst/opencode/internal/lsp"
"github.com/sst/opencode/internal/message"
"github.com/sst/opencode/internal/permission"
"github.com/sst/opencode/internal/session"
)
func CoderAgentTools(
func PrimaryAgentTools(
permissions permission.Service,
sessions session.Service,
messages message.Service,
@@ -19,10 +19,8 @@ func CoderAgentTools(
lspClients map[string]*lsp.Client,
) []tools.BaseTool {
ctx := context.Background()
otherTools := GetMcpTools(ctx, permissions)
if len(lspClients) > 0 {
otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients))
}
mcpTools := GetMcpTools(ctx, permissions)
return append(
[]tools.BaseTool{
tools.NewBashTool(permissions),
@@ -31,12 +29,17 @@ func CoderAgentTools(
tools.NewGlobTool(),
tools.NewGrepTool(),
tools.NewLsTool(),
tools.NewSourcegraphTool(),
tools.NewViewTool(lspClients),
tools.NewPatchTool(lspClients, permissions, history),
tools.NewWriteTool(lspClients, permissions, history),
tools.NewDiagnosticsTool(lspClients),
tools.NewDefinitionTool(lspClients),
tools.NewReferencesTool(lspClients),
tools.NewDocSymbolsTool(lspClients),
tools.NewWorkspaceSymbolsTool(lspClients),
tools.NewCodeActionTool(lspClients),
NewAgentTool(sessions, messages, lspClients),
}, otherTools...,
}, mcpTools...,
)
}
@@ -45,7 +48,10 @@ func TaskAgentTools(lspClients map[string]*lsp.Client) []tools.BaseTool {
tools.NewGlobTool(),
tools.NewGrepTool(),
tools.NewLsTool(),
tools.NewSourcegraphTool(),
tools.NewViewTool(lspClients),
tools.NewDefinitionTool(lspClients),
tools.NewReferencesTool(lspClients),
tools.NewDocSymbolsTool(lspClients),
tools.NewWorkspaceSymbolsTool(lspClients),
}
}

View File

@@ -14,64 +14,69 @@ const (
// https://docs.anthropic.com/en/docs/about-claude/models/all-models
var AnthropicModels = map[ModelID]Model{
Claude35Sonnet: {
ID: Claude35Sonnet,
Name: "Claude 3.5 Sonnet",
Provider: ProviderAnthropic,
APIModel: "claude-3-5-sonnet-latest",
CostPer1MIn: 3.0,
CostPer1MInCached: 3.75,
CostPer1MOutCached: 0.30,
CostPer1MOut: 15.0,
ContextWindow: 200000,
DefaultMaxTokens: 5000,
ID: Claude35Sonnet,
Name: "Claude 3.5 Sonnet",
Provider: ProviderAnthropic,
APIModel: "claude-3-5-sonnet-latest",
CostPer1MIn: 3.0,
CostPer1MInCached: 3.75,
CostPer1MOutCached: 0.30,
CostPer1MOut: 15.0,
ContextWindow: 200000,
DefaultMaxTokens: 5000,
SupportsAttachments: true,
},
Claude3Haiku: {
ID: Claude3Haiku,
Name: "Claude 3 Haiku",
Provider: ProviderAnthropic,
APIModel: "claude-3-haiku-20240307", // doesn't support "-latest"
CostPer1MIn: 0.25,
CostPer1MInCached: 0.30,
CostPer1MOutCached: 0.03,
CostPer1MOut: 1.25,
ContextWindow: 200000,
DefaultMaxTokens: 4096,
ID: Claude3Haiku,
Name: "Claude 3 Haiku",
Provider: ProviderAnthropic,
APIModel: "claude-3-haiku-20240307", // doesn't support "-latest"
CostPer1MIn: 0.25,
CostPer1MInCached: 0.30,
CostPer1MOutCached: 0.03,
CostPer1MOut: 1.25,
ContextWindow: 200000,
DefaultMaxTokens: 4096,
SupportsAttachments: true,
},
Claude37Sonnet: {
ID: Claude37Sonnet,
Name: "Claude 3.7 Sonnet",
Provider: ProviderAnthropic,
APIModel: "claude-3-7-sonnet-latest",
CostPer1MIn: 3.0,
CostPer1MInCached: 3.75,
CostPer1MOutCached: 0.30,
CostPer1MOut: 15.0,
ContextWindow: 200000,
DefaultMaxTokens: 50000,
CanReason: true,
ID: Claude37Sonnet,
Name: "Claude 3.7 Sonnet",
Provider: ProviderAnthropic,
APIModel: "claude-3-7-sonnet-latest",
CostPer1MIn: 3.0,
CostPer1MInCached: 3.75,
CostPer1MOutCached: 0.30,
CostPer1MOut: 15.0,
ContextWindow: 200000,
DefaultMaxTokens: 50000,
CanReason: true,
SupportsAttachments: true,
},
Claude35Haiku: {
ID: Claude35Haiku,
Name: "Claude 3.5 Haiku",
Provider: ProviderAnthropic,
APIModel: "claude-3-5-haiku-latest",
CostPer1MIn: 0.80,
CostPer1MInCached: 1.0,
CostPer1MOutCached: 0.08,
CostPer1MOut: 4.0,
ContextWindow: 200000,
DefaultMaxTokens: 4096,
ID: Claude35Haiku,
Name: "Claude 3.5 Haiku",
Provider: ProviderAnthropic,
APIModel: "claude-3-5-haiku-latest",
CostPer1MIn: 0.80,
CostPer1MInCached: 1.0,
CostPer1MOutCached: 0.08,
CostPer1MOut: 4.0,
ContextWindow: 200000,
DefaultMaxTokens: 4096,
SupportsAttachments: true,
},
Claude3Opus: {
ID: Claude3Opus,
Name: "Claude 3 Opus",
Provider: ProviderAnthropic,
APIModel: "claude-3-opus-latest",
CostPer1MIn: 15.0,
CostPer1MInCached: 18.75,
CostPer1MOutCached: 1.50,
CostPer1MOut: 75.0,
ContextWindow: 200000,
DefaultMaxTokens: 4096,
ID: Claude3Opus,
Name: "Claude 3 Opus",
Provider: ProviderAnthropic,
APIModel: "claude-3-opus-latest",
CostPer1MIn: 15.0,
CostPer1MInCached: 18.75,
CostPer1MOutCached: 1.50,
CostPer1MOut: 75.0,
ContextWindow: 200000,
DefaultMaxTokens: 4096,
SupportsAttachments: true,
},
}

View File

@@ -18,140 +18,151 @@ const (
var AzureModels = map[ModelID]Model{
AzureGPT41: {
ID: AzureGPT41,
Name: "Azure OpenAI GPT 4.1",
Provider: ProviderAzure,
APIModel: "gpt-4.1",
CostPer1MIn: OpenAIModels[GPT41].CostPer1MIn,
CostPer1MInCached: OpenAIModels[GPT41].CostPer1MInCached,
CostPer1MOut: OpenAIModels[GPT41].CostPer1MOut,
CostPer1MOutCached: OpenAIModels[GPT41].CostPer1MOutCached,
ContextWindow: OpenAIModels[GPT41].ContextWindow,
DefaultMaxTokens: OpenAIModels[GPT41].DefaultMaxTokens,
ID: AzureGPT41,
Name: "Azure OpenAI GPT 4.1",
Provider: ProviderAzure,
APIModel: "gpt-4.1",
CostPer1MIn: OpenAIModels[GPT41].CostPer1MIn,
CostPer1MInCached: OpenAIModels[GPT41].CostPer1MInCached,
CostPer1MOut: OpenAIModels[GPT41].CostPer1MOut,
CostPer1MOutCached: OpenAIModels[GPT41].CostPer1MOutCached,
ContextWindow: OpenAIModels[GPT41].ContextWindow,
DefaultMaxTokens: OpenAIModels[GPT41].DefaultMaxTokens,
SupportsAttachments: true,
},
AzureGPT41Mini: {
ID: AzureGPT41Mini,
Name: "Azure OpenAI GPT 4.1 mini",
Provider: ProviderAzure,
APIModel: "gpt-4.1-mini",
CostPer1MIn: OpenAIModels[GPT41Mini].CostPer1MIn,
CostPer1MInCached: OpenAIModels[GPT41Mini].CostPer1MInCached,
CostPer1MOut: OpenAIModels[GPT41Mini].CostPer1MOut,
CostPer1MOutCached: OpenAIModels[GPT41Mini].CostPer1MOutCached,
ContextWindow: OpenAIModels[GPT41Mini].ContextWindow,
DefaultMaxTokens: OpenAIModels[GPT41Mini].DefaultMaxTokens,
ID: AzureGPT41Mini,
Name: "Azure OpenAI GPT 4.1 mini",
Provider: ProviderAzure,
APIModel: "gpt-4.1-mini",
CostPer1MIn: OpenAIModels[GPT41Mini].CostPer1MIn,
CostPer1MInCached: OpenAIModels[GPT41Mini].CostPer1MInCached,
CostPer1MOut: OpenAIModels[GPT41Mini].CostPer1MOut,
CostPer1MOutCached: OpenAIModels[GPT41Mini].CostPer1MOutCached,
ContextWindow: OpenAIModels[GPT41Mini].ContextWindow,
DefaultMaxTokens: OpenAIModels[GPT41Mini].DefaultMaxTokens,
SupportsAttachments: true,
},
AzureGPT41Nano: {
ID: AzureGPT41Nano,
Name: "Azure OpenAI GPT 4.1 nano",
Provider: ProviderAzure,
APIModel: "gpt-4.1-nano",
CostPer1MIn: OpenAIModels[GPT41Nano].CostPer1MIn,
CostPer1MInCached: OpenAIModels[GPT41Nano].CostPer1MInCached,
CostPer1MOut: OpenAIModels[GPT41Nano].CostPer1MOut,
CostPer1MOutCached: OpenAIModels[GPT41Nano].CostPer1MOutCached,
ContextWindow: OpenAIModels[GPT41Nano].ContextWindow,
DefaultMaxTokens: OpenAIModels[GPT41Nano].DefaultMaxTokens,
ID: AzureGPT41Nano,
Name: "Azure OpenAI GPT 4.1 nano",
Provider: ProviderAzure,
APIModel: "gpt-4.1-nano",
CostPer1MIn: OpenAIModels[GPT41Nano].CostPer1MIn,
CostPer1MInCached: OpenAIModels[GPT41Nano].CostPer1MInCached,
CostPer1MOut: OpenAIModels[GPT41Nano].CostPer1MOut,
CostPer1MOutCached: OpenAIModels[GPT41Nano].CostPer1MOutCached,
ContextWindow: OpenAIModels[GPT41Nano].ContextWindow,
DefaultMaxTokens: OpenAIModels[GPT41Nano].DefaultMaxTokens,
SupportsAttachments: true,
},
AzureGPT45Preview: {
ID: AzureGPT45Preview,
Name: "Azure OpenAI GPT 4.5 preview",
Provider: ProviderAzure,
APIModel: "gpt-4.5-preview",
CostPer1MIn: OpenAIModels[GPT45Preview].CostPer1MIn,
CostPer1MInCached: OpenAIModels[GPT45Preview].CostPer1MInCached,
CostPer1MOut: OpenAIModels[GPT45Preview].CostPer1MOut,
CostPer1MOutCached: OpenAIModels[GPT45Preview].CostPer1MOutCached,
ContextWindow: OpenAIModels[GPT45Preview].ContextWindow,
DefaultMaxTokens: OpenAIModels[GPT45Preview].DefaultMaxTokens,
ID: AzureGPT45Preview,
Name: "Azure OpenAI GPT 4.5 preview",
Provider: ProviderAzure,
APIModel: "gpt-4.5-preview",
CostPer1MIn: OpenAIModels[GPT45Preview].CostPer1MIn,
CostPer1MInCached: OpenAIModels[GPT45Preview].CostPer1MInCached,
CostPer1MOut: OpenAIModels[GPT45Preview].CostPer1MOut,
CostPer1MOutCached: OpenAIModels[GPT45Preview].CostPer1MOutCached,
ContextWindow: OpenAIModels[GPT45Preview].ContextWindow,
DefaultMaxTokens: OpenAIModels[GPT45Preview].DefaultMaxTokens,
SupportsAttachments: true,
},
AzureGPT4o: {
ID: AzureGPT4o,
Name: "Azure OpenAI GPT-4o",
Provider: ProviderAzure,
APIModel: "gpt-4o",
CostPer1MIn: OpenAIModels[GPT4o].CostPer1MIn,
CostPer1MInCached: OpenAIModels[GPT4o].CostPer1MInCached,
CostPer1MOut: OpenAIModels[GPT4o].CostPer1MOut,
CostPer1MOutCached: OpenAIModels[GPT4o].CostPer1MOutCached,
ContextWindow: OpenAIModels[GPT4o].ContextWindow,
DefaultMaxTokens: OpenAIModels[GPT4o].DefaultMaxTokens,
ID: AzureGPT4o,
Name: "Azure OpenAI GPT-4o",
Provider: ProviderAzure,
APIModel: "gpt-4o",
CostPer1MIn: OpenAIModels[GPT4o].CostPer1MIn,
CostPer1MInCached: OpenAIModels[GPT4o].CostPer1MInCached,
CostPer1MOut: OpenAIModels[GPT4o].CostPer1MOut,
CostPer1MOutCached: OpenAIModels[GPT4o].CostPer1MOutCached,
ContextWindow: OpenAIModels[GPT4o].ContextWindow,
DefaultMaxTokens: OpenAIModels[GPT4o].DefaultMaxTokens,
SupportsAttachments: true,
},
AzureGPT4oMini: {
ID: AzureGPT4oMini,
Name: "Azure OpenAI GPT-4o mini",
Provider: ProviderAzure,
APIModel: "gpt-4o-mini",
CostPer1MIn: OpenAIModels[GPT4oMini].CostPer1MIn,
CostPer1MInCached: OpenAIModels[GPT4oMini].CostPer1MInCached,
CostPer1MOut: OpenAIModels[GPT4oMini].CostPer1MOut,
CostPer1MOutCached: OpenAIModels[GPT4oMini].CostPer1MOutCached,
ContextWindow: OpenAIModels[GPT4oMini].ContextWindow,
DefaultMaxTokens: OpenAIModels[GPT4oMini].DefaultMaxTokens,
ID: AzureGPT4oMini,
Name: "Azure OpenAI GPT-4o mini",
Provider: ProviderAzure,
APIModel: "gpt-4o-mini",
CostPer1MIn: OpenAIModels[GPT4oMini].CostPer1MIn,
CostPer1MInCached: OpenAIModels[GPT4oMini].CostPer1MInCached,
CostPer1MOut: OpenAIModels[GPT4oMini].CostPer1MOut,
CostPer1MOutCached: OpenAIModels[GPT4oMini].CostPer1MOutCached,
ContextWindow: OpenAIModels[GPT4oMini].ContextWindow,
DefaultMaxTokens: OpenAIModels[GPT4oMini].DefaultMaxTokens,
SupportsAttachments: true,
},
AzureO1: {
ID: AzureO1,
Name: "Azure OpenAI O1",
Provider: ProviderAzure,
APIModel: "o1",
CostPer1MIn: OpenAIModels[O1].CostPer1MIn,
CostPer1MInCached: OpenAIModels[O1].CostPer1MInCached,
CostPer1MOut: OpenAIModels[O1].CostPer1MOut,
CostPer1MOutCached: OpenAIModels[O1].CostPer1MOutCached,
ContextWindow: OpenAIModels[O1].ContextWindow,
DefaultMaxTokens: OpenAIModels[O1].DefaultMaxTokens,
CanReason: OpenAIModels[O1].CanReason,
ID: AzureO1,
Name: "Azure OpenAI O1",
Provider: ProviderAzure,
APIModel: "o1",
CostPer1MIn: OpenAIModels[O1].CostPer1MIn,
CostPer1MInCached: OpenAIModels[O1].CostPer1MInCached,
CostPer1MOut: OpenAIModels[O1].CostPer1MOut,
CostPer1MOutCached: OpenAIModels[O1].CostPer1MOutCached,
ContextWindow: OpenAIModels[O1].ContextWindow,
DefaultMaxTokens: OpenAIModels[O1].DefaultMaxTokens,
CanReason: OpenAIModels[O1].CanReason,
SupportsAttachments: true,
},
AzureO1Mini: {
ID: AzureO1Mini,
Name: "Azure OpenAI O1 mini",
Provider: ProviderAzure,
APIModel: "o1-mini",
CostPer1MIn: OpenAIModels[O1Mini].CostPer1MIn,
CostPer1MInCached: OpenAIModels[O1Mini].CostPer1MInCached,
CostPer1MOut: OpenAIModels[O1Mini].CostPer1MOut,
CostPer1MOutCached: OpenAIModels[O1Mini].CostPer1MOutCached,
ContextWindow: OpenAIModels[O1Mini].ContextWindow,
DefaultMaxTokens: OpenAIModels[O1Mini].DefaultMaxTokens,
CanReason: OpenAIModels[O1Mini].CanReason,
ID: AzureO1Mini,
Name: "Azure OpenAI O1 mini",
Provider: ProviderAzure,
APIModel: "o1-mini",
CostPer1MIn: OpenAIModels[O1Mini].CostPer1MIn,
CostPer1MInCached: OpenAIModels[O1Mini].CostPer1MInCached,
CostPer1MOut: OpenAIModels[O1Mini].CostPer1MOut,
CostPer1MOutCached: OpenAIModels[O1Mini].CostPer1MOutCached,
ContextWindow: OpenAIModels[O1Mini].ContextWindow,
DefaultMaxTokens: OpenAIModels[O1Mini].DefaultMaxTokens,
CanReason: OpenAIModels[O1Mini].CanReason,
SupportsAttachments: true,
},
AzureO3: {
ID: AzureO3,
Name: "Azure OpenAI O3",
Provider: ProviderAzure,
APIModel: "o3",
CostPer1MIn: OpenAIModels[O3].CostPer1MIn,
CostPer1MInCached: OpenAIModels[O3].CostPer1MInCached,
CostPer1MOut: OpenAIModels[O3].CostPer1MOut,
CostPer1MOutCached: OpenAIModels[O3].CostPer1MOutCached,
ContextWindow: OpenAIModels[O3].ContextWindow,
DefaultMaxTokens: OpenAIModels[O3].DefaultMaxTokens,
CanReason: OpenAIModels[O3].CanReason,
ID: AzureO3,
Name: "Azure OpenAI O3",
Provider: ProviderAzure,
APIModel: "o3",
CostPer1MIn: OpenAIModels[O3].CostPer1MIn,
CostPer1MInCached: OpenAIModels[O3].CostPer1MInCached,
CostPer1MOut: OpenAIModels[O3].CostPer1MOut,
CostPer1MOutCached: OpenAIModels[O3].CostPer1MOutCached,
ContextWindow: OpenAIModels[O3].ContextWindow,
DefaultMaxTokens: OpenAIModels[O3].DefaultMaxTokens,
CanReason: OpenAIModels[O3].CanReason,
SupportsAttachments: true,
},
AzureO3Mini: {
ID: AzureO3Mini,
Name: "Azure OpenAI O3 mini",
Provider: ProviderAzure,
APIModel: "o3-mini",
CostPer1MIn: OpenAIModels[O3Mini].CostPer1MIn,
CostPer1MInCached: OpenAIModels[O3Mini].CostPer1MInCached,
CostPer1MOut: OpenAIModels[O3Mini].CostPer1MOut,
CostPer1MOutCached: OpenAIModels[O3Mini].CostPer1MOutCached,
ContextWindow: OpenAIModels[O3Mini].ContextWindow,
DefaultMaxTokens: OpenAIModels[O3Mini].DefaultMaxTokens,
CanReason: OpenAIModels[O3Mini].CanReason,
ID: AzureO3Mini,
Name: "Azure OpenAI O3 mini",
Provider: ProviderAzure,
APIModel: "o3-mini",
CostPer1MIn: OpenAIModels[O3Mini].CostPer1MIn,
CostPer1MInCached: OpenAIModels[O3Mini].CostPer1MInCached,
CostPer1MOut: OpenAIModels[O3Mini].CostPer1MOut,
CostPer1MOutCached: OpenAIModels[O3Mini].CostPer1MOutCached,
ContextWindow: OpenAIModels[O3Mini].ContextWindow,
DefaultMaxTokens: OpenAIModels[O3Mini].DefaultMaxTokens,
CanReason: OpenAIModels[O3Mini].CanReason,
SupportsAttachments: false,
},
AzureO4Mini: {
ID: AzureO4Mini,
Name: "Azure OpenAI O4 mini",
Provider: ProviderAzure,
APIModel: "o4-mini",
CostPer1MIn: OpenAIModels[O4Mini].CostPer1MIn,
CostPer1MInCached: OpenAIModels[O4Mini].CostPer1MInCached,
CostPer1MOut: OpenAIModels[O4Mini].CostPer1MOut,
CostPer1MOutCached: OpenAIModels[O4Mini].CostPer1MOutCached,
ContextWindow: OpenAIModels[O4Mini].ContextWindow,
DefaultMaxTokens: OpenAIModels[O4Mini].DefaultMaxTokens,
CanReason: OpenAIModels[O4Mini].CanReason,
ID: AzureO4Mini,
Name: "Azure OpenAI O4 mini",
Provider: ProviderAzure,
APIModel: "o4-mini",
CostPer1MIn: OpenAIModels[O4Mini].CostPer1MIn,
CostPer1MInCached: OpenAIModels[O4Mini].CostPer1MInCached,
CostPer1MOut: OpenAIModels[O4Mini].CostPer1MOut,
CostPer1MOutCached: OpenAIModels[O4Mini].CostPer1MOutCached,
ContextWindow: OpenAIModels[O4Mini].ContextWindow,
DefaultMaxTokens: OpenAIModels[O4Mini].DefaultMaxTokens,
CanReason: OpenAIModels[O4Mini].CanReason,
SupportsAttachments: true,
},
}

View File

@@ -12,52 +12,56 @@ const (
var GeminiModels = map[ModelID]Model{
Gemini25Flash: {
ID: Gemini25Flash,
Name: "Gemini 2.5 Flash",
Provider: ProviderGemini,
APIModel: "gemini-2.5-flash-preview-04-17",
CostPer1MIn: 0.15,
CostPer1MInCached: 0,
CostPer1MOutCached: 0,
CostPer1MOut: 0.60,
ContextWindow: 1000000,
DefaultMaxTokens: 50000,
ID: Gemini25Flash,
Name: "Gemini 2.5 Flash",
Provider: ProviderGemini,
APIModel: "gemini-2.5-flash-preview-04-17",
CostPer1MIn: 0.15,
CostPer1MInCached: 0,
CostPer1MOutCached: 0,
CostPer1MOut: 0.60,
ContextWindow: 1000000,
DefaultMaxTokens: 50000,
SupportsAttachments: true,
},
Gemini25: {
ID: Gemini25,
Name: "Gemini 2.5 Pro",
Provider: ProviderGemini,
APIModel: "gemini-2.5-pro-preview-03-25",
CostPer1MIn: 1.25,
CostPer1MInCached: 0,
CostPer1MOutCached: 0,
CostPer1MOut: 10,
ContextWindow: 1000000,
DefaultMaxTokens: 50000,
ID: Gemini25,
Name: "Gemini 2.5 Pro",
Provider: ProviderGemini,
APIModel: "gemini-2.5-pro-preview-03-25",
CostPer1MIn: 1.25,
CostPer1MInCached: 0,
CostPer1MOutCached: 0,
CostPer1MOut: 10,
ContextWindow: 1000000,
DefaultMaxTokens: 50000,
SupportsAttachments: true,
},
Gemini20Flash: {
ID: Gemini20Flash,
Name: "Gemini 2.0 Flash",
Provider: ProviderGemini,
APIModel: "gemini-2.0-flash",
CostPer1MIn: 0.10,
CostPer1MInCached: 0,
CostPer1MOutCached: 0,
CostPer1MOut: 0.40,
ContextWindow: 1000000,
DefaultMaxTokens: 6000,
ID: Gemini20Flash,
Name: "Gemini 2.0 Flash",
Provider: ProviderGemini,
APIModel: "gemini-2.0-flash",
CostPer1MIn: 0.10,
CostPer1MInCached: 0,
CostPer1MOutCached: 0,
CostPer1MOut: 0.40,
ContextWindow: 1000000,
DefaultMaxTokens: 6000,
SupportsAttachments: true,
},
Gemini20FlashLite: {
ID: Gemini20FlashLite,
Name: "Gemini 2.0 Flash Lite",
Provider: ProviderGemini,
APIModel: "gemini-2.0-flash-lite",
CostPer1MIn: 0.05,
CostPer1MInCached: 0,
CostPer1MOutCached: 0,
CostPer1MOut: 0.30,
ContextWindow: 1000000,
DefaultMaxTokens: 6000,
ID: Gemini20FlashLite,
Name: "Gemini 2.0 Flash Lite",
Provider: ProviderGemini,
APIModel: "gemini-2.0-flash-lite",
CostPer1MIn: 0.05,
CostPer1MInCached: 0,
CostPer1MOutCached: 0,
CostPer1MOut: 0.30,
ContextWindow: 1000000,
DefaultMaxTokens: 6000,
SupportsAttachments: true,
},
}

View File

@@ -28,55 +28,60 @@ var GroqModels = map[ModelID]Model{
ContextWindow: 128_000,
DefaultMaxTokens: 50000,
// for some reason, the groq api doesn't like the reasoningEffort parameter
CanReason: false,
CanReason: false,
SupportsAttachments: false,
},
Llama4Scout: {
ID: Llama4Scout,
Name: "Llama4Scout",
Provider: ProviderGROQ,
APIModel: "meta-llama/llama-4-scout-17b-16e-instruct",
CostPer1MIn: 0.11,
CostPer1MInCached: 0,
CostPer1MOutCached: 0,
CostPer1MOut: 0.34,
ContextWindow: 128_000, // 10M when?
ID: Llama4Scout,
Name: "Llama4Scout",
Provider: ProviderGROQ,
APIModel: "meta-llama/llama-4-scout-17b-16e-instruct",
CostPer1MIn: 0.11,
CostPer1MInCached: 0,
CostPer1MOutCached: 0,
CostPer1MOut: 0.34,
ContextWindow: 128_000, // 10M when?
SupportsAttachments: true,
},
Llama4Maverick: {
ID: Llama4Maverick,
Name: "Llama4Maverick",
Provider: ProviderGROQ,
APIModel: "meta-llama/llama-4-maverick-17b-128e-instruct",
CostPer1MIn: 0.20,
CostPer1MInCached: 0,
CostPer1MOutCached: 0,
CostPer1MOut: 0.20,
ContextWindow: 128_000,
ID: Llama4Maverick,
Name: "Llama4Maverick",
Provider: ProviderGROQ,
APIModel: "meta-llama/llama-4-maverick-17b-128e-instruct",
CostPer1MIn: 0.20,
CostPer1MInCached: 0,
CostPer1MOutCached: 0,
CostPer1MOut: 0.20,
ContextWindow: 128_000,
SupportsAttachments: true,
},
Llama3_3_70BVersatile: {
ID: Llama3_3_70BVersatile,
Name: "Llama3_3_70BVersatile",
Provider: ProviderGROQ,
APIModel: "llama-3.3-70b-versatile",
CostPer1MIn: 0.59,
CostPer1MInCached: 0,
CostPer1MOutCached: 0,
CostPer1MOut: 0.79,
ContextWindow: 128_000,
ID: Llama3_3_70BVersatile,
Name: "Llama3_3_70BVersatile",
Provider: ProviderGROQ,
APIModel: "llama-3.3-70b-versatile",
CostPer1MIn: 0.59,
CostPer1MInCached: 0,
CostPer1MOutCached: 0,
CostPer1MOut: 0.79,
ContextWindow: 128_000,
SupportsAttachments: false,
},
DeepseekR1DistillLlama70b: {
ID: DeepseekR1DistillLlama70b,
Name: "DeepseekR1DistillLlama70b",
Provider: ProviderGROQ,
APIModel: "deepseek-r1-distill-llama-70b",
CostPer1MIn: 0.75,
CostPer1MInCached: 0,
CostPer1MOutCached: 0,
CostPer1MOut: 0.99,
ContextWindow: 128_000,
CanReason: true,
ID: DeepseekR1DistillLlama70b,
Name: "DeepseekR1DistillLlama70b",
Provider: ProviderGROQ,
APIModel: "deepseek-r1-distill-llama-70b",
CostPer1MIn: 0.75,
CostPer1MInCached: 0,
CostPer1MOutCached: 0,
CostPer1MOut: 0.99,
ContextWindow: 128_000,
CanReason: true,
SupportsAttachments: false,
},
}

View File

@@ -8,17 +8,18 @@ type (
)
type Model struct {
ID ModelID `json:"id"`
Name string `json:"name"`
Provider ModelProvider `json:"provider"`
APIModel string `json:"api_model"`
CostPer1MIn float64 `json:"cost_per_1m_in"`
CostPer1MOut float64 `json:"cost_per_1m_out"`
CostPer1MInCached float64 `json:"cost_per_1m_in_cached"`
CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"`
ContextWindow int64 `json:"context_window"`
DefaultMaxTokens int64 `json:"default_max_tokens"`
CanReason bool `json:"can_reason"`
ID ModelID `json:"id"`
Name string `json:"name"`
Provider ModelProvider `json:"provider"`
APIModel string `json:"api_model"`
CostPer1MIn float64 `json:"cost_per_1m_in"`
CostPer1MOut float64 `json:"cost_per_1m_out"`
CostPer1MInCached float64 `json:"cost_per_1m_in_cached"`
CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"`
ContextWindow int64 `json:"context_window"`
DefaultMaxTokens int64 `json:"default_max_tokens"`
CanReason bool `json:"can_reason"`
SupportsAttachments bool `json:"supports_attachments"`
}
// Model IDs
@@ -35,11 +36,13 @@ const (
// Providers in order of popularity
var ProviderPopularity = map[ModelProvider]int{
ProviderAnthropic: 1,
ProviderOpenAI: 2,
ProviderGemini: 3,
ProviderGROQ: 4,
ProviderBedrock: 5,
ProviderAnthropic: 1,
ProviderOpenAI: 2,
ProviderGemini: 3,
ProviderGROQ: 4,
ProviderOpenRouter: 5,
ProviderBedrock: 6,
ProviderAzure: 7,
}
var SupportedModels = map[ModelID]Model{
@@ -69,14 +72,18 @@ var SupportedModels = map[ModelID]Model{
//
// // Bedrock
BedrockClaude37Sonnet: {
ID: BedrockClaude37Sonnet,
Name: "Bedrock: Claude 3.7 Sonnet",
Provider: ProviderBedrock,
APIModel: "anthropic.claude-3-7-sonnet-20250219-v1:0",
CostPer1MIn: 3.0,
CostPer1MInCached: 3.75,
CostPer1MOutCached: 0.30,
CostPer1MOut: 15.0,
ID: BedrockClaude37Sonnet,
Name: "Bedrock: Claude 3.7 Sonnet",
Provider: ProviderBedrock,
APIModel: "anthropic.claude-3-7-sonnet-20250219-v1:0",
CostPer1MIn: 3.0,
CostPer1MInCached: 3.75,
CostPer1MOutCached: 0.30,
CostPer1MOut: 15.0,
ContextWindow: 200_000,
DefaultMaxTokens: 50_000,
CanReason: true,
SupportsAttachments: true,
},
}
@@ -87,4 +94,5 @@ func init() {
maps.Copy(SupportedModels, GroqModels)
maps.Copy(SupportedModels, AzureModels)
maps.Copy(SupportedModels, OpenRouterModels)
maps.Copy(SupportedModels, XAIModels)
}

View File

@@ -19,151 +19,163 @@ const (
var OpenAIModels = map[ModelID]Model{
GPT41: {
ID: GPT41,
Name: "GPT 4.1",
Provider: ProviderOpenAI,
APIModel: "gpt-4.1",
CostPer1MIn: 2.00,
CostPer1MInCached: 0.50,
CostPer1MOutCached: 0.0,
CostPer1MOut: 8.00,
ContextWindow: 1_047_576,
DefaultMaxTokens: 20000,
ID: GPT41,
Name: "GPT 4.1",
Provider: ProviderOpenAI,
APIModel: "gpt-4.1",
CostPer1MIn: 2.00,
CostPer1MInCached: 0.50,
CostPer1MOutCached: 0.0,
CostPer1MOut: 8.00,
ContextWindow: 1_047_576,
DefaultMaxTokens: 20000,
SupportsAttachments: true,
},
GPT41Mini: {
ID: GPT41Mini,
Name: "GPT 4.1 mini",
Provider: ProviderOpenAI,
APIModel: "gpt-4.1",
CostPer1MIn: 0.40,
CostPer1MInCached: 0.10,
CostPer1MOutCached: 0.0,
CostPer1MOut: 1.60,
ContextWindow: 200_000,
DefaultMaxTokens: 20000,
ID: GPT41Mini,
Name: "GPT 4.1 mini",
Provider: ProviderOpenAI,
APIModel: "gpt-4.1",
CostPer1MIn: 0.40,
CostPer1MInCached: 0.10,
CostPer1MOutCached: 0.0,
CostPer1MOut: 1.60,
ContextWindow: 200_000,
DefaultMaxTokens: 20000,
SupportsAttachments: true,
},
GPT41Nano: {
ID: GPT41Nano,
Name: "GPT 4.1 nano",
Provider: ProviderOpenAI,
APIModel: "gpt-4.1-nano",
CostPer1MIn: 0.10,
CostPer1MInCached: 0.025,
CostPer1MOutCached: 0.0,
CostPer1MOut: 0.40,
ContextWindow: 1_047_576,
DefaultMaxTokens: 20000,
ID: GPT41Nano,
Name: "GPT 4.1 nano",
Provider: ProviderOpenAI,
APIModel: "gpt-4.1-nano",
CostPer1MIn: 0.10,
CostPer1MInCached: 0.025,
CostPer1MOutCached: 0.0,
CostPer1MOut: 0.40,
ContextWindow: 1_047_576,
DefaultMaxTokens: 20000,
SupportsAttachments: true,
},
GPT45Preview: {
ID: GPT45Preview,
Name: "GPT 4.5 preview",
Provider: ProviderOpenAI,
APIModel: "gpt-4.5-preview",
CostPer1MIn: 75.00,
CostPer1MInCached: 37.50,
CostPer1MOutCached: 0.0,
CostPer1MOut: 150.00,
ContextWindow: 128_000,
DefaultMaxTokens: 15000,
ID: GPT45Preview,
Name: "GPT 4.5 preview",
Provider: ProviderOpenAI,
APIModel: "gpt-4.5-preview",
CostPer1MIn: 75.00,
CostPer1MInCached: 37.50,
CostPer1MOutCached: 0.0,
CostPer1MOut: 150.00,
ContextWindow: 128_000,
DefaultMaxTokens: 15000,
SupportsAttachments: true,
},
GPT4o: {
ID: GPT4o,
Name: "GPT 4o",
Provider: ProviderOpenAI,
APIModel: "gpt-4o",
CostPer1MIn: 2.50,
CostPer1MInCached: 1.25,
CostPer1MOutCached: 0.0,
CostPer1MOut: 10.00,
ContextWindow: 128_000,
DefaultMaxTokens: 4096,
ID: GPT4o,
Name: "GPT 4o",
Provider: ProviderOpenAI,
APIModel: "gpt-4o",
CostPer1MIn: 2.50,
CostPer1MInCached: 1.25,
CostPer1MOutCached: 0.0,
CostPer1MOut: 10.00,
ContextWindow: 128_000,
DefaultMaxTokens: 4096,
SupportsAttachments: true,
},
GPT4oMini: {
ID: GPT4oMini,
Name: "GPT 4o mini",
Provider: ProviderOpenAI,
APIModel: "gpt-4o-mini",
CostPer1MIn: 0.15,
CostPer1MInCached: 0.075,
CostPer1MOutCached: 0.0,
CostPer1MOut: 0.60,
ContextWindow: 128_000,
ID: GPT4oMini,
Name: "GPT 4o mini",
Provider: ProviderOpenAI,
APIModel: "gpt-4o-mini",
CostPer1MIn: 0.15,
CostPer1MInCached: 0.075,
CostPer1MOutCached: 0.0,
CostPer1MOut: 0.60,
ContextWindow: 128_000,
SupportsAttachments: true,
},
O1: {
ID: O1,
Name: "O1",
Provider: ProviderOpenAI,
APIModel: "o1",
CostPer1MIn: 15.00,
CostPer1MInCached: 7.50,
CostPer1MOutCached: 0.0,
CostPer1MOut: 60.00,
ContextWindow: 200_000,
DefaultMaxTokens: 50000,
CanReason: true,
ID: O1,
Name: "O1",
Provider: ProviderOpenAI,
APIModel: "o1",
CostPer1MIn: 15.00,
CostPer1MInCached: 7.50,
CostPer1MOutCached: 0.0,
CostPer1MOut: 60.00,
ContextWindow: 200_000,
DefaultMaxTokens: 50000,
CanReason: true,
SupportsAttachments: true,
},
O1Pro: {
ID: O1Pro,
Name: "o1 pro",
Provider: ProviderOpenAI,
APIModel: "o1-pro",
CostPer1MIn: 150.00,
CostPer1MInCached: 0.0,
CostPer1MOutCached: 0.0,
CostPer1MOut: 600.00,
ContextWindow: 200_000,
DefaultMaxTokens: 50000,
CanReason: true,
ID: O1Pro,
Name: "o1 pro",
Provider: ProviderOpenAI,
APIModel: "o1-pro",
CostPer1MIn: 150.00,
CostPer1MInCached: 0.0,
CostPer1MOutCached: 0.0,
CostPer1MOut: 600.00,
ContextWindow: 200_000,
DefaultMaxTokens: 50000,
CanReason: true,
SupportsAttachments: true,
},
O1Mini: {
ID: O1Mini,
Name: "o1 mini",
Provider: ProviderOpenAI,
APIModel: "o1-mini",
CostPer1MIn: 1.10,
CostPer1MInCached: 0.55,
CostPer1MOutCached: 0.0,
CostPer1MOut: 4.40,
ContextWindow: 128_000,
DefaultMaxTokens: 50000,
CanReason: true,
ID: O1Mini,
Name: "o1 mini",
Provider: ProviderOpenAI,
APIModel: "o1-mini",
CostPer1MIn: 1.10,
CostPer1MInCached: 0.55,
CostPer1MOutCached: 0.0,
CostPer1MOut: 4.40,
ContextWindow: 128_000,
DefaultMaxTokens: 50000,
CanReason: true,
SupportsAttachments: true,
},
O3: {
ID: O3,
Name: "o3",
Provider: ProviderOpenAI,
APIModel: "o3",
CostPer1MIn: 10.00,
CostPer1MInCached: 2.50,
CostPer1MOutCached: 0.0,
CostPer1MOut: 40.00,
ContextWindow: 200_000,
CanReason: true,
ID: O3,
Name: "o3",
Provider: ProviderOpenAI,
APIModel: "o3",
CostPer1MIn: 10.00,
CostPer1MInCached: 2.50,
CostPer1MOutCached: 0.0,
CostPer1MOut: 40.00,
ContextWindow: 200_000,
CanReason: true,
SupportsAttachments: true,
},
O3Mini: {
ID: O3Mini,
Name: "o3 mini",
Provider: ProviderOpenAI,
APIModel: "o3-mini",
CostPer1MIn: 1.10,
CostPer1MInCached: 0.55,
CostPer1MOutCached: 0.0,
CostPer1MOut: 4.40,
ContextWindow: 200_000,
DefaultMaxTokens: 50000,
CanReason: true,
ID: O3Mini,
Name: "o3 mini",
Provider: ProviderOpenAI,
APIModel: "o3-mini",
CostPer1MIn: 1.10,
CostPer1MInCached: 0.55,
CostPer1MOutCached: 0.0,
CostPer1MOut: 4.40,
ContextWindow: 200_000,
DefaultMaxTokens: 50000,
CanReason: true,
SupportsAttachments: false,
},
O4Mini: {
ID: O4Mini,
Name: "o4 mini",
Provider: ProviderOpenAI,
APIModel: "o4-mini",
CostPer1MIn: 1.10,
CostPer1MInCached: 0.275,
CostPer1MOutCached: 0.0,
CostPer1MOut: 4.40,
ContextWindow: 128_000,
DefaultMaxTokens: 50000,
CanReason: true,
ID: O4Mini,
Name: "o4 mini",
Provider: ProviderOpenAI,
APIModel: "o4-mini",
CostPer1MIn: 1.10,
CostPer1MInCached: 0.275,
CostPer1MOutCached: 0.0,
CostPer1MOut: 4.40,
ContextWindow: 128_000,
DefaultMaxTokens: 50000,
CanReason: true,
SupportsAttachments: true,
},
}

View File

@@ -22,12 +22,17 @@ const (
OpenRouterClaude37Sonnet ModelID = "openrouter.claude-3.7-sonnet"
OpenRouterClaude35Haiku ModelID = "openrouter.claude-3.5-haiku"
OpenRouterClaude3Opus ModelID = "openrouter.claude-3-opus"
OpenRouterQwen235B ModelID = "openrouter.qwen-3-235b"
OpenRouterQwen32B ModelID = "openrouter.qwen-3-32b"
OpenRouterQwen30B ModelID = "openrouter.qwen-3-30b"
OpenRouterQwen14B ModelID = "openrouter.qwen-3-14b"
OpenRouterQwen8B ModelID = "openrouter.qwen-3-8b"
)
var OpenRouterModels = map[ModelID]Model{
OpenRouterGPT41: {
ID: OpenRouterGPT41,
Name: "OpenRouter GPT 4.1",
Name: "OpenRouter: GPT 4.1",
Provider: ProviderOpenRouter,
APIModel: "openai/gpt-4.1",
CostPer1MIn: OpenAIModels[GPT41].CostPer1MIn,
@@ -39,7 +44,7 @@ var OpenRouterModels = map[ModelID]Model{
},
OpenRouterGPT41Mini: {
ID: OpenRouterGPT41Mini,
Name: "OpenRouter GPT 4.1 mini",
Name: "OpenRouter: GPT 4.1 mini",
Provider: ProviderOpenRouter,
APIModel: "openai/gpt-4.1-mini",
CostPer1MIn: OpenAIModels[GPT41Mini].CostPer1MIn,
@@ -51,7 +56,7 @@ var OpenRouterModels = map[ModelID]Model{
},
OpenRouterGPT41Nano: {
ID: OpenRouterGPT41Nano,
Name: "OpenRouter GPT 4.1 nano",
Name: "OpenRouter: GPT 4.1 nano",
Provider: ProviderOpenRouter,
APIModel: "openai/gpt-4.1-nano",
CostPer1MIn: OpenAIModels[GPT41Nano].CostPer1MIn,
@@ -63,7 +68,7 @@ var OpenRouterModels = map[ModelID]Model{
},
OpenRouterGPT45Preview: {
ID: OpenRouterGPT45Preview,
Name: "OpenRouter GPT 4.5 preview",
Name: "OpenRouter: GPT 4.5 preview",
Provider: ProviderOpenRouter,
APIModel: "openai/gpt-4.5-preview",
CostPer1MIn: OpenAIModels[GPT45Preview].CostPer1MIn,
@@ -75,7 +80,7 @@ var OpenRouterModels = map[ModelID]Model{
},
OpenRouterGPT4o: {
ID: OpenRouterGPT4o,
Name: "OpenRouter GPT 4o",
Name: "OpenRouter: GPT 4o",
Provider: ProviderOpenRouter,
APIModel: "openai/gpt-4o",
CostPer1MIn: OpenAIModels[GPT4o].CostPer1MIn,
@@ -87,7 +92,7 @@ var OpenRouterModels = map[ModelID]Model{
},
OpenRouterGPT4oMini: {
ID: OpenRouterGPT4oMini,
Name: "OpenRouter GPT 4o mini",
Name: "OpenRouter: GPT 4o mini",
Provider: ProviderOpenRouter,
APIModel: "openai/gpt-4o-mini",
CostPer1MIn: OpenAIModels[GPT4oMini].CostPer1MIn,
@@ -98,7 +103,7 @@ var OpenRouterModels = map[ModelID]Model{
},
OpenRouterO1: {
ID: OpenRouterO1,
Name: "OpenRouter O1",
Name: "OpenRouter: O1",
Provider: ProviderOpenRouter,
APIModel: "openai/o1",
CostPer1MIn: OpenAIModels[O1].CostPer1MIn,
@@ -111,7 +116,7 @@ var OpenRouterModels = map[ModelID]Model{
},
OpenRouterO1Pro: {
ID: OpenRouterO1Pro,
Name: "OpenRouter o1 pro",
Name: "OpenRouter: o1 pro",
Provider: ProviderOpenRouter,
APIModel: "openai/o1-pro",
CostPer1MIn: OpenAIModels[O1Pro].CostPer1MIn,
@@ -124,7 +129,7 @@ var OpenRouterModels = map[ModelID]Model{
},
OpenRouterO1Mini: {
ID: OpenRouterO1Mini,
Name: "OpenRouter o1 mini",
Name: "OpenRouter: o1 mini",
Provider: ProviderOpenRouter,
APIModel: "openai/o1-mini",
CostPer1MIn: OpenAIModels[O1Mini].CostPer1MIn,
@@ -137,7 +142,7 @@ var OpenRouterModels = map[ModelID]Model{
},
OpenRouterO3: {
ID: OpenRouterO3,
Name: "OpenRouter o3",
Name: "OpenRouter: o3",
Provider: ProviderOpenRouter,
APIModel: "openai/o3",
CostPer1MIn: OpenAIModels[O3].CostPer1MIn,
@@ -150,7 +155,7 @@ var OpenRouterModels = map[ModelID]Model{
},
OpenRouterO3Mini: {
ID: OpenRouterO3Mini,
Name: "OpenRouter o3 mini",
Name: "OpenRouter: o3 mini",
Provider: ProviderOpenRouter,
APIModel: "openai/o3-mini-high",
CostPer1MIn: OpenAIModels[O3Mini].CostPer1MIn,
@@ -163,7 +168,7 @@ var OpenRouterModels = map[ModelID]Model{
},
OpenRouterO4Mini: {
ID: OpenRouterO4Mini,
Name: "OpenRouter o4 mini",
Name: "OpenRouter: o4 mini",
Provider: ProviderOpenRouter,
APIModel: "openai/o4-mini-high",
CostPer1MIn: OpenAIModels[O4Mini].CostPer1MIn,
@@ -176,7 +181,7 @@ var OpenRouterModels = map[ModelID]Model{
},
OpenRouterGemini25Flash: {
ID: OpenRouterGemini25Flash,
Name: "OpenRouter Gemini 2.5 Flash",
Name: "OpenRouter: Gemini 2.5 Flash",
Provider: ProviderOpenRouter,
APIModel: "google/gemini-2.5-flash-preview:thinking",
CostPer1MIn: GeminiModels[Gemini25Flash].CostPer1MIn,
@@ -188,7 +193,7 @@ var OpenRouterModels = map[ModelID]Model{
},
OpenRouterGemini25: {
ID: OpenRouterGemini25,
Name: "OpenRouter Gemini 2.5 Pro",
Name: "OpenRouter: Gemini 2.5 Pro",
Provider: ProviderOpenRouter,
APIModel: "google/gemini-2.5-pro-preview-03-25",
CostPer1MIn: GeminiModels[Gemini25].CostPer1MIn,
@@ -200,7 +205,7 @@ var OpenRouterModels = map[ModelID]Model{
},
OpenRouterClaude35Sonnet: {
ID: OpenRouterClaude35Sonnet,
Name: "OpenRouter Claude 3.5 Sonnet",
Name: "OpenRouter: Claude 3.5 Sonnet",
Provider: ProviderOpenRouter,
APIModel: "anthropic/claude-3.5-sonnet",
CostPer1MIn: AnthropicModels[Claude35Sonnet].CostPer1MIn,
@@ -212,7 +217,7 @@ var OpenRouterModels = map[ModelID]Model{
},
OpenRouterClaude3Haiku: {
ID: OpenRouterClaude3Haiku,
Name: "OpenRouter Claude 3 Haiku",
Name: "OpenRouter: Claude 3 Haiku",
Provider: ProviderOpenRouter,
APIModel: "anthropic/claude-3-haiku",
CostPer1MIn: AnthropicModels[Claude3Haiku].CostPer1MIn,
@@ -224,7 +229,7 @@ var OpenRouterModels = map[ModelID]Model{
},
OpenRouterClaude37Sonnet: {
ID: OpenRouterClaude37Sonnet,
Name: "OpenRouter Claude 3.7 Sonnet",
Name: "OpenRouter: Claude 3.7 Sonnet",
Provider: ProviderOpenRouter,
APIModel: "anthropic/claude-3.7-sonnet",
CostPer1MIn: AnthropicModels[Claude37Sonnet].CostPer1MIn,
@@ -237,7 +242,7 @@ var OpenRouterModels = map[ModelID]Model{
},
OpenRouterClaude35Haiku: {
ID: OpenRouterClaude35Haiku,
Name: "OpenRouter Claude 3.5 Haiku",
Name: "OpenRouter: Claude 3.5 Haiku",
Provider: ProviderOpenRouter,
APIModel: "anthropic/claude-3.5-haiku",
CostPer1MIn: AnthropicModels[Claude35Haiku].CostPer1MIn,
@@ -249,7 +254,7 @@ var OpenRouterModels = map[ModelID]Model{
},
OpenRouterClaude3Opus: {
ID: OpenRouterClaude3Opus,
Name: "OpenRouter Claude 3 Opus",
Name: "OpenRouter: Claude 3 Opus",
Provider: ProviderOpenRouter,
APIModel: "anthropic/claude-3-opus",
CostPer1MIn: AnthropicModels[Claude3Opus].CostPer1MIn,
@@ -259,4 +264,64 @@ var OpenRouterModels = map[ModelID]Model{
ContextWindow: AnthropicModels[Claude3Opus].ContextWindow,
DefaultMaxTokens: AnthropicModels[Claude3Opus].DefaultMaxTokens,
},
OpenRouterQwen235B: {
ID: OpenRouterQwen235B,
Name: "OpenRouter: Qwen3 235B A22B",
Provider: ProviderOpenRouter,
APIModel: "qwen/qwen3-235b-a22b",
CostPer1MIn: 0.1,
CostPer1MInCached: 0.1,
CostPer1MOut: 0.1,
CostPer1MOutCached: 0.1,
ContextWindow: 40960,
DefaultMaxTokens: 4096,
},
OpenRouterQwen32B: {
ID: OpenRouterQwen32B,
Name: "OpenRouter: Qwen3 32B",
Provider: ProviderOpenRouter,
APIModel: "qwen/qwen3-32b",
CostPer1MIn: 0.1,
CostPer1MInCached: 0.1,
CostPer1MOut: 0.3,
CostPer1MOutCached: 0.3,
ContextWindow: 40960,
DefaultMaxTokens: 4096,
},
OpenRouterQwen30B: {
ID: OpenRouterQwen30B,
Name: "OpenRouter: Qwen3 30B A3B",
Provider: ProviderOpenRouter,
APIModel: "qwen/qwen3-30b-a3b",
CostPer1MIn: 0.1,
CostPer1MInCached: 0.1,
CostPer1MOut: 0.3,
CostPer1MOutCached: 0.3,
ContextWindow: 40960,
DefaultMaxTokens: 4096,
},
OpenRouterQwen14B: {
ID: OpenRouterQwen14B,
Name: "OpenRouter: Qwen3 14B",
Provider: ProviderOpenRouter,
APIModel: "qwen/qwen3-14b",
CostPer1MIn: 0.7,
CostPer1MInCached: 0.7,
CostPer1MOut: 0.24,
CostPer1MOutCached: 0.24,
ContextWindow: 40960,
DefaultMaxTokens: 4096,
},
OpenRouterQwen8B: {
ID: OpenRouterQwen8B,
Name: "OpenRouter: Qwen3 8B",
Provider: ProviderOpenRouter,
APIModel: "qwen/qwen3-8b",
CostPer1MIn: 0.35,
CostPer1MInCached: 0.35,
CostPer1MOut: 0.138,
CostPer1MOutCached: 0.138,
ContextWindow: 128000,
DefaultMaxTokens: 4096,
},
}

View File

@@ -0,0 +1,61 @@
package models
const (
ProviderXAI ModelProvider = "xai"
XAIGrok3Beta ModelID = "grok-3-beta"
XAIGrok3MiniBeta ModelID = "grok-3-mini-beta"
XAIGrok3FastBeta ModelID = "grok-3-fast-beta"
XAiGrok3MiniFastBeta ModelID = "grok-3-mini-fast-beta"
)
var XAIModels = map[ModelID]Model{
XAIGrok3Beta: {
ID: XAIGrok3Beta,
Name: "Grok3 Beta",
Provider: ProviderXAI,
APIModel: "grok-3-beta",
CostPer1MIn: 3.0,
CostPer1MInCached: 0,
CostPer1MOut: 15,
CostPer1MOutCached: 0,
ContextWindow: 131_072,
DefaultMaxTokens: 20_000,
},
XAIGrok3MiniBeta: {
ID: XAIGrok3MiniBeta,
Name: "Grok3 Mini Beta",
Provider: ProviderXAI,
APIModel: "grok-3-mini-beta",
CostPer1MIn: 0.3,
CostPer1MInCached: 0,
CostPer1MOut: 0.5,
CostPer1MOutCached: 0,
ContextWindow: 131_072,
DefaultMaxTokens: 20_000,
},
XAIGrok3FastBeta: {
ID: XAIGrok3FastBeta,
Name: "Grok3 Fast Beta",
Provider: ProviderXAI,
APIModel: "grok-3-fast-beta",
CostPer1MIn: 5,
CostPer1MInCached: 0,
CostPer1MOut: 25,
CostPer1MOutCached: 0,
ContextWindow: 131_072,
DefaultMaxTokens: 20_000,
},
XAiGrok3MiniFastBeta: {
ID: XAiGrok3MiniFastBeta,
Name: "Grok3 Mini Fast Beta",
Provider: ProviderXAI,
APIModel: "grok-3-mini-fast-beta",
CostPer1MIn: 0.6,
CostPer1MInCached: 0,
CostPer1MOut: 4.0,
CostPer1MOutCached: 0,
ContextWindow: 131_072,
DefaultMaxTokens: 20_000,
},
}

View File

@@ -8,23 +8,23 @@ import (
"runtime"
"time"
"github.com/opencode-ai/opencode/internal/config"
"github.com/opencode-ai/opencode/internal/llm/models"
"github.com/opencode-ai/opencode/internal/llm/tools"
"github.com/sst/opencode/internal/config"
"github.com/sst/opencode/internal/llm/models"
"github.com/sst/opencode/internal/llm/tools"
)
func CoderPrompt(provider models.ModelProvider) string {
basePrompt := baseAnthropicCoderPrompt
func PrimaryPrompt(provider models.ModelProvider) string {
basePrompt := baseAnthropicPrimaryPrompt
switch provider {
case models.ProviderOpenAI:
basePrompt = baseOpenAICoderPrompt
basePrompt = baseOpenAIPrimaryPrompt
}
envInfo := getEnvironmentInfo()
return fmt.Sprintf("%s\n\n%s\n%s", basePrompt, envInfo, lspInformation())
}
const baseOpenAICoderPrompt = `
const baseOpenAIPrimaryPrompt = `
You are operating as and within the OpenCode CLI, a terminal-based agentic coding assistant built by OpenAI. It wraps OpenAI models to enable natural language interaction with a local codebase. You are expected to be precise, safe, and helpful.
You can:
@@ -71,7 +71,7 @@ You MUST adhere to the following criteria when executing the task:
- Remember the user does not see the full output of tools
`
const baseAnthropicCoderPrompt = `You are OpenCode, an interactive CLI tool that helps users with software engineering tasks. Use the instructions below and the tools available to you to assist the user.
const baseAnthropicPrimaryPrompt = `You are OpenCode, an interactive CLI tool that helps users with software engineering tasks. Use the instructions below and the tools available to you to assist the user.
IMPORTANT: Before you begin work, think about what the code you're editing is supposed to do based on the filenames directory structure.
@@ -81,7 +81,7 @@ If the current working directory contains a file called OpenCode.md, it will be
2. Recording the user's code style preferences (naming conventions, preferred libraries, etc.)
3. Maintaining useful information about the codebase structure and organization
When you spend time searching for commands to typecheck, lint, build, or test, you should ask the user if it's okay to add those commands to OpenCode.md. Similarly, when learning about code style preferences or important codebase information, ask if it's okay to add that to OpenCode.md so you can remember it for next time.
When you spend time searching for commands to typecheck, lint, build, or test, you should ask the user if it's okay to add those commands to CONTEXT.md. Similarly, when learning about code style preferences or important codebase information, ask if it's okay to add that to CONTEXT.md so you can remember it for next time.
# Tone and style
You should be concise, direct, and to the point. When you run a non-trivial bash command, you should explain what the command does and why you are running it, to make sure the user understands what you are doing (this is especially important when you are running a command that will make changes to the user's system).
@@ -103,7 +103,7 @@ assistant: 4
<example>
user: is 11 a prime number?
assistant: true
assistant: yes
</example>
<example>

View File

@@ -7,16 +7,16 @@ import (
"strings"
"sync"
"github.com/opencode-ai/opencode/internal/config"
"github.com/opencode-ai/opencode/internal/llm/models"
"github.com/opencode-ai/opencode/internal/logging"
"github.com/sst/opencode/internal/config"
"github.com/sst/opencode/internal/llm/models"
"log/slog"
)
func GetAgentPrompt(agentName config.AgentName, provider models.ModelProvider) string {
basePrompt := ""
switch agentName {
case config.AgentCoder:
basePrompt = CoderPrompt(provider)
case config.AgentPrimary:
basePrompt = PrimaryPrompt(provider)
case config.AgentTitle:
basePrompt = TitlePrompt(provider)
case config.AgentTask:
@@ -25,10 +25,10 @@ func GetAgentPrompt(agentName config.AgentName, provider models.ModelProvider) s
basePrompt = "You are a helpful assistant"
}
if agentName == config.AgentCoder || agentName == config.AgentTask {
if agentName == config.AgentPrimary || agentName == config.AgentTask {
// Add context from project-specific instruction files if they exist
contextContent := getContextFromPaths()
logging.Debug("Context content", "Context", contextContent)
slog.Debug("Context content", "Context", contextContent)
if contextContent != "" {
return fmt.Sprintf("%s\n\n# Project-Specific Context\n Make sure to follow the instructions in the context below\n%s", basePrompt, contextContent)
}

View File

@@ -0,0 +1,61 @@
package prompt
import (
"fmt"
"log/slog"
"os"
"path/filepath"
"testing"
"github.com/sst/opencode/internal/config"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestGetContextFromPaths(t *testing.T) {
t.Parallel()
lvl := new(slog.LevelVar)
lvl.Set(slog.LevelDebug)
tmpDir := t.TempDir()
_, err := config.Load(tmpDir, false, lvl)
if err != nil {
t.Fatalf("Failed to load config: %v", err)
}
cfg := config.Get()
cfg.WorkingDir = tmpDir
cfg.ContextPaths = []string{
"file.txt",
"directory/",
}
testFiles := []string{
"file.txt",
"directory/file_a.txt",
"directory/file_b.txt",
"directory/file_c.txt",
}
createTestFiles(t, tmpDir, testFiles)
context := getContextFromPaths()
expectedContext := fmt.Sprintf("# From:%s/file.txt\nfile.txt: test content\n# From:%s/directory/file_a.txt\ndirectory/file_a.txt: test content\n# From:%s/directory/file_b.txt\ndirectory/file_b.txt: test content\n# From:%s/directory/file_c.txt\ndirectory/file_c.txt: test content", tmpDir, tmpDir, tmpDir, tmpDir)
assert.Equal(t, expectedContext, context)
}
func createTestFiles(t *testing.T, tmpDir string, testFiles []string) {
t.Helper()
for _, path := range testFiles {
fullPath := filepath.Join(tmpDir, path)
if path[len(path)-1] == '/' {
err := os.MkdirAll(fullPath, 0755)
require.NoError(t, err)
} else {
dir := filepath.Dir(fullPath)
err := os.MkdirAll(dir, 0755)
require.NoError(t, err)
err = os.WriteFile(fullPath, []byte(path+": test content"), 0644)
require.NoError(t, err)
}
}
}

View File

@@ -3,7 +3,7 @@ package prompt
import (
"fmt"
"github.com/opencode-ai/opencode/internal/llm/models"
"github.com/sst/opencode/internal/llm/models"
)
func TaskPrompt(_ models.ModelProvider) string {

View File

@@ -1,6 +1,6 @@
package prompt
import "github.com/opencode-ai/opencode/internal/llm/models"
import "github.com/sst/opencode/internal/llm/models"
func TitlePrompt(_ models.ModelProvider) string {
return `you will generate a short title based on the first message a user begins a conversation with
@@ -8,5 +8,6 @@ func TitlePrompt(_ models.ModelProvider) string {
- the title should be a summary of the user's message
- it should be one line long
- do not use quotes or colons
- the entire text you return will be used as the title`
- the entire text you return will be used as the title
- never return anything that is more than one sentence (one line) long`
}

View File

@@ -12,10 +12,12 @@ import (
"github.com/anthropics/anthropic-sdk-go"
"github.com/anthropics/anthropic-sdk-go/bedrock"
"github.com/anthropics/anthropic-sdk-go/option"
"github.com/opencode-ai/opencode/internal/config"
"github.com/opencode-ai/opencode/internal/llm/tools"
"github.com/opencode-ai/opencode/internal/logging"
"github.com/opencode-ai/opencode/internal/message"
"github.com/sst/opencode/internal/config"
"github.com/sst/opencode/internal/llm/models"
"github.com/sst/opencode/internal/llm/tools"
"github.com/sst/opencode/internal/message"
"github.com/sst/opencode/internal/status"
"log/slog"
)
type anthropicOptions struct {
@@ -70,18 +72,29 @@ func (a *anthropicClient) convertMessages(messages []message.Message) (anthropic
Type: "ephemeral",
}
}
anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(content))
var contentBlocks []anthropic.ContentBlockParamUnion
contentBlocks = append(contentBlocks, content)
for _, binaryContent := range msg.BinaryContent() {
base64Image := binaryContent.String(models.ProviderAnthropic)
imageBlock := anthropic.NewImageBlockBase64(binaryContent.MIMEType, base64Image)
contentBlocks = append(contentBlocks, imageBlock)
}
anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(contentBlocks...))
case message.Assistant:
blocks := []anthropic.ContentBlockParamUnion{}
if msg.Content().String() != "" {
content := anthropic.NewTextBlock(msg.Content().String())
if cache && !a.options.disableCache {
content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
Type: "ephemeral",
if msg.Content() != nil {
content := msg.Content().String()
if strings.TrimSpace(content) != "" {
block := anthropic.NewTextBlock(content)
if cache && !a.options.disableCache {
block.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
Type: "ephemeral",
}
}
blocks = append(blocks, block)
}
blocks = append(blocks, content)
}
for _, toolCall := range msg.ToolCalls() {
@@ -94,7 +107,7 @@ func (a *anthropicClient) convertMessages(messages []message.Message) (anthropic
}
if len(blocks) == 0 {
logging.Warn("There is a message without content, investigate, this should not happen")
slog.Warn("There is a message without content, investigate, this should not happen")
continue
}
anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...))
@@ -196,9 +209,10 @@ func (a *anthropicClient) send(ctx context.Context, messages []message.Message,
preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools))
cfg := config.Get()
if cfg.Debug {
// jsonData, _ := json.Marshal(preparedMessages)
// logging.Debug("Prepared messages", "messages", string(jsonData))
jsonData, _ := json.Marshal(preparedMessages)
slog.Debug("Prepared messages", "messages", string(jsonData))
}
attempts := 0
for {
attempts++
@@ -208,12 +222,13 @@ func (a *anthropicClient) send(ctx context.Context, messages []message.Message,
)
// If there is an error we are going to see if we can retry the call
if err != nil {
slog.Error("Error in Anthropic API call", "error", err)
retry, after, retryErr := a.shouldRetry(attempts, err)
if retryErr != nil {
return nil, retryErr
}
if retry {
logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
status.Warn(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries))
select {
case <-ctx.Done():
return nil, ctx.Err()
@@ -243,8 +258,8 @@ func (a *anthropicClient) stream(ctx context.Context, messages []message.Message
preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools))
cfg := config.Get()
if cfg.Debug {
// jsonData, _ := json.Marshal(preparedMessages)
// logging.Debug("Prepared messages", "messages", string(jsonData))
jsonData, _ := json.Marshal(preparedMessages)
slog.Debug("Prepared messages", "messages", string(jsonData))
}
attempts := 0
eventChan := make(chan ProviderEvent)
@@ -262,7 +277,7 @@ func (a *anthropicClient) stream(ctx context.Context, messages []message.Message
event := anthropicStream.Current()
err := accumulatedMessage.Accumulate(event)
if err != nil {
logging.Warn("Error accumulating message", "error", err)
slog.Warn("Error accumulating message", "error", err)
continue
}
@@ -351,7 +366,7 @@ func (a *anthropicClient) stream(ctx context.Context, messages []message.Message
return
}
if retry {
logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
status.Warn(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries))
select {
case <-ctx.Done():
// context cancelled

View File

@@ -7,8 +7,8 @@ import (
"os"
"strings"
"github.com/opencode-ai/opencode/internal/llm/tools"
"github.com/opencode-ai/opencode/internal/message"
"github.com/sst/opencode/internal/llm/tools"
"github.com/sst/opencode/internal/message"
)
type bedrockOptions struct {
@@ -55,7 +55,7 @@ func newBedrockClient(opts providerClientOptions) BedrockClient {
if strings.Contains(string(opts.model.APIModel), "anthropic") {
// Create Anthropic client with Bedrock configuration
anthropicOpts := opts
anthropicOpts.anthropicOptions = append(anthropicOpts.anthropicOptions,
anthropicOpts.anthropicOptions = append(anthropicOpts.anthropicOptions,
WithAnthropicBedrock(true),
WithAnthropicDisableCache(),
)
@@ -84,7 +84,7 @@ func (b *bedrockClient) send(ctx context.Context, messages []message.Message, to
func (b *bedrockClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
eventChan := make(chan ProviderEvent)
if b.childProvider == nil {
go func() {
eventChan <- ProviderEvent{
@@ -95,6 +95,6 @@ func (b *bedrockClient) stream(ctx context.Context, messages []message.Message,
}()
return eventChan
}
return b.childProvider.stream(ctx, messages, tools)
}
}

View File

@@ -9,14 +9,13 @@ import (
"strings"
"time"
"github.com/google/generative-ai-go/genai"
"github.com/google/uuid"
"github.com/opencode-ai/opencode/internal/config"
"github.com/opencode-ai/opencode/internal/llm/tools"
"github.com/opencode-ai/opencode/internal/logging"
"github.com/opencode-ai/opencode/internal/message"
"google.golang.org/api/iterator"
"google.golang.org/api/option"
"github.com/sst/opencode/internal/config"
"github.com/sst/opencode/internal/llm/tools"
"github.com/sst/opencode/internal/message"
"github.com/sst/opencode/internal/status"
"google.golang.org/genai"
"log/slog"
)
type geminiOptions struct {
@@ -39,9 +38,9 @@ func newGeminiClient(opts providerClientOptions) GeminiClient {
o(&geminiOpts)
}
client, err := genai.NewClient(context.Background(), option.WithAPIKey(opts.apiKey))
client, err := genai.NewClient(context.Background(), &genai.ClientConfig{APIKey: opts.apiKey, Backend: genai.BackendGeminiAPI})
if err != nil {
logging.Error("Failed to create Gemini client", "error", err)
slog.Error("Failed to create Gemini client", "error", err)
return nil
}
@@ -57,27 +56,37 @@ func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Cont
for _, msg := range messages {
switch msg.Role {
case message.User:
var parts []*genai.Part
parts = append(parts, &genai.Part{Text: msg.Content().String()})
for _, binaryContent := range msg.BinaryContent() {
imageFormat := strings.Split(binaryContent.MIMEType, "/")
parts = append(parts, &genai.Part{InlineData: &genai.Blob{
MIMEType: imageFormat[1],
Data: binaryContent.Data,
}})
}
history = append(history, &genai.Content{
Parts: []genai.Part{genai.Text(msg.Content().String())},
Parts: parts,
Role: "user",
})
case message.Assistant:
content := &genai.Content{
Role: "model",
Parts: []genai.Part{},
Parts: []*genai.Part{},
}
if msg.Content().String() != "" {
content.Parts = append(content.Parts, genai.Text(msg.Content().String()))
content.Parts = append(content.Parts, &genai.Part{Text: msg.Content().String()})
}
if len(msg.ToolCalls()) > 0 {
for _, call := range msg.ToolCalls() {
args, _ := parseJsonToMap(call.Input)
content.Parts = append(content.Parts, genai.FunctionCall{
Name: call.Name,
Args: args,
content.Parts = append(content.Parts, &genai.Part{
FunctionCall: &genai.FunctionCall{
Name: call.Name,
Args: args,
},
})
}
}
@@ -105,10 +114,14 @@ func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Cont
}
history = append(history, &genai.Content{
Parts: []genai.Part{genai.FunctionResponse{
Name: toolCall.Name,
Response: response,
}},
Parts: []*genai.Part{
{
FunctionResponse: &genai.FunctionResponse{
Name: toolCall.Name,
Response: response,
},
},
},
Role: "function",
})
}
@@ -152,37 +165,35 @@ func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishRea
}
func (g *geminiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
model := g.client.GenerativeModel(g.providerOptions.model.APIModel)
model.SetMaxOutputTokens(int32(g.providerOptions.maxTokens))
model.SystemInstruction = &genai.Content{
Parts: []genai.Part{
genai.Text(g.providerOptions.systemMessage),
},
}
// Convert tools
if len(tools) > 0 {
model.Tools = g.convertTools(tools)
}
// Convert messages
geminiMessages := g.convertMessages(messages)
cfg := config.Get()
if cfg.Debug {
jsonData, _ := json.Marshal(geminiMessages)
logging.Debug("Prepared messages", "messages", string(jsonData))
slog.Debug("Prepared messages", "messages", string(jsonData))
}
history := geminiMessages[:len(geminiMessages)-1] // All but last message
lastMsg := geminiMessages[len(geminiMessages)-1]
chat, _ := g.client.Chats.Create(ctx, g.providerOptions.model.APIModel, &genai.GenerateContentConfig{
MaxOutputTokens: int32(g.providerOptions.maxTokens),
SystemInstruction: &genai.Content{
Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}},
},
Tools: g.convertTools(tools),
}, history)
attempts := 0
for {
attempts++
var toolCalls []message.ToolCall
chat := model.StartChat()
chat.History = geminiMessages[:len(geminiMessages)-1] // All but last message
lastMsg := geminiMessages[len(geminiMessages)-1]
resp, err := chat.SendMessage(ctx, lastMsg.Parts...)
var lastMsgParts []genai.Part
for _, part := range lastMsg.Parts {
lastMsgParts = append(lastMsgParts, *part)
}
resp, err := chat.SendMessage(ctx, lastMsgParts...)
// If there is an error we are going to see if we can retry the call
if err != nil {
retry, after, retryErr := g.shouldRetry(attempts, err)
@@ -190,7 +201,7 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too
return nil, retryErr
}
if retry {
logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
status.Warn(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries))
select {
case <-ctx.Done():
return nil, ctx.Err()
@@ -205,15 +216,15 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too
if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
for _, part := range resp.Candidates[0].Content.Parts {
switch p := part.(type) {
case genai.Text:
content = string(p)
case genai.FunctionCall:
switch {
case part.Text != "":
content = string(part.Text)
case part.FunctionCall != nil:
id := "call_" + uuid.New().String()
args, _ := json.Marshal(p.Args)
args, _ := json.Marshal(part.FunctionCall.Args)
toolCalls = append(toolCalls, message.ToolCall{
ID: id,
Name: p.Name,
Name: part.FunctionCall.Name,
Input: string(args),
Type: "function",
Finished: true,
@@ -239,27 +250,25 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too
}
func (g *geminiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
model := g.client.GenerativeModel(g.providerOptions.model.APIModel)
model.SetMaxOutputTokens(int32(g.providerOptions.maxTokens))
model.SystemInstruction = &genai.Content{
Parts: []genai.Part{
genai.Text(g.providerOptions.systemMessage),
},
}
// Convert tools
if len(tools) > 0 {
model.Tools = g.convertTools(tools)
}
// Convert messages
geminiMessages := g.convertMessages(messages)
cfg := config.Get()
if cfg.Debug {
jsonData, _ := json.Marshal(geminiMessages)
logging.Debug("Prepared messages", "messages", string(jsonData))
slog.Debug("Prepared messages", "messages", string(jsonData))
}
history := geminiMessages[:len(geminiMessages)-1] // All but last message
lastMsg := geminiMessages[len(geminiMessages)-1]
chat, _ := g.client.Chats.Create(ctx, g.providerOptions.model.APIModel, &genai.GenerateContentConfig{
MaxOutputTokens: int32(g.providerOptions.maxTokens),
SystemInstruction: &genai.Content{
Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}},
},
Tools: g.convertTools(tools),
}, history)
attempts := 0
eventChan := make(chan ProviderEvent)
@@ -268,11 +277,6 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
for {
attempts++
chat := model.StartChat()
chat.History = geminiMessages[:len(geminiMessages)-1]
lastMsg := geminiMessages[len(geminiMessages)-1]
iter := chat.SendMessageStream(ctx, lastMsg.Parts...)
currentContent := ""
toolCalls := []message.ToolCall{}
@@ -280,11 +284,12 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
eventChan <- ProviderEvent{Type: EventContentStart}
for {
resp, err := iter.Next()
if err == iterator.Done {
break
}
var lastMsgParts []genai.Part
for _, part := range lastMsg.Parts {
lastMsgParts = append(lastMsgParts, *part)
}
for resp, err := range chat.SendMessageStream(ctx, lastMsgParts...) {
if err != nil {
retry, after, retryErr := g.shouldRetry(attempts, err)
if retryErr != nil {
@@ -292,7 +297,7 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
return
}
if retry {
logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
status.Warn(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries))
select {
case <-ctx.Done():
if ctx.Err() != nil {
@@ -313,9 +318,9 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
for _, part := range resp.Candidates[0].Content.Parts {
switch p := part.(type) {
case genai.Text:
delta := string(p)
switch {
case part.Text != "":
delta := string(part.Text)
if delta != "" {
eventChan <- ProviderEvent{
Type: EventContentDelta,
@@ -323,12 +328,12 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
}
currentContent += delta
}
case genai.FunctionCall:
case part.FunctionCall != nil:
id := "call_" + uuid.New().String()
args, _ := json.Marshal(p.Args)
args, _ := json.Marshal(part.FunctionCall.Args)
newCall := message.ToolCall{
ID: id,
Name: p.Name,
Name: part.FunctionCall.Name,
Input: string(args),
Type: "function",
Finished: true,
@@ -416,12 +421,12 @@ func (g *geminiClient) toolCalls(resp *genai.GenerateContentResponse) []message.
if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
for _, part := range resp.Candidates[0].Content.Parts {
if funcCall, ok := part.(genai.FunctionCall); ok {
if part.FunctionCall != nil {
id := "call_" + uuid.New().String()
args, _ := json.Marshal(funcCall.Args)
args, _ := json.Marshal(part.FunctionCall.Args)
toolCalls = append(toolCalls, message.ToolCall{
ID: id,
Name: funcCall.Name,
Name: part.FunctionCall.Name,
Input: string(args),
Type: "function",
})

View File

@@ -11,10 +11,12 @@ import (
"github.com/openai/openai-go"
"github.com/openai/openai-go/option"
"github.com/openai/openai-go/shared"
"github.com/opencode-ai/opencode/internal/config"
"github.com/opencode-ai/opencode/internal/llm/tools"
"github.com/opencode-ai/opencode/internal/logging"
"github.com/opencode-ai/opencode/internal/message"
"github.com/sst/opencode/internal/config"
"github.com/sst/opencode/internal/llm/models"
"github.com/sst/opencode/internal/llm/tools"
"github.com/sst/opencode/internal/message"
"github.com/sst/opencode/internal/status"
"log/slog"
)
type openaiOptions struct {
@@ -71,7 +73,17 @@ func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessag
for _, msg := range messages {
switch msg.Role {
case message.User:
openaiMessages = append(openaiMessages, openai.UserMessage(msg.Content().String()))
var content []openai.ChatCompletionContentPartUnionParam
textBlock := openai.ChatCompletionContentPartTextParam{Text: msg.Content().String()}
content = append(content, openai.ChatCompletionContentPartUnionParam{OfText: &textBlock})
for _, binaryContent := range msg.BinaryContent() {
imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: binaryContent.String(models.ProviderOpenAI)}
imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL}
content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
}
openaiMessages = append(openaiMessages, openai.UserMessage(content))
case message.Assistant:
assistantMsg := openai.ChatCompletionAssistantMessageParam{
@@ -171,6 +183,14 @@ func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessagePar
params.MaxTokens = openai.Int(o.providerOptions.maxTokens)
}
if o.providerOptions.model.Provider == models.ProviderOpenRouter {
params.WithExtraFields(map[string]any{
"provider": map[string]any{
"require_parameters": true,
},
})
}
return params
}
@@ -179,7 +199,7 @@ func (o *openaiClient) send(ctx context.Context, messages []message.Message, too
cfg := config.Get()
if cfg.Debug {
jsonData, _ := json.Marshal(params)
logging.Debug("Prepared messages", "messages", string(jsonData))
slog.Debug("Prepared messages", "messages", string(jsonData))
}
attempts := 0
for {
@@ -195,7 +215,7 @@ func (o *openaiClient) send(ctx context.Context, messages []message.Message, too
return nil, retryErr
}
if retry {
logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
status.Warn(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries))
select {
case <-ctx.Done():
return nil, ctx.Err()
@@ -236,7 +256,7 @@ func (o *openaiClient) stream(ctx context.Context, messages []message.Message, t
cfg := config.Get()
if cfg.Debug {
jsonData, _ := json.Marshal(params)
logging.Debug("Prepared messages", "messages", string(jsonData))
slog.Debug("Prepared messages", "messages", string(jsonData))
}
attempts := 0
@@ -258,15 +278,6 @@ func (o *openaiClient) stream(ctx context.Context, messages []message.Message, t
chunk := openaiStream.Current()
acc.AddChunk(chunk)
if tool, ok := acc.JustFinishedToolCall(); ok {
toolCalls = append(toolCalls, message.ToolCall{
ID: tool.Id,
Name: tool.Name,
Input: tool.Arguments,
Type: "function",
})
}
for _, choice := range chunk.Choices {
if choice.Delta.Content != "" {
eventChan <- ProviderEvent{
@@ -282,7 +293,9 @@ func (o *openaiClient) stream(ctx context.Context, messages []message.Message, t
if err == nil || errors.Is(err, io.EOF) {
// Stream completed successfully
finishReason := o.finishReason(string(acc.ChatCompletion.Choices[0].FinishReason))
if len(acc.ChatCompletion.Choices[0].Message.ToolCalls) > 0 {
toolCalls = append(toolCalls, o.toolCalls(acc.ChatCompletion)...)
}
if len(toolCalls) > 0 {
finishReason = message.FinishReasonToolUse
}
@@ -308,7 +321,7 @@ func (o *openaiClient) stream(ctx context.Context, messages []message.Message, t
return
}
if retry {
logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
status.Warn(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries))
select {
case <-ctx.Done():
// context cancelled
@@ -414,7 +427,7 @@ func WithReasoningEffort(effort string) OpenAIOption {
case "low", "medium", "high":
defaultReasoningEffort = effort
default:
logging.Warn("Invalid reasoning effort, using default: medium")
slog.Warn("Invalid reasoning effort, using default: medium")
}
options.reasoningEffort = defaultReasoningEffort
}

View File

@@ -4,9 +4,10 @@ import (
"context"
"fmt"
"github.com/opencode-ai/opencode/internal/llm/models"
"github.com/opencode-ai/opencode/internal/llm/tools"
"github.com/opencode-ai/opencode/internal/message"
"github.com/sst/opencode/internal/llm/models"
"github.com/sst/opencode/internal/llm/tools"
"github.com/sst/opencode/internal/message"
"log/slog"
)
type EventType string
@@ -55,6 +56,8 @@ type Provider interface {
StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
Model() models.Model
MaxTokens() int64
}
type providerClientOptions struct {
@@ -132,6 +135,15 @@ func NewProvider(providerName models.ModelProvider, opts ...ProviderClientOption
options: clientOptions,
client: newOpenAIClient(clientOptions),
}, nil
case models.ProviderXAI:
clientOptions.openaiOptions = append(clientOptions.openaiOptions,
WithOpenAIBaseURL("https://api.x.ai/v1"),
)
return &baseProvider[OpenAIClient]{
options: clientOptions,
client: newOpenAIClient(clientOptions),
}, nil
case models.ProviderMock:
// TODO: implement mock client for test
panic("not implemented")
@@ -152,16 +164,55 @@ func (p *baseProvider[C]) cleanMessages(messages []message.Message) (cleaned []m
func (p *baseProvider[C]) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
messages = p.cleanMessages(messages)
return p.client.send(ctx, messages, tools)
response, err := p.client.send(ctx, messages, tools)
if err == nil && response != nil {
slog.Debug("API request token usage",
"model", p.options.model.Name,
"input_tokens", response.Usage.InputTokens,
"output_tokens", response.Usage.OutputTokens,
"cache_creation_tokens", response.Usage.CacheCreationTokens,
"cache_read_tokens", response.Usage.CacheReadTokens,
"total_tokens", response.Usage.InputTokens+response.Usage.OutputTokens)
}
return response, err
}
func (p *baseProvider[C]) Model() models.Model {
return p.options.model
}
func (p *baseProvider[C]) MaxTokens() int64 {
return p.options.maxTokens
}
func (p *baseProvider[C]) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
messages = p.cleanMessages(messages)
return p.client.stream(ctx, messages, tools)
eventChan := p.client.stream(ctx, messages, tools)
// Create a new channel to intercept events
wrappedChan := make(chan ProviderEvent)
go func() {
defer close(wrappedChan)
for event := range eventChan {
// Pass the event through
wrappedChan <- event
// Log token usage when we get the complete event
if event.Type == EventComplete && event.Response != nil {
slog.Debug("API streaming request token usage",
"model", p.options.model.Name,
"input_tokens", event.Response.Usage.InputTokens,
"output_tokens", event.Response.Usage.OutputTokens,
"cache_creation_tokens", event.Response.Usage.CacheCreationTokens,
"cache_read_tokens", event.Response.Usage.CacheReadTokens,
"total_tokens", event.Response.Usage.InputTokens+event.Response.Usage.OutputTokens)
}
}
}()
return wrappedChan
}
func WithAPIKey(apiKey string) ProviderClientOption {

View File

@@ -7,9 +7,9 @@ import (
"strings"
"time"
"github.com/opencode-ai/opencode/internal/config"
"github.com/opencode-ai/opencode/internal/llm/tools/shell"
"github.com/opencode-ai/opencode/internal/permission"
"github.com/sst/opencode/internal/config"
"github.com/sst/opencode/internal/llm/tools/shell"
"github.com/sst/opencode/internal/permission"
)
type BashParams struct {
@@ -268,6 +268,7 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
}
if !isSafeReadOnly {
p := b.permissions.Request(
ctx,
permission.CreatePermissionRequest{
SessionID: sessionID,
Path: config.WorkingDirectory(),

View File

@@ -9,12 +9,12 @@ import (
"strings"
"time"
"github.com/opencode-ai/opencode/internal/config"
"github.com/opencode-ai/opencode/internal/diff"
"github.com/opencode-ai/opencode/internal/history"
"github.com/opencode-ai/opencode/internal/logging"
"github.com/opencode-ai/opencode/internal/lsp"
"github.com/opencode-ai/opencode/internal/permission"
"github.com/sst/opencode/internal/config"
"github.com/sst/opencode/internal/diff"
"github.com/sst/opencode/internal/history"
"github.com/sst/opencode/internal/lsp"
"github.com/sst/opencode/internal/permission"
"log/slog"
)
type EditParams struct {
@@ -37,7 +37,7 @@ type EditResponseMetadata struct {
type editTool struct {
lspClients map[string]*lsp.Client
permissions permission.Service
files history.Service
history history.Service
}
const (
@@ -95,7 +95,7 @@ func NewEditTool(lspClients map[string]*lsp.Client, permissions permission.Servi
return &editTool{
lspClients: lspClients,
permissions: permissions,
files: files,
history: files,
}
}
@@ -202,6 +202,7 @@ func (e *editTool) createNewFile(ctx context.Context, filePath, content string)
permissionPath = rootDir
}
p := e.permissions.Request(
ctx,
permission.CreatePermissionRequest{
SessionID: sessionID,
Path: permissionPath,
@@ -224,17 +225,17 @@ func (e *editTool) createNewFile(ctx context.Context, filePath, content string)
}
// File can't be in the history so we create a new file history
_, err = e.files.Create(ctx, sessionID, filePath, "")
_, err = e.history.Create(ctx, sessionID, filePath, "")
if err != nil {
// Log error but don't fail the operation
return ToolResponse{}, fmt.Errorf("error creating file history: %w", err)
}
// Add the new content to the file history
_, err = e.files.CreateVersion(ctx, sessionID, filePath, content)
_, err = e.history.CreateVersion(ctx, sessionID, filePath, content)
if err != nil {
// Log error but don't fail the operation
logging.Debug("Error creating file history version", "error", err)
slog.Debug("Error creating file history version", "error", err)
}
recordFileWrite(filePath)
@@ -313,6 +314,7 @@ func (e *editTool) deleteContent(ctx context.Context, filePath, oldString string
permissionPath = rootDir
}
p := e.permissions.Request(
ctx,
permission.CreatePermissionRequest{
SessionID: sessionID,
Path: permissionPath,
@@ -335,9 +337,9 @@ func (e *editTool) deleteContent(ctx context.Context, filePath, oldString string
}
// Check if file exists in history
file, err := e.files.GetByPathAndSession(ctx, filePath, sessionID)
file, err := e.history.GetLatestByPathAndSession(ctx, filePath, sessionID)
if err != nil {
_, err = e.files.Create(ctx, sessionID, filePath, oldContent)
_, err = e.history.Create(ctx, sessionID, filePath, oldContent)
if err != nil {
// Log error but don't fail the operation
return ToolResponse{}, fmt.Errorf("error creating file history: %w", err)
@@ -345,15 +347,15 @@ func (e *editTool) deleteContent(ctx context.Context, filePath, oldString string
}
if file.Content != oldContent {
// User Manually changed the content store an intermediate version
_, err = e.files.CreateVersion(ctx, sessionID, filePath, oldContent)
_, err = e.history.CreateVersion(ctx, sessionID, filePath, oldContent)
if err != nil {
logging.Debug("Error creating file history version", "error", err)
slog.Debug("Error creating file history version", "error", err)
}
}
// Store the new version
_, err = e.files.CreateVersion(ctx, sessionID, filePath, "")
_, err = e.history.CreateVersion(ctx, sessionID, filePath, "")
if err != nil {
logging.Debug("Error creating file history version", "error", err)
slog.Debug("Error creating file history version", "error", err)
}
recordFileWrite(filePath)
@@ -433,6 +435,7 @@ func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newS
permissionPath = rootDir
}
p := e.permissions.Request(
ctx,
permission.CreatePermissionRequest{
SessionID: sessionID,
Path: permissionPath,
@@ -455,9 +458,9 @@ func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newS
}
// Check if file exists in history
file, err := e.files.GetByPathAndSession(ctx, filePath, sessionID)
file, err := e.history.GetLatestByPathAndSession(ctx, filePath, sessionID)
if err != nil {
_, err = e.files.Create(ctx, sessionID, filePath, oldContent)
_, err = e.history.Create(ctx, sessionID, filePath, oldContent)
if err != nil {
// Log error but don't fail the operation
return ToolResponse{}, fmt.Errorf("error creating file history: %w", err)
@@ -465,15 +468,15 @@ func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newS
}
if file.Content != oldContent {
// User Manually changed the content store an intermediate version
_, err = e.files.CreateVersion(ctx, sessionID, filePath, oldContent)
_, err = e.history.CreateVersion(ctx, sessionID, filePath, oldContent)
if err != nil {
logging.Debug("Error creating file history version", "error", err)
slog.Debug("Error creating file history version", "error", err)
}
}
// Store the new version
_, err = e.files.CreateVersion(ctx, sessionID, filePath, newContent)
_, err = e.history.CreateVersion(ctx, sessionID, filePath, newContent)
if err != nil {
logging.Debug("Error creating file history version", "error", err)
slog.Debug("Error creating file history version", "error", err)
}
recordFileWrite(filePath)

View File

@@ -11,8 +11,8 @@ import (
md "github.com/JohannesKaufmann/html-to-markdown"
"github.com/PuerkitoBio/goquery"
"github.com/opencode-ai/opencode/internal/config"
"github.com/opencode-ai/opencode/internal/permission"
"github.com/sst/opencode/internal/config"
"github.com/sst/opencode/internal/permission"
)
type FetchParams struct {
@@ -122,6 +122,7 @@ func (t *fetchTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error
}
p := t.permissions.Request(
ctx,
permission.CreatePermissionRequest{
SessionID: sessionID,
Path: config.WorkingDirectory(),

View File

@@ -14,7 +14,7 @@ import (
"time"
"github.com/bmatcuk/doublestar/v4"
"github.com/opencode-ai/opencode/internal/config"
"github.com/sst/opencode/internal/config"
)
const (

View File

@@ -14,7 +14,7 @@ import (
"strings"
"time"
"github.com/opencode-ai/opencode/internal/config"
"github.com/sst/opencode/internal/config"
)
type GrepParams struct {

View File

@@ -8,7 +8,7 @@ import (
"path/filepath"
"strings"
"github.com/opencode-ai/opencode/internal/config"
"github.com/sst/opencode/internal/config"
)
type LSParams struct {

View File

@@ -83,19 +83,19 @@ func TestLsTool_Run(t *testing.T) {
response, err := tool.Run(context.Background(), call)
require.NoError(t, err)
// Check that visible directories and files are included
assert.Contains(t, response.Content, "dir1")
assert.Contains(t, response.Content, "dir2")
assert.Contains(t, response.Content, "dir3")
assert.Contains(t, response.Content, "file1.txt")
assert.Contains(t, response.Content, "file2.txt")
// Check that hidden files and directories are not included
assert.NotContains(t, response.Content, ".hidden_dir")
assert.NotContains(t, response.Content, ".hidden_file.txt")
assert.NotContains(t, response.Content, ".hidden_root_file.txt")
// Check that __pycache__ is not included
assert.NotContains(t, response.Content, "__pycache__")
})
@@ -122,7 +122,7 @@ func TestLsTool_Run(t *testing.T) {
t.Run("handles empty path parameter", func(t *testing.T) {
// For this test, we need to mock the config.WorkingDirectory function
// Since we can't easily do that, we'll just check that the response doesn't contain an error message
tool := NewLsTool()
params := LSParams{
Path: "",
@@ -138,7 +138,7 @@ func TestLsTool_Run(t *testing.T) {
response, err := tool.Run(context.Background(), call)
require.NoError(t, err)
// The response should either contain a valid directory listing or an error
// We'll just check that it's not empty
assert.NotEmpty(t, response.Content)
@@ -173,11 +173,11 @@ func TestLsTool_Run(t *testing.T) {
response, err := tool.Run(context.Background(), call)
require.NoError(t, err)
// The output format is a tree, so we need to check for specific patterns
// Check that file1.txt is not directly mentioned
assert.NotContains(t, response.Content, "- file1.txt")
// Check that dir1/ is not directly mentioned
assert.NotContains(t, response.Content, "- dir1/")
})
@@ -189,12 +189,12 @@ func TestLsTool_Run(t *testing.T) {
defer func() {
os.Chdir(origWd)
}()
// Change to a directory above the temp directory
parentDir := filepath.Dir(tempDir)
err = os.Chdir(parentDir)
require.NoError(t, err)
tool := NewLsTool()
params := LSParams{
Path: filepath.Base(tempDir),
@@ -210,7 +210,7 @@ func TestLsTool_Run(t *testing.T) {
response, err := tool.Run(context.Background(), call)
require.NoError(t, err)
// Should list the temp directory contents
assert.Contains(t, response.Content, "dir1")
assert.Contains(t, response.Content, "file1.txt")
@@ -291,22 +291,22 @@ func TestCreateFileTree(t *testing.T) {
}
tree := createFileTree(paths)
// Check the structure of the tree
assert.Len(t, tree, 1) // Should have one root node
// Check the root node
rootNode := tree[0]
assert.Equal(t, "path", rootNode.Name)
assert.Equal(t, "directory", rootNode.Type)
assert.Len(t, rootNode.Children, 1)
// Check the "to" node
toNode := rootNode.Children[0]
assert.Equal(t, "to", toNode.Name)
assert.Equal(t, "directory", toNode.Type)
assert.Len(t, toNode.Children, 3) // file1.txt, dir1, dir2
// Find the dir1 node
var dir1Node *TreeNode
for _, child := range toNode.Children {
@@ -315,7 +315,7 @@ func TestCreateFileTree(t *testing.T) {
break
}
}
require.NotNil(t, dir1Node)
assert.Equal(t, "directory", dir1Node.Type)
assert.Len(t, dir1Node.Children, 2) // file2.txt and subdir
@@ -354,9 +354,9 @@ func TestPrintTree(t *testing.T) {
Type: "file",
},
}
result := printTree(tree, "/root")
// Check the output format
assert.Contains(t, result, "- /root/")
assert.Contains(t, result, " - dir1/")
@@ -405,7 +405,7 @@ func TestListDirectory(t *testing.T) {
files, truncated, err := listDirectory(tempDir, []string{}, 1000)
require.NoError(t, err)
assert.False(t, truncated)
// Check that visible files and directories are included
containsPath := func(paths []string, target string) bool {
targetPath := filepath.Join(tempDir, target)
@@ -416,12 +416,12 @@ func TestListDirectory(t *testing.T) {
}
return false
}
assert.True(t, containsPath(files, "dir1"))
assert.True(t, containsPath(files, "file1.txt"))
assert.True(t, containsPath(files, "file2.txt"))
assert.True(t, containsPath(files, "dir1/file3.txt"))
// Check that hidden files and directories are not included
assert.False(t, containsPath(files, ".hidden_dir"))
assert.False(t, containsPath(files, ".hidden_file.txt"))
@@ -438,12 +438,12 @@ func TestListDirectory(t *testing.T) {
files, truncated, err := listDirectory(tempDir, []string{"*.txt"}, 1000)
require.NoError(t, err)
assert.False(t, truncated)
// Check that no .txt files are included
for _, file := range files {
assert.False(t, strings.HasSuffix(file, ".txt"), "Found .txt file: %s", file)
}
// But directories should still be included
containsDir := false
for _, file := range files {
@@ -454,4 +454,4 @@ func TestListDirectory(t *testing.T) {
}
assert.True(t, containsDir)
})
}
}

View File

@@ -0,0 +1,350 @@
package tools
import (
"context"
"encoding/json"
"fmt"
"strings"
"github.com/sst/opencode/internal/lsp"
"github.com/sst/opencode/internal/lsp/protocol"
"github.com/sst/opencode/internal/lsp/util"
)
type CodeActionParams struct {
FilePath string `json:"file_path"`
Line int `json:"line"`
Column int `json:"column"`
EndLine int `json:"end_line,omitempty"`
EndColumn int `json:"end_column,omitempty"`
ActionID int `json:"action_id,omitempty"`
LspName string `json:"lsp_name,omitempty"`
}
type codeActionTool struct {
lspClients map[string]*lsp.Client
}
const (
CodeActionToolName = "codeAction"
codeActionDescription = `Get available code actions at a specific position or range in a file.
WHEN TO USE THIS TOOL:
- Use when you need to find available fixes or refactorings for code issues
- Helpful for resolving errors, warnings, or improving code quality
- Great for discovering automated code transformations
HOW TO USE:
- Provide the path to the file containing the code
- Specify the line number (1-based) where the action should be applied
- Specify the column number (1-based) where the action should be applied
- Optionally specify end_line and end_column to define a range
- Results show available code actions with their titles and kinds
TO EXECUTE A CODE ACTION:
- After getting the list of available actions, call the tool again with the same parameters
- Add action_id parameter with the number of the action you want to execute (e.g., 1 for the first action)
- Add lsp_name parameter with the name of the LSP server that provided the action
FEATURES:
- Finds quick fixes for errors and warnings
- Discovers available refactorings
- Shows code organization actions
- Returns detailed information about each action
- Can execute selected code actions
LIMITATIONS:
- Requires a functioning LSP server for the file type
- May not work for all code issues depending on LSP capabilities
- Results depend on the accuracy of the LSP server
TIPS:
- Use in conjunction with Diagnostics tool to find issues that can be fixed
- First call without action_id to see available actions, then call again with action_id to execute
`
)
func NewCodeActionTool(lspClients map[string]*lsp.Client) BaseTool {
return &codeActionTool{
lspClients,
}
}
func (b *codeActionTool) Info() ToolInfo {
return ToolInfo{
Name: CodeActionToolName,
Description: codeActionDescription,
Parameters: map[string]any{
"file_path": map[string]any{
"type": "string",
"description": "The path to the file containing the code",
},
"line": map[string]any{
"type": "integer",
"description": "The line number (1-based) where the action should be applied",
},
"column": map[string]any{
"type": "integer",
"description": "The column number (1-based) where the action should be applied",
},
"end_line": map[string]any{
"type": "integer",
"description": "The ending line number (1-based) for a range (optional)",
},
"end_column": map[string]any{
"type": "integer",
"description": "The ending column number (1-based) for a range (optional)",
},
"action_id": map[string]any{
"type": "integer",
"description": "The ID of the code action to execute (optional)",
},
"lsp_name": map[string]any{
"type": "string",
"description": "The name of the LSP server that provided the action (optional)",
},
},
Required: []string{"file_path", "line", "column"},
}
}
func (b *codeActionTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
var params CodeActionParams
if err := json.Unmarshal([]byte(call.Input), &params); err != nil {
return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
}
lsps := b.lspClients
if len(lsps) == 0 {
return NewTextResponse("\nLSP clients are still initializing. Code actions will be available once they're ready.\n"), nil
}
// Ensure file is open in LSP
notifyLspOpenFile(ctx, params.FilePath, lsps)
// Convert 1-based line/column to 0-based for LSP protocol
line := max(0, params.Line-1)
column := max(0, params.Column-1)
// Handle optional end line/column
endLine := line
endColumn := column
if params.EndLine > 0 {
endLine = max(0, params.EndLine-1)
}
if params.EndColumn > 0 {
endColumn = max(0, params.EndColumn-1)
}
// Check if we're executing a specific action
if params.ActionID > 0 && params.LspName != "" {
return executeCodeAction(ctx, params.FilePath, line, column, endLine, endColumn, params.ActionID, params.LspName, lsps)
}
// Otherwise, just list available actions
output := getCodeActions(ctx, params.FilePath, line, column, endLine, endColumn, lsps)
return NewTextResponse(output), nil
}
func getCodeActions(ctx context.Context, filePath string, line, column, endLine, endColumn int, lsps map[string]*lsp.Client) string {
var results []string
for lspName, client := range lsps {
// Create code action params
uri := fmt.Sprintf("file://%s", filePath)
codeActionParams := protocol.CodeActionParams{
TextDocument: protocol.TextDocumentIdentifier{
URI: protocol.DocumentUri(uri),
},
Range: protocol.Range{
Start: protocol.Position{
Line: uint32(line),
Character: uint32(column),
},
End: protocol.Position{
Line: uint32(endLine),
Character: uint32(endColumn),
},
},
Context: protocol.CodeActionContext{
// Request all kinds of code actions
Only: []protocol.CodeActionKind{
protocol.QuickFix,
protocol.Refactor,
protocol.RefactorExtract,
protocol.RefactorInline,
protocol.RefactorRewrite,
protocol.Source,
protocol.SourceOrganizeImports,
protocol.SourceFixAll,
},
},
}
// Get code actions
codeActions, err := client.CodeAction(ctx, codeActionParams)
if err != nil {
results = append(results, fmt.Sprintf("Error from %s: %s", lspName, err))
continue
}
if len(codeActions) == 0 {
results = append(results, fmt.Sprintf("No code actions found by %s", lspName))
continue
}
// Format the code actions
results = append(results, fmt.Sprintf("Code actions found by %s:", lspName))
for i, action := range codeActions {
actionInfo := formatCodeAction(action, i+1)
results = append(results, actionInfo)
}
}
if len(results) == 0 {
return "No code actions found at the specified position."
}
return strings.Join(results, "\n")
}
func formatCodeAction(action protocol.Or_Result_textDocument_codeAction_Item0_Elem, index int) string {
switch v := action.Value.(type) {
case protocol.CodeAction:
kind := "Unknown"
if v.Kind != "" {
kind = string(v.Kind)
}
var details []string
// Add edit information if available
if v.Edit != nil {
numChanges := 0
if v.Edit.Changes != nil {
numChanges = len(v.Edit.Changes)
}
if v.Edit.DocumentChanges != nil {
numChanges = len(v.Edit.DocumentChanges)
}
details = append(details, fmt.Sprintf("Edits: %d changes", numChanges))
}
// Add command information if available
if v.Command != nil {
details = append(details, fmt.Sprintf("Command: %s", v.Command.Title))
}
// Add diagnostics information if available
if v.Diagnostics != nil && len(v.Diagnostics) > 0 {
details = append(details, fmt.Sprintf("Fixes: %d diagnostics", len(v.Diagnostics)))
}
detailsStr := ""
if len(details) > 0 {
detailsStr = " (" + strings.Join(details, ", ") + ")"
}
return fmt.Sprintf(" %d. %s [%s]%s", index, v.Title, kind, detailsStr)
case protocol.Command:
return fmt.Sprintf(" %d. %s [Command]", index, v.Title)
}
return fmt.Sprintf(" %d. Unknown code action type", index)
}
func executeCodeAction(ctx context.Context, filePath string, line, column, endLine, endColumn, actionID int, lspName string, lsps map[string]*lsp.Client) (ToolResponse, error) {
client, ok := lsps[lspName]
if !ok {
return NewTextErrorResponse(fmt.Sprintf("LSP server '%s' not found", lspName)), nil
}
// Create code action params
uri := fmt.Sprintf("file://%s", filePath)
codeActionParams := protocol.CodeActionParams{
TextDocument: protocol.TextDocumentIdentifier{
URI: protocol.DocumentUri(uri),
},
Range: protocol.Range{
Start: protocol.Position{
Line: uint32(line),
Character: uint32(column),
},
End: protocol.Position{
Line: uint32(endLine),
Character: uint32(endColumn),
},
},
Context: protocol.CodeActionContext{
// Request all kinds of code actions
Only: []protocol.CodeActionKind{
protocol.QuickFix,
protocol.Refactor,
protocol.RefactorExtract,
protocol.RefactorInline,
protocol.RefactorRewrite,
protocol.Source,
protocol.SourceOrganizeImports,
protocol.SourceFixAll,
},
},
}
// Get code actions
codeActions, err := client.CodeAction(ctx, codeActionParams)
if err != nil {
return NewTextErrorResponse(fmt.Sprintf("Error getting code actions: %s", err)), nil
}
if len(codeActions) == 0 {
return NewTextErrorResponse("No code actions found"), nil
}
// Check if the requested action ID is valid
if actionID < 1 || actionID > len(codeActions) {
return NewTextErrorResponse(fmt.Sprintf("Invalid action ID: %d. Available actions: 1-%d", actionID, len(codeActions))), nil
}
// Get the selected action (adjust for 0-based index)
selectedAction := codeActions[actionID-1]
// Execute the action based on its type
switch v := selectedAction.Value.(type) {
case protocol.CodeAction:
// Apply workspace edit if available
if v.Edit != nil {
err := util.ApplyWorkspaceEdit(*v.Edit)
if err != nil {
return NewTextErrorResponse(fmt.Sprintf("Error applying edit: %s", err)), nil
}
}
// Execute command if available
if v.Command != nil {
_, err := client.ExecuteCommand(ctx, protocol.ExecuteCommandParams{
Command: v.Command.Command,
Arguments: v.Command.Arguments,
})
if err != nil {
return NewTextErrorResponse(fmt.Sprintf("Error executing command: %s", err)), nil
}
}
return NewTextResponse(fmt.Sprintf("Successfully executed code action: %s", v.Title)), nil
case protocol.Command:
// Execute the command
_, err := client.ExecuteCommand(ctx, protocol.ExecuteCommandParams{
Command: v.Command,
Arguments: v.Arguments,
})
if err != nil {
return NewTextErrorResponse(fmt.Sprintf("Error executing command: %s", err)), nil
}
return NewTextResponse(fmt.Sprintf("Successfully executed command: %s", v.Title)), nil
}
return NewTextErrorResponse("Unknown code action type"), nil
}

View File

@@ -0,0 +1,198 @@
package tools
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"strings"
"github.com/sst/opencode/internal/lsp"
"github.com/sst/opencode/internal/lsp/protocol"
)
type DefinitionParams struct {
FilePath string `json:"file_path"`
Line int `json:"line"`
Column int `json:"column"`
}
type definitionTool struct {
lspClients map[string]*lsp.Client
}
const (
DefinitionToolName = "definition"
definitionDescription = `Find the definition of a symbol at a specific position in a file.
WHEN TO USE THIS TOOL:
- Use when you need to find where a symbol is defined
- Helpful for understanding code structure and relationships
- Great for navigating between implementation and interface
HOW TO USE:
- Provide the path to the file containing the symbol
- Specify the line number (1-based) where the symbol appears
- Specify the column number (1-based) where the symbol appears
- Results show the location of the symbol's definition
FEATURES:
- Finds definitions across files in the project
- Works with variables, functions, classes, interfaces, etc.
- Returns file path, line, and column of the definition
LIMITATIONS:
- Requires a functioning LSP server for the file type
- May not work for all symbols depending on LSP capabilities
- Results depend on the accuracy of the LSP server
TIPS:
- Use in conjunction with References tool to understand usage
- Combine with View tool to examine the definition
`
)
func NewDefinitionTool(lspClients map[string]*lsp.Client) BaseTool {
return &definitionTool{
lspClients,
}
}
func (b *definitionTool) Info() ToolInfo {
return ToolInfo{
Name: DefinitionToolName,
Description: definitionDescription,
Parameters: map[string]any{
"file_path": map[string]any{
"type": "string",
"description": "The path to the file containing the symbol",
},
"line": map[string]any{
"type": "integer",
"description": "The line number (1-based) where the symbol appears",
},
"column": map[string]any{
"type": "integer",
"description": "The column number (1-based) where the symbol appears",
},
},
Required: []string{"file_path", "line", "column"},
}
}
func (b *definitionTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
var params DefinitionParams
if err := json.Unmarshal([]byte(call.Input), &params); err != nil {
return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
}
lsps := b.lspClients
if len(lsps) == 0 {
return NewTextResponse("\nLSP clients are still initializing. Definition lookup will be available once they're ready.\n"), nil
}
// Ensure file is open in LSP
notifyLspOpenFile(ctx, params.FilePath, lsps)
// Convert 1-based line/column to 0-based for LSP protocol
line := max(0, params.Line-1)
column := max(0, params.Column-1)
output := getDefinition(ctx, params.FilePath, line, column, lsps)
return NewTextResponse(output), nil
}
func getDefinition(ctx context.Context, filePath string, line, column int, lsps map[string]*lsp.Client) string {
var results []string
slog.Debug(fmt.Sprintf("Looking for definition in %s at line %d, column %d", filePath, line+1, column+1))
slog.Debug(fmt.Sprintf("Available LSP clients: %v", getClientNames(lsps)))
for lspName, client := range lsps {
slog.Debug(fmt.Sprintf("Trying LSP client: %s", lspName))
// Create definition params
uri := fmt.Sprintf("file://%s", filePath)
definitionParams := protocol.DefinitionParams{
TextDocumentPositionParams: protocol.TextDocumentPositionParams{
TextDocument: protocol.TextDocumentIdentifier{
URI: protocol.DocumentUri(uri),
},
Position: protocol.Position{
Line: uint32(line),
Character: uint32(column),
},
},
}
slog.Debug(fmt.Sprintf("Sending definition request with params: %+v", definitionParams))
// Get definition
definition, err := client.Definition(ctx, definitionParams)
if err != nil {
slog.Debug(fmt.Sprintf("Error from %s: %s", lspName, err))
results = append(results, fmt.Sprintf("Error from %s: %s", lspName, err))
continue
}
slog.Debug(fmt.Sprintf("Got definition result type: %T", definition.Value))
// Process the definition result
locations := processDefinitionResult(definition)
slog.Debug(fmt.Sprintf("Processed locations count: %d", len(locations)))
if len(locations) == 0 {
results = append(results, fmt.Sprintf("No definition found by %s", lspName))
continue
}
// Format the locations
for _, loc := range locations {
path := strings.TrimPrefix(string(loc.URI), "file://")
// Convert 0-based line/column to 1-based for display
defLine := loc.Range.Start.Line + 1
defColumn := loc.Range.Start.Character + 1
slog.Debug(fmt.Sprintf("Found definition at %s:%d:%d", path, defLine, defColumn))
results = append(results, fmt.Sprintf("Definition found by %s: %s:%d:%d", lspName, path, defLine, defColumn))
}
}
if len(results) == 0 {
return "No definition found for the symbol at the specified position."
}
return strings.Join(results, "\n")
}
func processDefinitionResult(result protocol.Or_Result_textDocument_definition) []protocol.Location {
var locations []protocol.Location
switch v := result.Value.(type) {
case protocol.Location:
locations = append(locations, v)
case []protocol.Location:
locations = append(locations, v...)
case []protocol.DefinitionLink:
for _, link := range v {
locations = append(locations, protocol.Location{
URI: link.TargetURI,
Range: link.TargetRange,
})
}
case protocol.Or_Definition:
switch d := v.Value.(type) {
case protocol.Location:
locations = append(locations, d)
case []protocol.Location:
locations = append(locations, d...)
}
}
return locations
}
// Helper function to get LSP client names for debugging
func getClientNames(lsps map[string]*lsp.Client) []string {
names := make([]string, 0, len(lsps))
for name := range lsps {
names = append(names, name)
}
return names
}

View File

@@ -9,8 +9,8 @@ import (
"strings"
"time"
"github.com/opencode-ai/opencode/internal/lsp"
"github.com/opencode-ai/opencode/internal/lsp/protocol"
"github.com/sst/opencode/internal/lsp"
"github.com/sst/opencode/internal/lsp/protocol"
)
type DiagnosticsParams struct {
@@ -74,7 +74,8 @@ func (b *diagnosticsTool) Run(ctx context.Context, call ToolCall) (ToolResponse,
lsps := b.lspClients
if len(lsps) == 0 {
return NewTextErrorResponse("no LSP clients available"), nil
// Return a more helpful message when LSP clients aren't ready yet
return NewTextResponse("\n<diagnostic_summary>\nLSP clients are still initializing. Diagnostics will be available once they're ready.\n</diagnostic_summary>\n"), nil
}
if params.FilePath != "" {

View File

@@ -0,0 +1,204 @@
package tools
import (
"context"
"encoding/json"
"fmt"
"strings"
"github.com/sst/opencode/internal/lsp"
"github.com/sst/opencode/internal/lsp/protocol"
)
type DocSymbolsParams struct {
FilePath string `json:"file_path"`
}
type docSymbolsTool struct {
lspClients map[string]*lsp.Client
}
const (
DocSymbolsToolName = "docSymbols"
docSymbolsDescription = `Get document symbols for a file.
WHEN TO USE THIS TOOL:
- Use when you need to understand the structure of a file
- Helpful for finding classes, functions, methods, and variables in a file
- Great for getting an overview of a file's organization
HOW TO USE:
- Provide the path to the file to get symbols for
- Results show all symbols defined in the file with their kind and location
FEATURES:
- Lists all symbols in a hierarchical structure
- Shows symbol types (function, class, variable, etc.)
- Provides location information for each symbol
- Organizes symbols by their scope and relationship
LIMITATIONS:
- Requires a functioning LSP server for the file type
- Results depend on the accuracy of the LSP server
- May not work for all file types
TIPS:
- Use to quickly understand the structure of a large file
- Combine with Definition and References tools for deeper code exploration
`
)
func NewDocSymbolsTool(lspClients map[string]*lsp.Client) BaseTool {
return &docSymbolsTool{
lspClients,
}
}
func (b *docSymbolsTool) Info() ToolInfo {
return ToolInfo{
Name: DocSymbolsToolName,
Description: docSymbolsDescription,
Parameters: map[string]any{
"file_path": map[string]any{
"type": "string",
"description": "The path to the file to get symbols for",
},
},
Required: []string{"file_path"},
}
}
func (b *docSymbolsTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
var params DocSymbolsParams
if err := json.Unmarshal([]byte(call.Input), &params); err != nil {
return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
}
lsps := b.lspClients
if len(lsps) == 0 {
return NewTextResponse("\nLSP clients are still initializing. Document symbols lookup will be available once they're ready.\n"), nil
}
// Ensure file is open in LSP
notifyLspOpenFile(ctx, params.FilePath, lsps)
output := getDocumentSymbols(ctx, params.FilePath, lsps)
return NewTextResponse(output), nil
}
func getDocumentSymbols(ctx context.Context, filePath string, lsps map[string]*lsp.Client) string {
var results []string
for lspName, client := range lsps {
// Create document symbol params
uri := fmt.Sprintf("file://%s", filePath)
symbolParams := protocol.DocumentSymbolParams{
TextDocument: protocol.TextDocumentIdentifier{
URI: protocol.DocumentUri(uri),
},
}
// Get document symbols
symbolResult, err := client.DocumentSymbol(ctx, symbolParams)
if err != nil {
results = append(results, fmt.Sprintf("Error from %s: %s", lspName, err))
continue
}
// Process the symbol result
symbols := processDocumentSymbolResult(symbolResult)
if len(symbols) == 0 {
results = append(results, fmt.Sprintf("No symbols found by %s", lspName))
continue
}
// Format the symbols
results = append(results, fmt.Sprintf("Symbols found by %s:", lspName))
for _, symbol := range symbols {
results = append(results, formatSymbol(symbol, 1))
}
}
if len(results) == 0 {
return "No symbols found in the specified file."
}
return strings.Join(results, "\n")
}
func processDocumentSymbolResult(result protocol.Or_Result_textDocument_documentSymbol) []SymbolInfo {
var symbols []SymbolInfo
switch v := result.Value.(type) {
case []protocol.SymbolInformation:
for _, si := range v {
symbols = append(symbols, SymbolInfo{
Name: si.Name,
Kind: symbolKindToString(si.Kind),
Location: locationToString(si.Location),
Children: nil,
})
}
case []protocol.DocumentSymbol:
for _, ds := range v {
symbols = append(symbols, documentSymbolToSymbolInfo(ds))
}
}
return symbols
}
// SymbolInfo represents a symbol in a document
type SymbolInfo struct {
Name string
Kind string
Location string
Children []SymbolInfo
}
func documentSymbolToSymbolInfo(symbol protocol.DocumentSymbol) SymbolInfo {
info := SymbolInfo{
Name: symbol.Name,
Kind: symbolKindToString(symbol.Kind),
Location: fmt.Sprintf("Line %d-%d",
symbol.Range.Start.Line+1,
symbol.Range.End.Line+1),
Children: []SymbolInfo{},
}
for _, child := range symbol.Children {
info.Children = append(info.Children, documentSymbolToSymbolInfo(child))
}
return info
}
func locationToString(location protocol.Location) string {
return fmt.Sprintf("Line %d-%d",
location.Range.Start.Line+1,
location.Range.End.Line+1)
}
func symbolKindToString(kind protocol.SymbolKind) string {
if kindStr, ok := protocol.TableKindMap[kind]; ok {
return kindStr
}
return "Unknown"
}
func formatSymbol(symbol SymbolInfo, level int) string {
indent := strings.Repeat(" ", level)
result := fmt.Sprintf("%s- %s (%s) %s", indent, symbol.Name, symbol.Kind, symbol.Location)
var childResults []string
for _, child := range symbol.Children {
childResults = append(childResults, formatSymbol(child, level+1))
}
if len(childResults) > 0 {
return result + "\n" + strings.Join(childResults, "\n")
}
return result
}

View File

@@ -0,0 +1,161 @@
package tools
import (
"context"
"encoding/json"
"fmt"
"strings"
"github.com/sst/opencode/internal/lsp"
"github.com/sst/opencode/internal/lsp/protocol"
)
type ReferencesParams struct {
FilePath string `json:"file_path"`
Line int `json:"line"`
Column int `json:"column"`
IncludeDeclaration bool `json:"include_declaration"`
}
type referencesTool struct {
lspClients map[string]*lsp.Client
}
const (
ReferencesToolName = "references"
referencesDescription = `Find all references to a symbol at a specific position in a file.
WHEN TO USE THIS TOOL:
- Use when you need to find all places where a symbol is used
- Helpful for understanding code usage and dependencies
- Great for refactoring and impact analysis
HOW TO USE:
- Provide the path to the file containing the symbol
- Specify the line number (1-based) where the symbol appears
- Specify the column number (1-based) where the symbol appears
- Optionally set include_declaration to include the declaration in results
- Results show all locations where the symbol is referenced
FEATURES:
- Finds references across files in the project
- Works with variables, functions, classes, interfaces, etc.
- Returns file paths, lines, and columns of all references
LIMITATIONS:
- Requires a functioning LSP server for the file type
- May not find all references depending on LSP capabilities
- Results depend on the accuracy of the LSP server
TIPS:
- Use in conjunction with Definition tool to understand symbol origins
- Combine with View tool to examine the references
`
)
func NewReferencesTool(lspClients map[string]*lsp.Client) BaseTool {
return &referencesTool{
lspClients,
}
}
func (b *referencesTool) Info() ToolInfo {
return ToolInfo{
Name: ReferencesToolName,
Description: referencesDescription,
Parameters: map[string]any{
"file_path": map[string]any{
"type": "string",
"description": "The path to the file containing the symbol",
},
"line": map[string]any{
"type": "integer",
"description": "The line number (1-based) where the symbol appears",
},
"column": map[string]any{
"type": "integer",
"description": "The column number (1-based) where the symbol appears",
},
"include_declaration": map[string]any{
"type": "boolean",
"description": "Whether to include the declaration in the results",
},
},
Required: []string{"file_path", "line", "column"},
}
}
func (b *referencesTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
var params ReferencesParams
if err := json.Unmarshal([]byte(call.Input), &params); err != nil {
return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
}
lsps := b.lspClients
if len(lsps) == 0 {
return NewTextResponse("\nLSP clients are still initializing. References lookup will be available once they're ready.\n"), nil
}
// Ensure file is open in LSP
notifyLspOpenFile(ctx, params.FilePath, lsps)
// Convert 1-based line/column to 0-based for LSP protocol
line := max(0, params.Line-1)
column := max(0, params.Column-1)
output := getReferences(ctx, params.FilePath, line, column, params.IncludeDeclaration, lsps)
return NewTextResponse(output), nil
}
func getReferences(ctx context.Context, filePath string, line, column int, includeDeclaration bool, lsps map[string]*lsp.Client) string {
var results []string
for lspName, client := range lsps {
// Create references params
uri := fmt.Sprintf("file://%s", filePath)
referencesParams := protocol.ReferenceParams{
TextDocumentPositionParams: protocol.TextDocumentPositionParams{
TextDocument: protocol.TextDocumentIdentifier{
URI: protocol.DocumentUri(uri),
},
Position: protocol.Position{
Line: uint32(line),
Character: uint32(column),
},
},
Context: protocol.ReferenceContext{
IncludeDeclaration: includeDeclaration,
},
}
// Get references
references, err := client.References(ctx, referencesParams)
if err != nil {
results = append(results, fmt.Sprintf("Error from %s: %s", lspName, err))
continue
}
if len(references) == 0 {
results = append(results, fmt.Sprintf("No references found by %s", lspName))
continue
}
// Format the locations
results = append(results, fmt.Sprintf("References found by %s:", lspName))
for _, loc := range references {
path := strings.TrimPrefix(string(loc.URI), "file://")
// Convert 0-based line/column to 1-based for display
refLine := loc.Range.Start.Line + 1
refColumn := loc.Range.Start.Character + 1
results = append(results, fmt.Sprintf(" %s:%d:%d", path, refLine, refColumn))
}
}
if len(results) == 0 {
return "No references found for the symbol at the specified position."
}
return strings.Join(results, "\n")
}

View File

@@ -0,0 +1,162 @@
package tools
import (
"context"
"encoding/json"
"fmt"
"strings"
"github.com/sst/opencode/internal/lsp"
"github.com/sst/opencode/internal/lsp/protocol"
)
type WorkspaceSymbolsParams struct {
Query string `json:"query"`
}
type workspaceSymbolsTool struct {
lspClients map[string]*lsp.Client
}
const (
WorkspaceSymbolsToolName = "workspaceSymbols"
workspaceSymbolsDescription = `Find symbols across the workspace matching a query.
WHEN TO USE THIS TOOL:
- Use when you need to find symbols across multiple files
- Helpful for locating classes, functions, or variables in a project
- Great for exploring large codebases
HOW TO USE:
- Provide a query string to search for symbols
- Results show matching symbols from across the workspace
FEATURES:
- Searches across all files in the workspace
- Shows symbol types (function, class, variable, etc.)
- Provides location information for each symbol
- Works with partial matches and fuzzy search (depending on LSP server)
LIMITATIONS:
- Requires a functioning LSP server for the file types
- Results depend on the accuracy of the LSP server
- Query capabilities vary by language server
- May not work for all file types
TIPS:
- Use specific queries to narrow down results
- Combine with DocSymbols tool for detailed file exploration
- Use with Definition tool to jump to symbol definitions
`
)
func NewWorkspaceSymbolsTool(lspClients map[string]*lsp.Client) BaseTool {
return &workspaceSymbolsTool{
lspClients,
}
}
func (b *workspaceSymbolsTool) Info() ToolInfo {
return ToolInfo{
Name: WorkspaceSymbolsToolName,
Description: workspaceSymbolsDescription,
Parameters: map[string]any{
"query": map[string]any{
"type": "string",
"description": "The query string to search for symbols",
},
},
Required: []string{"query"},
}
}
func (b *workspaceSymbolsTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
var params WorkspaceSymbolsParams
if err := json.Unmarshal([]byte(call.Input), &params); err != nil {
return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
}
lsps := b.lspClients
if len(lsps) == 0 {
return NewTextResponse("\nLSP clients are still initializing. Workspace symbols lookup will be available once they're ready.\n"), nil
}
output := getWorkspaceSymbols(ctx, params.Query, lsps)
return NewTextResponse(output), nil
}
func getWorkspaceSymbols(ctx context.Context, query string, lsps map[string]*lsp.Client) string {
var results []string
for lspName, client := range lsps {
// Create workspace symbol params
symbolParams := protocol.WorkspaceSymbolParams{
Query: query,
}
// Get workspace symbols
symbolResult, err := client.Symbol(ctx, symbolParams)
if err != nil {
results = append(results, fmt.Sprintf("Error from %s: %s", lspName, err))
continue
}
// Process the symbol result
symbols := processWorkspaceSymbolResult(symbolResult)
if len(symbols) == 0 {
results = append(results, fmt.Sprintf("No symbols found by %s for query '%s'", lspName, query))
continue
}
// Format the symbols
results = append(results, fmt.Sprintf("Symbols found by %s for query '%s':", lspName, query))
for _, symbol := range symbols {
results = append(results, fmt.Sprintf(" %s (%s) - %s", symbol.Name, symbol.Kind, symbol.Location))
}
}
if len(results) == 0 {
return fmt.Sprintf("No symbols found matching query '%s'.", query)
}
return strings.Join(results, "\n")
}
func processWorkspaceSymbolResult(result protocol.Or_Result_workspace_symbol) []SymbolInfo {
var symbols []SymbolInfo
switch v := result.Value.(type) {
case []protocol.SymbolInformation:
for _, si := range v {
symbols = append(symbols, SymbolInfo{
Name: si.Name,
Kind: symbolKindToString(si.Kind),
Location: formatWorkspaceLocation(si.Location),
Children: nil,
})
}
case []protocol.WorkspaceSymbol:
for _, ws := range v {
location := "Unknown location"
if ws.Location.Value != nil {
if loc, ok := ws.Location.Value.(protocol.Location); ok {
location = formatWorkspaceLocation(loc)
}
}
symbols = append(symbols, SymbolInfo{
Name: ws.Name,
Kind: symbolKindToString(ws.Kind),
Location: location,
Children: nil,
})
}
}
return symbols
}
func formatWorkspaceLocation(location protocol.Location) string {
path := strings.TrimPrefix(string(location.URI), "file://")
return fmt.Sprintf("%s:%d:%d", path, location.Range.Start.Line+1, location.Range.Start.Character+1)
}

View File

@@ -8,12 +8,12 @@ import (
"path/filepath"
"time"
"github.com/opencode-ai/opencode/internal/config"
"github.com/opencode-ai/opencode/internal/diff"
"github.com/opencode-ai/opencode/internal/history"
"github.com/opencode-ai/opencode/internal/logging"
"github.com/opencode-ai/opencode/internal/lsp"
"github.com/opencode-ai/opencode/internal/permission"
"github.com/sst/opencode/internal/config"
"github.com/sst/opencode/internal/diff"
"github.com/sst/opencode/internal/history"
"github.com/sst/opencode/internal/lsp"
"github.com/sst/opencode/internal/permission"
"log/slog"
)
type PatchParams struct {
@@ -193,6 +193,7 @@ func (p *patchTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error
dir := filepath.Dir(path)
patchDiff, _, _ := diff.GenerateDiff("", *change.NewContent, path)
p := p.permissions.Request(
ctx,
permission.CreatePermissionRequest{
SessionID: sessionID,
Path: dir,
@@ -220,6 +221,7 @@ func (p *patchTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error
patchDiff, _, _ := diff.GenerateDiff(currentContent, newContent, path)
dir := filepath.Dir(path)
p := p.permissions.Request(
ctx,
permission.CreatePermissionRequest{
SessionID: sessionID,
Path: dir,
@@ -239,6 +241,7 @@ func (p *patchTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error
dir := filepath.Dir(path)
patchDiff, _, _ := diff.GenerateDiff(*change.OldContent, "", path)
p := p.permissions.Request(
ctx,
permission.CreatePermissionRequest{
SessionID: sessionID,
Path: dir,
@@ -313,12 +316,12 @@ func (p *patchTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error
totalRemovals += removals
// Update history
file, err := p.files.GetByPathAndSession(ctx, absPath, sessionID)
file, err := p.files.GetLatestByPathAndSession(ctx, absPath, sessionID)
if err != nil && change.Type != diff.ActionAdd {
// If not adding a file, create history entry for existing file
_, err = p.files.Create(ctx, sessionID, absPath, oldContent)
if err != nil {
logging.Debug("Error creating file history", "error", err)
slog.Debug("Error creating file history", "error", err)
}
}
@@ -326,7 +329,7 @@ func (p *patchTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error
// User manually changed content, store intermediate version
_, err = p.files.CreateVersion(ctx, sessionID, absPath, oldContent)
if err != nil {
logging.Debug("Error creating file history version", "error", err)
slog.Debug("Error creating file history version", "error", err)
}
}
@@ -337,7 +340,7 @@ func (p *patchTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error
_, err = p.files.CreateVersion(ctx, sessionID, absPath, newContent)
}
if err != nil {
logging.Debug("Error creating file history version", "error", err)
slog.Debug("Error creating file history version", "error", err)
}
// Record file operations

View File

@@ -11,6 +11,8 @@ import (
"sync"
"syscall"
"time"
"github.com/sst/opencode/internal/status"
)
type PersistentShell struct {
@@ -99,7 +101,7 @@ func newPersistentShell(cwd string) *PersistentShell {
go func() {
err := cmd.Wait()
if err != nil {
// Log the error if needed
status.Error(fmt.Sprintf("Shell process exited with error: %v", err))
}
shell.isAlive = false
close(shell.commandQueue)

View File

@@ -1,383 +0,0 @@
package tools
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
)
type SourcegraphParams struct {
Query string `json:"query"`
Count int `json:"count,omitempty"`
ContextWindow int `json:"context_window,omitempty"`
Timeout int `json:"timeout,omitempty"`
}
type SourcegraphResponseMetadata struct {
NumberOfMatches int `json:"number_of_matches"`
Truncated bool `json:"truncated"`
}
type sourcegraphTool struct {
client *http.Client
}
const (
SourcegraphToolName = "sourcegraph"
sourcegraphToolDescription = `Search code across public repositories using Sourcegraph's GraphQL API.
WHEN TO USE THIS TOOL:
- Use when you need to find code examples or implementations across public repositories
- Helpful for researching how others have solved similar problems
- Useful for discovering patterns and best practices in open source code
HOW TO USE:
- Provide a search query using Sourcegraph's query syntax
- Optionally specify the number of results to return (default: 10)
- Optionally set a timeout for the request
QUERY SYNTAX:
- Basic search: "fmt.Println" searches for exact matches
- File filters: "file:.go fmt.Println" limits to Go files
- Repository filters: "repo:^github\.com/golang/go$ fmt.Println" limits to specific repos
- Language filters: "lang:go fmt.Println" limits to Go code
- Boolean operators: "fmt.Println AND log.Fatal" for combined terms
- Regular expressions: "fmt\.(Print|Printf|Println)" for pattern matching
- Quoted strings: "\"exact phrase\"" for exact phrase matching
- Exclude filters: "-file:test" or "-repo:forks" to exclude matches
ADVANCED FILTERS:
- Repository filters:
* "repo:name" - Match repositories with name containing "name"
* "repo:^github\.com/org/repo$" - Exact repository match
* "repo:org/repo@branch" - Search specific branch
* "repo:org/repo rev:branch" - Alternative branch syntax
* "-repo:name" - Exclude repositories
* "fork:yes" or "fork:only" - Include or only show forks
* "archived:yes" or "archived:only" - Include or only show archived repos
* "visibility:public" or "visibility:private" - Filter by visibility
- File filters:
* "file:\.js$" - Files with .js extension
* "file:internal/" - Files in internal directory
* "-file:test" - Exclude test files
* "file:has.content(Copyright)" - Files containing "Copyright"
* "file:has.contributor([email protected])" - Files with specific contributor
- Content filters:
* "content:\"exact string\"" - Search for exact string
* "-content:\"unwanted\"" - Exclude files with unwanted content
* "case:yes" - Case-sensitive search
- Type filters:
* "type:symbol" - Search for symbols (functions, classes, etc.)
* "type:file" - Search file content only
* "type:path" - Search filenames only
* "type:diff" - Search code changes
* "type:commit" - Search commit messages
- Commit/diff search:
* "after:\"1 month ago\"" - Commits after date
* "before:\"2023-01-01\"" - Commits before date
* "author:name" - Commits by author
* "message:\"fix bug\"" - Commits with message
- Result selection:
* "select:repo" - Show only repository names
* "select:file" - Show only file paths
* "select:content" - Show only matching content
* "select:symbol" - Show only matching symbols
- Result control:
* "count:100" - Return up to 100 results
* "count:all" - Return all results
* "timeout:30s" - Set search timeout
EXAMPLES:
- "file:.go context.WithTimeout" - Find Go code using context.WithTimeout
- "lang:typescript useState type:symbol" - Find TypeScript React useState hooks
- "repo:^github\.com/kubernetes/kubernetes$ pod list type:file" - Find Kubernetes files related to pod listing
- "repo:sourcegraph/sourcegraph$ after:\"3 months ago\" type:diff database" - Recent changes to database code
- "file:Dockerfile (alpine OR ubuntu) -content:alpine:latest" - Dockerfiles with specific base images
- "repo:has.path(\.py) file:requirements.txt tensorflow" - Python projects using TensorFlow
BOOLEAN OPERATORS:
- "term1 AND term2" - Results containing both terms
- "term1 OR term2" - Results containing either term
- "term1 NOT term2" - Results with term1 but not term2
- "term1 and (term2 or term3)" - Grouping with parentheses
LIMITATIONS:
- Only searches public repositories
- Rate limits may apply
- Complex queries may take longer to execute
- Maximum of 20 results per query
TIPS:
- Use specific file extensions to narrow results
- Add repo: filters for more targeted searches
- Use type:symbol to find function/method definitions
- Use type:file to find relevant files`
)
func NewSourcegraphTool() BaseTool {
return &sourcegraphTool{
client: &http.Client{
Timeout: 30 * time.Second,
},
}
}
func (t *sourcegraphTool) Info() ToolInfo {
return ToolInfo{
Name: SourcegraphToolName,
Description: sourcegraphToolDescription,
Parameters: map[string]any{
"query": map[string]any{
"type": "string",
"description": "The Sourcegraph search query",
},
"count": map[string]any{
"type": "number",
"description": "Optional number of results to return (default: 10, max: 20)",
},
"context_window": map[string]any{
"type": "number",
"description": "The context around the match to return (default: 10 lines)",
},
"timeout": map[string]any{
"type": "number",
"description": "Optional timeout in seconds (max 120)",
},
},
Required: []string{"query"},
}
}
func (t *sourcegraphTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
var params SourcegraphParams
if err := json.Unmarshal([]byte(call.Input), &params); err != nil {
return NewTextErrorResponse("Failed to parse sourcegraph parameters: " + err.Error()), nil
}
if params.Query == "" {
return NewTextErrorResponse("Query parameter is required"), nil
}
if params.Count <= 0 {
params.Count = 10
} else if params.Count > 20 {
params.Count = 20 // Limit to 20 results
}
if params.ContextWindow <= 0 {
params.ContextWindow = 10 // Default context window
}
client := t.client
if params.Timeout > 0 {
maxTimeout := 120 // 2 minutes
if params.Timeout > maxTimeout {
params.Timeout = maxTimeout
}
client = &http.Client{
Timeout: time.Duration(params.Timeout) * time.Second,
}
}
type graphqlRequest struct {
Query string `json:"query"`
Variables struct {
Query string `json:"query"`
} `json:"variables"`
}
request := graphqlRequest{
Query: "query Search($query: String!) { search(query: $query, version: V2, patternType: keyword ) { results { matchCount, limitHit, resultCount, approximateResultCount, missing { name }, timedout { name }, indexUnavailable, results { __typename, ... on FileMatch { repository { name }, file { path, url, content }, lineMatches { preview, lineNumber, offsetAndLengths } } } } } }",
}
request.Variables.Query = params.Query
graphqlQueryBytes, err := json.Marshal(request)
if err != nil {
return ToolResponse{}, fmt.Errorf("failed to marshal GraphQL request: %w", err)
}
graphqlQuery := string(graphqlQueryBytes)
req, err := http.NewRequestWithContext(
ctx,
"POST",
"https://sourcegraph.com/.api/graphql",
bytes.NewBuffer([]byte(graphqlQuery)),
)
if err != nil {
return ToolResponse{}, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", "opencode/1.0")
resp, err := client.Do(req)
if err != nil {
return ToolResponse{}, fmt.Errorf("failed to fetch URL: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
if len(body) > 0 {
return NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d, response: %s", resp.StatusCode, string(body))), nil
}
return NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d", resp.StatusCode)), nil
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return ToolResponse{}, fmt.Errorf("failed to read response body: %w", err)
}
var result map[string]any
if err = json.Unmarshal(body, &result); err != nil {
return ToolResponse{}, fmt.Errorf("failed to unmarshal response: %w", err)
}
formattedResults, err := formatSourcegraphResults(result, params.ContextWindow)
if err != nil {
return NewTextErrorResponse("Failed to format results: " + err.Error()), nil
}
return NewTextResponse(formattedResults), nil
}
func formatSourcegraphResults(result map[string]any, contextWindow int) (string, error) {
var buffer strings.Builder
if errors, ok := result["errors"].([]any); ok && len(errors) > 0 {
buffer.WriteString("## Sourcegraph API Error\n\n")
for _, err := range errors {
if errMap, ok := err.(map[string]any); ok {
if message, ok := errMap["message"].(string); ok {
buffer.WriteString(fmt.Sprintf("- %s\n", message))
}
}
}
return buffer.String(), nil
}
data, ok := result["data"].(map[string]any)
if !ok {
return "", fmt.Errorf("invalid response format: missing data field")
}
search, ok := data["search"].(map[string]any)
if !ok {
return "", fmt.Errorf("invalid response format: missing search field")
}
searchResults, ok := search["results"].(map[string]any)
if !ok {
return "", fmt.Errorf("invalid response format: missing results field")
}
matchCount, _ := searchResults["matchCount"].(float64)
resultCount, _ := searchResults["resultCount"].(float64)
limitHit, _ := searchResults["limitHit"].(bool)
buffer.WriteString("# Sourcegraph Search Results\n\n")
buffer.WriteString(fmt.Sprintf("Found %d matches across %d results\n", int(matchCount), int(resultCount)))
if limitHit {
buffer.WriteString("(Result limit reached, try a more specific query)\n")
}
buffer.WriteString("\n")
results, ok := searchResults["results"].([]any)
if !ok || len(results) == 0 {
buffer.WriteString("No results found. Try a different query.\n")
return buffer.String(), nil
}
maxResults := 10
if len(results) > maxResults {
results = results[:maxResults]
}
for i, res := range results {
fileMatch, ok := res.(map[string]any)
if !ok {
continue
}
typeName, _ := fileMatch["__typename"].(string)
if typeName != "FileMatch" {
continue
}
repo, _ := fileMatch["repository"].(map[string]any)
file, _ := fileMatch["file"].(map[string]any)
lineMatches, _ := fileMatch["lineMatches"].([]any)
if repo == nil || file == nil {
continue
}
repoName, _ := repo["name"].(string)
filePath, _ := file["path"].(string)
fileURL, _ := file["url"].(string)
fileContent, _ := file["content"].(string)
buffer.WriteString(fmt.Sprintf("## Result %d: %s/%s\n\n", i+1, repoName, filePath))
if fileURL != "" {
buffer.WriteString(fmt.Sprintf("URL: %s\n\n", fileURL))
}
if len(lineMatches) > 0 {
for _, lm := range lineMatches {
lineMatch, ok := lm.(map[string]any)
if !ok {
continue
}
lineNumber, _ := lineMatch["lineNumber"].(float64)
preview, _ := lineMatch["preview"].(string)
if fileContent != "" {
lines := strings.Split(fileContent, "\n")
buffer.WriteString("```\n")
startLine := max(1, int(lineNumber)-contextWindow)
for j := startLine - 1; j < int(lineNumber)-1 && j < len(lines); j++ {
if j >= 0 {
buffer.WriteString(fmt.Sprintf("%d| %s\n", j+1, lines[j]))
}
}
buffer.WriteString(fmt.Sprintf("%d| %s\n", int(lineNumber), preview))
endLine := int(lineNumber) + contextWindow
for j := int(lineNumber); j < endLine && j < len(lines); j++ {
if j < len(lines) {
buffer.WriteString(fmt.Sprintf("%d| %s\n", j+1, lines[j]))
}
}
buffer.WriteString("```\n\n")
} else {
buffer.WriteString("```\n")
buffer.WriteString(fmt.Sprintf("%d| %s\n", int(lineNumber), preview))
buffer.WriteString("```\n\n")
}
}
}
}
return buffer.String(), nil
}

View File

@@ -10,8 +10,8 @@ import (
"path/filepath"
"strings"
"github.com/opencode-ai/opencode/internal/config"
"github.com/opencode-ai/opencode/internal/lsp"
"github.com/sst/opencode/internal/config"
"github.com/sst/opencode/internal/lsp"
)
type ViewParams struct {

View File

@@ -9,12 +9,12 @@ import (
"strings"
"time"
"github.com/opencode-ai/opencode/internal/config"
"github.com/opencode-ai/opencode/internal/diff"
"github.com/opencode-ai/opencode/internal/history"
"github.com/opencode-ai/opencode/internal/logging"
"github.com/opencode-ai/opencode/internal/lsp"
"github.com/opencode-ai/opencode/internal/permission"
"github.com/sst/opencode/internal/config"
"github.com/sst/opencode/internal/diff"
"github.com/sst/opencode/internal/history"
"github.com/sst/opencode/internal/lsp"
"github.com/sst/opencode/internal/permission"
"log/slog"
)
type WriteParams struct {
@@ -167,6 +167,7 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error
permissionPath = rootDir
}
p := w.permissions.Request(
ctx,
permission.CreatePermissionRequest{
SessionID: sessionID,
Path: permissionPath,
@@ -189,7 +190,7 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error
}
// Check if file exists in history
file, err := w.files.GetByPathAndSession(ctx, filePath, sessionID)
file, err := w.files.GetLatestByPathAndSession(ctx, filePath, sessionID)
if err != nil {
_, err = w.files.Create(ctx, sessionID, filePath, oldContent)
if err != nil {
@@ -201,13 +202,13 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error
// User Manually changed the content store an intermediate version
_, err = w.files.CreateVersion(ctx, sessionID, filePath, oldContent)
if err != nil {
logging.Debug("Error creating file history version", "error", err)
slog.Debug("Error creating file history version", "error", err)
}
}
// Store the new version
_, err = w.files.CreateVersion(ctx, sessionID, filePath, params.Content)
if err != nil {
logging.Debug("Error creating file history version", "error", err)
slog.Debug("Error creating file history version", "error", err)
}
recordFileWrite(filePath)

View File

@@ -1,78 +0,0 @@
package logging
import (
"fmt"
"log/slog"
"os"
"runtime/debug"
"time"
)
func Info(msg string, args ...any) {
slog.Info(msg, args...)
}
func Debug(msg string, args ...any) {
slog.Debug(msg, args...)
}
func Warn(msg string, args ...any) {
slog.Warn(msg, args...)
}
func Error(msg string, args ...any) {
slog.Error(msg, args...)
}
func InfoPersist(msg string, args ...any) {
args = append(args, persistKeyArg, true)
slog.Info(msg, args...)
}
func DebugPersist(msg string, args ...any) {
args = append(args, persistKeyArg, true)
slog.Debug(msg, args...)
}
func WarnPersist(msg string, args ...any) {
args = append(args, persistKeyArg, true)
slog.Warn(msg, args...)
}
func ErrorPersist(msg string, args ...any) {
args = append(args, persistKeyArg, true)
slog.Error(msg, args...)
}
// RecoverPanic is a common function to handle panics gracefully.
// It logs the error, creates a panic log file with stack trace,
// and executes an optional cleanup function before returning.
func RecoverPanic(name string, cleanup func()) {
if r := recover(); r != nil {
// Log the panic
ErrorPersist(fmt.Sprintf("Panic in %s: %v", name, r))
// Create a timestamped panic log file
timestamp := time.Now().Format("20060102-150405")
filename := fmt.Sprintf("opencode-panic-%s-%s.log", name, timestamp)
file, err := os.Create(filename)
if err != nil {
ErrorPersist(fmt.Sprintf("Failed to create panic log: %v", err))
} else {
defer file.Close()
// Write panic information and stack trace
fmt.Fprintf(file, "Panic in %s: %v\n\n", name, r)
fmt.Fprintf(file, "Time: %s\n\n", time.Now().Format(time.RFC3339))
fmt.Fprintf(file, "Stack Trace:\n%s\n", debug.Stack())
InfoPersist(fmt.Sprintf("Panic details written to %s", filename))
}
// Execute cleanup function if provided
if cleanup != nil {
cleanup()
}
}
}

292
internal/logging/logging.go Normal file
View File

@@ -0,0 +1,292 @@
package logging
import (
"bytes"
"context"
"database/sql"
"encoding/json"
"fmt"
"io"
"log/slog"
"os"
"runtime/debug"
"strings"
"time"
"github.com/go-logfmt/logfmt"
"github.com/google/uuid"
"github.com/sst/opencode/internal/db"
"github.com/sst/opencode/internal/pubsub"
)
type Log struct {
ID string
SessionID string
Timestamp time.Time
Level string
Message string
Attributes map[string]string
CreatedAt time.Time
}
const (
EventLogCreated pubsub.EventType = "log_created"
)
type Service interface {
pubsub.Subscriber[Log]
Create(ctx context.Context, timestamp time.Time, level, message string, attributes map[string]string, sessionID string) error
ListBySession(ctx context.Context, sessionID string) ([]Log, error)
ListAll(ctx context.Context, limit int) ([]Log, error)
}
type service struct {
db *db.Queries
broker *pubsub.Broker[Log]
}
var globalLoggingService *service
func InitService(dbConn *sql.DB) error {
if globalLoggingService != nil {
return fmt.Errorf("logging service already initialized")
}
queries := db.New(dbConn)
broker := pubsub.NewBroker[Log]()
globalLoggingService = &service{
db: queries,
broker: broker,
}
return nil
}
func GetService() Service {
if globalLoggingService == nil {
panic("logging service not initialized. Call logging.InitService() first.")
}
return globalLoggingService
}
func (s *service) Create(ctx context.Context, timestamp time.Time, level, message string, attributes map[string]string, sessionID string) error {
if level == "" {
level = "info"
}
var attributesJSON sql.NullString
if len(attributes) > 0 {
attributesBytes, err := json.Marshal(attributes)
if err != nil {
return fmt.Errorf("failed to marshal log attributes: %w", err)
}
attributesJSON = sql.NullString{String: string(attributesBytes), Valid: true}
}
dbLog, err := s.db.CreateLog(ctx, db.CreateLogParams{
ID: uuid.New().String(),
SessionID: sql.NullString{String: sessionID, Valid: sessionID != ""},
Timestamp: timestamp.UTC().Format(time.RFC3339Nano),
Level: level,
Message: message,
Attributes: attributesJSON,
})
if err != nil {
return fmt.Errorf("db.CreateLog: %w", err)
}
log := s.fromDBItem(dbLog)
s.broker.Publish(EventLogCreated, log)
return nil
}
func (s *service) ListBySession(ctx context.Context, sessionID string) ([]Log, error) {
dbLogs, err := s.db.ListLogsBySession(ctx, sql.NullString{String: sessionID, Valid: true})
if err != nil {
return nil, fmt.Errorf("db.ListLogsBySession: %w", err)
}
logs := make([]Log, len(dbLogs))
for i, dbSess := range dbLogs {
logs[i] = s.fromDBItem(dbSess)
}
return logs, nil
}
func (s *service) ListAll(ctx context.Context, limit int) ([]Log, error) {
dbLogs, err := s.db.ListAllLogs(ctx, int64(limit))
if err != nil {
return nil, fmt.Errorf("db.ListAllLogs: %w", err)
}
logs := make([]Log, len(dbLogs))
for i, dbSess := range dbLogs {
logs[i] = s.fromDBItem(dbSess)
}
return logs, nil
}
func (s *service) Subscribe(ctx context.Context) <-chan pubsub.Event[Log] {
return s.broker.Subscribe(ctx)
}
func (s *service) fromDBItem(item db.Log) Log {
log := Log{
ID: item.ID,
SessionID: item.SessionID.String,
Level: item.Level,
Message: item.Message,
}
// Parse timestamp from ISO string
timestamp, err := time.Parse(time.RFC3339Nano, item.Timestamp)
if err == nil {
log.Timestamp = timestamp
} else {
log.Timestamp = time.Now() // Fallback
}
// Parse created_at from ISO string
createdAt, err := time.Parse(time.RFC3339Nano, item.CreatedAt)
if err == nil {
log.CreatedAt = createdAt
} else {
log.CreatedAt = time.Now() // Fallback
}
if item.Attributes.Valid && item.Attributes.String != "" {
if err := json.Unmarshal([]byte(item.Attributes.String), &log.Attributes); err != nil {
slog.Error("Failed to unmarshal log attributes", "log_id", item.ID, "error", err)
log.Attributes = make(map[string]string)
}
} else {
log.Attributes = make(map[string]string)
}
return log
}
func Create(ctx context.Context, timestamp time.Time, level, message string, attributes map[string]string, sessionID string) error {
return GetService().Create(ctx, timestamp, level, message, attributes, sessionID)
}
func ListBySession(ctx context.Context, sessionID string) ([]Log, error) {
return GetService().ListBySession(ctx, sessionID)
}
func ListAll(ctx context.Context, limit int) ([]Log, error) {
return GetService().ListAll(ctx, limit)
}
func Subscribe(ctx context.Context) <-chan pubsub.Event[Log] {
return GetService().Subscribe(ctx)
}
type slogWriter struct{}
func (sw *slogWriter) Write(p []byte) (n int, err error) {
// Example: time=2024-05-09T12:34:56.789-05:00 level=INFO msg="User request" session=xyz foo=bar
d := logfmt.NewDecoder(bytes.NewReader(p))
for d.ScanRecord() {
var timestamp time.Time
var level string
var message string
var sessionID string
var attributes map[string]string
attributes = make(map[string]string)
hasTimestamp := false
for d.ScanKeyval() {
key := string(d.Key())
value := string(d.Value())
switch key {
case "time":
parsedTime, timeErr := time.Parse(time.RFC3339Nano, value)
if timeErr != nil {
parsedTime, timeErr = time.Parse(time.RFC3339, value)
if timeErr != nil {
slog.Error("Failed to parse time in slog writer", "value", value, "error", timeErr)
timestamp = time.Now().UTC()
hasTimestamp = true
continue
}
}
timestamp = parsedTime
hasTimestamp = true
case "level":
level = strings.ToLower(value)
case "msg", "message":
message = value
case "session_id":
sessionID = value
default:
attributes[key] = value
}
}
if d.Err() != nil {
return len(p), fmt.Errorf("logfmt.ScanRecord: %w", d.Err())
}
if !hasTimestamp {
timestamp = time.Now()
}
// Create log entry via the service (non-blocking or handle error appropriately)
// Using context.Background() as this is a low-level logging write.
go func(timestamp time.Time, level, message string, attributes map[string]string, sessionID string) { // Run in a goroutine to avoid blocking slog
if globalLoggingService == nil {
// If the logging service is not initialized, log the message to stderr
// fmt.Fprintf(os.Stderr, "ERROR [logging.slogWriter]: logging service not initialized\n")
return
}
if err := Create(context.Background(), timestamp, level, message, attributes, sessionID); err != nil {
// Log internal error using a more primitive logger to avoid loops
fmt.Fprintf(os.Stderr, "ERROR [logging.slogWriter]: failed to persist log: %v\n", err)
}
}(timestamp, level, message, attributes, sessionID)
}
if d.Err() != nil {
return len(p), fmt.Errorf("logfmt.ScanRecord final: %w", d.Err())
}
return len(p), nil
}
func NewSlogWriter() io.Writer {
return &slogWriter{}
}
// RecoverPanic is a common function to handle panics gracefully.
// It logs the error, creates a panic log file with stack trace,
// and executes an optional cleanup function.
func RecoverPanic(name string, cleanup func()) {
if r := recover(); r != nil {
errorMsg := fmt.Sprintf("Panic in %s: %v", name, r)
// Use slog directly here, as our service might be the one panicking.
slog.Error(errorMsg)
// status.Error(errorMsg)
timestamp := time.Now().Format("20060102-150405")
filename := fmt.Sprintf("opencode-panic-%s-%s.log", name, timestamp)
file, err := os.Create(filename)
if err != nil {
errMsg := fmt.Sprintf("Failed to create panic log file '%s': %v", filename, err)
slog.Error(errMsg)
// status.Error(errMsg)
} else {
defer file.Close()
fmt.Fprintf(file, "Panic in %s: %v\n\n", name, r)
fmt.Fprintf(file, "Time: %s\n\n", time.Now().Format(time.RFC3339))
fmt.Fprintf(file, "Stack Trace:\n%s\n", string(debug.Stack())) // Capture stack trace
infoMsg := fmt.Sprintf("Panic details written to %s", filename)
slog.Info(infoMsg)
// status.Info(infoMsg)
}
if cleanup != nil {
cleanup()
}
}
}

View File

@@ -1,21 +0,0 @@
package logging
import (
"time"
)
// LogMessage is the event payload for a log message
type LogMessage struct {
ID string
Time time.Time
Level string
Persist bool // used when we want to show the mesage in the status bar
PersistTime time.Duration // used when we want to show the mesage in the status bar
Message string `json:"msg"`
Attributes []Attr
}
type Attr struct {
Key string
Value string
}

View File

@@ -1,101 +0,0 @@
package logging
import (
"bytes"
"context"
"fmt"
"strings"
"sync"
"time"
"github.com/go-logfmt/logfmt"
"github.com/opencode-ai/opencode/internal/pubsub"
)
const (
persistKeyArg = "$_persist"
PersistTimeArg = "$_persist_time"
)
type LogData struct {
messages []LogMessage
*pubsub.Broker[LogMessage]
lock sync.Mutex
}
func (l *LogData) Add(msg LogMessage) {
l.lock.Lock()
defer l.lock.Unlock()
l.messages = append(l.messages, msg)
l.Publish(pubsub.CreatedEvent, msg)
}
func (l *LogData) List() []LogMessage {
l.lock.Lock()
defer l.lock.Unlock()
return l.messages
}
var defaultLogData = &LogData{
messages: make([]LogMessage, 0),
Broker: pubsub.NewBroker[LogMessage](),
}
type writer struct{}
func (w *writer) Write(p []byte) (int, error) {
d := logfmt.NewDecoder(bytes.NewReader(p))
for d.ScanRecord() {
msg := LogMessage{
ID: fmt.Sprintf("%d", time.Now().UnixNano()),
Time: time.Now(),
}
for d.ScanKeyval() {
switch string(d.Key()) {
case "time":
parsed, err := time.Parse(time.RFC3339, string(d.Value()))
if err != nil {
return 0, fmt.Errorf("parsing time: %w", err)
}
msg.Time = parsed
case "level":
msg.Level = strings.ToLower(string(d.Value()))
case "msg":
msg.Message = string(d.Value())
default:
if string(d.Key()) == persistKeyArg {
msg.Persist = true
} else if string(d.Key()) == PersistTimeArg {
parsed, err := time.ParseDuration(string(d.Value()))
if err != nil {
continue
}
msg.PersistTime = parsed
} else {
msg.Attributes = append(msg.Attributes, Attr{
Key: string(d.Key()),
Value: string(d.Value()),
})
}
}
}
defaultLogData.Add(msg)
}
if d.Err() != nil {
return 0, d.Err()
}
return len(p), nil
}
func NewWriter() *writer {
w := &writer{}
return w
}
func Subscribe(ctx context.Context) <-chan pubsub.Event[LogMessage] {
return defaultLogData.Subscribe(ctx)
}
func List() []LogMessage {
return defaultLogData.List()
}

View File

@@ -14,9 +14,12 @@ import (
"sync/atomic"
"time"
"github.com/opencode-ai/opencode/internal/config"
"github.com/opencode-ai/opencode/internal/logging"
"github.com/opencode-ai/opencode/internal/lsp/protocol"
"log/slog"
"github.com/sst/opencode/internal/config"
"github.com/sst/opencode/internal/logging"
"github.com/sst/opencode/internal/lsp/protocol"
"github.com/sst/opencode/internal/status"
)
type Client struct {
@@ -96,17 +99,17 @@ func NewClient(ctx context.Context, command string, args ...string) (*Client, er
go func() {
scanner := bufio.NewScanner(stderr)
for scanner.Scan() {
fmt.Fprintf(os.Stderr, "LSP Server: %s\n", scanner.Text())
slog.Info("LSP Server", "message", scanner.Text())
}
if err := scanner.Err(); err != nil {
fmt.Fprintf(os.Stderr, "Error reading stderr: %v\n", err)
slog.Error("Error reading LSP stderr", "error", err)
}
}()
// Start message handling loop
go func() {
defer logging.RecoverPanic("LSP-message-handler", func() {
logging.ErrorPersist("LSP message handler crashed, LSP functionality may be impaired")
status.Error("LSP message handler crashed, LSP functionality may be impaired")
})
client.handleMessages()
}()
@@ -300,7 +303,7 @@ func (c *Client) WaitForServerReady(ctx context.Context) error {
defer ticker.Stop()
if cnf.DebugLSP {
logging.Debug("Waiting for LSP server to be ready...")
slog.Debug("Waiting for LSP server to be ready...")
}
// Determine server type for specialized initialization
@@ -309,7 +312,7 @@ func (c *Client) WaitForServerReady(ctx context.Context) error {
// For TypeScript-like servers, we need to open some key files first
if serverType == ServerTypeTypeScript {
if cnf.DebugLSP {
logging.Debug("TypeScript-like server detected, opening key configuration files")
slog.Debug("TypeScript-like server detected, opening key configuration files")
}
c.openKeyConfigFiles(ctx)
}
@@ -326,15 +329,15 @@ func (c *Client) WaitForServerReady(ctx context.Context) error {
// Server responded successfully
c.SetServerState(StateReady)
if cnf.DebugLSP {
logging.Debug("LSP server is ready")
slog.Debug("LSP server is ready")
}
return nil
} else {
logging.Debug("LSP server not ready yet", "error", err, "serverType", serverType)
slog.Debug("LSP server not ready yet", "error", err, "serverType", serverType)
}
if cnf.DebugLSP {
logging.Debug("LSP server not ready yet", "error", err, "serverType", serverType)
slog.Debug("LSP server not ready yet", "error", err, "serverType", serverType)
}
}
}
@@ -409,9 +412,9 @@ func (c *Client) openKeyConfigFiles(ctx context.Context) {
if _, err := os.Stat(file); err == nil {
// File exists, try to open it
if err := c.OpenFile(ctx, file); err != nil {
logging.Debug("Failed to open key config file", "file", file, "error", err)
slog.Debug("Failed to open key config file", "file", file, "error", err)
} else {
logging.Debug("Opened key config file for initialization", "file", file)
slog.Debug("Opened key config file for initialization", "file", file)
}
}
}
@@ -487,7 +490,7 @@ func (c *Client) pingTypeScriptServer(ctx context.Context) error {
return nil
})
if err != nil {
logging.Debug("Error walking directory for TypeScript files", "error", err)
slog.Debug("Error walking directory for TypeScript files", "error", err)
}
// Final fallback - just try a generic capability
@@ -527,7 +530,7 @@ func (c *Client) openTypeScriptFiles(ctx context.Context, workDir string) {
if err := c.OpenFile(ctx, path); err == nil {
filesOpened++
if cnf.DebugLSP {
logging.Debug("Opened TypeScript file for initialization", "file", path)
slog.Debug("Opened TypeScript file for initialization", "file", path)
}
}
}
@@ -536,11 +539,11 @@ func (c *Client) openTypeScriptFiles(ctx context.Context, workDir string) {
})
if err != nil && cnf.DebugLSP {
logging.Debug("Error walking directory for TypeScript files", "error", err)
slog.Debug("Error walking directory for TypeScript files", "error", err)
}
if cnf.DebugLSP {
logging.Debug("Opened TypeScript files for initialization", "count", filesOpened)
slog.Debug("Opened TypeScript files for initialization", "count", filesOpened)
}
}
@@ -627,6 +630,15 @@ func (c *Client) OpenFile(ctx context.Context, filepath string) error {
func (c *Client) NotifyChange(ctx context.Context, filepath string) error {
uri := fmt.Sprintf("file://%s", filepath)
// Verify file exists before attempting to read it
if _, err := os.Stat(filepath); err != nil {
if os.IsNotExist(err) {
// File was deleted - close it in the LSP client instead of notifying change
return c.CloseFile(ctx, filepath)
}
return fmt.Errorf("error checking file: %w", err)
}
content, err := os.ReadFile(filepath)
if err != nil {
return fmt.Errorf("error reading file: %w", err)
@@ -681,7 +693,7 @@ func (c *Client) CloseFile(ctx context.Context, filepath string) error {
}
if cnf.DebugLSP {
logging.Debug("Closing file", "file", filepath)
slog.Debug("Closing file", "file", filepath)
}
if err := c.Notify(ctx, "textDocument/didClose", params); err != nil {
return err
@@ -720,12 +732,12 @@ func (c *Client) CloseAllFiles(ctx context.Context) {
for _, filePath := range filesToClose {
err := c.CloseFile(ctx, filePath)
if err != nil && cnf.DebugLSP {
logging.Warn("Error closing file", "file", filePath, "error", err)
slog.Warn("Error closing file", "file", filePath, "error", err)
}
}
if cnf.DebugLSP {
logging.Debug("Closed all files", "files", filesToClose)
slog.Debug("Closed all files", "files", filesToClose)
}
}

View File

@@ -0,0 +1,65 @@
package discovery
import (
"fmt"
"github.com/sst/opencode/internal/config"
"log/slog"
)
// IntegrateLSPServers discovers languages and LSP servers and integrates them into the application configuration
func IntegrateLSPServers(workingDir string) error {
// Get the current configuration
cfg := config.Get()
if cfg == nil {
return fmt.Errorf("config not loaded")
}
// Check if this is the first run
shouldInit, err := config.ShouldShowInitDialog()
if err != nil {
return fmt.Errorf("failed to check initialization status: %w", err)
}
// Always run language detection, but log differently for first run vs. subsequent runs
if shouldInit || len(cfg.LSP) == 0 {
slog.Info("Running initial LSP auto-discovery...")
} else {
slog.Debug("Running LSP auto-discovery to detect new languages...")
}
// Configure LSP servers
servers, err := ConfigureLSPServers(workingDir)
if err != nil {
return fmt.Errorf("failed to configure LSP servers: %w", err)
}
// Update the configuration with discovered servers
for langID, serverInfo := range servers {
// Skip languages that already have a configured server
if _, exists := cfg.LSP[langID]; exists {
slog.Debug("LSP server already configured for language", "language", langID)
continue
}
if serverInfo.Available {
// Only add servers that were found
cfg.LSP[langID] = config.LSPConfig{
Disabled: false,
Command: serverInfo.Path,
Args: serverInfo.Args,
}
slog.Info("Added LSP server to configuration",
"language", langID,
"command", serverInfo.Command,
"path", serverInfo.Path)
} else {
slog.Warn("LSP server not available",
"language", langID,
"command", serverInfo.Command,
"installCmd", serverInfo.InstallCmd)
}
}
return nil
}

View File

@@ -0,0 +1,298 @@
package discovery
import (
"os"
"path/filepath"
"strings"
"sync"
"github.com/sst/opencode/internal/lsp"
"log/slog"
)
// LanguageInfo stores information about a detected language
type LanguageInfo struct {
// Language identifier (e.g., "go", "typescript", "python")
ID string
// Number of files detected for this language
FileCount int
// Project files associated with this language (e.g., go.mod, package.json)
ProjectFiles []string
// Whether this is likely a primary language in the project
IsPrimary bool
}
// ProjectFile represents a project configuration file
type ProjectFile struct {
// File name or pattern to match
Name string
// Associated language ID
LanguageID string
// Whether this file strongly indicates the language is primary
IsPrimary bool
}
// Common project files that indicate specific languages
var projectFilePatterns = []ProjectFile{
{Name: "go.mod", LanguageID: "go", IsPrimary: true},
{Name: "go.sum", LanguageID: "go", IsPrimary: false},
{Name: "package.json", LanguageID: "javascript", IsPrimary: true}, // Could be TypeScript too
{Name: "tsconfig.json", LanguageID: "typescript", IsPrimary: true},
{Name: "jsconfig.json", LanguageID: "javascript", IsPrimary: true},
{Name: "pyproject.toml", LanguageID: "python", IsPrimary: true},
{Name: "setup.py", LanguageID: "python", IsPrimary: true},
{Name: "requirements.txt", LanguageID: "python", IsPrimary: true},
{Name: "Cargo.toml", LanguageID: "rust", IsPrimary: true},
{Name: "Cargo.lock", LanguageID: "rust", IsPrimary: false},
{Name: "CMakeLists.txt", LanguageID: "cmake", IsPrimary: true},
{Name: "pom.xml", LanguageID: "java", IsPrimary: true},
{Name: "build.gradle", LanguageID: "java", IsPrimary: true},
{Name: "build.gradle.kts", LanguageID: "kotlin", IsPrimary: true},
{Name: "composer.json", LanguageID: "php", IsPrimary: true},
{Name: "Gemfile", LanguageID: "ruby", IsPrimary: true},
{Name: "Rakefile", LanguageID: "ruby", IsPrimary: true},
{Name: "mix.exs", LanguageID: "elixir", IsPrimary: true},
{Name: "rebar.config", LanguageID: "erlang", IsPrimary: true},
{Name: "dune-project", LanguageID: "ocaml", IsPrimary: true},
{Name: "stack.yaml", LanguageID: "haskell", IsPrimary: true},
{Name: "cabal.project", LanguageID: "haskell", IsPrimary: true},
{Name: "Makefile", LanguageID: "make", IsPrimary: false},
{Name: "Dockerfile", LanguageID: "dockerfile", IsPrimary: false},
}
// Map of file extensions to language IDs
var extensionToLanguage = map[string]string{
".go": "go",
".js": "javascript",
".jsx": "javascript",
".ts": "typescript",
".tsx": "typescript",
".py": "python",
".rs": "rust",
".java": "java",
".c": "c",
".cpp": "cpp",
".h": "c",
".hpp": "cpp",
".rb": "ruby",
".php": "php",
".cs": "csharp",
".fs": "fsharp",
".swift": "swift",
".kt": "kotlin",
".scala": "scala",
".hs": "haskell",
".ml": "ocaml",
".ex": "elixir",
".exs": "elixir",
".erl": "erlang",
".lua": "lua",
".r": "r",
".sh": "shell",
".bash": "shell",
".zsh": "shell",
".html": "html",
".css": "css",
".scss": "scss",
".sass": "sass",
".less": "less",
".json": "json",
".xml": "xml",
".yaml": "yaml",
".yml": "yaml",
".md": "markdown",
".dart": "dart",
}
// Directories to exclude from scanning
var excludedDirs = map[string]bool{
".git": true,
"node_modules": true,
"vendor": true,
"dist": true,
"build": true,
"target": true,
".idea": true,
".vscode": true,
".github": true,
".gitlab": true,
"__pycache__": true,
".next": true,
".nuxt": true,
"venv": true,
"env": true,
".env": true,
}
// DetectLanguages scans a directory to identify programming languages used in the project
func DetectLanguages(rootDir string) (map[string]LanguageInfo, error) {
languages := make(map[string]LanguageInfo)
var mutex sync.Mutex
// Walk the directory tree
err := filepath.Walk(rootDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return nil // Skip files that can't be accessed
}
// Skip excluded directories
if info.IsDir() {
if excludedDirs[info.Name()] || strings.HasPrefix(info.Name(), ".") {
return filepath.SkipDir
}
return nil
}
// Skip hidden files
if strings.HasPrefix(info.Name(), ".") {
return nil
}
// Check for project files
for _, pattern := range projectFilePatterns {
if info.Name() == pattern.Name {
mutex.Lock()
lang, exists := languages[pattern.LanguageID]
if !exists {
lang = LanguageInfo{
ID: pattern.LanguageID,
FileCount: 0,
ProjectFiles: []string{},
IsPrimary: pattern.IsPrimary,
}
}
lang.ProjectFiles = append(lang.ProjectFiles, path)
if pattern.IsPrimary {
lang.IsPrimary = true
}
languages[pattern.LanguageID] = lang
mutex.Unlock()
break
}
}
// Check file extension
ext := strings.ToLower(filepath.Ext(path))
if langID, ok := extensionToLanguage[ext]; ok {
mutex.Lock()
lang, exists := languages[langID]
if !exists {
lang = LanguageInfo{
ID: langID,
FileCount: 0,
ProjectFiles: []string{},
}
}
lang.FileCount++
languages[langID] = lang
mutex.Unlock()
}
return nil
})
if err != nil {
return nil, err
}
// Determine primary languages based on file count if not already marked
determinePrimaryLanguages(languages)
// Log detected languages
for id, info := range languages {
if info.IsPrimary {
slog.Debug("Detected primary language", "language", id, "files", info.FileCount, "projectFiles", len(info.ProjectFiles))
} else {
slog.Debug("Detected secondary language", "language", id, "files", info.FileCount)
}
}
return languages, nil
}
// determinePrimaryLanguages marks languages as primary based on file count
func determinePrimaryLanguages(languages map[string]LanguageInfo) {
// Find the language with the most files
var maxFiles int
for _, info := range languages {
if info.FileCount > maxFiles {
maxFiles = info.FileCount
}
}
// Mark languages with at least 20% of the max files as primary
threshold := max(maxFiles/5, 5) // At least 5 files to be considered primary
for id, info := range languages {
if !info.IsPrimary && info.FileCount >= threshold {
info.IsPrimary = true
languages[id] = info
}
}
}
// GetLanguageIDFromExtension returns the language ID for a given file extension
func GetLanguageIDFromExtension(ext string) string {
ext = strings.ToLower(ext)
if langID, ok := extensionToLanguage[ext]; ok {
return langID
}
return ""
}
// GetLanguageIDFromProtocol converts a protocol.LanguageKind to our language ID string
func GetLanguageIDFromProtocol(langKind string) string {
// Convert protocol language kind to our language ID
switch langKind {
case "go":
return "go"
case "typescript":
return "typescript"
case "typescriptreact":
return "typescript"
case "javascript":
return "javascript"
case "javascriptreact":
return "javascript"
case "python":
return "python"
case "rust":
return "rust"
case "java":
return "java"
case "c":
return "c"
case "cpp":
return "cpp"
default:
// Try to normalize the language kind
return strings.ToLower(langKind)
}
}
// GetLanguageIDFromPath determines the language ID from a file path
func GetLanguageIDFromPath(path string) string {
// Check file extension first
ext := filepath.Ext(path)
if langID := GetLanguageIDFromExtension(ext); langID != "" {
return langID
}
// Check if it's a known project file
filename := filepath.Base(path)
for _, pattern := range projectFilePatterns {
if filename == pattern.Name {
return pattern.LanguageID
}
}
// Use LSP's detection as a fallback
uri := "file://" + path
langKind := lsp.DetectLanguageID(uri)
return GetLanguageIDFromProtocol(string(langKind))
}

View File

@@ -0,0 +1,306 @@
package discovery
import (
"fmt"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"log/slog"
)
// ServerInfo contains information about an LSP server
type ServerInfo struct {
// Command to run the server
Command string
// Arguments to pass to the command
Args []string
// Command to install the server (for user guidance)
InstallCmd string
// Whether this server is available
Available bool
// Full path to the executable (if found)
Path string
}
// LanguageServerMap maps language IDs to their corresponding LSP servers
var LanguageServerMap = map[string]ServerInfo{
"go": {
Command: "gopls",
InstallCmd: "go install golang.org/x/tools/gopls@latest",
},
"typescript": {
Command: "typescript-language-server",
Args: []string{"--stdio"},
InstallCmd: "npm install -g typescript-language-server typescript",
},
"javascript": {
Command: "typescript-language-server",
Args: []string{"--stdio"},
InstallCmd: "npm install -g typescript-language-server typescript",
},
"python": {
Command: "pylsp",
InstallCmd: "pip install python-lsp-server",
},
"rust": {
Command: "rust-analyzer",
InstallCmd: "rustup component add rust-analyzer",
},
"java": {
Command: "jdtls",
InstallCmd: "Install Eclipse JDT Language Server",
},
"c": {
Command: "clangd",
InstallCmd: "Install clangd from your package manager",
},
"cpp": {
Command: "clangd",
InstallCmd: "Install clangd from your package manager",
},
"php": {
Command: "intelephense",
Args: []string{"--stdio"},
InstallCmd: "npm install -g intelephense",
},
"ruby": {
Command: "solargraph",
Args: []string{"stdio"},
InstallCmd: "gem install solargraph",
},
"lua": {
Command: "lua-language-server",
InstallCmd: "Install lua-language-server from your package manager",
},
"html": {
Command: "vscode-html-language-server",
Args: []string{"--stdio"},
InstallCmd: "npm install -g vscode-langservers-extracted",
},
"css": {
Command: "vscode-css-language-server",
Args: []string{"--stdio"},
InstallCmd: "npm install -g vscode-langservers-extracted",
},
"json": {
Command: "vscode-json-language-server",
Args: []string{"--stdio"},
InstallCmd: "npm install -g vscode-langservers-extracted",
},
"yaml": {
Command: "yaml-language-server",
Args: []string{"--stdio"},
InstallCmd: "npm install -g yaml-language-server",
},
}
// FindLSPServer searches for an LSP server for the given language
func FindLSPServer(languageID string) (ServerInfo, error) {
// Get server info for the language
serverInfo, exists := LanguageServerMap[languageID]
if !exists {
return ServerInfo{}, fmt.Errorf("no LSP server defined for language: %s", languageID)
}
// Check if the command is in PATH
path, err := exec.LookPath(serverInfo.Command)
if err == nil {
serverInfo.Available = true
serverInfo.Path = path
slog.Debug("Found LSP server in PATH", "language", languageID, "command", serverInfo.Command, "path", path)
return serverInfo, nil
}
// If not in PATH, search in common installation locations
paths := getCommonLSPPaths(languageID, serverInfo.Command)
for _, searchPath := range paths {
if _, err := os.Stat(searchPath); err == nil {
// Found the server
serverInfo.Available = true
serverInfo.Path = searchPath
slog.Debug("Found LSP server in common location", "language", languageID, "command", serverInfo.Command, "path", searchPath)
return serverInfo, nil
}
}
// Server not found
slog.Debug("LSP server not found", "language", languageID, "command", serverInfo.Command)
return serverInfo, fmt.Errorf("LSP server for %s not found. Install with: %s", languageID, serverInfo.InstallCmd)
}
// getCommonLSPPaths returns common installation paths for LSP servers based on language and OS
func getCommonLSPPaths(languageID, command string) []string {
var paths []string
homeDir, err := os.UserHomeDir()
if err != nil {
slog.Error("Failed to get user home directory", "error", err)
return paths
}
// Add platform-specific paths
switch runtime.GOOS {
case "darwin":
// macOS paths
paths = append(paths,
fmt.Sprintf("/usr/local/bin/%s", command),
fmt.Sprintf("/opt/homebrew/bin/%s", command),
fmt.Sprintf("%s/.local/bin/%s", homeDir, command),
)
case "linux":
// Linux paths
paths = append(paths,
fmt.Sprintf("/usr/bin/%s", command),
fmt.Sprintf("/usr/local/bin/%s", command),
fmt.Sprintf("%s/.local/bin/%s", homeDir, command),
)
case "windows":
// Windows paths
paths = append(paths,
fmt.Sprintf("%s\\AppData\\Local\\Programs\\%s.exe", homeDir, command),
fmt.Sprintf("C:\\Program Files\\%s\\bin\\%s.exe", command, command),
)
}
// Add language-specific paths
switch languageID {
case "go":
gopath := os.Getenv("GOPATH")
if gopath == "" {
gopath = filepath.Join(homeDir, "go")
}
paths = append(paths, filepath.Join(gopath, "bin", command))
if runtime.GOOS == "windows" {
paths = append(paths, filepath.Join(gopath, "bin", command+".exe"))
}
case "typescript", "javascript", "html", "css", "json", "yaml", "php":
// Node.js global packages
if runtime.GOOS == "windows" {
paths = append(paths,
fmt.Sprintf("%s\\AppData\\Roaming\\npm\\%s.cmd", homeDir, command),
fmt.Sprintf("%s\\AppData\\Roaming\\npm\\node_modules\\.bin\\%s.cmd", homeDir, command),
)
} else {
paths = append(paths,
fmt.Sprintf("%s/.npm-global/bin/%s", homeDir, command),
fmt.Sprintf("%s/.nvm/versions/node/*/bin/%s", homeDir, command),
fmt.Sprintf("/usr/local/lib/node_modules/.bin/%s", command),
)
}
case "python":
// Python paths
if runtime.GOOS == "windows" {
paths = append(paths,
fmt.Sprintf("%s\\AppData\\Local\\Programs\\Python\\Python*\\Scripts\\%s.exe", homeDir, command),
fmt.Sprintf("C:\\Python*\\Scripts\\%s.exe", command),
)
} else {
paths = append(paths,
fmt.Sprintf("%s/.local/bin/%s", homeDir, command),
fmt.Sprintf("%s/.pyenv/shims/%s", homeDir, command),
fmt.Sprintf("/usr/local/bin/%s", command),
)
}
case "rust":
// Rust paths
if runtime.GOOS == "windows" {
paths = append(paths,
fmt.Sprintf("%s\\.rustup\\toolchains\\*\\bin\\%s.exe", homeDir, command),
fmt.Sprintf("%s\\.cargo\\bin\\%s.exe", homeDir, command),
)
} else {
paths = append(paths,
fmt.Sprintf("%s/.rustup/toolchains/*/bin/%s", homeDir, command),
fmt.Sprintf("%s/.cargo/bin/%s", homeDir, command),
)
}
}
// Add VSCode extensions path
vscodePath := getVSCodeExtensionsPath(homeDir)
if vscodePath != "" {
paths = append(paths, vscodePath)
}
// Expand any glob patterns in paths
var expandedPaths []string
for _, path := range paths {
if strings.Contains(path, "*") {
// This is a glob pattern, expand it
matches, err := filepath.Glob(path)
if err == nil {
expandedPaths = append(expandedPaths, matches...)
}
} else {
expandedPaths = append(expandedPaths, path)
}
}
return expandedPaths
}
// getVSCodeExtensionsPath returns the path to VSCode extensions directory
func getVSCodeExtensionsPath(homeDir string) string {
var basePath string
switch runtime.GOOS {
case "darwin":
basePath = filepath.Join(homeDir, "Library", "Application Support", "Code", "User", "globalStorage")
case "linux":
basePath = filepath.Join(homeDir, ".config", "Code", "User", "globalStorage")
case "windows":
basePath = filepath.Join(homeDir, "AppData", "Roaming", "Code", "User", "globalStorage")
default:
return ""
}
// Check if the directory exists
if _, err := os.Stat(basePath); err != nil {
return ""
}
return basePath
}
// ConfigureLSPServers detects languages and configures LSP servers
func ConfigureLSPServers(rootDir string) (map[string]ServerInfo, error) {
// Detect languages in the project
languages, err := DetectLanguages(rootDir)
if err != nil {
return nil, fmt.Errorf("failed to detect languages: %w", err)
}
// Find LSP servers for detected languages
servers := make(map[string]ServerInfo)
for langID, langInfo := range languages {
// Prioritize primary languages but include all languages that have server definitions
if !langInfo.IsPrimary && langInfo.FileCount < 3 {
// Skip non-primary languages with very few files
slog.Debug("Skipping non-primary language with few files", "language", langID, "files", langInfo.FileCount)
continue
}
// Check if we have a server for this language
serverInfo, err := FindLSPServer(langID)
if err != nil {
slog.Warn("LSP server not found", "language", langID, "error", err)
continue
}
// Add to the map of configured servers
servers[langID] = serverInfo
if langInfo.IsPrimary {
slog.Info("Configured LSP server for primary language", "language", langID, "command", serverInfo.Command, "path", serverInfo.Path)
} else {
slog.Info("Configured LSP server for secondary language", "language", langID, "command", serverInfo.Command, "path", serverInfo.Path)
}
}
return servers, nil
}

View File

@@ -3,10 +3,10 @@ package lsp
import (
"encoding/json"
"github.com/opencode-ai/opencode/internal/config"
"github.com/opencode-ai/opencode/internal/logging"
"github.com/opencode-ai/opencode/internal/lsp/protocol"
"github.com/opencode-ai/opencode/internal/lsp/util"
"github.com/sst/opencode/internal/config"
"github.com/sst/opencode/internal/lsp/protocol"
"github.com/sst/opencode/internal/lsp/util"
"log/slog"
)
// Requests
@@ -18,7 +18,7 @@ func HandleWorkspaceConfiguration(params json.RawMessage) (any, error) {
func HandleRegisterCapability(params json.RawMessage) (any, error) {
var registerParams protocol.RegistrationParams
if err := json.Unmarshal(params, &registerParams); err != nil {
logging.Error("Error unmarshaling registration params", "error", err)
slog.Error("Error unmarshaling registration params", "error", err)
return nil, err
}
@@ -28,13 +28,13 @@ func HandleRegisterCapability(params json.RawMessage) (any, error) {
// Parse the registration options
optionsJSON, err := json.Marshal(reg.RegisterOptions)
if err != nil {
logging.Error("Error marshaling registration options", "error", err)
slog.Error("Error marshaling registration options", "error", err)
continue
}
var options protocol.DidChangeWatchedFilesRegistrationOptions
if err := json.Unmarshal(optionsJSON, &options); err != nil {
logging.Error("Error unmarshaling registration options", "error", err)
slog.Error("Error unmarshaling registration options", "error", err)
continue
}
@@ -54,7 +54,7 @@ func HandleApplyEdit(params json.RawMessage) (any, error) {
err := util.ApplyWorkspaceEdit(edit.Edit)
if err != nil {
logging.Error("Error applying workspace edit", "error", err)
slog.Error("Error applying workspace edit", "error", err)
return protocol.ApplyWorkspaceEditResult{Applied: false, FailureReason: err.Error()}, nil
}
@@ -89,7 +89,7 @@ func HandleServerMessage(params json.RawMessage) {
}
if err := json.Unmarshal(params, &msg); err == nil {
if cnf.DebugLSP {
logging.Debug("Server message", "type", msg.Type, "message", msg.Message)
slog.Debug("Server message", "type", msg.Type, "message", msg.Message)
}
}
}
@@ -97,7 +97,7 @@ func HandleServerMessage(params json.RawMessage) {
func HandleDiagnostics(client *Client, params json.RawMessage) {
var diagParams protocol.PublishDiagnosticsParams
if err := json.Unmarshal(params, &diagParams); err != nil {
logging.Error("Error unmarshaling diagnostics params", "error", err)
slog.Error("Error unmarshaling diagnostics params", "error", err)
return
}

View File

@@ -4,7 +4,7 @@ import (
"path/filepath"
"strings"
"github.com/opencode-ai/opencode/internal/lsp/protocol"
"github.com/sst/opencode/internal/lsp/protocol"
)
func DetectLanguageID(uri string) protocol.LanguageKind {

View File

@@ -4,7 +4,7 @@ package lsp
import (
"context"
"github.com/opencode-ai/opencode/internal/lsp/protocol"
"github.com/sst/opencode/internal/lsp/protocol"
)
// Implementation sends a textDocument/implementation request to the LSP server.

View File

@@ -8,8 +8,8 @@ import (
"io"
"strings"
"github.com/opencode-ai/opencode/internal/config"
"github.com/opencode-ai/opencode/internal/logging"
"github.com/sst/opencode/internal/config"
"log/slog"
)
// Write writes an LSP message to the given writer
@@ -21,7 +21,7 @@ func WriteMessage(w io.Writer, msg *Message) error {
cnf := config.Get()
if cnf.DebugLSP {
logging.Debug("Sending message to server", "method", msg.Method, "id", msg.ID)
slog.Debug("Sending message to server", "method", msg.Method, "id", msg.ID)
}
_, err = fmt.Fprintf(w, "Content-Length: %d\r\n\r\n", len(data))
@@ -50,7 +50,7 @@ func ReadMessage(r *bufio.Reader) (*Message, error) {
line = strings.TrimSpace(line)
if cnf.DebugLSP {
logging.Debug("Received header", "line", line)
slog.Debug("Received header", "line", line)
}
if line == "" {
@@ -66,7 +66,7 @@ func ReadMessage(r *bufio.Reader) (*Message, error) {
}
if cnf.DebugLSP {
logging.Debug("Content-Length", "length", contentLength)
slog.Debug("Content-Length", "length", contentLength)
}
// Read content
@@ -77,7 +77,7 @@ func ReadMessage(r *bufio.Reader) (*Message, error) {
}
if cnf.DebugLSP {
logging.Debug("Received content", "content", string(content))
slog.Debug("Received content", "content", string(content))
}
// Parse message
@@ -96,7 +96,7 @@ func (c *Client) handleMessages() {
msg, err := ReadMessage(c.stdout)
if err != nil {
if cnf.DebugLSP {
logging.Error("Error reading message", "error", err)
slog.Error("Error reading message", "error", err)
}
return
}
@@ -104,7 +104,7 @@ func (c *Client) handleMessages() {
// Handle server->client request (has both Method and ID)
if msg.Method != "" && msg.ID != 0 {
if cnf.DebugLSP {
logging.Debug("Received request from server", "method", msg.Method, "id", msg.ID)
slog.Debug("Received request from server", "method", msg.Method, "id", msg.ID)
}
response := &Message{
@@ -144,7 +144,7 @@ func (c *Client) handleMessages() {
// Send response back to server
if err := WriteMessage(c.stdin, response); err != nil {
logging.Error("Error sending response to server", "error", err)
slog.Error("Error sending response to server", "error", err)
}
continue
@@ -158,11 +158,11 @@ func (c *Client) handleMessages() {
if ok {
if cnf.DebugLSP {
logging.Debug("Handling notification", "method", msg.Method)
slog.Debug("Handling notification", "method", msg.Method)
}
go handler(msg.Params)
} else if cnf.DebugLSP {
logging.Debug("No handler for notification", "method", msg.Method)
slog.Debug("No handler for notification", "method", msg.Method)
}
continue
}
@@ -175,12 +175,12 @@ func (c *Client) handleMessages() {
if ok {
if cnf.DebugLSP {
logging.Debug("Received response for request", "id", msg.ID)
slog.Debug("Received response for request", "id", msg.ID)
}
ch <- msg
close(ch)
} else if cnf.DebugLSP {
logging.Debug("No handler for response", "id", msg.ID)
slog.Debug("No handler for response", "id", msg.ID)
}
}
}
@@ -192,7 +192,7 @@ func (c *Client) Call(ctx context.Context, method string, params any, result any
id := c.nextID.Add(1)
if cnf.DebugLSP {
logging.Debug("Making call", "method", method, "id", id)
slog.Debug("Making call", "method", method, "id", id)
}
msg, err := NewRequest(id, method, params)
@@ -218,14 +218,14 @@ func (c *Client) Call(ctx context.Context, method string, params any, result any
}
if cnf.DebugLSP {
logging.Debug("Request sent", "method", method, "id", id)
slog.Debug("Request sent", "method", method, "id", id)
}
// Wait for response
resp := <-ch
if cnf.DebugLSP {
logging.Debug("Received response", "id", id)
slog.Debug("Received response", "id", id)
}
if resp.Error != nil {
@@ -251,7 +251,7 @@ func (c *Client) Call(ctx context.Context, method string, params any, result any
func (c *Client) Notify(ctx context.Context, method string, params any) error {
cnf := config.Get()
if cnf.DebugLSP {
logging.Debug("Sending notification", "method", method)
slog.Debug("Sending notification", "method", method)
}
msg, err := NewNotification(method, params)

View File

@@ -7,7 +7,7 @@ import (
"sort"
"strings"
"github.com/opencode-ai/opencode/internal/lsp/protocol"
"github.com/sst/opencode/internal/lsp/protocol"
)
func applyTextEdits(uri protocol.DocumentUri, edits []protocol.TextEdit) error {

View File

@@ -5,16 +5,17 @@ import (
"fmt"
"os"
"path/filepath"
"strconv"
"strings"
"sync"
"time"
"github.com/bmatcuk/doublestar/v4"
"github.com/fsnotify/fsnotify"
"github.com/opencode-ai/opencode/internal/config"
"github.com/opencode-ai/opencode/internal/logging"
"github.com/opencode-ai/opencode/internal/lsp"
"github.com/opencode-ai/opencode/internal/lsp/protocol"
"github.com/sst/opencode/internal/config"
"github.com/sst/opencode/internal/lsp"
"github.com/sst/opencode/internal/lsp/protocol"
"log/slog"
)
// WorkspaceWatcher manages LSP file watching
@@ -45,7 +46,7 @@ func NewWorkspaceWatcher(client *lsp.Client) *WorkspaceWatcher {
func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watchers []protocol.FileSystemWatcher) {
cnf := config.Get()
logging.Debug("Adding file watcher registrations")
slog.Debug("Adding file watcher registrations")
w.registrationMu.Lock()
defer w.registrationMu.Unlock()
@@ -54,33 +55,33 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc
// Print detailed registration information for debugging
if cnf.DebugLSP {
logging.Debug("Adding file watcher registrations",
slog.Debug("Adding file watcher registrations",
"id", id,
"watchers", len(watchers),
"total", len(w.registrations),
)
for i, watcher := range watchers {
logging.Debug("Registration", "index", i+1)
slog.Debug("Registration", "index", i+1)
// Log the GlobPattern
switch v := watcher.GlobPattern.Value.(type) {
case string:
logging.Debug("GlobPattern", "pattern", v)
slog.Debug("GlobPattern", "pattern", v)
case protocol.RelativePattern:
logging.Debug("GlobPattern", "pattern", v.Pattern)
slog.Debug("GlobPattern", "pattern", v.Pattern)
// Log BaseURI details
switch u := v.BaseURI.Value.(type) {
case string:
logging.Debug("BaseURI", "baseURI", u)
slog.Debug("BaseURI", "baseURI", u)
case protocol.DocumentUri:
logging.Debug("BaseURI", "baseURI", u)
slog.Debug("BaseURI", "baseURI", u)
default:
logging.Debug("BaseURI", "baseURI", u)
slog.Debug("BaseURI", "baseURI", u)
}
default:
logging.Debug("GlobPattern", "unknown type", fmt.Sprintf("%T", v))
slog.Debug("GlobPattern", "unknown type", fmt.Sprintf("%T", v))
}
// Log WatchKind
@@ -89,13 +90,13 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc
watchKind = *watcher.Kind
}
logging.Debug("WatchKind", "kind", watchKind)
slog.Debug("WatchKind", "kind", watchKind)
}
}
// Determine server type for specialized handling
serverName := getServerNameFromContext(ctx)
logging.Debug("Server type detected", "serverName", serverName)
slog.Debug("Server type detected", "serverName", serverName)
// Check if this server has sent file watchers
hasFileWatchers := len(watchers) > 0
@@ -123,7 +124,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc
filesOpened += highPriorityFilesOpened
if cnf.DebugLSP {
logging.Debug("Opened high-priority files",
slog.Debug("Opened high-priority files",
"count", highPriorityFilesOpened,
"serverName", serverName)
}
@@ -131,7 +132,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc
// If we've already opened enough high-priority files, we might not need more
if filesOpened >= maxFilesToOpen {
if cnf.DebugLSP {
logging.Debug("Reached file limit with high-priority files",
slog.Debug("Reached file limit with high-priority files",
"filesOpened", filesOpened,
"maxFiles", maxFilesToOpen)
}
@@ -149,7 +150,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc
if d.IsDir() {
if path != w.workspacePath && shouldExcludeDir(path) {
if cnf.DebugLSP {
logging.Debug("Skipping excluded directory", "path", path)
slog.Debug("Skipping excluded directory", "path", path)
}
return filepath.SkipDir
}
@@ -177,7 +178,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc
elapsedTime := time.Since(startTime)
if cnf.DebugLSP {
logging.Debug("Limited workspace scan complete",
slog.Debug("Limited workspace scan complete",
"filesOpened", filesOpened,
"maxFiles", maxFilesToOpen,
"elapsedTime", elapsedTime.Seconds(),
@@ -186,11 +187,11 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc
}
if err != nil && cnf.DebugLSP {
logging.Debug("Error scanning workspace for files to open", "error", err)
slog.Debug("Error scanning workspace for files to open", "error", err)
}
}()
} else if cnf.DebugLSP {
logging.Debug("Using on-demand file loading for server", "server", serverName)
slog.Debug("Using on-demand file loading for server", "server", serverName)
}
}
@@ -263,7 +264,7 @@ func (w *WorkspaceWatcher) openHighPriorityFiles(ctx context.Context, serverName
matches, err := doublestar.Glob(os.DirFS(w.workspacePath), pattern)
if err != nil {
if cnf.DebugLSP {
logging.Debug("Error finding high-priority files", "pattern", pattern, "error", err)
slog.Debug("Error finding high-priority files", "pattern", pattern, "error", err)
}
continue
}
@@ -281,12 +282,12 @@ func (w *WorkspaceWatcher) openHighPriorityFiles(ctx context.Context, serverName
// Open the file
if err := w.client.OpenFile(ctx, fullPath); err != nil {
if cnf.DebugLSP {
logging.Debug("Error opening high-priority file", "path", fullPath, "error", err)
slog.Debug("Error opening high-priority file", "path", fullPath, "error", err)
}
} else {
filesOpened++
if cnf.DebugLSP {
logging.Debug("Opened high-priority file", "path", fullPath)
slog.Debug("Opened high-priority file", "path", fullPath)
}
}
@@ -318,7 +319,7 @@ func (w *WorkspaceWatcher) WatchWorkspace(ctx context.Context, workspacePath str
}
serverName := getServerNameFromContext(ctx)
logging.Debug("Starting workspace watcher", "workspacePath", workspacePath, "serverName", serverName)
slog.Debug("Starting workspace watcher", "workspacePath", workspacePath, "serverName", serverName)
// Register handler for file watcher registrations from the server
lsp.RegisterFileWatchHandler(func(id string, watchers []protocol.FileSystemWatcher) {
@@ -327,7 +328,7 @@ func (w *WorkspaceWatcher) WatchWorkspace(ctx context.Context, workspacePath str
watcher, err := fsnotify.NewWatcher()
if err != nil {
logging.Error("Error creating watcher", "error", err)
slog.Error("Error creating watcher", "error", err)
}
defer watcher.Close()
@@ -341,7 +342,7 @@ func (w *WorkspaceWatcher) WatchWorkspace(ctx context.Context, workspacePath str
if d.IsDir() && path != workspacePath {
if shouldExcludeDir(path) {
if cnf.DebugLSP {
logging.Debug("Skipping excluded directory", "path", path)
slog.Debug("Skipping excluded directory", "path", path)
}
return filepath.SkipDir
}
@@ -351,14 +352,14 @@ func (w *WorkspaceWatcher) WatchWorkspace(ctx context.Context, workspacePath str
if d.IsDir() {
err = watcher.Add(path)
if err != nil {
logging.Error("Error watching path", "path", path, "error", err)
slog.Error("Error watching path", "path", path, "error", err)
}
}
return nil
})
if err != nil {
logging.Error("Error walking workspace", "error", err)
slog.Error("Error walking workspace", "error", err)
}
// Event loop
@@ -375,27 +376,37 @@ func (w *WorkspaceWatcher) WatchWorkspace(ctx context.Context, workspacePath str
// Add new directories to the watcher
if event.Op&fsnotify.Create != 0 {
if info, err := os.Stat(event.Name); err == nil {
if info.IsDir() {
// Skip excluded directories
if !shouldExcludeDir(event.Name) {
if err := watcher.Add(event.Name); err != nil {
logging.Error("Error adding directory to watcher", "path", event.Name, "error", err)
}
}
} else {
// For newly created files
if !shouldExcludeFile(event.Name) {
w.openMatchingFile(ctx, event.Name)
// Check if the file/directory still exists before processing
info, err := os.Stat(event.Name)
if err != nil {
if os.IsNotExist(err) {
// File was deleted between event and processing - ignore
slog.Debug("File deleted between create event and stat", "path", event.Name)
continue
}
slog.Error("Error getting file info", "path", event.Name, "error", err)
continue
}
if info.IsDir() {
// Skip excluded directories
if !shouldExcludeDir(event.Name) {
if err := watcher.Add(event.Name); err != nil {
slog.Error("Error adding directory to watcher", "path", event.Name, "error", err)
}
}
} else {
// For newly created files
if !shouldExcludeFile(event.Name) {
w.openMatchingFile(ctx, event.Name)
}
}
}
// Debug logging
if cnf.DebugLSP {
matched, kind := w.isPathWatched(event.Name)
logging.Debug("File event",
slog.Debug("File event",
"path", event.Name,
"operation", event.Op.String(),
"watched", matched,
@@ -416,7 +427,7 @@ func (w *WorkspaceWatcher) WatchWorkspace(ctx context.Context, workspacePath str
// Just send the notification if needed
info, err := os.Stat(event.Name)
if err != nil {
logging.Error("Error getting file info", "path", event.Name, "error", err)
slog.Error("Error getting file info", "path", event.Name, "error", err)
return
}
if !info.IsDir() && watchKind&protocol.WatchCreate != 0 {
@@ -444,7 +455,7 @@ func (w *WorkspaceWatcher) WatchWorkspace(ctx context.Context, workspacePath str
if !ok {
return
}
logging.Error("Error watching file", "error", err)
slog.Error("Error watching file", "error", err)
}
}
}
@@ -569,7 +580,7 @@ func matchesSimpleGlob(pattern, path string) bool {
// Fall back to simple matching for simpler patterns
matched, err := filepath.Match(pattern, path)
if err != nil {
logging.Error("Error matching pattern", "pattern", pattern, "path", path, "error", err)
slog.Error("Error matching pattern", "pattern", pattern, "path", path, "error", err)
return false
}
@@ -580,7 +591,7 @@ func matchesSimpleGlob(pattern, path string) bool {
func (w *WorkspaceWatcher) matchesPattern(path string, pattern protocol.GlobPattern) bool {
patternInfo, err := pattern.AsPattern()
if err != nil {
logging.Error("Error parsing pattern", "pattern", pattern, "error", err)
slog.Error("Error parsing pattern", "pattern", pattern, "error", err)
return false
}
@@ -605,7 +616,7 @@ func (w *WorkspaceWatcher) matchesPattern(path string, pattern protocol.GlobPatt
// Make path relative to basePath for matching
relPath, err := filepath.Rel(basePath, path)
if err != nil {
logging.Error("Error getting relative path", "path", path, "basePath", basePath, "error", err)
slog.Error("Error getting relative path", "path", path, "basePath", basePath, "error", err)
return false
}
relPath = filepath.ToSlash(relPath)
@@ -643,19 +654,55 @@ func (w *WorkspaceWatcher) debounceHandleFileEvent(ctx context.Context, uri stri
func (w *WorkspaceWatcher) handleFileEvent(ctx context.Context, uri string, changeType protocol.FileChangeType) {
// If the file is open and it's a change event, use didChange notification
filePath := uri[7:] // Remove "file://" prefix
if changeType == protocol.FileChangeType(protocol.Deleted) {
// Always clear diagnostics for deleted files
w.client.ClearDiagnosticsForURI(protocol.DocumentUri(uri))
} else if changeType == protocol.FileChangeType(protocol.Changed) && w.client.IsFileOpen(filePath) {
err := w.client.NotifyChange(ctx, filePath)
if err != nil {
logging.Error("Error notifying change", "error", err)
// If the file was open, close it in the LSP client
if w.client.IsFileOpen(filePath) {
if err := w.client.CloseFile(ctx, filePath); err != nil {
slog.Debug("Error closing deleted file in LSP client", "file", filePath, "error", err)
// Continue anyway - the file is gone
}
}
} else if changeType == protocol.FileChangeType(protocol.Changed) {
// For changed files, verify the file still exists before notifying
if _, err := os.Stat(filePath); err != nil {
if os.IsNotExist(err) {
// File was deleted between the event and now - treat as delete
slog.Debug("File deleted between change event and processing", "file", filePath)
w.handleFileEvent(ctx, uri, protocol.FileChangeType(protocol.Deleted))
return
}
slog.Error("Error getting file info", "path", filePath, "error", err)
return
}
// File exists and is open, notify change
if w.client.IsFileOpen(filePath) {
err := w.client.NotifyChange(ctx, filePath)
if err != nil {
slog.Error("Error notifying change", "error", err)
}
return
}
} else if changeType == protocol.FileChangeType(protocol.Created) {
// For created files, verify the file still exists before notifying
if _, err := os.Stat(filePath); err != nil {
if os.IsNotExist(err) {
// File was deleted between the event and now - ignore
slog.Debug("File deleted between create event and processing", "file", filePath)
return
}
slog.Error("Error getting file info", "path", filePath, "error", err)
return
}
return
}
// Notify LSP server about the file event using didChangeWatchedFiles
if err := w.notifyFileEvent(ctx, uri, changeType); err != nil {
logging.Error("Error notifying LSP server about file event", "error", err)
slog.Error("Error notifying LSP server about file event", "error", err)
}
}
@@ -663,7 +710,7 @@ func (w *WorkspaceWatcher) handleFileEvent(ctx context.Context, uri string, chan
func (w *WorkspaceWatcher) notifyFileEvent(ctx context.Context, uri string, changeType protocol.FileChangeType) error {
cnf := config.Get()
if cnf.DebugLSP {
logging.Debug("Notifying file event",
slog.Debug("Notifying file event",
"uri", uri,
"changeType", changeType,
)
@@ -828,6 +875,11 @@ func shouldExcludeFile(filePath string) bool {
return true
}
// Skip numeric temporary files (often created by editors)
if _, err := strconv.Atoi(fileName); err == nil {
return true
}
// Check file size
info, err := os.Stat(filePath)
if err != nil {
@@ -838,7 +890,7 @@ func shouldExcludeFile(filePath string) bool {
// Skip large files
if info.Size() > maxFileSize {
if cnf.DebugLSP {
logging.Debug("Skipping large file",
slog.Debug("Skipping large file",
"path", filePath,
"size", info.Size(),
"maxSize", maxFileSize,
@@ -856,9 +908,19 @@ func shouldExcludeFile(filePath string) bool {
// openMatchingFile opens a file if it matches any of the registered patterns
func (w *WorkspaceWatcher) openMatchingFile(ctx context.Context, path string) {
cnf := config.Get()
// Skip directories
// Skip directories and verify file exists
info, err := os.Stat(path)
if err != nil || info.IsDir() {
if err != nil {
if os.IsNotExist(err) {
// File was deleted between event and processing - ignore
slog.Debug("File deleted between event and openMatchingFile", "path", path)
return
}
slog.Error("Error getting file info", "path", path, "error", err)
return
}
if info.IsDir() {
return
}
@@ -876,10 +938,10 @@ func (w *WorkspaceWatcher) openMatchingFile(ctx context.Context, path string) {
// This helps with project initialization for certain language servers
if isHighPriorityFile(path, serverName) {
if cnf.DebugLSP {
logging.Debug("Opening high-priority file", "path", path, "serverName", serverName)
slog.Debug("Opening high-priority file", "path", path, "serverName", serverName)
}
if err := w.client.OpenFile(ctx, path); err != nil && cnf.DebugLSP {
logging.Error("Error opening high-priority file", "path", path, "error", err)
slog.Error("Error opening high-priority file", "path", path, "error", err)
}
return
}
@@ -891,7 +953,7 @@ func (w *WorkspaceWatcher) openMatchingFile(ctx context.Context, path string) {
// Check file size - for preloading we're more conservative
if info.Size() > (1 * 1024 * 1024) { // 1MB limit for preloaded files
if cnf.DebugLSP {
logging.Debug("Skipping large file for preloading", "path", path, "size", info.Size())
slog.Debug("Skipping large file for preloading", "path", path, "size", info.Size())
}
return
}
@@ -923,7 +985,7 @@ func (w *WorkspaceWatcher) openMatchingFile(ctx context.Context, path string) {
if shouldOpen {
// Don't need to check if it's already open - the client.OpenFile handles that
if err := w.client.OpenFile(ctx, path); err != nil && cnf.DebugLSP {
logging.Error("Error opening file", "path", path, "error", err)
slog.Error("Error opening file", "path", path, "error", err)
}
}
}

View File

@@ -0,0 +1,8 @@
package message
type Attachment struct {
FilePath string
FileName string
MimeType string
Content []byte
}

View File

@@ -5,7 +5,7 @@ import (
"slices"
"time"
"github.com/opencode-ai/opencode/internal/llm/models"
"github.com/sst/opencode/internal/llm/models"
)
type MessageRole string
@@ -48,7 +48,10 @@ type TextContent struct {
Text string `json:"text"`
}
func (tc TextContent) String() string {
func (tc *TextContent) String() string {
if tc == nil {
return ""
}
return tc.Text
}
@@ -66,13 +69,17 @@ func (iuc ImageURLContent) String() string {
func (ImageURLContent) isPart() {}
type BinaryContent struct {
Path string
MIMEType string
Data []byte
}
func (bc BinaryContent) String() string {
func (bc BinaryContent) String(provider models.ModelProvider) string {
base64Encoded := base64.StdEncoding.EncodeToString(bc.Data)
return "data:" + bc.MIMEType + ";base64," + base64Encoded
if provider == models.ProviderOpenAI {
return "data:" + bc.MIMEType + ";base64," + base64Encoded
}
return base64Encoded
}
func (BinaryContent) isPart() {}
@@ -98,30 +105,24 @@ type ToolResult struct {
func (ToolResult) isPart() {}
type Finish struct {
Reason FinishReason `json:"reason"`
Time time.Time `json:"time"`
}
type DBFinish struct {
Reason FinishReason `json:"reason"`
Time int64 `json:"time"`
}
func (Finish) isPart() {}
type Message struct {
ID string
Role MessageRole
SessionID string
Parts []ContentPart
Model models.ModelID
CreatedAt int64
UpdatedAt int64
}
func (m *Message) Content() TextContent {
func (m *Message) Content() *TextContent {
for _, part := range m.Parts {
if c, ok := part.(TextContent); ok {
return c
return &c
}
}
return TextContent{}
return nil
}
func (m *Message) ReasoningContent() ReasoningContent {
@@ -312,7 +313,7 @@ func (m *Message) AddFinish(reason FinishReason) {
break
}
}
m.Parts = append(m.Parts, Finish{Reason: reason, Time: time.Now().Unix()})
m.Parts = append(m.Parts, Finish{Reason: reason, Time: time.Now()})
}
func (m *Message) AddImageURL(url, detail string) {

View File

@@ -5,12 +5,31 @@ import (
"database/sql"
"encoding/json"
"fmt"
"log/slog"
"strings"
"sync"
"time"
"github.com/google/uuid"
"github.com/opencode-ai/opencode/internal/db"
"github.com/opencode-ai/opencode/internal/llm/models"
"github.com/opencode-ai/opencode/internal/pubsub"
"github.com/sst/opencode/internal/db"
"github.com/sst/opencode/internal/llm/models"
"github.com/sst/opencode/internal/pubsub"
)
type Message struct {
ID string
Role MessageRole
SessionID string
Parts []ContentPart
Model models.ModelID
CreatedAt time.Time
UpdatedAt time.Time
}
const (
EventMessageCreated pubsub.EventType = "message_created"
EventMessageUpdated pubsub.EventType = "message_updated"
EventMessageDeleted pubsub.EventType = "message_deleted"
)
type CreateMessageParams struct {
@@ -20,145 +39,317 @@ type CreateMessageParams struct {
}
type Service interface {
pubsub.Suscriber[Message]
pubsub.Subscriber[Message]
Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error)
Update(ctx context.Context, message Message) error
Update(ctx context.Context, message Message) (Message, error)
Get(ctx context.Context, id string) (Message, error)
List(ctx context.Context, sessionID string) ([]Message, error)
ListAfter(ctx context.Context, sessionID string, timestamp time.Time) ([]Message, error)
Delete(ctx context.Context, id string) error
DeleteSessionMessages(ctx context.Context, sessionID string) error
}
type service struct {
*pubsub.Broker[Message]
q db.Querier
db *db.Queries
broker *pubsub.Broker[Message]
mu sync.RWMutex
}
func NewService(q db.Querier) Service {
return &service{
Broker: pubsub.NewBroker[Message](),
q: q,
}
}
var globalMessageService *service
func (s *service) Delete(ctx context.Context, id string) error {
message, err := s.Get(ctx, id)
if err != nil {
return err
func InitService(dbConn *sql.DB) error {
if globalMessageService != nil {
return fmt.Errorf("message service already initialized")
}
err = s.q.DeleteMessage(ctx, message.ID)
if err != nil {
return err
queries := db.New(dbConn)
broker := pubsub.NewBroker[Message]()
globalMessageService = &service{
db: queries,
broker: broker,
}
s.Publish(pubsub.DeletedEvent, message)
return nil
}
func (s *service) Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error) {
if params.Role != Assistant {
params.Parts = append(params.Parts, Finish{
Reason: "stop",
})
func GetService() Service {
if globalMessageService == nil {
panic("message service not initialized. Call message.InitService() first.")
}
partsJSON, err := marshallParts(params.Parts)
if err != nil {
return Message{}, err
return globalMessageService
}
func (s *service) Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error) {
s.mu.Lock()
defer s.mu.Unlock()
isFinished := false
for _, p := range params.Parts {
if _, ok := p.(Finish); ok {
isFinished = true
break
}
}
if params.Role == User && !isFinished {
params.Parts = append(params.Parts, Finish{Reason: FinishReasonEndTurn, Time: time.Now()})
}
dbMessage, err := s.q.CreateMessage(ctx, db.CreateMessageParams{
partsJSON, err := marshallParts(params.Parts)
if err != nil {
return Message{}, fmt.Errorf("failed to marshal message parts: %w", err)
}
dbMsgParams := db.CreateMessageParams{
ID: uuid.New().String(),
SessionID: sessionID,
Role: string(params.Role),
Parts: string(partsJSON),
Model: sql.NullString{String: string(params.Model), Valid: true},
})
if err != nil {
return Message{}, err
Model: sql.NullString{String: string(params.Model), Valid: params.Model != ""},
}
dbMessage, err := s.db.CreateMessage(ctx, dbMsgParams)
if err != nil {
return Message{}, fmt.Errorf("db.CreateMessage: %w", err)
}
message, err := s.fromDBItem(dbMessage)
if err != nil {
return Message{}, err
return Message{}, fmt.Errorf("failed to convert DB message: %w", err)
}
s.Publish(pubsub.CreatedEvent, message)
s.broker.Publish(EventMessageCreated, message)
return message, nil
}
func (s *service) DeleteSessionMessages(ctx context.Context, sessionID string) error {
messages, err := s.List(ctx, sessionID)
if err != nil {
return err
func (s *service) Update(ctx context.Context, message Message) (Message, error) {
s.mu.Lock()
defer s.mu.Unlock()
if message.ID == "" {
return Message{}, fmt.Errorf("cannot update message with empty ID")
}
for _, message := range messages {
if message.SessionID == sessionID {
err = s.Delete(ctx, message.ID)
if err != nil {
return err
}
partsJSON, err := marshallParts(message.Parts)
if err != nil {
return Message{}, fmt.Errorf("failed to marshal message parts for update: %w", err)
}
var dbFinishedAt sql.NullString
finishPart := message.FinishPart()
if finishPart != nil && !finishPart.Time.IsZero() {
dbFinishedAt = sql.NullString{
String: finishPart.Time.UTC().Format(time.RFC3339Nano),
Valid: true,
}
}
return nil
}
func (s *service) Update(ctx context.Context, message Message) error {
parts, err := marshallParts(message.Parts)
if err != nil {
return err
}
finishedAt := sql.NullInt64{}
if f := message.FinishPart(); f != nil {
finishedAt.Int64 = f.Time
finishedAt.Valid = true
}
err = s.q.UpdateMessage(ctx, db.UpdateMessageParams{
// UpdatedAt is handled by the DB trigger (strftime('%s', 'now'))
err = s.db.UpdateMessage(ctx, db.UpdateMessageParams{
ID: message.ID,
Parts: string(parts),
FinishedAt: finishedAt,
Parts: string(partsJSON),
FinishedAt: dbFinishedAt,
})
if err != nil {
return err
return Message{}, fmt.Errorf("db.UpdateMessage: %w", err)
}
message.UpdatedAt = time.Now().Unix()
s.Publish(pubsub.UpdatedEvent, message)
return nil
dbUpdatedMessage, err := s.db.GetMessage(ctx, message.ID)
if err != nil {
return Message{}, fmt.Errorf("failed to fetch message after update: %w", err)
}
updatedMessage, err := s.fromDBItem(dbUpdatedMessage)
if err != nil {
return Message{}, fmt.Errorf("failed to convert updated DB message: %w", err)
}
s.broker.Publish(EventMessageUpdated, updatedMessage)
return updatedMessage, nil
}
func (s *service) Get(ctx context.Context, id string) (Message, error) {
dbMessage, err := s.q.GetMessage(ctx, id)
s.mu.RLock()
defer s.mu.RUnlock()
dbMessage, err := s.db.GetMessage(ctx, id)
if err != nil {
return Message{}, err
if err == sql.ErrNoRows {
return Message{}, fmt.Errorf("message with ID '%s' not found", id)
}
return Message{}, fmt.Errorf("db.GetMessage: %w", err)
}
return s.fromDBItem(dbMessage)
}
func (s *service) List(ctx context.Context, sessionID string) ([]Message, error) {
dbMessages, err := s.q.ListMessagesBySession(ctx, sessionID)
s.mu.RLock()
defer s.mu.RUnlock()
dbMessages, err := s.db.ListMessagesBySession(ctx, sessionID)
if err != nil {
return nil, fmt.Errorf("db.ListMessagesBySession: %w", err)
}
messages := make([]Message, len(dbMessages))
for i, dbMsg := range dbMessages {
msg, convErr := s.fromDBItem(dbMsg)
if convErr != nil {
return nil, fmt.Errorf("failed to convert DB message at index %d: %w", i, convErr)
}
messages[i] = msg
}
return messages, nil
}
func (s *service) ListAfter(ctx context.Context, sessionID string, timestamp time.Time) ([]Message, error) {
s.mu.RLock()
defer s.mu.RUnlock()
dbMessages, err := s.db.ListMessagesBySessionAfter(ctx, db.ListMessagesBySessionAfterParams{
SessionID: sessionID,
CreatedAt: timestamp.Format(time.RFC3339Nano),
})
if err != nil {
return nil, fmt.Errorf("db.ListMessagesBySessionAfter: %w", err)
}
messages := make([]Message, len(dbMessages))
for i, dbMsg := range dbMessages {
msg, convErr := s.fromDBItem(dbMsg)
if convErr != nil {
return nil, fmt.Errorf("failed to convert DB message at index %d (ListAfter): %w", i, convErr)
}
messages[i] = msg
}
return messages, nil
}
func (s *service) Delete(ctx context.Context, id string) error {
s.mu.Lock()
messageToPublish, err := s.getServiceForPublish(ctx, id)
s.mu.Unlock()
if err != nil {
// If error was due to not found, it's not a critical failure for deletion intent
if strings.Contains(err.Error(), "not found") {
return nil // Or return the error if strictness is required
}
return err
}
s.mu.Lock()
defer s.mu.Unlock()
err = s.db.DeleteMessage(ctx, id)
if err != nil {
return fmt.Errorf("db.DeleteMessage: %w", err)
}
if messageToPublish != nil {
s.broker.Publish(EventMessageDeleted, *messageToPublish)
}
return nil
}
func (s *service) getServiceForPublish(ctx context.Context, id string) (*Message, error) {
dbMsg, err := s.db.GetMessage(ctx, id)
if err != nil {
return nil, err
}
messages := make([]Message, len(dbMessages))
for i, dbMessage := range dbMessages {
messages[i], err = s.fromDBItem(dbMessage)
if err != nil {
return nil, err
msg, convErr := s.fromDBItem(dbMsg)
if convErr != nil {
return nil, fmt.Errorf("failed to convert DB message for publishing: %w", convErr)
}
return &msg, nil
}
func (s *service) DeleteSessionMessages(ctx context.Context, sessionID string) error {
s.mu.Lock()
defer s.mu.Unlock()
messagesToDelete, err := s.db.ListMessagesBySession(ctx, sessionID)
if err != nil {
return fmt.Errorf("failed to list messages for deletion: %w", err)
}
err = s.db.DeleteSessionMessages(ctx, sessionID)
if err != nil {
return fmt.Errorf("db.DeleteSessionMessages: %w", err)
}
for _, dbMsg := range messagesToDelete {
msg, convErr := s.fromDBItem(dbMsg)
if convErr == nil {
s.broker.Publish(EventMessageDeleted, msg)
} else {
slog.Error("Failed to convert DB message for delete event publishing", "id", dbMsg.ID, "error", convErr)
}
}
return messages, nil
return nil
}
func (s *service) Subscribe(ctx context.Context) <-chan pubsub.Event[Message] {
return s.broker.Subscribe(ctx)
}
func (s *service) fromDBItem(item db.Message) (Message, error) {
parts, err := unmarshallParts([]byte(item.Parts))
if err != nil {
return Message{}, err
return Message{}, fmt.Errorf("unmarshallParts for message ID %s: %w. Raw parts: %s", item.ID, err, item.Parts)
}
return Message{
// Parse timestamps from ISO strings
createdAt, err := time.Parse(time.RFC3339Nano, item.CreatedAt)
if err != nil {
slog.Error("Failed to parse created_at", "value", item.CreatedAt, "error", err)
createdAt = time.Now() // Fallback
}
updatedAt, err := time.Parse(time.RFC3339Nano, item.UpdatedAt)
if err != nil {
slog.Error("Failed to parse created_at", "value", item.CreatedAt, "error", err)
updatedAt = time.Now() // Fallback
}
msg := Message{
ID: item.ID,
SessionID: item.SessionID,
Role: MessageRole(item.Role),
Parts: parts,
Model: models.ModelID(item.Model.String),
CreatedAt: item.CreatedAt,
UpdatedAt: item.UpdatedAt,
}, nil
CreatedAt: createdAt,
UpdatedAt: updatedAt,
}
return msg, nil
}
func Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error) {
return GetService().Create(ctx, sessionID, params)
}
func Update(ctx context.Context, message Message) (Message, error) {
return GetService().Update(ctx, message)
}
func Get(ctx context.Context, id string) (Message, error) {
return GetService().Get(ctx, id)
}
func List(ctx context.Context, sessionID string) ([]Message, error) {
return GetService().List(ctx, sessionID)
}
func ListAfter(ctx context.Context, sessionID string, timestamp time.Time) ([]Message, error) {
return GetService().ListAfter(ctx, sessionID, timestamp)
}
func Delete(ctx context.Context, id string) error {
return GetService().Delete(ctx, id)
}
func DeleteSessionMessages(ctx context.Context, sessionID string) error {
return GetService().DeleteSessionMessages(ctx, sessionID)
}
func Subscribe(ctx context.Context) <-chan pubsub.Event[Message] {
return GetService().Subscribe(ctx)
}
type partType string
@@ -174,109 +365,139 @@ const (
)
type partWrapper struct {
Type partType `json:"type"`
Data ContentPart `json:"data"`
Type partType `json:"type"`
Data json.RawMessage `json:"data"`
}
func marshallParts(parts []ContentPart) ([]byte, error) {
wrappedParts := make([]partWrapper, len(parts))
wrappedParts := make([]json.RawMessage, len(parts))
for i, part := range parts {
var typ partType
var dataBytes []byte
var err error
switch part.(type) {
switch p := part.(type) {
case ReasoningContent:
typ = reasoningType
dataBytes, err = json.Marshal(p)
case TextContent:
typ = textType
dataBytes, err = json.Marshal(p)
case *TextContent:
typ = textType
dataBytes, err = json.Marshal(p)
case ImageURLContent:
typ = imageURLType
dataBytes, err = json.Marshal(p)
case BinaryContent:
typ = binaryType
dataBytes, err = json.Marshal(p)
case ToolCall:
typ = toolCallType
dataBytes, err = json.Marshal(p)
case ToolResult:
typ = toolResultType
dataBytes, err = json.Marshal(p)
case Finish:
typ = finishType
var dbFinish DBFinish
dbFinish.Reason = p.Reason
dbFinish.Time = p.Time.UnixMilli()
dataBytes, err = json.Marshal(dbFinish)
default:
return nil, fmt.Errorf("unknown part type: %T", part)
return nil, fmt.Errorf("unknown part type for marshalling: %T", part)
}
wrappedParts[i] = partWrapper{
Type: typ,
Data: part,
if err != nil {
return nil, fmt.Errorf("failed to marshal part data for type %s: %w", typ, err)
}
wrapper := struct {
Type partType `json:"type"`
Data json.RawMessage `json:"data"`
}{Type: typ, Data: dataBytes}
wrappedBytes, err := json.Marshal(wrapper)
if err != nil {
return nil, fmt.Errorf("failed to marshal part wrapper for type %s: %w", typ, err)
}
wrappedParts[i] = wrappedBytes
}
return json.Marshal(wrappedParts)
}
func unmarshallParts(data []byte) ([]ContentPart, error) {
temp := []json.RawMessage{}
if err := json.Unmarshal(data, &temp); err != nil {
return nil, err
var rawMessages []json.RawMessage
if err := json.Unmarshal(data, &rawMessages); err != nil {
return nil, fmt.Errorf("failed to unmarshal parts data as array: %w. Data: %s", err, string(data))
}
parts := make([]ContentPart, 0)
for _, rawPart := range temp {
var wrapper struct {
Type partType `json:"type"`
Data json.RawMessage `json:"data"`
}
parts := make([]ContentPart, 0, len(rawMessages))
for _, rawPart := range rawMessages {
var wrapper partWrapper
if err := json.Unmarshal(rawPart, &wrapper); err != nil {
return nil, err
// Fallback for old format where parts might be just TextContent string
var text string
if errText := json.Unmarshal(rawPart, &text); errText == nil {
parts = append(parts, TextContent{Text: text})
continue
}
return nil, fmt.Errorf("failed to unmarshal part wrapper: %w. Raw part: %s", err, string(rawPart))
}
switch wrapper.Type {
case reasoningType:
part := ReasoningContent{}
if err := json.Unmarshal(wrapper.Data, &part); err != nil {
return nil, err
var p ReasoningContent
if err := json.Unmarshal(wrapper.Data, &p); err != nil {
return nil, fmt.Errorf("unmarshal ReasoningContent: %w. Data: %s", err, string(wrapper.Data))
}
parts = append(parts, part)
parts = append(parts, p)
case textType:
part := TextContent{}
if err := json.Unmarshal(wrapper.Data, &part); err != nil {
return nil, err
var p TextContent
if err := json.Unmarshal(wrapper.Data, &p); err != nil {
return nil, fmt.Errorf("unmarshal TextContent: %w. Data: %s", err, string(wrapper.Data))
}
parts = append(parts, part)
parts = append(parts, p)
case imageURLType:
part := ImageURLContent{}
if err := json.Unmarshal(wrapper.Data, &part); err != nil {
return nil, err
var p ImageURLContent
if err := json.Unmarshal(wrapper.Data, &p); err != nil {
return nil, fmt.Errorf("unmarshal ImageURLContent: %w. Data: %s", err, string(wrapper.Data))
}
parts = append(parts, p)
case binaryType:
part := BinaryContent{}
if err := json.Unmarshal(wrapper.Data, &part); err != nil {
return nil, err
var p BinaryContent
if err := json.Unmarshal(wrapper.Data, &p); err != nil {
return nil, fmt.Errorf("unmarshal BinaryContent: %w. Data: %s", err, string(wrapper.Data))
}
parts = append(parts, part)
parts = append(parts, p)
case toolCallType:
part := ToolCall{}
if err := json.Unmarshal(wrapper.Data, &part); err != nil {
return nil, err
var p ToolCall
if err := json.Unmarshal(wrapper.Data, &p); err != nil {
return nil, fmt.Errorf("unmarshal ToolCall: %w. Data: %s", err, string(wrapper.Data))
}
parts = append(parts, part)
parts = append(parts, p)
case toolResultType:
part := ToolResult{}
if err := json.Unmarshal(wrapper.Data, &part); err != nil {
return nil, err
var p ToolResult
if err := json.Unmarshal(wrapper.Data, &p); err != nil {
return nil, fmt.Errorf("unmarshal ToolResult: %w. Data: %s", err, string(wrapper.Data))
}
parts = append(parts, part)
parts = append(parts, p)
case finishType:
part := Finish{}
if err := json.Unmarshal(wrapper.Data, &part); err != nil {
return nil, err
var p DBFinish
if err := json.Unmarshal(wrapper.Data, &p); err != nil {
return nil, fmt.Errorf("unmarshal Finish: %w. Data: %s", err, string(wrapper.Data))
}
parts = append(parts, part)
parts = append(parts, Finish{Reason: FinishReason(p.Reason), Time: time.UnixMilli(p.Time)})
default:
return nil, fmt.Errorf("unknown part type: %s", wrapper.Type)
slog.Warn("Unknown part type during unmarshalling, attempting to parse as TextContent", "type", wrapper.Type, "data", string(wrapper.Data))
// Fallback: if type is unknown or empty, try to parse data as TextContent directly
var p TextContent
if err := json.Unmarshal(wrapper.Data, &p); err == nil {
parts = append(parts, p)
} else {
// If that also fails, log it but continue if possible, or return error
slog.Error("Failed to unmarshal unknown part type and fallback to TextContent failed", "type", wrapper.Type, "data", string(wrapper.Data), "error", err)
// Depending on strictness, you might return an error here:
// return nil, fmt.Errorf("unknown part type '%s' and failed fallback: %w", wrapper.Type, err)
}
}
}
return parts, nil
}

View File

@@ -1,14 +1,18 @@
package permission
import (
"context"
"errors"
"fmt"
"path/filepath"
"slices"
"strings"
"sync"
"log/slog"
"github.com/google/uuid"
"github.com/opencode-ai/opencode/internal/config"
"github.com/opencode-ai/opencode/internal/pubsub"
"github.com/sst/opencode/internal/config"
"github.com/sst/opencode/internal/pubsub"
)
var ErrorPermissionDenied = errors.New("permission denied")
@@ -32,56 +36,141 @@ type PermissionRequest struct {
Path string `json:"path"`
}
type PermissionResponse struct {
Request PermissionRequest
Granted bool
}
const (
EventPermissionRequested pubsub.EventType = "permission_requested"
EventPermissionGranted pubsub.EventType = "permission_granted"
EventPermissionDenied pubsub.EventType = "permission_denied"
EventPermissionPersisted pubsub.EventType = "permission_persisted"
)
type Service interface {
pubsub.Suscriber[PermissionRequest]
GrantPersistant(permission PermissionRequest)
Grant(permission PermissionRequest)
Deny(permission PermissionRequest)
Request(opts CreatePermissionRequest) bool
AutoApproveSession(sessionID string)
pubsub.Subscriber[PermissionRequest]
SubscribeToResponseEvents(ctx context.Context) <-chan pubsub.Event[PermissionResponse]
GrantPersistant(ctx context.Context, permission PermissionRequest)
Grant(ctx context.Context, permission PermissionRequest)
Deny(ctx context.Context, permission PermissionRequest)
Request(ctx context.Context, opts CreatePermissionRequest) bool
AutoApproveSession(ctx context.Context, sessionID string)
IsAutoApproved(ctx context.Context, sessionID string) bool
}
type permissionService struct {
*pubsub.Broker[PermissionRequest]
broker *pubsub.Broker[PermissionRequest]
responseBroker *pubsub.Broker[PermissionResponse]
sessionPermissions []PermissionRequest
sessionPermissions map[string][]PermissionRequest
pendingRequests sync.Map
autoApproveSessions []string
autoApproveSessions map[string]bool
mu sync.RWMutex
}
func (s *permissionService) GrantPersistant(permission PermissionRequest) {
var globalPermissionService *permissionService
func InitService() error {
if globalPermissionService != nil {
return fmt.Errorf("permission service already initialized")
}
globalPermissionService = &permissionService{
broker: pubsub.NewBroker[PermissionRequest](),
responseBroker: pubsub.NewBroker[PermissionResponse](),
sessionPermissions: make(map[string][]PermissionRequest),
autoApproveSessions: make(map[string]bool),
}
return nil
}
func GetService() *permissionService {
if globalPermissionService == nil {
panic("permission service not initialized. Call permission.InitService() first.")
}
return globalPermissionService
}
func (s *permissionService) GrantPersistant(ctx context.Context, permission PermissionRequest) {
s.mu.Lock()
s.sessionPermissions[permission.SessionID] = append(s.sessionPermissions[permission.SessionID], permission)
s.mu.Unlock()
respCh, ok := s.pendingRequests.Load(permission.ID)
if ok {
respCh.(chan bool) <- true
select {
case respCh.(chan bool) <- true:
case <-ctx.Done():
slog.Warn("Context cancelled while sending grant persistent response", "request_id", permission.ID)
}
}
s.sessionPermissions = append(s.sessionPermissions, permission)
s.responseBroker.Publish(EventPermissionPersisted, PermissionResponse{Request: permission, Granted: true})
}
func (s *permissionService) Grant(permission PermissionRequest) {
func (s *permissionService) Grant(ctx context.Context, permission PermissionRequest) {
respCh, ok := s.pendingRequests.Load(permission.ID)
if ok {
respCh.(chan bool) <- true
select {
case respCh.(chan bool) <- true:
case <-ctx.Done():
slog.Warn("Context cancelled while sending grant response", "request_id", permission.ID)
}
}
s.responseBroker.Publish(EventPermissionGranted, PermissionResponse{Request: permission, Granted: true})
}
func (s *permissionService) Deny(permission PermissionRequest) {
func (s *permissionService) Deny(ctx context.Context, permission PermissionRequest) {
respCh, ok := s.pendingRequests.Load(permission.ID)
if ok {
respCh.(chan bool) <- false
select {
case respCh.(chan bool) <- false:
case <-ctx.Done():
slog.Warn("Context cancelled while sending deny response", "request_id", permission.ID)
}
}
s.responseBroker.Publish(EventPermissionDenied, PermissionResponse{Request: permission, Granted: false})
}
func (s *permissionService) Request(opts CreatePermissionRequest) bool {
if slices.Contains(s.autoApproveSessions, opts.SessionID) {
func (s *permissionService) Request(ctx context.Context, opts CreatePermissionRequest) bool {
s.mu.RLock()
if s.autoApproveSessions[opts.SessionID] {
s.mu.RUnlock()
return true
}
dir := filepath.Dir(opts.Path)
if dir == "." {
dir = config.WorkingDirectory()
requestPath := opts.Path
if !filepath.IsAbs(requestPath) {
requestPath = filepath.Join(config.WorkingDirectory(), requestPath)
}
permission := PermissionRequest{
requestPath = filepath.Clean(requestPath)
if permissions, ok := s.sessionPermissions[opts.SessionID]; ok {
for _, p := range permissions {
storedPath := p.Path
if !filepath.IsAbs(storedPath) {
storedPath = filepath.Join(config.WorkingDirectory(), storedPath)
}
storedPath = filepath.Clean(storedPath)
if p.ToolName == opts.ToolName && p.Action == opts.Action &&
(requestPath == storedPath || strings.HasPrefix(requestPath, storedPath+string(filepath.Separator))) {
s.mu.RUnlock()
return true
}
}
}
s.mu.RUnlock()
normalizedPath := opts.Path
if !filepath.IsAbs(normalizedPath) {
normalizedPath = filepath.Join(config.WorkingDirectory(), normalizedPath)
}
normalizedPath = filepath.Clean(normalizedPath)
permissionReq := PermissionRequest{
ID: uuid.New().String(),
Path: dir,
Path: normalizedPath,
SessionID: opts.SessionID,
ToolName: opts.ToolName,
Description: opts.Description,
@@ -89,31 +178,69 @@ func (s *permissionService) Request(opts CreatePermissionRequest) bool {
Params: opts.Params,
}
for _, p := range s.sessionPermissions {
if p.ToolName == permission.ToolName && p.Action == permission.Action && p.SessionID == permission.SessionID && p.Path == permission.Path {
return true
}
}
respCh := make(chan bool, 1)
s.pendingRequests.Store(permissionReq.ID, respCh)
defer s.pendingRequests.Delete(permissionReq.ID)
s.pendingRequests.Store(permission.ID, respCh)
defer s.pendingRequests.Delete(permission.ID)
s.broker.Publish(EventPermissionRequested, permissionReq)
s.Publish(pubsub.CreatedEvent, permission)
// Wait for the response with a timeout
resp := <-respCh
return resp
}
func (s *permissionService) AutoApproveSession(sessionID string) {
s.autoApproveSessions = append(s.autoApproveSessions, sessionID)
}
func NewPermissionService() Service {
return &permissionService{
Broker: pubsub.NewBroker[PermissionRequest](),
sessionPermissions: make([]PermissionRequest, 0),
select {
case resp := <-respCh:
return resp
case <-ctx.Done():
slog.Warn("Permission request timed out or context cancelled", "request_id", permissionReq.ID, "tool", opts.ToolName)
return false
}
}
func (s *permissionService) AutoApproveSession(ctx context.Context, sessionID string) {
s.mu.Lock()
defer s.mu.Unlock()
s.autoApproveSessions[sessionID] = true
}
func (s *permissionService) IsAutoApproved(ctx context.Context, sessionID string) bool {
s.mu.RLock()
defer s.mu.RUnlock()
return s.autoApproveSessions[sessionID]
}
func (s *permissionService) Subscribe(ctx context.Context) <-chan pubsub.Event[PermissionRequest] {
return s.broker.Subscribe(ctx)
}
func (s *permissionService) SubscribeToResponseEvents(ctx context.Context) <-chan pubsub.Event[PermissionResponse] {
return s.responseBroker.Subscribe(ctx)
}
func GrantPersistant(ctx context.Context, permission PermissionRequest) {
GetService().GrantPersistant(ctx, permission)
}
func Grant(ctx context.Context, permission PermissionRequest) {
GetService().Grant(ctx, permission)
}
func Deny(ctx context.Context, permission PermissionRequest) {
GetService().Deny(ctx, permission)
}
func Request(ctx context.Context, opts CreatePermissionRequest) bool {
return GetService().Request(ctx, opts)
}
func AutoApproveSession(ctx context.Context, sessionID string) {
GetService().AutoApproveSession(ctx, sessionID)
}
func IsAutoApproved(ctx context.Context, sessionID string) bool {
return GetService().IsAutoApproved(ctx, sessionID)
}
func SubscribeToRequests(ctx context.Context) <-chan pubsub.Event[PermissionRequest] {
return GetService().Subscribe(ctx)
}
func SubscribeToResponses(ctx context.Context) <-chan pubsub.Event[PermissionResponse] {
return GetService().SubscribeToResponseEvents(ctx)
}

View File

@@ -2,115 +2,112 @@ package pubsub
import (
"context"
"fmt"
"log/slog"
"sync"
"time"
)
const bufferSize = 64
const defaultChannelBufferSize = 100
type Broker[T any] struct {
subs map[chan Event[T]]struct{}
mu sync.RWMutex
done chan struct{}
subCount int
maxEvents int
subs map[chan Event[T]]context.CancelFunc
mu sync.RWMutex
isClosed bool
}
func NewBroker[T any]() *Broker[T] {
return NewBrokerWithOptions[T](bufferSize, 1000)
}
func NewBrokerWithOptions[T any](channelBufferSize, maxEvents int) *Broker[T] {
b := &Broker[T]{
subs: make(map[chan Event[T]]struct{}),
done: make(chan struct{}),
subCount: 0,
maxEvents: maxEvents,
return &Broker[T]{
subs: make(map[chan Event[T]]context.CancelFunc),
}
return b
}
func (b *Broker[T]) Shutdown() {
select {
case <-b.done: // Already closed
return
default:
close(b.done)
}
b.mu.Lock()
defer b.mu.Unlock()
for ch := range b.subs {
delete(b.subs, ch)
close(ch)
if b.isClosed {
b.mu.Unlock()
return
}
b.isClosed = true
b.subCount = 0
for ch, cancel := range b.subs {
cancel()
close(ch)
delete(b.subs, ch)
}
b.mu.Unlock()
slog.Debug("PubSub broker shut down", "type", fmt.Sprintf("%T", *new(T)))
}
func (b *Broker[T]) Subscribe(ctx context.Context) <-chan Event[T] {
b.mu.Lock()
defer b.mu.Unlock()
select {
case <-b.done:
ch := make(chan Event[T])
close(ch)
return ch
default:
if b.isClosed {
closedCh := make(chan Event[T])
close(closedCh)
return closedCh
}
sub := make(chan Event[T], bufferSize)
b.subs[sub] = struct{}{}
b.subCount++
subCtx, subCancel := context.WithCancel(ctx)
subscriberChannel := make(chan Event[T], defaultChannelBufferSize)
b.subs[subscriberChannel] = subCancel
go func() {
<-ctx.Done()
<-subCtx.Done()
b.mu.Lock()
defer b.mu.Unlock()
select {
case <-b.done:
return
default:
if _, ok := b.subs[subscriberChannel]; ok {
close(subscriberChannel)
delete(b.subs, subscriberChannel)
}
delete(b.subs, sub)
close(sub)
b.subCount--
}()
return sub
return subscriberChannel
}
func (b *Broker[T]) Publish(eventType EventType, payload T) {
b.mu.RLock()
defer b.mu.RUnlock()
if b.isClosed {
slog.Warn("Attempted to publish on a closed pubsub broker", "type", eventType, "payload_type", fmt.Sprintf("%T", payload))
return
}
event := Event[T]{Type: eventType, Payload: payload}
for ch := range b.subs {
// Non-blocking send with a fallback to a goroutine to prevent slow subscribers
// from blocking the publisher.
select {
case ch <- event:
// Successfully sent
default:
// Subscriber channel is full or receiver is slow.
// Send in a new goroutine to avoid blocking the publisher.
// This might lead to out-of-order delivery for this specific slow subscriber.
go func(sChan chan Event[T], ev Event[T]) {
// Re-check if broker is closed before attempting send in goroutine
b.mu.RLock()
isBrokerClosed := b.isClosed
b.mu.RUnlock()
if isBrokerClosed {
return
}
select {
case sChan <- ev:
case <-time.After(2 * time.Second): // Timeout for slow subscriber
slog.Warn("PubSub: Dropped event for slow subscriber after timeout", "type", ev.Type)
}
}(ch, event)
}
}
}
func (b *Broker[T]) GetSubscriberCount() int {
b.mu.RLock()
defer b.mu.RUnlock()
return b.subCount
}
func (b *Broker[T]) Publish(t EventType, payload T) {
b.mu.RLock()
select {
case <-b.done:
b.mu.RUnlock()
return
default:
}
subscribers := make([]chan Event[T], 0, len(b.subs))
for sub := range b.subs {
subscribers = append(subscribers, sub)
}
b.mu.RUnlock()
event := Event[T]{Type: t, Payload: payload}
for _, sub := range subscribers {
select {
case sub <- event:
default:
}
}
return len(b.subs)
}

View File

@@ -0,0 +1,144 @@
package pubsub
import (
"context"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestBrokerSubscribe(t *testing.T) {
t.Parallel()
t.Run("with cancellable context", func(t *testing.T) {
t.Parallel()
broker := NewBroker[string]()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ch := broker.Subscribe(ctx)
assert.NotNil(t, ch)
assert.Equal(t, 1, broker.GetSubscriberCount())
// Cancel the context should remove the subscription
cancel()
time.Sleep(10 * time.Millisecond) // Give time for goroutine to process
assert.Equal(t, 0, broker.GetSubscriberCount())
})
t.Run("with background context", func(t *testing.T) {
t.Parallel()
broker := NewBroker[string]()
// Using context.Background() should not leak goroutines
ch := broker.Subscribe(context.Background())
assert.NotNil(t, ch)
assert.Equal(t, 1, broker.GetSubscriberCount())
// Shutdown should clean up all subscriptions
broker.Shutdown()
assert.Equal(t, 0, broker.GetSubscriberCount())
})
}
func TestBrokerPublish(t *testing.T) {
t.Parallel()
broker := NewBroker[string]()
ctx := t.Context()
ch := broker.Subscribe(ctx)
// Publish a message
broker.Publish(EventTypeCreated, "test message")
// Verify message is received
select {
case event := <-ch:
assert.Equal(t, EventTypeCreated, event.Type)
assert.Equal(t, "test message", event.Payload)
case <-time.After(100 * time.Millisecond):
t.Fatal("timeout waiting for message")
}
}
func TestBrokerShutdown(t *testing.T) {
t.Parallel()
broker := NewBroker[string]()
// Create multiple subscribers
ch1 := broker.Subscribe(context.Background())
ch2 := broker.Subscribe(context.Background())
assert.Equal(t, 2, broker.GetSubscriberCount())
// Shutdown should close all channels and clean up
broker.Shutdown()
// Verify channels are closed
_, ok1 := <-ch1
_, ok2 := <-ch2
assert.False(t, ok1, "channel 1 should be closed")
assert.False(t, ok2, "channel 2 should be closed")
// Verify subscriber count is reset
assert.Equal(t, 0, broker.GetSubscriberCount())
}
func TestBrokerConcurrency(t *testing.T) {
t.Parallel()
broker := NewBroker[int]()
// Create a large number of subscribers
const numSubscribers = 100
var wg sync.WaitGroup
wg.Add(numSubscribers)
// Create a channel to collect received events
receivedEvents := make(chan int, numSubscribers)
for i := range numSubscribers {
go func(id int) {
defer wg.Done()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ch := broker.Subscribe(ctx)
// Receive one message then cancel
select {
case event := <-ch:
receivedEvents <- event.Payload
case <-time.After(1 * time.Second):
t.Errorf("timeout waiting for message %d", id)
}
cancel()
}(i)
}
// Give subscribers time to set up
time.Sleep(10 * time.Millisecond)
// Publish messages to all subscribers
for i := range numSubscribers {
broker.Publish(EventTypeCreated, i)
}
// Wait for all subscribers to finish
wg.Wait()
close(receivedEvents)
// Give time for cleanup goroutines to run
time.Sleep(10 * time.Millisecond)
// Verify all subscribers are cleaned up
assert.Equal(t, 0, broker.GetSubscriberCount())
// Verify we received the expected number of events
count := 0
for range receivedEvents {
count++
}
assert.Equal(t, numSubscribers, count)
}

View File

@@ -2,27 +2,23 @@ package pubsub
import "context"
type EventType string
const (
CreatedEvent EventType = "created"
UpdatedEvent EventType = "updated"
DeletedEvent EventType = "deleted"
EventTypeCreated EventType = "created"
EventTypeUpdated EventType = "updated"
EventTypeDeleted EventType = "deleted"
)
type Suscriber[T any] interface {
Subscribe(context.Context) <-chan Event[T]
type Event[T any] struct {
Type EventType
Payload T
}
type (
// EventType identifies the type of event
EventType string
type Subscriber[T any] interface {
Subscribe(ctx context.Context) <-chan Event[T]
}
// Event represents an event in the lifecycle of a resource
Event[T any] struct {
Type EventType
Payload T
}
Publisher[T any] interface {
Publish(EventType, T)
}
)
type Publisher[T any] interface {
Publish(eventType EventType, payload T)
}

View File

@@ -3,10 +3,13 @@ package session
import (
"context"
"database/sql"
"fmt"
"sync"
"time"
"github.com/google/uuid"
"github.com/opencode-ai/opencode/internal/db"
"github.com/opencode-ai/opencode/internal/pubsub"
"github.com/sst/opencode/internal/db"
"github.com/sst/opencode/internal/pubsub"
)
type Session struct {
@@ -17,117 +20,197 @@ type Session struct {
PromptTokens int64
CompletionTokens int64
Cost float64
CreatedAt int64
UpdatedAt int64
Summary string
SummarizedAt time.Time
CreatedAt time.Time
UpdatedAt time.Time
}
const (
EventSessionCreated pubsub.EventType = "session_created"
EventSessionUpdated pubsub.EventType = "session_updated"
EventSessionDeleted pubsub.EventType = "session_deleted"
)
type Service interface {
pubsub.Suscriber[Session]
pubsub.Subscriber[Session]
Create(ctx context.Context, title string) (Session, error)
CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error)
CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error)
Get(ctx context.Context, id string) (Session, error)
List(ctx context.Context) ([]Session, error)
Save(ctx context.Context, session Session) (Session, error)
Update(ctx context.Context, session Session) (Session, error)
Delete(ctx context.Context, id string) error
}
type service struct {
*pubsub.Broker[Session]
q db.Querier
db *db.Queries
broker *pubsub.Broker[Session]
mu sync.RWMutex
}
var globalSessionService *service
func InitService(dbConn *sql.DB) error {
if globalSessionService != nil {
return fmt.Errorf("session service already initialized")
}
queries := db.New(dbConn)
broker := pubsub.NewBroker[Session]()
globalSessionService = &service{
db: queries,
broker: broker,
}
return nil
}
func GetService() Service {
if globalSessionService == nil {
panic("session service not initialized. Call session.InitService() first.")
}
return globalSessionService
}
func (s *service) Create(ctx context.Context, title string) (Session, error) {
dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
s.mu.Lock()
defer s.mu.Unlock()
if title == "" {
title = "New Session - " + time.Now().Format("2006-01-02 15:04:05")
}
dbSessParams := db.CreateSessionParams{
ID: uuid.New().String(),
Title: title,
})
if err != nil {
return Session{}, err
}
dbSession, err := s.db.CreateSession(ctx, dbSessParams)
if err != nil {
return Session{}, fmt.Errorf("db.CreateSession: %w", err)
}
session := s.fromDBItem(dbSession)
s.Publish(pubsub.CreatedEvent, session)
s.broker.Publish(EventSessionCreated, session)
return session, nil
}
func (s *service) CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error) {
dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
s.mu.Lock()
defer s.mu.Unlock()
if title == "" {
title = "Task Session - " + time.Now().Format("2006-01-02 15:04:05")
}
if toolCallID == "" {
toolCallID = uuid.New().String()
}
dbSessParams := db.CreateSessionParams{
ID: toolCallID,
ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
ParentSessionID: sql.NullString{String: parentSessionID, Valid: parentSessionID != ""},
Title: title,
})
}
dbSession, err := s.db.CreateSession(ctx, dbSessParams)
if err != nil {
return Session{}, err
return Session{}, fmt.Errorf("db.CreateTaskSession: %w", err)
}
session := s.fromDBItem(dbSession)
s.Publish(pubsub.CreatedEvent, session)
s.broker.Publish(EventSessionCreated, session)
return session, nil
}
func (s *service) CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error) {
dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
ID: "title-" + parentSessionID,
ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
Title: "Generate a title",
})
if err != nil {
return Session{}, err
}
session := s.fromDBItem(dbSession)
s.Publish(pubsub.CreatedEvent, session)
return session, nil
}
func (s *service) Delete(ctx context.Context, id string) error {
session, err := s.Get(ctx, id)
if err != nil {
return err
}
err = s.q.DeleteSession(ctx, session.ID)
if err != nil {
return err
}
s.Publish(pubsub.DeletedEvent, session)
return nil
}
func (s *service) Get(ctx context.Context, id string) (Session, error) {
dbSession, err := s.q.GetSessionByID(ctx, id)
s.mu.RLock()
defer s.mu.RUnlock()
dbSession, err := s.db.GetSessionByID(ctx, id)
if err != nil {
return Session{}, err
if err == sql.ErrNoRows {
return Session{}, fmt.Errorf("session ID '%s' not found", id)
}
return Session{}, fmt.Errorf("db.GetSessionByID: %w", err)
}
return s.fromDBItem(dbSession), nil
}
func (s *service) Save(ctx context.Context, session Session) (Session, error) {
dbSession, err := s.q.UpdateSession(ctx, db.UpdateSessionParams{
func (s *service) List(ctx context.Context) ([]Session, error) {
s.mu.RLock()
defer s.mu.RUnlock()
dbSessions, err := s.db.ListSessions(ctx)
if err != nil {
return nil, fmt.Errorf("db.ListSessions: %w", err)
}
sessions := make([]Session, len(dbSessions))
for i, dbSess := range dbSessions {
sessions[i] = s.fromDBItem(dbSess)
}
return sessions, nil
}
func (s *service) Update(ctx context.Context, session Session) (Session, error) {
s.mu.Lock()
defer s.mu.Unlock()
if session.ID == "" {
return Session{}, fmt.Errorf("cannot update session with empty ID")
}
params := db.UpdateSessionParams{
ID: session.ID,
Title: session.Title,
PromptTokens: session.PromptTokens,
CompletionTokens: session.CompletionTokens,
Cost: session.Cost,
})
if err != nil {
return Session{}, err
Summary: sql.NullString{String: session.Summary, Valid: session.Summary != ""},
SummarizedAt: sql.NullString{String: session.SummarizedAt.UTC().Format(time.RFC3339Nano), Valid: !session.SummarizedAt.IsZero()},
}
session = s.fromDBItem(dbSession)
s.Publish(pubsub.UpdatedEvent, session)
return session, nil
dbSession, err := s.db.UpdateSession(ctx, params)
if err != nil {
return Session{}, fmt.Errorf("db.UpdateSession: %w", err)
}
updatedSession := s.fromDBItem(dbSession)
s.broker.Publish(EventSessionUpdated, updatedSession)
return updatedSession, nil
}
func (s *service) List(ctx context.Context) ([]Session, error) {
dbSessions, err := s.q.ListSessions(ctx)
func (s *service) Delete(ctx context.Context, id string) error {
s.mu.Lock()
dbSess, err := s.db.GetSessionByID(ctx, id)
if err != nil {
return nil, err
s.mu.Unlock()
if err == sql.ErrNoRows {
return fmt.Errorf("session ID '%s' not found for deletion", id)
}
return fmt.Errorf("db.GetSessionByID before delete: %w", err)
}
sessions := make([]Session, len(dbSessions))
for i, dbSession := range dbSessions {
sessions[i] = s.fromDBItem(dbSession)
sessionToPublish := s.fromDBItem(dbSess)
s.mu.Unlock()
s.mu.Lock()
defer s.mu.Unlock()
err = s.db.DeleteSession(ctx, id)
if err != nil {
return fmt.Errorf("db.DeleteSession: %w", err)
}
return sessions, nil
s.broker.Publish(EventSessionDeleted, sessionToPublish)
return nil
}
func (s service) fromDBItem(item db.Session) Session {
func (s *service) Subscribe(ctx context.Context) <-chan pubsub.Event[Session] {
return s.broker.Subscribe(ctx)
}
func (s *service) fromDBItem(item db.Session) Session {
var summarizedAt time.Time
if item.SummarizedAt.Valid {
parsedTime, err := time.Parse(time.RFC3339Nano, item.SummarizedAt.String)
if err == nil {
summarizedAt = parsedTime
}
}
createdAt, _ := time.Parse(time.RFC3339Nano, item.CreatedAt)
updatedAt, _ := time.Parse(time.RFC3339Nano, item.UpdatedAt)
return Session{
ID: item.ID,
ParentSessionID: item.ParentSessionID.String,
@@ -136,15 +219,37 @@ func (s service) fromDBItem(item db.Session) Session {
PromptTokens: item.PromptTokens,
CompletionTokens: item.CompletionTokens,
Cost: item.Cost,
CreatedAt: item.CreatedAt,
UpdatedAt: item.UpdatedAt,
Summary: item.Summary.String,
SummarizedAt: summarizedAt,
CreatedAt: createdAt,
UpdatedAt: updatedAt,
}
}
func NewService(q db.Querier) Service {
broker := pubsub.NewBroker[Session]()
return &service{
broker,
q,
}
func Create(ctx context.Context, title string) (Session, error) {
return GetService().Create(ctx, title)
}
func CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error) {
return GetService().CreateTaskSession(ctx, toolCallID, parentSessionID, title)
}
func Get(ctx context.Context, id string) (Session, error) {
return GetService().Get(ctx, id)
}
func List(ctx context.Context) ([]Session, error) {
return GetService().List(ctx)
}
func Update(ctx context.Context, session Session) (Session, error) {
return GetService().Update(ctx, session)
}
func Delete(ctx context.Context, id string) error {
return GetService().Delete(ctx, id)
}
func Subscribe(ctx context.Context) <-chan pubsub.Event[Session] {
return GetService().Subscribe(ctx)
}

117
internal/status/status.go Normal file
View File

@@ -0,0 +1,117 @@
package status
import (
"context"
"fmt"
"log/slog"
"sync"
"time"
"github.com/sst/opencode/internal/pubsub"
)
type Level string
const (
LevelInfo Level = "info"
LevelWarn Level = "warn"
LevelError Level = "error"
LevelDebug Level = "debug"
)
type StatusMessage struct {
Level Level `json:"level"`
Message string `json:"message"`
Timestamp time.Time `json:"timestamp"`
}
const (
EventStatusPublished pubsub.EventType = "status_published"
)
type Service interface {
pubsub.Subscriber[StatusMessage]
Info(message string)
Warn(message string)
Error(message string)
Debug(message string)
}
type service struct {
broker *pubsub.Broker[StatusMessage]
mu sync.RWMutex
}
var globalStatusService *service
func InitService() error {
if globalStatusService != nil {
return fmt.Errorf("status service already initialized")
}
broker := pubsub.NewBroker[StatusMessage]()
globalStatusService = &service{
broker: broker,
}
return nil
}
func GetService() Service {
if globalStatusService == nil {
panic("status service not initialized. Call status.InitService() at application startup.")
}
return globalStatusService
}
func (s *service) Info(message string) {
s.publish(LevelInfo, message)
slog.Info(message)
}
func (s *service) Warn(message string) {
s.publish(LevelWarn, message)
slog.Warn(message)
}
func (s *service) Error(message string) {
s.publish(LevelError, message)
slog.Error(message)
}
func (s *service) Debug(message string) {
s.publish(LevelDebug, message)
slog.Debug(message)
}
func (s *service) publish(level Level, messageText string) {
statusMsg := StatusMessage{
Level: level,
Message: messageText,
Timestamp: time.Now(),
}
s.broker.Publish(EventStatusPublished, statusMsg)
}
func (s *service) Subscribe(ctx context.Context) <-chan pubsub.Event[StatusMessage] {
return s.broker.Subscribe(ctx)
}
func Info(message string) {
GetService().Info(message)
}
func Warn(message string) {
GetService().Warn(message)
}
func Error(message string) {
GetService().Error(message)
}
func Debug(message string) {
GetService().Debug(message)
}
func Subscribe(ctx context.Context) <-chan pubsub.Event[StatusMessage] {
return GetService().Subscribe(ctx)
}

View File

@@ -6,28 +6,41 @@ import (
"github.com/charmbracelet/lipgloss"
"github.com/charmbracelet/x/ansi"
"github.com/opencode-ai/opencode/internal/config"
"github.com/opencode-ai/opencode/internal/session"
"github.com/opencode-ai/opencode/internal/tui/styles"
"github.com/opencode-ai/opencode/internal/version"
"github.com/sst/opencode/internal/config"
"github.com/sst/opencode/internal/message"
"github.com/sst/opencode/internal/tui/styles"
"github.com/sst/opencode/internal/tui/theme"
"github.com/sst/opencode/internal/version"
)
type SendMsg struct {
Text string
Text string
Attachments []message.Attachment
}
type SessionSelectedMsg = session.Session
type SessionClearedMsg struct{}
type EditorFocusMsg bool
func header(width int) string {
return lipgloss.JoinVertical(
lipgloss.Top,
logo(width),
repo(width),
"",
cwd(width),
)
}
func lspsConfigured(width int) string {
cfg := config.Get()
title := "LSP Configuration"
title := "LSP Servers"
title = ansi.Truncate(title, width, "…")
lsps := styles.BaseStyle.Width(width).Foreground(styles.PrimaryColor).Bold(true).Render(title)
t := theme.CurrentTheme()
baseStyle := styles.BaseStyle()
lsps := baseStyle.
Width(width).
Foreground(t.Primary()).
Bold(true).
Render(title)
// Get LSP names and sort them for consistent ordering
var lspNames []string
@@ -39,16 +52,19 @@ func lspsConfigured(width int) string {
var lspViews []string
for _, name := range lspNames {
lsp := cfg.LSP[name]
lspName := styles.BaseStyle.Foreground(styles.Forground).Render(
fmt.Sprintf("• %s", name),
)
lspName := baseStyle.
Foreground(t.Text()).
Render(fmt.Sprintf("• %s", name))
cmd := lsp.Command
cmd = ansi.Truncate(cmd, width-lipgloss.Width(lspName)-3, "…")
lspPath := styles.BaseStyle.Foreground(styles.ForgroundDim).Render(
fmt.Sprintf(" (%s)", cmd),
)
lspPath := baseStyle.
Foreground(t.TextMuted()).
Render(fmt.Sprintf(" (%s)", cmd))
lspViews = append(lspViews,
styles.BaseStyle.
baseStyle.
Width(width).
Render(
lipgloss.JoinHorizontal(
@@ -59,7 +75,8 @@ func lspsConfigured(width int) string {
),
)
}
return styles.BaseStyle.
return baseStyle.
Width(width).
Render(
lipgloss.JoinVertical(
@@ -75,10 +92,14 @@ func lspsConfigured(width int) string {
func logo(width int) string {
logo := fmt.Sprintf("%s %s", styles.OpenCodeIcon, "OpenCode")
t := theme.CurrentTheme()
baseStyle := styles.BaseStyle()
version := styles.BaseStyle.Foreground(styles.ForgroundDim).Render(version.Version)
versionText := baseStyle.
Foreground(t.TextMuted()).
Render(version.Version)
return styles.BaseStyle.
return baseStyle.
Bold(true).
Width(width).
Render(
@@ -86,34 +107,27 @@ func logo(width int) string {
lipgloss.Left,
logo,
" ",
version,
versionText,
),
)
}
func repo(width int) string {
repo := "https://github.com/opencode-ai/opencode"
return styles.BaseStyle.
Foreground(styles.ForgroundDim).
repo := "github.com/sst/opencode"
t := theme.CurrentTheme()
return styles.BaseStyle().
Foreground(t.TextMuted()).
Width(width).
Render(repo)
}
func cwd(width int) string {
cwd := fmt.Sprintf("cwd: %s", config.WorkingDirectory())
return styles.BaseStyle.
Foreground(styles.ForgroundDim).
t := theme.CurrentTheme()
return styles.BaseStyle().
Foreground(t.TextMuted()).
Width(width).
Render(cwd)
}
func header(width int) string {
header := lipgloss.JoinVertical(
lipgloss.Top,
logo(width),
repo(width),
"",
cwd(width),
)
return header
}

View File

@@ -1,24 +1,33 @@
package chat
import (
"fmt"
"os"
"os/exec"
"slices"
"unicode"
"github.com/charmbracelet/bubbles/key"
"github.com/charmbracelet/bubbles/textarea"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
"github.com/opencode-ai/opencode/internal/app"
"github.com/opencode-ai/opencode/internal/session"
"github.com/opencode-ai/opencode/internal/tui/layout"
"github.com/opencode-ai/opencode/internal/tui/styles"
"github.com/opencode-ai/opencode/internal/tui/util"
"github.com/sst/opencode/internal/app"
"github.com/sst/opencode/internal/message"
"github.com/sst/opencode/internal/status"
"github.com/sst/opencode/internal/tui/components/dialog"
"github.com/sst/opencode/internal/tui/layout"
"github.com/sst/opencode/internal/tui/styles"
"github.com/sst/opencode/internal/tui/theme"
"github.com/sst/opencode/internal/tui/util"
)
type editorCmp struct {
app *app.App
session session.Session
textarea textarea.Model
width int
height int
app *app.App
textarea textarea.Model
attachments []message.Attachment
deleteMode bool
}
type EditorKeyMaps struct {
@@ -31,6 +40,11 @@ type bluredEditorKeyMaps struct {
Focus key.Binding
OpenEditor key.Binding
}
type DeleteAttachmentKeyMaps struct {
AttachmentDeleteMode key.Binding
Escape key.Binding
DeleteAllAttachments key.Binding
}
var editorMaps = EditorKeyMaps{
Send: key.NewBinding(
@@ -43,15 +57,36 @@ var editorMaps = EditorKeyMaps{
),
}
func openEditor() tea.Cmd {
var DeleteKeyMaps = DeleteAttachmentKeyMaps{
AttachmentDeleteMode: key.NewBinding(
key.WithKeys("ctrl+r"),
key.WithHelp("ctrl+r+{i}", "delete attachment at index i"),
),
Escape: key.NewBinding(
key.WithKeys("esc"),
key.WithHelp("esc", "cancel delete mode"),
),
DeleteAllAttachments: key.NewBinding(
key.WithKeys("r"),
key.WithHelp("ctrl+r+r", "delete all attchments"),
),
}
const (
maxAttachments = 5
)
func (m *editorCmp) openEditor(value string) tea.Cmd {
editor := os.Getenv("EDITOR")
if editor == "" {
editor = "nvim"
}
tmpfile, err := os.CreateTemp("", "msg_*.md")
tmpfile.WriteString(value)
if err != nil {
return util.ReportError(err)
status.Error(err.Error())
return nil
}
tmpfile.Close()
c := exec.Command(editor, tmpfile.Name()) //nolint:gosec
@@ -60,18 +95,24 @@ func openEditor() tea.Cmd {
c.Stderr = os.Stderr
return tea.ExecProcess(c, func(err error) tea.Msg {
if err != nil {
return util.ReportError(err)
status.Error(err.Error())
return nil
}
content, err := os.ReadFile(tmpfile.Name())
if err != nil {
return util.ReportError(err)
status.Error(err.Error())
return nil
}
if len(content) == 0 {
return util.ReportWarn("Message is empty")
status.Warn("Message is empty")
return nil
}
os.Remove(tmpfile.Name())
attachments := m.attachments
m.attachments = nil
return SendMsg{
Text: string(content),
Text: string(content),
Attachments: attachments,
}
})
}
@@ -81,18 +122,23 @@ func (m *editorCmp) Init() tea.Cmd {
}
func (m *editorCmp) send() tea.Cmd {
if m.app.CoderAgent.IsSessionBusy(m.session.ID) {
return util.ReportWarn("Agent is working, please wait...")
if m.app.PrimaryAgent.IsSessionBusy(m.app.CurrentSession.ID) {
status.Warn("Agent is working, please wait...")
return nil
}
value := m.textarea.Value()
m.textarea.Reset()
attachments := m.attachments
m.attachments = nil
if value == "" {
return nil
}
return tea.Batch(
util.CmdHandler(SendMsg{
Text: value,
Text: value,
Attachments: attachments,
}),
)
}
@@ -100,21 +146,53 @@ func (m *editorCmp) send() tea.Cmd {
func (m *editorCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
var cmd tea.Cmd
switch msg := msg.(type) {
case SessionSelectedMsg:
if msg.ID != m.session.ID {
m.session = msg
}
case dialog.ThemeChangedMsg:
m.textarea = CreateTextArea(&m.textarea)
return m, nil
case dialog.AttachmentAddedMsg:
if len(m.attachments) >= maxAttachments {
status.Error(fmt.Sprintf("cannot add more than %d images", maxAttachments))
return m, cmd
}
m.attachments = append(m.attachments, msg.Attachment)
case tea.KeyMsg:
if key.Matches(msg, DeleteKeyMaps.AttachmentDeleteMode) {
m.deleteMode = true
return m, nil
}
if key.Matches(msg, DeleteKeyMaps.DeleteAllAttachments) && m.deleteMode {
m.deleteMode = false
m.attachments = nil
return m, nil
}
if m.deleteMode && len(msg.Runes) > 0 && unicode.IsDigit(msg.Runes[0]) {
num := int(msg.Runes[0] - '0')
m.deleteMode = false
if num < 10 && len(m.attachments) > num {
if num == 0 {
m.attachments = m.attachments[num+1:]
} else {
m.attachments = slices.Delete(m.attachments, num, num+1)
}
return m, nil
}
}
if key.Matches(msg, messageKeys.PageUp) || key.Matches(msg, messageKeys.PageDown) ||
key.Matches(msg, messageKeys.HalfPageUp) || key.Matches(msg, messageKeys.HalfPageDown) {
return m, nil
}
if key.Matches(msg, editorMaps.OpenEditor) {
if m.app.CoderAgent.IsSessionBusy(m.session.ID) {
return m, util.ReportWarn("Agent is working, please wait...")
if m.app.PrimaryAgent.IsSessionBusy(m.app.CurrentSession.ID) {
status.Warn("Agent is working, please wait...")
return m, nil
}
return m, openEditor()
value := m.textarea.Value()
m.textarea.Reset()
return m, m.openEditor(value)
}
if key.Matches(msg, DeleteKeyMaps.Escape) {
m.deleteMode = false
return m, nil
}
// Handle Enter key
if m.textarea.Focused() && key.Matches(msg, editorMaps.Send) {
@@ -128,20 +206,38 @@ func (m *editorCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
return m, m.send()
}
}
}
m.textarea, cmd = m.textarea.Update(msg)
return m, cmd
}
func (m *editorCmp) View() string {
style := lipgloss.NewStyle().Padding(0, 0, 0, 1).Bold(true)
t := theme.CurrentTheme()
return lipgloss.JoinHorizontal(lipgloss.Top, style.Render(">"), m.textarea.View())
// Style the prompt with theme colors
style := lipgloss.NewStyle().
Padding(0, 0, 0, 1).
Bold(true).
Foreground(t.Primary())
if len(m.attachments) == 0 {
return lipgloss.JoinHorizontal(lipgloss.Top, style.Render(">"), m.textarea.View())
}
m.textarea.SetHeight(m.height - 1)
return lipgloss.JoinVertical(lipgloss.Top,
m.attachmentsContent(),
lipgloss.JoinHorizontal(lipgloss.Top, style.Render(">"),
m.textarea.View()),
)
}
func (m *editorCmp) SetSize(width, height int) tea.Cmd {
m.width = width
m.height = height
m.textarea.SetWidth(width - 3) // account for the prompt and padding right
m.textarea.SetHeight(height)
m.textarea.SetWidth(width)
return nil
}
@@ -149,29 +245,70 @@ func (m *editorCmp) GetSize() (int, int) {
return m.textarea.Width(), m.textarea.Height()
}
func (m *editorCmp) attachmentsContent() string {
var styledAttachments []string
t := theme.CurrentTheme()
attachmentStyles := styles.BaseStyle().
MarginLeft(1).
Background(t.TextMuted()).
Foreground(t.Text())
for i, attachment := range m.attachments {
var filename string
if len(attachment.FileName) > 10 {
filename = fmt.Sprintf(" %s %s...", styles.DocumentIcon, attachment.FileName[0:7])
} else {
filename = fmt.Sprintf(" %s %s", styles.DocumentIcon, attachment.FileName)
}
if m.deleteMode {
filename = fmt.Sprintf("%d%s", i, filename)
}
styledAttachments = append(styledAttachments, attachmentStyles.Render(filename))
}
content := lipgloss.JoinHorizontal(lipgloss.Left, styledAttachments...)
return content
}
func (m *editorCmp) BindingKeys() []key.Binding {
bindings := []key.Binding{}
bindings = append(bindings, layout.KeyMapToSlice(editorMaps)...)
bindings = append(bindings, layout.KeyMapToSlice(DeleteKeyMaps)...)
return bindings
}
func NewEditorCmp(app *app.App) tea.Model {
ti := textarea.New()
ti.Prompt = " "
ti.ShowLineNumbers = false
ti.BlurredStyle.Base = ti.BlurredStyle.Base.Background(styles.Background)
ti.BlurredStyle.CursorLine = ti.BlurredStyle.CursorLine.Background(styles.Background)
ti.BlurredStyle.Placeholder = ti.BlurredStyle.Placeholder.Background(styles.Background)
ti.BlurredStyle.Text = ti.BlurredStyle.Text.Background(styles.Background)
func CreateTextArea(existing *textarea.Model) textarea.Model {
t := theme.CurrentTheme()
bgColor := t.Background()
textColor := t.Text()
textMutedColor := t.TextMuted()
ti.FocusedStyle.Base = ti.FocusedStyle.Base.Background(styles.Background)
ti.FocusedStyle.CursorLine = ti.FocusedStyle.CursorLine.Background(styles.Background)
ti.FocusedStyle.Placeholder = ti.FocusedStyle.Placeholder.Background(styles.Background)
ti.FocusedStyle.Text = ti.BlurredStyle.Text.Background(styles.Background)
ti.CharLimit = -1
ti.Focus()
ta := textarea.New()
ta.BlurredStyle.Base = styles.BaseStyle().Background(bgColor).Foreground(textColor)
ta.BlurredStyle.CursorLine = styles.BaseStyle().Background(bgColor)
ta.BlurredStyle.Placeholder = styles.BaseStyle().Background(bgColor).Foreground(textMutedColor)
ta.BlurredStyle.Text = styles.BaseStyle().Background(bgColor).Foreground(textColor)
ta.FocusedStyle.Base = styles.BaseStyle().Background(bgColor).Foreground(textColor)
ta.FocusedStyle.CursorLine = styles.BaseStyle().Background(bgColor)
ta.FocusedStyle.Placeholder = styles.BaseStyle().Background(bgColor).Foreground(textMutedColor)
ta.FocusedStyle.Text = styles.BaseStyle().Background(bgColor).Foreground(textColor)
ta.Prompt = " "
ta.ShowLineNumbers = false
ta.CharLimit = -1
if existing != nil {
ta.SetValue(existing.Value())
ta.SetWidth(existing.Width())
ta.SetHeight(existing.Height())
}
ta.Focus()
return ta
}
func NewEditorCmp(app *app.App) tea.Model {
ta := CreateTextArea(nil)
return &editorCmp{
app: app,
textarea: ti,
textarea: ta,
}
}

View File

@@ -4,37 +4,44 @@ import (
"context"
"fmt"
"math"
"time"
"github.com/charmbracelet/bubbles/key"
"github.com/charmbracelet/bubbles/spinner"
"github.com/charmbracelet/bubbles/viewport"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
"github.com/opencode-ai/opencode/internal/app"
"github.com/opencode-ai/opencode/internal/message"
"github.com/opencode-ai/opencode/internal/pubsub"
"github.com/opencode-ai/opencode/internal/session"
"github.com/opencode-ai/opencode/internal/tui/styles"
"github.com/opencode-ai/opencode/internal/tui/util"
"github.com/sst/opencode/internal/app"
"github.com/sst/opencode/internal/message"
"github.com/sst/opencode/internal/pubsub"
"github.com/sst/opencode/internal/session"
"github.com/sst/opencode/internal/status"
"github.com/sst/opencode/internal/tui/components/dialog"
"github.com/sst/opencode/internal/tui/state"
"github.com/sst/opencode/internal/tui/styles"
"github.com/sst/opencode/internal/tui/theme"
)
type cacheItem struct {
width int
content []uiMessage
}
type messagesCmp struct {
app *app.App
width, height int
viewport viewport.Model
session session.Session
messages []message.Message
uiMessages []uiMessage
currentMsgID string
cachedContent map[string]cacheItem
spinner spinner.Model
rendering bool
app *app.App
width, height int
viewport viewport.Model
messages []message.Message
uiMessages []uiMessage
currentMsgID string
cachedContent map[string]cacheItem
spinner spinner.Model
rendering bool
attachments viewport.Model
showToolMessages bool
}
type renderFinishedMsg struct{}
type ToggleToolMessagesMsg struct{}
type MessageKeys struct {
PageDown key.Binding
@@ -69,20 +76,23 @@ func (m *messagesCmp) Init() tea.Cmd {
func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
var cmds []tea.Cmd
switch msg := msg.(type) {
case SessionSelectedMsg:
if msg.ID != m.session.ID {
cmd := m.SetSession(msg)
return m, cmd
}
case dialog.ThemeChangedMsg:
m.rerender()
return m, nil
case SessionClearedMsg:
m.session = session.Session{}
case ToggleToolMessagesMsg:
m.showToolMessages = !m.showToolMessages
// Clear the cache to force re-rendering of all messages
m.cachedContent = make(map[string]cacheItem)
m.renderView()
return m, nil
case state.SessionSelectedMsg:
cmd := m.Reload(msg)
return m, cmd
case state.SessionClearedMsg:
m.messages = make([]message.Message, 0)
m.currentMsgID = ""
m.rendering = false
return m, nil
case tea.KeyMsg:
if key.Matches(msg, messageKeys.PageUp) || key.Matches(msg, messageKeys.PageDown) ||
key.Matches(msg, messageKeys.HalfPageUp) || key.Matches(msg, messageKeys.HalfPageDown) {
@@ -90,15 +100,13 @@ func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
m.viewport = u
cmds = append(cmds, cmd)
}
case renderFinishedMsg:
m.rendering = false
m.viewport.GotoBottom()
case pubsub.Event[message.Message]:
needsRerender := false
if msg.Type == pubsub.CreatedEvent {
if msg.Payload.SessionID == m.session.ID {
if msg.Type == message.EventMessageCreated {
if msg.Payload.SessionID == m.app.CurrentSession.ID {
messageExists := false
for _, v := range m.messages {
if v.ID == msg.Payload.ID {
@@ -128,7 +136,7 @@ func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
}
}
}
} else if msg.Type == pubsub.UpdatedEvent && msg.Payload.SessionID == m.session.ID {
} else if msg.Type == message.EventMessageUpdated && msg.Payload.SessionID == m.app.CurrentSession.ID {
for i, v := range m.messages {
if v.ID == msg.Payload.ID {
m.messages[i] = msg.Payload
@@ -141,8 +149,8 @@ func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
if needsRerender {
m.renderView()
if len(m.messages) > 0 {
if (msg.Type == pubsub.CreatedEvent) ||
(msg.Type == pubsub.UpdatedEvent && msg.Payload.ID == m.messages[len(m.messages)-1].ID) {
if (msg.Type == message.EventMessageCreated) ||
(msg.Type == message.EventMessageUpdated && msg.Payload.ID == m.messages[len(m.messages)-1].ID) {
m.viewport.GotoBottom()
}
}
@@ -156,7 +164,7 @@ func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
}
func (m *messagesCmp) IsAgentWorking() bool {
return m.app.CoderAgent.IsSessionBusy(m.session.ID)
return m.app.PrimaryAgent.IsSessionBusy(m.app.CurrentSession.ID)
}
func formatTimeDifference(unixTime1, unixTime2 int64) string {
@@ -174,6 +182,7 @@ func formatTimeDifference(unixTime1, unixTime2 int64) string {
func (m *messagesCmp) renderView() {
m.uiMessages = make([]uiMessage, 0)
pos := 0
baseStyle := styles.BaseStyle()
if m.width == 0 {
return
@@ -210,6 +219,7 @@ func (m *messagesCmp) renderView() {
m.currentMsgID,
m.width,
pos,
m.showToolMessages,
)
for _, msg := range assistantMessages {
m.uiMessages = append(m.uiMessages, msg)
@@ -224,16 +234,17 @@ func (m *messagesCmp) renderView() {
messages := make([]string, 0)
for _, v := range m.uiMessages {
messages = append(messages, v.content,
styles.BaseStyle.
messages = append(messages, lipgloss.JoinVertical(lipgloss.Left, v.content),
baseStyle.
Width(m.width).
Render(
"",
),
)
}
m.viewport.SetContent(
styles.BaseStyle.
baseStyle.
Width(m.width).
Render(
lipgloss.JoinVertical(
@@ -245,8 +256,10 @@ func (m *messagesCmp) renderView() {
}
func (m *messagesCmp) View() string {
baseStyle := styles.BaseStyle()
if m.rendering {
return styles.BaseStyle.
return baseStyle.
Width(m.width).
Render(
lipgloss.JoinVertical(
@@ -258,14 +271,14 @@ func (m *messagesCmp) View() string {
)
}
if len(m.messages) == 0 {
content := styles.BaseStyle.
content := baseStyle.
Width(m.width).
Height(m.height - 1).
Render(
m.initialScreen(),
)
return styles.BaseStyle.
return baseStyle.
Width(m.width).
Render(
lipgloss.JoinVertical(
@@ -277,7 +290,7 @@ func (m *messagesCmp) View() string {
)
}
return styles.BaseStyle.
return baseStyle.
Width(m.width).
Render(
lipgloss.JoinVertical(
@@ -328,6 +341,9 @@ func hasUnfinishedToolCalls(messages []message.Message) bool {
func (m *messagesCmp) working() string {
text := ""
if m.IsAgentWorking() && len(m.messages) > 0 {
t := theme.CurrentTheme()
baseStyle := styles.BaseStyle()
task := "Thinking..."
lastMessage := m.messages[len(m.messages)-1]
if hasToolsWithoutResponse(m.messages) {
@@ -338,42 +354,51 @@ func (m *messagesCmp) working() string {
task = "Generating..."
}
if task != "" {
text += styles.BaseStyle.Width(m.width).Foreground(styles.PrimaryColor).Bold(true).Render(
fmt.Sprintf("%s %s ", m.spinner.View(), task),
)
text += baseStyle.
Width(m.width).
Foreground(t.Primary()).
Bold(true).
Render(fmt.Sprintf("%s %s ", m.spinner.View(), task))
}
}
return text
}
func (m *messagesCmp) help() string {
t := theme.CurrentTheme()
baseStyle := styles.BaseStyle()
text := ""
if m.app.CoderAgent.IsBusy() {
if m.app.PrimaryAgent.IsBusy() {
text += lipgloss.JoinHorizontal(
lipgloss.Left,
styles.BaseStyle.Foreground(styles.ForgroundDim).Bold(true).Render("press "),
styles.BaseStyle.Foreground(styles.Forground).Bold(true).Render("esc"),
styles.BaseStyle.Foreground(styles.ForgroundDim).Bold(true).Render(" to exit cancel"),
baseStyle.Foreground(t.TextMuted()).Bold(true).Render("press "),
baseStyle.Foreground(t.Text()).Bold(true).Render("esc"),
baseStyle.Foreground(t.TextMuted()).Bold(true).Render(" to interrupt"),
)
} else {
text += lipgloss.JoinHorizontal(
lipgloss.Left,
styles.BaseStyle.Foreground(styles.ForgroundDim).Bold(true).Render("press "),
styles.BaseStyle.Foreground(styles.Forground).Bold(true).Render("enter"),
styles.BaseStyle.Foreground(styles.ForgroundDim).Bold(true).Render(" to send the message,"),
styles.BaseStyle.Foreground(styles.ForgroundDim).Bold(true).Render(" write"),
styles.BaseStyle.Foreground(styles.Forground).Bold(true).Render(" \\"),
styles.BaseStyle.Foreground(styles.ForgroundDim).Bold(true).Render(" and enter to add a new line"),
baseStyle.Foreground(t.Text()).Bold(true).Render("enter"),
baseStyle.Foreground(t.TextMuted()).Bold(true).Render(" to send,"),
baseStyle.Foreground(t.Text()).Bold(true).Render(" \\"),
baseStyle.Foreground(t.TextMuted()).Bold(true).Render("+"),
baseStyle.Foreground(t.Text()).Bold(true).Render("enter"),
baseStyle.Foreground(t.TextMuted()).Bold(true).Render(" for newline,"),
baseStyle.Foreground(t.Text()).Bold(true).Render(" ctrl+h"),
baseStyle.Foreground(t.TextMuted()).Bold(true).Render(" to toggle tool messages"),
)
}
return styles.BaseStyle.
return baseStyle.
Width(m.width).
Render(text)
}
func (m *messagesCmp) initialScreen() string {
return styles.BaseStyle.Width(m.width).Render(
baseStyle := styles.BaseStyle()
return baseStyle.Width(m.width).Render(
lipgloss.JoinVertical(
lipgloss.Top,
header(m.width),
@@ -383,6 +408,13 @@ func (m *messagesCmp) initialScreen() string {
)
}
func (m *messagesCmp) rerender() {
for _, msg := range m.messages {
delete(m.cachedContent, msg.ID)
}
m.renderView()
}
func (m *messagesCmp) SetSize(width, height int) tea.Cmd {
if m.width == width && m.height == height {
return nil
@@ -391,11 +423,9 @@ func (m *messagesCmp) SetSize(width, height int) tea.Cmd {
m.height = height
m.viewport.Width = width
m.viewport.Height = height - 2
for _, msg := range m.messages {
delete(m.cachedContent, msg.ID)
}
m.uiMessages = make([]uiMessage, 0)
m.renderView()
m.attachments.Width = width + 40
m.attachments.Height = 3
m.rerender()
return nil
}
@@ -403,17 +433,16 @@ func (m *messagesCmp) GetSize() (int, int) {
return m.width, m.height
}
func (m *messagesCmp) SetSession(session session.Session) tea.Cmd {
if m.session.ID == session.ID {
return nil
}
m.session = session
func (m *messagesCmp) Reload(session *session.Session) tea.Cmd {
messages, err := m.app.Messages.List(context.Background(), session.ID)
if err != nil {
return util.ReportError(err)
status.Error(err.Error())
return nil
}
m.messages = messages
m.currentMsgID = m.messages[len(m.messages)-1].ID
if len(m.messages) > 0 {
m.currentMsgID = m.messages[len(m.messages)-1].ID
}
delete(m.cachedContent, m.currentMsgID)
m.rendering = true
return func() tea.Msg {
@@ -432,17 +461,23 @@ func (m *messagesCmp) BindingKeys() []key.Binding {
}
func NewMessagesCmp(app *app.App) tea.Model {
s := spinner.New()
s.Spinner = spinner.Pulse
customSpinner := spinner.Spinner{
Frames: []string{" ", "┃", "┃"},
FPS: time.Second / 3,
}
s := spinner.New(spinner.WithSpinner(customSpinner))
vp := viewport.New(0, 0)
attachmets := viewport.New(0, 0)
vp.KeyMap.PageUp = messageKeys.PageUp
vp.KeyMap.PageDown = messageKeys.PageDown
vp.KeyMap.HalfPageUp = messageKeys.HalfPageUp
vp.KeyMap.HalfPageDown = messageKeys.HalfPageDown
return &messagesCmp{
app: app,
cachedContent: make(map[string]cacheItem),
viewport: vp,
spinner: s,
app: app,
cachedContent: make(map[string]cacheItem),
viewport: vp,
spinner: s,
attachments: attachmets,
showToolMessages: true,
}
}

View File

@@ -6,19 +6,18 @@ import (
"fmt"
"path/filepath"
"strings"
"sync"
"time"
"github.com/charmbracelet/glamour"
"github.com/charmbracelet/lipgloss"
"github.com/charmbracelet/x/ansi"
"github.com/opencode-ai/opencode/internal/config"
"github.com/opencode-ai/opencode/internal/diff"
"github.com/opencode-ai/opencode/internal/llm/agent"
"github.com/opencode-ai/opencode/internal/llm/models"
"github.com/opencode-ai/opencode/internal/llm/tools"
"github.com/opencode-ai/opencode/internal/message"
"github.com/opencode-ai/opencode/internal/tui/styles"
"github.com/sst/opencode/internal/config"
"github.com/sst/opencode/internal/diff"
"github.com/sst/opencode/internal/llm/agent"
"github.com/sst/opencode/internal/llm/models"
"github.com/sst/opencode/internal/llm/tools"
"github.com/sst/opencode/internal/message"
"github.com/sst/opencode/internal/tui/styles"
"github.com/sst/opencode/internal/tui/theme"
)
type uiMessageType int
@@ -31,8 +30,6 @@ const (
maxResultHeight = 10
)
var diffStyle = diff.NewStyleConfig(diff.WithShowHeader(false), diff.WithShowHunkHeader(false))
type uiMessage struct {
ID string
messageType uiMessageType
@@ -41,46 +38,37 @@ type uiMessage struct {
content string
}
type renderCache struct {
mutex sync.Mutex
cache map[string][]uiMessage
}
func toMarkdown(content string, focused bool, width int) string {
r, _ := glamour.NewTermRenderer(
glamour.WithStyles(styles.MarkdownTheme(false)),
glamour.WithWordWrap(width),
)
if focused {
r, _ = glamour.NewTermRenderer(
glamour.WithStyles(styles.MarkdownTheme(true)),
glamour.WithWordWrap(width),
)
}
r := styles.GetMarkdownRenderer(width)
rendered, _ := r.Render(content)
return rendered
}
func renderMessage(msg string, isUser bool, isFocused bool, width int, info ...string) string {
style := styles.BaseStyle.
t := theme.CurrentTheme()
style := styles.BaseStyle().
Width(width - 1).
BorderLeft(true).
Foreground(styles.ForgroundDim).
BorderForeground(styles.PrimaryColor).
Foreground(t.TextMuted()).
BorderForeground(t.Primary()).
BorderStyle(lipgloss.ThickBorder())
if isUser {
style = style.
BorderForeground(styles.Blue)
}
parts := []string{
styles.ForceReplaceBackgroundWithLipgloss(toMarkdown(msg, isFocused, width), styles.Background),
style = style.BorderForeground(t.Secondary())
}
// remove newline at the end
// Apply markdown formatting and handle background color
parts := []string{
styles.ForceReplaceBackgroundWithLipgloss(toMarkdown(msg, isFocused, width), t.Background()),
}
// Remove newline at the end
parts[0] = strings.TrimSuffix(parts[0], "\n")
if len(info) > 0 {
parts = append(parts, info...)
}
rendered := style.Render(
lipgloss.JoinVertical(
lipgloss.Left,
@@ -92,7 +80,41 @@ func renderMessage(msg string, isUser bool, isFocused bool, width int, info ...s
}
func renderUserMessage(msg message.Message, isFocused bool, width int, position int) uiMessage {
content := renderMessage(msg.Content().String(), true, isFocused, width)
var styledAttachments []string
t := theme.CurrentTheme()
baseStyle := styles.BaseStyle()
attachmentStyles := baseStyle.
MarginLeft(1).
Background(t.TextMuted()).
Foreground(t.Text())
for _, attachment := range msg.BinaryContent() {
file := filepath.Base(attachment.Path)
var filename string
if len(file) > 10 {
filename = fmt.Sprintf(" %s %s...", styles.DocumentIcon, file[0:7])
} else {
filename = fmt.Sprintf(" %s %s", styles.DocumentIcon, file)
}
styledAttachments = append(styledAttachments, attachmentStyles.Render(filename))
}
// Add timestamp info
info := []string{}
timestamp := msg.CreatedAt.Local().Format("02 Jan 2006 03:04 PM")
username, _ := config.GetUsername()
info = append(info, baseStyle.
Width(width-1).
Foreground(t.TextMuted()).
Render(fmt.Sprintf(" %s (%s)", username, timestamp)),
)
content := ""
if len(styledAttachments) > 0 {
attachmentContent := baseStyle.Width(width).Render(lipgloss.JoinHorizontal(lipgloss.Left, styledAttachments...))
content = renderMessage(msg.Content().String(), true, isFocused, width, append(info, attachmentContent)...)
} else {
content = renderMessage(msg.Content().String(), true, isFocused, width, info...)
}
userMsg := uiMessage{
ID: msg.ID,
messageType: userMessageType,
@@ -112,37 +134,56 @@ func renderAssistantMessage(
focusedUIMessageId string,
width int,
position int,
showToolMessages bool,
) []uiMessage {
messages := []uiMessage{}
content := msg.Content().String()
content := strings.TrimSpace(msg.Content().String())
thinking := msg.IsThinking()
thinkingContent := msg.ReasoningContent().Thinking
finished := msg.IsFinished()
finishData := msg.FinishPart()
info := []string{}
// Add finish info if available
t := theme.CurrentTheme()
baseStyle := styles.BaseStyle()
// Always add timestamp info
timestamp := msg.CreatedAt.Local().Format("02 Jan 2006 03:04 PM")
modelName := "Assistant"
if msg.Model != "" {
modelName = models.SupportedModels[msg.Model].Name
}
info = append(info, baseStyle.
Width(width-1).
Foreground(t.TextMuted()).
Render(fmt.Sprintf(" %s (%s)", modelName, timestamp)),
)
if finished {
// Add finish info if available
switch finishData.Reason {
case message.FinishReasonEndTurn:
took := formatTimeDifference(msg.CreatedAt, finishData.Time)
info = append(info, styles.BaseStyle.Width(width-1).Foreground(styles.ForgroundDim).Render(
fmt.Sprintf(" %s (%s)", models.SupportedModels[msg.Model].Name, took),
))
case message.FinishReasonCanceled:
info = append(info, styles.BaseStyle.Width(width-1).Foreground(styles.ForgroundDim).Render(
fmt.Sprintf(" %s (%s)", models.SupportedModels[msg.Model].Name, "canceled"),
))
info = append(info, baseStyle.
Width(width-1).
Foreground(t.Warning()).
Render("(canceled)"),
)
case message.FinishReasonError:
info = append(info, styles.BaseStyle.Width(width-1).Foreground(styles.ForgroundDim).Render(
fmt.Sprintf(" %s (%s)", models.SupportedModels[msg.Model].Name, "error"),
))
info = append(info, baseStyle.
Width(width-1).
Foreground(t.Error()).
Render("(error)"),
)
case message.FinishReasonPermissionDenied:
info = append(info, styles.BaseStyle.Width(width-1).Foreground(styles.ForgroundDim).Render(
fmt.Sprintf(" %s (%s)", models.SupportedModels[msg.Model].Name, "permission denied"),
))
info = append(info, baseStyle.
Width(width-1).
Foreground(t.Info()).
Render("(permission denied)"),
)
}
}
if content != "" || (finished && finishData.Reason == message.FinishReasonEndTurn) {
if content == "" {
content = "*Finished without output*"
@@ -159,23 +200,35 @@ func renderAssistantMessage(
position += messages[0].height
position++ // for the space
} else if thinking && thinkingContent != "" {
// Render the thinking content
content = renderMessage(thinkingContent, false, msg.ID == focusedUIMessageId, width)
// Render the thinking content with timestamp
content = renderMessage(thinkingContent, false, msg.ID == focusedUIMessageId, width, info...)
messages = append(messages, uiMessage{
ID: msg.ID,
messageType: assistantMessageType,
position: position,
height: lipgloss.Height(content),
content: content,
})
position += lipgloss.Height(content)
position++ // for the space
}
for i, toolCall := range msg.ToolCalls() {
toolCallContent := renderToolMessage(
toolCall,
allMessages,
messagesService,
focusedUIMessageId,
false,
width,
i+1,
)
messages = append(messages, toolCallContent)
position += toolCallContent.height
position++ // for the space
// Only render tool messages if they should be shown
if showToolMessages {
for i, toolCall := range msg.ToolCalls() {
toolCallContent := renderToolMessage(
toolCall,
allMessages,
messagesService,
focusedUIMessageId,
false,
width,
i+1,
)
messages = append(messages, toolCallContent)
position += toolCallContent.height
position++ // for the space
}
}
return messages
}
@@ -207,8 +260,6 @@ func toolName(name string) string {
return "Grep"
case tools.LSToolName:
return "List"
case tools.SourcegraphToolName:
return "Sourcegraph"
case tools.ViewToolName:
return "View"
case tools.WriteToolName:
@@ -235,8 +286,6 @@ func getToolAction(name string) string {
return "Searching content..."
case tools.LSToolName:
return "Listing directory..."
case tools.SourcegraphToolName:
return "Searching code..."
case tools.ViewToolName:
return "Reading file..."
case tools.WriteToolName:
@@ -375,10 +424,6 @@ func renderToolParams(paramWidth int, toolCall message.ToolCall) string {
path = "."
}
return renderParams(paramWidth, path)
case tools.SourcegraphToolName:
var params tools.SourcegraphParams
json.Unmarshal([]byte(toolCall.Input), &params)
return renderParams(paramWidth, params.Query)
case tools.ViewToolName:
var params tools.ViewParams
json.Unmarshal([]byte(toolCall.Input), &params)
@@ -414,32 +459,35 @@ func truncateHeight(content string, height int) string {
}
func renderToolResponse(toolCall message.ToolCall, response message.ToolResult, width int) string {
t := theme.CurrentTheme()
baseStyle := styles.BaseStyle()
if response.IsError {
errContent := fmt.Sprintf("Error: %s", strings.ReplaceAll(response.Content, "\n", " "))
errContent = ansi.Truncate(errContent, width-1, "...")
return styles.BaseStyle.
return baseStyle.
Width(width).
Foreground(styles.Error).
Foreground(t.Error()).
Render(errContent)
}
resultContent := truncateHeight(response.Content, maxResultHeight)
switch toolCall.Name {
case agent.AgentToolName:
return styles.ForceReplaceBackgroundWithLipgloss(
toMarkdown(resultContent, false, width),
styles.Background,
t.Background(),
)
case tools.BashToolName:
resultContent = fmt.Sprintf("```bash\n%s\n```", resultContent)
return styles.ForceReplaceBackgroundWithLipgloss(
toMarkdown(resultContent, true, width),
styles.Background,
t.Background(),
)
case tools.EditToolName:
metadata := tools.EditResponseMetadata{}
json.Unmarshal([]byte(response.Metadata), &metadata)
truncDiff := truncateHeight(metadata.Diff, maxResultHeight)
formattedDiff, _ := diff.FormatDiff(truncDiff, diff.WithTotalWidth(width), diff.WithStyle(diffStyle))
formattedDiff, _ := diff.FormatDiff(metadata.Diff, diff.WithTotalWidth(width))
return formattedDiff
case tools.FetchToolName:
var params tools.FetchParams
@@ -454,16 +502,14 @@ func renderToolResponse(toolCall message.ToolCall, response message.ToolResult,
resultContent = fmt.Sprintf("```%s\n%s\n```", mdFormat, resultContent)
return styles.ForceReplaceBackgroundWithLipgloss(
toMarkdown(resultContent, true, width),
styles.Background,
t.Background(),
)
case tools.GlobToolName:
return styles.BaseStyle.Width(width).Foreground(styles.ForgroundMid).Render(resultContent)
return baseStyle.Width(width).Foreground(t.TextMuted()).Render(resultContent)
case tools.GrepToolName:
return styles.BaseStyle.Width(width).Foreground(styles.ForgroundMid).Render(resultContent)
return baseStyle.Width(width).Foreground(t.TextMuted()).Render(resultContent)
case tools.LSToolName:
return styles.BaseStyle.Width(width).Foreground(styles.ForgroundMid).Render(resultContent)
case tools.SourcegraphToolName:
return styles.BaseStyle.Width(width).Foreground(styles.ForgroundMid).Render(resultContent)
return baseStyle.Width(width).Foreground(t.TextMuted()).Render(resultContent)
case tools.ViewToolName:
metadata := tools.ViewResponseMetadata{}
json.Unmarshal([]byte(response.Metadata), &metadata)
@@ -476,7 +522,7 @@ func renderToolResponse(toolCall message.ToolCall, response message.ToolResult,
resultContent = fmt.Sprintf("```%s\n%s\n```", ext, truncateHeight(metadata.Content, maxResultHeight))
return styles.ForceReplaceBackgroundWithLipgloss(
toMarkdown(resultContent, true, width),
styles.Background,
t.Background(),
)
case tools.WriteToolName:
params := tools.WriteParams{}
@@ -492,13 +538,13 @@ func renderToolResponse(toolCall message.ToolCall, response message.ToolResult,
resultContent = fmt.Sprintf("```%s\n%s\n```", ext, truncateHeight(params.Content, maxResultHeight))
return styles.ForceReplaceBackgroundWithLipgloss(
toMarkdown(resultContent, true, width),
styles.Background,
t.Background(),
)
default:
resultContent = fmt.Sprintf("```text\n%s\n```", resultContent)
return styles.ForceReplaceBackgroundWithLipgloss(
toMarkdown(resultContent, true, width),
styles.Background,
t.Background(),
)
}
}
@@ -515,39 +561,31 @@ func renderToolMessage(
if nested {
width = width - 3
}
style := styles.BaseStyle.
t := theme.CurrentTheme()
baseStyle := styles.BaseStyle()
style := baseStyle.
Width(width - 1).
BorderLeft(true).
BorderStyle(lipgloss.ThickBorder()).
PaddingLeft(1).
BorderForeground(styles.ForgroundDim)
BorderForeground(t.TextMuted())
response := findToolResponse(toolCall.ID, allMessages)
toolName := styles.BaseStyle.Foreground(styles.ForgroundDim).Render(fmt.Sprintf("%s: ", toolName(toolCall.Name)))
toolNameText := baseStyle.Foreground(t.TextMuted()).
Render(fmt.Sprintf("%s: ", toolName(toolCall.Name)))
if !toolCall.Finished {
// Get a brief description of what the tool is doing
toolAction := getToolAction(toolCall.Name)
// toolInput := strings.ReplaceAll(toolCall.Input, "\n", " ")
// truncatedInput := toolInput
// if len(truncatedInput) > 10 {
// truncatedInput = truncatedInput[len(truncatedInput)-10:]
// }
//
// truncatedInput = styles.BaseStyle.
// Italic(true).
// Width(width - 2 - lipgloss.Width(toolName)).
// Background(styles.BackgroundDim).
// Foreground(styles.ForgroundMid).
// Render(truncatedInput)
progressText := styles.BaseStyle.
Width(width - 2 - lipgloss.Width(toolName)).
Foreground(styles.ForgroundDim).
progressText := baseStyle.
Width(width - 2 - lipgloss.Width(toolNameText)).
Foreground(t.TextMuted()).
Render(fmt.Sprintf("%s", toolAction))
content := style.Render(lipgloss.JoinHorizontal(lipgloss.Left, toolName, progressText))
content := style.Render(lipgloss.JoinHorizontal(lipgloss.Left, toolNameText, progressText))
toolMsg := uiMessage{
messageType: toolMessageType,
position: position,
@@ -556,37 +594,39 @@ func renderToolMessage(
}
return toolMsg
}
params := renderToolParams(width-2-lipgloss.Width(toolName), toolCall)
params := renderToolParams(width-1-lipgloss.Width(toolNameText), toolCall)
responseContent := ""
if response != nil {
responseContent = renderToolResponse(toolCall, *response, width-2)
responseContent = strings.TrimSuffix(responseContent, "\n")
} else {
responseContent = styles.BaseStyle.
responseContent = baseStyle.
Italic(true).
Width(width - 2).
Foreground(styles.ForgroundDim).
Foreground(t.TextMuted()).
Render("Waiting for response...")
}
parts := []string{}
if !nested {
params := styles.BaseStyle.
Width(width - 2 - lipgloss.Width(toolName)).
Foreground(styles.ForgroundDim).
formattedParams := baseStyle.
Width(width - 2 - lipgloss.Width(toolNameText)).
Foreground(t.TextMuted()).
Render(params)
parts = append(parts, lipgloss.JoinHorizontal(lipgloss.Left, toolName, params))
parts = append(parts, lipgloss.JoinHorizontal(lipgloss.Left, toolNameText, formattedParams))
} else {
prefix := styles.BaseStyle.
Foreground(styles.ForgroundDim).
prefix := baseStyle.
Foreground(t.TextMuted()).
Render(" └ ")
params := styles.BaseStyle.
Width(width - 2 - lipgloss.Width(toolName)).
Foreground(styles.ForgroundMid).
formattedParams := baseStyle.
Width(width - 2 - lipgloss.Width(toolNameText)).
Foreground(t.TextMuted()).
Render(params)
parts = append(parts, lipgloss.JoinHorizontal(lipgloss.Left, prefix, toolName, params))
parts = append(parts, lipgloss.JoinHorizontal(lipgloss.Left, prefix, toolNameText, formattedParams))
}
if toolCall.Name == agent.AgentToolName {
taskMessages, _ := messagesService.List(context.Background(), toolCall.ID)
toolCalls := []message.ToolCall{}

Some files were not shown because too many files have changed in this diff Show More