mirror of
https://github.com/anomalyco/opencode.git
synced 2026-04-24 06:45:22 +00:00
Compare commits
103 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c8f8d67a88 | ||
|
|
182e32e4f7 | ||
|
|
5ea989fb74 | ||
|
|
45c778b90d | ||
|
|
47cbb650a0 | ||
|
|
e91371c6a5 | ||
|
|
9d17314309 | ||
|
|
3982be4310 | ||
|
|
4c998d4f4f | ||
|
|
f7849c2d59 | ||
|
|
463002185b | ||
|
|
53a80eac1e | ||
|
|
01b6bf5bb7 | ||
|
|
d8f3b60625 | ||
|
|
cf8e16018d | ||
|
|
674797bd48 | ||
|
|
1f9610e266 | ||
|
|
ae86ef519c | ||
|
|
2391e338b4 | ||
|
|
1e9399fbee | ||
|
|
e9f74b867f | ||
|
|
5079556896 | ||
|
|
7f0e68b933 | ||
|
|
0c21ca5318 | ||
|
|
0117c72a2c | ||
|
|
e3eb9e5435 | ||
|
|
d941be3f1f | ||
|
|
36e5ae804e | ||
|
|
c9b90dd184 | ||
|
|
8270a1e4b1 | ||
|
|
7f9c992993 | ||
|
|
b6524c0982 | ||
|
|
425c0f1bab | ||
|
|
d20d0c5a95 | ||
|
|
5af3c05d41 | ||
|
|
df4a9295c0 | ||
|
|
8cbfc581b5 | ||
|
|
4bb350a09b | ||
|
|
17c5b9c12c | ||
|
|
1f8580553c | ||
|
|
f92b2b76dc | ||
|
|
1d1a1ddcbf | ||
|
|
dfe5fd8d97 | ||
|
|
ed9fba99c9 | ||
|
|
f100777199 | ||
|
|
f41b7bbd0a | ||
|
|
e35ea2d448 | ||
|
|
bab17d7520 | ||
|
|
051d7d7936 | ||
|
|
b638dafe5f | ||
|
|
e387b1f16c | ||
|
|
71a68dd56d | ||
|
|
3ee8ebd3d3 | ||
|
|
ef298b2f18 | ||
|
|
3cc08494a5 | ||
|
|
afcdabd095 | ||
|
|
efaba6c5b8 | ||
|
|
874715838a | ||
|
|
167eb9ddfa | ||
|
|
fba344718f | ||
|
|
cdd906e32e | ||
|
|
ff0ef3bb43 | ||
|
|
0095832be3 | ||
|
|
406ccf9b87 | ||
|
|
f90d6238ed | ||
|
|
f004a0b8c3 | ||
|
|
49423da081 | ||
|
|
364cf5b429 | ||
|
|
b2f24e38ed | ||
|
|
49037e7b28 | ||
|
|
c66832d299 | ||
|
|
7398b4ce70 | ||
|
|
a61b2026eb | ||
|
|
69ade34c2c | ||
|
|
fbca5441f6 | ||
|
|
e4680caebb | ||
|
|
e760d28c5a | ||
|
|
7d5f0f9d18 | ||
|
|
515f4e8642 | ||
|
|
f2b36b9234 | ||
|
|
f224978bbc | ||
|
|
8819a37a05 | ||
|
|
769dff00ba | ||
|
|
d1be7a984e | ||
|
|
3e30607a6d | ||
|
|
d08e58279d | ||
|
|
7bc542abff | ||
|
|
ed50c36789 | ||
|
|
98cf65b425 | ||
|
|
5406083850 | ||
|
|
91ae9b33d3 | ||
|
|
a42175c067 | ||
|
|
8497145db2 | ||
|
|
89544fad61 | ||
|
|
1151accf4b | ||
|
|
1ae3f1830b | ||
|
|
1e958b62ad | ||
|
|
fdf5367f4f | ||
|
|
0e8842a007 | ||
|
|
060994f393 | ||
|
|
61b605e724 | ||
|
|
61d9dc9511 | ||
|
|
76275e533e |
2
.github/workflows/build.yml
vendored
2
.github/workflows/build.yml
vendored
@@ -4,7 +4,7 @@ on:
|
||||
workflow_dispatch:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- dev
|
||||
|
||||
concurrency: ${{ github.workflow }}-${{ github.ref }}
|
||||
|
||||
|
||||
2
.github/workflows/release.yml
vendored
2
.github/workflows/release.yml
vendored
@@ -36,5 +36,5 @@ jobs:
|
||||
version: latest
|
||||
args: release --clean
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.HOMEBREW_GITHUB_TOKEN }}
|
||||
GITHUB_TOKEN: ${{ secrets.SST_GITHUB_TOKEN }}
|
||||
AUR_KEY: ${{ secrets.AUR_KEY }}
|
||||
|
||||
@@ -12,7 +12,7 @@ builds:
|
||||
- amd64
|
||||
- arm64
|
||||
ldflags:
|
||||
- -s -w -X github.com/opencode-ai/opencode/internal/version.Version={{.Version}}
|
||||
- -s -w -X github.com/sst/opencode/internal/version.Version={{.Version}}
|
||||
main: ./main.go
|
||||
|
||||
archives:
|
||||
@@ -35,10 +35,11 @@ snapshot:
|
||||
name_template: "0.0.0-{{ .Timestamp }}"
|
||||
aurs:
|
||||
- name: opencode
|
||||
homepage: "https://github.com/opencode-ai/opencode"
|
||||
homepage: "https://github.com/sst/opencode"
|
||||
description: "terminal based agent that can build anything"
|
||||
maintainers:
|
||||
- "kujtimiihoxha <kujtimii.h@gmail.com>"
|
||||
- "dax"
|
||||
- "adam"
|
||||
license: "MIT"
|
||||
private_key: "{{ .Env.AUR_KEY }}"
|
||||
git_url: "ssh://aur@aur.archlinux.org/opencode-bin.git"
|
||||
@@ -50,7 +51,7 @@ aurs:
|
||||
install -Dm755 ./opencode "${pkgdir}/usr/bin/opencode"
|
||||
brews:
|
||||
- repository:
|
||||
owner: opencode-ai
|
||||
owner: sst
|
||||
name: homebrew-tap
|
||||
nfpms:
|
||||
- maintainer: kujtimiihoxha
|
||||
|
||||
@@ -1,8 +1,3 @@
|
||||
{
|
||||
"$schema": "./opencode-schema.json",
|
||||
"lsp": {
|
||||
"gopls": {
|
||||
"command": "gopls"
|
||||
}
|
||||
}
|
||||
"$schema": "./opencode-schema.json"
|
||||
}
|
||||
|
||||
24
CONTEXT.md
Normal file
24
CONTEXT.md
Normal file
@@ -0,0 +1,24 @@
|
||||
# OpenCode Development Context
|
||||
|
||||
## Build Commands
|
||||
- Build: `go build`
|
||||
- Run: `go run main.go`
|
||||
- Test: `go test ./...`
|
||||
- Test single package: `go test ./internal/package/...`
|
||||
- Test single test: `go test ./internal/package -run TestName`
|
||||
- Verbose test: `go test -v ./...`
|
||||
- Coverage: `go test -cover ./...`
|
||||
- Lint: `go vet ./...`
|
||||
- Format: `go fmt ./...`
|
||||
- Build snapshot: `./scripts/snapshot`
|
||||
|
||||
## Code Style
|
||||
- Use Go 1.24+ features
|
||||
- Follow standard Go formatting (gofmt)
|
||||
- Use table-driven tests with t.Parallel() when possible
|
||||
- Error handling: check errors immediately, return early
|
||||
- Naming: CamelCase for exported, camelCase for unexported
|
||||
- Imports: standard library first, then external, then internal
|
||||
- Use context.Context for cancellation and timeouts
|
||||
- Prefer interfaces for dependencies to enable testing
|
||||
- Use testify for assertions in tests
|
||||
160
README.md
160
README.md
@@ -1,5 +1,7 @@
|
||||
# ⌬ OpenCode
|
||||
|
||||

|
||||
|
||||
> **⚠️ Early Development Notice:** This project is in early development and is not yet ready for production use. Features may change, break, or be incomplete. Use at your own risk.
|
||||
|
||||
A powerful terminal-based AI assistant for developers, providing intelligent coding assistance directly in your terminal.
|
||||
@@ -35,7 +37,7 @@ curl -fsSL https://opencode.ai/install | VERSION=0.1.0 bash
|
||||
### Using Homebrew (macOS and Linux)
|
||||
|
||||
```bash
|
||||
brew install opencode-ai/tap/opencode
|
||||
brew install sst/tap/opencode
|
||||
```
|
||||
|
||||
### Using AUR (Arch Linux)
|
||||
@@ -51,7 +53,7 @@ paru -S opencode-bin
|
||||
### Using Go
|
||||
|
||||
```bash
|
||||
go install github.com/opencode-ai/opencode@latest
|
||||
go install github.com/sst/opencode@latest
|
||||
```
|
||||
|
||||
## Configuration
|
||||
@@ -67,7 +69,7 @@ OpenCode looks for configuration in the following locations:
|
||||
You can configure OpenCode using environment variables:
|
||||
|
||||
| Environment Variable | Purpose |
|
||||
|----------------------------|--------------------------------------------------------|
|
||||
| -------------------------- | ------------------------------------------------------ |
|
||||
| `ANTHROPIC_API_KEY` | For Claude models |
|
||||
| `OPENAI_API_KEY` | For OpenAI models |
|
||||
| `GEMINI_API_KEY` | For Google Gemini models |
|
||||
@@ -79,7 +81,6 @@ You can configure OpenCode using environment variables:
|
||||
| `AZURE_OPENAI_API_KEY` | For Azure OpenAI models (optional when using Entra ID) |
|
||||
| `AZURE_OPENAI_API_VERSION` | For Azure OpenAI models |
|
||||
|
||||
|
||||
### Configuration File Structure
|
||||
|
||||
```json
|
||||
@@ -106,7 +107,7 @@ You can configure OpenCode using environment variables:
|
||||
}
|
||||
},
|
||||
"agents": {
|
||||
"coder": {
|
||||
"primary": {
|
||||
"model": "claude-3.7-sonnet",
|
||||
"maxTokens": 5000
|
||||
},
|
||||
@@ -296,12 +297,81 @@ OpenCode's AI assistant has access to various tools to help with coding tasks:
|
||||
|
||||
### Other Tools
|
||||
|
||||
| Tool | Description | Parameters |
|
||||
| ------------- | -------------------------------------- | ----------------------------------------------------------------------------------------- |
|
||||
| `bash` | Execute shell commands | `command` (required), `timeout` (optional) |
|
||||
| `fetch` | Fetch data from URLs | `url` (required), `format` (required), `timeout` (optional) |
|
||||
| `sourcegraph` | Search code across public repositories | `query` (required), `count` (optional), `context_window` (optional), `timeout` (optional) |
|
||||
| `agent` | Run sub-tasks with the AI agent | `prompt` (required) |
|
||||
| Tool | Description | Parameters |
|
||||
| ------- | ------------------------------- | ----------------------------------------------------------- |
|
||||
| `bash` | Execute shell commands | `command` (required), `timeout` (optional) |
|
||||
| `fetch` | Fetch data from URLs | `url` (required), `format` (required), `timeout` (optional) |
|
||||
| `agent` | Run sub-tasks with the AI agent | `prompt` (required) |
|
||||
|
||||
## Theming
|
||||
|
||||
OpenCode supports multiple themes for customizing the appearance of the terminal interface.
|
||||
|
||||
### Available Themes
|
||||
|
||||
The following predefined themes are available:
|
||||
|
||||
- `opencode` (default)
|
||||
- `catppuccin`
|
||||
- `dracula`
|
||||
- `flexoki`
|
||||
- `gruvbox`
|
||||
- `monokai`
|
||||
- `onedark`
|
||||
- `tokyonight`
|
||||
- `tron`
|
||||
- `custom` (user-defined)
|
||||
|
||||
### Setting a Theme
|
||||
|
||||
You can set a theme in your `.opencode.json` configuration file:
|
||||
|
||||
```json
|
||||
{
|
||||
"tui": {
|
||||
"theme": "monokai"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Custom Themes
|
||||
|
||||
You can define your own custom theme by setting the `theme` to `"custom"` and providing color definitions in the `customTheme` map:
|
||||
|
||||
```json
|
||||
{
|
||||
"tui": {
|
||||
"theme": "custom",
|
||||
"customTheme": {
|
||||
"primary": "#ffcc00",
|
||||
"secondary": "#00ccff",
|
||||
"accent": { "dark": "#aa00ff", "light": "#ddccff" },
|
||||
"error": "#ff0000"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
#### Color Definition Formats
|
||||
|
||||
Custom theme colors support two formats:
|
||||
|
||||
1. **Simple Hex String**: A single hex color string (e.g., `"#aabbcc"`) that will be used for both light and dark terminal backgrounds.
|
||||
|
||||
2. **Adaptive Object**: An object with `dark` and `light` keys, each holding a hex color string. This allows for adaptive colors based on the terminal's background.
|
||||
|
||||
#### Available Color Keys
|
||||
|
||||
You can define any of the following color keys in your `customTheme`:
|
||||
|
||||
- Base colors: `primary`, `secondary`, `accent`
|
||||
- Status colors: `error`, `warning`, `success`, `info`
|
||||
- Text colors: `text`, `textMuted`, `textEmphasized`
|
||||
- Background colors: `background`, `backgroundSecondary`, `backgroundDarker`
|
||||
- Border colors: `borderNormal`, `borderFocused`, `borderDim`
|
||||
- Diff view colors: `diffAdded`, `diffRemoved`, `diffContext`, etc.
|
||||
|
||||
You don't need to define all colors. Any undefined colors will fall back to the default "opencode" theme colors.
|
||||
|
||||
## Architecture
|
||||
|
||||
@@ -318,6 +388,72 @@ OpenCode is built with a modular architecture:
|
||||
- **internal/session**: Session management
|
||||
- **internal/lsp**: Language Server Protocol integration
|
||||
|
||||
## Custom Commands
|
||||
|
||||
OpenCode supports custom commands that can be created by users to quickly send predefined prompts to the AI assistant.
|
||||
|
||||
### Creating Custom Commands
|
||||
|
||||
Custom commands are predefined prompts stored as Markdown files in one of three locations:
|
||||
|
||||
1. **User Commands** (prefixed with `user:`):
|
||||
|
||||
```
|
||||
$XDG_CONFIG_HOME/opencode/commands/
|
||||
```
|
||||
|
||||
(typically `~/.config/opencode/commands/` on Linux/macOS)
|
||||
|
||||
or
|
||||
|
||||
```
|
||||
$HOME/.opencode/commands/
|
||||
```
|
||||
|
||||
2. **Project Commands** (prefixed with `project:`):
|
||||
```
|
||||
<PROJECT DIR>/.opencode/commands/
|
||||
```
|
||||
|
||||
Each `.md` file in these directories becomes a custom command. The file name (without extension) becomes the command ID.
|
||||
|
||||
For example, creating a file at `~/.config/opencode/commands/prime-context.md` with content:
|
||||
|
||||
```markdown
|
||||
RUN git ls-files
|
||||
READ README.md
|
||||
```
|
||||
|
||||
This creates a command called `user:prime-context`.
|
||||
|
||||
### Command Arguments
|
||||
|
||||
You can create commands that accept arguments by including the `$ARGUMENTS` placeholder in your command file:
|
||||
|
||||
```markdown
|
||||
RUN git show $ARGUMENTS
|
||||
```
|
||||
|
||||
When you run this command, OpenCode will prompt you to enter the text that should replace `$ARGUMENTS`.
|
||||
|
||||
### Organizing Commands
|
||||
|
||||
You can organize commands in subdirectories:
|
||||
|
||||
```
|
||||
~/.config/opencode/commands/git/commit.md
|
||||
```
|
||||
|
||||
This creates a command with ID `user:git:commit`.
|
||||
|
||||
### Using Custom Commands
|
||||
|
||||
1. Press `Ctrl+K` to open the command dialog
|
||||
2. Select your custom command (prefixed with either `user:` or `project:`)
|
||||
3. Press Enter to execute the command
|
||||
|
||||
The content of the command file will be sent as a message to the AI assistant.
|
||||
|
||||
## MCP (Model Context Protocol)
|
||||
|
||||
OpenCode implements the Model Context Protocol (MCP) to extend its capabilities through external tools. MCP provides a standardized way for the AI assistant to interact with external services and tools.
|
||||
@@ -408,7 +544,7 @@ While the LSP client implementation supports the full LSP protocol (including co
|
||||
|
||||
```bash
|
||||
# Clone the repository
|
||||
git clone https://github.com/opencode-ai/opencode.git
|
||||
git clone https://github.com/sst/opencode.git
|
||||
cd opencode
|
||||
|
||||
# Build
|
||||
|
||||
99
cmd/root.go
99
cmd/root.go
@@ -7,19 +7,42 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"log/slog"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/opencode-ai/opencode/internal/app"
|
||||
"github.com/opencode-ai/opencode/internal/config"
|
||||
"github.com/opencode-ai/opencode/internal/db"
|
||||
"github.com/opencode-ai/opencode/internal/llm/agent"
|
||||
"github.com/opencode-ai/opencode/internal/logging"
|
||||
"github.com/opencode-ai/opencode/internal/pubsub"
|
||||
"github.com/opencode-ai/opencode/internal/tui"
|
||||
"github.com/opencode-ai/opencode/internal/version"
|
||||
zone "github.com/lrstanley/bubblezone"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/sst/opencode/internal/app"
|
||||
"github.com/sst/opencode/internal/config"
|
||||
"github.com/sst/opencode/internal/db"
|
||||
"github.com/sst/opencode/internal/llm/agent"
|
||||
"github.com/sst/opencode/internal/logging"
|
||||
"github.com/sst/opencode/internal/lsp/discovery"
|
||||
"github.com/sst/opencode/internal/pubsub"
|
||||
"github.com/sst/opencode/internal/tui"
|
||||
"github.com/sst/opencode/internal/version"
|
||||
)
|
||||
|
||||
type SessionIDHandler struct {
|
||||
slog.Handler
|
||||
app *app.App
|
||||
}
|
||||
|
||||
func (h *SessionIDHandler) Handle(ctx context.Context, r slog.Record) error {
|
||||
if h.app != nil {
|
||||
sessionID := h.app.CurrentSession.ID
|
||||
if sessionID != "" {
|
||||
r.AddAttrs(slog.String("session_id", sessionID))
|
||||
}
|
||||
}
|
||||
return h.Handler.Handle(ctx, r)
|
||||
}
|
||||
|
||||
func (h *SessionIDHandler) WithApp(app *app.App) *SessionIDHandler {
|
||||
h.app = app
|
||||
return h
|
||||
}
|
||||
|
||||
var rootCmd = &cobra.Command{
|
||||
Use: "OpenCode",
|
||||
Short: "A terminal AI assistant for software development",
|
||||
@@ -37,6 +60,13 @@ to assist developers in writing, debugging, and understanding code directly from
|
||||
return nil
|
||||
}
|
||||
|
||||
// Setup logging
|
||||
lvl := new(slog.LevelVar)
|
||||
textHandler := slog.NewTextHandler(logging.NewSlogWriter(), &slog.HandlerOptions{Level: lvl})
|
||||
sessionAwareHandler := &SessionIDHandler{Handler: textHandler}
|
||||
logger := slog.New(sessionAwareHandler)
|
||||
slog.SetDefault(logger)
|
||||
|
||||
// Load the config
|
||||
debug, _ := cmd.Flags().GetBool("debug")
|
||||
cwd, _ := cmd.Flags().GetString("cwd")
|
||||
@@ -53,11 +83,17 @@ to assist developers in writing, debugging, and understanding code directly from
|
||||
}
|
||||
cwd = c
|
||||
}
|
||||
_, err := config.Load(cwd, debug)
|
||||
_, err := config.Load(cwd, debug, lvl)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Run LSP auto-discovery
|
||||
if err := discovery.IntegrateLSPServers(cwd); err != nil {
|
||||
slog.Warn("Failed to auto-discover LSP servers", "error", err)
|
||||
// Continue anyway, this is not a fatal error
|
||||
}
|
||||
|
||||
// Connect DB, this will also run migrations
|
||||
conn, err := db.Connect()
|
||||
if err != nil {
|
||||
@@ -70,16 +106,16 @@ to assist developers in writing, debugging, and understanding code directly from
|
||||
|
||||
app, err := app.New(ctx, conn)
|
||||
if err != nil {
|
||||
logging.Error("Failed to create app: %v", err)
|
||||
slog.Error("Failed to create app", "error", err)
|
||||
return err
|
||||
}
|
||||
sessionAwareHandler.WithApp(app)
|
||||
|
||||
// Set up the TUI
|
||||
zone.NewGlobal()
|
||||
program := tea.NewProgram(
|
||||
tui.New(app),
|
||||
tea.WithAltScreen(),
|
||||
tea.WithMouseCellMotion(),
|
||||
)
|
||||
|
||||
// Initialize MCP tools in the background
|
||||
@@ -103,11 +139,11 @@ to assist developers in writing, debugging, and understanding code directly from
|
||||
for {
|
||||
select {
|
||||
case <-tuiCtx.Done():
|
||||
logging.Info("TUI message handler shutting down")
|
||||
slog.Info("TUI message handler shutting down")
|
||||
return
|
||||
case msg, ok := <-ch:
|
||||
if !ok {
|
||||
logging.Info("TUI message channel closed")
|
||||
slog.Info("TUI message channel closed")
|
||||
return
|
||||
}
|
||||
program.Send(msg)
|
||||
@@ -117,19 +153,19 @@ to assist developers in writing, debugging, and understanding code directly from
|
||||
|
||||
// Cleanup function for when the program exits
|
||||
cleanup := func() {
|
||||
// Shutdown the app
|
||||
app.Shutdown()
|
||||
|
||||
// Cancel subscriptions first
|
||||
cancelSubs()
|
||||
|
||||
// Then shutdown the app
|
||||
app.Shutdown()
|
||||
|
||||
// Then cancel TUI message handler
|
||||
tuiCancel()
|
||||
|
||||
// Wait for TUI message handler to finish
|
||||
tuiWg.Wait()
|
||||
|
||||
logging.Info("All goroutines cleaned up")
|
||||
slog.Info("All goroutines cleaned up")
|
||||
}
|
||||
|
||||
// Run the TUI
|
||||
@@ -137,18 +173,18 @@ to assist developers in writing, debugging, and understanding code directly from
|
||||
cleanup()
|
||||
|
||||
if err != nil {
|
||||
logging.Error("TUI error: %v", err)
|
||||
slog.Error("TUI error", "error", err)
|
||||
return fmt.Errorf("TUI error: %v", err)
|
||||
}
|
||||
|
||||
logging.Info("TUI exited with result: %v", result)
|
||||
slog.Info("TUI exited", "result", result)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
// attemptTUIRecovery tries to recover the TUI after a panic
|
||||
func attemptTUIRecovery(program *tea.Program) {
|
||||
logging.Info("Attempting to recover TUI after panic")
|
||||
slog.Info("Attempting to recover TUI after panic")
|
||||
|
||||
// We could try to restart the TUI or gracefully exit
|
||||
// For now, we'll just quit the program to avoid further issues
|
||||
@@ -165,7 +201,7 @@ func initMCPTools(ctx context.Context, app *app.App) {
|
||||
|
||||
// Set this up once with proper error handling
|
||||
agent.GetMcpTools(ctxWithTimeout, app.Permissions)
|
||||
logging.Info("MCP message handling goroutine exiting")
|
||||
slog.Info("MCP message handling goroutine exiting")
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -182,12 +218,16 @@ func setupSubscriber[T any](
|
||||
defer logging.RecoverPanic(fmt.Sprintf("subscription-%s", name), nil)
|
||||
|
||||
subCh := subscriber(ctx)
|
||||
if subCh == nil {
|
||||
slog.Warn("subscription channel is nil", "name", name)
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case event, ok := <-subCh:
|
||||
if !ok {
|
||||
logging.Info("subscription channel closed", "name", name)
|
||||
slog.Info("subscription channel closed", "name", name)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -196,13 +236,13 @@ func setupSubscriber[T any](
|
||||
select {
|
||||
case outputCh <- msg:
|
||||
case <-time.After(2 * time.Second):
|
||||
logging.Warn("message dropped due to slow consumer", "name", name)
|
||||
slog.Warn("message dropped due to slow consumer", "name", name)
|
||||
case <-ctx.Done():
|
||||
logging.Info("subscription cancelled", "name", name)
|
||||
slog.Info("subscription cancelled", "name", name)
|
||||
return
|
||||
}
|
||||
case <-ctx.Done():
|
||||
logging.Info("subscription cancelled", "name", name)
|
||||
slog.Info("subscription cancelled", "name", name)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -215,13 +255,14 @@ func setupSubscriptions(app *app.App, parentCtx context.Context) (chan tea.Msg,
|
||||
wg := sync.WaitGroup{}
|
||||
ctx, cancel := context.WithCancel(parentCtx) // Inherit from parent context
|
||||
|
||||
setupSubscriber(ctx, &wg, "logging", logging.Subscribe, ch)
|
||||
setupSubscriber(ctx, &wg, "logging", app.Logs.Subscribe, ch)
|
||||
setupSubscriber(ctx, &wg, "sessions", app.Sessions.Subscribe, ch)
|
||||
setupSubscriber(ctx, &wg, "messages", app.Messages.Subscribe, ch)
|
||||
setupSubscriber(ctx, &wg, "permissions", app.Permissions.Subscribe, ch)
|
||||
setupSubscriber(ctx, &wg, "status", app.Status.Subscribe, ch)
|
||||
|
||||
cleanupFunc := func() {
|
||||
logging.Info("Cancelling all subscriptions")
|
||||
slog.Info("Cancelling all subscriptions")
|
||||
cancel() // Signal all goroutines to stop
|
||||
|
||||
waitCh := make(chan struct{})
|
||||
@@ -233,10 +274,10 @@ func setupSubscriptions(app *app.App, parentCtx context.Context) (chan tea.Msg,
|
||||
|
||||
select {
|
||||
case <-waitCh:
|
||||
logging.Info("All subscription goroutines completed successfully")
|
||||
slog.Info("All subscription goroutines completed successfully")
|
||||
close(ch) // Only close after all writers are confirmed done
|
||||
case <-time.After(5 * time.Second):
|
||||
logging.Warn("Timed out waiting for some subscription goroutines to complete")
|
||||
slog.Warn("Timed out waiting for some subscription goroutines to complete")
|
||||
close(ch)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -46,7 +46,7 @@ Here's an example configuration that conforms to the schema:
|
||||
}
|
||||
},
|
||||
"agents": {
|
||||
"coder": {
|
||||
"primary": {
|
||||
"model": "claude-3.7-sonnet",
|
||||
"maxTokens": 5000,
|
||||
"reasoningEffort": "medium"
|
||||
@@ -61,4 +61,5 @@ Here's an example configuration that conforms to the schema:
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
```
|
||||
|
||||
|
||||
@@ -5,8 +5,8 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/opencode-ai/opencode/internal/config"
|
||||
"github.com/opencode-ai/opencode/internal/llm/models"
|
||||
"github.com/sst/opencode/internal/config"
|
||||
"github.com/sst/opencode/internal/llm/models"
|
||||
)
|
||||
|
||||
// JSONSchemaType represents a JSON Schema type
|
||||
@@ -98,6 +98,57 @@ func generateSchema() map[string]any {
|
||||
},
|
||||
}
|
||||
|
||||
schema["properties"].(map[string]any)["tui"] = map[string]any{
|
||||
"type": "object",
|
||||
"description": "Terminal User Interface configuration",
|
||||
"properties": map[string]any{
|
||||
"theme": map[string]any{
|
||||
"type": "string",
|
||||
"description": "TUI theme name",
|
||||
"default": "opencode",
|
||||
"enum": []string{
|
||||
"opencode",
|
||||
"catppuccin",
|
||||
"dracula",
|
||||
"flexoki",
|
||||
"gruvbox",
|
||||
"monokai",
|
||||
"onedark",
|
||||
"tokyonight",
|
||||
"tron",
|
||||
"custom",
|
||||
},
|
||||
},
|
||||
"customTheme": map[string]any{
|
||||
"type": "object",
|
||||
"description": "Custom theme color definitions",
|
||||
"additionalProperties": map[string]any{
|
||||
"oneOf": []map[string]any{
|
||||
{
|
||||
"type": "string",
|
||||
"pattern": "^#[0-9a-fA-F]{6}$",
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"dark": map[string]any{
|
||||
"type": "string",
|
||||
"pattern": "^#[0-9a-fA-F]{6}$",
|
||||
},
|
||||
"light": map[string]any{
|
||||
"type": "string",
|
||||
"pattern": "^#[0-9a-fA-F]{6}$",
|
||||
},
|
||||
},
|
||||
"required": []string{"dark", "light"},
|
||||
"additionalProperties": false,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Add MCP servers
|
||||
schema["properties"].(map[string]any)["mcpServers"] = map[string]any{
|
||||
"type": "object",
|
||||
@@ -223,7 +274,7 @@ func generateSchema() map[string]any {
|
||||
// Add specific agent properties
|
||||
agentProperties := map[string]any{}
|
||||
knownAgents := []string{
|
||||
string(config.AgentCoder),
|
||||
string(config.AgentPrimary),
|
||||
string(config.AgentTask),
|
||||
string(config.AgentTitle),
|
||||
}
|
||||
|
||||
24
go.mod
24
go.mod
@@ -1,9 +1,7 @@
|
||||
module github.com/opencode-ai/opencode
|
||||
module github.com/sst/opencode
|
||||
|
||||
go 1.24.0
|
||||
|
||||
toolchain go1.24.2
|
||||
|
||||
require (
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.7.0
|
||||
github.com/JohannesKaufmann/html-to-markdown v1.6.0
|
||||
@@ -16,12 +14,10 @@ require (
|
||||
github.com/charmbracelet/bubbles v0.20.0
|
||||
github.com/charmbracelet/bubbletea v1.3.4
|
||||
github.com/charmbracelet/glamour v0.9.1
|
||||
github.com/charmbracelet/huh v0.6.0
|
||||
github.com/charmbracelet/lipgloss v1.1.0
|
||||
github.com/charmbracelet/x/ansi v0.8.0
|
||||
github.com/fsnotify/fsnotify v1.8.0
|
||||
github.com/go-logfmt/logfmt v0.6.0
|
||||
github.com/google/generative-ai-go v0.19.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/lrstanley/bubblezone v0.0.0-20250315020633-c249a3fe1231
|
||||
github.com/mark3labs/mcp-go v0.17.0
|
||||
@@ -35,16 +31,12 @@ require (
|
||||
github.com/spf13/cobra v1.9.1
|
||||
github.com/spf13/viper v1.20.0
|
||||
github.com/stretchr/testify v1.10.0
|
||||
google.golang.org/api v0.215.0
|
||||
)
|
||||
|
||||
require (
|
||||
cloud.google.com/go v0.116.0 // indirect
|
||||
cloud.google.com/go/ai v0.8.0 // indirect
|
||||
cloud.google.com/go/auth v0.13.0 // indirect
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.6 // indirect
|
||||
cloud.google.com/go/compute/metadata v0.6.0 // indirect
|
||||
cloud.google.com/go/longrunning v0.5.7 // indirect
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.17.0 // indirect
|
||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0 // indirect
|
||||
github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2 // indirect
|
||||
@@ -68,30 +60,30 @@ require (
|
||||
github.com/aymerick/douceur v0.2.0 // indirect
|
||||
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect
|
||||
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd // indirect
|
||||
github.com/charmbracelet/x/exp/strings v0.0.0-20240722160745-212f7b056ed0 // indirect
|
||||
github.com/charmbracelet/x/term v0.2.1 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/disintegration/imaging v1.6.2
|
||||
github.com/dlclark/regexp2 v1.11.4 // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
|
||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||
github.com/go-logr/logr v1.4.2 // indirect
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
github.com/go-viper/mapstructure/v2 v2.2.1 // indirect
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2 // indirect
|
||||
github.com/google/go-cmp v0.7.0 // indirect
|
||||
github.com/google/s2a-go v0.1.8 // indirect
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect
|
||||
github.com/googleapis/gax-go/v2 v2.14.1 // indirect
|
||||
github.com/gorilla/css v1.0.1 // indirect
|
||||
github.com/gorilla/websocket v1.5.3 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/kylelemons/godebug v1.1.0 // indirect
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mattn/go-localereader v0.0.1 // indirect
|
||||
github.com/mattn/go-runewidth v0.0.16 // indirect
|
||||
github.com/mfridman/interpolate v0.0.2 // indirect
|
||||
github.com/microcosm-cc/bluemonday v1.0.27 // indirect
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2 // indirect
|
||||
github.com/muesli/cancelreader v0.2.2 // indirect
|
||||
github.com/ncruces/julianday v1.0.0 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.3 // indirect
|
||||
@@ -116,21 +108,19 @@ require (
|
||||
github.com/yuin/goldmark v1.7.8 // indirect
|
||||
github.com/yuin/goldmark-emoji v1.0.5 // indirect
|
||||
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.54.0 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0 // indirect
|
||||
go.opentelemetry.io/otel v1.35.0 // indirect
|
||||
go.opentelemetry.io/otel/metric v1.35.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.35.0 // indirect
|
||||
go.uber.org/multierr v1.11.0 // indirect
|
||||
golang.org/x/crypto v0.37.0 // indirect
|
||||
golang.org/x/image v0.26.0 // indirect
|
||||
golang.org/x/net v0.39.0 // indirect
|
||||
golang.org/x/oauth2 v0.25.0 // indirect
|
||||
golang.org/x/sync v0.13.0 // indirect
|
||||
golang.org/x/sys v0.32.0 // indirect
|
||||
golang.org/x/term v0.31.0 // indirect
|
||||
golang.org/x/text v0.24.0 // indirect
|
||||
golang.org/x/time v0.8.0 // indirect
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20250106144421-5f5ef82da422 // indirect
|
||||
google.golang.org/genai v1.3.0
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250324211829-b45e905df463 // indirect
|
||||
google.golang.org/grpc v1.71.0 // indirect
|
||||
google.golang.org/protobuf v1.36.6 // indirect
|
||||
|
||||
33
go.sum
33
go.sum
@@ -1,15 +1,9 @@
|
||||
cloud.google.com/go v0.116.0 h1:B3fRrSDkLRt5qSHWe40ERJvhvnQwdZiHu0bJOpldweE=
|
||||
cloud.google.com/go v0.116.0/go.mod h1:cEPSRWPzZEswwdr9BxE6ChEn01dWlTaF05LiC2Xs70U=
|
||||
cloud.google.com/go/ai v0.8.0 h1:rXUEz8Wp2OlrM8r1bfmpF2+VKqc1VJpafE3HgzRnD/w=
|
||||
cloud.google.com/go/ai v0.8.0/go.mod h1:t3Dfk4cM61sytiggo2UyGsDVW3RF1qGZaUKDrZFyqkE=
|
||||
cloud.google.com/go/auth v0.13.0 h1:8Fu8TZy167JkW8Tj3q7dIkr2v4cndv41ouecJx0PAHs=
|
||||
cloud.google.com/go/auth v0.13.0/go.mod h1:COOjD9gwfKNKz+IIduatIhYJQIc0mG3H102r/EMxX6Q=
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.6 h1:V6a6XDu2lTwPZWOawrAa9HUK+DB2zfJyTuciBG5hFkU=
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.6/go.mod h1:AlmsELtlEBnaNTL7jCj8VQFLy6mbZv0s4Q7NGBeQ5E8=
|
||||
cloud.google.com/go/compute/metadata v0.6.0 h1:A6hENjEsCDtC1k8byVsgwvVcioamEHvZ4j01OwKxG9I=
|
||||
cloud.google.com/go/compute/metadata v0.6.0/go.mod h1:FjyFAW1MW0C203CEOMDTu3Dk1FlqW3Rga40jzHL4hfg=
|
||||
cloud.google.com/go/longrunning v0.5.7 h1:WLbHekDbjK1fVFD3ibpFFVoyizlLRl73I7YKuAKilhU=
|
||||
cloud.google.com/go/longrunning v0.5.7/go.mod h1:8GClkudohy1Fxm3owmBGid8W0pSgodEMwEAztp38Xng=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.17.0 h1:g0EZJwz7xkXQiZAI5xi9f3WWFYBlX1CPTrR+NDToRkQ=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.17.0/go.mod h1:XCW7KnZet0Opnr7HccfUw1PLc4CjHqpcaxW8DHklNkQ=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.7.0 h1:tfLQ34V6F7tVSwoTf/4lH5sE0o6eCJuNDTmH09nDpbc=
|
||||
@@ -82,8 +76,6 @@ github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc h1:4p
|
||||
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc/go.mod h1:X4/0JoqgTIPSFcRA/P6INZzIuyqdFY5rm8tb41s9okk=
|
||||
github.com/charmbracelet/glamour v0.9.1 h1:11dEfiGP8q1BEqvGoIjivuc2rBk+5qEXdPtaQ2WoiCM=
|
||||
github.com/charmbracelet/glamour v0.9.1/go.mod h1:+SHvIS8qnwhgTpVMiXwn7OfGomSqff1cHBCI8jLOetk=
|
||||
github.com/charmbracelet/huh v0.6.0 h1:mZM8VvZGuE0hoDXq6XLxRtgfWyTI3b2jZNKh0xWmax8=
|
||||
github.com/charmbracelet/huh v0.6.0/go.mod h1:GGNKeWCeNzKpEOh/OJD8WBwTQjV3prFAtQPpLv+AVwU=
|
||||
github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY=
|
||||
github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30=
|
||||
github.com/charmbracelet/x/ansi v0.8.0 h1:9GTq3xq9caJW8ZrBTe0LIe2fvfLR/bYXKTx2llXn7xE=
|
||||
@@ -92,14 +84,14 @@ github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd h1:vy0G
|
||||
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd/go.mod h1:xe0nKWGd3eJgtqZRaN9RjMtK7xUYchjzPr7q6kcvCCs=
|
||||
github.com/charmbracelet/x/exp/golden v0.0.0-20240815200342-61de596daa2b h1:MnAMdlwSltxJyULnrYbkZpp4k58Co7Tah3ciKhSNo0Q=
|
||||
github.com/charmbracelet/x/exp/golden v0.0.0-20240815200342-61de596daa2b/go.mod h1:wDlXFlCrmJ8J+swcL/MnGUuYnqgQdW9rhSD61oNMb6U=
|
||||
github.com/charmbracelet/x/exp/strings v0.0.0-20240722160745-212f7b056ed0 h1:qko3AQ4gK1MTS/de7F5hPGx6/k1u0w4TeYmBFwzYVP4=
|
||||
github.com/charmbracelet/x/exp/strings v0.0.0-20240722160745-212f7b056ed0/go.mod h1:pBhA0ybfXv6hDjQUZ7hk1lVxBiUbupdw5R31yPUViVQ=
|
||||
github.com/charmbracelet/x/term v0.2.1 h1:AQeHeLZ1OqSXhrAWpYUtZyX1T3zVxfpZuEQMIQaGIAQ=
|
||||
github.com/charmbracelet/x/term v0.2.1/go.mod h1:oQ4enTYFV7QN4m0i9mzHrViD7TQKvNEEkHUMCmsxdUg=
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/disintegration/imaging v1.6.2 h1:w1LecBlG2Lnp8B3jk5zSuNqd7b4DXhcjwek1ei82L+c=
|
||||
github.com/disintegration/imaging v1.6.2/go.mod h1:44/5580QXChDfwIclfc/PCwrr44amcmDAg8hxG0Ewe4=
|
||||
github.com/dlclark/regexp2 v1.11.4 h1:rPYF9/LECdNymJufQKmri9gV604RvvABwgOA8un7yAo=
|
||||
github.com/dlclark/regexp2 v1.11.4/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||
@@ -125,8 +117,6 @@ github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeD
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
||||
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||
github.com/google/generative-ai-go v0.19.0 h1:R71szggh8wHMCUlEMsW2A/3T+5LdEIkiaHSYgSpUgdg=
|
||||
github.com/google/generative-ai-go v0.19.0/go.mod h1:JYolL13VG7j79kM5BtHz4qwONHkeJQzOCkKXnpqtS/E=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/s2a-go v0.1.8 h1:zZDs9gcbt9ZPLV0ndSyQk6Kacx2g/X+SKYovpnz3SMM=
|
||||
@@ -139,6 +129,8 @@ github.com/googleapis/gax-go/v2 v2.14.1 h1:hb0FFeiPaQskmvakKu5EbCbpntQn48jyHuvrk
|
||||
github.com/googleapis/gax-go/v2 v2.14.1/go.mod h1:Hb/NubMaVM88SrNkvl8X/o8XWwDJEPqouaLeN2IUxoA=
|
||||
github.com/gorilla/css v1.0.1 h1:ntNaBIghp6JmvWnxbZKANoLyuXTPZ4cAMlo6RyhlbO8=
|
||||
github.com/gorilla/css v1.0.1/go.mod h1:BvnYkspnSzMmwRK+b8/xgNPLiIuNZr6vbZBTPQ2A3b0=
|
||||
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM=
|
||||
github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
@@ -169,8 +161,6 @@ github.com/mfridman/interpolate v0.0.2 h1:pnuTK7MQIxxFz1Gr+rjSIx9u7qVjf5VOoM/u6B
|
||||
github.com/mfridman/interpolate v0.0.2/go.mod h1:p+7uk6oE07mpE/Ik1b8EckO0O4ZXiGAfshKBWLUM9Xg=
|
||||
github.com/microcosm-cc/bluemonday v1.0.27 h1:MpEUotklkwCSLeH+Qdx1VJgNqLlpY2KXwXFM08ygZfk=
|
||||
github.com/microcosm-cc/bluemonday v1.0.27/go.mod h1:jFi9vgW+H7c3V0lb6nR74Ib/DIB5OBs92Dimizgw2cA=
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2 h1:vGKWl0YJqUNxE8d+h8f6NJLcCJrgbhC4NcD46KavDd4=
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2/go.mod h1:MG3aRVU/N29oo/V/IhBX8GR/zz4kQkprJgF2EVszyDE=
|
||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI=
|
||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo=
|
||||
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
|
||||
@@ -258,8 +248,6 @@ github.com/yuin/goldmark-emoji v1.0.5 h1:EMVWyCGPlXJfUXBXpuMu+ii3TIaxbVBnEX9uaDC
|
||||
github.com/yuin/goldmark-emoji v1.0.5/go.mod h1:tTkZEbwu5wkPmgTcitqddVxY9osFZiavD+r4AzQrh1U=
|
||||
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
|
||||
go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.54.0 h1:r6I7RJCN86bpD/FQwedZ0vSixDpwuWREjW9oRMsmqDc=
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.54.0/go.mod h1:B9yO6b04uB80CzjedvewuqDhxJxi11s7/GtiGa8bAjI=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0 h1:TT4fX+nBOA/+LUkobKGW1ydGcn+G3vRw9+g5HwCphpk=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0/go.mod h1:L7UH0GbB0p47T4Rri3uHjbpCFYrVrwc1I25QhNPiGK8=
|
||||
go.opentelemetry.io/otel v1.35.0 h1:xKWKPxrxB6OtMCbmMY021CqC45J+3Onta9MqjhnusiQ=
|
||||
@@ -283,6 +271,9 @@ golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE=
|
||||
golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc=
|
||||
golang.org/x/exp v0.0.0-20250305212735-054e65f0b394 h1:nDVHiLt8aIbd/VzvPWN6kSOPE7+F/fNFDSXLVYkE/Iw=
|
||||
golang.org/x/exp v0.0.0-20250305212735-054e65f0b394/go.mod h1:sIifuuw/Yco/y6yb6+bDNfyeQ/MdPUy/hKEMYQV17cM=
|
||||
golang.org/x/image v0.0.0-20191009234506-e7c1f5e7dbb8/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
|
||||
golang.org/x/image v0.26.0 h1:4XjIFEZWQmCZi6Wv8BoxsDhRU3RVnLX04dToTDAEPlY=
|
||||
golang.org/x/image v0.26.0/go.mod h1:lcxbMFAovzpnJxzXS3nyL83K27tmqtKzIJpctK8YO5c=
|
||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
||||
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
@@ -296,8 +287,6 @@ golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8=
|
||||
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
|
||||
golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY=
|
||||
golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E=
|
||||
golang.org/x/oauth2 v0.25.0 h1:CY4y7XT9v0cRI9oupztF8AgiIu99L/ksR/Xp/6jrZ70=
|
||||
golang.org/x/oauth2 v0.25.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
@@ -338,17 +327,13 @@ golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0=
|
||||
golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU=
|
||||
golang.org/x/time v0.8.0 h1:9i3RxcPv3PZnitoVGMPDKZSq1xW1gK1Xy3ArNOGZfEg=
|
||||
golang.org/x/time v0.8.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
|
||||
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/api v0.215.0 h1:jdYF4qnyczlEz2ReWIsosNLDuzXyvFHJtI5gcr0J7t0=
|
||||
google.golang.org/api v0.215.0/go.mod h1:fta3CVtuJYOEdugLNWm6WodzOS8KdFckABwN4I40hzY=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20250106144421-5f5ef82da422 h1:GVIKPyP/kLIyVOgOnTwFOrvQaQUzOzGMCxgFUOEmm24=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20250106144421-5f5ef82da422/go.mod h1:b6h1vNKhxaSoEI+5jc3PJUCustfli/mRab7295pY7rw=
|
||||
google.golang.org/genai v1.3.0 h1:tXhPJF30skOjnnDY7ZnjK3q7IKy4PuAlEA0fk7uEaEI=
|
||||
google.golang.org/genai v1.3.0/go.mod h1:TyfOKRz/QyCaj6f/ZDt505x+YreXnY40l2I6k8TvgqY=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250324211829-b45e905df463 h1:e0AIkUUhxyBKh6ssZNrAMeqhA7RKUj42346d1y02i2g=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250324211829-b45e905df463/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A=
|
||||
google.golang.org/grpc v1.71.0 h1:kF77BGdPTQ4/JZWMlb9VpJ5pa25aqvVqogsxNHHdeBg=
|
||||
|
||||
6
install
6
install
@@ -40,15 +40,15 @@ INSTALL_DIR=$HOME/.opencode/bin
|
||||
mkdir -p "$INSTALL_DIR"
|
||||
|
||||
if [ -z "$requested_version" ]; then
|
||||
url="https://github.com/opencode-ai/opencode/releases/latest/download/$filename"
|
||||
specific_version=$(curl -s https://api.github.com/repos/opencode-ai/opencode/releases/latest | awk -F'"' '/"tag_name": "/ {gsub(/^v/, "", $4); print $4}')
|
||||
url="https://github.com/sst/opencode/releases/latest/download/$filename"
|
||||
specific_version=$(curl -s https://api.github.com/repos/sst/opencode/releases/latest | awk -F'"' '/"tag_name": "/ {gsub(/^v/, "", $4); print $4}')
|
||||
|
||||
if [[ $? -ne 0 ]]; then
|
||||
echo "${RED}Failed to fetch version information${NC}"
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
url="https://github.com/opencode-ai/opencode/releases/download/v${requested_version}/$filename"
|
||||
url="https://github.com/sst/opencode/releases/download/v${requested_version}/$filename"
|
||||
specific_version=$requested_version
|
||||
fi
|
||||
|
||||
|
||||
@@ -7,24 +7,30 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/opencode-ai/opencode/internal/config"
|
||||
"github.com/opencode-ai/opencode/internal/db"
|
||||
"github.com/opencode-ai/opencode/internal/history"
|
||||
"github.com/opencode-ai/opencode/internal/llm/agent"
|
||||
"github.com/opencode-ai/opencode/internal/logging"
|
||||
"github.com/opencode-ai/opencode/internal/lsp"
|
||||
"github.com/opencode-ai/opencode/internal/message"
|
||||
"github.com/opencode-ai/opencode/internal/permission"
|
||||
"github.com/opencode-ai/opencode/internal/session"
|
||||
"log/slog"
|
||||
|
||||
"github.com/sst/opencode/internal/config"
|
||||
"github.com/sst/opencode/internal/history"
|
||||
"github.com/sst/opencode/internal/llm/agent"
|
||||
"github.com/sst/opencode/internal/logging"
|
||||
"github.com/sst/opencode/internal/lsp"
|
||||
"github.com/sst/opencode/internal/message"
|
||||
"github.com/sst/opencode/internal/permission"
|
||||
"github.com/sst/opencode/internal/session"
|
||||
"github.com/sst/opencode/internal/status"
|
||||
"github.com/sst/opencode/internal/tui/theme"
|
||||
)
|
||||
|
||||
type App struct {
|
||||
Sessions session.Service
|
||||
Messages message.Service
|
||||
History history.Service
|
||||
Permissions permission.Service
|
||||
CurrentSession *session.Session
|
||||
Logs logging.Service
|
||||
Sessions session.Service
|
||||
Messages message.Service
|
||||
History history.Service
|
||||
Permissions permission.Service
|
||||
Status status.Service
|
||||
|
||||
CoderAgent agent.Service
|
||||
PrimaryAgent agent.Service
|
||||
|
||||
LSPClients map[string]*lsp.Client
|
||||
|
||||
@@ -36,28 +42,59 @@ type App struct {
|
||||
}
|
||||
|
||||
func New(ctx context.Context, conn *sql.DB) (*App, error) {
|
||||
q := db.New(conn)
|
||||
sessions := session.NewService(q)
|
||||
messages := message.NewService(q)
|
||||
files := history.NewService(q, conn)
|
||||
err := logging.InitService(conn)
|
||||
if err != nil {
|
||||
slog.Error("Failed to initialize logging service", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
err = session.InitService(conn)
|
||||
if err != nil {
|
||||
slog.Error("Failed to initialize session service", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
err = message.InitService(conn)
|
||||
if err != nil {
|
||||
slog.Error("Failed to initialize message service", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
err = history.InitService(conn)
|
||||
if err != nil {
|
||||
slog.Error("Failed to initialize history service", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
err = permission.InitService()
|
||||
if err != nil {
|
||||
slog.Error("Failed to initialize permission service", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
err = status.InitService()
|
||||
if err != nil {
|
||||
slog.Error("Failed to initialize status service", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
app := &App{
|
||||
Sessions: sessions,
|
||||
Messages: messages,
|
||||
History: files,
|
||||
Permissions: permission.NewPermissionService(),
|
||||
LSPClients: make(map[string]*lsp.Client),
|
||||
CurrentSession: &session.Session{},
|
||||
Logs: logging.GetService(),
|
||||
Sessions: session.GetService(),
|
||||
Messages: message.GetService(),
|
||||
History: history.GetService(),
|
||||
Permissions: permission.GetService(),
|
||||
Status: status.GetService(),
|
||||
LSPClients: make(map[string]*lsp.Client),
|
||||
}
|
||||
|
||||
// Initialize theme based on configuration
|
||||
app.initTheme()
|
||||
|
||||
// Initialize LSP clients in the background
|
||||
go app.initLSPClients(ctx)
|
||||
|
||||
var err error
|
||||
app.CoderAgent, err = agent.NewAgent(
|
||||
config.AgentCoder,
|
||||
app.PrimaryAgent, err = agent.NewAgent(
|
||||
config.AgentPrimary,
|
||||
app.Sessions,
|
||||
app.Messages,
|
||||
agent.CoderAgentTools(
|
||||
agent.PrimaryAgentTools(
|
||||
app.Permissions,
|
||||
app.Sessions,
|
||||
app.Messages,
|
||||
@@ -66,13 +103,28 @@ func New(ctx context.Context, conn *sql.DB) (*App, error) {
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
logging.Error("Failed to create coder agent", err)
|
||||
slog.Error("Failed to create primary agent", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return app, nil
|
||||
}
|
||||
|
||||
// initTheme sets the application theme based on the configuration
|
||||
func (app *App) initTheme() {
|
||||
cfg := config.Get()
|
||||
if cfg == nil || cfg.TUI.Theme == "" {
|
||||
return // Use default theme
|
||||
}
|
||||
|
||||
// Try to set the theme from config
|
||||
err := theme.SetTheme(cfg.TUI.Theme)
|
||||
if err != nil {
|
||||
slog.Warn("Failed to set theme from config, using default theme", "theme", cfg.TUI.Theme, "error", err)
|
||||
} else {
|
||||
slog.Debug("Set theme from config", "theme", cfg.TUI.Theme)
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown performs a clean shutdown of the application
|
||||
func (app *App) Shutdown() {
|
||||
@@ -93,7 +145,7 @@ func (app *App) Shutdown() {
|
||||
for name, client := range clients {
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
if err := client.Shutdown(shutdownCtx); err != nil {
|
||||
logging.Error("Failed to shutdown LSP client", "name", name, "error", err)
|
||||
slog.Error("Failed to shutdown LSP client", "name", name, "error", err)
|
||||
}
|
||||
cancel()
|
||||
}
|
||||
|
||||
@@ -4,10 +4,12 @@ import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/opencode-ai/opencode/internal/config"
|
||||
"github.com/opencode-ai/opencode/internal/logging"
|
||||
"github.com/opencode-ai/opencode/internal/lsp"
|
||||
"github.com/opencode-ai/opencode/internal/lsp/watcher"
|
||||
"log/slog"
|
||||
|
||||
"github.com/sst/opencode/internal/config"
|
||||
"github.com/sst/opencode/internal/logging"
|
||||
"github.com/sst/opencode/internal/lsp"
|
||||
"github.com/sst/opencode/internal/lsp/watcher"
|
||||
)
|
||||
|
||||
func (app *App) initLSPClients(ctx context.Context) {
|
||||
@@ -18,29 +20,29 @@ func (app *App) initLSPClients(ctx context.Context) {
|
||||
// Start each client initialization in its own goroutine
|
||||
go app.createAndStartLSPClient(ctx, name, clientConfig.Command, clientConfig.Args...)
|
||||
}
|
||||
logging.Info("LSP clients initialization started in background")
|
||||
slog.Info("LSP clients initialization started in background")
|
||||
}
|
||||
|
||||
// createAndStartLSPClient creates a new LSP client, initializes it, and starts its workspace watcher
|
||||
func (app *App) createAndStartLSPClient(ctx context.Context, name string, command string, args ...string) {
|
||||
// Create a specific context for initialization with a timeout
|
||||
logging.Info("Creating LSP client", "name", name, "command", command, "args", args)
|
||||
|
||||
slog.Info("Creating LSP client", "name", name, "command", command, "args", args)
|
||||
|
||||
// Create the LSP client
|
||||
lspClient, err := lsp.NewClient(ctx, command, args...)
|
||||
if err != nil {
|
||||
logging.Error("Failed to create LSP client for", name, err)
|
||||
slog.Error("Failed to create LSP client for", name, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Create a longer timeout for initialization (some servers take time to start)
|
||||
initCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
|
||||
// Initialize with the initialization context
|
||||
_, err = lspClient.InitializeLSPClient(initCtx, config.WorkingDirectory())
|
||||
if err != nil {
|
||||
logging.Error("Initialize failed", "name", name, "error", err)
|
||||
slog.Error("Initialize failed", "name", name, "error", err)
|
||||
// Clean up the client to prevent resource leaks
|
||||
lspClient.Close()
|
||||
return
|
||||
@@ -48,22 +50,22 @@ func (app *App) createAndStartLSPClient(ctx context.Context, name string, comman
|
||||
|
||||
// Wait for the server to be ready
|
||||
if err := lspClient.WaitForServerReady(initCtx); err != nil {
|
||||
logging.Error("Server failed to become ready", "name", name, "error", err)
|
||||
slog.Error("Server failed to become ready", "name", name, "error", err)
|
||||
// We'll continue anyway, as some functionality might still work
|
||||
lspClient.SetServerState(lsp.StateError)
|
||||
} else {
|
||||
logging.Info("LSP server is ready", "name", name)
|
||||
slog.Info("LSP server is ready", "name", name)
|
||||
lspClient.SetServerState(lsp.StateReady)
|
||||
}
|
||||
|
||||
logging.Info("LSP client initialized", "name", name)
|
||||
|
||||
slog.Info("LSP client initialized", "name", name)
|
||||
|
||||
// Create a child context that can be canceled when the app is shutting down
|
||||
watchCtx, cancelFunc := context.WithCancel(ctx)
|
||||
|
||||
|
||||
// Create a context with the server name for better identification
|
||||
watchCtx = context.WithValue(watchCtx, "serverName", name)
|
||||
|
||||
|
||||
// Create the workspace watcher
|
||||
workspaceWatcher := watcher.NewWorkspaceWatcher(lspClient)
|
||||
|
||||
@@ -92,7 +94,7 @@ func (app *App) runWorkspaceWatcher(ctx context.Context, name string, workspaceW
|
||||
})
|
||||
|
||||
workspaceWatcher.WatchWorkspace(ctx, config.WorkingDirectory())
|
||||
logging.Info("Workspace watcher stopped", "client", name)
|
||||
slog.Info("Workspace watcher stopped", "client", name)
|
||||
}
|
||||
|
||||
// restartLSPClient attempts to restart a crashed or failed LSP client
|
||||
@@ -101,7 +103,7 @@ func (app *App) restartLSPClient(ctx context.Context, name string) {
|
||||
cfg := config.Get()
|
||||
clientConfig, exists := cfg.LSP[name]
|
||||
if !exists {
|
||||
logging.Error("Cannot restart client, configuration not found", "client", name)
|
||||
slog.Error("Cannot restart client, configuration not found", "client", name)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -118,9 +120,15 @@ func (app *App) restartLSPClient(ctx context.Context, name string) {
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
_ = oldClient.Shutdown(shutdownCtx)
|
||||
cancel()
|
||||
|
||||
// Ensure we close the client to free resources
|
||||
_ = oldClient.Close()
|
||||
}
|
||||
|
||||
// Wait a moment before restarting to avoid rapid restart cycles
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
// Create a new client using the shared function
|
||||
app.createAndStartLSPClient(ctx, name, clientConfig.Command, clientConfig.Args...)
|
||||
logging.Info("Successfully restarted LSP client", "client", name)
|
||||
slog.Info("Successfully restarted LSP client", "client", name)
|
||||
}
|
||||
|
||||
@@ -2,14 +2,16 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/opencode-ai/opencode/internal/llm/models"
|
||||
"github.com/opencode-ai/opencode/internal/logging"
|
||||
"github.com/spf13/viper"
|
||||
"github.com/sst/opencode/internal/llm/models"
|
||||
)
|
||||
|
||||
// MCPType defines the type of MCP (Model Control Protocol) server.
|
||||
@@ -34,9 +36,9 @@ type MCPServer struct {
|
||||
type AgentName string
|
||||
|
||||
const (
|
||||
AgentCoder AgentName = "coder"
|
||||
AgentTask AgentName = "task"
|
||||
AgentTitle AgentName = "title"
|
||||
AgentPrimary AgentName = "primary"
|
||||
AgentTask AgentName = "task"
|
||||
AgentTitle AgentName = "title"
|
||||
)
|
||||
|
||||
// Agent defines configuration for different LLM models and their token limits.
|
||||
@@ -65,6 +67,12 @@ type LSPConfig struct {
|
||||
Options any `json:"options"`
|
||||
}
|
||||
|
||||
// TUIConfig defines the configuration for the Terminal User Interface.
|
||||
type TUIConfig struct {
|
||||
Theme string `json:"theme,omitempty"`
|
||||
CustomTheme map[string]any `json:"customTheme,omitempty"`
|
||||
}
|
||||
|
||||
// Config is the main configuration structure for the application.
|
||||
type Config struct {
|
||||
Data Data `json:"data"`
|
||||
@@ -76,6 +84,7 @@ type Config struct {
|
||||
Debug bool `json:"debug,omitempty"`
|
||||
DebugLSP bool `json:"debugLSP,omitempty"`
|
||||
ContextPaths []string `json:"contextPaths,omitempty"`
|
||||
TUI TUIConfig `json:"tui"`
|
||||
}
|
||||
|
||||
// Application constants
|
||||
@@ -93,6 +102,8 @@ var defaultContextPaths = []string{
|
||||
".cursor/rules/",
|
||||
"CLAUDE.md",
|
||||
"CLAUDE.local.md",
|
||||
"CONTEXT.md",
|
||||
"CONTEXT.local.md",
|
||||
"opencode.md",
|
||||
"opencode.local.md",
|
||||
"OpenCode.md",
|
||||
@@ -107,7 +118,7 @@ var cfg *Config
|
||||
// Load initializes the configuration from environment variables and config files.
|
||||
// If debug is true, debug mode is enabled and log level is set to debug.
|
||||
// It returns an error if configuration loading fails.
|
||||
func Load(workingDir string, debug bool) (*Config, error) {
|
||||
func Load(workingDir string, debug bool, lvl *slog.LevelVar) (*Config, error) {
|
||||
if cfg != nil {
|
||||
return cfg, nil
|
||||
}
|
||||
@@ -121,7 +132,6 @@ func Load(workingDir string, debug bool) (*Config, error) {
|
||||
|
||||
configureViper()
|
||||
setDefaults(debug)
|
||||
setProviderDefaults()
|
||||
|
||||
// Read global config
|
||||
if err := readConfig(viper.ReadInConfig()); err != nil {
|
||||
@@ -131,45 +141,21 @@ func Load(workingDir string, debug bool) (*Config, error) {
|
||||
// Load and merge local config
|
||||
mergeLocalConfig(workingDir)
|
||||
|
||||
setProviderDefaults()
|
||||
|
||||
// Apply configuration to the struct
|
||||
if err := viper.Unmarshal(cfg); err != nil {
|
||||
return cfg, fmt.Errorf("failed to unmarshal config: %w", err)
|
||||
}
|
||||
|
||||
applyDefaultValues()
|
||||
|
||||
defaultLevel := slog.LevelInfo
|
||||
if cfg.Debug {
|
||||
defaultLevel = slog.LevelDebug
|
||||
}
|
||||
if os.Getenv("OPENCODE_DEV_DEBUG") == "true" {
|
||||
loggingFile := fmt.Sprintf("%s/%s", cfg.Data.Directory, "debug.log")
|
||||
|
||||
// if file does not exist create it
|
||||
if _, err := os.Stat(loggingFile); os.IsNotExist(err) {
|
||||
if err := os.MkdirAll(cfg.Data.Directory, 0o755); err != nil {
|
||||
return cfg, fmt.Errorf("failed to create directory: %w", err)
|
||||
}
|
||||
if _, err := os.Create(loggingFile); err != nil {
|
||||
return cfg, fmt.Errorf("failed to create log file: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
sloggingFileWriter, err := os.OpenFile(loggingFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o666)
|
||||
if err != nil {
|
||||
return cfg, fmt.Errorf("failed to open log file: %w", err)
|
||||
}
|
||||
// Configure logger
|
||||
logger := slog.New(slog.NewTextHandler(sloggingFileWriter, &slog.HandlerOptions{
|
||||
Level: defaultLevel,
|
||||
}))
|
||||
slog.SetDefault(logger)
|
||||
} else {
|
||||
// Configure logger
|
||||
logger := slog.New(slog.NewTextHandler(logging.NewWriter(), &slog.HandlerOptions{
|
||||
Level: defaultLevel,
|
||||
}))
|
||||
slog.SetDefault(logger)
|
||||
}
|
||||
lvl.Set(defaultLevel)
|
||||
slog.SetLogLoggerLevel(defaultLevel)
|
||||
|
||||
// Validate configuration
|
||||
if err := Validate(); err != nil {
|
||||
@@ -203,6 +189,7 @@ func configureViper() {
|
||||
func setDefaults(debug bool) {
|
||||
viper.SetDefault("data.directory", defaultDataDirectory)
|
||||
viper.SetDefault("contextPaths", defaultContextPaths)
|
||||
viper.SetDefault("tui.theme", "opencode")
|
||||
|
||||
if debug {
|
||||
viper.SetDefault("debug", true)
|
||||
@@ -213,7 +200,8 @@ func setDefaults(debug bool) {
|
||||
}
|
||||
}
|
||||
|
||||
// setProviderDefaults configures LLM provider defaults based on environment variables.
|
||||
// setProviderDefaults configures LLM provider defaults based on provider provided by
|
||||
// environment variables and configuration file.
|
||||
func setProviderDefaults() {
|
||||
// Set all API keys we can find in the environment
|
||||
if apiKey := os.Getenv("ANTHROPIC_API_KEY"); apiKey != "" {
|
||||
@@ -228,66 +216,85 @@ func setProviderDefaults() {
|
||||
if apiKey := os.Getenv("GROQ_API_KEY"); apiKey != "" {
|
||||
viper.SetDefault("providers.groq.apiKey", apiKey)
|
||||
}
|
||||
if apiKey := os.Getenv("OPENROUTER_API_KEY"); apiKey != "" {
|
||||
viper.SetDefault("providers.openrouter.apiKey", apiKey)
|
||||
}
|
||||
if apiKey := os.Getenv("XAI_API_KEY"); apiKey != "" {
|
||||
viper.SetDefault("providers.xai.apiKey", apiKey)
|
||||
}
|
||||
if apiKey := os.Getenv("AZURE_OPENAI_ENDPOINT"); apiKey != "" {
|
||||
// api-key may be empty when using Entra ID credentials – that's okay
|
||||
viper.SetDefault("providers.azure.apiKey", os.Getenv("AZURE_OPENAI_API_KEY"))
|
||||
}
|
||||
|
||||
// Use this order to set the default models
|
||||
// 1. Anthropic
|
||||
// 2. OpenAI
|
||||
// 3. Google Gemini
|
||||
// 4. Groq
|
||||
// 5. AWS Bedrock
|
||||
// 5. OpenRouter
|
||||
// 6. AWS Bedrock
|
||||
// 7. Azure
|
||||
|
||||
// Anthropic configuration
|
||||
if apiKey := os.Getenv("ANTHROPIC_API_KEY"); apiKey != "" {
|
||||
viper.SetDefault("agents.coder.model", models.Claude37Sonnet)
|
||||
if key := viper.GetString("providers.anthropic.apiKey"); strings.TrimSpace(key) != "" {
|
||||
viper.SetDefault("agents.primary.model", models.Claude37Sonnet)
|
||||
viper.SetDefault("agents.task.model", models.Claude37Sonnet)
|
||||
viper.SetDefault("agents.title.model", models.Claude37Sonnet)
|
||||
return
|
||||
}
|
||||
|
||||
// OpenAI configuration
|
||||
if apiKey := os.Getenv("OPENAI_API_KEY"); apiKey != "" {
|
||||
viper.SetDefault("agents.coder.model", models.GPT41)
|
||||
if key := viper.GetString("providers.openai.apiKey"); strings.TrimSpace(key) != "" {
|
||||
viper.SetDefault("agents.primary.model", models.GPT41)
|
||||
viper.SetDefault("agents.task.model", models.GPT41Mini)
|
||||
viper.SetDefault("agents.title.model", models.GPT41Mini)
|
||||
return
|
||||
}
|
||||
|
||||
// Google Gemini configuration
|
||||
if apiKey := os.Getenv("GEMINI_API_KEY"); apiKey != "" {
|
||||
viper.SetDefault("agents.coder.model", models.Gemini25)
|
||||
if key := viper.GetString("providers.gemini.apiKey"); strings.TrimSpace(key) != "" {
|
||||
viper.SetDefault("agents.primary.model", models.Gemini25)
|
||||
viper.SetDefault("agents.task.model", models.Gemini25Flash)
|
||||
viper.SetDefault("agents.title.model", models.Gemini25Flash)
|
||||
return
|
||||
}
|
||||
|
||||
// Groq configuration
|
||||
if apiKey := os.Getenv("GROQ_API_KEY"); apiKey != "" {
|
||||
viper.SetDefault("agents.coder.model", models.QWENQwq)
|
||||
if key := viper.GetString("providers.groq.apiKey"); strings.TrimSpace(key) != "" {
|
||||
viper.SetDefault("agents.primary.model", models.QWENQwq)
|
||||
viper.SetDefault("agents.task.model", models.QWENQwq)
|
||||
viper.SetDefault("agents.title.model", models.QWENQwq)
|
||||
return
|
||||
}
|
||||
|
||||
// OpenRouter configuration
|
||||
if apiKey := os.Getenv("OPENROUTER_API_KEY"); apiKey != "" {
|
||||
viper.SetDefault("providers.openrouter.apiKey", apiKey)
|
||||
viper.SetDefault("agents.coder.model", models.OpenRouterClaude37Sonnet)
|
||||
if key := viper.GetString("providers.openrouter.apiKey"); strings.TrimSpace(key) != "" {
|
||||
viper.SetDefault("agents.primary.model", models.OpenRouterClaude37Sonnet)
|
||||
viper.SetDefault("agents.task.model", models.OpenRouterClaude37Sonnet)
|
||||
viper.SetDefault("agents.title.model", models.OpenRouterClaude35Haiku)
|
||||
return
|
||||
}
|
||||
|
||||
// XAI configuration
|
||||
if key := viper.GetString("providers.xai.apiKey"); strings.TrimSpace(key) != "" {
|
||||
viper.SetDefault("agents.primary.model", models.XAIGrok3Beta)
|
||||
viper.SetDefault("agents.task.model", models.XAIGrok3Beta)
|
||||
viper.SetDefault("agents.title.model", models.XAiGrok3MiniFastBeta)
|
||||
return
|
||||
}
|
||||
|
||||
// AWS Bedrock configuration
|
||||
if hasAWSCredentials() {
|
||||
viper.SetDefault("agents.coder.model", models.BedrockClaude37Sonnet)
|
||||
viper.SetDefault("agents.primary.model", models.BedrockClaude37Sonnet)
|
||||
viper.SetDefault("agents.task.model", models.BedrockClaude37Sonnet)
|
||||
viper.SetDefault("agents.title.model", models.BedrockClaude37Sonnet)
|
||||
return
|
||||
}
|
||||
|
||||
// Azure OpenAI configuration
|
||||
if os.Getenv("AZURE_OPENAI_ENDPOINT") != "" {
|
||||
// api-key may be empty when using Entra ID credentials – that's okay
|
||||
viper.SetDefault("providers.azure.apiKey", os.Getenv("AZURE_OPENAI_API_KEY"))
|
||||
viper.SetDefault("agents.coder.model", models.AzureGPT41)
|
||||
viper.SetDefault("agents.primary.model", models.AzureGPT41)
|
||||
viper.SetDefault("agents.task.model", models.AzureGPT41Mini)
|
||||
viper.SetDefault("agents.title.model", models.AzureGPT41Mini)
|
||||
return
|
||||
@@ -363,13 +370,13 @@ func validateAgent(cfg *Config, name AgentName, agent Agent) error {
|
||||
// Check if model exists
|
||||
model, modelExists := models.SupportedModels[agent.Model]
|
||||
if !modelExists {
|
||||
logging.Warn("unsupported model configured, reverting to default",
|
||||
slog.Warn("unsupported model configured, reverting to default",
|
||||
"agent", name,
|
||||
"configured_model", agent.Model)
|
||||
|
||||
// Set default model based on available providers
|
||||
if setDefaultModelForAgent(name) {
|
||||
logging.Info("set default model for agent", "agent", name, "model", cfg.Agents[name].Model)
|
||||
slog.Info("set default model for agent", "agent", name, "model", cfg.Agents[name].Model)
|
||||
} else {
|
||||
return fmt.Errorf("no valid provider available for agent %s", name)
|
||||
}
|
||||
@@ -384,14 +391,14 @@ func validateAgent(cfg *Config, name AgentName, agent Agent) error {
|
||||
// Provider not configured, check if we have environment variables
|
||||
apiKey := getProviderAPIKey(provider)
|
||||
if apiKey == "" {
|
||||
logging.Warn("provider not configured for model, reverting to default",
|
||||
slog.Warn("provider not configured for model, reverting to default",
|
||||
"agent", name,
|
||||
"model", agent.Model,
|
||||
"provider", provider)
|
||||
|
||||
// Set default model based on available providers
|
||||
if setDefaultModelForAgent(name) {
|
||||
logging.Info("set default model for agent", "agent", name, "model", cfg.Agents[name].Model)
|
||||
slog.Info("set default model for agent", "agent", name, "model", cfg.Agents[name].Model)
|
||||
} else {
|
||||
return fmt.Errorf("no valid provider available for agent %s", name)
|
||||
}
|
||||
@@ -400,18 +407,18 @@ func validateAgent(cfg *Config, name AgentName, agent Agent) error {
|
||||
cfg.Providers[provider] = Provider{
|
||||
APIKey: apiKey,
|
||||
}
|
||||
logging.Info("added provider from environment", "provider", provider)
|
||||
slog.Info("added provider from environment", "provider", provider)
|
||||
}
|
||||
} else if providerCfg.Disabled || providerCfg.APIKey == "" {
|
||||
// Provider is disabled or has no API key
|
||||
logging.Warn("provider is disabled or has no API key, reverting to default",
|
||||
slog.Warn("provider is disabled or has no API key, reverting to default",
|
||||
"agent", name,
|
||||
"model", agent.Model,
|
||||
"provider", provider)
|
||||
|
||||
// Set default model based on available providers
|
||||
if setDefaultModelForAgent(name) {
|
||||
logging.Info("set default model for agent", "agent", name, "model", cfg.Agents[name].Model)
|
||||
slog.Info("set default model for agent", "agent", name, "model", cfg.Agents[name].Model)
|
||||
} else {
|
||||
return fmt.Errorf("no valid provider available for agent %s", name)
|
||||
}
|
||||
@@ -419,7 +426,7 @@ func validateAgent(cfg *Config, name AgentName, agent Agent) error {
|
||||
|
||||
// Validate max tokens
|
||||
if agent.MaxTokens <= 0 {
|
||||
logging.Warn("invalid max tokens, setting to default",
|
||||
slog.Warn("invalid max tokens, setting to default",
|
||||
"agent", name,
|
||||
"model", agent.Model,
|
||||
"max_tokens", agent.MaxTokens)
|
||||
@@ -434,7 +441,7 @@ func validateAgent(cfg *Config, name AgentName, agent Agent) error {
|
||||
cfg.Agents[name] = updatedAgent
|
||||
} else if model.ContextWindow > 0 && agent.MaxTokens > model.ContextWindow/2 {
|
||||
// Ensure max tokens doesn't exceed half the context window (reasonable limit)
|
||||
logging.Warn("max tokens exceeds half the context window, adjusting",
|
||||
slog.Warn("max tokens exceeds half the context window, adjusting",
|
||||
"agent", name,
|
||||
"model", agent.Model,
|
||||
"max_tokens", agent.MaxTokens,
|
||||
@@ -450,7 +457,7 @@ func validateAgent(cfg *Config, name AgentName, agent Agent) error {
|
||||
if model.CanReason && provider == models.ProviderOpenAI {
|
||||
if agent.ReasoningEffort == "" {
|
||||
// Set default reasoning effort for models that support it
|
||||
logging.Info("setting default reasoning effort for model that supports reasoning",
|
||||
slog.Info("setting default reasoning effort for model that supports reasoning",
|
||||
"agent", name,
|
||||
"model", agent.Model)
|
||||
|
||||
@@ -462,7 +469,7 @@ func validateAgent(cfg *Config, name AgentName, agent Agent) error {
|
||||
// Check if reasoning effort is valid (low, medium, high)
|
||||
effort := strings.ToLower(agent.ReasoningEffort)
|
||||
if effort != "low" && effort != "medium" && effort != "high" {
|
||||
logging.Warn("invalid reasoning effort, setting to medium",
|
||||
slog.Warn("invalid reasoning effort, setting to medium",
|
||||
"agent", name,
|
||||
"model", agent.Model,
|
||||
"reasoning_effort", agent.ReasoningEffort)
|
||||
@@ -475,7 +482,7 @@ func validateAgent(cfg *Config, name AgentName, agent Agent) error {
|
||||
}
|
||||
} else if !model.CanReason && agent.ReasoningEffort != "" {
|
||||
// Model doesn't support reasoning but reasoning effort is set
|
||||
logging.Warn("model doesn't support reasoning but reasoning effort is set, ignoring",
|
||||
slog.Warn("model doesn't support reasoning but reasoning effort is set, ignoring",
|
||||
"agent", name,
|
||||
"model", agent.Model,
|
||||
"reasoning_effort", agent.ReasoningEffort)
|
||||
@@ -505,7 +512,7 @@ func Validate() error {
|
||||
// Validate providers
|
||||
for provider, providerCfg := range cfg.Providers {
|
||||
if providerCfg.APIKey == "" && !providerCfg.Disabled {
|
||||
logging.Warn("provider has no API key, marking as disabled", "provider", provider)
|
||||
slog.Warn("provider has no API key, marking as disabled", "provider", provider)
|
||||
providerCfg.Disabled = true
|
||||
cfg.Providers[provider] = providerCfg
|
||||
}
|
||||
@@ -514,7 +521,7 @@ func Validate() error {
|
||||
// Validate LSP configurations
|
||||
for language, lspConfig := range cfg.LSP {
|
||||
if lspConfig.Command == "" && !lspConfig.Disabled {
|
||||
logging.Warn("LSP configuration has no command, marking as disabled", "language", language)
|
||||
slog.Warn("LSP configuration has no command, marking as disabled", "language", language)
|
||||
lspConfig.Disabled = true
|
||||
cfg.LSP[language] = lspConfig
|
||||
}
|
||||
@@ -679,6 +686,24 @@ func WorkingDirectory() string {
|
||||
return cfg.WorkingDir
|
||||
}
|
||||
|
||||
// GetHostname returns the system hostname or "User" if it can't be determined
|
||||
func GetHostname() (string, error) {
|
||||
hostname, err := os.Hostname()
|
||||
if err != nil {
|
||||
return "User", err
|
||||
}
|
||||
return hostname, nil
|
||||
}
|
||||
|
||||
// GetUsername returns the current user's username
|
||||
func GetUsername() (string, error) {
|
||||
currentUser, err := user.Current()
|
||||
if err != nil {
|
||||
return "User", err
|
||||
}
|
||||
return currentUser.Username, nil
|
||||
}
|
||||
|
||||
func UpdateAgentModel(agentName AgentName, modelID models.ModelID) error {
|
||||
if cfg == nil {
|
||||
panic("config not loaded")
|
||||
@@ -711,3 +736,62 @@ func UpdateAgentModel(agentName AgentName, modelID models.ModelID) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateTheme updates the theme in the configuration and writes it to the config file.
|
||||
func UpdateTheme(themeName string) error {
|
||||
if cfg == nil {
|
||||
return fmt.Errorf("config not loaded")
|
||||
}
|
||||
|
||||
// Update the in-memory config
|
||||
cfg.TUI.Theme = themeName
|
||||
|
||||
// Get the config file path
|
||||
configFile := viper.ConfigFileUsed()
|
||||
var configData []byte
|
||||
if configFile == "" {
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get home directory: %w", err)
|
||||
}
|
||||
configFile = filepath.Join(homeDir, fmt.Sprintf(".%s.json", appName))
|
||||
slog.Info("config file not found, creating new one", "path", configFile)
|
||||
configData = []byte(`{}`)
|
||||
} else {
|
||||
// Read the existing config file
|
||||
data, err := os.ReadFile(configFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read config file: %w", err)
|
||||
}
|
||||
configData = data
|
||||
}
|
||||
|
||||
// Parse the JSON
|
||||
var configMap map[string]any
|
||||
if err := json.Unmarshal(configData, &configMap); err != nil {
|
||||
return fmt.Errorf("failed to parse config file: %w", err)
|
||||
}
|
||||
|
||||
// Update just the theme value
|
||||
tuiConfig, ok := configMap["tui"].(map[string]any)
|
||||
if !ok {
|
||||
// TUI config doesn't exist yet, create it
|
||||
configMap["tui"] = map[string]any{"theme": themeName}
|
||||
} else {
|
||||
// Update existing TUI config
|
||||
tuiConfig["theme"] = themeName
|
||||
configMap["tui"] = tuiConfig
|
||||
}
|
||||
|
||||
// Write the updated config back to file
|
||||
updatedData, err := json.MarshalIndent(configMap, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal config: %w", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(configFile, updatedData, 0o644); err != nil {
|
||||
return fmt.Errorf("failed to write config file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -58,4 +58,3 @@ func MarkProjectInitialized() error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -9,8 +9,8 @@ import (
|
||||
_ "github.com/ncruces/go-sqlite3/driver"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
|
||||
"github.com/opencode-ai/opencode/internal/config"
|
||||
"github.com/opencode-ai/opencode/internal/logging"
|
||||
"github.com/sst/opencode/internal/config"
|
||||
"log/slog"
|
||||
|
||||
"github.com/pressly/goose/v3"
|
||||
)
|
||||
@@ -47,21 +47,21 @@ func Connect() (*sql.DB, error) {
|
||||
|
||||
for _, pragma := range pragmas {
|
||||
if _, err = db.Exec(pragma); err != nil {
|
||||
logging.Error("Failed to set pragma", pragma, err)
|
||||
slog.Error("Failed to set pragma", pragma, err)
|
||||
} else {
|
||||
logging.Debug("Set pragma", "pragma", pragma)
|
||||
slog.Debug("Set pragma", "pragma", pragma)
|
||||
}
|
||||
}
|
||||
|
||||
goose.SetBaseFS(FS)
|
||||
|
||||
if err := goose.SetDialect("sqlite3"); err != nil {
|
||||
logging.Error("Failed to set dialect", "error", err)
|
||||
slog.Error("Failed to set dialect", "error", err)
|
||||
return nil, fmt.Errorf("failed to set dialect: %w", err)
|
||||
}
|
||||
|
||||
if err := goose.Up(db, "migrations"); err != nil {
|
||||
logging.Error("Failed to apply migrations", "error", err)
|
||||
slog.Error("Failed to apply migrations", "error", err)
|
||||
return nil, fmt.Errorf("failed to apply migrations: %w", err)
|
||||
}
|
||||
return db, nil
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.27.0
|
||||
// sqlc v1.29.0
|
||||
|
||||
package db
|
||||
|
||||
@@ -27,6 +27,9 @@ func Prepare(ctx context.Context, db DBTX) (*Queries, error) {
|
||||
if q.createFileStmt, err = db.PrepareContext(ctx, createFile); err != nil {
|
||||
return nil, fmt.Errorf("error preparing query CreateFile: %w", err)
|
||||
}
|
||||
if q.createLogStmt, err = db.PrepareContext(ctx, createLog); err != nil {
|
||||
return nil, fmt.Errorf("error preparing query CreateLog: %w", err)
|
||||
}
|
||||
if q.createMessageStmt, err = db.PrepareContext(ctx, createMessage); err != nil {
|
||||
return nil, fmt.Errorf("error preparing query CreateMessage: %w", err)
|
||||
}
|
||||
@@ -60,6 +63,9 @@ func Prepare(ctx context.Context, db DBTX) (*Queries, error) {
|
||||
if q.getSessionByIDStmt, err = db.PrepareContext(ctx, getSessionByID); err != nil {
|
||||
return nil, fmt.Errorf("error preparing query GetSessionByID: %w", err)
|
||||
}
|
||||
if q.listAllLogsStmt, err = db.PrepareContext(ctx, listAllLogs); err != nil {
|
||||
return nil, fmt.Errorf("error preparing query ListAllLogs: %w", err)
|
||||
}
|
||||
if q.listFilesByPathStmt, err = db.PrepareContext(ctx, listFilesByPath); err != nil {
|
||||
return nil, fmt.Errorf("error preparing query ListFilesByPath: %w", err)
|
||||
}
|
||||
@@ -69,9 +75,15 @@ func Prepare(ctx context.Context, db DBTX) (*Queries, error) {
|
||||
if q.listLatestSessionFilesStmt, err = db.PrepareContext(ctx, listLatestSessionFiles); err != nil {
|
||||
return nil, fmt.Errorf("error preparing query ListLatestSessionFiles: %w", err)
|
||||
}
|
||||
if q.listLogsBySessionStmt, err = db.PrepareContext(ctx, listLogsBySession); err != nil {
|
||||
return nil, fmt.Errorf("error preparing query ListLogsBySession: %w", err)
|
||||
}
|
||||
if q.listMessagesBySessionStmt, err = db.PrepareContext(ctx, listMessagesBySession); err != nil {
|
||||
return nil, fmt.Errorf("error preparing query ListMessagesBySession: %w", err)
|
||||
}
|
||||
if q.listMessagesBySessionAfterStmt, err = db.PrepareContext(ctx, listMessagesBySessionAfter); err != nil {
|
||||
return nil, fmt.Errorf("error preparing query ListMessagesBySessionAfter: %w", err)
|
||||
}
|
||||
if q.listNewFilesStmt, err = db.PrepareContext(ctx, listNewFiles); err != nil {
|
||||
return nil, fmt.Errorf("error preparing query ListNewFiles: %w", err)
|
||||
}
|
||||
@@ -97,6 +109,11 @@ func (q *Queries) Close() error {
|
||||
err = fmt.Errorf("error closing createFileStmt: %w", cerr)
|
||||
}
|
||||
}
|
||||
if q.createLogStmt != nil {
|
||||
if cerr := q.createLogStmt.Close(); cerr != nil {
|
||||
err = fmt.Errorf("error closing createLogStmt: %w", cerr)
|
||||
}
|
||||
}
|
||||
if q.createMessageStmt != nil {
|
||||
if cerr := q.createMessageStmt.Close(); cerr != nil {
|
||||
err = fmt.Errorf("error closing createMessageStmt: %w", cerr)
|
||||
@@ -152,6 +169,11 @@ func (q *Queries) Close() error {
|
||||
err = fmt.Errorf("error closing getSessionByIDStmt: %w", cerr)
|
||||
}
|
||||
}
|
||||
if q.listAllLogsStmt != nil {
|
||||
if cerr := q.listAllLogsStmt.Close(); cerr != nil {
|
||||
err = fmt.Errorf("error closing listAllLogsStmt: %w", cerr)
|
||||
}
|
||||
}
|
||||
if q.listFilesByPathStmt != nil {
|
||||
if cerr := q.listFilesByPathStmt.Close(); cerr != nil {
|
||||
err = fmt.Errorf("error closing listFilesByPathStmt: %w", cerr)
|
||||
@@ -167,11 +189,21 @@ func (q *Queries) Close() error {
|
||||
err = fmt.Errorf("error closing listLatestSessionFilesStmt: %w", cerr)
|
||||
}
|
||||
}
|
||||
if q.listLogsBySessionStmt != nil {
|
||||
if cerr := q.listLogsBySessionStmt.Close(); cerr != nil {
|
||||
err = fmt.Errorf("error closing listLogsBySessionStmt: %w", cerr)
|
||||
}
|
||||
}
|
||||
if q.listMessagesBySessionStmt != nil {
|
||||
if cerr := q.listMessagesBySessionStmt.Close(); cerr != nil {
|
||||
err = fmt.Errorf("error closing listMessagesBySessionStmt: %w", cerr)
|
||||
}
|
||||
}
|
||||
if q.listMessagesBySessionAfterStmt != nil {
|
||||
if cerr := q.listMessagesBySessionAfterStmt.Close(); cerr != nil {
|
||||
err = fmt.Errorf("error closing listMessagesBySessionAfterStmt: %w", cerr)
|
||||
}
|
||||
}
|
||||
if q.listNewFilesStmt != nil {
|
||||
if cerr := q.listNewFilesStmt.Close(); cerr != nil {
|
||||
err = fmt.Errorf("error closing listNewFilesStmt: %w", cerr)
|
||||
@@ -234,55 +266,63 @@ func (q *Queries) queryRow(ctx context.Context, stmt *sql.Stmt, query string, ar
|
||||
}
|
||||
|
||||
type Queries struct {
|
||||
db DBTX
|
||||
tx *sql.Tx
|
||||
createFileStmt *sql.Stmt
|
||||
createMessageStmt *sql.Stmt
|
||||
createSessionStmt *sql.Stmt
|
||||
deleteFileStmt *sql.Stmt
|
||||
deleteMessageStmt *sql.Stmt
|
||||
deleteSessionStmt *sql.Stmt
|
||||
deleteSessionFilesStmt *sql.Stmt
|
||||
deleteSessionMessagesStmt *sql.Stmt
|
||||
getFileStmt *sql.Stmt
|
||||
getFileByPathAndSessionStmt *sql.Stmt
|
||||
getMessageStmt *sql.Stmt
|
||||
getSessionByIDStmt *sql.Stmt
|
||||
listFilesByPathStmt *sql.Stmt
|
||||
listFilesBySessionStmt *sql.Stmt
|
||||
listLatestSessionFilesStmt *sql.Stmt
|
||||
listMessagesBySessionStmt *sql.Stmt
|
||||
listNewFilesStmt *sql.Stmt
|
||||
listSessionsStmt *sql.Stmt
|
||||
updateFileStmt *sql.Stmt
|
||||
updateMessageStmt *sql.Stmt
|
||||
updateSessionStmt *sql.Stmt
|
||||
db DBTX
|
||||
tx *sql.Tx
|
||||
createFileStmt *sql.Stmt
|
||||
createLogStmt *sql.Stmt
|
||||
createMessageStmt *sql.Stmt
|
||||
createSessionStmt *sql.Stmt
|
||||
deleteFileStmt *sql.Stmt
|
||||
deleteMessageStmt *sql.Stmt
|
||||
deleteSessionStmt *sql.Stmt
|
||||
deleteSessionFilesStmt *sql.Stmt
|
||||
deleteSessionMessagesStmt *sql.Stmt
|
||||
getFileStmt *sql.Stmt
|
||||
getFileByPathAndSessionStmt *sql.Stmt
|
||||
getMessageStmt *sql.Stmt
|
||||
getSessionByIDStmt *sql.Stmt
|
||||
listAllLogsStmt *sql.Stmt
|
||||
listFilesByPathStmt *sql.Stmt
|
||||
listFilesBySessionStmt *sql.Stmt
|
||||
listLatestSessionFilesStmt *sql.Stmt
|
||||
listLogsBySessionStmt *sql.Stmt
|
||||
listMessagesBySessionStmt *sql.Stmt
|
||||
listMessagesBySessionAfterStmt *sql.Stmt
|
||||
listNewFilesStmt *sql.Stmt
|
||||
listSessionsStmt *sql.Stmt
|
||||
updateFileStmt *sql.Stmt
|
||||
updateMessageStmt *sql.Stmt
|
||||
updateSessionStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func (q *Queries) WithTx(tx *sql.Tx) *Queries {
|
||||
return &Queries{
|
||||
db: tx,
|
||||
tx: tx,
|
||||
createFileStmt: q.createFileStmt,
|
||||
createMessageStmt: q.createMessageStmt,
|
||||
createSessionStmt: q.createSessionStmt,
|
||||
deleteFileStmt: q.deleteFileStmt,
|
||||
deleteMessageStmt: q.deleteMessageStmt,
|
||||
deleteSessionStmt: q.deleteSessionStmt,
|
||||
deleteSessionFilesStmt: q.deleteSessionFilesStmt,
|
||||
deleteSessionMessagesStmt: q.deleteSessionMessagesStmt,
|
||||
getFileStmt: q.getFileStmt,
|
||||
getFileByPathAndSessionStmt: q.getFileByPathAndSessionStmt,
|
||||
getMessageStmt: q.getMessageStmt,
|
||||
getSessionByIDStmt: q.getSessionByIDStmt,
|
||||
listFilesByPathStmt: q.listFilesByPathStmt,
|
||||
listFilesBySessionStmt: q.listFilesBySessionStmt,
|
||||
listLatestSessionFilesStmt: q.listLatestSessionFilesStmt,
|
||||
listMessagesBySessionStmt: q.listMessagesBySessionStmt,
|
||||
listNewFilesStmt: q.listNewFilesStmt,
|
||||
listSessionsStmt: q.listSessionsStmt,
|
||||
updateFileStmt: q.updateFileStmt,
|
||||
updateMessageStmt: q.updateMessageStmt,
|
||||
updateSessionStmt: q.updateSessionStmt,
|
||||
db: tx,
|
||||
tx: tx,
|
||||
createFileStmt: q.createFileStmt,
|
||||
createLogStmt: q.createLogStmt,
|
||||
createMessageStmt: q.createMessageStmt,
|
||||
createSessionStmt: q.createSessionStmt,
|
||||
deleteFileStmt: q.deleteFileStmt,
|
||||
deleteMessageStmt: q.deleteMessageStmt,
|
||||
deleteSessionStmt: q.deleteSessionStmt,
|
||||
deleteSessionFilesStmt: q.deleteSessionFilesStmt,
|
||||
deleteSessionMessagesStmt: q.deleteSessionMessagesStmt,
|
||||
getFileStmt: q.getFileStmt,
|
||||
getFileByPathAndSessionStmt: q.getFileByPathAndSessionStmt,
|
||||
getMessageStmt: q.getMessageStmt,
|
||||
getSessionByIDStmt: q.getSessionByIDStmt,
|
||||
listAllLogsStmt: q.listAllLogsStmt,
|
||||
listFilesByPathStmt: q.listFilesByPathStmt,
|
||||
listFilesBySessionStmt: q.listFilesBySessionStmt,
|
||||
listLatestSessionFilesStmt: q.listLatestSessionFilesStmt,
|
||||
listLogsBySessionStmt: q.listLogsBySessionStmt,
|
||||
listMessagesBySessionStmt: q.listMessagesBySessionStmt,
|
||||
listMessagesBySessionAfterStmt: q.listMessagesBySessionAfterStmt,
|
||||
listNewFilesStmt: q.listNewFilesStmt,
|
||||
listSessionsStmt: q.listSessionsStmt,
|
||||
updateFileStmt: q.updateFileStmt,
|
||||
updateMessageStmt: q.updateMessageStmt,
|
||||
updateSessionStmt: q.updateSessionStmt,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.27.0
|
||||
// sqlc v1.29.0
|
||||
// source: files.sql
|
||||
|
||||
package db
|
||||
@@ -15,13 +15,11 @@ INSERT INTO files (
|
||||
session_id,
|
||||
path,
|
||||
content,
|
||||
version,
|
||||
created_at,
|
||||
updated_at
|
||||
version
|
||||
) VALUES (
|
||||
?, ?, ?, ?, ?, strftime('%s', 'now'), strftime('%s', 'now')
|
||||
?, ?, ?, ?, ?
|
||||
)
|
||||
RETURNING id, session_id, path, content, version, created_at, updated_at
|
||||
RETURNING id, session_id, path, content, version, is_new, created_at, updated_at
|
||||
`
|
||||
|
||||
type CreateFileParams struct {
|
||||
@@ -47,6 +45,7 @@ func (q *Queries) CreateFile(ctx context.Context, arg CreateFileParams) (File, e
|
||||
&i.Path,
|
||||
&i.Content,
|
||||
&i.Version,
|
||||
&i.IsNew,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
@@ -74,7 +73,7 @@ func (q *Queries) DeleteSessionFiles(ctx context.Context, sessionID string) erro
|
||||
}
|
||||
|
||||
const getFile = `-- name: GetFile :one
|
||||
SELECT id, session_id, path, content, version, created_at, updated_at
|
||||
SELECT id, session_id, path, content, version, is_new, created_at, updated_at
|
||||
FROM files
|
||||
WHERE id = ? LIMIT 1
|
||||
`
|
||||
@@ -88,6 +87,7 @@ func (q *Queries) GetFile(ctx context.Context, id string) (File, error) {
|
||||
&i.Path,
|
||||
&i.Content,
|
||||
&i.Version,
|
||||
&i.IsNew,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
@@ -95,7 +95,7 @@ func (q *Queries) GetFile(ctx context.Context, id string) (File, error) {
|
||||
}
|
||||
|
||||
const getFileByPathAndSession = `-- name: GetFileByPathAndSession :one
|
||||
SELECT id, session_id, path, content, version, created_at, updated_at
|
||||
SELECT id, session_id, path, content, version, is_new, created_at, updated_at
|
||||
FROM files
|
||||
WHERE path = ? AND session_id = ?
|
||||
ORDER BY created_at DESC
|
||||
@@ -116,6 +116,7 @@ func (q *Queries) GetFileByPathAndSession(ctx context.Context, arg GetFileByPath
|
||||
&i.Path,
|
||||
&i.Content,
|
||||
&i.Version,
|
||||
&i.IsNew,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
@@ -123,7 +124,7 @@ func (q *Queries) GetFileByPathAndSession(ctx context.Context, arg GetFileByPath
|
||||
}
|
||||
|
||||
const listFilesByPath = `-- name: ListFilesByPath :many
|
||||
SELECT id, session_id, path, content, version, created_at, updated_at
|
||||
SELECT id, session_id, path, content, version, is_new, created_at, updated_at
|
||||
FROM files
|
||||
WHERE path = ?
|
||||
ORDER BY created_at DESC
|
||||
@@ -144,6 +145,7 @@ func (q *Queries) ListFilesByPath(ctx context.Context, path string) ([]File, err
|
||||
&i.Path,
|
||||
&i.Content,
|
||||
&i.Version,
|
||||
&i.IsNew,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
); err != nil {
|
||||
@@ -161,7 +163,7 @@ func (q *Queries) ListFilesByPath(ctx context.Context, path string) ([]File, err
|
||||
}
|
||||
|
||||
const listFilesBySession = `-- name: ListFilesBySession :many
|
||||
SELECT id, session_id, path, content, version, created_at, updated_at
|
||||
SELECT id, session_id, path, content, version, is_new, created_at, updated_at
|
||||
FROM files
|
||||
WHERE session_id = ?
|
||||
ORDER BY created_at ASC
|
||||
@@ -182,6 +184,7 @@ func (q *Queries) ListFilesBySession(ctx context.Context, sessionID string) ([]F
|
||||
&i.Path,
|
||||
&i.Content,
|
||||
&i.Version,
|
||||
&i.IsNew,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
); err != nil {
|
||||
@@ -199,7 +202,7 @@ func (q *Queries) ListFilesBySession(ctx context.Context, sessionID string) ([]F
|
||||
}
|
||||
|
||||
const listLatestSessionFiles = `-- name: ListLatestSessionFiles :many
|
||||
SELECT f.id, f.session_id, f.path, f.content, f.version, f.created_at, f.updated_at
|
||||
SELECT f.id, f.session_id, f.path, f.content, f.version, f.is_new, f.created_at, f.updated_at
|
||||
FROM files f
|
||||
INNER JOIN (
|
||||
SELECT path, MAX(created_at) as max_created_at
|
||||
@@ -225,6 +228,7 @@ func (q *Queries) ListLatestSessionFiles(ctx context.Context, sessionID string)
|
||||
&i.Path,
|
||||
&i.Content,
|
||||
&i.Version,
|
||||
&i.IsNew,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
); err != nil {
|
||||
@@ -242,7 +246,7 @@ func (q *Queries) ListLatestSessionFiles(ctx context.Context, sessionID string)
|
||||
}
|
||||
|
||||
const listNewFiles = `-- name: ListNewFiles :many
|
||||
SELECT id, session_id, path, content, version, created_at, updated_at
|
||||
SELECT id, session_id, path, content, version, is_new, created_at, updated_at
|
||||
FROM files
|
||||
WHERE is_new = 1
|
||||
ORDER BY created_at DESC
|
||||
@@ -263,6 +267,7 @@ func (q *Queries) ListNewFiles(ctx context.Context) ([]File, error) {
|
||||
&i.Path,
|
||||
&i.Content,
|
||||
&i.Version,
|
||||
&i.IsNew,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
); err != nil {
|
||||
@@ -284,9 +289,9 @@ UPDATE files
|
||||
SET
|
||||
content = ?,
|
||||
version = ?,
|
||||
updated_at = strftime('%s', 'now')
|
||||
updated_at = strftime('%Y-%m-%dT%H:%M:%f000Z', 'now')
|
||||
WHERE id = ?
|
||||
RETURNING id, session_id, path, content, version, created_at, updated_at
|
||||
RETURNING id, session_id, path, content, version, is_new, created_at, updated_at
|
||||
`
|
||||
|
||||
type UpdateFileParams struct {
|
||||
@@ -304,6 +309,7 @@ func (q *Queries) UpdateFile(ctx context.Context, arg UpdateFileParams) (File, e
|
||||
&i.Path,
|
||||
&i.Content,
|
||||
&i.Version,
|
||||
&i.IsNew,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
|
||||
137
internal/db/logs.sql.go
Normal file
137
internal/db/logs.sql.go
Normal file
@@ -0,0 +1,137 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.29.0
|
||||
// source: logs.sql
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
const createLog = `-- name: CreateLog :one
|
||||
INSERT INTO logs (
|
||||
id,
|
||||
session_id,
|
||||
timestamp,
|
||||
level,
|
||||
message,
|
||||
attributes
|
||||
) VALUES (
|
||||
?,
|
||||
?,
|
||||
?,
|
||||
?,
|
||||
?,
|
||||
?
|
||||
) RETURNING id, session_id, timestamp, level, message, attributes, created_at, updated_at
|
||||
`
|
||||
|
||||
type CreateLogParams struct {
|
||||
ID string `json:"id"`
|
||||
SessionID sql.NullString `json:"session_id"`
|
||||
Timestamp string `json:"timestamp"`
|
||||
Level string `json:"level"`
|
||||
Message string `json:"message"`
|
||||
Attributes sql.NullString `json:"attributes"`
|
||||
}
|
||||
|
||||
func (q *Queries) CreateLog(ctx context.Context, arg CreateLogParams) (Log, error) {
|
||||
row := q.queryRow(ctx, q.createLogStmt, createLog,
|
||||
arg.ID,
|
||||
arg.SessionID,
|
||||
arg.Timestamp,
|
||||
arg.Level,
|
||||
arg.Message,
|
||||
arg.Attributes,
|
||||
)
|
||||
var i Log
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.SessionID,
|
||||
&i.Timestamp,
|
||||
&i.Level,
|
||||
&i.Message,
|
||||
&i.Attributes,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const listAllLogs = `-- name: ListAllLogs :many
|
||||
SELECT id, session_id, timestamp, level, message, attributes, created_at, updated_at FROM logs
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT ?
|
||||
`
|
||||
|
||||
func (q *Queries) ListAllLogs(ctx context.Context, limit int64) ([]Log, error) {
|
||||
rows, err := q.query(ctx, q.listAllLogsStmt, listAllLogs, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
items := []Log{}
|
||||
for rows.Next() {
|
||||
var i Log
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.SessionID,
|
||||
&i.Timestamp,
|
||||
&i.Level,
|
||||
&i.Message,
|
||||
&i.Attributes,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const listLogsBySession = `-- name: ListLogsBySession :many
|
||||
SELECT id, session_id, timestamp, level, message, attributes, created_at, updated_at FROM logs
|
||||
WHERE session_id = ?
|
||||
ORDER BY timestamp DESC
|
||||
`
|
||||
|
||||
func (q *Queries) ListLogsBySession(ctx context.Context, sessionID sql.NullString) ([]Log, error) {
|
||||
rows, err := q.query(ctx, q.listLogsBySessionStmt, listLogsBySession, sessionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
items := []Log{}
|
||||
for rows.Next() {
|
||||
var i Log
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.SessionID,
|
||||
&i.Timestamp,
|
||||
&i.Level,
|
||||
&i.Message,
|
||||
&i.Attributes,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.27.0
|
||||
// sqlc v1.29.0
|
||||
// source: messages.sql
|
||||
|
||||
package db
|
||||
@@ -16,11 +16,9 @@ INSERT INTO messages (
|
||||
session_id,
|
||||
role,
|
||||
parts,
|
||||
model,
|
||||
created_at,
|
||||
updated_at
|
||||
model
|
||||
) VALUES (
|
||||
?, ?, ?, ?, ?, strftime('%s', 'now'), strftime('%s', 'now')
|
||||
?, ?, ?, ?, ?
|
||||
)
|
||||
RETURNING id, session_id, role, parts, model, created_at, updated_at, finished_at
|
||||
`
|
||||
@@ -136,19 +134,63 @@ func (q *Queries) ListMessagesBySession(ctx context.Context, sessionID string) (
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const listMessagesBySessionAfter = `-- name: ListMessagesBySessionAfter :many
|
||||
SELECT id, session_id, role, parts, model, created_at, updated_at, finished_at
|
||||
FROM messages
|
||||
WHERE session_id = ? AND created_at > ?
|
||||
ORDER BY created_at ASC
|
||||
`
|
||||
|
||||
type ListMessagesBySessionAfterParams struct {
|
||||
SessionID string `json:"session_id"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
}
|
||||
|
||||
func (q *Queries) ListMessagesBySessionAfter(ctx context.Context, arg ListMessagesBySessionAfterParams) ([]Message, error) {
|
||||
rows, err := q.query(ctx, q.listMessagesBySessionAfterStmt, listMessagesBySessionAfter, arg.SessionID, arg.CreatedAt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
items := []Message{}
|
||||
for rows.Next() {
|
||||
var i Message
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.SessionID,
|
||||
&i.Role,
|
||||
&i.Parts,
|
||||
&i.Model,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.FinishedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const updateMessage = `-- name: UpdateMessage :exec
|
||||
UPDATE messages
|
||||
SET
|
||||
parts = ?,
|
||||
finished_at = ?,
|
||||
updated_at = strftime('%s', 'now')
|
||||
updated_at = strftime('%Y-%m-%dT%H:%M:%f000Z', 'now')
|
||||
WHERE id = ?
|
||||
`
|
||||
|
||||
type UpdateMessageParams struct {
|
||||
Parts string `json:"parts"`
|
||||
FinishedAt sql.NullInt64 `json:"finished_at"`
|
||||
ID string `json:"id"`
|
||||
Parts string `json:"parts"`
|
||||
FinishedAt sql.NullString `json:"finished_at"`
|
||||
ID string `json:"id"`
|
||||
}
|
||||
|
||||
func (q *Queries) UpdateMessage(ctx context.Context, arg UpdateMessageParams) error {
|
||||
|
||||
@@ -6,17 +6,19 @@ CREATE TABLE IF NOT EXISTS sessions (
|
||||
parent_session_id TEXT,
|
||||
title TEXT NOT NULL,
|
||||
message_count INTEGER NOT NULL DEFAULT 0 CHECK (message_count >= 0),
|
||||
prompt_tokens INTEGER NOT NULL DEFAULT 0 CHECK (prompt_tokens >= 0),
|
||||
completion_tokens INTEGER NOT NULL DEFAULT 0 CHECK (completion_tokens>= 0),
|
||||
prompt_tokens INTEGER NOT NULL DEFAULT 0 CHECK (prompt_tokens >= 0),
|
||||
completion_tokens INTEGER NOT NULL DEFAULT 0 CHECK (completion_tokens >= 0),
|
||||
cost REAL NOT NULL DEFAULT 0.0 CHECK (cost >= 0.0),
|
||||
updated_at INTEGER NOT NULL, -- Unix timestamp in milliseconds
|
||||
created_at INTEGER NOT NULL -- Unix timestamp in milliseconds
|
||||
summary TEXT,
|
||||
summarized_at TEXT,
|
||||
updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%f000Z', 'now')),
|
||||
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%f000Z', 'now'))
|
||||
);
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS update_sessions_updated_at
|
||||
AFTER UPDATE ON sessions
|
||||
BEGIN
|
||||
UPDATE sessions SET updated_at = strftime('%s', 'now')
|
||||
UPDATE sessions SET updated_at = strftime('%Y-%m-%dT%H:%M:%f000Z', 'now')
|
||||
WHERE id = new.id;
|
||||
END;
|
||||
|
||||
@@ -27,8 +29,9 @@ CREATE TABLE IF NOT EXISTS files (
|
||||
path TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
version TEXT NOT NULL,
|
||||
created_at INTEGER NOT NULL, -- Unix timestamp in milliseconds
|
||||
updated_at INTEGER NOT NULL, -- Unix timestamp in milliseconds
|
||||
is_new INTEGER DEFAULT 0,
|
||||
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%f000Z', 'now')),
|
||||
updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%f000Z', 'now')),
|
||||
FOREIGN KEY (session_id) REFERENCES sessions (id) ON DELETE CASCADE,
|
||||
UNIQUE(path, session_id, version)
|
||||
);
|
||||
@@ -39,7 +42,7 @@ CREATE INDEX IF NOT EXISTS idx_files_path ON files (path);
|
||||
CREATE TRIGGER IF NOT EXISTS update_files_updated_at
|
||||
AFTER UPDATE ON files
|
||||
BEGIN
|
||||
UPDATE files SET updated_at = strftime('%s', 'now')
|
||||
UPDATE files SET updated_at = strftime('%Y-%m-%dT%H:%M:%f000Z', 'now')
|
||||
WHERE id = new.id;
|
||||
END;
|
||||
|
||||
@@ -50,9 +53,9 @@ CREATE TABLE IF NOT EXISTS messages (
|
||||
role TEXT NOT NULL,
|
||||
parts TEXT NOT NULL default '[]',
|
||||
model TEXT,
|
||||
created_at INTEGER NOT NULL, -- Unix timestamp in milliseconds
|
||||
updated_at INTEGER NOT NULL, -- Unix timestamp in milliseconds
|
||||
finished_at INTEGER, -- Unix timestamp in milliseconds
|
||||
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%f000Z', 'now')),
|
||||
updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%f000Z', 'now')),
|
||||
finished_at TEXT,
|
||||
FOREIGN KEY (session_id) REFERENCES sessions (id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
@@ -61,7 +64,7 @@ CREATE INDEX IF NOT EXISTS idx_messages_session_id ON messages (session_id);
|
||||
CREATE TRIGGER IF NOT EXISTS update_messages_updated_at
|
||||
AFTER UPDATE ON messages
|
||||
BEGIN
|
||||
UPDATE messages SET updated_at = strftime('%s', 'now')
|
||||
UPDATE messages SET updated_at = strftime('%Y-%m-%dT%H:%M:%f000Z', 'now')
|
||||
WHERE id = new.id;
|
||||
END;
|
||||
|
||||
@@ -81,6 +84,28 @@ UPDATE sessions SET
|
||||
WHERE id = old.session_id;
|
||||
END;
|
||||
|
||||
-- Logs
|
||||
CREATE TABLE IF NOT EXISTS logs (
|
||||
id TEXT PRIMARY KEY,
|
||||
session_id TEXT REFERENCES sessions(id) ON DELETE CASCADE,
|
||||
timestamp TEXT NOT NULL,
|
||||
level TEXT NOT NULL,
|
||||
message TEXT NOT NULL,
|
||||
attributes TEXT,
|
||||
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%f000Z', 'now')),
|
||||
updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%f000Z', 'now'))
|
||||
);
|
||||
|
||||
CREATE INDEX logs_session_id_idx ON logs(session_id);
|
||||
CREATE INDEX logs_timestamp_idx ON logs(timestamp);
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS update_logs_updated_at
|
||||
AFTER UPDATE ON logs
|
||||
BEGIN
|
||||
UPDATE logs SET updated_at = strftime('%Y-%m-%dT%H:%M:%f000Z', 'now')
|
||||
WHERE id = new.id;
|
||||
END;
|
||||
|
||||
-- +goose StatementEnd
|
||||
|
||||
-- +goose Down
|
||||
@@ -88,11 +113,13 @@ END;
|
||||
DROP TRIGGER IF EXISTS update_sessions_updated_at;
|
||||
DROP TRIGGER IF EXISTS update_messages_updated_at;
|
||||
DROP TRIGGER IF EXISTS update_files_updated_at;
|
||||
DROP TRIGGER IF EXISTS update_logs_updated_at;
|
||||
|
||||
DROP TRIGGER IF EXISTS update_session_message_count_on_delete;
|
||||
DROP TRIGGER IF EXISTS update_session_message_count_on_insert;
|
||||
|
||||
DROP TABLE IF EXISTS sessions;
|
||||
DROP TABLE IF EXISTS logs;
|
||||
DROP TABLE IF EXISTS messages;
|
||||
DROP TABLE IF EXISTS files;
|
||||
DROP TABLE IF EXISTS sessions;
|
||||
-- +goose StatementEnd
|
||||
@@ -1,6 +1,6 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.27.0
|
||||
// sqlc v1.29.0
|
||||
|
||||
package db
|
||||
|
||||
@@ -9,13 +9,25 @@ import (
|
||||
)
|
||||
|
||||
type File struct {
|
||||
ID string `json:"id"`
|
||||
SessionID string `json:"session_id"`
|
||||
Path string `json:"path"`
|
||||
Content string `json:"content"`
|
||||
Version string `json:"version"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
UpdatedAt int64 `json:"updated_at"`
|
||||
ID string `json:"id"`
|
||||
SessionID string `json:"session_id"`
|
||||
Path string `json:"path"`
|
||||
Content string `json:"content"`
|
||||
Version string `json:"version"`
|
||||
IsNew sql.NullInt64 `json:"is_new"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
}
|
||||
|
||||
type Log struct {
|
||||
ID string `json:"id"`
|
||||
SessionID sql.NullString `json:"session_id"`
|
||||
Timestamp string `json:"timestamp"`
|
||||
Level string `json:"level"`
|
||||
Message string `json:"message"`
|
||||
Attributes sql.NullString `json:"attributes"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
@@ -24,9 +36,9 @@ type Message struct {
|
||||
Role string `json:"role"`
|
||||
Parts string `json:"parts"`
|
||||
Model sql.NullString `json:"model"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
UpdatedAt int64 `json:"updated_at"`
|
||||
FinishedAt sql.NullInt64 `json:"finished_at"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
FinishedAt sql.NullString `json:"finished_at"`
|
||||
}
|
||||
|
||||
type Session struct {
|
||||
@@ -37,6 +49,8 @@ type Session struct {
|
||||
PromptTokens int64 `json:"prompt_tokens"`
|
||||
CompletionTokens int64 `json:"completion_tokens"`
|
||||
Cost float64 `json:"cost"`
|
||||
UpdatedAt int64 `json:"updated_at"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
Summary sql.NullString `json:"summary"`
|
||||
SummarizedAt sql.NullString `json:"summarized_at"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
}
|
||||
|
||||
@@ -1,15 +1,17 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.27.0
|
||||
// sqlc v1.29.0
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
type Querier interface {
|
||||
CreateFile(ctx context.Context, arg CreateFileParams) (File, error)
|
||||
CreateLog(ctx context.Context, arg CreateLogParams) (Log, error)
|
||||
CreateMessage(ctx context.Context, arg CreateMessageParams) (Message, error)
|
||||
CreateSession(ctx context.Context, arg CreateSessionParams) (Session, error)
|
||||
DeleteFile(ctx context.Context, id string) error
|
||||
@@ -21,10 +23,13 @@ type Querier interface {
|
||||
GetFileByPathAndSession(ctx context.Context, arg GetFileByPathAndSessionParams) (File, error)
|
||||
GetMessage(ctx context.Context, id string) (Message, error)
|
||||
GetSessionByID(ctx context.Context, id string) (Session, error)
|
||||
ListAllLogs(ctx context.Context, limit int64) ([]Log, error)
|
||||
ListFilesByPath(ctx context.Context, path string) ([]File, error)
|
||||
ListFilesBySession(ctx context.Context, sessionID string) ([]File, error)
|
||||
ListLatestSessionFiles(ctx context.Context, sessionID string) ([]File, error)
|
||||
ListLogsBySession(ctx context.Context, sessionID sql.NullString) ([]Log, error)
|
||||
ListMessagesBySession(ctx context.Context, sessionID string) ([]Message, error)
|
||||
ListMessagesBySessionAfter(ctx context.Context, arg ListMessagesBySessionAfterParams) ([]Message, error)
|
||||
ListNewFiles(ctx context.Context) ([]File, error)
|
||||
ListSessions(ctx context.Context) ([]Session, error)
|
||||
UpdateFile(ctx context.Context, arg UpdateFileParams) (File, error)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.27.0
|
||||
// sqlc v1.29.0
|
||||
// source: sessions.sql
|
||||
|
||||
package db
|
||||
@@ -19,8 +19,8 @@ INSERT INTO sessions (
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
cost,
|
||||
updated_at,
|
||||
created_at
|
||||
summary,
|
||||
summarized_at
|
||||
) VALUES (
|
||||
?,
|
||||
?,
|
||||
@@ -29,9 +29,9 @@ INSERT INTO sessions (
|
||||
?,
|
||||
?,
|
||||
?,
|
||||
strftime('%s', 'now'),
|
||||
strftime('%s', 'now')
|
||||
) RETURNING id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at
|
||||
?,
|
||||
?
|
||||
) RETURNING id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, summary, summarized_at, updated_at, created_at
|
||||
`
|
||||
|
||||
type CreateSessionParams struct {
|
||||
@@ -42,6 +42,8 @@ type CreateSessionParams struct {
|
||||
PromptTokens int64 `json:"prompt_tokens"`
|
||||
CompletionTokens int64 `json:"completion_tokens"`
|
||||
Cost float64 `json:"cost"`
|
||||
Summary sql.NullString `json:"summary"`
|
||||
SummarizedAt sql.NullString `json:"summarized_at"`
|
||||
}
|
||||
|
||||
func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) (Session, error) {
|
||||
@@ -53,6 +55,8 @@ func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) (S
|
||||
arg.PromptTokens,
|
||||
arg.CompletionTokens,
|
||||
arg.Cost,
|
||||
arg.Summary,
|
||||
arg.SummarizedAt,
|
||||
)
|
||||
var i Session
|
||||
err := row.Scan(
|
||||
@@ -63,6 +67,8 @@ func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) (S
|
||||
&i.PromptTokens,
|
||||
&i.CompletionTokens,
|
||||
&i.Cost,
|
||||
&i.Summary,
|
||||
&i.SummarizedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.CreatedAt,
|
||||
)
|
||||
@@ -80,7 +86,7 @@ func (q *Queries) DeleteSession(ctx context.Context, id string) error {
|
||||
}
|
||||
|
||||
const getSessionByID = `-- name: GetSessionByID :one
|
||||
SELECT id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at
|
||||
SELECT id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, summary, summarized_at, updated_at, created_at
|
||||
FROM sessions
|
||||
WHERE id = ? LIMIT 1
|
||||
`
|
||||
@@ -96,6 +102,8 @@ func (q *Queries) GetSessionByID(ctx context.Context, id string) (Session, error
|
||||
&i.PromptTokens,
|
||||
&i.CompletionTokens,
|
||||
&i.Cost,
|
||||
&i.Summary,
|
||||
&i.SummarizedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.CreatedAt,
|
||||
)
|
||||
@@ -103,7 +111,7 @@ func (q *Queries) GetSessionByID(ctx context.Context, id string) (Session, error
|
||||
}
|
||||
|
||||
const listSessions = `-- name: ListSessions :many
|
||||
SELECT id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at
|
||||
SELECT id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, summary, summarized_at, updated_at, created_at
|
||||
FROM sessions
|
||||
WHERE parent_session_id is NULL
|
||||
ORDER BY created_at DESC
|
||||
@@ -126,6 +134,8 @@ func (q *Queries) ListSessions(ctx context.Context) ([]Session, error) {
|
||||
&i.PromptTokens,
|
||||
&i.CompletionTokens,
|
||||
&i.Cost,
|
||||
&i.Summary,
|
||||
&i.SummarizedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.CreatedAt,
|
||||
); err != nil {
|
||||
@@ -148,17 +158,21 @@ SET
|
||||
title = ?,
|
||||
prompt_tokens = ?,
|
||||
completion_tokens = ?,
|
||||
cost = ?
|
||||
cost = ?,
|
||||
summary = ?,
|
||||
summarized_at = ?
|
||||
WHERE id = ?
|
||||
RETURNING id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at
|
||||
RETURNING id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, summary, summarized_at, updated_at, created_at
|
||||
`
|
||||
|
||||
type UpdateSessionParams struct {
|
||||
Title string `json:"title"`
|
||||
PromptTokens int64 `json:"prompt_tokens"`
|
||||
CompletionTokens int64 `json:"completion_tokens"`
|
||||
Cost float64 `json:"cost"`
|
||||
ID string `json:"id"`
|
||||
Title string `json:"title"`
|
||||
PromptTokens int64 `json:"prompt_tokens"`
|
||||
CompletionTokens int64 `json:"completion_tokens"`
|
||||
Cost float64 `json:"cost"`
|
||||
Summary sql.NullString `json:"summary"`
|
||||
SummarizedAt sql.NullString `json:"summarized_at"`
|
||||
ID string `json:"id"`
|
||||
}
|
||||
|
||||
func (q *Queries) UpdateSession(ctx context.Context, arg UpdateSessionParams) (Session, error) {
|
||||
@@ -167,6 +181,8 @@ func (q *Queries) UpdateSession(ctx context.Context, arg UpdateSessionParams) (S
|
||||
arg.PromptTokens,
|
||||
arg.CompletionTokens,
|
||||
arg.Cost,
|
||||
arg.Summary,
|
||||
arg.SummarizedAt,
|
||||
arg.ID,
|
||||
)
|
||||
var i Session
|
||||
@@ -178,6 +194,8 @@ func (q *Queries) UpdateSession(ctx context.Context, arg UpdateSessionParams) (S
|
||||
&i.PromptTokens,
|
||||
&i.CompletionTokens,
|
||||
&i.Cost,
|
||||
&i.Summary,
|
||||
&i.SummarizedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.CreatedAt,
|
||||
)
|
||||
|
||||
@@ -28,11 +28,9 @@ INSERT INTO files (
|
||||
session_id,
|
||||
path,
|
||||
content,
|
||||
version,
|
||||
created_at,
|
||||
updated_at
|
||||
version
|
||||
) VALUES (
|
||||
?, ?, ?, ?, ?, strftime('%s', 'now'), strftime('%s', 'now')
|
||||
?, ?, ?, ?, ?
|
||||
)
|
||||
RETURNING *;
|
||||
|
||||
@@ -41,7 +39,7 @@ UPDATE files
|
||||
SET
|
||||
content = ?,
|
||||
version = ?,
|
||||
updated_at = strftime('%s', 'now')
|
||||
updated_at = strftime('%Y-%m-%dT%H:%M:%f000Z', 'now')
|
||||
WHERE id = ?
|
||||
RETURNING *;
|
||||
|
||||
|
||||
26
internal/db/sql/logs.sql
Normal file
26
internal/db/sql/logs.sql
Normal file
@@ -0,0 +1,26 @@
|
||||
-- name: CreateLog :one
|
||||
INSERT INTO logs (
|
||||
id,
|
||||
session_id,
|
||||
timestamp,
|
||||
level,
|
||||
message,
|
||||
attributes
|
||||
) VALUES (
|
||||
?,
|
||||
?,
|
||||
?,
|
||||
?,
|
||||
?,
|
||||
?
|
||||
) RETURNING *;
|
||||
|
||||
-- name: ListLogsBySession :many
|
||||
SELECT * FROM logs
|
||||
WHERE session_id = ?
|
||||
ORDER BY timestamp DESC;
|
||||
|
||||
-- name: ListAllLogs :many
|
||||
SELECT * FROM logs
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT ?;
|
||||
@@ -9,17 +9,21 @@ FROM messages
|
||||
WHERE session_id = ?
|
||||
ORDER BY created_at ASC;
|
||||
|
||||
-- name: ListMessagesBySessionAfter :many
|
||||
SELECT *
|
||||
FROM messages
|
||||
WHERE session_id = ? AND created_at > ?
|
||||
ORDER BY created_at ASC;
|
||||
|
||||
-- name: CreateMessage :one
|
||||
INSERT INTO messages (
|
||||
id,
|
||||
session_id,
|
||||
role,
|
||||
parts,
|
||||
model,
|
||||
created_at,
|
||||
updated_at
|
||||
model
|
||||
) VALUES (
|
||||
?, ?, ?, ?, ?, strftime('%s', 'now'), strftime('%s', 'now')
|
||||
?, ?, ?, ?, ?
|
||||
)
|
||||
RETURNING *;
|
||||
|
||||
@@ -28,7 +32,7 @@ UPDATE messages
|
||||
SET
|
||||
parts = ?,
|
||||
finished_at = ?,
|
||||
updated_at = strftime('%s', 'now')
|
||||
updated_at = strftime('%Y-%m-%dT%H:%M:%f000Z', 'now')
|
||||
WHERE id = ?;
|
||||
|
||||
|
||||
|
||||
@@ -7,8 +7,8 @@ INSERT INTO sessions (
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
cost,
|
||||
updated_at,
|
||||
created_at
|
||||
summary,
|
||||
summarized_at
|
||||
) VALUES (
|
||||
?,
|
||||
?,
|
||||
@@ -17,8 +17,8 @@ INSERT INTO sessions (
|
||||
?,
|
||||
?,
|
||||
?,
|
||||
strftime('%s', 'now'),
|
||||
strftime('%s', 'now')
|
||||
?,
|
||||
?
|
||||
) RETURNING *;
|
||||
|
||||
-- name: GetSessionByID :one
|
||||
@@ -38,7 +38,9 @@ SET
|
||||
title = ?,
|
||||
prompt_tokens = ?,
|
||||
completion_tokens = ?,
|
||||
cost = ?
|
||||
cost = ?,
|
||||
summary = ?,
|
||||
summarized_at = ?
|
||||
WHERE id = ?
|
||||
RETURNING *;
|
||||
|
||||
|
||||
@@ -15,8 +15,9 @@ import (
|
||||
"github.com/aymanbagabas/go-udiff"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
"github.com/charmbracelet/x/ansi"
|
||||
"github.com/opencode-ai/opencode/internal/config"
|
||||
"github.com/sergi/go-diff/diffmatchpatch"
|
||||
"github.com/sst/opencode/internal/config"
|
||||
"github.com/sst/opencode/internal/tui/theme"
|
||||
)
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
@@ -68,143 +69,6 @@ type linePair struct {
|
||||
right *DiffLine
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Style Configuration
|
||||
// -------------------------------------------------------------------------
|
||||
|
||||
// StyleConfig defines styling for diff rendering
|
||||
type StyleConfig struct {
|
||||
ShowHeader bool
|
||||
ShowHunkHeader bool
|
||||
FileNameFg lipgloss.Color
|
||||
// Background colors
|
||||
RemovedLineBg lipgloss.Color
|
||||
AddedLineBg lipgloss.Color
|
||||
ContextLineBg lipgloss.Color
|
||||
HunkLineBg lipgloss.Color
|
||||
RemovedLineNumberBg lipgloss.Color
|
||||
AddedLineNamerBg lipgloss.Color
|
||||
|
||||
// Foreground colors
|
||||
HunkLineFg lipgloss.Color
|
||||
RemovedFg lipgloss.Color
|
||||
AddedFg lipgloss.Color
|
||||
LineNumberFg lipgloss.Color
|
||||
RemovedHighlightFg lipgloss.Color
|
||||
AddedHighlightFg lipgloss.Color
|
||||
|
||||
// Highlight settings
|
||||
HighlightStyle string
|
||||
RemovedHighlightBg lipgloss.Color
|
||||
AddedHighlightBg lipgloss.Color
|
||||
}
|
||||
|
||||
// StyleOption is a function that modifies a StyleConfig
|
||||
type StyleOption func(*StyleConfig)
|
||||
|
||||
// NewStyleConfig creates a StyleConfig with default values
|
||||
func NewStyleConfig(opts ...StyleOption) StyleConfig {
|
||||
// Default color scheme
|
||||
config := StyleConfig{
|
||||
ShowHeader: true,
|
||||
ShowHunkHeader: true,
|
||||
FileNameFg: lipgloss.Color("#a0a0a0"),
|
||||
RemovedLineBg: lipgloss.Color("#3A3030"),
|
||||
AddedLineBg: lipgloss.Color("#303A30"),
|
||||
ContextLineBg: lipgloss.Color("#212121"),
|
||||
HunkLineBg: lipgloss.Color("#212121"),
|
||||
HunkLineFg: lipgloss.Color("#a0a0a0"),
|
||||
RemovedFg: lipgloss.Color("#7C4444"),
|
||||
AddedFg: lipgloss.Color("#478247"),
|
||||
LineNumberFg: lipgloss.Color("#888888"),
|
||||
HighlightStyle: "dracula",
|
||||
RemovedHighlightBg: lipgloss.Color("#612726"),
|
||||
AddedHighlightBg: lipgloss.Color("#256125"),
|
||||
RemovedLineNumberBg: lipgloss.Color("#332929"),
|
||||
AddedLineNamerBg: lipgloss.Color("#293229"),
|
||||
RemovedHighlightFg: lipgloss.Color("#FADADD"),
|
||||
AddedHighlightFg: lipgloss.Color("#DAFADA"),
|
||||
}
|
||||
|
||||
// Apply all provided options
|
||||
for _, opt := range opts {
|
||||
opt(&config)
|
||||
}
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
// Style option functions
|
||||
func WithFileNameFg(color lipgloss.Color) StyleOption {
|
||||
return func(s *StyleConfig) { s.FileNameFg = color }
|
||||
}
|
||||
|
||||
func WithRemovedLineBg(color lipgloss.Color) StyleOption {
|
||||
return func(s *StyleConfig) { s.RemovedLineBg = color }
|
||||
}
|
||||
|
||||
func WithAddedLineBg(color lipgloss.Color) StyleOption {
|
||||
return func(s *StyleConfig) { s.AddedLineBg = color }
|
||||
}
|
||||
|
||||
func WithContextLineBg(color lipgloss.Color) StyleOption {
|
||||
return func(s *StyleConfig) { s.ContextLineBg = color }
|
||||
}
|
||||
|
||||
func WithRemovedFg(color lipgloss.Color) StyleOption {
|
||||
return func(s *StyleConfig) { s.RemovedFg = color }
|
||||
}
|
||||
|
||||
func WithAddedFg(color lipgloss.Color) StyleOption {
|
||||
return func(s *StyleConfig) { s.AddedFg = color }
|
||||
}
|
||||
|
||||
func WithLineNumberFg(color lipgloss.Color) StyleOption {
|
||||
return func(s *StyleConfig) { s.LineNumberFg = color }
|
||||
}
|
||||
|
||||
func WithHighlightStyle(style string) StyleOption {
|
||||
return func(s *StyleConfig) { s.HighlightStyle = style }
|
||||
}
|
||||
|
||||
func WithRemovedHighlightColors(bg, fg lipgloss.Color) StyleOption {
|
||||
return func(s *StyleConfig) {
|
||||
s.RemovedHighlightBg = bg
|
||||
s.RemovedHighlightFg = fg
|
||||
}
|
||||
}
|
||||
|
||||
func WithAddedHighlightColors(bg, fg lipgloss.Color) StyleOption {
|
||||
return func(s *StyleConfig) {
|
||||
s.AddedHighlightBg = bg
|
||||
s.AddedHighlightFg = fg
|
||||
}
|
||||
}
|
||||
|
||||
func WithRemovedLineNumberBg(color lipgloss.Color) StyleOption {
|
||||
return func(s *StyleConfig) { s.RemovedLineNumberBg = color }
|
||||
}
|
||||
|
||||
func WithAddedLineNumberBg(color lipgloss.Color) StyleOption {
|
||||
return func(s *StyleConfig) { s.AddedLineNamerBg = color }
|
||||
}
|
||||
|
||||
func WithHunkLineBg(color lipgloss.Color) StyleOption {
|
||||
return func(s *StyleConfig) { s.HunkLineBg = color }
|
||||
}
|
||||
|
||||
func WithHunkLineFg(color lipgloss.Color) StyleOption {
|
||||
return func(s *StyleConfig) { s.HunkLineFg = color }
|
||||
}
|
||||
|
||||
func WithShowHeader(show bool) StyleOption {
|
||||
return func(s *StyleConfig) { s.ShowHeader = show }
|
||||
}
|
||||
|
||||
func WithShowHunkHeader(show bool) StyleOption {
|
||||
return func(s *StyleConfig) { s.ShowHunkHeader = show }
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Parse Configuration
|
||||
// -------------------------------------------------------------------------
|
||||
@@ -233,7 +97,6 @@ func WithContextSize(size int) ParseOption {
|
||||
// SideBySideConfig configures the rendering of side-by-side diffs
|
||||
type SideBySideConfig struct {
|
||||
TotalWidth int
|
||||
Style StyleConfig
|
||||
}
|
||||
|
||||
// SideBySideOption modifies a SideBySideConfig
|
||||
@@ -243,7 +106,6 @@ type SideBySideOption func(*SideBySideConfig)
|
||||
func NewSideBySideConfig(opts ...SideBySideOption) SideBySideConfig {
|
||||
config := SideBySideConfig{
|
||||
TotalWidth: 160, // Default width for side-by-side view
|
||||
Style: NewStyleConfig(),
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
@@ -262,20 +124,6 @@ func WithTotalWidth(width int) SideBySideOption {
|
||||
}
|
||||
}
|
||||
|
||||
// WithStyle sets the styling configuration
|
||||
func WithStyle(style StyleConfig) SideBySideOption {
|
||||
return func(s *SideBySideConfig) {
|
||||
s.Style = style
|
||||
}
|
||||
}
|
||||
|
||||
// WithStyleOptions applies the specified style options
|
||||
func WithStyleOptions(opts ...StyleOption) SideBySideOption {
|
||||
return func(s *SideBySideConfig) {
|
||||
s.Style = NewStyleConfig(opts...)
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Diff Parsing
|
||||
// -------------------------------------------------------------------------
|
||||
@@ -382,7 +230,7 @@ func ParseUnifiedDiff(diff string) (DiffResult, error) {
|
||||
}
|
||||
|
||||
// HighlightIntralineChanges updates lines in a hunk to show character-level differences
|
||||
func HighlightIntralineChanges(h *Hunk, style StyleConfig) {
|
||||
func HighlightIntralineChanges(h *Hunk) {
|
||||
var updated []DiffLine
|
||||
dmp := diffmatchpatch.New()
|
||||
|
||||
@@ -476,6 +324,8 @@ func pairLines(lines []DiffLine) []linePair {
|
||||
|
||||
// SyntaxHighlight applies syntax highlighting to text based on file extension
|
||||
func SyntaxHighlight(w io.Writer, source, fileName, formatter string, bg lipgloss.TerminalColor) error {
|
||||
t := theme.CurrentTheme()
|
||||
|
||||
// Determine the language lexer to use
|
||||
l := lexers.Match(fileName)
|
||||
if l == nil {
|
||||
@@ -491,93 +341,175 @@ func SyntaxHighlight(w io.Writer, source, fileName, formatter string, bg lipglos
|
||||
if f == nil {
|
||||
f = formatters.Fallback
|
||||
}
|
||||
theme := `
|
||||
<style name="vscode-dark-plus">
|
||||
<!-- Base colors -->
|
||||
<entry type="Background" style="bg:#1E1E1E"/>
|
||||
<entry type="Text" style="#D4D4D4"/>
|
||||
<entry type="Other" style="#D4D4D4"/>
|
||||
<entry type="Error" style="#F44747"/>
|
||||
<!-- Keywords - using the Control flow / Special keywords color -->
|
||||
<entry type="Keyword" style="#C586C0"/>
|
||||
<entry type="KeywordConstant" style="#4FC1FF"/>
|
||||
<entry type="KeywordDeclaration" style="#C586C0"/>
|
||||
<entry type="KeywordNamespace" style="#C586C0"/>
|
||||
<entry type="KeywordPseudo" style="#C586C0"/>
|
||||
<entry type="KeywordReserved" style="#C586C0"/>
|
||||
<entry type="KeywordType" style="#4EC9B0"/>
|
||||
<!-- Names -->
|
||||
<entry type="Name" style="#D4D4D4"/>
|
||||
<entry type="NameAttribute" style="#9CDCFE"/>
|
||||
<entry type="NameBuiltin" style="#4EC9B0"/>
|
||||
<entry type="NameBuiltinPseudo" style="#9CDCFE"/>
|
||||
<entry type="NameClass" style="#4EC9B0"/>
|
||||
<entry type="NameConstant" style="#4FC1FF"/>
|
||||
<entry type="NameDecorator" style="#DCDCAA"/>
|
||||
<entry type="NameEntity" style="#9CDCFE"/>
|
||||
<entry type="NameException" style="#4EC9B0"/>
|
||||
<entry type="NameFunction" style="#DCDCAA"/>
|
||||
<entry type="NameLabel" style="#C8C8C8"/>
|
||||
<entry type="NameNamespace" style="#4EC9B0"/>
|
||||
<entry type="NameOther" style="#9CDCFE"/>
|
||||
<entry type="NameTag" style="#569CD6"/>
|
||||
<entry type="NameVariable" style="#9CDCFE"/>
|
||||
<entry type="NameVariableClass" style="#9CDCFE"/>
|
||||
<entry type="NameVariableGlobal" style="#9CDCFE"/>
|
||||
<entry type="NameVariableInstance" style="#9CDCFE"/>
|
||||
<!-- Literals -->
|
||||
<entry type="Literal" style="#CE9178"/>
|
||||
<entry type="LiteralDate" style="#CE9178"/>
|
||||
<entry type="LiteralString" style="#CE9178"/>
|
||||
<entry type="LiteralStringBacktick" style="#CE9178"/>
|
||||
<entry type="LiteralStringChar" style="#CE9178"/>
|
||||
<entry type="LiteralStringDoc" style="#CE9178"/>
|
||||
<entry type="LiteralStringDouble" style="#CE9178"/>
|
||||
<entry type="LiteralStringEscape" style="#d7ba7d"/>
|
||||
<entry type="LiteralStringHeredoc" style="#CE9178"/>
|
||||
<entry type="LiteralStringInterpol" style="#CE9178"/>
|
||||
<entry type="LiteralStringOther" style="#CE9178"/>
|
||||
<entry type="LiteralStringRegex" style="#d16969"/>
|
||||
<entry type="LiteralStringSingle" style="#CE9178"/>
|
||||
<entry type="LiteralStringSymbol" style="#CE9178"/>
|
||||
<!-- Numbers - using the numberLiteral color -->
|
||||
<entry type="LiteralNumber" style="#b5cea8"/>
|
||||
<entry type="LiteralNumberBin" style="#b5cea8"/>
|
||||
<entry type="LiteralNumberFloat" style="#b5cea8"/>
|
||||
<entry type="LiteralNumberHex" style="#b5cea8"/>
|
||||
<entry type="LiteralNumberInteger" style="#b5cea8"/>
|
||||
<entry type="LiteralNumberIntegerLong" style="#b5cea8"/>
|
||||
<entry type="LiteralNumberOct" style="#b5cea8"/>
|
||||
<!-- Operators -->
|
||||
<entry type="Operator" style="#D4D4D4"/>
|
||||
<entry type="OperatorWord" style="#C586C0"/>
|
||||
<entry type="Punctuation" style="#D4D4D4"/>
|
||||
<!-- Comments - standard VSCode Dark+ comment color -->
|
||||
<entry type="Comment" style="#6A9955"/>
|
||||
<entry type="CommentHashbang" style="#6A9955"/>
|
||||
<entry type="CommentMultiline" style="#6A9955"/>
|
||||
<entry type="CommentSingle" style="#6A9955"/>
|
||||
<entry type="CommentSpecial" style="#6A9955"/>
|
||||
<entry type="CommentPreproc" style="#C586C0"/>
|
||||
<!-- Generic styles -->
|
||||
<entry type="Generic" style="#D4D4D4"/>
|
||||
<entry type="GenericDeleted" style="#F44747"/>
|
||||
<entry type="GenericEmph" style="italic #D4D4D4"/>
|
||||
<entry type="GenericError" style="#F44747"/>
|
||||
<entry type="GenericHeading" style="bold #D4D4D4"/>
|
||||
<entry type="GenericInserted" style="#b5cea8"/>
|
||||
<entry type="GenericOutput" style="#808080"/>
|
||||
<entry type="GenericPrompt" style="#D4D4D4"/>
|
||||
<entry type="GenericStrong" style="bold #D4D4D4"/>
|
||||
<entry type="GenericSubheading" style="bold #D4D4D4"/>
|
||||
<entry type="GenericTraceback" style="#F44747"/>
|
||||
<entry type="GenericUnderline" style="underline"/>
|
||||
<entry type="TextWhitespace" style="#D4D4D4"/>
|
||||
</style>
|
||||
`
|
||||
|
||||
r := strings.NewReader(theme)
|
||||
// Dynamic theme based on current theme values
|
||||
syntaxThemeXml := fmt.Sprintf(`
|
||||
<style name="opencode-theme">
|
||||
<!-- Base colors -->
|
||||
<entry type="Background" style="bg:%s"/>
|
||||
<entry type="Text" style="%s"/>
|
||||
<entry type="Other" style="%s"/>
|
||||
<entry type="Error" style="%s"/>
|
||||
<!-- Keywords -->
|
||||
<entry type="Keyword" style="%s"/>
|
||||
<entry type="KeywordConstant" style="%s"/>
|
||||
<entry type="KeywordDeclaration" style="%s"/>
|
||||
<entry type="KeywordNamespace" style="%s"/>
|
||||
<entry type="KeywordPseudo" style="%s"/>
|
||||
<entry type="KeywordReserved" style="%s"/>
|
||||
<entry type="KeywordType" style="%s"/>
|
||||
<!-- Names -->
|
||||
<entry type="Name" style="%s"/>
|
||||
<entry type="NameAttribute" style="%s"/>
|
||||
<entry type="NameBuiltin" style="%s"/>
|
||||
<entry type="NameBuiltinPseudo" style="%s"/>
|
||||
<entry type="NameClass" style="%s"/>
|
||||
<entry type="NameConstant" style="%s"/>
|
||||
<entry type="NameDecorator" style="%s"/>
|
||||
<entry type="NameEntity" style="%s"/>
|
||||
<entry type="NameException" style="%s"/>
|
||||
<entry type="NameFunction" style="%s"/>
|
||||
<entry type="NameLabel" style="%s"/>
|
||||
<entry type="NameNamespace" style="%s"/>
|
||||
<entry type="NameOther" style="%s"/>
|
||||
<entry type="NameTag" style="%s"/>
|
||||
<entry type="NameVariable" style="%s"/>
|
||||
<entry type="NameVariableClass" style="%s"/>
|
||||
<entry type="NameVariableGlobal" style="%s"/>
|
||||
<entry type="NameVariableInstance" style="%s"/>
|
||||
<!-- Literals -->
|
||||
<entry type="Literal" style="%s"/>
|
||||
<entry type="LiteralDate" style="%s"/>
|
||||
<entry type="LiteralString" style="%s"/>
|
||||
<entry type="LiteralStringBacktick" style="%s"/>
|
||||
<entry type="LiteralStringChar" style="%s"/>
|
||||
<entry type="LiteralStringDoc" style="%s"/>
|
||||
<entry type="LiteralStringDouble" style="%s"/>
|
||||
<entry type="LiteralStringEscape" style="%s"/>
|
||||
<entry type="LiteralStringHeredoc" style="%s"/>
|
||||
<entry type="LiteralStringInterpol" style="%s"/>
|
||||
<entry type="LiteralStringOther" style="%s"/>
|
||||
<entry type="LiteralStringRegex" style="%s"/>
|
||||
<entry type="LiteralStringSingle" style="%s"/>
|
||||
<entry type="LiteralStringSymbol" style="%s"/>
|
||||
<!-- Numbers -->
|
||||
<entry type="LiteralNumber" style="%s"/>
|
||||
<entry type="LiteralNumberBin" style="%s"/>
|
||||
<entry type="LiteralNumberFloat" style="%s"/>
|
||||
<entry type="LiteralNumberHex" style="%s"/>
|
||||
<entry type="LiteralNumberInteger" style="%s"/>
|
||||
<entry type="LiteralNumberIntegerLong" style="%s"/>
|
||||
<entry type="LiteralNumberOct" style="%s"/>
|
||||
<!-- Operators -->
|
||||
<entry type="Operator" style="%s"/>
|
||||
<entry type="OperatorWord" style="%s"/>
|
||||
<entry type="Punctuation" style="%s"/>
|
||||
<!-- Comments -->
|
||||
<entry type="Comment" style="%s"/>
|
||||
<entry type="CommentHashbang" style="%s"/>
|
||||
<entry type="CommentMultiline" style="%s"/>
|
||||
<entry type="CommentSingle" style="%s"/>
|
||||
<entry type="CommentSpecial" style="%s"/>
|
||||
<entry type="CommentPreproc" style="%s"/>
|
||||
<!-- Generic styles -->
|
||||
<entry type="Generic" style="%s"/>
|
||||
<entry type="GenericDeleted" style="%s"/>
|
||||
<entry type="GenericEmph" style="italic %s"/>
|
||||
<entry type="GenericError" style="%s"/>
|
||||
<entry type="GenericHeading" style="bold %s"/>
|
||||
<entry type="GenericInserted" style="%s"/>
|
||||
<entry type="GenericOutput" style="%s"/>
|
||||
<entry type="GenericPrompt" style="%s"/>
|
||||
<entry type="GenericStrong" style="bold %s"/>
|
||||
<entry type="GenericSubheading" style="bold %s"/>
|
||||
<entry type="GenericTraceback" style="%s"/>
|
||||
<entry type="GenericUnderline" style="underline"/>
|
||||
<entry type="TextWhitespace" style="%s"/>
|
||||
</style>
|
||||
`,
|
||||
getColor(t.Background()), // Background
|
||||
getColor(t.Text()), // Text
|
||||
getColor(t.Text()), // Other
|
||||
getColor(t.Error()), // Error
|
||||
|
||||
getColor(t.SyntaxKeyword()), // Keyword
|
||||
getColor(t.SyntaxKeyword()), // KeywordConstant
|
||||
getColor(t.SyntaxKeyword()), // KeywordDeclaration
|
||||
getColor(t.SyntaxKeyword()), // KeywordNamespace
|
||||
getColor(t.SyntaxKeyword()), // KeywordPseudo
|
||||
getColor(t.SyntaxKeyword()), // KeywordReserved
|
||||
getColor(t.SyntaxType()), // KeywordType
|
||||
|
||||
getColor(t.Text()), // Name
|
||||
getColor(t.SyntaxVariable()), // NameAttribute
|
||||
getColor(t.SyntaxType()), // NameBuiltin
|
||||
getColor(t.SyntaxVariable()), // NameBuiltinPseudo
|
||||
getColor(t.SyntaxType()), // NameClass
|
||||
getColor(t.SyntaxVariable()), // NameConstant
|
||||
getColor(t.SyntaxFunction()), // NameDecorator
|
||||
getColor(t.SyntaxVariable()), // NameEntity
|
||||
getColor(t.SyntaxType()), // NameException
|
||||
getColor(t.SyntaxFunction()), // NameFunction
|
||||
getColor(t.Text()), // NameLabel
|
||||
getColor(t.SyntaxType()), // NameNamespace
|
||||
getColor(t.SyntaxVariable()), // NameOther
|
||||
getColor(t.SyntaxKeyword()), // NameTag
|
||||
getColor(t.SyntaxVariable()), // NameVariable
|
||||
getColor(t.SyntaxVariable()), // NameVariableClass
|
||||
getColor(t.SyntaxVariable()), // NameVariableGlobal
|
||||
getColor(t.SyntaxVariable()), // NameVariableInstance
|
||||
|
||||
getColor(t.SyntaxString()), // Literal
|
||||
getColor(t.SyntaxString()), // LiteralDate
|
||||
getColor(t.SyntaxString()), // LiteralString
|
||||
getColor(t.SyntaxString()), // LiteralStringBacktick
|
||||
getColor(t.SyntaxString()), // LiteralStringChar
|
||||
getColor(t.SyntaxString()), // LiteralStringDoc
|
||||
getColor(t.SyntaxString()), // LiteralStringDouble
|
||||
getColor(t.SyntaxString()), // LiteralStringEscape
|
||||
getColor(t.SyntaxString()), // LiteralStringHeredoc
|
||||
getColor(t.SyntaxString()), // LiteralStringInterpol
|
||||
getColor(t.SyntaxString()), // LiteralStringOther
|
||||
getColor(t.SyntaxString()), // LiteralStringRegex
|
||||
getColor(t.SyntaxString()), // LiteralStringSingle
|
||||
getColor(t.SyntaxString()), // LiteralStringSymbol
|
||||
|
||||
getColor(t.SyntaxNumber()), // LiteralNumber
|
||||
getColor(t.SyntaxNumber()), // LiteralNumberBin
|
||||
getColor(t.SyntaxNumber()), // LiteralNumberFloat
|
||||
getColor(t.SyntaxNumber()), // LiteralNumberHex
|
||||
getColor(t.SyntaxNumber()), // LiteralNumberInteger
|
||||
getColor(t.SyntaxNumber()), // LiteralNumberIntegerLong
|
||||
getColor(t.SyntaxNumber()), // LiteralNumberOct
|
||||
|
||||
getColor(t.SyntaxOperator()), // Operator
|
||||
getColor(t.SyntaxKeyword()), // OperatorWord
|
||||
getColor(t.SyntaxPunctuation()), // Punctuation
|
||||
|
||||
getColor(t.SyntaxComment()), // Comment
|
||||
getColor(t.SyntaxComment()), // CommentHashbang
|
||||
getColor(t.SyntaxComment()), // CommentMultiline
|
||||
getColor(t.SyntaxComment()), // CommentSingle
|
||||
getColor(t.SyntaxComment()), // CommentSpecial
|
||||
getColor(t.SyntaxKeyword()), // CommentPreproc
|
||||
|
||||
getColor(t.Text()), // Generic
|
||||
getColor(t.Error()), // GenericDeleted
|
||||
getColor(t.Text()), // GenericEmph
|
||||
getColor(t.Error()), // GenericError
|
||||
getColor(t.Text()), // GenericHeading
|
||||
getColor(t.Success()), // GenericInserted
|
||||
getColor(t.TextMuted()), // GenericOutput
|
||||
getColor(t.Text()), // GenericPrompt
|
||||
getColor(t.Text()), // GenericStrong
|
||||
getColor(t.Text()), // GenericSubheading
|
||||
getColor(t.Error()), // GenericTraceback
|
||||
getColor(t.Text()), // TextWhitespace
|
||||
)
|
||||
|
||||
r := strings.NewReader(syntaxThemeXml)
|
||||
style := chroma.MustNewXMLStyle(r)
|
||||
|
||||
// Modify the style to use the provided background
|
||||
s, err := style.Builder().Transform(
|
||||
func(t chroma.StyleEntry) chroma.StyleEntry {
|
||||
@@ -599,6 +531,14 @@ func SyntaxHighlight(w io.Writer, source, fileName, formatter string, bg lipglos
|
||||
return f.Format(w, s, it)
|
||||
}
|
||||
|
||||
// getColor returns the appropriate hex color string based on terminal background
|
||||
func getColor(adaptiveColor lipgloss.AdaptiveColor) string {
|
||||
if lipgloss.HasDarkBackground() {
|
||||
return adaptiveColor.Dark
|
||||
}
|
||||
return adaptiveColor.Light
|
||||
}
|
||||
|
||||
// highlightLine applies syntax highlighting to a single line
|
||||
func highlightLine(fileName string, line string, bg lipgloss.TerminalColor) string {
|
||||
var buf bytes.Buffer
|
||||
@@ -610,11 +550,11 @@ func highlightLine(fileName string, line string, bg lipgloss.TerminalColor) stri
|
||||
}
|
||||
|
||||
// createStyles generates the lipgloss styles needed for rendering diffs
|
||||
func createStyles(config StyleConfig) (removedLineStyle, addedLineStyle, contextLineStyle, lineNumberStyle lipgloss.Style) {
|
||||
removedLineStyle = lipgloss.NewStyle().Background(config.RemovedLineBg)
|
||||
addedLineStyle = lipgloss.NewStyle().Background(config.AddedLineBg)
|
||||
contextLineStyle = lipgloss.NewStyle().Background(config.ContextLineBg)
|
||||
lineNumberStyle = lipgloss.NewStyle().Foreground(config.LineNumberFg)
|
||||
func createStyles(t theme.Theme) (removedLineStyle, addedLineStyle, contextLineStyle, lineNumberStyle lipgloss.Style) {
|
||||
removedLineStyle = lipgloss.NewStyle().Background(t.DiffRemovedBg())
|
||||
addedLineStyle = lipgloss.NewStyle().Background(t.DiffAddedBg())
|
||||
contextLineStyle = lipgloss.NewStyle().Background(t.DiffContextBg())
|
||||
lineNumberStyle = lipgloss.NewStyle().Foreground(t.DiffLineNumber())
|
||||
|
||||
return
|
||||
}
|
||||
@@ -624,8 +564,7 @@ func createStyles(config StyleConfig) (removedLineStyle, addedLineStyle, context
|
||||
// -------------------------------------------------------------------------
|
||||
|
||||
// applyHighlighting applies intra-line highlighting to a piece of text
|
||||
func applyHighlighting(content string, segments []Segment, segmentType LineType, highlightBg lipgloss.Color,
|
||||
) string {
|
||||
func applyHighlighting(content string, segments []Segment, segmentType LineType, highlightBg lipgloss.AdaptiveColor) string {
|
||||
// Find all ANSI sequences in the content
|
||||
ansiRegex := regexp.MustCompile(`\x1b(?:[@-Z\\-_]|\[[0-9?]*(?:;[0-9?]*)*[@-~])`)
|
||||
ansiMatches := ansiRegex.FindAllStringIndex(content, -1)
|
||||
@@ -663,6 +602,10 @@ func applyHighlighting(content string, segments []Segment, segmentType LineType,
|
||||
inSelection := false
|
||||
currentPos := 0
|
||||
|
||||
// Get the appropriate color based on terminal background
|
||||
bgColor := lipgloss.Color(getColor(highlightBg))
|
||||
fgColor := lipgloss.Color(getColor(theme.CurrentTheme().Background()))
|
||||
|
||||
for i := 0; i < len(content); {
|
||||
// Check if we're at an ANSI sequence
|
||||
isAnsi := false
|
||||
@@ -697,12 +640,17 @@ func applyHighlighting(content string, segments []Segment, segmentType LineType,
|
||||
// Get the current styling
|
||||
currentStyle := ansiSequences[currentPos]
|
||||
|
||||
// Apply background highlight
|
||||
// Apply foreground and background highlight
|
||||
sb.WriteString("\x1b[38;2;")
|
||||
r, g, b, _ := fgColor.RGBA()
|
||||
sb.WriteString(fmt.Sprintf("%d;%d;%dm", r>>8, g>>8, b>>8))
|
||||
sb.WriteString("\x1b[48;2;")
|
||||
r, g, b, _ := highlightBg.RGBA()
|
||||
r, g, b, _ = bgColor.RGBA()
|
||||
sb.WriteString(fmt.Sprintf("%d;%d;%dm", r>>8, g>>8, b>>8))
|
||||
sb.WriteString(char)
|
||||
sb.WriteString("\x1b[49m") // Reset only background
|
||||
|
||||
// Full reset of all attributes to ensure clean state
|
||||
sb.WriteString("\x1b[0m")
|
||||
|
||||
// Reapply the original ANSI sequence
|
||||
sb.WriteString(currentStyle)
|
||||
@@ -718,50 +666,98 @@ func applyHighlighting(content string, segments []Segment, segmentType LineType,
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// renderLeftColumn formats the left side of a side-by-side diff
|
||||
func renderLeftColumn(fileName string, dl *DiffLine, colWidth int, styles StyleConfig) string {
|
||||
// renderDiffColumnLine is a helper function that handles the common logic for rendering diff columns
|
||||
func renderDiffColumnLine(
|
||||
fileName string,
|
||||
dl *DiffLine,
|
||||
colWidth int,
|
||||
isLeftColumn bool,
|
||||
t theme.Theme,
|
||||
) string {
|
||||
if dl == nil {
|
||||
contextLineStyle := lipgloss.NewStyle().Background(styles.ContextLineBg)
|
||||
contextLineStyle := lipgloss.NewStyle().Background(t.DiffContextBg())
|
||||
return contextLineStyle.Width(colWidth).Render("")
|
||||
}
|
||||
|
||||
removedLineStyle, _, contextLineStyle, lineNumberStyle := createStyles(styles)
|
||||
removedLineStyle, addedLineStyle, contextLineStyle, lineNumberStyle := createStyles(t)
|
||||
|
||||
// Determine line style based on line type
|
||||
// Determine line style based on line type and column
|
||||
var marker string
|
||||
var bgStyle lipgloss.Style
|
||||
switch dl.Kind {
|
||||
case LineRemoved:
|
||||
marker = removedLineStyle.Foreground(styles.RemovedFg).Render("-")
|
||||
bgStyle = removedLineStyle
|
||||
lineNumberStyle = lineNumberStyle.Foreground(styles.RemovedFg).Background(styles.RemovedLineNumberBg)
|
||||
case LineAdded:
|
||||
marker = "?"
|
||||
bgStyle = contextLineStyle
|
||||
case LineContext:
|
||||
marker = contextLineStyle.Render(" ")
|
||||
bgStyle = contextLineStyle
|
||||
var lineNum string
|
||||
var highlightType LineType
|
||||
var highlightColor lipgloss.AdaptiveColor
|
||||
|
||||
if isLeftColumn {
|
||||
// Left column logic
|
||||
switch dl.Kind {
|
||||
case LineRemoved:
|
||||
marker = "-"
|
||||
bgStyle = removedLineStyle
|
||||
lineNumberStyle = lineNumberStyle.Foreground(t.DiffRemoved()).Background(t.DiffRemovedLineNumberBg())
|
||||
highlightType = LineRemoved
|
||||
highlightColor = t.DiffHighlightRemoved()
|
||||
case LineAdded:
|
||||
marker = "?"
|
||||
bgStyle = contextLineStyle
|
||||
case LineContext:
|
||||
marker = " "
|
||||
bgStyle = contextLineStyle
|
||||
}
|
||||
|
||||
// Format line number for left column
|
||||
if dl.OldLineNo > 0 {
|
||||
lineNum = fmt.Sprintf("%6d", dl.OldLineNo)
|
||||
}
|
||||
} else {
|
||||
// Right column logic
|
||||
switch dl.Kind {
|
||||
case LineAdded:
|
||||
marker = "+"
|
||||
bgStyle = addedLineStyle
|
||||
lineNumberStyle = lineNumberStyle.Foreground(t.DiffAdded()).Background(t.DiffAddedLineNumberBg())
|
||||
highlightType = LineAdded
|
||||
highlightColor = t.DiffHighlightAdded()
|
||||
case LineRemoved:
|
||||
marker = "?"
|
||||
bgStyle = contextLineStyle
|
||||
case LineContext:
|
||||
marker = " "
|
||||
bgStyle = contextLineStyle
|
||||
}
|
||||
|
||||
// Format line number for right column
|
||||
if dl.NewLineNo > 0 {
|
||||
lineNum = fmt.Sprintf("%6d", dl.NewLineNo)
|
||||
}
|
||||
}
|
||||
|
||||
// Format line number
|
||||
lineNum := ""
|
||||
if dl.OldLineNo > 0 {
|
||||
lineNum = fmt.Sprintf("%6d", dl.OldLineNo)
|
||||
// Style the marker based on line type
|
||||
var styledMarker string
|
||||
switch dl.Kind {
|
||||
case LineRemoved:
|
||||
styledMarker = removedLineStyle.Foreground(t.DiffRemoved()).Render(marker)
|
||||
case LineAdded:
|
||||
styledMarker = addedLineStyle.Foreground(t.DiffAdded()).Render(marker)
|
||||
case LineContext:
|
||||
styledMarker = contextLineStyle.Foreground(t.TextMuted()).Render(marker)
|
||||
default:
|
||||
styledMarker = marker
|
||||
}
|
||||
|
||||
// Create the line prefix
|
||||
prefix := lineNumberStyle.Render(lineNum + " " + marker)
|
||||
prefix := lineNumberStyle.Render(lineNum + " " + styledMarker)
|
||||
|
||||
// Apply syntax highlighting
|
||||
content := highlightLine(fileName, dl.Content, bgStyle.GetBackground())
|
||||
|
||||
// Apply intra-line highlighting for removed lines
|
||||
if dl.Kind == LineRemoved && len(dl.Segments) > 0 {
|
||||
content = applyHighlighting(content, dl.Segments, LineRemoved, styles.RemovedHighlightBg)
|
||||
// Apply intra-line highlighting if needed
|
||||
if (dl.Kind == LineRemoved && isLeftColumn || dl.Kind == LineAdded && !isLeftColumn) && len(dl.Segments) > 0 {
|
||||
content = applyHighlighting(content, dl.Segments, highlightType, highlightColor)
|
||||
}
|
||||
|
||||
// Add a padding space for removed lines
|
||||
if dl.Kind == LineRemoved {
|
||||
// Add a padding space for added/removed lines
|
||||
if (dl.Kind == LineRemoved && isLeftColumn) || (dl.Kind == LineAdded && !isLeftColumn) {
|
||||
content = bgStyle.Render(" ") + content
|
||||
}
|
||||
|
||||
@@ -771,67 +767,19 @@ func renderLeftColumn(fileName string, dl *DiffLine, colWidth int, styles StyleC
|
||||
ansi.Truncate(
|
||||
lineText,
|
||||
colWidth,
|
||||
lipgloss.NewStyle().Background(styles.HunkLineBg).Foreground(styles.HunkLineFg).Render("..."),
|
||||
lipgloss.NewStyle().Background(bgStyle.GetBackground()).Foreground(t.TextMuted()).Render("..."),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
// renderLeftColumn formats the left side of a side-by-side diff
|
||||
func renderLeftColumn(fileName string, dl *DiffLine, colWidth int) string {
|
||||
return renderDiffColumnLine(fileName, dl, colWidth, true, theme.CurrentTheme())
|
||||
}
|
||||
|
||||
// renderRightColumn formats the right side of a side-by-side diff
|
||||
func renderRightColumn(fileName string, dl *DiffLine, colWidth int, styles StyleConfig) string {
|
||||
if dl == nil {
|
||||
contextLineStyle := lipgloss.NewStyle().Background(styles.ContextLineBg)
|
||||
return contextLineStyle.Width(colWidth).Render("")
|
||||
}
|
||||
|
||||
_, addedLineStyle, contextLineStyle, lineNumberStyle := createStyles(styles)
|
||||
|
||||
// Determine line style based on line type
|
||||
var marker string
|
||||
var bgStyle lipgloss.Style
|
||||
switch dl.Kind {
|
||||
case LineAdded:
|
||||
marker = addedLineStyle.Foreground(styles.AddedFg).Render("+")
|
||||
bgStyle = addedLineStyle
|
||||
lineNumberStyle = lineNumberStyle.Foreground(styles.AddedFg).Background(styles.AddedLineNamerBg)
|
||||
case LineRemoved:
|
||||
marker = "?"
|
||||
bgStyle = contextLineStyle
|
||||
case LineContext:
|
||||
marker = contextLineStyle.Render(" ")
|
||||
bgStyle = contextLineStyle
|
||||
}
|
||||
|
||||
// Format line number
|
||||
lineNum := ""
|
||||
if dl.NewLineNo > 0 {
|
||||
lineNum = fmt.Sprintf("%6d", dl.NewLineNo)
|
||||
}
|
||||
|
||||
// Create the line prefix
|
||||
prefix := lineNumberStyle.Render(lineNum + " " + marker)
|
||||
|
||||
// Apply syntax highlighting
|
||||
content := highlightLine(fileName, dl.Content, bgStyle.GetBackground())
|
||||
|
||||
// Apply intra-line highlighting for added lines
|
||||
if dl.Kind == LineAdded && len(dl.Segments) > 0 {
|
||||
content = applyHighlighting(content, dl.Segments, LineAdded, styles.AddedHighlightBg)
|
||||
}
|
||||
|
||||
// Add a padding space for added lines
|
||||
if dl.Kind == LineAdded {
|
||||
content = bgStyle.Render(" ") + content
|
||||
}
|
||||
|
||||
// Create the final line and truncate if needed
|
||||
lineText := prefix + content
|
||||
return bgStyle.MaxHeight(1).Width(colWidth).Render(
|
||||
ansi.Truncate(
|
||||
lineText,
|
||||
colWidth,
|
||||
lipgloss.NewStyle().Background(styles.HunkLineBg).Foreground(styles.HunkLineFg).Render("..."),
|
||||
),
|
||||
)
|
||||
func renderRightColumn(fileName string, dl *DiffLine, colWidth int) string {
|
||||
return renderDiffColumnLine(fileName, dl, colWidth, false, theme.CurrentTheme())
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
@@ -848,7 +796,7 @@ func RenderSideBySideHunk(fileName string, h Hunk, opts ...SideBySideOption) str
|
||||
copy(hunkCopy.Lines, h.Lines)
|
||||
|
||||
// Highlight changes within lines
|
||||
HighlightIntralineChanges(&hunkCopy, config.Style)
|
||||
HighlightIntralineChanges(&hunkCopy)
|
||||
|
||||
// Pair lines for side-by-side display
|
||||
pairs := pairLines(hunkCopy.Lines)
|
||||
@@ -860,8 +808,8 @@ func RenderSideBySideHunk(fileName string, h Hunk, opts ...SideBySideOption) str
|
||||
rightWidth := config.TotalWidth - colWidth
|
||||
var sb strings.Builder
|
||||
for _, p := range pairs {
|
||||
leftStr := renderLeftColumn(fileName, p.left, leftWidth, config.Style)
|
||||
rightStr := renderRightColumn(fileName, p.right, rightWidth, config.Style)
|
||||
leftStr := renderLeftColumn(fileName, p.left, leftWidth)
|
||||
rightStr := renderRightColumn(fileName, p.right, rightWidth)
|
||||
sb.WriteString(leftStr + rightStr + "\n")
|
||||
}
|
||||
|
||||
@@ -870,6 +818,7 @@ func RenderSideBySideHunk(fileName string, h Hunk, opts ...SideBySideOption) str
|
||||
|
||||
// FormatDiff creates a side-by-side formatted view of a diff
|
||||
func FormatDiff(diffText string, opts ...SideBySideOption) (string, error) {
|
||||
t := theme.CurrentTheme()
|
||||
diffResult, err := ParseUnifiedDiff(diffText)
|
||||
if err != nil {
|
||||
return "", err
|
||||
@@ -877,53 +826,14 @@ func FormatDiff(diffText string, opts ...SideBySideOption) (string, error) {
|
||||
|
||||
var sb strings.Builder
|
||||
config := NewSideBySideConfig(opts...)
|
||||
|
||||
if config.Style.ShowHeader {
|
||||
removeIcon := lipgloss.NewStyle().
|
||||
Background(config.Style.RemovedLineBg).
|
||||
Foreground(config.Style.RemovedFg).
|
||||
Render("⏹")
|
||||
addIcon := lipgloss.NewStyle().
|
||||
Background(config.Style.AddedLineBg).
|
||||
Foreground(config.Style.AddedFg).
|
||||
Render("⏹")
|
||||
|
||||
fileName := lipgloss.NewStyle().
|
||||
Background(config.Style.ContextLineBg).
|
||||
Foreground(config.Style.FileNameFg).
|
||||
Render(" " + diffResult.OldFile)
|
||||
for _, h := range diffResult.Hunks {
|
||||
sb.WriteString(
|
||||
lipgloss.NewStyle().
|
||||
Background(config.Style.ContextLineBg).
|
||||
Padding(0, 1, 0, 1).
|
||||
Foreground(config.Style.FileNameFg).
|
||||
BorderStyle(lipgloss.NormalBorder()).
|
||||
BorderTop(true).
|
||||
BorderBottom(true).
|
||||
BorderForeground(config.Style.FileNameFg).
|
||||
BorderBackground(config.Style.ContextLineBg).
|
||||
Background(t.DiffHunkHeader()).
|
||||
Foreground(t.Background()).
|
||||
Width(config.TotalWidth).
|
||||
Render(
|
||||
lipgloss.JoinHorizontal(lipgloss.Top,
|
||||
removeIcon,
|
||||
addIcon,
|
||||
fileName,
|
||||
),
|
||||
) + "\n",
|
||||
Render(h.Header) + "\n",
|
||||
)
|
||||
}
|
||||
|
||||
for _, h := range diffResult.Hunks {
|
||||
// Render hunk header
|
||||
if config.Style.ShowHunkHeader {
|
||||
sb.WriteString(
|
||||
lipgloss.NewStyle().
|
||||
Background(config.Style.HunkLineBg).
|
||||
Foreground(config.Style.HunkLineFg).
|
||||
Width(config.TotalWidth).
|
||||
Render(h.Header) + "\n",
|
||||
)
|
||||
}
|
||||
sb.WriteString(RenderSideBySideHunk(diffResult.OldFile, h, opts...))
|
||||
}
|
||||
|
||||
@@ -938,14 +848,16 @@ func GenerateDiff(beforeContent, afterContent, fileName string) (string, int, in
|
||||
fileName = strings.TrimPrefix(fileName, cwd)
|
||||
fileName = strings.TrimPrefix(fileName, "/")
|
||||
|
||||
edits := udiff.Strings(beforeContent, afterContent)
|
||||
unified, _ := udiff.ToUnified("a/"+fileName, "b/"+fileName, beforeContent, edits, 8)
|
||||
|
||||
var (
|
||||
unified = udiff.Unified("a/"+fileName, "b/"+fileName, beforeContent, afterContent)
|
||||
additions = 0
|
||||
removals = 0
|
||||
)
|
||||
|
||||
lines := strings.Split(unified, "\n")
|
||||
for _, line := range lines {
|
||||
lines := strings.SplitSeq(unified, "\n")
|
||||
for line := range lines {
|
||||
if strings.HasPrefix(line, "+") && !strings.HasPrefix(line, "+++") {
|
||||
additions++
|
||||
} else if strings.HasPrefix(line, "-") && !strings.HasPrefix(line, "---") {
|
||||
|
||||
103
internal/diff/diff_test.go
Normal file
103
internal/diff/diff_test.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package diff
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// TestApplyHighlighting tests the applyHighlighting function with various ANSI sequences
|
||||
func TestApplyHighlighting(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Mock theme colors for testing
|
||||
mockHighlightBg := lipgloss.AdaptiveColor{
|
||||
Dark: "#FF0000", // Red background for highlighting
|
||||
Light: "#FF0000",
|
||||
}
|
||||
|
||||
// Test cases
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
segments []Segment
|
||||
segmentType LineType
|
||||
expectContains string
|
||||
}{
|
||||
{
|
||||
name: "Simple text with no ANSI",
|
||||
content: "This is a test",
|
||||
segments: []Segment{{Start: 0, End: 4, Type: LineAdded}},
|
||||
segmentType: LineAdded,
|
||||
// Should contain full reset sequence after highlighting
|
||||
expectContains: "\x1b[0m",
|
||||
},
|
||||
{
|
||||
name: "Text with existing ANSI foreground",
|
||||
content: "This \x1b[32mis\x1b[0m a test", // "is" in green
|
||||
segments: []Segment{{Start: 5, End: 7, Type: LineAdded}},
|
||||
segmentType: LineAdded,
|
||||
// Should contain full reset sequence after highlighting
|
||||
expectContains: "\x1b[0m",
|
||||
},
|
||||
{
|
||||
name: "Text with existing ANSI background",
|
||||
content: "This \x1b[42mis\x1b[0m a test", // "is" with green background
|
||||
segments: []Segment{{Start: 5, End: 7, Type: LineAdded}},
|
||||
segmentType: LineAdded,
|
||||
// Should contain full reset sequence after highlighting
|
||||
expectContains: "\x1b[0m",
|
||||
},
|
||||
{
|
||||
name: "Text with complex ANSI styling",
|
||||
content: "This \x1b[1;32;45mis\x1b[0m a test", // "is" bold green on magenta
|
||||
segments: []Segment{{Start: 5, End: 7, Type: LineAdded}},
|
||||
segmentType: LineAdded,
|
||||
// Should contain full reset sequence after highlighting
|
||||
expectContains: "\x1b[0m",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc // Capture range variable for parallel testing
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
result := applyHighlighting(tc.content, tc.segments, tc.segmentType, mockHighlightBg)
|
||||
|
||||
// Verify the result contains the expected sequence
|
||||
assert.Contains(t, result, tc.expectContains,
|
||||
"Result should contain full reset sequence")
|
||||
|
||||
// Print the result for manual inspection if needed
|
||||
if t.Failed() {
|
||||
fmt.Printf("Original: %q\nResult: %q\n", tc.content, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestApplyHighlightingWithMultipleSegments tests highlighting multiple segments
|
||||
func TestApplyHighlightingWithMultipleSegments(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Mock theme colors for testing
|
||||
mockHighlightBg := lipgloss.AdaptiveColor{
|
||||
Dark: "#FF0000", // Red background for highlighting
|
||||
Light: "#FF0000",
|
||||
}
|
||||
|
||||
content := "This is a test with multiple segments to highlight"
|
||||
segments := []Segment{
|
||||
{Start: 0, End: 4, Type: LineAdded}, // "This"
|
||||
{Start: 8, End: 9, Type: LineAdded}, // "a"
|
||||
{Start: 15, End: 23, Type: LineAdded}, // "multiple"
|
||||
}
|
||||
|
||||
result := applyHighlighting(content, segments, LineAdded, mockHighlightBg)
|
||||
|
||||
// Verify the result contains the full reset sequence
|
||||
assert.Contains(t, result, "\x1b[0m",
|
||||
"Result should contain full reset sequence")
|
||||
}
|
||||
@@ -1,252 +0,0 @@
|
||||
package history
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/opencode-ai/opencode/internal/db"
|
||||
"github.com/opencode-ai/opencode/internal/pubsub"
|
||||
)
|
||||
|
||||
const (
|
||||
InitialVersion = "initial"
|
||||
)
|
||||
|
||||
type File struct {
|
||||
ID string
|
||||
SessionID string
|
||||
Path string
|
||||
Content string
|
||||
Version string
|
||||
CreatedAt int64
|
||||
UpdatedAt int64
|
||||
}
|
||||
|
||||
type Service interface {
|
||||
pubsub.Suscriber[File]
|
||||
Create(ctx context.Context, sessionID, path, content string) (File, error)
|
||||
CreateVersion(ctx context.Context, sessionID, path, content string) (File, error)
|
||||
Get(ctx context.Context, id string) (File, error)
|
||||
GetByPathAndSession(ctx context.Context, path, sessionID string) (File, error)
|
||||
ListBySession(ctx context.Context, sessionID string) ([]File, error)
|
||||
ListLatestSessionFiles(ctx context.Context, sessionID string) ([]File, error)
|
||||
Update(ctx context.Context, file File) (File, error)
|
||||
Delete(ctx context.Context, id string) error
|
||||
DeleteSessionFiles(ctx context.Context, sessionID string) error
|
||||
}
|
||||
|
||||
type service struct {
|
||||
*pubsub.Broker[File]
|
||||
db *sql.DB
|
||||
q *db.Queries
|
||||
}
|
||||
|
||||
func NewService(q *db.Queries, db *sql.DB) Service {
|
||||
return &service{
|
||||
Broker: pubsub.NewBroker[File](),
|
||||
q: q,
|
||||
db: db,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *service) Create(ctx context.Context, sessionID, path, content string) (File, error) {
|
||||
return s.createWithVersion(ctx, sessionID, path, content, InitialVersion)
|
||||
}
|
||||
|
||||
func (s *service) CreateVersion(ctx context.Context, sessionID, path, content string) (File, error) {
|
||||
// Get the latest version for this path
|
||||
files, err := s.q.ListFilesByPath(ctx, path)
|
||||
if err != nil {
|
||||
return File{}, err
|
||||
}
|
||||
|
||||
if len(files) == 0 {
|
||||
// No previous versions, create initial
|
||||
return s.Create(ctx, sessionID, path, content)
|
||||
}
|
||||
|
||||
// Get the latest version
|
||||
latestFile := files[0] // Files are ordered by created_at DESC
|
||||
latestVersion := latestFile.Version
|
||||
|
||||
// Generate the next version
|
||||
var nextVersion string
|
||||
if latestVersion == InitialVersion {
|
||||
nextVersion = "v1"
|
||||
} else if strings.HasPrefix(latestVersion, "v") {
|
||||
versionNum, err := strconv.Atoi(latestVersion[1:])
|
||||
if err != nil {
|
||||
// If we can't parse the version, just use a timestamp-based version
|
||||
nextVersion = fmt.Sprintf("v%d", latestFile.CreatedAt)
|
||||
} else {
|
||||
nextVersion = fmt.Sprintf("v%d", versionNum+1)
|
||||
}
|
||||
} else {
|
||||
// If the version format is unexpected, use a timestamp-based version
|
||||
nextVersion = fmt.Sprintf("v%d", latestFile.CreatedAt)
|
||||
}
|
||||
|
||||
return s.createWithVersion(ctx, sessionID, path, content, nextVersion)
|
||||
}
|
||||
|
||||
func (s *service) createWithVersion(ctx context.Context, sessionID, path, content, version string) (File, error) {
|
||||
// Maximum number of retries for transaction conflicts
|
||||
const maxRetries = 3
|
||||
var file File
|
||||
var err error
|
||||
|
||||
// Retry loop for transaction conflicts
|
||||
for attempt := range maxRetries {
|
||||
// Start a transaction
|
||||
tx, txErr := s.db.Begin()
|
||||
if txErr != nil {
|
||||
return File{}, fmt.Errorf("failed to begin transaction: %w", txErr)
|
||||
}
|
||||
|
||||
// Create a new queries instance with the transaction
|
||||
qtx := s.q.WithTx(tx)
|
||||
|
||||
// Try to create the file within the transaction
|
||||
dbFile, txErr := qtx.CreateFile(ctx, db.CreateFileParams{
|
||||
ID: uuid.New().String(),
|
||||
SessionID: sessionID,
|
||||
Path: path,
|
||||
Content: content,
|
||||
Version: version,
|
||||
})
|
||||
if txErr != nil {
|
||||
// Rollback the transaction
|
||||
tx.Rollback()
|
||||
|
||||
// Check if this is a uniqueness constraint violation
|
||||
if strings.Contains(txErr.Error(), "UNIQUE constraint failed") {
|
||||
if attempt < maxRetries-1 {
|
||||
// If we have retries left, generate a new version and try again
|
||||
if strings.HasPrefix(version, "v") {
|
||||
versionNum, parseErr := strconv.Atoi(version[1:])
|
||||
if parseErr == nil {
|
||||
version = fmt.Sprintf("v%d", versionNum+1)
|
||||
continue
|
||||
}
|
||||
}
|
||||
// If we can't parse the version, use a timestamp-based version
|
||||
version = fmt.Sprintf("v%d", time.Now().Unix())
|
||||
continue
|
||||
}
|
||||
}
|
||||
return File{}, txErr
|
||||
}
|
||||
|
||||
// Commit the transaction
|
||||
if txErr = tx.Commit(); txErr != nil {
|
||||
return File{}, fmt.Errorf("failed to commit transaction: %w", txErr)
|
||||
}
|
||||
|
||||
file = s.fromDBItem(dbFile)
|
||||
s.Publish(pubsub.CreatedEvent, file)
|
||||
return file, nil
|
||||
}
|
||||
|
||||
return file, err
|
||||
}
|
||||
|
||||
func (s *service) Get(ctx context.Context, id string) (File, error) {
|
||||
dbFile, err := s.q.GetFile(ctx, id)
|
||||
if err != nil {
|
||||
return File{}, err
|
||||
}
|
||||
return s.fromDBItem(dbFile), nil
|
||||
}
|
||||
|
||||
func (s *service) GetByPathAndSession(ctx context.Context, path, sessionID string) (File, error) {
|
||||
dbFile, err := s.q.GetFileByPathAndSession(ctx, db.GetFileByPathAndSessionParams{
|
||||
Path: path,
|
||||
SessionID: sessionID,
|
||||
})
|
||||
if err != nil {
|
||||
return File{}, err
|
||||
}
|
||||
return s.fromDBItem(dbFile), nil
|
||||
}
|
||||
|
||||
func (s *service) ListBySession(ctx context.Context, sessionID string) ([]File, error) {
|
||||
dbFiles, err := s.q.ListFilesBySession(ctx, sessionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
files := make([]File, len(dbFiles))
|
||||
for i, dbFile := range dbFiles {
|
||||
files[i] = s.fromDBItem(dbFile)
|
||||
}
|
||||
return files, nil
|
||||
}
|
||||
|
||||
func (s *service) ListLatestSessionFiles(ctx context.Context, sessionID string) ([]File, error) {
|
||||
dbFiles, err := s.q.ListLatestSessionFiles(ctx, sessionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
files := make([]File, len(dbFiles))
|
||||
for i, dbFile := range dbFiles {
|
||||
files[i] = s.fromDBItem(dbFile)
|
||||
}
|
||||
return files, nil
|
||||
}
|
||||
|
||||
func (s *service) Update(ctx context.Context, file File) (File, error) {
|
||||
dbFile, err := s.q.UpdateFile(ctx, db.UpdateFileParams{
|
||||
ID: file.ID,
|
||||
Content: file.Content,
|
||||
Version: file.Version,
|
||||
})
|
||||
if err != nil {
|
||||
return File{}, err
|
||||
}
|
||||
updatedFile := s.fromDBItem(dbFile)
|
||||
s.Publish(pubsub.UpdatedEvent, updatedFile)
|
||||
return updatedFile, nil
|
||||
}
|
||||
|
||||
func (s *service) Delete(ctx context.Context, id string) error {
|
||||
file, err := s.Get(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = s.q.DeleteFile(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.Publish(pubsub.DeletedEvent, file)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *service) DeleteSessionFiles(ctx context.Context, sessionID string) error {
|
||||
files, err := s.ListBySession(ctx, sessionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, file := range files {
|
||||
err = s.Delete(ctx, file.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *service) fromDBItem(item db.File) File {
|
||||
return File{
|
||||
ID: item.ID,
|
||||
SessionID: item.SessionID,
|
||||
Path: item.Path,
|
||||
Content: item.Content,
|
||||
Version: item.Version,
|
||||
CreatedAt: item.CreatedAt,
|
||||
UpdatedAt: item.UpdatedAt,
|
||||
}
|
||||
}
|
||||
441
internal/history/history.go
Normal file
441
internal/history/history.go
Normal file
@@ -0,0 +1,441 @@
|
||||
package history
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/sst/opencode/internal/db"
|
||||
"github.com/sst/opencode/internal/pubsub"
|
||||
)
|
||||
|
||||
const (
|
||||
InitialVersion = "initial"
|
||||
)
|
||||
|
||||
type File struct {
|
||||
ID string
|
||||
SessionID string
|
||||
Path string
|
||||
Content string
|
||||
Version string
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
const (
|
||||
EventFileCreated pubsub.EventType = "history_file_created"
|
||||
EventFileVersionCreated pubsub.EventType = "history_file_version_created"
|
||||
EventFileUpdated pubsub.EventType = "history_file_updated"
|
||||
EventFileDeleted pubsub.EventType = "history_file_deleted"
|
||||
EventSessionFilesDeleted pubsub.EventType = "history_session_files_deleted"
|
||||
)
|
||||
|
||||
type Service interface {
|
||||
pubsub.Subscriber[File]
|
||||
|
||||
Create(ctx context.Context, sessionID, path, content string) (File, error)
|
||||
CreateVersion(ctx context.Context, sessionID, path, content string) (File, error)
|
||||
Get(ctx context.Context, id string) (File, error)
|
||||
GetByPathAndVersion(ctx context.Context, sessionID, path, version string) (File, error)
|
||||
GetLatestByPathAndSession(ctx context.Context, path, sessionID string) (File, error)
|
||||
ListBySession(ctx context.Context, sessionID string) ([]File, error)
|
||||
ListLatestSessionFiles(ctx context.Context, sessionID string) ([]File, error)
|
||||
ListVersionsByPath(ctx context.Context, path string) ([]File, error)
|
||||
Update(ctx context.Context, file File) (File, error)
|
||||
Delete(ctx context.Context, id string) error
|
||||
DeleteSessionFiles(ctx context.Context, sessionID string) error
|
||||
}
|
||||
|
||||
type service struct {
|
||||
db *db.Queries
|
||||
sqlDB *sql.DB
|
||||
broker *pubsub.Broker[File]
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
var globalHistoryService *service
|
||||
|
||||
func InitService(sqlDatabase *sql.DB) error {
|
||||
if globalHistoryService != nil {
|
||||
return fmt.Errorf("history service already initialized")
|
||||
}
|
||||
queries := db.New(sqlDatabase)
|
||||
broker := pubsub.NewBroker[File]()
|
||||
|
||||
globalHistoryService = &service{
|
||||
db: queries,
|
||||
sqlDB: sqlDatabase,
|
||||
broker: broker,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetService() Service {
|
||||
if globalHistoryService == nil {
|
||||
panic("history service not initialized. Call history.InitService() first.")
|
||||
}
|
||||
return globalHistoryService
|
||||
}
|
||||
|
||||
func (s *service) Create(ctx context.Context, sessionID, path, content string) (File, error) {
|
||||
return s.createWithVersion(ctx, sessionID, path, content, InitialVersion, EventFileCreated)
|
||||
}
|
||||
|
||||
func (s *service) CreateVersion(ctx context.Context, sessionID, path, content string) (File, error) {
|
||||
s.mu.RLock()
|
||||
files, err := s.db.ListFilesByPath(ctx, path)
|
||||
s.mu.RUnlock()
|
||||
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return File{}, fmt.Errorf("db.ListFilesByPath for next version: %w", err)
|
||||
}
|
||||
|
||||
latestVersionNumber := 0
|
||||
if len(files) > 0 {
|
||||
// Sort to be absolutely sure about the latest version globally for this path
|
||||
slices.SortFunc(files, func(a, b db.File) int {
|
||||
if strings.HasPrefix(a.Version, "v") && strings.HasPrefix(b.Version, "v") {
|
||||
vA, _ := strconv.Atoi(a.Version[1:])
|
||||
vB, _ := strconv.Atoi(b.Version[1:])
|
||||
return vB - vA // Descending to get latest first
|
||||
}
|
||||
if a.Version == InitialVersion && b.Version != InitialVersion {
|
||||
return 1 // initial comes after vX
|
||||
}
|
||||
if b.Version == InitialVersion && a.Version != InitialVersion {
|
||||
return -1
|
||||
}
|
||||
// Compare timestamps as strings (ISO format sorts correctly)
|
||||
if b.CreatedAt > a.CreatedAt {
|
||||
return 1
|
||||
} else if a.CreatedAt > b.CreatedAt {
|
||||
return -1
|
||||
}
|
||||
return 0 // Equal timestamps
|
||||
})
|
||||
|
||||
latestFile := files[0]
|
||||
if strings.HasPrefix(latestFile.Version, "v") {
|
||||
vNum, parseErr := strconv.Atoi(latestFile.Version[1:])
|
||||
if parseErr == nil {
|
||||
latestVersionNumber = vNum
|
||||
}
|
||||
}
|
||||
}
|
||||
nextVersionStr := fmt.Sprintf("v%d", latestVersionNumber+1)
|
||||
return s.createWithVersion(ctx, sessionID, path, content, nextVersionStr, EventFileVersionCreated)
|
||||
}
|
||||
|
||||
func (s *service) createWithVersion(ctx context.Context, sessionID, path, content, version string, eventType pubsub.EventType) (File, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
const maxRetries = 3
|
||||
var file File
|
||||
var err error
|
||||
|
||||
for attempt := range maxRetries {
|
||||
tx, txErr := s.sqlDB.BeginTx(ctx, nil)
|
||||
if txErr != nil {
|
||||
return File{}, fmt.Errorf("failed to begin transaction: %w", txErr)
|
||||
}
|
||||
qtx := s.db.WithTx(tx)
|
||||
|
||||
dbFile, createErr := qtx.CreateFile(ctx, db.CreateFileParams{
|
||||
ID: uuid.New().String(),
|
||||
SessionID: sessionID,
|
||||
Path: path,
|
||||
Content: content,
|
||||
Version: version,
|
||||
})
|
||||
|
||||
if createErr != nil {
|
||||
if rbErr := tx.Rollback(); rbErr != nil {
|
||||
slog.Error("Failed to rollback transaction on create error", "error", rbErr)
|
||||
}
|
||||
if strings.Contains(createErr.Error(), "UNIQUE constraint failed: files.path, files.session_id, files.version") {
|
||||
if attempt < maxRetries-1 {
|
||||
slog.Warn("Unique constraint violation for file version, retrying with incremented version", "path", path, "session", sessionID, "attempted_version", version, "attempt", attempt+1)
|
||||
// Increment version string like v1, v2, v3...
|
||||
if strings.HasPrefix(version, "v") {
|
||||
numPart := version[1:]
|
||||
num, parseErr := strconv.Atoi(numPart)
|
||||
if parseErr == nil {
|
||||
version = fmt.Sprintf("v%d", num+1)
|
||||
continue // Retry with new version
|
||||
}
|
||||
}
|
||||
// Fallback if version is not "vX" or parsing failed
|
||||
version = fmt.Sprintf("%s-retry%d", version, attempt+1)
|
||||
continue
|
||||
}
|
||||
}
|
||||
return File{}, fmt.Errorf("db.CreateFile within transaction: %w", createErr)
|
||||
}
|
||||
|
||||
if commitErr := tx.Commit(); commitErr != nil {
|
||||
return File{}, fmt.Errorf("failed to commit transaction: %w", commitErr)
|
||||
}
|
||||
|
||||
file = s.fromDBItem(dbFile)
|
||||
s.broker.Publish(eventType, file)
|
||||
return file, nil // Success
|
||||
}
|
||||
|
||||
return File{}, fmt.Errorf("failed to create file after %d retries due to version conflicts: %w", maxRetries, err)
|
||||
}
|
||||
|
||||
func (s *service) Get(ctx context.Context, id string) (File, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
dbFile, err := s.db.GetFile(ctx, id)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return File{}, fmt.Errorf("file with ID '%s' not found", id)
|
||||
}
|
||||
return File{}, fmt.Errorf("db.GetFile: %w", err)
|
||||
}
|
||||
return s.fromDBItem(dbFile), nil
|
||||
}
|
||||
|
||||
func (s *service) GetByPathAndVersion(ctx context.Context, sessionID, path, version string) (File, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
// sqlc doesn't directly support GetyByPathAndVersionAndSession
|
||||
// We list and filter. This could be optimized with a custom query if performance is an issue.
|
||||
allFilesForPath, err := s.db.ListFilesByPath(ctx, path)
|
||||
if err != nil {
|
||||
return File{}, fmt.Errorf("db.ListFilesByPath for GetByPathAndVersion: %w", err)
|
||||
}
|
||||
|
||||
for _, dbFile := range allFilesForPath {
|
||||
if dbFile.SessionID == sessionID && dbFile.Version == version {
|
||||
return s.fromDBItem(dbFile), nil
|
||||
}
|
||||
}
|
||||
return File{}, fmt.Errorf("file not found for session '%s', path '%s', version '%s'", sessionID, path, version)
|
||||
}
|
||||
|
||||
func (s *service) GetLatestByPathAndSession(ctx context.Context, path, sessionID string) (File, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
// GetFileByPathAndSession in sqlc already orders by created_at DESC and takes LIMIT 1
|
||||
dbFile, err := s.db.GetFileByPathAndSession(ctx, db.GetFileByPathAndSessionParams{
|
||||
Path: path,
|
||||
SessionID: sessionID,
|
||||
})
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return File{}, fmt.Errorf("no file found for path '%s' in session '%s'", path, sessionID)
|
||||
}
|
||||
return File{}, fmt.Errorf("db.GetFileByPathAndSession: %w", err)
|
||||
}
|
||||
return s.fromDBItem(dbFile), nil
|
||||
}
|
||||
|
||||
func (s *service) ListBySession(ctx context.Context, sessionID string) ([]File, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
dbFiles, err := s.db.ListFilesBySession(ctx, sessionID) // Assumes this orders by created_at ASC
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("db.ListFilesBySession: %w", err)
|
||||
}
|
||||
files := make([]File, len(dbFiles))
|
||||
for i, dbF := range dbFiles {
|
||||
files[i] = s.fromDBItem(dbF)
|
||||
}
|
||||
return files, nil
|
||||
}
|
||||
|
||||
func (s *service) ListLatestSessionFiles(ctx context.Context, sessionID string) ([]File, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
dbFiles, err := s.db.ListLatestSessionFiles(ctx, sessionID) // Uses the specific sqlc query
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("db.ListLatestSessionFiles: %w", err)
|
||||
}
|
||||
files := make([]File, len(dbFiles))
|
||||
for i, dbF := range dbFiles {
|
||||
files[i] = s.fromDBItem(dbF)
|
||||
}
|
||||
return files, nil
|
||||
}
|
||||
|
||||
func (s *service) ListVersionsByPath(ctx context.Context, path string) ([]File, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
dbFiles, err := s.db.ListFilesByPath(ctx, path) // sqlc query orders by created_at DESC
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("db.ListFilesByPath: %w", err)
|
||||
}
|
||||
files := make([]File, len(dbFiles))
|
||||
for i, dbF := range dbFiles {
|
||||
files[i] = s.fromDBItem(dbF)
|
||||
}
|
||||
return files, nil
|
||||
}
|
||||
|
||||
func (s *service) Update(ctx context.Context, file File) (File, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if file.ID == "" {
|
||||
return File{}, fmt.Errorf("cannot update file with empty ID")
|
||||
}
|
||||
// UpdatedAt is handled by DB trigger
|
||||
dbFile, err := s.db.UpdateFile(ctx, db.UpdateFileParams{
|
||||
ID: file.ID,
|
||||
Content: file.Content,
|
||||
Version: file.Version,
|
||||
})
|
||||
if err != nil {
|
||||
return File{}, fmt.Errorf("db.UpdateFile: %w", err)
|
||||
}
|
||||
updatedFile := s.fromDBItem(dbFile)
|
||||
s.broker.Publish(EventFileUpdated, updatedFile)
|
||||
return updatedFile, nil
|
||||
}
|
||||
|
||||
func (s *service) Delete(ctx context.Context, id string) error {
|
||||
s.mu.Lock()
|
||||
fileToPublish, err := s.getServiceForPublish(ctx, id) // Use internal method with appropriate locking
|
||||
s.mu.Unlock()
|
||||
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
slog.Warn("Attempted to delete non-existent file history", "id", id)
|
||||
return nil // Or return specific error if needed
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
err = s.db.DeleteFile(ctx, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("db.DeleteFile: %w", err)
|
||||
}
|
||||
if fileToPublish != nil {
|
||||
s.broker.Publish(EventFileDeleted, *fileToPublish)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *service) getServiceForPublish(ctx context.Context, id string) (*File, error) {
|
||||
// Assumes outer lock is NOT held or caller manages it.
|
||||
// For GetFile, it has its own RLock.
|
||||
dbFile, err := s.db.GetFile(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
file := s.fromDBItem(dbFile)
|
||||
return &file, nil
|
||||
}
|
||||
|
||||
func (s *service) DeleteSessionFiles(ctx context.Context, sessionID string) error {
|
||||
s.mu.Lock() // Lock for the entire operation
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Get files first for publishing events
|
||||
filesToDelete, err := s.db.ListFilesBySession(ctx, sessionID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("db.ListFilesBySession for deletion: %w", err)
|
||||
}
|
||||
|
||||
err = s.db.DeleteSessionFiles(ctx, sessionID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("db.DeleteSessionFiles: %w", err)
|
||||
}
|
||||
|
||||
for _, dbFile := range filesToDelete {
|
||||
file := s.fromDBItem(dbFile)
|
||||
s.broker.Publish(EventFileDeleted, file) // Individual delete events
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *service) Subscribe(ctx context.Context) <-chan pubsub.Event[File] {
|
||||
return s.broker.Subscribe(ctx)
|
||||
}
|
||||
|
||||
func (s *service) fromDBItem(item db.File) File {
|
||||
// Parse timestamps from ISO strings
|
||||
createdAt, err := time.Parse(time.RFC3339Nano, item.CreatedAt)
|
||||
if err != nil {
|
||||
slog.Error("Failed to parse created_at", "value", item.CreatedAt, "error", err)
|
||||
createdAt = time.Now() // Fallback
|
||||
}
|
||||
|
||||
updatedAt, err := time.Parse(time.RFC3339Nano, item.UpdatedAt)
|
||||
if err != nil {
|
||||
slog.Error("Failed to parse created_at", "value", item.CreatedAt, "error", err)
|
||||
updatedAt = time.Now() // Fallback
|
||||
}
|
||||
|
||||
return File{
|
||||
ID: item.ID,
|
||||
SessionID: item.SessionID,
|
||||
Path: item.Path,
|
||||
Content: item.Content,
|
||||
Version: item.Version,
|
||||
CreatedAt: createdAt,
|
||||
UpdatedAt: updatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func Create(ctx context.Context, sessionID, path, content string) (File, error) {
|
||||
return GetService().Create(ctx, sessionID, path, content)
|
||||
}
|
||||
|
||||
func CreateVersion(ctx context.Context, sessionID, path, content string) (File, error) {
|
||||
return GetService().CreateVersion(ctx, sessionID, path, content)
|
||||
}
|
||||
|
||||
func Get(ctx context.Context, id string) (File, error) {
|
||||
return GetService().Get(ctx, id)
|
||||
}
|
||||
|
||||
func GetByPathAndVersion(ctx context.Context, sessionID, path, version string) (File, error) {
|
||||
return GetService().GetByPathAndVersion(ctx, sessionID, path, version)
|
||||
}
|
||||
|
||||
func GetLatestByPathAndSession(ctx context.Context, path, sessionID string) (File, error) {
|
||||
return GetService().GetLatestByPathAndSession(ctx, path, sessionID)
|
||||
}
|
||||
|
||||
func ListBySession(ctx context.Context, sessionID string) ([]File, error) {
|
||||
return GetService().ListBySession(ctx, sessionID)
|
||||
}
|
||||
|
||||
func ListLatestSessionFiles(ctx context.Context, sessionID string) ([]File, error) {
|
||||
return GetService().ListLatestSessionFiles(ctx, sessionID)
|
||||
}
|
||||
|
||||
func ListVersionsByPath(ctx context.Context, path string) ([]File, error) {
|
||||
return GetService().ListVersionsByPath(ctx, path)
|
||||
}
|
||||
|
||||
func Update(ctx context.Context, file File) (File, error) {
|
||||
return GetService().Update(ctx, file)
|
||||
}
|
||||
|
||||
func Delete(ctx context.Context, id string) error {
|
||||
return GetService().Delete(ctx, id)
|
||||
}
|
||||
|
||||
func DeleteSessionFiles(ctx context.Context, sessionID string) error {
|
||||
return GetService().DeleteSessionFiles(ctx, sessionID)
|
||||
}
|
||||
|
||||
func Subscribe(ctx context.Context) <-chan pubsub.Event[File] {
|
||||
return GetService().Subscribe(ctx)
|
||||
}
|
||||
@@ -5,11 +5,11 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/opencode-ai/opencode/internal/config"
|
||||
"github.com/opencode-ai/opencode/internal/llm/tools"
|
||||
"github.com/opencode-ai/opencode/internal/lsp"
|
||||
"github.com/opencode-ai/opencode/internal/message"
|
||||
"github.com/opencode-ai/opencode/internal/session"
|
||||
"github.com/sst/opencode/internal/config"
|
||||
"github.com/sst/opencode/internal/llm/tools"
|
||||
"github.com/sst/opencode/internal/lsp"
|
||||
"github.com/sst/opencode/internal/message"
|
||||
"github.com/sst/opencode/internal/session"
|
||||
)
|
||||
|
||||
type agentTool struct {
|
||||
@@ -88,10 +88,8 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes
|
||||
}
|
||||
|
||||
parentSession.Cost += updatedSession.Cost
|
||||
parentSession.PromptTokens += updatedSession.PromptTokens
|
||||
parentSession.CompletionTokens += updatedSession.CompletionTokens
|
||||
|
||||
_, err = b.sessions.Save(ctx, parentSession)
|
||||
_, err = b.sessions.Update(ctx, parentSession)
|
||||
if err != nil {
|
||||
return tools.ToolResponse{}, fmt.Errorf("error saving parent session: %s", err)
|
||||
}
|
||||
|
||||
@@ -6,16 +6,20 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/opencode-ai/opencode/internal/config"
|
||||
"github.com/opencode-ai/opencode/internal/llm/models"
|
||||
"github.com/opencode-ai/opencode/internal/llm/prompt"
|
||||
"github.com/opencode-ai/opencode/internal/llm/provider"
|
||||
"github.com/opencode-ai/opencode/internal/llm/tools"
|
||||
"github.com/opencode-ai/opencode/internal/logging"
|
||||
"github.com/opencode-ai/opencode/internal/message"
|
||||
"github.com/opencode-ai/opencode/internal/permission"
|
||||
"github.com/opencode-ai/opencode/internal/session"
|
||||
"log/slog"
|
||||
|
||||
"github.com/sst/opencode/internal/config"
|
||||
"github.com/sst/opencode/internal/llm/models"
|
||||
"github.com/sst/opencode/internal/llm/prompt"
|
||||
"github.com/sst/opencode/internal/llm/provider"
|
||||
"github.com/sst/opencode/internal/llm/tools"
|
||||
"github.com/sst/opencode/internal/logging"
|
||||
"github.com/sst/opencode/internal/message"
|
||||
"github.com/sst/opencode/internal/permission"
|
||||
"github.com/sst/opencode/internal/session"
|
||||
"github.com/sst/opencode/internal/status"
|
||||
)
|
||||
|
||||
// Common errors
|
||||
@@ -38,11 +42,14 @@ func (e *AgentEvent) Response() message.Message {
|
||||
}
|
||||
|
||||
type Service interface {
|
||||
Run(ctx context.Context, sessionID string, content string) (<-chan AgentEvent, error)
|
||||
Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
|
||||
Cancel(sessionID string)
|
||||
IsSessionBusy(sessionID string) bool
|
||||
IsBusy() bool
|
||||
Update(agentName config.AgentName, modelID models.ModelID) (models.Model, error)
|
||||
CompactSession(ctx context.Context, sessionID string, force bool) error
|
||||
GetUsage(ctx context.Context, sessionID string) (*int64, error)
|
||||
EstimateContextWindowUsage(ctx context.Context, sessionID string) (float64, bool, error)
|
||||
}
|
||||
|
||||
type agent struct {
|
||||
@@ -68,8 +75,8 @@ func NewAgent(
|
||||
return nil, err
|
||||
}
|
||||
var titleProvider provider.Provider
|
||||
// Only generate titles for the coder agent
|
||||
if agentName == config.AgentCoder {
|
||||
// Only generate titles for the primary agent
|
||||
if agentName == config.AgentPrimary {
|
||||
titleProvider, err = createAgentProvider(config.AgentTitle)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -91,7 +98,7 @@ func NewAgent(
|
||||
func (a *agent) Cancel(sessionID string) {
|
||||
if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
|
||||
if cancel, ok := cancelFunc.(context.CancelFunc); ok {
|
||||
logging.InfoPersist(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
|
||||
status.Info(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
@@ -117,6 +124,9 @@ func (a *agent) IsSessionBusy(sessionID string) bool {
|
||||
}
|
||||
|
||||
func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
|
||||
if content == "" {
|
||||
return nil
|
||||
}
|
||||
if a.titleProvider == nil {
|
||||
return nil
|
||||
}
|
||||
@@ -124,16 +134,13 @@ func (a *agent) generateTitle(ctx context.Context, sessionID string, content str
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
parts := []message.ContentPart{message.TextContent{Text: content}}
|
||||
response, err := a.titleProvider.SendMessages(
|
||||
ctx,
|
||||
[]message.Message{
|
||||
{
|
||||
Role: message.User,
|
||||
Parts: []message.ContentPart{
|
||||
message.TextContent{
|
||||
Text: content,
|
||||
},
|
||||
},
|
||||
Role: message.User,
|
||||
Parts: parts,
|
||||
},
|
||||
},
|
||||
make([]tools.BaseTool, 0),
|
||||
@@ -148,7 +155,7 @@ func (a *agent) generateTitle(ctx context.Context, sessionID string, content str
|
||||
}
|
||||
|
||||
session.Title = title
|
||||
_, err = a.sessions.Save(ctx, session)
|
||||
_, err = a.sessions.Update(ctx, session)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -158,7 +165,10 @@ func (a *agent) err(err error) AgentEvent {
|
||||
}
|
||||
}
|
||||
|
||||
func (a *agent) Run(ctx context.Context, sessionID string, content string) (<-chan AgentEvent, error) {
|
||||
func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
|
||||
if !a.provider.Model().SupportsAttachments && attachments != nil {
|
||||
attachments = nil
|
||||
}
|
||||
events := make(chan AgentEvent)
|
||||
if a.IsSessionBusy(sessionID) {
|
||||
return nil, ErrSessionBusy
|
||||
@@ -168,49 +178,98 @@ func (a *agent) Run(ctx context.Context, sessionID string, content string) (<-ch
|
||||
|
||||
a.activeRequests.Store(sessionID, cancel)
|
||||
go func() {
|
||||
logging.Debug("Request started", "sessionID", sessionID)
|
||||
slog.Debug("Request started", "sessionID", sessionID)
|
||||
defer logging.RecoverPanic("agent.Run", func() {
|
||||
events <- a.err(fmt.Errorf("panic while running the agent"))
|
||||
})
|
||||
|
||||
result := a.processGeneration(genCtx, sessionID, content)
|
||||
if result.Err() != nil && !errors.Is(result.Err(), ErrRequestCancelled) && !errors.Is(result.Err(), context.Canceled) {
|
||||
logging.ErrorPersist(fmt.Sprintf("Generation error for session %s: %v", sessionID, result))
|
||||
var attachmentParts []message.ContentPart
|
||||
for _, attachment := range attachments {
|
||||
attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
|
||||
}
|
||||
logging.Debug("Request completed", "sessionID", sessionID)
|
||||
result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
|
||||
if result.Err() != nil && !errors.Is(result.Err(), ErrRequestCancelled) && !errors.Is(result.Err(), context.Canceled) {
|
||||
status.Error(result.Err().Error())
|
||||
}
|
||||
slog.Debug("Request completed", "sessionID", sessionID)
|
||||
a.activeRequests.Delete(sessionID)
|
||||
cancel()
|
||||
events <- result
|
||||
close(events)
|
||||
}()
|
||||
|
||||
return events, nil
|
||||
}
|
||||
|
||||
func (a *agent) processGeneration(ctx context.Context, sessionID, content string) AgentEvent {
|
||||
// List existing messages; if none, start title generation asynchronously.
|
||||
msgs, err := a.messages.List(ctx, sessionID)
|
||||
func (a *agent) prepareMessageHistory(ctx context.Context, sessionID string) (session.Session, []message.Message, error) {
|
||||
currentSession, err := a.sessions.Get(ctx, sessionID)
|
||||
if err != nil {
|
||||
return a.err(fmt.Errorf("failed to list messages: %w", err))
|
||||
}
|
||||
if len(msgs) == 0 {
|
||||
go func() {
|
||||
defer logging.RecoverPanic("agent.Run", func() {
|
||||
logging.ErrorPersist("panic while generating title")
|
||||
})
|
||||
titleErr := a.generateTitle(context.Background(), sessionID, content)
|
||||
if titleErr != nil {
|
||||
logging.ErrorPersist(fmt.Sprintf("failed to generate title: %v", titleErr))
|
||||
}
|
||||
}()
|
||||
return currentSession, nil, fmt.Errorf("failed to get session: %w", err)
|
||||
}
|
||||
|
||||
userMsg, err := a.createUserMessage(ctx, sessionID, content)
|
||||
var sessionMessages []message.Message
|
||||
if currentSession.Summary != "" && !currentSession.SummarizedAt.IsZero() {
|
||||
// If summary exists, only fetch messages after the summarization timestamp
|
||||
sessionMessages, err = a.messages.ListAfter(ctx, sessionID, currentSession.SummarizedAt)
|
||||
if err != nil {
|
||||
return currentSession, nil, fmt.Errorf("failed to list messages after summary: %w", err)
|
||||
}
|
||||
} else {
|
||||
// If no summary, fetch all messages
|
||||
sessionMessages, err = a.messages.List(ctx, sessionID)
|
||||
if err != nil {
|
||||
return currentSession, nil, fmt.Errorf("failed to list messages: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
var messages []message.Message
|
||||
if currentSession.Summary != "" && !currentSession.SummarizedAt.IsZero() {
|
||||
// If summary exists, create a temporary message for the summary
|
||||
summaryMessage := message.Message{
|
||||
Role: message.Assistant,
|
||||
Parts: []message.ContentPart{
|
||||
message.TextContent{Text: currentSession.Summary},
|
||||
},
|
||||
}
|
||||
// Start with the summary, then add messages after the summary timestamp
|
||||
messages = append([]message.Message{summaryMessage}, sessionMessages...)
|
||||
} else {
|
||||
// If no summary, just use all messages
|
||||
messages = sessionMessages
|
||||
}
|
||||
|
||||
return currentSession, messages, nil
|
||||
}
|
||||
|
||||
func (a *agent) triggerTitleGeneration(sessionID string, content string) {
|
||||
go func() {
|
||||
defer logging.RecoverPanic("agent.Run", func() {
|
||||
status.Error("panic while generating title")
|
||||
})
|
||||
titleErr := a.generateTitle(context.Background(), sessionID, content)
|
||||
if titleErr != nil {
|
||||
status.Error(fmt.Sprintf("failed to generate title: %v", titleErr))
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
|
||||
currentSession, sessionMessages, err := a.prepareMessageHistory(ctx, sessionID)
|
||||
if err != nil {
|
||||
return a.err(err)
|
||||
}
|
||||
|
||||
// If this is a new session, start title generation asynchronously
|
||||
if len(sessionMessages) == 0 && currentSession.Summary == "" {
|
||||
a.triggerTitleGeneration(sessionID, content)
|
||||
}
|
||||
|
||||
userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
|
||||
if err != nil {
|
||||
return a.err(fmt.Errorf("failed to create user message: %w", err))
|
||||
}
|
||||
|
||||
// Append the new user message to the conversation history.
|
||||
msgHistory := append(msgs, userMsg)
|
||||
messages := append(sessionMessages, userMsg)
|
||||
|
||||
for {
|
||||
// Check for cancellation before each iteration
|
||||
select {
|
||||
@@ -219,7 +278,42 @@ func (a *agent) processGeneration(ctx context.Context, sessionID, content string
|
||||
default:
|
||||
// Continue processing
|
||||
}
|
||||
agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
|
||||
|
||||
// Check if auto-compaction is needed before calling the provider
|
||||
usagePercentage, needsCompaction, errEstimate := a.EstimateContextWindowUsage(ctx, sessionID)
|
||||
if errEstimate != nil {
|
||||
slog.Warn("Failed to estimate context window usage for auto-compaction", "error", errEstimate, "sessionID", sessionID)
|
||||
} else if needsCompaction {
|
||||
status.Info(fmt.Sprintf("Context window usage is at %.2f%%. Auto-compacting conversation...", usagePercentage))
|
||||
|
||||
// Run compaction synchronously
|
||||
compactCtx, cancelCompact := context.WithTimeout(ctx, 30*time.Second) // Use appropriate context
|
||||
errCompact := a.CompactSession(compactCtx, sessionID, true)
|
||||
cancelCompact()
|
||||
|
||||
if errCompact != nil {
|
||||
status.Warn(fmt.Sprintf("Auto-compaction failed: %v. Context window usage may continue to grow.", errCompact))
|
||||
} else {
|
||||
status.Info("Auto-compaction completed successfully.")
|
||||
// After compaction, message history needs to be re-prepared.
|
||||
// The 'messages' slice needs to be updated with the new summary and subsequent messages,
|
||||
// ensuring the latest user message is correctly appended.
|
||||
_, sessionMessagesFromCompact, errPrepare := a.prepareMessageHistory(ctx, sessionID)
|
||||
if errPrepare != nil {
|
||||
return a.err(fmt.Errorf("failed to re-prepare message history after compaction: %w", errPrepare))
|
||||
}
|
||||
messages = sessionMessagesFromCompact
|
||||
|
||||
// Ensure the user message that triggered this cycle is the last one.
|
||||
// 'userMsg' was created before this loop using a.createUserMessage.
|
||||
// It should be appended to the 'messages' slice if it's not already the last element.
|
||||
if len(messages) == 0 || (len(messages) > 0 && messages[len(messages)-1].ID != userMsg.ID) {
|
||||
messages = append(messages, userMsg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, messages)
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
agentMessage.AddFinish(message.FinishReasonCanceled)
|
||||
@@ -228,10 +322,10 @@ func (a *agent) processGeneration(ctx context.Context, sessionID, content string
|
||||
}
|
||||
return a.err(fmt.Errorf("failed to process events: %w", err))
|
||||
}
|
||||
logging.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
|
||||
slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
|
||||
if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
|
||||
// We are not done, we need to respond with the tool response
|
||||
msgHistory = append(msgHistory, agentMessage, *toolResults)
|
||||
messages = append(messages, agentMessage, *toolResults)
|
||||
continue
|
||||
}
|
||||
return AgentEvent{
|
||||
@@ -240,15 +334,36 @@ func (a *agent) processGeneration(ctx context.Context, sessionID, content string
|
||||
}
|
||||
}
|
||||
|
||||
func (a *agent) createUserMessage(ctx context.Context, sessionID, content string) (message.Message, error) {
|
||||
func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
|
||||
parts := []message.ContentPart{message.TextContent{Text: content}}
|
||||
parts = append(parts, attachmentParts...)
|
||||
return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
|
||||
Role: message.User,
|
||||
Parts: []message.ContentPart{
|
||||
message.TextContent{Text: content},
|
||||
},
|
||||
Role: message.User,
|
||||
Parts: parts,
|
||||
})
|
||||
}
|
||||
|
||||
func (a *agent) createToolResponseMessage(ctx context.Context, sessionID string, toolResults []message.ToolResult) (*message.Message, error) {
|
||||
if len(toolResults) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
parts := make([]message.ContentPart, 0, len(toolResults))
|
||||
for _, tr := range toolResults {
|
||||
parts = append(parts, tr)
|
||||
}
|
||||
|
||||
msg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
|
||||
Role: message.Tool,
|
||||
Parts: parts,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create tool response message: %w", err)
|
||||
}
|
||||
|
||||
return &msg, nil
|
||||
}
|
||||
|
||||
func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
|
||||
eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
|
||||
|
||||
@@ -277,12 +392,37 @@ func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msg
|
||||
}
|
||||
}
|
||||
|
||||
toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
|
||||
toolCalls := assistantMsg.ToolCalls()
|
||||
// If the assistant wants to use tools, execute them
|
||||
if assistantMsg.FinishReason() == message.FinishReasonToolUse {
|
||||
toolCalls := assistantMsg.ToolCalls()
|
||||
if len(toolCalls) > 0 {
|
||||
toolResults, err := a.executeToolCalls(ctx, toolCalls)
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
|
||||
}
|
||||
return assistantMsg, nil, err
|
||||
}
|
||||
|
||||
// Create a message with the tool results
|
||||
toolResponseMsg, err := a.createToolResponseMessage(ctx, sessionID, toolResults)
|
||||
if err != nil {
|
||||
return assistantMsg, nil, err
|
||||
}
|
||||
|
||||
return assistantMsg, toolResponseMsg, nil
|
||||
}
|
||||
}
|
||||
|
||||
return assistantMsg, nil, nil
|
||||
}
|
||||
|
||||
func (a *agent) executeToolCalls(ctx context.Context, toolCalls []message.ToolCall) ([]message.ToolResult, error) {
|
||||
toolResults := make([]message.ToolResult, len(toolCalls))
|
||||
|
||||
for i, toolCall := range toolCalls {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
|
||||
// Make all future tool calls cancelled
|
||||
for j := i; j < len(toolCalls); j++ {
|
||||
toolResults[j] = message.ToolResult{
|
||||
@@ -291,7 +431,7 @@ func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msg
|
||||
IsError: true,
|
||||
}
|
||||
}
|
||||
goto out
|
||||
return toolResults, ctx.Err()
|
||||
default:
|
||||
// Continue processing
|
||||
var tool tools.BaseTool
|
||||
@@ -316,6 +456,7 @@ func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msg
|
||||
Name: toolCall.Name,
|
||||
Input: toolCall.Input,
|
||||
})
|
||||
|
||||
if toolErr != nil {
|
||||
if errors.Is(toolErr, permission.ErrorPermissionDenied) {
|
||||
toolResults[i] = message.ToolResult{
|
||||
@@ -323,6 +464,7 @@ func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msg
|
||||
Content: "Permission denied",
|
||||
IsError: true,
|
||||
}
|
||||
// Cancel all remaining tool calls if permission is denied
|
||||
for j := i + 1; j < len(toolCalls); j++ {
|
||||
toolResults[j] = message.ToolResult{
|
||||
ToolCallID: toolCalls[j].ID,
|
||||
@@ -330,10 +472,18 @@ func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msg
|
||||
IsError: true,
|
||||
}
|
||||
}
|
||||
a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied)
|
||||
break
|
||||
return toolResults, nil
|
||||
}
|
||||
|
||||
// Handle other errors
|
||||
toolResults[i] = message.ToolResult{
|
||||
ToolCallID: toolCall.ID,
|
||||
Content: toolErr.Error(),
|
||||
IsError: true,
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
toolResults[i] = message.ToolResult{
|
||||
ToolCallID: toolCall.ID,
|
||||
Content: toolResult.Content,
|
||||
@@ -342,28 +492,13 @@ func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msg
|
||||
}
|
||||
}
|
||||
}
|
||||
out:
|
||||
if len(toolResults) == 0 {
|
||||
return assistantMsg, nil, nil
|
||||
}
|
||||
parts := make([]message.ContentPart, 0)
|
||||
for _, tr := range toolResults {
|
||||
parts = append(parts, tr)
|
||||
}
|
||||
msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
|
||||
Role: message.Tool,
|
||||
Parts: parts,
|
||||
})
|
||||
if err != nil {
|
||||
return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
|
||||
}
|
||||
|
||||
return assistantMsg, &msg, err
|
||||
return toolResults, nil
|
||||
}
|
||||
|
||||
func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReson message.FinishReason) {
|
||||
msg.AddFinish(finishReson)
|
||||
_ = a.messages.Update(ctx, *msg)
|
||||
_, _ = a.messages.Update(ctx, *msg)
|
||||
}
|
||||
|
||||
func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
|
||||
@@ -371,19 +506,22 @@ func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
// Continue processing.
|
||||
// Continue processing
|
||||
}
|
||||
|
||||
switch event.Type {
|
||||
case provider.EventThinkingDelta:
|
||||
assistantMsg.AppendReasoningContent(event.Content)
|
||||
return a.messages.Update(ctx, *assistantMsg)
|
||||
_, err := a.messages.Update(ctx, *assistantMsg)
|
||||
return err
|
||||
case provider.EventContentDelta:
|
||||
assistantMsg.AppendContent(event.Content)
|
||||
return a.messages.Update(ctx, *assistantMsg)
|
||||
_, err := a.messages.Update(ctx, *assistantMsg)
|
||||
return err
|
||||
case provider.EventToolUseStart:
|
||||
assistantMsg.AddToolCall(*event.ToolCall)
|
||||
return a.messages.Update(ctx, *assistantMsg)
|
||||
_, err := a.messages.Update(ctx, *assistantMsg)
|
||||
return err
|
||||
// TODO: see how to handle this
|
||||
// case provider.EventToolUseDelta:
|
||||
// tm := time.Unix(assistantMsg.UpdatedAt, 0)
|
||||
@@ -395,18 +533,19 @@ func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg
|
||||
// }
|
||||
case provider.EventToolUseStop:
|
||||
assistantMsg.FinishToolCall(event.ToolCall.ID)
|
||||
return a.messages.Update(ctx, *assistantMsg)
|
||||
_, err := a.messages.Update(ctx, *assistantMsg)
|
||||
return err
|
||||
case provider.EventError:
|
||||
if errors.Is(event.Error, context.Canceled) {
|
||||
logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID))
|
||||
status.Info(fmt.Sprintf("Event processing canceled for session: %s", sessionID))
|
||||
return context.Canceled
|
||||
}
|
||||
logging.ErrorPersist(event.Error.Error())
|
||||
status.Error(event.Error.Error())
|
||||
return event.Error
|
||||
case provider.EventComplete:
|
||||
assistantMsg.SetToolCalls(event.Response.ToolCalls)
|
||||
assistantMsg.AddFinish(event.Response.FinishReason)
|
||||
if err := a.messages.Update(ctx, *assistantMsg); err != nil {
|
||||
if _, err := a.messages.Update(ctx, *assistantMsg); err != nil {
|
||||
return fmt.Errorf("failed to update message: %w", err)
|
||||
}
|
||||
return a.TrackUsage(ctx, sessionID, a.provider.Model(), event.Response.Usage)
|
||||
@@ -415,6 +554,49 @@ func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *agent) GetUsage(ctx context.Context, sessionID string) (*int64, error) {
|
||||
session, err := a.sessions.Get(ctx, sessionID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get session: %w", err)
|
||||
}
|
||||
|
||||
usage := session.PromptTokens + session.CompletionTokens
|
||||
return &usage, nil
|
||||
}
|
||||
|
||||
func (a *agent) EstimateContextWindowUsage(ctx context.Context, sessionID string) (float64, bool, error) {
|
||||
session, err := a.sessions.Get(ctx, sessionID)
|
||||
if err != nil {
|
||||
return 0, false, fmt.Errorf("failed to get session: %w", err)
|
||||
}
|
||||
|
||||
// Get the model's context window size
|
||||
model := a.provider.Model()
|
||||
contextWindow := model.ContextWindow
|
||||
if contextWindow <= 0 {
|
||||
// Default to a reasonable size if not specified
|
||||
contextWindow = 100000
|
||||
}
|
||||
|
||||
// Calculate current token usage
|
||||
currentTokens := session.PromptTokens + session.CompletionTokens
|
||||
|
||||
// Get the max tokens setting for the agent
|
||||
maxTokens := a.provider.MaxTokens()
|
||||
|
||||
// Calculate percentage of context window used
|
||||
usagePercentage := float64(currentTokens) / float64(contextWindow)
|
||||
|
||||
// Check if we need to auto-compact
|
||||
// Auto-compact when:
|
||||
// 1. Usage exceeds 90% of context window, OR
|
||||
// 2. Current usage + maxTokens would exceed 100% of context window
|
||||
needsCompaction := usagePercentage >= 0.9 ||
|
||||
float64(currentTokens+maxTokens) > float64(contextWindow)
|
||||
|
||||
return usagePercentage * 100, needsCompaction, nil
|
||||
}
|
||||
|
||||
func (a *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error {
|
||||
sess, err := a.sessions.Get(ctx, sessionID)
|
||||
if err != nil {
|
||||
@@ -427,10 +609,10 @@ func (a *agent) TrackUsage(ctx context.Context, sessionID string, model models.M
|
||||
model.CostPer1MOut/1e6*float64(usage.OutputTokens)
|
||||
|
||||
sess.Cost += cost
|
||||
sess.CompletionTokens += usage.OutputTokens
|
||||
sess.PromptTokens += usage.InputTokens
|
||||
sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
|
||||
sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
|
||||
|
||||
_, err = a.sessions.Save(ctx, sess)
|
||||
_, err = a.sessions.Update(ctx, sess)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to save session: %w", err)
|
||||
}
|
||||
@@ -456,6 +638,127 @@ func (a *agent) Update(agentName config.AgentName, modelID models.ModelID) (mode
|
||||
return a.provider.Model(), nil
|
||||
}
|
||||
|
||||
func (a *agent) CompactSession(ctx context.Context, sessionID string, force bool) error {
|
||||
// Check if the session is busy
|
||||
if a.IsSessionBusy(sessionID) && !force {
|
||||
return ErrSessionBusy
|
||||
}
|
||||
|
||||
// Create a cancellable context
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
// Mark the session as busy during compaction
|
||||
compactionCancelFunc := func() {}
|
||||
a.activeRequests.Store(sessionID+"-compact", compactionCancelFunc)
|
||||
defer a.activeRequests.Delete(sessionID + "-compact")
|
||||
|
||||
// Fetch the session
|
||||
session, err := a.sessions.Get(ctx, sessionID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get session: %w", err)
|
||||
}
|
||||
|
||||
// Fetch all messages for the session
|
||||
sessionMessages, err := a.messages.List(ctx, sessionID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to list messages: %w", err)
|
||||
}
|
||||
|
||||
var existingSummary string
|
||||
if session.Summary != "" && !session.SummarizedAt.IsZero() {
|
||||
// Filter messages that were created after the last summarization
|
||||
var newMessages []message.Message
|
||||
for _, msg := range sessionMessages {
|
||||
if msg.CreatedAt.After(session.SummarizedAt) {
|
||||
newMessages = append(newMessages, msg)
|
||||
}
|
||||
}
|
||||
sessionMessages = newMessages
|
||||
existingSummary = session.Summary
|
||||
}
|
||||
|
||||
// If there are no messages to summarize and no existing summary, return early
|
||||
if len(sessionMessages) == 0 && existingSummary == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
messages := []message.Message{
|
||||
message.Message{
|
||||
Role: message.System,
|
||||
Parts: []message.ContentPart{
|
||||
message.TextContent{
|
||||
Text: `You are a helpful AI assistant tasked with summarizing conversations.
|
||||
|
||||
When asked to summarize, provide a detailed but concise summary of the conversation.
|
||||
Focus on information that would be helpful for continuing the conversation, including:
|
||||
- What was done
|
||||
- What is currently being worked on
|
||||
- Which files are being modified
|
||||
- What needs to be done next
|
||||
|
||||
Your summary should be comprehensive enough to provide context but concise enough to be quickly understood.`,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// If there's an existing summary, include it
|
||||
if existingSummary != "" {
|
||||
messages = append(messages, message.Message{
|
||||
Role: message.Assistant,
|
||||
Parts: []message.ContentPart{
|
||||
message.TextContent{
|
||||
Text: existingSummary,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Add all messages since the last summarized message
|
||||
messages = append(messages, sessionMessages...)
|
||||
|
||||
// Add a final user message requesting the summary
|
||||
messages = append(messages, message.Message{
|
||||
Role: message.User,
|
||||
Parts: []message.ContentPart{
|
||||
message.TextContent{
|
||||
Text: "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next.",
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
// Call provider to get the summary
|
||||
response, err := a.provider.SendMessages(ctx, messages, a.tools)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get summary from the assistant: %w", err)
|
||||
}
|
||||
|
||||
// Extract the summary text
|
||||
summaryText := strings.TrimSpace(response.Content)
|
||||
if summaryText == "" {
|
||||
return fmt.Errorf("received empty summary from the assistant")
|
||||
}
|
||||
|
||||
// Update the session with the new summary
|
||||
session.Summary = summaryText
|
||||
session.SummarizedAt = time.Now()
|
||||
|
||||
// Save the updated session
|
||||
_, err = a.sessions.Update(ctx, session)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to save session with summary: %w", err)
|
||||
}
|
||||
|
||||
// Track token usage
|
||||
err = a.TrackUsage(ctx, sessionID, a.provider.Model(), response.Usage)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to track usage: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func createAgentProvider(agentName config.AgentName) (provider.Provider, error) {
|
||||
cfg := config.Get()
|
||||
agentConfig, ok := cfg.Agents[agentName]
|
||||
@@ -491,7 +794,7 @@ func createAgentProvider(agentName config.AgentName) (provider.Provider, error)
|
||||
provider.WithReasoningEffort(agentConfig.ReasoningEffort),
|
||||
),
|
||||
)
|
||||
} else if model.Provider == models.ProviderAnthropic && model.CanReason && agentName == config.AgentCoder {
|
||||
} else if model.Provider == models.ProviderAnthropic && model.CanReason && agentName == config.AgentPrimary {
|
||||
opts = append(
|
||||
opts,
|
||||
provider.WithAnthropicOptions(
|
||||
|
||||
@@ -5,11 +5,11 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/opencode-ai/opencode/internal/config"
|
||||
"github.com/opencode-ai/opencode/internal/llm/tools"
|
||||
"github.com/opencode-ai/opencode/internal/logging"
|
||||
"github.com/opencode-ai/opencode/internal/permission"
|
||||
"github.com/opencode-ai/opencode/internal/version"
|
||||
"github.com/sst/opencode/internal/config"
|
||||
"github.com/sst/opencode/internal/llm/tools"
|
||||
"github.com/sst/opencode/internal/permission"
|
||||
"github.com/sst/opencode/internal/version"
|
||||
"log/slog"
|
||||
|
||||
"github.com/mark3labs/mcp-go/client"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
@@ -86,6 +86,7 @@ func (b *mcpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolRes
|
||||
}
|
||||
permissionDescription := fmt.Sprintf("execute %s with the following parameters: %s", b.Info().Name, params.Input)
|
||||
p := b.permissions.Request(
|
||||
ctx,
|
||||
permission.CreatePermissionRequest{
|
||||
SessionID: sessionID,
|
||||
Path: config.WorkingDirectory(),
|
||||
@@ -146,13 +147,13 @@ func getTools(ctx context.Context, name string, m config.MCPServer, permissions
|
||||
|
||||
_, err := c.Initialize(ctx, initRequest)
|
||||
if err != nil {
|
||||
logging.Error("error initializing mcp client", "error", err)
|
||||
slog.Error("error initializing mcp client", "error", err)
|
||||
return stdioTools
|
||||
}
|
||||
toolsRequest := mcp.ListToolsRequest{}
|
||||
tools, err := c.ListTools(ctx, toolsRequest)
|
||||
if err != nil {
|
||||
logging.Error("error listing tools", "error", err)
|
||||
slog.Error("error listing tools", "error", err)
|
||||
return stdioTools
|
||||
}
|
||||
for _, t := range tools.Tools {
|
||||
@@ -175,7 +176,7 @@ func GetMcpTools(ctx context.Context, permissions permission.Service) []tools.Ba
|
||||
m.Args...,
|
||||
)
|
||||
if err != nil {
|
||||
logging.Error("error creating mcp client", "error", err)
|
||||
slog.Error("error creating mcp client", "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -186,7 +187,7 @@ func GetMcpTools(ctx context.Context, permissions permission.Service) []tools.Ba
|
||||
client.WithHeaders(m.Headers),
|
||||
)
|
||||
if err != nil {
|
||||
logging.Error("error creating mcp client", "error", err)
|
||||
slog.Error("error creating mcp client", "error", err)
|
||||
continue
|
||||
}
|
||||
mcpTools = append(mcpTools, getTools(ctx, name, m, permissions, c)...)
|
||||
|
||||
@@ -3,15 +3,15 @@ package agent
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/opencode-ai/opencode/internal/history"
|
||||
"github.com/opencode-ai/opencode/internal/llm/tools"
|
||||
"github.com/opencode-ai/opencode/internal/lsp"
|
||||
"github.com/opencode-ai/opencode/internal/message"
|
||||
"github.com/opencode-ai/opencode/internal/permission"
|
||||
"github.com/opencode-ai/opencode/internal/session"
|
||||
"github.com/sst/opencode/internal/history"
|
||||
"github.com/sst/opencode/internal/llm/tools"
|
||||
"github.com/sst/opencode/internal/lsp"
|
||||
"github.com/sst/opencode/internal/message"
|
||||
"github.com/sst/opencode/internal/permission"
|
||||
"github.com/sst/opencode/internal/session"
|
||||
)
|
||||
|
||||
func CoderAgentTools(
|
||||
func PrimaryAgentTools(
|
||||
permissions permission.Service,
|
||||
sessions session.Service,
|
||||
messages message.Service,
|
||||
@@ -19,10 +19,8 @@ func CoderAgentTools(
|
||||
lspClients map[string]*lsp.Client,
|
||||
) []tools.BaseTool {
|
||||
ctx := context.Background()
|
||||
otherTools := GetMcpTools(ctx, permissions)
|
||||
if len(lspClients) > 0 {
|
||||
otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients))
|
||||
}
|
||||
mcpTools := GetMcpTools(ctx, permissions)
|
||||
|
||||
return append(
|
||||
[]tools.BaseTool{
|
||||
tools.NewBashTool(permissions),
|
||||
@@ -31,12 +29,17 @@ func CoderAgentTools(
|
||||
tools.NewGlobTool(),
|
||||
tools.NewGrepTool(),
|
||||
tools.NewLsTool(),
|
||||
tools.NewSourcegraphTool(),
|
||||
tools.NewViewTool(lspClients),
|
||||
tools.NewPatchTool(lspClients, permissions, history),
|
||||
tools.NewWriteTool(lspClients, permissions, history),
|
||||
tools.NewDiagnosticsTool(lspClients),
|
||||
tools.NewDefinitionTool(lspClients),
|
||||
tools.NewReferencesTool(lspClients),
|
||||
tools.NewDocSymbolsTool(lspClients),
|
||||
tools.NewWorkspaceSymbolsTool(lspClients),
|
||||
tools.NewCodeActionTool(lspClients),
|
||||
NewAgentTool(sessions, messages, lspClients),
|
||||
}, otherTools...,
|
||||
}, mcpTools...,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -45,7 +48,10 @@ func TaskAgentTools(lspClients map[string]*lsp.Client) []tools.BaseTool {
|
||||
tools.NewGlobTool(),
|
||||
tools.NewGrepTool(),
|
||||
tools.NewLsTool(),
|
||||
tools.NewSourcegraphTool(),
|
||||
tools.NewViewTool(lspClients),
|
||||
tools.NewDefinitionTool(lspClients),
|
||||
tools.NewReferencesTool(lspClients),
|
||||
tools.NewDocSymbolsTool(lspClients),
|
||||
tools.NewWorkspaceSymbolsTool(lspClients),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,64 +14,69 @@ const (
|
||||
// https://docs.anthropic.com/en/docs/about-claude/models/all-models
|
||||
var AnthropicModels = map[ModelID]Model{
|
||||
Claude35Sonnet: {
|
||||
ID: Claude35Sonnet,
|
||||
Name: "Claude 3.5 Sonnet",
|
||||
Provider: ProviderAnthropic,
|
||||
APIModel: "claude-3-5-sonnet-latest",
|
||||
CostPer1MIn: 3.0,
|
||||
CostPer1MInCached: 3.75,
|
||||
CostPer1MOutCached: 0.30,
|
||||
CostPer1MOut: 15.0,
|
||||
ContextWindow: 200000,
|
||||
DefaultMaxTokens: 5000,
|
||||
ID: Claude35Sonnet,
|
||||
Name: "Claude 3.5 Sonnet",
|
||||
Provider: ProviderAnthropic,
|
||||
APIModel: "claude-3-5-sonnet-latest",
|
||||
CostPer1MIn: 3.0,
|
||||
CostPer1MInCached: 3.75,
|
||||
CostPer1MOutCached: 0.30,
|
||||
CostPer1MOut: 15.0,
|
||||
ContextWindow: 200000,
|
||||
DefaultMaxTokens: 5000,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
Claude3Haiku: {
|
||||
ID: Claude3Haiku,
|
||||
Name: "Claude 3 Haiku",
|
||||
Provider: ProviderAnthropic,
|
||||
APIModel: "claude-3-haiku-20240307", // doesn't support "-latest"
|
||||
CostPer1MIn: 0.25,
|
||||
CostPer1MInCached: 0.30,
|
||||
CostPer1MOutCached: 0.03,
|
||||
CostPer1MOut: 1.25,
|
||||
ContextWindow: 200000,
|
||||
DefaultMaxTokens: 4096,
|
||||
ID: Claude3Haiku,
|
||||
Name: "Claude 3 Haiku",
|
||||
Provider: ProviderAnthropic,
|
||||
APIModel: "claude-3-haiku-20240307", // doesn't support "-latest"
|
||||
CostPer1MIn: 0.25,
|
||||
CostPer1MInCached: 0.30,
|
||||
CostPer1MOutCached: 0.03,
|
||||
CostPer1MOut: 1.25,
|
||||
ContextWindow: 200000,
|
||||
DefaultMaxTokens: 4096,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
Claude37Sonnet: {
|
||||
ID: Claude37Sonnet,
|
||||
Name: "Claude 3.7 Sonnet",
|
||||
Provider: ProviderAnthropic,
|
||||
APIModel: "claude-3-7-sonnet-latest",
|
||||
CostPer1MIn: 3.0,
|
||||
CostPer1MInCached: 3.75,
|
||||
CostPer1MOutCached: 0.30,
|
||||
CostPer1MOut: 15.0,
|
||||
ContextWindow: 200000,
|
||||
DefaultMaxTokens: 50000,
|
||||
CanReason: true,
|
||||
ID: Claude37Sonnet,
|
||||
Name: "Claude 3.7 Sonnet",
|
||||
Provider: ProviderAnthropic,
|
||||
APIModel: "claude-3-7-sonnet-latest",
|
||||
CostPer1MIn: 3.0,
|
||||
CostPer1MInCached: 3.75,
|
||||
CostPer1MOutCached: 0.30,
|
||||
CostPer1MOut: 15.0,
|
||||
ContextWindow: 200000,
|
||||
DefaultMaxTokens: 50000,
|
||||
CanReason: true,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
Claude35Haiku: {
|
||||
ID: Claude35Haiku,
|
||||
Name: "Claude 3.5 Haiku",
|
||||
Provider: ProviderAnthropic,
|
||||
APIModel: "claude-3-5-haiku-latest",
|
||||
CostPer1MIn: 0.80,
|
||||
CostPer1MInCached: 1.0,
|
||||
CostPer1MOutCached: 0.08,
|
||||
CostPer1MOut: 4.0,
|
||||
ContextWindow: 200000,
|
||||
DefaultMaxTokens: 4096,
|
||||
ID: Claude35Haiku,
|
||||
Name: "Claude 3.5 Haiku",
|
||||
Provider: ProviderAnthropic,
|
||||
APIModel: "claude-3-5-haiku-latest",
|
||||
CostPer1MIn: 0.80,
|
||||
CostPer1MInCached: 1.0,
|
||||
CostPer1MOutCached: 0.08,
|
||||
CostPer1MOut: 4.0,
|
||||
ContextWindow: 200000,
|
||||
DefaultMaxTokens: 4096,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
Claude3Opus: {
|
||||
ID: Claude3Opus,
|
||||
Name: "Claude 3 Opus",
|
||||
Provider: ProviderAnthropic,
|
||||
APIModel: "claude-3-opus-latest",
|
||||
CostPer1MIn: 15.0,
|
||||
CostPer1MInCached: 18.75,
|
||||
CostPer1MOutCached: 1.50,
|
||||
CostPer1MOut: 75.0,
|
||||
ContextWindow: 200000,
|
||||
DefaultMaxTokens: 4096,
|
||||
ID: Claude3Opus,
|
||||
Name: "Claude 3 Opus",
|
||||
Provider: ProviderAnthropic,
|
||||
APIModel: "claude-3-opus-latest",
|
||||
CostPer1MIn: 15.0,
|
||||
CostPer1MInCached: 18.75,
|
||||
CostPer1MOutCached: 1.50,
|
||||
CostPer1MOut: 75.0,
|
||||
ContextWindow: 200000,
|
||||
DefaultMaxTokens: 4096,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -18,140 +18,151 @@ const (
|
||||
|
||||
var AzureModels = map[ModelID]Model{
|
||||
AzureGPT41: {
|
||||
ID: AzureGPT41,
|
||||
Name: "Azure OpenAI – GPT 4.1",
|
||||
Provider: ProviderAzure,
|
||||
APIModel: "gpt-4.1",
|
||||
CostPer1MIn: OpenAIModels[GPT41].CostPer1MIn,
|
||||
CostPer1MInCached: OpenAIModels[GPT41].CostPer1MInCached,
|
||||
CostPer1MOut: OpenAIModels[GPT41].CostPer1MOut,
|
||||
CostPer1MOutCached: OpenAIModels[GPT41].CostPer1MOutCached,
|
||||
ContextWindow: OpenAIModels[GPT41].ContextWindow,
|
||||
DefaultMaxTokens: OpenAIModels[GPT41].DefaultMaxTokens,
|
||||
ID: AzureGPT41,
|
||||
Name: "Azure OpenAI – GPT 4.1",
|
||||
Provider: ProviderAzure,
|
||||
APIModel: "gpt-4.1",
|
||||
CostPer1MIn: OpenAIModels[GPT41].CostPer1MIn,
|
||||
CostPer1MInCached: OpenAIModels[GPT41].CostPer1MInCached,
|
||||
CostPer1MOut: OpenAIModels[GPT41].CostPer1MOut,
|
||||
CostPer1MOutCached: OpenAIModels[GPT41].CostPer1MOutCached,
|
||||
ContextWindow: OpenAIModels[GPT41].ContextWindow,
|
||||
DefaultMaxTokens: OpenAIModels[GPT41].DefaultMaxTokens,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
AzureGPT41Mini: {
|
||||
ID: AzureGPT41Mini,
|
||||
Name: "Azure OpenAI – GPT 4.1 mini",
|
||||
Provider: ProviderAzure,
|
||||
APIModel: "gpt-4.1-mini",
|
||||
CostPer1MIn: OpenAIModels[GPT41Mini].CostPer1MIn,
|
||||
CostPer1MInCached: OpenAIModels[GPT41Mini].CostPer1MInCached,
|
||||
CostPer1MOut: OpenAIModels[GPT41Mini].CostPer1MOut,
|
||||
CostPer1MOutCached: OpenAIModels[GPT41Mini].CostPer1MOutCached,
|
||||
ContextWindow: OpenAIModels[GPT41Mini].ContextWindow,
|
||||
DefaultMaxTokens: OpenAIModels[GPT41Mini].DefaultMaxTokens,
|
||||
ID: AzureGPT41Mini,
|
||||
Name: "Azure OpenAI – GPT 4.1 mini",
|
||||
Provider: ProviderAzure,
|
||||
APIModel: "gpt-4.1-mini",
|
||||
CostPer1MIn: OpenAIModels[GPT41Mini].CostPer1MIn,
|
||||
CostPer1MInCached: OpenAIModels[GPT41Mini].CostPer1MInCached,
|
||||
CostPer1MOut: OpenAIModels[GPT41Mini].CostPer1MOut,
|
||||
CostPer1MOutCached: OpenAIModels[GPT41Mini].CostPer1MOutCached,
|
||||
ContextWindow: OpenAIModels[GPT41Mini].ContextWindow,
|
||||
DefaultMaxTokens: OpenAIModels[GPT41Mini].DefaultMaxTokens,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
AzureGPT41Nano: {
|
||||
ID: AzureGPT41Nano,
|
||||
Name: "Azure OpenAI – GPT 4.1 nano",
|
||||
Provider: ProviderAzure,
|
||||
APIModel: "gpt-4.1-nano",
|
||||
CostPer1MIn: OpenAIModels[GPT41Nano].CostPer1MIn,
|
||||
CostPer1MInCached: OpenAIModels[GPT41Nano].CostPer1MInCached,
|
||||
CostPer1MOut: OpenAIModels[GPT41Nano].CostPer1MOut,
|
||||
CostPer1MOutCached: OpenAIModels[GPT41Nano].CostPer1MOutCached,
|
||||
ContextWindow: OpenAIModels[GPT41Nano].ContextWindow,
|
||||
DefaultMaxTokens: OpenAIModels[GPT41Nano].DefaultMaxTokens,
|
||||
ID: AzureGPT41Nano,
|
||||
Name: "Azure OpenAI – GPT 4.1 nano",
|
||||
Provider: ProviderAzure,
|
||||
APIModel: "gpt-4.1-nano",
|
||||
CostPer1MIn: OpenAIModels[GPT41Nano].CostPer1MIn,
|
||||
CostPer1MInCached: OpenAIModels[GPT41Nano].CostPer1MInCached,
|
||||
CostPer1MOut: OpenAIModels[GPT41Nano].CostPer1MOut,
|
||||
CostPer1MOutCached: OpenAIModels[GPT41Nano].CostPer1MOutCached,
|
||||
ContextWindow: OpenAIModels[GPT41Nano].ContextWindow,
|
||||
DefaultMaxTokens: OpenAIModels[GPT41Nano].DefaultMaxTokens,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
AzureGPT45Preview: {
|
||||
ID: AzureGPT45Preview,
|
||||
Name: "Azure OpenAI – GPT 4.5 preview",
|
||||
Provider: ProviderAzure,
|
||||
APIModel: "gpt-4.5-preview",
|
||||
CostPer1MIn: OpenAIModels[GPT45Preview].CostPer1MIn,
|
||||
CostPer1MInCached: OpenAIModels[GPT45Preview].CostPer1MInCached,
|
||||
CostPer1MOut: OpenAIModels[GPT45Preview].CostPer1MOut,
|
||||
CostPer1MOutCached: OpenAIModels[GPT45Preview].CostPer1MOutCached,
|
||||
ContextWindow: OpenAIModels[GPT45Preview].ContextWindow,
|
||||
DefaultMaxTokens: OpenAIModels[GPT45Preview].DefaultMaxTokens,
|
||||
ID: AzureGPT45Preview,
|
||||
Name: "Azure OpenAI – GPT 4.5 preview",
|
||||
Provider: ProviderAzure,
|
||||
APIModel: "gpt-4.5-preview",
|
||||
CostPer1MIn: OpenAIModels[GPT45Preview].CostPer1MIn,
|
||||
CostPer1MInCached: OpenAIModels[GPT45Preview].CostPer1MInCached,
|
||||
CostPer1MOut: OpenAIModels[GPT45Preview].CostPer1MOut,
|
||||
CostPer1MOutCached: OpenAIModels[GPT45Preview].CostPer1MOutCached,
|
||||
ContextWindow: OpenAIModels[GPT45Preview].ContextWindow,
|
||||
DefaultMaxTokens: OpenAIModels[GPT45Preview].DefaultMaxTokens,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
AzureGPT4o: {
|
||||
ID: AzureGPT4o,
|
||||
Name: "Azure OpenAI – GPT-4o",
|
||||
Provider: ProviderAzure,
|
||||
APIModel: "gpt-4o",
|
||||
CostPer1MIn: OpenAIModels[GPT4o].CostPer1MIn,
|
||||
CostPer1MInCached: OpenAIModels[GPT4o].CostPer1MInCached,
|
||||
CostPer1MOut: OpenAIModels[GPT4o].CostPer1MOut,
|
||||
CostPer1MOutCached: OpenAIModels[GPT4o].CostPer1MOutCached,
|
||||
ContextWindow: OpenAIModels[GPT4o].ContextWindow,
|
||||
DefaultMaxTokens: OpenAIModels[GPT4o].DefaultMaxTokens,
|
||||
ID: AzureGPT4o,
|
||||
Name: "Azure OpenAI – GPT-4o",
|
||||
Provider: ProviderAzure,
|
||||
APIModel: "gpt-4o",
|
||||
CostPer1MIn: OpenAIModels[GPT4o].CostPer1MIn,
|
||||
CostPer1MInCached: OpenAIModels[GPT4o].CostPer1MInCached,
|
||||
CostPer1MOut: OpenAIModels[GPT4o].CostPer1MOut,
|
||||
CostPer1MOutCached: OpenAIModels[GPT4o].CostPer1MOutCached,
|
||||
ContextWindow: OpenAIModels[GPT4o].ContextWindow,
|
||||
DefaultMaxTokens: OpenAIModels[GPT4o].DefaultMaxTokens,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
AzureGPT4oMini: {
|
||||
ID: AzureGPT4oMini,
|
||||
Name: "Azure OpenAI – GPT-4o mini",
|
||||
Provider: ProviderAzure,
|
||||
APIModel: "gpt-4o-mini",
|
||||
CostPer1MIn: OpenAIModels[GPT4oMini].CostPer1MIn,
|
||||
CostPer1MInCached: OpenAIModels[GPT4oMini].CostPer1MInCached,
|
||||
CostPer1MOut: OpenAIModels[GPT4oMini].CostPer1MOut,
|
||||
CostPer1MOutCached: OpenAIModels[GPT4oMini].CostPer1MOutCached,
|
||||
ContextWindow: OpenAIModels[GPT4oMini].ContextWindow,
|
||||
DefaultMaxTokens: OpenAIModels[GPT4oMini].DefaultMaxTokens,
|
||||
ID: AzureGPT4oMini,
|
||||
Name: "Azure OpenAI – GPT-4o mini",
|
||||
Provider: ProviderAzure,
|
||||
APIModel: "gpt-4o-mini",
|
||||
CostPer1MIn: OpenAIModels[GPT4oMini].CostPer1MIn,
|
||||
CostPer1MInCached: OpenAIModels[GPT4oMini].CostPer1MInCached,
|
||||
CostPer1MOut: OpenAIModels[GPT4oMini].CostPer1MOut,
|
||||
CostPer1MOutCached: OpenAIModels[GPT4oMini].CostPer1MOutCached,
|
||||
ContextWindow: OpenAIModels[GPT4oMini].ContextWindow,
|
||||
DefaultMaxTokens: OpenAIModels[GPT4oMini].DefaultMaxTokens,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
AzureO1: {
|
||||
ID: AzureO1,
|
||||
Name: "Azure OpenAI – O1",
|
||||
Provider: ProviderAzure,
|
||||
APIModel: "o1",
|
||||
CostPer1MIn: OpenAIModels[O1].CostPer1MIn,
|
||||
CostPer1MInCached: OpenAIModels[O1].CostPer1MInCached,
|
||||
CostPer1MOut: OpenAIModels[O1].CostPer1MOut,
|
||||
CostPer1MOutCached: OpenAIModels[O1].CostPer1MOutCached,
|
||||
ContextWindow: OpenAIModels[O1].ContextWindow,
|
||||
DefaultMaxTokens: OpenAIModels[O1].DefaultMaxTokens,
|
||||
CanReason: OpenAIModels[O1].CanReason,
|
||||
ID: AzureO1,
|
||||
Name: "Azure OpenAI – O1",
|
||||
Provider: ProviderAzure,
|
||||
APIModel: "o1",
|
||||
CostPer1MIn: OpenAIModels[O1].CostPer1MIn,
|
||||
CostPer1MInCached: OpenAIModels[O1].CostPer1MInCached,
|
||||
CostPer1MOut: OpenAIModels[O1].CostPer1MOut,
|
||||
CostPer1MOutCached: OpenAIModels[O1].CostPer1MOutCached,
|
||||
ContextWindow: OpenAIModels[O1].ContextWindow,
|
||||
DefaultMaxTokens: OpenAIModels[O1].DefaultMaxTokens,
|
||||
CanReason: OpenAIModels[O1].CanReason,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
AzureO1Mini: {
|
||||
ID: AzureO1Mini,
|
||||
Name: "Azure OpenAI – O1 mini",
|
||||
Provider: ProviderAzure,
|
||||
APIModel: "o1-mini",
|
||||
CostPer1MIn: OpenAIModels[O1Mini].CostPer1MIn,
|
||||
CostPer1MInCached: OpenAIModels[O1Mini].CostPer1MInCached,
|
||||
CostPer1MOut: OpenAIModels[O1Mini].CostPer1MOut,
|
||||
CostPer1MOutCached: OpenAIModels[O1Mini].CostPer1MOutCached,
|
||||
ContextWindow: OpenAIModels[O1Mini].ContextWindow,
|
||||
DefaultMaxTokens: OpenAIModels[O1Mini].DefaultMaxTokens,
|
||||
CanReason: OpenAIModels[O1Mini].CanReason,
|
||||
ID: AzureO1Mini,
|
||||
Name: "Azure OpenAI – O1 mini",
|
||||
Provider: ProviderAzure,
|
||||
APIModel: "o1-mini",
|
||||
CostPer1MIn: OpenAIModels[O1Mini].CostPer1MIn,
|
||||
CostPer1MInCached: OpenAIModels[O1Mini].CostPer1MInCached,
|
||||
CostPer1MOut: OpenAIModels[O1Mini].CostPer1MOut,
|
||||
CostPer1MOutCached: OpenAIModels[O1Mini].CostPer1MOutCached,
|
||||
ContextWindow: OpenAIModels[O1Mini].ContextWindow,
|
||||
DefaultMaxTokens: OpenAIModels[O1Mini].DefaultMaxTokens,
|
||||
CanReason: OpenAIModels[O1Mini].CanReason,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
AzureO3: {
|
||||
ID: AzureO3,
|
||||
Name: "Azure OpenAI – O3",
|
||||
Provider: ProviderAzure,
|
||||
APIModel: "o3",
|
||||
CostPer1MIn: OpenAIModels[O3].CostPer1MIn,
|
||||
CostPer1MInCached: OpenAIModels[O3].CostPer1MInCached,
|
||||
CostPer1MOut: OpenAIModels[O3].CostPer1MOut,
|
||||
CostPer1MOutCached: OpenAIModels[O3].CostPer1MOutCached,
|
||||
ContextWindow: OpenAIModels[O3].ContextWindow,
|
||||
DefaultMaxTokens: OpenAIModels[O3].DefaultMaxTokens,
|
||||
CanReason: OpenAIModels[O3].CanReason,
|
||||
ID: AzureO3,
|
||||
Name: "Azure OpenAI – O3",
|
||||
Provider: ProviderAzure,
|
||||
APIModel: "o3",
|
||||
CostPer1MIn: OpenAIModels[O3].CostPer1MIn,
|
||||
CostPer1MInCached: OpenAIModels[O3].CostPer1MInCached,
|
||||
CostPer1MOut: OpenAIModels[O3].CostPer1MOut,
|
||||
CostPer1MOutCached: OpenAIModels[O3].CostPer1MOutCached,
|
||||
ContextWindow: OpenAIModels[O3].ContextWindow,
|
||||
DefaultMaxTokens: OpenAIModels[O3].DefaultMaxTokens,
|
||||
CanReason: OpenAIModels[O3].CanReason,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
AzureO3Mini: {
|
||||
ID: AzureO3Mini,
|
||||
Name: "Azure OpenAI – O3 mini",
|
||||
Provider: ProviderAzure,
|
||||
APIModel: "o3-mini",
|
||||
CostPer1MIn: OpenAIModels[O3Mini].CostPer1MIn,
|
||||
CostPer1MInCached: OpenAIModels[O3Mini].CostPer1MInCached,
|
||||
CostPer1MOut: OpenAIModels[O3Mini].CostPer1MOut,
|
||||
CostPer1MOutCached: OpenAIModels[O3Mini].CostPer1MOutCached,
|
||||
ContextWindow: OpenAIModels[O3Mini].ContextWindow,
|
||||
DefaultMaxTokens: OpenAIModels[O3Mini].DefaultMaxTokens,
|
||||
CanReason: OpenAIModels[O3Mini].CanReason,
|
||||
ID: AzureO3Mini,
|
||||
Name: "Azure OpenAI – O3 mini",
|
||||
Provider: ProviderAzure,
|
||||
APIModel: "o3-mini",
|
||||
CostPer1MIn: OpenAIModels[O3Mini].CostPer1MIn,
|
||||
CostPer1MInCached: OpenAIModels[O3Mini].CostPer1MInCached,
|
||||
CostPer1MOut: OpenAIModels[O3Mini].CostPer1MOut,
|
||||
CostPer1MOutCached: OpenAIModels[O3Mini].CostPer1MOutCached,
|
||||
ContextWindow: OpenAIModels[O3Mini].ContextWindow,
|
||||
DefaultMaxTokens: OpenAIModels[O3Mini].DefaultMaxTokens,
|
||||
CanReason: OpenAIModels[O3Mini].CanReason,
|
||||
SupportsAttachments: false,
|
||||
},
|
||||
AzureO4Mini: {
|
||||
ID: AzureO4Mini,
|
||||
Name: "Azure OpenAI – O4 mini",
|
||||
Provider: ProviderAzure,
|
||||
APIModel: "o4-mini",
|
||||
CostPer1MIn: OpenAIModels[O4Mini].CostPer1MIn,
|
||||
CostPer1MInCached: OpenAIModels[O4Mini].CostPer1MInCached,
|
||||
CostPer1MOut: OpenAIModels[O4Mini].CostPer1MOut,
|
||||
CostPer1MOutCached: OpenAIModels[O4Mini].CostPer1MOutCached,
|
||||
ContextWindow: OpenAIModels[O4Mini].ContextWindow,
|
||||
DefaultMaxTokens: OpenAIModels[O4Mini].DefaultMaxTokens,
|
||||
CanReason: OpenAIModels[O4Mini].CanReason,
|
||||
ID: AzureO4Mini,
|
||||
Name: "Azure OpenAI – O4 mini",
|
||||
Provider: ProviderAzure,
|
||||
APIModel: "o4-mini",
|
||||
CostPer1MIn: OpenAIModels[O4Mini].CostPer1MIn,
|
||||
CostPer1MInCached: OpenAIModels[O4Mini].CostPer1MInCached,
|
||||
CostPer1MOut: OpenAIModels[O4Mini].CostPer1MOut,
|
||||
CostPer1MOutCached: OpenAIModels[O4Mini].CostPer1MOutCached,
|
||||
ContextWindow: OpenAIModels[O4Mini].ContextWindow,
|
||||
DefaultMaxTokens: OpenAIModels[O4Mini].DefaultMaxTokens,
|
||||
CanReason: OpenAIModels[O4Mini].CanReason,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -12,52 +12,56 @@ const (
|
||||
|
||||
var GeminiModels = map[ModelID]Model{
|
||||
Gemini25Flash: {
|
||||
ID: Gemini25Flash,
|
||||
Name: "Gemini 2.5 Flash",
|
||||
Provider: ProviderGemini,
|
||||
APIModel: "gemini-2.5-flash-preview-04-17",
|
||||
CostPer1MIn: 0.15,
|
||||
CostPer1MInCached: 0,
|
||||
CostPer1MOutCached: 0,
|
||||
CostPer1MOut: 0.60,
|
||||
ContextWindow: 1000000,
|
||||
DefaultMaxTokens: 50000,
|
||||
ID: Gemini25Flash,
|
||||
Name: "Gemini 2.5 Flash",
|
||||
Provider: ProviderGemini,
|
||||
APIModel: "gemini-2.5-flash-preview-04-17",
|
||||
CostPer1MIn: 0.15,
|
||||
CostPer1MInCached: 0,
|
||||
CostPer1MOutCached: 0,
|
||||
CostPer1MOut: 0.60,
|
||||
ContextWindow: 1000000,
|
||||
DefaultMaxTokens: 50000,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
Gemini25: {
|
||||
ID: Gemini25,
|
||||
Name: "Gemini 2.5 Pro",
|
||||
Provider: ProviderGemini,
|
||||
APIModel: "gemini-2.5-pro-preview-03-25",
|
||||
CostPer1MIn: 1.25,
|
||||
CostPer1MInCached: 0,
|
||||
CostPer1MOutCached: 0,
|
||||
CostPer1MOut: 10,
|
||||
ContextWindow: 1000000,
|
||||
DefaultMaxTokens: 50000,
|
||||
ID: Gemini25,
|
||||
Name: "Gemini 2.5 Pro",
|
||||
Provider: ProviderGemini,
|
||||
APIModel: "gemini-2.5-pro-preview-03-25",
|
||||
CostPer1MIn: 1.25,
|
||||
CostPer1MInCached: 0,
|
||||
CostPer1MOutCached: 0,
|
||||
CostPer1MOut: 10,
|
||||
ContextWindow: 1000000,
|
||||
DefaultMaxTokens: 50000,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
|
||||
Gemini20Flash: {
|
||||
ID: Gemini20Flash,
|
||||
Name: "Gemini 2.0 Flash",
|
||||
Provider: ProviderGemini,
|
||||
APIModel: "gemini-2.0-flash",
|
||||
CostPer1MIn: 0.10,
|
||||
CostPer1MInCached: 0,
|
||||
CostPer1MOutCached: 0,
|
||||
CostPer1MOut: 0.40,
|
||||
ContextWindow: 1000000,
|
||||
DefaultMaxTokens: 6000,
|
||||
ID: Gemini20Flash,
|
||||
Name: "Gemini 2.0 Flash",
|
||||
Provider: ProviderGemini,
|
||||
APIModel: "gemini-2.0-flash",
|
||||
CostPer1MIn: 0.10,
|
||||
CostPer1MInCached: 0,
|
||||
CostPer1MOutCached: 0,
|
||||
CostPer1MOut: 0.40,
|
||||
ContextWindow: 1000000,
|
||||
DefaultMaxTokens: 6000,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
Gemini20FlashLite: {
|
||||
ID: Gemini20FlashLite,
|
||||
Name: "Gemini 2.0 Flash Lite",
|
||||
Provider: ProviderGemini,
|
||||
APIModel: "gemini-2.0-flash-lite",
|
||||
CostPer1MIn: 0.05,
|
||||
CostPer1MInCached: 0,
|
||||
CostPer1MOutCached: 0,
|
||||
CostPer1MOut: 0.30,
|
||||
ContextWindow: 1000000,
|
||||
DefaultMaxTokens: 6000,
|
||||
ID: Gemini20FlashLite,
|
||||
Name: "Gemini 2.0 Flash Lite",
|
||||
Provider: ProviderGemini,
|
||||
APIModel: "gemini-2.0-flash-lite",
|
||||
CostPer1MIn: 0.05,
|
||||
CostPer1MInCached: 0,
|
||||
CostPer1MOutCached: 0,
|
||||
CostPer1MOut: 0.30,
|
||||
ContextWindow: 1000000,
|
||||
DefaultMaxTokens: 6000,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -28,55 +28,60 @@ var GroqModels = map[ModelID]Model{
|
||||
ContextWindow: 128_000,
|
||||
DefaultMaxTokens: 50000,
|
||||
// for some reason, the groq api doesn't like the reasoningEffort parameter
|
||||
CanReason: false,
|
||||
CanReason: false,
|
||||
SupportsAttachments: false,
|
||||
},
|
||||
|
||||
Llama4Scout: {
|
||||
ID: Llama4Scout,
|
||||
Name: "Llama4Scout",
|
||||
Provider: ProviderGROQ,
|
||||
APIModel: "meta-llama/llama-4-scout-17b-16e-instruct",
|
||||
CostPer1MIn: 0.11,
|
||||
CostPer1MInCached: 0,
|
||||
CostPer1MOutCached: 0,
|
||||
CostPer1MOut: 0.34,
|
||||
ContextWindow: 128_000, // 10M when?
|
||||
ID: Llama4Scout,
|
||||
Name: "Llama4Scout",
|
||||
Provider: ProviderGROQ,
|
||||
APIModel: "meta-llama/llama-4-scout-17b-16e-instruct",
|
||||
CostPer1MIn: 0.11,
|
||||
CostPer1MInCached: 0,
|
||||
CostPer1MOutCached: 0,
|
||||
CostPer1MOut: 0.34,
|
||||
ContextWindow: 128_000, // 10M when?
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
|
||||
Llama4Maverick: {
|
||||
ID: Llama4Maverick,
|
||||
Name: "Llama4Maverick",
|
||||
Provider: ProviderGROQ,
|
||||
APIModel: "meta-llama/llama-4-maverick-17b-128e-instruct",
|
||||
CostPer1MIn: 0.20,
|
||||
CostPer1MInCached: 0,
|
||||
CostPer1MOutCached: 0,
|
||||
CostPer1MOut: 0.20,
|
||||
ContextWindow: 128_000,
|
||||
ID: Llama4Maverick,
|
||||
Name: "Llama4Maverick",
|
||||
Provider: ProviderGROQ,
|
||||
APIModel: "meta-llama/llama-4-maverick-17b-128e-instruct",
|
||||
CostPer1MIn: 0.20,
|
||||
CostPer1MInCached: 0,
|
||||
CostPer1MOutCached: 0,
|
||||
CostPer1MOut: 0.20,
|
||||
ContextWindow: 128_000,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
|
||||
Llama3_3_70BVersatile: {
|
||||
ID: Llama3_3_70BVersatile,
|
||||
Name: "Llama3_3_70BVersatile",
|
||||
Provider: ProviderGROQ,
|
||||
APIModel: "llama-3.3-70b-versatile",
|
||||
CostPer1MIn: 0.59,
|
||||
CostPer1MInCached: 0,
|
||||
CostPer1MOutCached: 0,
|
||||
CostPer1MOut: 0.79,
|
||||
ContextWindow: 128_000,
|
||||
ID: Llama3_3_70BVersatile,
|
||||
Name: "Llama3_3_70BVersatile",
|
||||
Provider: ProviderGROQ,
|
||||
APIModel: "llama-3.3-70b-versatile",
|
||||
CostPer1MIn: 0.59,
|
||||
CostPer1MInCached: 0,
|
||||
CostPer1MOutCached: 0,
|
||||
CostPer1MOut: 0.79,
|
||||
ContextWindow: 128_000,
|
||||
SupportsAttachments: false,
|
||||
},
|
||||
|
||||
DeepseekR1DistillLlama70b: {
|
||||
ID: DeepseekR1DistillLlama70b,
|
||||
Name: "DeepseekR1DistillLlama70b",
|
||||
Provider: ProviderGROQ,
|
||||
APIModel: "deepseek-r1-distill-llama-70b",
|
||||
CostPer1MIn: 0.75,
|
||||
CostPer1MInCached: 0,
|
||||
CostPer1MOutCached: 0,
|
||||
CostPer1MOut: 0.99,
|
||||
ContextWindow: 128_000,
|
||||
CanReason: true,
|
||||
ID: DeepseekR1DistillLlama70b,
|
||||
Name: "DeepseekR1DistillLlama70b",
|
||||
Provider: ProviderGROQ,
|
||||
APIModel: "deepseek-r1-distill-llama-70b",
|
||||
CostPer1MIn: 0.75,
|
||||
CostPer1MInCached: 0,
|
||||
CostPer1MOutCached: 0,
|
||||
CostPer1MOut: 0.99,
|
||||
ContextWindow: 128_000,
|
||||
CanReason: true,
|
||||
SupportsAttachments: false,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -8,17 +8,18 @@ type (
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
ID ModelID `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Provider ModelProvider `json:"provider"`
|
||||
APIModel string `json:"api_model"`
|
||||
CostPer1MIn float64 `json:"cost_per_1m_in"`
|
||||
CostPer1MOut float64 `json:"cost_per_1m_out"`
|
||||
CostPer1MInCached float64 `json:"cost_per_1m_in_cached"`
|
||||
CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"`
|
||||
ContextWindow int64 `json:"context_window"`
|
||||
DefaultMaxTokens int64 `json:"default_max_tokens"`
|
||||
CanReason bool `json:"can_reason"`
|
||||
ID ModelID `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Provider ModelProvider `json:"provider"`
|
||||
APIModel string `json:"api_model"`
|
||||
CostPer1MIn float64 `json:"cost_per_1m_in"`
|
||||
CostPer1MOut float64 `json:"cost_per_1m_out"`
|
||||
CostPer1MInCached float64 `json:"cost_per_1m_in_cached"`
|
||||
CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"`
|
||||
ContextWindow int64 `json:"context_window"`
|
||||
DefaultMaxTokens int64 `json:"default_max_tokens"`
|
||||
CanReason bool `json:"can_reason"`
|
||||
SupportsAttachments bool `json:"supports_attachments"`
|
||||
}
|
||||
|
||||
// Model IDs
|
||||
@@ -35,11 +36,13 @@ const (
|
||||
|
||||
// Providers in order of popularity
|
||||
var ProviderPopularity = map[ModelProvider]int{
|
||||
ProviderAnthropic: 1,
|
||||
ProviderOpenAI: 2,
|
||||
ProviderGemini: 3,
|
||||
ProviderGROQ: 4,
|
||||
ProviderBedrock: 5,
|
||||
ProviderAnthropic: 1,
|
||||
ProviderOpenAI: 2,
|
||||
ProviderGemini: 3,
|
||||
ProviderGROQ: 4,
|
||||
ProviderOpenRouter: 5,
|
||||
ProviderBedrock: 6,
|
||||
ProviderAzure: 7,
|
||||
}
|
||||
|
||||
var SupportedModels = map[ModelID]Model{
|
||||
@@ -69,14 +72,18 @@ var SupportedModels = map[ModelID]Model{
|
||||
//
|
||||
// // Bedrock
|
||||
BedrockClaude37Sonnet: {
|
||||
ID: BedrockClaude37Sonnet,
|
||||
Name: "Bedrock: Claude 3.7 Sonnet",
|
||||
Provider: ProviderBedrock,
|
||||
APIModel: "anthropic.claude-3-7-sonnet-20250219-v1:0",
|
||||
CostPer1MIn: 3.0,
|
||||
CostPer1MInCached: 3.75,
|
||||
CostPer1MOutCached: 0.30,
|
||||
CostPer1MOut: 15.0,
|
||||
ID: BedrockClaude37Sonnet,
|
||||
Name: "Bedrock: Claude 3.7 Sonnet",
|
||||
Provider: ProviderBedrock,
|
||||
APIModel: "anthropic.claude-3-7-sonnet-20250219-v1:0",
|
||||
CostPer1MIn: 3.0,
|
||||
CostPer1MInCached: 3.75,
|
||||
CostPer1MOutCached: 0.30,
|
||||
CostPer1MOut: 15.0,
|
||||
ContextWindow: 200_000,
|
||||
DefaultMaxTokens: 50_000,
|
||||
CanReason: true,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -87,4 +94,5 @@ func init() {
|
||||
maps.Copy(SupportedModels, GroqModels)
|
||||
maps.Copy(SupportedModels, AzureModels)
|
||||
maps.Copy(SupportedModels, OpenRouterModels)
|
||||
maps.Copy(SupportedModels, XAIModels)
|
||||
}
|
||||
|
||||
@@ -19,151 +19,163 @@ const (
|
||||
|
||||
var OpenAIModels = map[ModelID]Model{
|
||||
GPT41: {
|
||||
ID: GPT41,
|
||||
Name: "GPT 4.1",
|
||||
Provider: ProviderOpenAI,
|
||||
APIModel: "gpt-4.1",
|
||||
CostPer1MIn: 2.00,
|
||||
CostPer1MInCached: 0.50,
|
||||
CostPer1MOutCached: 0.0,
|
||||
CostPer1MOut: 8.00,
|
||||
ContextWindow: 1_047_576,
|
||||
DefaultMaxTokens: 20000,
|
||||
ID: GPT41,
|
||||
Name: "GPT 4.1",
|
||||
Provider: ProviderOpenAI,
|
||||
APIModel: "gpt-4.1",
|
||||
CostPer1MIn: 2.00,
|
||||
CostPer1MInCached: 0.50,
|
||||
CostPer1MOutCached: 0.0,
|
||||
CostPer1MOut: 8.00,
|
||||
ContextWindow: 1_047_576,
|
||||
DefaultMaxTokens: 20000,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
GPT41Mini: {
|
||||
ID: GPT41Mini,
|
||||
Name: "GPT 4.1 mini",
|
||||
Provider: ProviderOpenAI,
|
||||
APIModel: "gpt-4.1",
|
||||
CostPer1MIn: 0.40,
|
||||
CostPer1MInCached: 0.10,
|
||||
CostPer1MOutCached: 0.0,
|
||||
CostPer1MOut: 1.60,
|
||||
ContextWindow: 200_000,
|
||||
DefaultMaxTokens: 20000,
|
||||
ID: GPT41Mini,
|
||||
Name: "GPT 4.1 mini",
|
||||
Provider: ProviderOpenAI,
|
||||
APIModel: "gpt-4.1",
|
||||
CostPer1MIn: 0.40,
|
||||
CostPer1MInCached: 0.10,
|
||||
CostPer1MOutCached: 0.0,
|
||||
CostPer1MOut: 1.60,
|
||||
ContextWindow: 200_000,
|
||||
DefaultMaxTokens: 20000,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
GPT41Nano: {
|
||||
ID: GPT41Nano,
|
||||
Name: "GPT 4.1 nano",
|
||||
Provider: ProviderOpenAI,
|
||||
APIModel: "gpt-4.1-nano",
|
||||
CostPer1MIn: 0.10,
|
||||
CostPer1MInCached: 0.025,
|
||||
CostPer1MOutCached: 0.0,
|
||||
CostPer1MOut: 0.40,
|
||||
ContextWindow: 1_047_576,
|
||||
DefaultMaxTokens: 20000,
|
||||
ID: GPT41Nano,
|
||||
Name: "GPT 4.1 nano",
|
||||
Provider: ProviderOpenAI,
|
||||
APIModel: "gpt-4.1-nano",
|
||||
CostPer1MIn: 0.10,
|
||||
CostPer1MInCached: 0.025,
|
||||
CostPer1MOutCached: 0.0,
|
||||
CostPer1MOut: 0.40,
|
||||
ContextWindow: 1_047_576,
|
||||
DefaultMaxTokens: 20000,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
GPT45Preview: {
|
||||
ID: GPT45Preview,
|
||||
Name: "GPT 4.5 preview",
|
||||
Provider: ProviderOpenAI,
|
||||
APIModel: "gpt-4.5-preview",
|
||||
CostPer1MIn: 75.00,
|
||||
CostPer1MInCached: 37.50,
|
||||
CostPer1MOutCached: 0.0,
|
||||
CostPer1MOut: 150.00,
|
||||
ContextWindow: 128_000,
|
||||
DefaultMaxTokens: 15000,
|
||||
ID: GPT45Preview,
|
||||
Name: "GPT 4.5 preview",
|
||||
Provider: ProviderOpenAI,
|
||||
APIModel: "gpt-4.5-preview",
|
||||
CostPer1MIn: 75.00,
|
||||
CostPer1MInCached: 37.50,
|
||||
CostPer1MOutCached: 0.0,
|
||||
CostPer1MOut: 150.00,
|
||||
ContextWindow: 128_000,
|
||||
DefaultMaxTokens: 15000,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
GPT4o: {
|
||||
ID: GPT4o,
|
||||
Name: "GPT 4o",
|
||||
Provider: ProviderOpenAI,
|
||||
APIModel: "gpt-4o",
|
||||
CostPer1MIn: 2.50,
|
||||
CostPer1MInCached: 1.25,
|
||||
CostPer1MOutCached: 0.0,
|
||||
CostPer1MOut: 10.00,
|
||||
ContextWindow: 128_000,
|
||||
DefaultMaxTokens: 4096,
|
||||
ID: GPT4o,
|
||||
Name: "GPT 4o",
|
||||
Provider: ProviderOpenAI,
|
||||
APIModel: "gpt-4o",
|
||||
CostPer1MIn: 2.50,
|
||||
CostPer1MInCached: 1.25,
|
||||
CostPer1MOutCached: 0.0,
|
||||
CostPer1MOut: 10.00,
|
||||
ContextWindow: 128_000,
|
||||
DefaultMaxTokens: 4096,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
GPT4oMini: {
|
||||
ID: GPT4oMini,
|
||||
Name: "GPT 4o mini",
|
||||
Provider: ProviderOpenAI,
|
||||
APIModel: "gpt-4o-mini",
|
||||
CostPer1MIn: 0.15,
|
||||
CostPer1MInCached: 0.075,
|
||||
CostPer1MOutCached: 0.0,
|
||||
CostPer1MOut: 0.60,
|
||||
ContextWindow: 128_000,
|
||||
ID: GPT4oMini,
|
||||
Name: "GPT 4o mini",
|
||||
Provider: ProviderOpenAI,
|
||||
APIModel: "gpt-4o-mini",
|
||||
CostPer1MIn: 0.15,
|
||||
CostPer1MInCached: 0.075,
|
||||
CostPer1MOutCached: 0.0,
|
||||
CostPer1MOut: 0.60,
|
||||
ContextWindow: 128_000,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
O1: {
|
||||
ID: O1,
|
||||
Name: "O1",
|
||||
Provider: ProviderOpenAI,
|
||||
APIModel: "o1",
|
||||
CostPer1MIn: 15.00,
|
||||
CostPer1MInCached: 7.50,
|
||||
CostPer1MOutCached: 0.0,
|
||||
CostPer1MOut: 60.00,
|
||||
ContextWindow: 200_000,
|
||||
DefaultMaxTokens: 50000,
|
||||
CanReason: true,
|
||||
ID: O1,
|
||||
Name: "O1",
|
||||
Provider: ProviderOpenAI,
|
||||
APIModel: "o1",
|
||||
CostPer1MIn: 15.00,
|
||||
CostPer1MInCached: 7.50,
|
||||
CostPer1MOutCached: 0.0,
|
||||
CostPer1MOut: 60.00,
|
||||
ContextWindow: 200_000,
|
||||
DefaultMaxTokens: 50000,
|
||||
CanReason: true,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
O1Pro: {
|
||||
ID: O1Pro,
|
||||
Name: "o1 pro",
|
||||
Provider: ProviderOpenAI,
|
||||
APIModel: "o1-pro",
|
||||
CostPer1MIn: 150.00,
|
||||
CostPer1MInCached: 0.0,
|
||||
CostPer1MOutCached: 0.0,
|
||||
CostPer1MOut: 600.00,
|
||||
ContextWindow: 200_000,
|
||||
DefaultMaxTokens: 50000,
|
||||
CanReason: true,
|
||||
ID: O1Pro,
|
||||
Name: "o1 pro",
|
||||
Provider: ProviderOpenAI,
|
||||
APIModel: "o1-pro",
|
||||
CostPer1MIn: 150.00,
|
||||
CostPer1MInCached: 0.0,
|
||||
CostPer1MOutCached: 0.0,
|
||||
CostPer1MOut: 600.00,
|
||||
ContextWindow: 200_000,
|
||||
DefaultMaxTokens: 50000,
|
||||
CanReason: true,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
O1Mini: {
|
||||
ID: O1Mini,
|
||||
Name: "o1 mini",
|
||||
Provider: ProviderOpenAI,
|
||||
APIModel: "o1-mini",
|
||||
CostPer1MIn: 1.10,
|
||||
CostPer1MInCached: 0.55,
|
||||
CostPer1MOutCached: 0.0,
|
||||
CostPer1MOut: 4.40,
|
||||
ContextWindow: 128_000,
|
||||
DefaultMaxTokens: 50000,
|
||||
CanReason: true,
|
||||
ID: O1Mini,
|
||||
Name: "o1 mini",
|
||||
Provider: ProviderOpenAI,
|
||||
APIModel: "o1-mini",
|
||||
CostPer1MIn: 1.10,
|
||||
CostPer1MInCached: 0.55,
|
||||
CostPer1MOutCached: 0.0,
|
||||
CostPer1MOut: 4.40,
|
||||
ContextWindow: 128_000,
|
||||
DefaultMaxTokens: 50000,
|
||||
CanReason: true,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
O3: {
|
||||
ID: O3,
|
||||
Name: "o3",
|
||||
Provider: ProviderOpenAI,
|
||||
APIModel: "o3",
|
||||
CostPer1MIn: 10.00,
|
||||
CostPer1MInCached: 2.50,
|
||||
CostPer1MOutCached: 0.0,
|
||||
CostPer1MOut: 40.00,
|
||||
ContextWindow: 200_000,
|
||||
CanReason: true,
|
||||
ID: O3,
|
||||
Name: "o3",
|
||||
Provider: ProviderOpenAI,
|
||||
APIModel: "o3",
|
||||
CostPer1MIn: 10.00,
|
||||
CostPer1MInCached: 2.50,
|
||||
CostPer1MOutCached: 0.0,
|
||||
CostPer1MOut: 40.00,
|
||||
ContextWindow: 200_000,
|
||||
CanReason: true,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
O3Mini: {
|
||||
ID: O3Mini,
|
||||
Name: "o3 mini",
|
||||
Provider: ProviderOpenAI,
|
||||
APIModel: "o3-mini",
|
||||
CostPer1MIn: 1.10,
|
||||
CostPer1MInCached: 0.55,
|
||||
CostPer1MOutCached: 0.0,
|
||||
CostPer1MOut: 4.40,
|
||||
ContextWindow: 200_000,
|
||||
DefaultMaxTokens: 50000,
|
||||
CanReason: true,
|
||||
ID: O3Mini,
|
||||
Name: "o3 mini",
|
||||
Provider: ProviderOpenAI,
|
||||
APIModel: "o3-mini",
|
||||
CostPer1MIn: 1.10,
|
||||
CostPer1MInCached: 0.55,
|
||||
CostPer1MOutCached: 0.0,
|
||||
CostPer1MOut: 4.40,
|
||||
ContextWindow: 200_000,
|
||||
DefaultMaxTokens: 50000,
|
||||
CanReason: true,
|
||||
SupportsAttachments: false,
|
||||
},
|
||||
O4Mini: {
|
||||
ID: O4Mini,
|
||||
Name: "o4 mini",
|
||||
Provider: ProviderOpenAI,
|
||||
APIModel: "o4-mini",
|
||||
CostPer1MIn: 1.10,
|
||||
CostPer1MInCached: 0.275,
|
||||
CostPer1MOutCached: 0.0,
|
||||
CostPer1MOut: 4.40,
|
||||
ContextWindow: 128_000,
|
||||
DefaultMaxTokens: 50000,
|
||||
CanReason: true,
|
||||
ID: O4Mini,
|
||||
Name: "o4 mini",
|
||||
Provider: ProviderOpenAI,
|
||||
APIModel: "o4-mini",
|
||||
CostPer1MIn: 1.10,
|
||||
CostPer1MInCached: 0.275,
|
||||
CostPer1MOutCached: 0.0,
|
||||
CostPer1MOut: 4.40,
|
||||
ContextWindow: 128_000,
|
||||
DefaultMaxTokens: 50000,
|
||||
CanReason: true,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -22,12 +22,17 @@ const (
|
||||
OpenRouterClaude37Sonnet ModelID = "openrouter.claude-3.7-sonnet"
|
||||
OpenRouterClaude35Haiku ModelID = "openrouter.claude-3.5-haiku"
|
||||
OpenRouterClaude3Opus ModelID = "openrouter.claude-3-opus"
|
||||
OpenRouterQwen235B ModelID = "openrouter.qwen-3-235b"
|
||||
OpenRouterQwen32B ModelID = "openrouter.qwen-3-32b"
|
||||
OpenRouterQwen30B ModelID = "openrouter.qwen-3-30b"
|
||||
OpenRouterQwen14B ModelID = "openrouter.qwen-3-14b"
|
||||
OpenRouterQwen8B ModelID = "openrouter.qwen-3-8b"
|
||||
)
|
||||
|
||||
var OpenRouterModels = map[ModelID]Model{
|
||||
OpenRouterGPT41: {
|
||||
ID: OpenRouterGPT41,
|
||||
Name: "OpenRouter – GPT 4.1",
|
||||
Name: "OpenRouter: GPT 4.1",
|
||||
Provider: ProviderOpenRouter,
|
||||
APIModel: "openai/gpt-4.1",
|
||||
CostPer1MIn: OpenAIModels[GPT41].CostPer1MIn,
|
||||
@@ -39,7 +44,7 @@ var OpenRouterModels = map[ModelID]Model{
|
||||
},
|
||||
OpenRouterGPT41Mini: {
|
||||
ID: OpenRouterGPT41Mini,
|
||||
Name: "OpenRouter – GPT 4.1 mini",
|
||||
Name: "OpenRouter: GPT 4.1 mini",
|
||||
Provider: ProviderOpenRouter,
|
||||
APIModel: "openai/gpt-4.1-mini",
|
||||
CostPer1MIn: OpenAIModels[GPT41Mini].CostPer1MIn,
|
||||
@@ -51,7 +56,7 @@ var OpenRouterModels = map[ModelID]Model{
|
||||
},
|
||||
OpenRouterGPT41Nano: {
|
||||
ID: OpenRouterGPT41Nano,
|
||||
Name: "OpenRouter – GPT 4.1 nano",
|
||||
Name: "OpenRouter: GPT 4.1 nano",
|
||||
Provider: ProviderOpenRouter,
|
||||
APIModel: "openai/gpt-4.1-nano",
|
||||
CostPer1MIn: OpenAIModels[GPT41Nano].CostPer1MIn,
|
||||
@@ -63,7 +68,7 @@ var OpenRouterModels = map[ModelID]Model{
|
||||
},
|
||||
OpenRouterGPT45Preview: {
|
||||
ID: OpenRouterGPT45Preview,
|
||||
Name: "OpenRouter – GPT 4.5 preview",
|
||||
Name: "OpenRouter: GPT 4.5 preview",
|
||||
Provider: ProviderOpenRouter,
|
||||
APIModel: "openai/gpt-4.5-preview",
|
||||
CostPer1MIn: OpenAIModels[GPT45Preview].CostPer1MIn,
|
||||
@@ -75,7 +80,7 @@ var OpenRouterModels = map[ModelID]Model{
|
||||
},
|
||||
OpenRouterGPT4o: {
|
||||
ID: OpenRouterGPT4o,
|
||||
Name: "OpenRouter – GPT 4o",
|
||||
Name: "OpenRouter: GPT 4o",
|
||||
Provider: ProviderOpenRouter,
|
||||
APIModel: "openai/gpt-4o",
|
||||
CostPer1MIn: OpenAIModels[GPT4o].CostPer1MIn,
|
||||
@@ -87,7 +92,7 @@ var OpenRouterModels = map[ModelID]Model{
|
||||
},
|
||||
OpenRouterGPT4oMini: {
|
||||
ID: OpenRouterGPT4oMini,
|
||||
Name: "OpenRouter – GPT 4o mini",
|
||||
Name: "OpenRouter: GPT 4o mini",
|
||||
Provider: ProviderOpenRouter,
|
||||
APIModel: "openai/gpt-4o-mini",
|
||||
CostPer1MIn: OpenAIModels[GPT4oMini].CostPer1MIn,
|
||||
@@ -98,7 +103,7 @@ var OpenRouterModels = map[ModelID]Model{
|
||||
},
|
||||
OpenRouterO1: {
|
||||
ID: OpenRouterO1,
|
||||
Name: "OpenRouter – O1",
|
||||
Name: "OpenRouter: O1",
|
||||
Provider: ProviderOpenRouter,
|
||||
APIModel: "openai/o1",
|
||||
CostPer1MIn: OpenAIModels[O1].CostPer1MIn,
|
||||
@@ -111,7 +116,7 @@ var OpenRouterModels = map[ModelID]Model{
|
||||
},
|
||||
OpenRouterO1Pro: {
|
||||
ID: OpenRouterO1Pro,
|
||||
Name: "OpenRouter – o1 pro",
|
||||
Name: "OpenRouter: o1 pro",
|
||||
Provider: ProviderOpenRouter,
|
||||
APIModel: "openai/o1-pro",
|
||||
CostPer1MIn: OpenAIModels[O1Pro].CostPer1MIn,
|
||||
@@ -124,7 +129,7 @@ var OpenRouterModels = map[ModelID]Model{
|
||||
},
|
||||
OpenRouterO1Mini: {
|
||||
ID: OpenRouterO1Mini,
|
||||
Name: "OpenRouter – o1 mini",
|
||||
Name: "OpenRouter: o1 mini",
|
||||
Provider: ProviderOpenRouter,
|
||||
APIModel: "openai/o1-mini",
|
||||
CostPer1MIn: OpenAIModels[O1Mini].CostPer1MIn,
|
||||
@@ -137,7 +142,7 @@ var OpenRouterModels = map[ModelID]Model{
|
||||
},
|
||||
OpenRouterO3: {
|
||||
ID: OpenRouterO3,
|
||||
Name: "OpenRouter – o3",
|
||||
Name: "OpenRouter: o3",
|
||||
Provider: ProviderOpenRouter,
|
||||
APIModel: "openai/o3",
|
||||
CostPer1MIn: OpenAIModels[O3].CostPer1MIn,
|
||||
@@ -150,7 +155,7 @@ var OpenRouterModels = map[ModelID]Model{
|
||||
},
|
||||
OpenRouterO3Mini: {
|
||||
ID: OpenRouterO3Mini,
|
||||
Name: "OpenRouter – o3 mini",
|
||||
Name: "OpenRouter: o3 mini",
|
||||
Provider: ProviderOpenRouter,
|
||||
APIModel: "openai/o3-mini-high",
|
||||
CostPer1MIn: OpenAIModels[O3Mini].CostPer1MIn,
|
||||
@@ -163,7 +168,7 @@ var OpenRouterModels = map[ModelID]Model{
|
||||
},
|
||||
OpenRouterO4Mini: {
|
||||
ID: OpenRouterO4Mini,
|
||||
Name: "OpenRouter – o4 mini",
|
||||
Name: "OpenRouter: o4 mini",
|
||||
Provider: ProviderOpenRouter,
|
||||
APIModel: "openai/o4-mini-high",
|
||||
CostPer1MIn: OpenAIModels[O4Mini].CostPer1MIn,
|
||||
@@ -176,7 +181,7 @@ var OpenRouterModels = map[ModelID]Model{
|
||||
},
|
||||
OpenRouterGemini25Flash: {
|
||||
ID: OpenRouterGemini25Flash,
|
||||
Name: "OpenRouter – Gemini 2.5 Flash",
|
||||
Name: "OpenRouter: Gemini 2.5 Flash",
|
||||
Provider: ProviderOpenRouter,
|
||||
APIModel: "google/gemini-2.5-flash-preview:thinking",
|
||||
CostPer1MIn: GeminiModels[Gemini25Flash].CostPer1MIn,
|
||||
@@ -188,7 +193,7 @@ var OpenRouterModels = map[ModelID]Model{
|
||||
},
|
||||
OpenRouterGemini25: {
|
||||
ID: OpenRouterGemini25,
|
||||
Name: "OpenRouter – Gemini 2.5 Pro",
|
||||
Name: "OpenRouter: Gemini 2.5 Pro",
|
||||
Provider: ProviderOpenRouter,
|
||||
APIModel: "google/gemini-2.5-pro-preview-03-25",
|
||||
CostPer1MIn: GeminiModels[Gemini25].CostPer1MIn,
|
||||
@@ -200,7 +205,7 @@ var OpenRouterModels = map[ModelID]Model{
|
||||
},
|
||||
OpenRouterClaude35Sonnet: {
|
||||
ID: OpenRouterClaude35Sonnet,
|
||||
Name: "OpenRouter – Claude 3.5 Sonnet",
|
||||
Name: "OpenRouter: Claude 3.5 Sonnet",
|
||||
Provider: ProviderOpenRouter,
|
||||
APIModel: "anthropic/claude-3.5-sonnet",
|
||||
CostPer1MIn: AnthropicModels[Claude35Sonnet].CostPer1MIn,
|
||||
@@ -212,7 +217,7 @@ var OpenRouterModels = map[ModelID]Model{
|
||||
},
|
||||
OpenRouterClaude3Haiku: {
|
||||
ID: OpenRouterClaude3Haiku,
|
||||
Name: "OpenRouter – Claude 3 Haiku",
|
||||
Name: "OpenRouter: Claude 3 Haiku",
|
||||
Provider: ProviderOpenRouter,
|
||||
APIModel: "anthropic/claude-3-haiku",
|
||||
CostPer1MIn: AnthropicModels[Claude3Haiku].CostPer1MIn,
|
||||
@@ -224,7 +229,7 @@ var OpenRouterModels = map[ModelID]Model{
|
||||
},
|
||||
OpenRouterClaude37Sonnet: {
|
||||
ID: OpenRouterClaude37Sonnet,
|
||||
Name: "OpenRouter – Claude 3.7 Sonnet",
|
||||
Name: "OpenRouter: Claude 3.7 Sonnet",
|
||||
Provider: ProviderOpenRouter,
|
||||
APIModel: "anthropic/claude-3.7-sonnet",
|
||||
CostPer1MIn: AnthropicModels[Claude37Sonnet].CostPer1MIn,
|
||||
@@ -237,7 +242,7 @@ var OpenRouterModels = map[ModelID]Model{
|
||||
},
|
||||
OpenRouterClaude35Haiku: {
|
||||
ID: OpenRouterClaude35Haiku,
|
||||
Name: "OpenRouter – Claude 3.5 Haiku",
|
||||
Name: "OpenRouter: Claude 3.5 Haiku",
|
||||
Provider: ProviderOpenRouter,
|
||||
APIModel: "anthropic/claude-3.5-haiku",
|
||||
CostPer1MIn: AnthropicModels[Claude35Haiku].CostPer1MIn,
|
||||
@@ -249,7 +254,7 @@ var OpenRouterModels = map[ModelID]Model{
|
||||
},
|
||||
OpenRouterClaude3Opus: {
|
||||
ID: OpenRouterClaude3Opus,
|
||||
Name: "OpenRouter – Claude 3 Opus",
|
||||
Name: "OpenRouter: Claude 3 Opus",
|
||||
Provider: ProviderOpenRouter,
|
||||
APIModel: "anthropic/claude-3-opus",
|
||||
CostPer1MIn: AnthropicModels[Claude3Opus].CostPer1MIn,
|
||||
@@ -259,4 +264,64 @@ var OpenRouterModels = map[ModelID]Model{
|
||||
ContextWindow: AnthropicModels[Claude3Opus].ContextWindow,
|
||||
DefaultMaxTokens: AnthropicModels[Claude3Opus].DefaultMaxTokens,
|
||||
},
|
||||
OpenRouterQwen235B: {
|
||||
ID: OpenRouterQwen235B,
|
||||
Name: "OpenRouter: Qwen3 235B A22B",
|
||||
Provider: ProviderOpenRouter,
|
||||
APIModel: "qwen/qwen3-235b-a22b",
|
||||
CostPer1MIn: 0.1,
|
||||
CostPer1MInCached: 0.1,
|
||||
CostPer1MOut: 0.1,
|
||||
CostPer1MOutCached: 0.1,
|
||||
ContextWindow: 40960,
|
||||
DefaultMaxTokens: 4096,
|
||||
},
|
||||
OpenRouterQwen32B: {
|
||||
ID: OpenRouterQwen32B,
|
||||
Name: "OpenRouter: Qwen3 32B",
|
||||
Provider: ProviderOpenRouter,
|
||||
APIModel: "qwen/qwen3-32b",
|
||||
CostPer1MIn: 0.1,
|
||||
CostPer1MInCached: 0.1,
|
||||
CostPer1MOut: 0.3,
|
||||
CostPer1MOutCached: 0.3,
|
||||
ContextWindow: 40960,
|
||||
DefaultMaxTokens: 4096,
|
||||
},
|
||||
OpenRouterQwen30B: {
|
||||
ID: OpenRouterQwen30B,
|
||||
Name: "OpenRouter: Qwen3 30B A3B",
|
||||
Provider: ProviderOpenRouter,
|
||||
APIModel: "qwen/qwen3-30b-a3b",
|
||||
CostPer1MIn: 0.1,
|
||||
CostPer1MInCached: 0.1,
|
||||
CostPer1MOut: 0.3,
|
||||
CostPer1MOutCached: 0.3,
|
||||
ContextWindow: 40960,
|
||||
DefaultMaxTokens: 4096,
|
||||
},
|
||||
OpenRouterQwen14B: {
|
||||
ID: OpenRouterQwen14B,
|
||||
Name: "OpenRouter: Qwen3 14B",
|
||||
Provider: ProviderOpenRouter,
|
||||
APIModel: "qwen/qwen3-14b",
|
||||
CostPer1MIn: 0.7,
|
||||
CostPer1MInCached: 0.7,
|
||||
CostPer1MOut: 0.24,
|
||||
CostPer1MOutCached: 0.24,
|
||||
ContextWindow: 40960,
|
||||
DefaultMaxTokens: 4096,
|
||||
},
|
||||
OpenRouterQwen8B: {
|
||||
ID: OpenRouterQwen8B,
|
||||
Name: "OpenRouter: Qwen3 8B",
|
||||
Provider: ProviderOpenRouter,
|
||||
APIModel: "qwen/qwen3-8b",
|
||||
CostPer1MIn: 0.35,
|
||||
CostPer1MInCached: 0.35,
|
||||
CostPer1MOut: 0.138,
|
||||
CostPer1MOutCached: 0.138,
|
||||
ContextWindow: 128000,
|
||||
DefaultMaxTokens: 4096,
|
||||
},
|
||||
}
|
||||
|
||||
61
internal/llm/models/xai.go
Normal file
61
internal/llm/models/xai.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package models
|
||||
|
||||
const (
|
||||
ProviderXAI ModelProvider = "xai"
|
||||
|
||||
XAIGrok3Beta ModelID = "grok-3-beta"
|
||||
XAIGrok3MiniBeta ModelID = "grok-3-mini-beta"
|
||||
XAIGrok3FastBeta ModelID = "grok-3-fast-beta"
|
||||
XAiGrok3MiniFastBeta ModelID = "grok-3-mini-fast-beta"
|
||||
)
|
||||
|
||||
var XAIModels = map[ModelID]Model{
|
||||
XAIGrok3Beta: {
|
||||
ID: XAIGrok3Beta,
|
||||
Name: "Grok3 Beta",
|
||||
Provider: ProviderXAI,
|
||||
APIModel: "grok-3-beta",
|
||||
CostPer1MIn: 3.0,
|
||||
CostPer1MInCached: 0,
|
||||
CostPer1MOut: 15,
|
||||
CostPer1MOutCached: 0,
|
||||
ContextWindow: 131_072,
|
||||
DefaultMaxTokens: 20_000,
|
||||
},
|
||||
XAIGrok3MiniBeta: {
|
||||
ID: XAIGrok3MiniBeta,
|
||||
Name: "Grok3 Mini Beta",
|
||||
Provider: ProviderXAI,
|
||||
APIModel: "grok-3-mini-beta",
|
||||
CostPer1MIn: 0.3,
|
||||
CostPer1MInCached: 0,
|
||||
CostPer1MOut: 0.5,
|
||||
CostPer1MOutCached: 0,
|
||||
ContextWindow: 131_072,
|
||||
DefaultMaxTokens: 20_000,
|
||||
},
|
||||
XAIGrok3FastBeta: {
|
||||
ID: XAIGrok3FastBeta,
|
||||
Name: "Grok3 Fast Beta",
|
||||
Provider: ProviderXAI,
|
||||
APIModel: "grok-3-fast-beta",
|
||||
CostPer1MIn: 5,
|
||||
CostPer1MInCached: 0,
|
||||
CostPer1MOut: 25,
|
||||
CostPer1MOutCached: 0,
|
||||
ContextWindow: 131_072,
|
||||
DefaultMaxTokens: 20_000,
|
||||
},
|
||||
XAiGrok3MiniFastBeta: {
|
||||
ID: XAiGrok3MiniFastBeta,
|
||||
Name: "Grok3 Mini Fast Beta",
|
||||
Provider: ProviderXAI,
|
||||
APIModel: "grok-3-mini-fast-beta",
|
||||
CostPer1MIn: 0.6,
|
||||
CostPer1MInCached: 0,
|
||||
CostPer1MOut: 4.0,
|
||||
CostPer1MOutCached: 0,
|
||||
ContextWindow: 131_072,
|
||||
DefaultMaxTokens: 20_000,
|
||||
},
|
||||
}
|
||||
@@ -8,23 +8,23 @@ import (
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/opencode-ai/opencode/internal/config"
|
||||
"github.com/opencode-ai/opencode/internal/llm/models"
|
||||
"github.com/opencode-ai/opencode/internal/llm/tools"
|
||||
"github.com/sst/opencode/internal/config"
|
||||
"github.com/sst/opencode/internal/llm/models"
|
||||
"github.com/sst/opencode/internal/llm/tools"
|
||||
)
|
||||
|
||||
func CoderPrompt(provider models.ModelProvider) string {
|
||||
basePrompt := baseAnthropicCoderPrompt
|
||||
func PrimaryPrompt(provider models.ModelProvider) string {
|
||||
basePrompt := baseAnthropicPrimaryPrompt
|
||||
switch provider {
|
||||
case models.ProviderOpenAI:
|
||||
basePrompt = baseOpenAICoderPrompt
|
||||
basePrompt = baseOpenAIPrimaryPrompt
|
||||
}
|
||||
envInfo := getEnvironmentInfo()
|
||||
|
||||
return fmt.Sprintf("%s\n\n%s\n%s", basePrompt, envInfo, lspInformation())
|
||||
}
|
||||
|
||||
const baseOpenAICoderPrompt = `
|
||||
const baseOpenAIPrimaryPrompt = `
|
||||
You are operating as and within the OpenCode CLI, a terminal-based agentic coding assistant built by OpenAI. It wraps OpenAI models to enable natural language interaction with a local codebase. You are expected to be precise, safe, and helpful.
|
||||
|
||||
You can:
|
||||
@@ -71,7 +71,7 @@ You MUST adhere to the following criteria when executing the task:
|
||||
- Remember the user does not see the full output of tools
|
||||
`
|
||||
|
||||
const baseAnthropicCoderPrompt = `You are OpenCode, an interactive CLI tool that helps users with software engineering tasks. Use the instructions below and the tools available to you to assist the user.
|
||||
const baseAnthropicPrimaryPrompt = `You are OpenCode, an interactive CLI tool that helps users with software engineering tasks. Use the instructions below and the tools available to you to assist the user.
|
||||
|
||||
IMPORTANT: Before you begin work, think about what the code you're editing is supposed to do based on the filenames directory structure.
|
||||
|
||||
@@ -81,7 +81,7 @@ If the current working directory contains a file called OpenCode.md, it will be
|
||||
2. Recording the user's code style preferences (naming conventions, preferred libraries, etc.)
|
||||
3. Maintaining useful information about the codebase structure and organization
|
||||
|
||||
When you spend time searching for commands to typecheck, lint, build, or test, you should ask the user if it's okay to add those commands to OpenCode.md. Similarly, when learning about code style preferences or important codebase information, ask if it's okay to add that to OpenCode.md so you can remember it for next time.
|
||||
When you spend time searching for commands to typecheck, lint, build, or test, you should ask the user if it's okay to add those commands to CONTEXT.md. Similarly, when learning about code style preferences or important codebase information, ask if it's okay to add that to CONTEXT.md so you can remember it for next time.
|
||||
|
||||
# Tone and style
|
||||
You should be concise, direct, and to the point. When you run a non-trivial bash command, you should explain what the command does and why you are running it, to make sure the user understands what you are doing (this is especially important when you are running a command that will make changes to the user's system).
|
||||
@@ -103,7 +103,7 @@ assistant: 4
|
||||
|
||||
<example>
|
||||
user: is 11 a prime number?
|
||||
assistant: true
|
||||
assistant: yes
|
||||
</example>
|
||||
|
||||
<example>
|
||||
@@ -7,16 +7,16 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/opencode-ai/opencode/internal/config"
|
||||
"github.com/opencode-ai/opencode/internal/llm/models"
|
||||
"github.com/opencode-ai/opencode/internal/logging"
|
||||
"github.com/sst/opencode/internal/config"
|
||||
"github.com/sst/opencode/internal/llm/models"
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
func GetAgentPrompt(agentName config.AgentName, provider models.ModelProvider) string {
|
||||
basePrompt := ""
|
||||
switch agentName {
|
||||
case config.AgentCoder:
|
||||
basePrompt = CoderPrompt(provider)
|
||||
case config.AgentPrimary:
|
||||
basePrompt = PrimaryPrompt(provider)
|
||||
case config.AgentTitle:
|
||||
basePrompt = TitlePrompt(provider)
|
||||
case config.AgentTask:
|
||||
@@ -25,10 +25,10 @@ func GetAgentPrompt(agentName config.AgentName, provider models.ModelProvider) s
|
||||
basePrompt = "You are a helpful assistant"
|
||||
}
|
||||
|
||||
if agentName == config.AgentCoder || agentName == config.AgentTask {
|
||||
if agentName == config.AgentPrimary || agentName == config.AgentTask {
|
||||
// Add context from project-specific instruction files if they exist
|
||||
contextContent := getContextFromPaths()
|
||||
logging.Debug("Context content", "Context", contextContent)
|
||||
slog.Debug("Context content", "Context", contextContent)
|
||||
if contextContent != "" {
|
||||
return fmt.Sprintf("%s\n\n# Project-Specific Context\n Make sure to follow the instructions in the context below\n%s", basePrompt, contextContent)
|
||||
}
|
||||
|
||||
61
internal/llm/prompt/prompt_test.go
Normal file
61
internal/llm/prompt/prompt_test.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package prompt
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/sst/opencode/internal/config"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGetContextFromPaths(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
lvl := new(slog.LevelVar)
|
||||
lvl.Set(slog.LevelDebug)
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
_, err := config.Load(tmpDir, false, lvl)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
cfg := config.Get()
|
||||
cfg.WorkingDir = tmpDir
|
||||
cfg.ContextPaths = []string{
|
||||
"file.txt",
|
||||
"directory/",
|
||||
}
|
||||
testFiles := []string{
|
||||
"file.txt",
|
||||
"directory/file_a.txt",
|
||||
"directory/file_b.txt",
|
||||
"directory/file_c.txt",
|
||||
}
|
||||
|
||||
createTestFiles(t, tmpDir, testFiles)
|
||||
|
||||
context := getContextFromPaths()
|
||||
expectedContext := fmt.Sprintf("# From:%s/file.txt\nfile.txt: test content\n# From:%s/directory/file_a.txt\ndirectory/file_a.txt: test content\n# From:%s/directory/file_b.txt\ndirectory/file_b.txt: test content\n# From:%s/directory/file_c.txt\ndirectory/file_c.txt: test content", tmpDir, tmpDir, tmpDir, tmpDir)
|
||||
assert.Equal(t, expectedContext, context)
|
||||
}
|
||||
|
||||
func createTestFiles(t *testing.T, tmpDir string, testFiles []string) {
|
||||
t.Helper()
|
||||
for _, path := range testFiles {
|
||||
fullPath := filepath.Join(tmpDir, path)
|
||||
if path[len(path)-1] == '/' {
|
||||
err := os.MkdirAll(fullPath, 0755)
|
||||
require.NoError(t, err)
|
||||
} else {
|
||||
dir := filepath.Dir(fullPath)
|
||||
err := os.MkdirAll(dir, 0755)
|
||||
require.NoError(t, err)
|
||||
err = os.WriteFile(fullPath, []byte(path+": test content"), 0644)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -3,7 +3,7 @@ package prompt
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/opencode-ai/opencode/internal/llm/models"
|
||||
"github.com/sst/opencode/internal/llm/models"
|
||||
)
|
||||
|
||||
func TaskPrompt(_ models.ModelProvider) string {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
package prompt
|
||||
|
||||
import "github.com/opencode-ai/opencode/internal/llm/models"
|
||||
import "github.com/sst/opencode/internal/llm/models"
|
||||
|
||||
func TitlePrompt(_ models.ModelProvider) string {
|
||||
return `you will generate a short title based on the first message a user begins a conversation with
|
||||
@@ -8,5 +8,6 @@ func TitlePrompt(_ models.ModelProvider) string {
|
||||
- the title should be a summary of the user's message
|
||||
- it should be one line long
|
||||
- do not use quotes or colons
|
||||
- the entire text you return will be used as the title`
|
||||
- the entire text you return will be used as the title
|
||||
- never return anything that is more than one sentence (one line) long`
|
||||
}
|
||||
|
||||
@@ -12,10 +12,12 @@ import (
|
||||
"github.com/anthropics/anthropic-sdk-go"
|
||||
"github.com/anthropics/anthropic-sdk-go/bedrock"
|
||||
"github.com/anthropics/anthropic-sdk-go/option"
|
||||
"github.com/opencode-ai/opencode/internal/config"
|
||||
"github.com/opencode-ai/opencode/internal/llm/tools"
|
||||
"github.com/opencode-ai/opencode/internal/logging"
|
||||
"github.com/opencode-ai/opencode/internal/message"
|
||||
"github.com/sst/opencode/internal/config"
|
||||
"github.com/sst/opencode/internal/llm/models"
|
||||
"github.com/sst/opencode/internal/llm/tools"
|
||||
"github.com/sst/opencode/internal/message"
|
||||
"github.com/sst/opencode/internal/status"
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
type anthropicOptions struct {
|
||||
@@ -70,18 +72,29 @@ func (a *anthropicClient) convertMessages(messages []message.Message) (anthropic
|
||||
Type: "ephemeral",
|
||||
}
|
||||
}
|
||||
anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(content))
|
||||
var contentBlocks []anthropic.ContentBlockParamUnion
|
||||
contentBlocks = append(contentBlocks, content)
|
||||
for _, binaryContent := range msg.BinaryContent() {
|
||||
base64Image := binaryContent.String(models.ProviderAnthropic)
|
||||
imageBlock := anthropic.NewImageBlockBase64(binaryContent.MIMEType, base64Image)
|
||||
contentBlocks = append(contentBlocks, imageBlock)
|
||||
}
|
||||
anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(contentBlocks...))
|
||||
|
||||
case message.Assistant:
|
||||
blocks := []anthropic.ContentBlockParamUnion{}
|
||||
if msg.Content().String() != "" {
|
||||
content := anthropic.NewTextBlock(msg.Content().String())
|
||||
if cache && !a.options.disableCache {
|
||||
content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
|
||||
Type: "ephemeral",
|
||||
|
||||
if msg.Content() != nil {
|
||||
content := msg.Content().String()
|
||||
if strings.TrimSpace(content) != "" {
|
||||
block := anthropic.NewTextBlock(content)
|
||||
if cache && !a.options.disableCache {
|
||||
block.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
|
||||
Type: "ephemeral",
|
||||
}
|
||||
}
|
||||
blocks = append(blocks, block)
|
||||
}
|
||||
blocks = append(blocks, content)
|
||||
}
|
||||
|
||||
for _, toolCall := range msg.ToolCalls() {
|
||||
@@ -94,7 +107,7 @@ func (a *anthropicClient) convertMessages(messages []message.Message) (anthropic
|
||||
}
|
||||
|
||||
if len(blocks) == 0 {
|
||||
logging.Warn("There is a message without content, investigate, this should not happen")
|
||||
slog.Warn("There is a message without content, investigate, this should not happen")
|
||||
continue
|
||||
}
|
||||
anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...))
|
||||
@@ -196,9 +209,10 @@ func (a *anthropicClient) send(ctx context.Context, messages []message.Message,
|
||||
preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools))
|
||||
cfg := config.Get()
|
||||
if cfg.Debug {
|
||||
// jsonData, _ := json.Marshal(preparedMessages)
|
||||
// logging.Debug("Prepared messages", "messages", string(jsonData))
|
||||
jsonData, _ := json.Marshal(preparedMessages)
|
||||
slog.Debug("Prepared messages", "messages", string(jsonData))
|
||||
}
|
||||
|
||||
attempts := 0
|
||||
for {
|
||||
attempts++
|
||||
@@ -208,12 +222,13 @@ func (a *anthropicClient) send(ctx context.Context, messages []message.Message,
|
||||
)
|
||||
// If there is an error we are going to see if we can retry the call
|
||||
if err != nil {
|
||||
slog.Error("Error in Anthropic API call", "error", err)
|
||||
retry, after, retryErr := a.shouldRetry(attempts, err)
|
||||
if retryErr != nil {
|
||||
return nil, retryErr
|
||||
}
|
||||
if retry {
|
||||
logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
|
||||
status.Warn(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries))
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
@@ -243,8 +258,8 @@ func (a *anthropicClient) stream(ctx context.Context, messages []message.Message
|
||||
preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools))
|
||||
cfg := config.Get()
|
||||
if cfg.Debug {
|
||||
// jsonData, _ := json.Marshal(preparedMessages)
|
||||
// logging.Debug("Prepared messages", "messages", string(jsonData))
|
||||
jsonData, _ := json.Marshal(preparedMessages)
|
||||
slog.Debug("Prepared messages", "messages", string(jsonData))
|
||||
}
|
||||
attempts := 0
|
||||
eventChan := make(chan ProviderEvent)
|
||||
@@ -262,7 +277,7 @@ func (a *anthropicClient) stream(ctx context.Context, messages []message.Message
|
||||
event := anthropicStream.Current()
|
||||
err := accumulatedMessage.Accumulate(event)
|
||||
if err != nil {
|
||||
logging.Warn("Error accumulating message", "error", err)
|
||||
slog.Warn("Error accumulating message", "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -351,7 +366,7 @@ func (a *anthropicClient) stream(ctx context.Context, messages []message.Message
|
||||
return
|
||||
}
|
||||
if retry {
|
||||
logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
|
||||
status.Warn(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries))
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// context cancelled
|
||||
|
||||
@@ -7,8 +7,8 @@ import (
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/opencode-ai/opencode/internal/llm/tools"
|
||||
"github.com/opencode-ai/opencode/internal/message"
|
||||
"github.com/sst/opencode/internal/llm/tools"
|
||||
"github.com/sst/opencode/internal/message"
|
||||
)
|
||||
|
||||
type bedrockOptions struct {
|
||||
@@ -55,7 +55,7 @@ func newBedrockClient(opts providerClientOptions) BedrockClient {
|
||||
if strings.Contains(string(opts.model.APIModel), "anthropic") {
|
||||
// Create Anthropic client with Bedrock configuration
|
||||
anthropicOpts := opts
|
||||
anthropicOpts.anthropicOptions = append(anthropicOpts.anthropicOptions,
|
||||
anthropicOpts.anthropicOptions = append(anthropicOpts.anthropicOptions,
|
||||
WithAnthropicBedrock(true),
|
||||
WithAnthropicDisableCache(),
|
||||
)
|
||||
@@ -84,7 +84,7 @@ func (b *bedrockClient) send(ctx context.Context, messages []message.Message, to
|
||||
|
||||
func (b *bedrockClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
|
||||
eventChan := make(chan ProviderEvent)
|
||||
|
||||
|
||||
if b.childProvider == nil {
|
||||
go func() {
|
||||
eventChan <- ProviderEvent{
|
||||
@@ -95,6 +95,6 @@ func (b *bedrockClient) stream(ctx context.Context, messages []message.Message,
|
||||
}()
|
||||
return eventChan
|
||||
}
|
||||
|
||||
|
||||
return b.childProvider.stream(ctx, messages, tools)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,14 +9,13 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/generative-ai-go/genai"
|
||||
"github.com/google/uuid"
|
||||
"github.com/opencode-ai/opencode/internal/config"
|
||||
"github.com/opencode-ai/opencode/internal/llm/tools"
|
||||
"github.com/opencode-ai/opencode/internal/logging"
|
||||
"github.com/opencode-ai/opencode/internal/message"
|
||||
"google.golang.org/api/iterator"
|
||||
"google.golang.org/api/option"
|
||||
"github.com/sst/opencode/internal/config"
|
||||
"github.com/sst/opencode/internal/llm/tools"
|
||||
"github.com/sst/opencode/internal/message"
|
||||
"github.com/sst/opencode/internal/status"
|
||||
"google.golang.org/genai"
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
type geminiOptions struct {
|
||||
@@ -39,9 +38,9 @@ func newGeminiClient(opts providerClientOptions) GeminiClient {
|
||||
o(&geminiOpts)
|
||||
}
|
||||
|
||||
client, err := genai.NewClient(context.Background(), option.WithAPIKey(opts.apiKey))
|
||||
client, err := genai.NewClient(context.Background(), &genai.ClientConfig{APIKey: opts.apiKey, Backend: genai.BackendGeminiAPI})
|
||||
if err != nil {
|
||||
logging.Error("Failed to create Gemini client", "error", err)
|
||||
slog.Error("Failed to create Gemini client", "error", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -57,27 +56,37 @@ func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Cont
|
||||
for _, msg := range messages {
|
||||
switch msg.Role {
|
||||
case message.User:
|
||||
var parts []*genai.Part
|
||||
parts = append(parts, &genai.Part{Text: msg.Content().String()})
|
||||
for _, binaryContent := range msg.BinaryContent() {
|
||||
imageFormat := strings.Split(binaryContent.MIMEType, "/")
|
||||
parts = append(parts, &genai.Part{InlineData: &genai.Blob{
|
||||
MIMEType: imageFormat[1],
|
||||
Data: binaryContent.Data,
|
||||
}})
|
||||
}
|
||||
history = append(history, &genai.Content{
|
||||
Parts: []genai.Part{genai.Text(msg.Content().String())},
|
||||
Parts: parts,
|
||||
Role: "user",
|
||||
})
|
||||
|
||||
case message.Assistant:
|
||||
content := &genai.Content{
|
||||
Role: "model",
|
||||
Parts: []genai.Part{},
|
||||
Parts: []*genai.Part{},
|
||||
}
|
||||
|
||||
if msg.Content().String() != "" {
|
||||
content.Parts = append(content.Parts, genai.Text(msg.Content().String()))
|
||||
content.Parts = append(content.Parts, &genai.Part{Text: msg.Content().String()})
|
||||
}
|
||||
|
||||
if len(msg.ToolCalls()) > 0 {
|
||||
for _, call := range msg.ToolCalls() {
|
||||
args, _ := parseJsonToMap(call.Input)
|
||||
content.Parts = append(content.Parts, genai.FunctionCall{
|
||||
Name: call.Name,
|
||||
Args: args,
|
||||
content.Parts = append(content.Parts, &genai.Part{
|
||||
FunctionCall: &genai.FunctionCall{
|
||||
Name: call.Name,
|
||||
Args: args,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -105,10 +114,14 @@ func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Cont
|
||||
}
|
||||
|
||||
history = append(history, &genai.Content{
|
||||
Parts: []genai.Part{genai.FunctionResponse{
|
||||
Name: toolCall.Name,
|
||||
Response: response,
|
||||
}},
|
||||
Parts: []*genai.Part{
|
||||
{
|
||||
FunctionResponse: &genai.FunctionResponse{
|
||||
Name: toolCall.Name,
|
||||
Response: response,
|
||||
},
|
||||
},
|
||||
},
|
||||
Role: "function",
|
||||
})
|
||||
}
|
||||
@@ -152,37 +165,35 @@ func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishRea
|
||||
}
|
||||
|
||||
func (g *geminiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
|
||||
model := g.client.GenerativeModel(g.providerOptions.model.APIModel)
|
||||
model.SetMaxOutputTokens(int32(g.providerOptions.maxTokens))
|
||||
model.SystemInstruction = &genai.Content{
|
||||
Parts: []genai.Part{
|
||||
genai.Text(g.providerOptions.systemMessage),
|
||||
},
|
||||
}
|
||||
// Convert tools
|
||||
if len(tools) > 0 {
|
||||
model.Tools = g.convertTools(tools)
|
||||
}
|
||||
|
||||
// Convert messages
|
||||
geminiMessages := g.convertMessages(messages)
|
||||
|
||||
cfg := config.Get()
|
||||
if cfg.Debug {
|
||||
jsonData, _ := json.Marshal(geminiMessages)
|
||||
logging.Debug("Prepared messages", "messages", string(jsonData))
|
||||
slog.Debug("Prepared messages", "messages", string(jsonData))
|
||||
}
|
||||
|
||||
history := geminiMessages[:len(geminiMessages)-1] // All but last message
|
||||
lastMsg := geminiMessages[len(geminiMessages)-1]
|
||||
chat, _ := g.client.Chats.Create(ctx, g.providerOptions.model.APIModel, &genai.GenerateContentConfig{
|
||||
MaxOutputTokens: int32(g.providerOptions.maxTokens),
|
||||
SystemInstruction: &genai.Content{
|
||||
Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}},
|
||||
},
|
||||
Tools: g.convertTools(tools),
|
||||
}, history)
|
||||
|
||||
attempts := 0
|
||||
for {
|
||||
attempts++
|
||||
var toolCalls []message.ToolCall
|
||||
chat := model.StartChat()
|
||||
chat.History = geminiMessages[:len(geminiMessages)-1] // All but last message
|
||||
|
||||
lastMsg := geminiMessages[len(geminiMessages)-1]
|
||||
|
||||
resp, err := chat.SendMessage(ctx, lastMsg.Parts...)
|
||||
var lastMsgParts []genai.Part
|
||||
for _, part := range lastMsg.Parts {
|
||||
lastMsgParts = append(lastMsgParts, *part)
|
||||
}
|
||||
resp, err := chat.SendMessage(ctx, lastMsgParts...)
|
||||
// If there is an error we are going to see if we can retry the call
|
||||
if err != nil {
|
||||
retry, after, retryErr := g.shouldRetry(attempts, err)
|
||||
@@ -190,7 +201,7 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too
|
||||
return nil, retryErr
|
||||
}
|
||||
if retry {
|
||||
logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
|
||||
status.Warn(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries))
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
@@ -205,15 +216,15 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too
|
||||
|
||||
if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
|
||||
for _, part := range resp.Candidates[0].Content.Parts {
|
||||
switch p := part.(type) {
|
||||
case genai.Text:
|
||||
content = string(p)
|
||||
case genai.FunctionCall:
|
||||
switch {
|
||||
case part.Text != "":
|
||||
content = string(part.Text)
|
||||
case part.FunctionCall != nil:
|
||||
id := "call_" + uuid.New().String()
|
||||
args, _ := json.Marshal(p.Args)
|
||||
args, _ := json.Marshal(part.FunctionCall.Args)
|
||||
toolCalls = append(toolCalls, message.ToolCall{
|
||||
ID: id,
|
||||
Name: p.Name,
|
||||
Name: part.FunctionCall.Name,
|
||||
Input: string(args),
|
||||
Type: "function",
|
||||
Finished: true,
|
||||
@@ -239,27 +250,25 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too
|
||||
}
|
||||
|
||||
func (g *geminiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
|
||||
model := g.client.GenerativeModel(g.providerOptions.model.APIModel)
|
||||
model.SetMaxOutputTokens(int32(g.providerOptions.maxTokens))
|
||||
model.SystemInstruction = &genai.Content{
|
||||
Parts: []genai.Part{
|
||||
genai.Text(g.providerOptions.systemMessage),
|
||||
},
|
||||
}
|
||||
// Convert tools
|
||||
if len(tools) > 0 {
|
||||
model.Tools = g.convertTools(tools)
|
||||
}
|
||||
|
||||
// Convert messages
|
||||
geminiMessages := g.convertMessages(messages)
|
||||
|
||||
cfg := config.Get()
|
||||
if cfg.Debug {
|
||||
jsonData, _ := json.Marshal(geminiMessages)
|
||||
logging.Debug("Prepared messages", "messages", string(jsonData))
|
||||
slog.Debug("Prepared messages", "messages", string(jsonData))
|
||||
}
|
||||
|
||||
history := geminiMessages[:len(geminiMessages)-1] // All but last message
|
||||
lastMsg := geminiMessages[len(geminiMessages)-1]
|
||||
chat, _ := g.client.Chats.Create(ctx, g.providerOptions.model.APIModel, &genai.GenerateContentConfig{
|
||||
MaxOutputTokens: int32(g.providerOptions.maxTokens),
|
||||
SystemInstruction: &genai.Content{
|
||||
Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}},
|
||||
},
|
||||
Tools: g.convertTools(tools),
|
||||
}, history)
|
||||
|
||||
attempts := 0
|
||||
eventChan := make(chan ProviderEvent)
|
||||
|
||||
@@ -268,11 +277,6 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
|
||||
|
||||
for {
|
||||
attempts++
|
||||
chat := model.StartChat()
|
||||
chat.History = geminiMessages[:len(geminiMessages)-1]
|
||||
lastMsg := geminiMessages[len(geminiMessages)-1]
|
||||
|
||||
iter := chat.SendMessageStream(ctx, lastMsg.Parts...)
|
||||
|
||||
currentContent := ""
|
||||
toolCalls := []message.ToolCall{}
|
||||
@@ -280,11 +284,12 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
|
||||
|
||||
eventChan <- ProviderEvent{Type: EventContentStart}
|
||||
|
||||
for {
|
||||
resp, err := iter.Next()
|
||||
if err == iterator.Done {
|
||||
break
|
||||
}
|
||||
var lastMsgParts []genai.Part
|
||||
|
||||
for _, part := range lastMsg.Parts {
|
||||
lastMsgParts = append(lastMsgParts, *part)
|
||||
}
|
||||
for resp, err := range chat.SendMessageStream(ctx, lastMsgParts...) {
|
||||
if err != nil {
|
||||
retry, after, retryErr := g.shouldRetry(attempts, err)
|
||||
if retryErr != nil {
|
||||
@@ -292,7 +297,7 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
|
||||
return
|
||||
}
|
||||
if retry {
|
||||
logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
|
||||
status.Warn(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries))
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if ctx.Err() != nil {
|
||||
@@ -313,9 +318,9 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
|
||||
|
||||
if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
|
||||
for _, part := range resp.Candidates[0].Content.Parts {
|
||||
switch p := part.(type) {
|
||||
case genai.Text:
|
||||
delta := string(p)
|
||||
switch {
|
||||
case part.Text != "":
|
||||
delta := string(part.Text)
|
||||
if delta != "" {
|
||||
eventChan <- ProviderEvent{
|
||||
Type: EventContentDelta,
|
||||
@@ -323,12 +328,12 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
|
||||
}
|
||||
currentContent += delta
|
||||
}
|
||||
case genai.FunctionCall:
|
||||
case part.FunctionCall != nil:
|
||||
id := "call_" + uuid.New().String()
|
||||
args, _ := json.Marshal(p.Args)
|
||||
args, _ := json.Marshal(part.FunctionCall.Args)
|
||||
newCall := message.ToolCall{
|
||||
ID: id,
|
||||
Name: p.Name,
|
||||
Name: part.FunctionCall.Name,
|
||||
Input: string(args),
|
||||
Type: "function",
|
||||
Finished: true,
|
||||
@@ -416,12 +421,12 @@ func (g *geminiClient) toolCalls(resp *genai.GenerateContentResponse) []message.
|
||||
|
||||
if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
|
||||
for _, part := range resp.Candidates[0].Content.Parts {
|
||||
if funcCall, ok := part.(genai.FunctionCall); ok {
|
||||
if part.FunctionCall != nil {
|
||||
id := "call_" + uuid.New().String()
|
||||
args, _ := json.Marshal(funcCall.Args)
|
||||
args, _ := json.Marshal(part.FunctionCall.Args)
|
||||
toolCalls = append(toolCalls, message.ToolCall{
|
||||
ID: id,
|
||||
Name: funcCall.Name,
|
||||
Name: part.FunctionCall.Name,
|
||||
Input: string(args),
|
||||
Type: "function",
|
||||
})
|
||||
|
||||
@@ -11,10 +11,12 @@ import (
|
||||
"github.com/openai/openai-go"
|
||||
"github.com/openai/openai-go/option"
|
||||
"github.com/openai/openai-go/shared"
|
||||
"github.com/opencode-ai/opencode/internal/config"
|
||||
"github.com/opencode-ai/opencode/internal/llm/tools"
|
||||
"github.com/opencode-ai/opencode/internal/logging"
|
||||
"github.com/opencode-ai/opencode/internal/message"
|
||||
"github.com/sst/opencode/internal/config"
|
||||
"github.com/sst/opencode/internal/llm/models"
|
||||
"github.com/sst/opencode/internal/llm/tools"
|
||||
"github.com/sst/opencode/internal/message"
|
||||
"github.com/sst/opencode/internal/status"
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
type openaiOptions struct {
|
||||
@@ -71,7 +73,17 @@ func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessag
|
||||
for _, msg := range messages {
|
||||
switch msg.Role {
|
||||
case message.User:
|
||||
openaiMessages = append(openaiMessages, openai.UserMessage(msg.Content().String()))
|
||||
var content []openai.ChatCompletionContentPartUnionParam
|
||||
textBlock := openai.ChatCompletionContentPartTextParam{Text: msg.Content().String()}
|
||||
content = append(content, openai.ChatCompletionContentPartUnionParam{OfText: &textBlock})
|
||||
for _, binaryContent := range msg.BinaryContent() {
|
||||
imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: binaryContent.String(models.ProviderOpenAI)}
|
||||
imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL}
|
||||
|
||||
content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
|
||||
}
|
||||
|
||||
openaiMessages = append(openaiMessages, openai.UserMessage(content))
|
||||
|
||||
case message.Assistant:
|
||||
assistantMsg := openai.ChatCompletionAssistantMessageParam{
|
||||
@@ -171,6 +183,14 @@ func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessagePar
|
||||
params.MaxTokens = openai.Int(o.providerOptions.maxTokens)
|
||||
}
|
||||
|
||||
if o.providerOptions.model.Provider == models.ProviderOpenRouter {
|
||||
params.WithExtraFields(map[string]any{
|
||||
"provider": map[string]any{
|
||||
"require_parameters": true,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
return params
|
||||
}
|
||||
|
||||
@@ -179,7 +199,7 @@ func (o *openaiClient) send(ctx context.Context, messages []message.Message, too
|
||||
cfg := config.Get()
|
||||
if cfg.Debug {
|
||||
jsonData, _ := json.Marshal(params)
|
||||
logging.Debug("Prepared messages", "messages", string(jsonData))
|
||||
slog.Debug("Prepared messages", "messages", string(jsonData))
|
||||
}
|
||||
attempts := 0
|
||||
for {
|
||||
@@ -195,7 +215,7 @@ func (o *openaiClient) send(ctx context.Context, messages []message.Message, too
|
||||
return nil, retryErr
|
||||
}
|
||||
if retry {
|
||||
logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
|
||||
status.Warn(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries))
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
@@ -236,7 +256,7 @@ func (o *openaiClient) stream(ctx context.Context, messages []message.Message, t
|
||||
cfg := config.Get()
|
||||
if cfg.Debug {
|
||||
jsonData, _ := json.Marshal(params)
|
||||
logging.Debug("Prepared messages", "messages", string(jsonData))
|
||||
slog.Debug("Prepared messages", "messages", string(jsonData))
|
||||
}
|
||||
|
||||
attempts := 0
|
||||
@@ -258,15 +278,6 @@ func (o *openaiClient) stream(ctx context.Context, messages []message.Message, t
|
||||
chunk := openaiStream.Current()
|
||||
acc.AddChunk(chunk)
|
||||
|
||||
if tool, ok := acc.JustFinishedToolCall(); ok {
|
||||
toolCalls = append(toolCalls, message.ToolCall{
|
||||
ID: tool.Id,
|
||||
Name: tool.Name,
|
||||
Input: tool.Arguments,
|
||||
Type: "function",
|
||||
})
|
||||
}
|
||||
|
||||
for _, choice := range chunk.Choices {
|
||||
if choice.Delta.Content != "" {
|
||||
eventChan <- ProviderEvent{
|
||||
@@ -282,7 +293,9 @@ func (o *openaiClient) stream(ctx context.Context, messages []message.Message, t
|
||||
if err == nil || errors.Is(err, io.EOF) {
|
||||
// Stream completed successfully
|
||||
finishReason := o.finishReason(string(acc.ChatCompletion.Choices[0].FinishReason))
|
||||
|
||||
if len(acc.ChatCompletion.Choices[0].Message.ToolCalls) > 0 {
|
||||
toolCalls = append(toolCalls, o.toolCalls(acc.ChatCompletion)...)
|
||||
}
|
||||
if len(toolCalls) > 0 {
|
||||
finishReason = message.FinishReasonToolUse
|
||||
}
|
||||
@@ -308,7 +321,7 @@ func (o *openaiClient) stream(ctx context.Context, messages []message.Message, t
|
||||
return
|
||||
}
|
||||
if retry {
|
||||
logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
|
||||
status.Warn(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries))
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// context cancelled
|
||||
@@ -414,7 +427,7 @@ func WithReasoningEffort(effort string) OpenAIOption {
|
||||
case "low", "medium", "high":
|
||||
defaultReasoningEffort = effort
|
||||
default:
|
||||
logging.Warn("Invalid reasoning effort, using default: medium")
|
||||
slog.Warn("Invalid reasoning effort, using default: medium")
|
||||
}
|
||||
options.reasoningEffort = defaultReasoningEffort
|
||||
}
|
||||
|
||||
@@ -4,9 +4,10 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/opencode-ai/opencode/internal/llm/models"
|
||||
"github.com/opencode-ai/opencode/internal/llm/tools"
|
||||
"github.com/opencode-ai/opencode/internal/message"
|
||||
"github.com/sst/opencode/internal/llm/models"
|
||||
"github.com/sst/opencode/internal/llm/tools"
|
||||
"github.com/sst/opencode/internal/message"
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
type EventType string
|
||||
@@ -55,6 +56,8 @@ type Provider interface {
|
||||
StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
|
||||
|
||||
Model() models.Model
|
||||
|
||||
MaxTokens() int64
|
||||
}
|
||||
|
||||
type providerClientOptions struct {
|
||||
@@ -132,6 +135,15 @@ func NewProvider(providerName models.ModelProvider, opts ...ProviderClientOption
|
||||
options: clientOptions,
|
||||
client: newOpenAIClient(clientOptions),
|
||||
}, nil
|
||||
case models.ProviderXAI:
|
||||
clientOptions.openaiOptions = append(clientOptions.openaiOptions,
|
||||
WithOpenAIBaseURL("https://api.x.ai/v1"),
|
||||
)
|
||||
return &baseProvider[OpenAIClient]{
|
||||
options: clientOptions,
|
||||
client: newOpenAIClient(clientOptions),
|
||||
}, nil
|
||||
|
||||
case models.ProviderMock:
|
||||
// TODO: implement mock client for test
|
||||
panic("not implemented")
|
||||
@@ -152,16 +164,55 @@ func (p *baseProvider[C]) cleanMessages(messages []message.Message) (cleaned []m
|
||||
|
||||
func (p *baseProvider[C]) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
|
||||
messages = p.cleanMessages(messages)
|
||||
return p.client.send(ctx, messages, tools)
|
||||
response, err := p.client.send(ctx, messages, tools)
|
||||
if err == nil && response != nil {
|
||||
slog.Debug("API request token usage",
|
||||
"model", p.options.model.Name,
|
||||
"input_tokens", response.Usage.InputTokens,
|
||||
"output_tokens", response.Usage.OutputTokens,
|
||||
"cache_creation_tokens", response.Usage.CacheCreationTokens,
|
||||
"cache_read_tokens", response.Usage.CacheReadTokens,
|
||||
"total_tokens", response.Usage.InputTokens+response.Usage.OutputTokens)
|
||||
}
|
||||
return response, err
|
||||
}
|
||||
|
||||
func (p *baseProvider[C]) Model() models.Model {
|
||||
return p.options.model
|
||||
}
|
||||
|
||||
func (p *baseProvider[C]) MaxTokens() int64 {
|
||||
return p.options.maxTokens
|
||||
}
|
||||
|
||||
func (p *baseProvider[C]) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
|
||||
messages = p.cleanMessages(messages)
|
||||
return p.client.stream(ctx, messages, tools)
|
||||
eventChan := p.client.stream(ctx, messages, tools)
|
||||
|
||||
// Create a new channel to intercept events
|
||||
wrappedChan := make(chan ProviderEvent)
|
||||
|
||||
go func() {
|
||||
defer close(wrappedChan)
|
||||
|
||||
for event := range eventChan {
|
||||
// Pass the event through
|
||||
wrappedChan <- event
|
||||
|
||||
// Log token usage when we get the complete event
|
||||
if event.Type == EventComplete && event.Response != nil {
|
||||
slog.Debug("API streaming request token usage",
|
||||
"model", p.options.model.Name,
|
||||
"input_tokens", event.Response.Usage.InputTokens,
|
||||
"output_tokens", event.Response.Usage.OutputTokens,
|
||||
"cache_creation_tokens", event.Response.Usage.CacheCreationTokens,
|
||||
"cache_read_tokens", event.Response.Usage.CacheReadTokens,
|
||||
"total_tokens", event.Response.Usage.InputTokens+event.Response.Usage.OutputTokens)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return wrappedChan
|
||||
}
|
||||
|
||||
func WithAPIKey(apiKey string) ProviderClientOption {
|
||||
|
||||
@@ -7,9 +7,9 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/opencode-ai/opencode/internal/config"
|
||||
"github.com/opencode-ai/opencode/internal/llm/tools/shell"
|
||||
"github.com/opencode-ai/opencode/internal/permission"
|
||||
"github.com/sst/opencode/internal/config"
|
||||
"github.com/sst/opencode/internal/llm/tools/shell"
|
||||
"github.com/sst/opencode/internal/permission"
|
||||
)
|
||||
|
||||
type BashParams struct {
|
||||
@@ -268,6 +268,7 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
|
||||
}
|
||||
if !isSafeReadOnly {
|
||||
p := b.permissions.Request(
|
||||
ctx,
|
||||
permission.CreatePermissionRequest{
|
||||
SessionID: sessionID,
|
||||
Path: config.WorkingDirectory(),
|
||||
|
||||
@@ -9,12 +9,12 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/opencode-ai/opencode/internal/config"
|
||||
"github.com/opencode-ai/opencode/internal/diff"
|
||||
"github.com/opencode-ai/opencode/internal/history"
|
||||
"github.com/opencode-ai/opencode/internal/logging"
|
||||
"github.com/opencode-ai/opencode/internal/lsp"
|
||||
"github.com/opencode-ai/opencode/internal/permission"
|
||||
"github.com/sst/opencode/internal/config"
|
||||
"github.com/sst/opencode/internal/diff"
|
||||
"github.com/sst/opencode/internal/history"
|
||||
"github.com/sst/opencode/internal/lsp"
|
||||
"github.com/sst/opencode/internal/permission"
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
type EditParams struct {
|
||||
@@ -37,7 +37,7 @@ type EditResponseMetadata struct {
|
||||
type editTool struct {
|
||||
lspClients map[string]*lsp.Client
|
||||
permissions permission.Service
|
||||
files history.Service
|
||||
history history.Service
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -95,7 +95,7 @@ func NewEditTool(lspClients map[string]*lsp.Client, permissions permission.Servi
|
||||
return &editTool{
|
||||
lspClients: lspClients,
|
||||
permissions: permissions,
|
||||
files: files,
|
||||
history: files,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -202,6 +202,7 @@ func (e *editTool) createNewFile(ctx context.Context, filePath, content string)
|
||||
permissionPath = rootDir
|
||||
}
|
||||
p := e.permissions.Request(
|
||||
ctx,
|
||||
permission.CreatePermissionRequest{
|
||||
SessionID: sessionID,
|
||||
Path: permissionPath,
|
||||
@@ -224,17 +225,17 @@ func (e *editTool) createNewFile(ctx context.Context, filePath, content string)
|
||||
}
|
||||
|
||||
// File can't be in the history so we create a new file history
|
||||
_, err = e.files.Create(ctx, sessionID, filePath, "")
|
||||
_, err = e.history.Create(ctx, sessionID, filePath, "")
|
||||
if err != nil {
|
||||
// Log error but don't fail the operation
|
||||
return ToolResponse{}, fmt.Errorf("error creating file history: %w", err)
|
||||
}
|
||||
|
||||
// Add the new content to the file history
|
||||
_, err = e.files.CreateVersion(ctx, sessionID, filePath, content)
|
||||
_, err = e.history.CreateVersion(ctx, sessionID, filePath, content)
|
||||
if err != nil {
|
||||
// Log error but don't fail the operation
|
||||
logging.Debug("Error creating file history version", "error", err)
|
||||
slog.Debug("Error creating file history version", "error", err)
|
||||
}
|
||||
|
||||
recordFileWrite(filePath)
|
||||
@@ -313,6 +314,7 @@ func (e *editTool) deleteContent(ctx context.Context, filePath, oldString string
|
||||
permissionPath = rootDir
|
||||
}
|
||||
p := e.permissions.Request(
|
||||
ctx,
|
||||
permission.CreatePermissionRequest{
|
||||
SessionID: sessionID,
|
||||
Path: permissionPath,
|
||||
@@ -335,9 +337,9 @@ func (e *editTool) deleteContent(ctx context.Context, filePath, oldString string
|
||||
}
|
||||
|
||||
// Check if file exists in history
|
||||
file, err := e.files.GetByPathAndSession(ctx, filePath, sessionID)
|
||||
file, err := e.history.GetLatestByPathAndSession(ctx, filePath, sessionID)
|
||||
if err != nil {
|
||||
_, err = e.files.Create(ctx, sessionID, filePath, oldContent)
|
||||
_, err = e.history.Create(ctx, sessionID, filePath, oldContent)
|
||||
if err != nil {
|
||||
// Log error but don't fail the operation
|
||||
return ToolResponse{}, fmt.Errorf("error creating file history: %w", err)
|
||||
@@ -345,15 +347,15 @@ func (e *editTool) deleteContent(ctx context.Context, filePath, oldString string
|
||||
}
|
||||
if file.Content != oldContent {
|
||||
// User Manually changed the content store an intermediate version
|
||||
_, err = e.files.CreateVersion(ctx, sessionID, filePath, oldContent)
|
||||
_, err = e.history.CreateVersion(ctx, sessionID, filePath, oldContent)
|
||||
if err != nil {
|
||||
logging.Debug("Error creating file history version", "error", err)
|
||||
slog.Debug("Error creating file history version", "error", err)
|
||||
}
|
||||
}
|
||||
// Store the new version
|
||||
_, err = e.files.CreateVersion(ctx, sessionID, filePath, "")
|
||||
_, err = e.history.CreateVersion(ctx, sessionID, filePath, "")
|
||||
if err != nil {
|
||||
logging.Debug("Error creating file history version", "error", err)
|
||||
slog.Debug("Error creating file history version", "error", err)
|
||||
}
|
||||
|
||||
recordFileWrite(filePath)
|
||||
@@ -433,6 +435,7 @@ func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newS
|
||||
permissionPath = rootDir
|
||||
}
|
||||
p := e.permissions.Request(
|
||||
ctx,
|
||||
permission.CreatePermissionRequest{
|
||||
SessionID: sessionID,
|
||||
Path: permissionPath,
|
||||
@@ -455,9 +458,9 @@ func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newS
|
||||
}
|
||||
|
||||
// Check if file exists in history
|
||||
file, err := e.files.GetByPathAndSession(ctx, filePath, sessionID)
|
||||
file, err := e.history.GetLatestByPathAndSession(ctx, filePath, sessionID)
|
||||
if err != nil {
|
||||
_, err = e.files.Create(ctx, sessionID, filePath, oldContent)
|
||||
_, err = e.history.Create(ctx, sessionID, filePath, oldContent)
|
||||
if err != nil {
|
||||
// Log error but don't fail the operation
|
||||
return ToolResponse{}, fmt.Errorf("error creating file history: %w", err)
|
||||
@@ -465,15 +468,15 @@ func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newS
|
||||
}
|
||||
if file.Content != oldContent {
|
||||
// User Manually changed the content store an intermediate version
|
||||
_, err = e.files.CreateVersion(ctx, sessionID, filePath, oldContent)
|
||||
_, err = e.history.CreateVersion(ctx, sessionID, filePath, oldContent)
|
||||
if err != nil {
|
||||
logging.Debug("Error creating file history version", "error", err)
|
||||
slog.Debug("Error creating file history version", "error", err)
|
||||
}
|
||||
}
|
||||
// Store the new version
|
||||
_, err = e.files.CreateVersion(ctx, sessionID, filePath, newContent)
|
||||
_, err = e.history.CreateVersion(ctx, sessionID, filePath, newContent)
|
||||
if err != nil {
|
||||
logging.Debug("Error creating file history version", "error", err)
|
||||
slog.Debug("Error creating file history version", "error", err)
|
||||
}
|
||||
|
||||
recordFileWrite(filePath)
|
||||
|
||||
@@ -11,8 +11,8 @@ import (
|
||||
|
||||
md "github.com/JohannesKaufmann/html-to-markdown"
|
||||
"github.com/PuerkitoBio/goquery"
|
||||
"github.com/opencode-ai/opencode/internal/config"
|
||||
"github.com/opencode-ai/opencode/internal/permission"
|
||||
"github.com/sst/opencode/internal/config"
|
||||
"github.com/sst/opencode/internal/permission"
|
||||
)
|
||||
|
||||
type FetchParams struct {
|
||||
@@ -122,6 +122,7 @@ func (t *fetchTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error
|
||||
}
|
||||
|
||||
p := t.permissions.Request(
|
||||
ctx,
|
||||
permission.CreatePermissionRequest{
|
||||
SessionID: sessionID,
|
||||
Path: config.WorkingDirectory(),
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/bmatcuk/doublestar/v4"
|
||||
"github.com/opencode-ai/opencode/internal/config"
|
||||
"github.com/sst/opencode/internal/config"
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/opencode-ai/opencode/internal/config"
|
||||
"github.com/sst/opencode/internal/config"
|
||||
)
|
||||
|
||||
type GrepParams struct {
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/opencode-ai/opencode/internal/config"
|
||||
"github.com/sst/opencode/internal/config"
|
||||
)
|
||||
|
||||
type LSParams struct {
|
||||
|
||||
@@ -83,19 +83,19 @@ func TestLsTool_Run(t *testing.T) {
|
||||
|
||||
response, err := tool.Run(context.Background(), call)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
// Check that visible directories and files are included
|
||||
assert.Contains(t, response.Content, "dir1")
|
||||
assert.Contains(t, response.Content, "dir2")
|
||||
assert.Contains(t, response.Content, "dir3")
|
||||
assert.Contains(t, response.Content, "file1.txt")
|
||||
assert.Contains(t, response.Content, "file2.txt")
|
||||
|
||||
|
||||
// Check that hidden files and directories are not included
|
||||
assert.NotContains(t, response.Content, ".hidden_dir")
|
||||
assert.NotContains(t, response.Content, ".hidden_file.txt")
|
||||
assert.NotContains(t, response.Content, ".hidden_root_file.txt")
|
||||
|
||||
|
||||
// Check that __pycache__ is not included
|
||||
assert.NotContains(t, response.Content, "__pycache__")
|
||||
})
|
||||
@@ -122,7 +122,7 @@ func TestLsTool_Run(t *testing.T) {
|
||||
t.Run("handles empty path parameter", func(t *testing.T) {
|
||||
// For this test, we need to mock the config.WorkingDirectory function
|
||||
// Since we can't easily do that, we'll just check that the response doesn't contain an error message
|
||||
|
||||
|
||||
tool := NewLsTool()
|
||||
params := LSParams{
|
||||
Path: "",
|
||||
@@ -138,7 +138,7 @@ func TestLsTool_Run(t *testing.T) {
|
||||
|
||||
response, err := tool.Run(context.Background(), call)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
// The response should either contain a valid directory listing or an error
|
||||
// We'll just check that it's not empty
|
||||
assert.NotEmpty(t, response.Content)
|
||||
@@ -173,11 +173,11 @@ func TestLsTool_Run(t *testing.T) {
|
||||
|
||||
response, err := tool.Run(context.Background(), call)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
// The output format is a tree, so we need to check for specific patterns
|
||||
// Check that file1.txt is not directly mentioned
|
||||
assert.NotContains(t, response.Content, "- file1.txt")
|
||||
|
||||
|
||||
// Check that dir1/ is not directly mentioned
|
||||
assert.NotContains(t, response.Content, "- dir1/")
|
||||
})
|
||||
@@ -189,12 +189,12 @@ func TestLsTool_Run(t *testing.T) {
|
||||
defer func() {
|
||||
os.Chdir(origWd)
|
||||
}()
|
||||
|
||||
|
||||
// Change to a directory above the temp directory
|
||||
parentDir := filepath.Dir(tempDir)
|
||||
err = os.Chdir(parentDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
tool := NewLsTool()
|
||||
params := LSParams{
|
||||
Path: filepath.Base(tempDir),
|
||||
@@ -210,7 +210,7 @@ func TestLsTool_Run(t *testing.T) {
|
||||
|
||||
response, err := tool.Run(context.Background(), call)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
// Should list the temp directory contents
|
||||
assert.Contains(t, response.Content, "dir1")
|
||||
assert.Contains(t, response.Content, "file1.txt")
|
||||
@@ -291,22 +291,22 @@ func TestCreateFileTree(t *testing.T) {
|
||||
}
|
||||
|
||||
tree := createFileTree(paths)
|
||||
|
||||
|
||||
// Check the structure of the tree
|
||||
assert.Len(t, tree, 1) // Should have one root node
|
||||
|
||||
|
||||
// Check the root node
|
||||
rootNode := tree[0]
|
||||
assert.Equal(t, "path", rootNode.Name)
|
||||
assert.Equal(t, "directory", rootNode.Type)
|
||||
assert.Len(t, rootNode.Children, 1)
|
||||
|
||||
|
||||
// Check the "to" node
|
||||
toNode := rootNode.Children[0]
|
||||
assert.Equal(t, "to", toNode.Name)
|
||||
assert.Equal(t, "directory", toNode.Type)
|
||||
assert.Len(t, toNode.Children, 3) // file1.txt, dir1, dir2
|
||||
|
||||
|
||||
// Find the dir1 node
|
||||
var dir1Node *TreeNode
|
||||
for _, child := range toNode.Children {
|
||||
@@ -315,7 +315,7 @@ func TestCreateFileTree(t *testing.T) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
require.NotNil(t, dir1Node)
|
||||
assert.Equal(t, "directory", dir1Node.Type)
|
||||
assert.Len(t, dir1Node.Children, 2) // file2.txt and subdir
|
||||
@@ -354,9 +354,9 @@ func TestPrintTree(t *testing.T) {
|
||||
Type: "file",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
result := printTree(tree, "/root")
|
||||
|
||||
|
||||
// Check the output format
|
||||
assert.Contains(t, result, "- /root/")
|
||||
assert.Contains(t, result, " - dir1/")
|
||||
@@ -405,7 +405,7 @@ func TestListDirectory(t *testing.T) {
|
||||
files, truncated, err := listDirectory(tempDir, []string{}, 1000)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, truncated)
|
||||
|
||||
|
||||
// Check that visible files and directories are included
|
||||
containsPath := func(paths []string, target string) bool {
|
||||
targetPath := filepath.Join(tempDir, target)
|
||||
@@ -416,12 +416,12 @@ func TestListDirectory(t *testing.T) {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
|
||||
assert.True(t, containsPath(files, "dir1"))
|
||||
assert.True(t, containsPath(files, "file1.txt"))
|
||||
assert.True(t, containsPath(files, "file2.txt"))
|
||||
assert.True(t, containsPath(files, "dir1/file3.txt"))
|
||||
|
||||
|
||||
// Check that hidden files and directories are not included
|
||||
assert.False(t, containsPath(files, ".hidden_dir"))
|
||||
assert.False(t, containsPath(files, ".hidden_file.txt"))
|
||||
@@ -438,12 +438,12 @@ func TestListDirectory(t *testing.T) {
|
||||
files, truncated, err := listDirectory(tempDir, []string{"*.txt"}, 1000)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, truncated)
|
||||
|
||||
|
||||
// Check that no .txt files are included
|
||||
for _, file := range files {
|
||||
assert.False(t, strings.HasSuffix(file, ".txt"), "Found .txt file: %s", file)
|
||||
}
|
||||
|
||||
|
||||
// But directories should still be included
|
||||
containsDir := false
|
||||
for _, file := range files {
|
||||
@@ -454,4 +454,4 @@ func TestListDirectory(t *testing.T) {
|
||||
}
|
||||
assert.True(t, containsDir)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
350
internal/llm/tools/lsp_code_action.go
Normal file
350
internal/llm/tools/lsp_code_action.go
Normal file
@@ -0,0 +1,350 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/sst/opencode/internal/lsp"
|
||||
"github.com/sst/opencode/internal/lsp/protocol"
|
||||
"github.com/sst/opencode/internal/lsp/util"
|
||||
)
|
||||
|
||||
type CodeActionParams struct {
|
||||
FilePath string `json:"file_path"`
|
||||
Line int `json:"line"`
|
||||
Column int `json:"column"`
|
||||
EndLine int `json:"end_line,omitempty"`
|
||||
EndColumn int `json:"end_column,omitempty"`
|
||||
ActionID int `json:"action_id,omitempty"`
|
||||
LspName string `json:"lsp_name,omitempty"`
|
||||
}
|
||||
|
||||
type codeActionTool struct {
|
||||
lspClients map[string]*lsp.Client
|
||||
}
|
||||
|
||||
const (
|
||||
CodeActionToolName = "codeAction"
|
||||
codeActionDescription = `Get available code actions at a specific position or range in a file.
|
||||
WHEN TO USE THIS TOOL:
|
||||
- Use when you need to find available fixes or refactorings for code issues
|
||||
- Helpful for resolving errors, warnings, or improving code quality
|
||||
- Great for discovering automated code transformations
|
||||
|
||||
HOW TO USE:
|
||||
- Provide the path to the file containing the code
|
||||
- Specify the line number (1-based) where the action should be applied
|
||||
- Specify the column number (1-based) where the action should be applied
|
||||
- Optionally specify end_line and end_column to define a range
|
||||
- Results show available code actions with their titles and kinds
|
||||
|
||||
TO EXECUTE A CODE ACTION:
|
||||
- After getting the list of available actions, call the tool again with the same parameters
|
||||
- Add action_id parameter with the number of the action you want to execute (e.g., 1 for the first action)
|
||||
- Add lsp_name parameter with the name of the LSP server that provided the action
|
||||
|
||||
FEATURES:
|
||||
- Finds quick fixes for errors and warnings
|
||||
- Discovers available refactorings
|
||||
- Shows code organization actions
|
||||
- Returns detailed information about each action
|
||||
- Can execute selected code actions
|
||||
|
||||
LIMITATIONS:
|
||||
- Requires a functioning LSP server for the file type
|
||||
- May not work for all code issues depending on LSP capabilities
|
||||
- Results depend on the accuracy of the LSP server
|
||||
|
||||
TIPS:
|
||||
- Use in conjunction with Diagnostics tool to find issues that can be fixed
|
||||
- First call without action_id to see available actions, then call again with action_id to execute
|
||||
`
|
||||
)
|
||||
|
||||
func NewCodeActionTool(lspClients map[string]*lsp.Client) BaseTool {
|
||||
return &codeActionTool{
|
||||
lspClients,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *codeActionTool) Info() ToolInfo {
|
||||
return ToolInfo{
|
||||
Name: CodeActionToolName,
|
||||
Description: codeActionDescription,
|
||||
Parameters: map[string]any{
|
||||
"file_path": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The path to the file containing the code",
|
||||
},
|
||||
"line": map[string]any{
|
||||
"type": "integer",
|
||||
"description": "The line number (1-based) where the action should be applied",
|
||||
},
|
||||
"column": map[string]any{
|
||||
"type": "integer",
|
||||
"description": "The column number (1-based) where the action should be applied",
|
||||
},
|
||||
"end_line": map[string]any{
|
||||
"type": "integer",
|
||||
"description": "The ending line number (1-based) for a range (optional)",
|
||||
},
|
||||
"end_column": map[string]any{
|
||||
"type": "integer",
|
||||
"description": "The ending column number (1-based) for a range (optional)",
|
||||
},
|
||||
"action_id": map[string]any{
|
||||
"type": "integer",
|
||||
"description": "The ID of the code action to execute (optional)",
|
||||
},
|
||||
"lsp_name": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The name of the LSP server that provided the action (optional)",
|
||||
},
|
||||
},
|
||||
Required: []string{"file_path", "line", "column"},
|
||||
}
|
||||
}
|
||||
|
||||
func (b *codeActionTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
|
||||
var params CodeActionParams
|
||||
if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
|
||||
return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
|
||||
}
|
||||
|
||||
lsps := b.lspClients
|
||||
|
||||
if len(lsps) == 0 {
|
||||
return NewTextResponse("\nLSP clients are still initializing. Code actions will be available once they're ready.\n"), nil
|
||||
}
|
||||
|
||||
// Ensure file is open in LSP
|
||||
notifyLspOpenFile(ctx, params.FilePath, lsps)
|
||||
|
||||
// Convert 1-based line/column to 0-based for LSP protocol
|
||||
line := max(0, params.Line-1)
|
||||
column := max(0, params.Column-1)
|
||||
|
||||
// Handle optional end line/column
|
||||
endLine := line
|
||||
endColumn := column
|
||||
if params.EndLine > 0 {
|
||||
endLine = max(0, params.EndLine-1)
|
||||
}
|
||||
if params.EndColumn > 0 {
|
||||
endColumn = max(0, params.EndColumn-1)
|
||||
}
|
||||
|
||||
// Check if we're executing a specific action
|
||||
if params.ActionID > 0 && params.LspName != "" {
|
||||
return executeCodeAction(ctx, params.FilePath, line, column, endLine, endColumn, params.ActionID, params.LspName, lsps)
|
||||
}
|
||||
|
||||
// Otherwise, just list available actions
|
||||
output := getCodeActions(ctx, params.FilePath, line, column, endLine, endColumn, lsps)
|
||||
return NewTextResponse(output), nil
|
||||
}
|
||||
|
||||
func getCodeActions(ctx context.Context, filePath string, line, column, endLine, endColumn int, lsps map[string]*lsp.Client) string {
|
||||
var results []string
|
||||
|
||||
for lspName, client := range lsps {
|
||||
// Create code action params
|
||||
uri := fmt.Sprintf("file://%s", filePath)
|
||||
codeActionParams := protocol.CodeActionParams{
|
||||
TextDocument: protocol.TextDocumentIdentifier{
|
||||
URI: protocol.DocumentUri(uri),
|
||||
},
|
||||
Range: protocol.Range{
|
||||
Start: protocol.Position{
|
||||
Line: uint32(line),
|
||||
Character: uint32(column),
|
||||
},
|
||||
End: protocol.Position{
|
||||
Line: uint32(endLine),
|
||||
Character: uint32(endColumn),
|
||||
},
|
||||
},
|
||||
Context: protocol.CodeActionContext{
|
||||
// Request all kinds of code actions
|
||||
Only: []protocol.CodeActionKind{
|
||||
protocol.QuickFix,
|
||||
protocol.Refactor,
|
||||
protocol.RefactorExtract,
|
||||
protocol.RefactorInline,
|
||||
protocol.RefactorRewrite,
|
||||
protocol.Source,
|
||||
protocol.SourceOrganizeImports,
|
||||
protocol.SourceFixAll,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Get code actions
|
||||
codeActions, err := client.CodeAction(ctx, codeActionParams)
|
||||
if err != nil {
|
||||
results = append(results, fmt.Sprintf("Error from %s: %s", lspName, err))
|
||||
continue
|
||||
}
|
||||
|
||||
if len(codeActions) == 0 {
|
||||
results = append(results, fmt.Sprintf("No code actions found by %s", lspName))
|
||||
continue
|
||||
}
|
||||
|
||||
// Format the code actions
|
||||
results = append(results, fmt.Sprintf("Code actions found by %s:", lspName))
|
||||
for i, action := range codeActions {
|
||||
actionInfo := formatCodeAction(action, i+1)
|
||||
results = append(results, actionInfo)
|
||||
}
|
||||
}
|
||||
|
||||
if len(results) == 0 {
|
||||
return "No code actions found at the specified position."
|
||||
}
|
||||
|
||||
return strings.Join(results, "\n")
|
||||
}
|
||||
|
||||
func formatCodeAction(action protocol.Or_Result_textDocument_codeAction_Item0_Elem, index int) string {
|
||||
switch v := action.Value.(type) {
|
||||
case protocol.CodeAction:
|
||||
kind := "Unknown"
|
||||
if v.Kind != "" {
|
||||
kind = string(v.Kind)
|
||||
}
|
||||
|
||||
var details []string
|
||||
|
||||
// Add edit information if available
|
||||
if v.Edit != nil {
|
||||
numChanges := 0
|
||||
if v.Edit.Changes != nil {
|
||||
numChanges = len(v.Edit.Changes)
|
||||
}
|
||||
if v.Edit.DocumentChanges != nil {
|
||||
numChanges = len(v.Edit.DocumentChanges)
|
||||
}
|
||||
details = append(details, fmt.Sprintf("Edits: %d changes", numChanges))
|
||||
}
|
||||
|
||||
// Add command information if available
|
||||
if v.Command != nil {
|
||||
details = append(details, fmt.Sprintf("Command: %s", v.Command.Title))
|
||||
}
|
||||
|
||||
// Add diagnostics information if available
|
||||
if v.Diagnostics != nil && len(v.Diagnostics) > 0 {
|
||||
details = append(details, fmt.Sprintf("Fixes: %d diagnostics", len(v.Diagnostics)))
|
||||
}
|
||||
|
||||
detailsStr := ""
|
||||
if len(details) > 0 {
|
||||
detailsStr = " (" + strings.Join(details, ", ") + ")"
|
||||
}
|
||||
|
||||
return fmt.Sprintf(" %d. %s [%s]%s", index, v.Title, kind, detailsStr)
|
||||
|
||||
case protocol.Command:
|
||||
return fmt.Sprintf(" %d. %s [Command]", index, v.Title)
|
||||
}
|
||||
|
||||
return fmt.Sprintf(" %d. Unknown code action type", index)
|
||||
}
|
||||
|
||||
func executeCodeAction(ctx context.Context, filePath string, line, column, endLine, endColumn, actionID int, lspName string, lsps map[string]*lsp.Client) (ToolResponse, error) {
|
||||
client, ok := lsps[lspName]
|
||||
if !ok {
|
||||
return NewTextErrorResponse(fmt.Sprintf("LSP server '%s' not found", lspName)), nil
|
||||
}
|
||||
|
||||
// Create code action params
|
||||
uri := fmt.Sprintf("file://%s", filePath)
|
||||
codeActionParams := protocol.CodeActionParams{
|
||||
TextDocument: protocol.TextDocumentIdentifier{
|
||||
URI: protocol.DocumentUri(uri),
|
||||
},
|
||||
Range: protocol.Range{
|
||||
Start: protocol.Position{
|
||||
Line: uint32(line),
|
||||
Character: uint32(column),
|
||||
},
|
||||
End: protocol.Position{
|
||||
Line: uint32(endLine),
|
||||
Character: uint32(endColumn),
|
||||
},
|
||||
},
|
||||
Context: protocol.CodeActionContext{
|
||||
// Request all kinds of code actions
|
||||
Only: []protocol.CodeActionKind{
|
||||
protocol.QuickFix,
|
||||
protocol.Refactor,
|
||||
protocol.RefactorExtract,
|
||||
protocol.RefactorInline,
|
||||
protocol.RefactorRewrite,
|
||||
protocol.Source,
|
||||
protocol.SourceOrganizeImports,
|
||||
protocol.SourceFixAll,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Get code actions
|
||||
codeActions, err := client.CodeAction(ctx, codeActionParams)
|
||||
if err != nil {
|
||||
return NewTextErrorResponse(fmt.Sprintf("Error getting code actions: %s", err)), nil
|
||||
}
|
||||
|
||||
if len(codeActions) == 0 {
|
||||
return NewTextErrorResponse("No code actions found"), nil
|
||||
}
|
||||
|
||||
// Check if the requested action ID is valid
|
||||
if actionID < 1 || actionID > len(codeActions) {
|
||||
return NewTextErrorResponse(fmt.Sprintf("Invalid action ID: %d. Available actions: 1-%d", actionID, len(codeActions))), nil
|
||||
}
|
||||
|
||||
// Get the selected action (adjust for 0-based index)
|
||||
selectedAction := codeActions[actionID-1]
|
||||
|
||||
// Execute the action based on its type
|
||||
switch v := selectedAction.Value.(type) {
|
||||
case protocol.CodeAction:
|
||||
// Apply workspace edit if available
|
||||
if v.Edit != nil {
|
||||
err := util.ApplyWorkspaceEdit(*v.Edit)
|
||||
if err != nil {
|
||||
return NewTextErrorResponse(fmt.Sprintf("Error applying edit: %s", err)), nil
|
||||
}
|
||||
}
|
||||
|
||||
// Execute command if available
|
||||
if v.Command != nil {
|
||||
_, err := client.ExecuteCommand(ctx, protocol.ExecuteCommandParams{
|
||||
Command: v.Command.Command,
|
||||
Arguments: v.Command.Arguments,
|
||||
})
|
||||
if err != nil {
|
||||
return NewTextErrorResponse(fmt.Sprintf("Error executing command: %s", err)), nil
|
||||
}
|
||||
}
|
||||
|
||||
return NewTextResponse(fmt.Sprintf("Successfully executed code action: %s", v.Title)), nil
|
||||
|
||||
case protocol.Command:
|
||||
// Execute the command
|
||||
_, err := client.ExecuteCommand(ctx, protocol.ExecuteCommandParams{
|
||||
Command: v.Command,
|
||||
Arguments: v.Arguments,
|
||||
})
|
||||
if err != nil {
|
||||
return NewTextErrorResponse(fmt.Sprintf("Error executing command: %s", err)), nil
|
||||
}
|
||||
|
||||
return NewTextResponse(fmt.Sprintf("Successfully executed command: %s", v.Title)), nil
|
||||
}
|
||||
|
||||
return NewTextErrorResponse("Unknown code action type"), nil
|
||||
}
|
||||
198
internal/llm/tools/lsp_definition.go
Normal file
198
internal/llm/tools/lsp_definition.go
Normal file
@@ -0,0 +1,198 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
|
||||
"github.com/sst/opencode/internal/lsp"
|
||||
"github.com/sst/opencode/internal/lsp/protocol"
|
||||
)
|
||||
|
||||
type DefinitionParams struct {
|
||||
FilePath string `json:"file_path"`
|
||||
Line int `json:"line"`
|
||||
Column int `json:"column"`
|
||||
}
|
||||
|
||||
type definitionTool struct {
|
||||
lspClients map[string]*lsp.Client
|
||||
}
|
||||
|
||||
const (
|
||||
DefinitionToolName = "definition"
|
||||
definitionDescription = `Find the definition of a symbol at a specific position in a file.
|
||||
WHEN TO USE THIS TOOL:
|
||||
- Use when you need to find where a symbol is defined
|
||||
- Helpful for understanding code structure and relationships
|
||||
- Great for navigating between implementation and interface
|
||||
|
||||
HOW TO USE:
|
||||
- Provide the path to the file containing the symbol
|
||||
- Specify the line number (1-based) where the symbol appears
|
||||
- Specify the column number (1-based) where the symbol appears
|
||||
- Results show the location of the symbol's definition
|
||||
|
||||
FEATURES:
|
||||
- Finds definitions across files in the project
|
||||
- Works with variables, functions, classes, interfaces, etc.
|
||||
- Returns file path, line, and column of the definition
|
||||
|
||||
LIMITATIONS:
|
||||
- Requires a functioning LSP server for the file type
|
||||
- May not work for all symbols depending on LSP capabilities
|
||||
- Results depend on the accuracy of the LSP server
|
||||
|
||||
TIPS:
|
||||
- Use in conjunction with References tool to understand usage
|
||||
- Combine with View tool to examine the definition
|
||||
`
|
||||
)
|
||||
|
||||
func NewDefinitionTool(lspClients map[string]*lsp.Client) BaseTool {
|
||||
return &definitionTool{
|
||||
lspClients,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *definitionTool) Info() ToolInfo {
|
||||
return ToolInfo{
|
||||
Name: DefinitionToolName,
|
||||
Description: definitionDescription,
|
||||
Parameters: map[string]any{
|
||||
"file_path": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The path to the file containing the symbol",
|
||||
},
|
||||
"line": map[string]any{
|
||||
"type": "integer",
|
||||
"description": "The line number (1-based) where the symbol appears",
|
||||
},
|
||||
"column": map[string]any{
|
||||
"type": "integer",
|
||||
"description": "The column number (1-based) where the symbol appears",
|
||||
},
|
||||
},
|
||||
Required: []string{"file_path", "line", "column"},
|
||||
}
|
||||
}
|
||||
|
||||
func (b *definitionTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
|
||||
var params DefinitionParams
|
||||
if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
|
||||
return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
|
||||
}
|
||||
|
||||
lsps := b.lspClients
|
||||
|
||||
if len(lsps) == 0 {
|
||||
return NewTextResponse("\nLSP clients are still initializing. Definition lookup will be available once they're ready.\n"), nil
|
||||
}
|
||||
|
||||
// Ensure file is open in LSP
|
||||
notifyLspOpenFile(ctx, params.FilePath, lsps)
|
||||
|
||||
// Convert 1-based line/column to 0-based for LSP protocol
|
||||
line := max(0, params.Line-1)
|
||||
column := max(0, params.Column-1)
|
||||
|
||||
output := getDefinition(ctx, params.FilePath, line, column, lsps)
|
||||
|
||||
return NewTextResponse(output), nil
|
||||
}
|
||||
|
||||
func getDefinition(ctx context.Context, filePath string, line, column int, lsps map[string]*lsp.Client) string {
|
||||
var results []string
|
||||
|
||||
slog.Debug(fmt.Sprintf("Looking for definition in %s at line %d, column %d", filePath, line+1, column+1))
|
||||
slog.Debug(fmt.Sprintf("Available LSP clients: %v", getClientNames(lsps)))
|
||||
|
||||
for lspName, client := range lsps {
|
||||
slog.Debug(fmt.Sprintf("Trying LSP client: %s", lspName))
|
||||
// Create definition params
|
||||
uri := fmt.Sprintf("file://%s", filePath)
|
||||
definitionParams := protocol.DefinitionParams{
|
||||
TextDocumentPositionParams: protocol.TextDocumentPositionParams{
|
||||
TextDocument: protocol.TextDocumentIdentifier{
|
||||
URI: protocol.DocumentUri(uri),
|
||||
},
|
||||
Position: protocol.Position{
|
||||
Line: uint32(line),
|
||||
Character: uint32(column),
|
||||
},
|
||||
},
|
||||
}
|
||||
slog.Debug(fmt.Sprintf("Sending definition request with params: %+v", definitionParams))
|
||||
|
||||
// Get definition
|
||||
definition, err := client.Definition(ctx, definitionParams)
|
||||
if err != nil {
|
||||
slog.Debug(fmt.Sprintf("Error from %s: %s", lspName, err))
|
||||
results = append(results, fmt.Sprintf("Error from %s: %s", lspName, err))
|
||||
continue
|
||||
}
|
||||
slog.Debug(fmt.Sprintf("Got definition result type: %T", definition.Value))
|
||||
|
||||
// Process the definition result
|
||||
locations := processDefinitionResult(definition)
|
||||
slog.Debug(fmt.Sprintf("Processed locations count: %d", len(locations)))
|
||||
if len(locations) == 0 {
|
||||
results = append(results, fmt.Sprintf("No definition found by %s", lspName))
|
||||
continue
|
||||
}
|
||||
|
||||
// Format the locations
|
||||
for _, loc := range locations {
|
||||
path := strings.TrimPrefix(string(loc.URI), "file://")
|
||||
// Convert 0-based line/column to 1-based for display
|
||||
defLine := loc.Range.Start.Line + 1
|
||||
defColumn := loc.Range.Start.Character + 1
|
||||
slog.Debug(fmt.Sprintf("Found definition at %s:%d:%d", path, defLine, defColumn))
|
||||
results = append(results, fmt.Sprintf("Definition found by %s: %s:%d:%d", lspName, path, defLine, defColumn))
|
||||
}
|
||||
}
|
||||
|
||||
if len(results) == 0 {
|
||||
return "No definition found for the symbol at the specified position."
|
||||
}
|
||||
|
||||
return strings.Join(results, "\n")
|
||||
}
|
||||
|
||||
func processDefinitionResult(result protocol.Or_Result_textDocument_definition) []protocol.Location {
|
||||
var locations []protocol.Location
|
||||
|
||||
switch v := result.Value.(type) {
|
||||
case protocol.Location:
|
||||
locations = append(locations, v)
|
||||
case []protocol.Location:
|
||||
locations = append(locations, v...)
|
||||
case []protocol.DefinitionLink:
|
||||
for _, link := range v {
|
||||
locations = append(locations, protocol.Location{
|
||||
URI: link.TargetURI,
|
||||
Range: link.TargetRange,
|
||||
})
|
||||
}
|
||||
case protocol.Or_Definition:
|
||||
switch d := v.Value.(type) {
|
||||
case protocol.Location:
|
||||
locations = append(locations, d)
|
||||
case []protocol.Location:
|
||||
locations = append(locations, d...)
|
||||
}
|
||||
}
|
||||
|
||||
return locations
|
||||
}
|
||||
|
||||
// Helper function to get LSP client names for debugging
|
||||
func getClientNames(lsps map[string]*lsp.Client) []string {
|
||||
names := make([]string, 0, len(lsps))
|
||||
for name := range lsps {
|
||||
names = append(names, name)
|
||||
}
|
||||
return names
|
||||
}
|
||||
@@ -9,8 +9,8 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/opencode-ai/opencode/internal/lsp"
|
||||
"github.com/opencode-ai/opencode/internal/lsp/protocol"
|
||||
"github.com/sst/opencode/internal/lsp"
|
||||
"github.com/sst/opencode/internal/lsp/protocol"
|
||||
)
|
||||
|
||||
type DiagnosticsParams struct {
|
||||
@@ -74,7 +74,8 @@ func (b *diagnosticsTool) Run(ctx context.Context, call ToolCall) (ToolResponse,
|
||||
lsps := b.lspClients
|
||||
|
||||
if len(lsps) == 0 {
|
||||
return NewTextErrorResponse("no LSP clients available"), nil
|
||||
// Return a more helpful message when LSP clients aren't ready yet
|
||||
return NewTextResponse("\n<diagnostic_summary>\nLSP clients are still initializing. Diagnostics will be available once they're ready.\n</diagnostic_summary>\n"), nil
|
||||
}
|
||||
|
||||
if params.FilePath != "" {
|
||||
204
internal/llm/tools/lsp_doc_symbols.go
Normal file
204
internal/llm/tools/lsp_doc_symbols.go
Normal file
@@ -0,0 +1,204 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/sst/opencode/internal/lsp"
|
||||
"github.com/sst/opencode/internal/lsp/protocol"
|
||||
)
|
||||
|
||||
type DocSymbolsParams struct {
|
||||
FilePath string `json:"file_path"`
|
||||
}
|
||||
|
||||
type docSymbolsTool struct {
|
||||
lspClients map[string]*lsp.Client
|
||||
}
|
||||
|
||||
const (
|
||||
DocSymbolsToolName = "docSymbols"
|
||||
docSymbolsDescription = `Get document symbols for a file.
|
||||
WHEN TO USE THIS TOOL:
|
||||
- Use when you need to understand the structure of a file
|
||||
- Helpful for finding classes, functions, methods, and variables in a file
|
||||
- Great for getting an overview of a file's organization
|
||||
|
||||
HOW TO USE:
|
||||
- Provide the path to the file to get symbols for
|
||||
- Results show all symbols defined in the file with their kind and location
|
||||
|
||||
FEATURES:
|
||||
- Lists all symbols in a hierarchical structure
|
||||
- Shows symbol types (function, class, variable, etc.)
|
||||
- Provides location information for each symbol
|
||||
- Organizes symbols by their scope and relationship
|
||||
|
||||
LIMITATIONS:
|
||||
- Requires a functioning LSP server for the file type
|
||||
- Results depend on the accuracy of the LSP server
|
||||
- May not work for all file types
|
||||
|
||||
TIPS:
|
||||
- Use to quickly understand the structure of a large file
|
||||
- Combine with Definition and References tools for deeper code exploration
|
||||
`
|
||||
)
|
||||
|
||||
func NewDocSymbolsTool(lspClients map[string]*lsp.Client) BaseTool {
|
||||
return &docSymbolsTool{
|
||||
lspClients,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *docSymbolsTool) Info() ToolInfo {
|
||||
return ToolInfo{
|
||||
Name: DocSymbolsToolName,
|
||||
Description: docSymbolsDescription,
|
||||
Parameters: map[string]any{
|
||||
"file_path": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The path to the file to get symbols for",
|
||||
},
|
||||
},
|
||||
Required: []string{"file_path"},
|
||||
}
|
||||
}
|
||||
|
||||
func (b *docSymbolsTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
|
||||
var params DocSymbolsParams
|
||||
if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
|
||||
return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
|
||||
}
|
||||
|
||||
lsps := b.lspClients
|
||||
|
||||
if len(lsps) == 0 {
|
||||
return NewTextResponse("\nLSP clients are still initializing. Document symbols lookup will be available once they're ready.\n"), nil
|
||||
}
|
||||
|
||||
// Ensure file is open in LSP
|
||||
notifyLspOpenFile(ctx, params.FilePath, lsps)
|
||||
|
||||
output := getDocumentSymbols(ctx, params.FilePath, lsps)
|
||||
|
||||
return NewTextResponse(output), nil
|
||||
}
|
||||
|
||||
func getDocumentSymbols(ctx context.Context, filePath string, lsps map[string]*lsp.Client) string {
|
||||
var results []string
|
||||
|
||||
for lspName, client := range lsps {
|
||||
// Create document symbol params
|
||||
uri := fmt.Sprintf("file://%s", filePath)
|
||||
symbolParams := protocol.DocumentSymbolParams{
|
||||
TextDocument: protocol.TextDocumentIdentifier{
|
||||
URI: protocol.DocumentUri(uri),
|
||||
},
|
||||
}
|
||||
|
||||
// Get document symbols
|
||||
symbolResult, err := client.DocumentSymbol(ctx, symbolParams)
|
||||
if err != nil {
|
||||
results = append(results, fmt.Sprintf("Error from %s: %s", lspName, err))
|
||||
continue
|
||||
}
|
||||
|
||||
// Process the symbol result
|
||||
symbols := processDocumentSymbolResult(symbolResult)
|
||||
if len(symbols) == 0 {
|
||||
results = append(results, fmt.Sprintf("No symbols found by %s", lspName))
|
||||
continue
|
||||
}
|
||||
|
||||
// Format the symbols
|
||||
results = append(results, fmt.Sprintf("Symbols found by %s:", lspName))
|
||||
for _, symbol := range symbols {
|
||||
results = append(results, formatSymbol(symbol, 1))
|
||||
}
|
||||
}
|
||||
|
||||
if len(results) == 0 {
|
||||
return "No symbols found in the specified file."
|
||||
}
|
||||
|
||||
return strings.Join(results, "\n")
|
||||
}
|
||||
|
||||
func processDocumentSymbolResult(result protocol.Or_Result_textDocument_documentSymbol) []SymbolInfo {
|
||||
var symbols []SymbolInfo
|
||||
|
||||
switch v := result.Value.(type) {
|
||||
case []protocol.SymbolInformation:
|
||||
for _, si := range v {
|
||||
symbols = append(symbols, SymbolInfo{
|
||||
Name: si.Name,
|
||||
Kind: symbolKindToString(si.Kind),
|
||||
Location: locationToString(si.Location),
|
||||
Children: nil,
|
||||
})
|
||||
}
|
||||
case []protocol.DocumentSymbol:
|
||||
for _, ds := range v {
|
||||
symbols = append(symbols, documentSymbolToSymbolInfo(ds))
|
||||
}
|
||||
}
|
||||
|
||||
return symbols
|
||||
}
|
||||
|
||||
// SymbolInfo represents a symbol in a document
|
||||
type SymbolInfo struct {
|
||||
Name string
|
||||
Kind string
|
||||
Location string
|
||||
Children []SymbolInfo
|
||||
}
|
||||
|
||||
func documentSymbolToSymbolInfo(symbol protocol.DocumentSymbol) SymbolInfo {
|
||||
info := SymbolInfo{
|
||||
Name: symbol.Name,
|
||||
Kind: symbolKindToString(symbol.Kind),
|
||||
Location: fmt.Sprintf("Line %d-%d",
|
||||
symbol.Range.Start.Line+1,
|
||||
symbol.Range.End.Line+1),
|
||||
Children: []SymbolInfo{},
|
||||
}
|
||||
|
||||
for _, child := range symbol.Children {
|
||||
info.Children = append(info.Children, documentSymbolToSymbolInfo(child))
|
||||
}
|
||||
|
||||
return info
|
||||
}
|
||||
|
||||
func locationToString(location protocol.Location) string {
|
||||
return fmt.Sprintf("Line %d-%d",
|
||||
location.Range.Start.Line+1,
|
||||
location.Range.End.Line+1)
|
||||
}
|
||||
|
||||
func symbolKindToString(kind protocol.SymbolKind) string {
|
||||
if kindStr, ok := protocol.TableKindMap[kind]; ok {
|
||||
return kindStr
|
||||
}
|
||||
return "Unknown"
|
||||
}
|
||||
|
||||
func formatSymbol(symbol SymbolInfo, level int) string {
|
||||
indent := strings.Repeat(" ", level)
|
||||
result := fmt.Sprintf("%s- %s (%s) %s", indent, symbol.Name, symbol.Kind, symbol.Location)
|
||||
|
||||
var childResults []string
|
||||
for _, child := range symbol.Children {
|
||||
childResults = append(childResults, formatSymbol(child, level+1))
|
||||
}
|
||||
|
||||
if len(childResults) > 0 {
|
||||
return result + "\n" + strings.Join(childResults, "\n")
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
161
internal/llm/tools/lsp_references.go
Normal file
161
internal/llm/tools/lsp_references.go
Normal file
@@ -0,0 +1,161 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/sst/opencode/internal/lsp"
|
||||
"github.com/sst/opencode/internal/lsp/protocol"
|
||||
)
|
||||
|
||||
type ReferencesParams struct {
|
||||
FilePath string `json:"file_path"`
|
||||
Line int `json:"line"`
|
||||
Column int `json:"column"`
|
||||
IncludeDeclaration bool `json:"include_declaration"`
|
||||
}
|
||||
|
||||
type referencesTool struct {
|
||||
lspClients map[string]*lsp.Client
|
||||
}
|
||||
|
||||
const (
|
||||
ReferencesToolName = "references"
|
||||
referencesDescription = `Find all references to a symbol at a specific position in a file.
|
||||
WHEN TO USE THIS TOOL:
|
||||
- Use when you need to find all places where a symbol is used
|
||||
- Helpful for understanding code usage and dependencies
|
||||
- Great for refactoring and impact analysis
|
||||
|
||||
HOW TO USE:
|
||||
- Provide the path to the file containing the symbol
|
||||
- Specify the line number (1-based) where the symbol appears
|
||||
- Specify the column number (1-based) where the symbol appears
|
||||
- Optionally set include_declaration to include the declaration in results
|
||||
- Results show all locations where the symbol is referenced
|
||||
|
||||
FEATURES:
|
||||
- Finds references across files in the project
|
||||
- Works with variables, functions, classes, interfaces, etc.
|
||||
- Returns file paths, lines, and columns of all references
|
||||
|
||||
LIMITATIONS:
|
||||
- Requires a functioning LSP server for the file type
|
||||
- May not find all references depending on LSP capabilities
|
||||
- Results depend on the accuracy of the LSP server
|
||||
|
||||
TIPS:
|
||||
- Use in conjunction with Definition tool to understand symbol origins
|
||||
- Combine with View tool to examine the references
|
||||
`
|
||||
)
|
||||
|
||||
func NewReferencesTool(lspClients map[string]*lsp.Client) BaseTool {
|
||||
return &referencesTool{
|
||||
lspClients,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *referencesTool) Info() ToolInfo {
|
||||
return ToolInfo{
|
||||
Name: ReferencesToolName,
|
||||
Description: referencesDescription,
|
||||
Parameters: map[string]any{
|
||||
"file_path": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The path to the file containing the symbol",
|
||||
},
|
||||
"line": map[string]any{
|
||||
"type": "integer",
|
||||
"description": "The line number (1-based) where the symbol appears",
|
||||
},
|
||||
"column": map[string]any{
|
||||
"type": "integer",
|
||||
"description": "The column number (1-based) where the symbol appears",
|
||||
},
|
||||
"include_declaration": map[string]any{
|
||||
"type": "boolean",
|
||||
"description": "Whether to include the declaration in the results",
|
||||
},
|
||||
},
|
||||
Required: []string{"file_path", "line", "column"},
|
||||
}
|
||||
}
|
||||
|
||||
func (b *referencesTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
|
||||
var params ReferencesParams
|
||||
if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
|
||||
return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
|
||||
}
|
||||
|
||||
lsps := b.lspClients
|
||||
|
||||
if len(lsps) == 0 {
|
||||
return NewTextResponse("\nLSP clients are still initializing. References lookup will be available once they're ready.\n"), nil
|
||||
}
|
||||
|
||||
// Ensure file is open in LSP
|
||||
notifyLspOpenFile(ctx, params.FilePath, lsps)
|
||||
|
||||
// Convert 1-based line/column to 0-based for LSP protocol
|
||||
line := max(0, params.Line-1)
|
||||
column := max(0, params.Column-1)
|
||||
|
||||
output := getReferences(ctx, params.FilePath, line, column, params.IncludeDeclaration, lsps)
|
||||
|
||||
return NewTextResponse(output), nil
|
||||
}
|
||||
|
||||
func getReferences(ctx context.Context, filePath string, line, column int, includeDeclaration bool, lsps map[string]*lsp.Client) string {
|
||||
var results []string
|
||||
|
||||
for lspName, client := range lsps {
|
||||
// Create references params
|
||||
uri := fmt.Sprintf("file://%s", filePath)
|
||||
referencesParams := protocol.ReferenceParams{
|
||||
TextDocumentPositionParams: protocol.TextDocumentPositionParams{
|
||||
TextDocument: protocol.TextDocumentIdentifier{
|
||||
URI: protocol.DocumentUri(uri),
|
||||
},
|
||||
Position: protocol.Position{
|
||||
Line: uint32(line),
|
||||
Character: uint32(column),
|
||||
},
|
||||
},
|
||||
Context: protocol.ReferenceContext{
|
||||
IncludeDeclaration: includeDeclaration,
|
||||
},
|
||||
}
|
||||
|
||||
// Get references
|
||||
references, err := client.References(ctx, referencesParams)
|
||||
if err != nil {
|
||||
results = append(results, fmt.Sprintf("Error from %s: %s", lspName, err))
|
||||
continue
|
||||
}
|
||||
|
||||
if len(references) == 0 {
|
||||
results = append(results, fmt.Sprintf("No references found by %s", lspName))
|
||||
continue
|
||||
}
|
||||
|
||||
// Format the locations
|
||||
results = append(results, fmt.Sprintf("References found by %s:", lspName))
|
||||
for _, loc := range references {
|
||||
path := strings.TrimPrefix(string(loc.URI), "file://")
|
||||
// Convert 0-based line/column to 1-based for display
|
||||
refLine := loc.Range.Start.Line + 1
|
||||
refColumn := loc.Range.Start.Character + 1
|
||||
results = append(results, fmt.Sprintf(" %s:%d:%d", path, refLine, refColumn))
|
||||
}
|
||||
}
|
||||
|
||||
if len(results) == 0 {
|
||||
return "No references found for the symbol at the specified position."
|
||||
}
|
||||
|
||||
return strings.Join(results, "\n")
|
||||
}
|
||||
|
||||
162
internal/llm/tools/lsp_workspace_symbols.go
Normal file
162
internal/llm/tools/lsp_workspace_symbols.go
Normal file
@@ -0,0 +1,162 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/sst/opencode/internal/lsp"
|
||||
"github.com/sst/opencode/internal/lsp/protocol"
|
||||
)
|
||||
|
||||
type WorkspaceSymbolsParams struct {
|
||||
Query string `json:"query"`
|
||||
}
|
||||
|
||||
type workspaceSymbolsTool struct {
|
||||
lspClients map[string]*lsp.Client
|
||||
}
|
||||
|
||||
const (
|
||||
WorkspaceSymbolsToolName = "workspaceSymbols"
|
||||
workspaceSymbolsDescription = `Find symbols across the workspace matching a query.
|
||||
WHEN TO USE THIS TOOL:
|
||||
- Use when you need to find symbols across multiple files
|
||||
- Helpful for locating classes, functions, or variables in a project
|
||||
- Great for exploring large codebases
|
||||
|
||||
HOW TO USE:
|
||||
- Provide a query string to search for symbols
|
||||
- Results show matching symbols from across the workspace
|
||||
|
||||
FEATURES:
|
||||
- Searches across all files in the workspace
|
||||
- Shows symbol types (function, class, variable, etc.)
|
||||
- Provides location information for each symbol
|
||||
- Works with partial matches and fuzzy search (depending on LSP server)
|
||||
|
||||
LIMITATIONS:
|
||||
- Requires a functioning LSP server for the file types
|
||||
- Results depend on the accuracy of the LSP server
|
||||
- Query capabilities vary by language server
|
||||
- May not work for all file types
|
||||
|
||||
TIPS:
|
||||
- Use specific queries to narrow down results
|
||||
- Combine with DocSymbols tool for detailed file exploration
|
||||
- Use with Definition tool to jump to symbol definitions
|
||||
`
|
||||
)
|
||||
|
||||
func NewWorkspaceSymbolsTool(lspClients map[string]*lsp.Client) BaseTool {
|
||||
return &workspaceSymbolsTool{
|
||||
lspClients,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *workspaceSymbolsTool) Info() ToolInfo {
|
||||
return ToolInfo{
|
||||
Name: WorkspaceSymbolsToolName,
|
||||
Description: workspaceSymbolsDescription,
|
||||
Parameters: map[string]any{
|
||||
"query": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The query string to search for symbols",
|
||||
},
|
||||
},
|
||||
Required: []string{"query"},
|
||||
}
|
||||
}
|
||||
|
||||
func (b *workspaceSymbolsTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
|
||||
var params WorkspaceSymbolsParams
|
||||
if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
|
||||
return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
|
||||
}
|
||||
|
||||
lsps := b.lspClients
|
||||
|
||||
if len(lsps) == 0 {
|
||||
return NewTextResponse("\nLSP clients are still initializing. Workspace symbols lookup will be available once they're ready.\n"), nil
|
||||
}
|
||||
|
||||
output := getWorkspaceSymbols(ctx, params.Query, lsps)
|
||||
|
||||
return NewTextResponse(output), nil
|
||||
}
|
||||
|
||||
func getWorkspaceSymbols(ctx context.Context, query string, lsps map[string]*lsp.Client) string {
|
||||
var results []string
|
||||
|
||||
for lspName, client := range lsps {
|
||||
// Create workspace symbol params
|
||||
symbolParams := protocol.WorkspaceSymbolParams{
|
||||
Query: query,
|
||||
}
|
||||
|
||||
// Get workspace symbols
|
||||
symbolResult, err := client.Symbol(ctx, symbolParams)
|
||||
if err != nil {
|
||||
results = append(results, fmt.Sprintf("Error from %s: %s", lspName, err))
|
||||
continue
|
||||
}
|
||||
|
||||
// Process the symbol result
|
||||
symbols := processWorkspaceSymbolResult(symbolResult)
|
||||
if len(symbols) == 0 {
|
||||
results = append(results, fmt.Sprintf("No symbols found by %s for query '%s'", lspName, query))
|
||||
continue
|
||||
}
|
||||
|
||||
// Format the symbols
|
||||
results = append(results, fmt.Sprintf("Symbols found by %s for query '%s':", lspName, query))
|
||||
for _, symbol := range symbols {
|
||||
results = append(results, fmt.Sprintf(" %s (%s) - %s", symbol.Name, symbol.Kind, symbol.Location))
|
||||
}
|
||||
}
|
||||
|
||||
if len(results) == 0 {
|
||||
return fmt.Sprintf("No symbols found matching query '%s'.", query)
|
||||
}
|
||||
|
||||
return strings.Join(results, "\n")
|
||||
}
|
||||
|
||||
func processWorkspaceSymbolResult(result protocol.Or_Result_workspace_symbol) []SymbolInfo {
|
||||
var symbols []SymbolInfo
|
||||
|
||||
switch v := result.Value.(type) {
|
||||
case []protocol.SymbolInformation:
|
||||
for _, si := range v {
|
||||
symbols = append(symbols, SymbolInfo{
|
||||
Name: si.Name,
|
||||
Kind: symbolKindToString(si.Kind),
|
||||
Location: formatWorkspaceLocation(si.Location),
|
||||
Children: nil,
|
||||
})
|
||||
}
|
||||
case []protocol.WorkspaceSymbol:
|
||||
for _, ws := range v {
|
||||
location := "Unknown location"
|
||||
if ws.Location.Value != nil {
|
||||
if loc, ok := ws.Location.Value.(protocol.Location); ok {
|
||||
location = formatWorkspaceLocation(loc)
|
||||
}
|
||||
}
|
||||
symbols = append(symbols, SymbolInfo{
|
||||
Name: ws.Name,
|
||||
Kind: symbolKindToString(ws.Kind),
|
||||
Location: location,
|
||||
Children: nil,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return symbols
|
||||
}
|
||||
|
||||
func formatWorkspaceLocation(location protocol.Location) string {
|
||||
path := strings.TrimPrefix(string(location.URI), "file://")
|
||||
return fmt.Sprintf("%s:%d:%d", path, location.Range.Start.Line+1, location.Range.Start.Character+1)
|
||||
}
|
||||
@@ -8,12 +8,12 @@ import (
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/opencode-ai/opencode/internal/config"
|
||||
"github.com/opencode-ai/opencode/internal/diff"
|
||||
"github.com/opencode-ai/opencode/internal/history"
|
||||
"github.com/opencode-ai/opencode/internal/logging"
|
||||
"github.com/opencode-ai/opencode/internal/lsp"
|
||||
"github.com/opencode-ai/opencode/internal/permission"
|
||||
"github.com/sst/opencode/internal/config"
|
||||
"github.com/sst/opencode/internal/diff"
|
||||
"github.com/sst/opencode/internal/history"
|
||||
"github.com/sst/opencode/internal/lsp"
|
||||
"github.com/sst/opencode/internal/permission"
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
type PatchParams struct {
|
||||
@@ -193,6 +193,7 @@ func (p *patchTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error
|
||||
dir := filepath.Dir(path)
|
||||
patchDiff, _, _ := diff.GenerateDiff("", *change.NewContent, path)
|
||||
p := p.permissions.Request(
|
||||
ctx,
|
||||
permission.CreatePermissionRequest{
|
||||
SessionID: sessionID,
|
||||
Path: dir,
|
||||
@@ -220,6 +221,7 @@ func (p *patchTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error
|
||||
patchDiff, _, _ := diff.GenerateDiff(currentContent, newContent, path)
|
||||
dir := filepath.Dir(path)
|
||||
p := p.permissions.Request(
|
||||
ctx,
|
||||
permission.CreatePermissionRequest{
|
||||
SessionID: sessionID,
|
||||
Path: dir,
|
||||
@@ -239,6 +241,7 @@ func (p *patchTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error
|
||||
dir := filepath.Dir(path)
|
||||
patchDiff, _, _ := diff.GenerateDiff(*change.OldContent, "", path)
|
||||
p := p.permissions.Request(
|
||||
ctx,
|
||||
permission.CreatePermissionRequest{
|
||||
SessionID: sessionID,
|
||||
Path: dir,
|
||||
@@ -313,12 +316,12 @@ func (p *patchTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error
|
||||
totalRemovals += removals
|
||||
|
||||
// Update history
|
||||
file, err := p.files.GetByPathAndSession(ctx, absPath, sessionID)
|
||||
file, err := p.files.GetLatestByPathAndSession(ctx, absPath, sessionID)
|
||||
if err != nil && change.Type != diff.ActionAdd {
|
||||
// If not adding a file, create history entry for existing file
|
||||
_, err = p.files.Create(ctx, sessionID, absPath, oldContent)
|
||||
if err != nil {
|
||||
logging.Debug("Error creating file history", "error", err)
|
||||
slog.Debug("Error creating file history", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -326,7 +329,7 @@ func (p *patchTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error
|
||||
// User manually changed content, store intermediate version
|
||||
_, err = p.files.CreateVersion(ctx, sessionID, absPath, oldContent)
|
||||
if err != nil {
|
||||
logging.Debug("Error creating file history version", "error", err)
|
||||
slog.Debug("Error creating file history version", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -337,7 +340,7 @@ func (p *patchTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error
|
||||
_, err = p.files.CreateVersion(ctx, sessionID, absPath, newContent)
|
||||
}
|
||||
if err != nil {
|
||||
logging.Debug("Error creating file history version", "error", err)
|
||||
slog.Debug("Error creating file history version", "error", err)
|
||||
}
|
||||
|
||||
// Record file operations
|
||||
|
||||
@@ -11,6 +11,8 @@ import (
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/sst/opencode/internal/status"
|
||||
)
|
||||
|
||||
type PersistentShell struct {
|
||||
@@ -99,7 +101,7 @@ func newPersistentShell(cwd string) *PersistentShell {
|
||||
go func() {
|
||||
err := cmd.Wait()
|
||||
if err != nil {
|
||||
// Log the error if needed
|
||||
status.Error(fmt.Sprintf("Shell process exited with error: %v", err))
|
||||
}
|
||||
shell.isAlive = false
|
||||
close(shell.commandQueue)
|
||||
|
||||
@@ -1,383 +0,0 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type SourcegraphParams struct {
|
||||
Query string `json:"query"`
|
||||
Count int `json:"count,omitempty"`
|
||||
ContextWindow int `json:"context_window,omitempty"`
|
||||
Timeout int `json:"timeout,omitempty"`
|
||||
}
|
||||
|
||||
type SourcegraphResponseMetadata struct {
|
||||
NumberOfMatches int `json:"number_of_matches"`
|
||||
Truncated bool `json:"truncated"`
|
||||
}
|
||||
|
||||
type sourcegraphTool struct {
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
const (
|
||||
SourcegraphToolName = "sourcegraph"
|
||||
sourcegraphToolDescription = `Search code across public repositories using Sourcegraph's GraphQL API.
|
||||
|
||||
WHEN TO USE THIS TOOL:
|
||||
- Use when you need to find code examples or implementations across public repositories
|
||||
- Helpful for researching how others have solved similar problems
|
||||
- Useful for discovering patterns and best practices in open source code
|
||||
|
||||
HOW TO USE:
|
||||
- Provide a search query using Sourcegraph's query syntax
|
||||
- Optionally specify the number of results to return (default: 10)
|
||||
- Optionally set a timeout for the request
|
||||
|
||||
QUERY SYNTAX:
|
||||
- Basic search: "fmt.Println" searches for exact matches
|
||||
- File filters: "file:.go fmt.Println" limits to Go files
|
||||
- Repository filters: "repo:^github\.com/golang/go$ fmt.Println" limits to specific repos
|
||||
- Language filters: "lang:go fmt.Println" limits to Go code
|
||||
- Boolean operators: "fmt.Println AND log.Fatal" for combined terms
|
||||
- Regular expressions: "fmt\.(Print|Printf|Println)" for pattern matching
|
||||
- Quoted strings: "\"exact phrase\"" for exact phrase matching
|
||||
- Exclude filters: "-file:test" or "-repo:forks" to exclude matches
|
||||
|
||||
ADVANCED FILTERS:
|
||||
- Repository filters:
|
||||
* "repo:name" - Match repositories with name containing "name"
|
||||
* "repo:^github\.com/org/repo$" - Exact repository match
|
||||
* "repo:org/repo@branch" - Search specific branch
|
||||
* "repo:org/repo rev:branch" - Alternative branch syntax
|
||||
* "-repo:name" - Exclude repositories
|
||||
* "fork:yes" or "fork:only" - Include or only show forks
|
||||
* "archived:yes" or "archived:only" - Include or only show archived repos
|
||||
* "visibility:public" or "visibility:private" - Filter by visibility
|
||||
|
||||
- File filters:
|
||||
* "file:\.js$" - Files with .js extension
|
||||
* "file:internal/" - Files in internal directory
|
||||
* "-file:test" - Exclude test files
|
||||
* "file:has.content(Copyright)" - Files containing "Copyright"
|
||||
* "file:has.contributor([email protected])" - Files with specific contributor
|
||||
|
||||
- Content filters:
|
||||
* "content:\"exact string\"" - Search for exact string
|
||||
* "-content:\"unwanted\"" - Exclude files with unwanted content
|
||||
* "case:yes" - Case-sensitive search
|
||||
|
||||
- Type filters:
|
||||
* "type:symbol" - Search for symbols (functions, classes, etc.)
|
||||
* "type:file" - Search file content only
|
||||
* "type:path" - Search filenames only
|
||||
* "type:diff" - Search code changes
|
||||
* "type:commit" - Search commit messages
|
||||
|
||||
- Commit/diff search:
|
||||
* "after:\"1 month ago\"" - Commits after date
|
||||
* "before:\"2023-01-01\"" - Commits before date
|
||||
* "author:name" - Commits by author
|
||||
* "message:\"fix bug\"" - Commits with message
|
||||
|
||||
- Result selection:
|
||||
* "select:repo" - Show only repository names
|
||||
* "select:file" - Show only file paths
|
||||
* "select:content" - Show only matching content
|
||||
* "select:symbol" - Show only matching symbols
|
||||
|
||||
- Result control:
|
||||
* "count:100" - Return up to 100 results
|
||||
* "count:all" - Return all results
|
||||
* "timeout:30s" - Set search timeout
|
||||
|
||||
EXAMPLES:
|
||||
- "file:.go context.WithTimeout" - Find Go code using context.WithTimeout
|
||||
- "lang:typescript useState type:symbol" - Find TypeScript React useState hooks
|
||||
- "repo:^github\.com/kubernetes/kubernetes$ pod list type:file" - Find Kubernetes files related to pod listing
|
||||
- "repo:sourcegraph/sourcegraph$ after:\"3 months ago\" type:diff database" - Recent changes to database code
|
||||
- "file:Dockerfile (alpine OR ubuntu) -content:alpine:latest" - Dockerfiles with specific base images
|
||||
- "repo:has.path(\.py) file:requirements.txt tensorflow" - Python projects using TensorFlow
|
||||
|
||||
BOOLEAN OPERATORS:
|
||||
- "term1 AND term2" - Results containing both terms
|
||||
- "term1 OR term2" - Results containing either term
|
||||
- "term1 NOT term2" - Results with term1 but not term2
|
||||
- "term1 and (term2 or term3)" - Grouping with parentheses
|
||||
|
||||
LIMITATIONS:
|
||||
- Only searches public repositories
|
||||
- Rate limits may apply
|
||||
- Complex queries may take longer to execute
|
||||
- Maximum of 20 results per query
|
||||
|
||||
TIPS:
|
||||
- Use specific file extensions to narrow results
|
||||
- Add repo: filters for more targeted searches
|
||||
- Use type:symbol to find function/method definitions
|
||||
- Use type:file to find relevant files`
|
||||
)
|
||||
|
||||
func NewSourcegraphTool() BaseTool {
|
||||
return &sourcegraphTool{
|
||||
client: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *sourcegraphTool) Info() ToolInfo {
|
||||
return ToolInfo{
|
||||
Name: SourcegraphToolName,
|
||||
Description: sourcegraphToolDescription,
|
||||
Parameters: map[string]any{
|
||||
"query": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The Sourcegraph search query",
|
||||
},
|
||||
"count": map[string]any{
|
||||
"type": "number",
|
||||
"description": "Optional number of results to return (default: 10, max: 20)",
|
||||
},
|
||||
"context_window": map[string]any{
|
||||
"type": "number",
|
||||
"description": "The context around the match to return (default: 10 lines)",
|
||||
},
|
||||
"timeout": map[string]any{
|
||||
"type": "number",
|
||||
"description": "Optional timeout in seconds (max 120)",
|
||||
},
|
||||
},
|
||||
Required: []string{"query"},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *sourcegraphTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
|
||||
var params SourcegraphParams
|
||||
if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
|
||||
return NewTextErrorResponse("Failed to parse sourcegraph parameters: " + err.Error()), nil
|
||||
}
|
||||
|
||||
if params.Query == "" {
|
||||
return NewTextErrorResponse("Query parameter is required"), nil
|
||||
}
|
||||
|
||||
if params.Count <= 0 {
|
||||
params.Count = 10
|
||||
} else if params.Count > 20 {
|
||||
params.Count = 20 // Limit to 20 results
|
||||
}
|
||||
|
||||
if params.ContextWindow <= 0 {
|
||||
params.ContextWindow = 10 // Default context window
|
||||
}
|
||||
client := t.client
|
||||
if params.Timeout > 0 {
|
||||
maxTimeout := 120 // 2 minutes
|
||||
if params.Timeout > maxTimeout {
|
||||
params.Timeout = maxTimeout
|
||||
}
|
||||
client = &http.Client{
|
||||
Timeout: time.Duration(params.Timeout) * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
type graphqlRequest struct {
|
||||
Query string `json:"query"`
|
||||
Variables struct {
|
||||
Query string `json:"query"`
|
||||
} `json:"variables"`
|
||||
}
|
||||
|
||||
request := graphqlRequest{
|
||||
Query: "query Search($query: String!) { search(query: $query, version: V2, patternType: keyword ) { results { matchCount, limitHit, resultCount, approximateResultCount, missing { name }, timedout { name }, indexUnavailable, results { __typename, ... on FileMatch { repository { name }, file { path, url, content }, lineMatches { preview, lineNumber, offsetAndLengths } } } } } }",
|
||||
}
|
||||
request.Variables.Query = params.Query
|
||||
|
||||
graphqlQueryBytes, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return ToolResponse{}, fmt.Errorf("failed to marshal GraphQL request: %w", err)
|
||||
}
|
||||
graphqlQuery := string(graphqlQueryBytes)
|
||||
|
||||
req, err := http.NewRequestWithContext(
|
||||
ctx,
|
||||
"POST",
|
||||
"https://sourcegraph.com/.api/graphql",
|
||||
bytes.NewBuffer([]byte(graphqlQuery)),
|
||||
)
|
||||
if err != nil {
|
||||
return ToolResponse{}, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", "opencode/1.0")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return ToolResponse{}, fmt.Errorf("failed to fetch URL: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if len(body) > 0 {
|
||||
return NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d, response: %s", resp.StatusCode, string(body))), nil
|
||||
}
|
||||
|
||||
return NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d", resp.StatusCode)), nil
|
||||
}
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return ToolResponse{}, fmt.Errorf("failed to read response body: %w", err)
|
||||
}
|
||||
|
||||
var result map[string]any
|
||||
if err = json.Unmarshal(body, &result); err != nil {
|
||||
return ToolResponse{}, fmt.Errorf("failed to unmarshal response: %w", err)
|
||||
}
|
||||
|
||||
formattedResults, err := formatSourcegraphResults(result, params.ContextWindow)
|
||||
if err != nil {
|
||||
return NewTextErrorResponse("Failed to format results: " + err.Error()), nil
|
||||
}
|
||||
|
||||
return NewTextResponse(formattedResults), nil
|
||||
}
|
||||
|
||||
func formatSourcegraphResults(result map[string]any, contextWindow int) (string, error) {
|
||||
var buffer strings.Builder
|
||||
|
||||
if errors, ok := result["errors"].([]any); ok && len(errors) > 0 {
|
||||
buffer.WriteString("## Sourcegraph API Error\n\n")
|
||||
for _, err := range errors {
|
||||
if errMap, ok := err.(map[string]any); ok {
|
||||
if message, ok := errMap["message"].(string); ok {
|
||||
buffer.WriteString(fmt.Sprintf("- %s\n", message))
|
||||
}
|
||||
}
|
||||
}
|
||||
return buffer.String(), nil
|
||||
}
|
||||
|
||||
data, ok := result["data"].(map[string]any)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("invalid response format: missing data field")
|
||||
}
|
||||
|
||||
search, ok := data["search"].(map[string]any)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("invalid response format: missing search field")
|
||||
}
|
||||
|
||||
searchResults, ok := search["results"].(map[string]any)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("invalid response format: missing results field")
|
||||
}
|
||||
|
||||
matchCount, _ := searchResults["matchCount"].(float64)
|
||||
resultCount, _ := searchResults["resultCount"].(float64)
|
||||
limitHit, _ := searchResults["limitHit"].(bool)
|
||||
|
||||
buffer.WriteString("# Sourcegraph Search Results\n\n")
|
||||
buffer.WriteString(fmt.Sprintf("Found %d matches across %d results\n", int(matchCount), int(resultCount)))
|
||||
|
||||
if limitHit {
|
||||
buffer.WriteString("(Result limit reached, try a more specific query)\n")
|
||||
}
|
||||
|
||||
buffer.WriteString("\n")
|
||||
|
||||
results, ok := searchResults["results"].([]any)
|
||||
if !ok || len(results) == 0 {
|
||||
buffer.WriteString("No results found. Try a different query.\n")
|
||||
return buffer.String(), nil
|
||||
}
|
||||
|
||||
maxResults := 10
|
||||
if len(results) > maxResults {
|
||||
results = results[:maxResults]
|
||||
}
|
||||
|
||||
for i, res := range results {
|
||||
fileMatch, ok := res.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
typeName, _ := fileMatch["__typename"].(string)
|
||||
if typeName != "FileMatch" {
|
||||
continue
|
||||
}
|
||||
|
||||
repo, _ := fileMatch["repository"].(map[string]any)
|
||||
file, _ := fileMatch["file"].(map[string]any)
|
||||
lineMatches, _ := fileMatch["lineMatches"].([]any)
|
||||
|
||||
if repo == nil || file == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
repoName, _ := repo["name"].(string)
|
||||
filePath, _ := file["path"].(string)
|
||||
fileURL, _ := file["url"].(string)
|
||||
fileContent, _ := file["content"].(string)
|
||||
|
||||
buffer.WriteString(fmt.Sprintf("## Result %d: %s/%s\n\n", i+1, repoName, filePath))
|
||||
|
||||
if fileURL != "" {
|
||||
buffer.WriteString(fmt.Sprintf("URL: %s\n\n", fileURL))
|
||||
}
|
||||
|
||||
if len(lineMatches) > 0 {
|
||||
for _, lm := range lineMatches {
|
||||
lineMatch, ok := lm.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
lineNumber, _ := lineMatch["lineNumber"].(float64)
|
||||
preview, _ := lineMatch["preview"].(string)
|
||||
|
||||
if fileContent != "" {
|
||||
lines := strings.Split(fileContent, "\n")
|
||||
|
||||
buffer.WriteString("```\n")
|
||||
|
||||
startLine := max(1, int(lineNumber)-contextWindow)
|
||||
|
||||
for j := startLine - 1; j < int(lineNumber)-1 && j < len(lines); j++ {
|
||||
if j >= 0 {
|
||||
buffer.WriteString(fmt.Sprintf("%d| %s\n", j+1, lines[j]))
|
||||
}
|
||||
}
|
||||
|
||||
buffer.WriteString(fmt.Sprintf("%d| %s\n", int(lineNumber), preview))
|
||||
|
||||
endLine := int(lineNumber) + contextWindow
|
||||
|
||||
for j := int(lineNumber); j < endLine && j < len(lines); j++ {
|
||||
if j < len(lines) {
|
||||
buffer.WriteString(fmt.Sprintf("%d| %s\n", j+1, lines[j]))
|
||||
}
|
||||
}
|
||||
|
||||
buffer.WriteString("```\n\n")
|
||||
} else {
|
||||
buffer.WriteString("```\n")
|
||||
buffer.WriteString(fmt.Sprintf("%d| %s\n", int(lineNumber), preview))
|
||||
buffer.WriteString("```\n\n")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return buffer.String(), nil
|
||||
}
|
||||
@@ -10,8 +10,8 @@ import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/opencode-ai/opencode/internal/config"
|
||||
"github.com/opencode-ai/opencode/internal/lsp"
|
||||
"github.com/sst/opencode/internal/config"
|
||||
"github.com/sst/opencode/internal/lsp"
|
||||
)
|
||||
|
||||
type ViewParams struct {
|
||||
|
||||
@@ -9,12 +9,12 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/opencode-ai/opencode/internal/config"
|
||||
"github.com/opencode-ai/opencode/internal/diff"
|
||||
"github.com/opencode-ai/opencode/internal/history"
|
||||
"github.com/opencode-ai/opencode/internal/logging"
|
||||
"github.com/opencode-ai/opencode/internal/lsp"
|
||||
"github.com/opencode-ai/opencode/internal/permission"
|
||||
"github.com/sst/opencode/internal/config"
|
||||
"github.com/sst/opencode/internal/diff"
|
||||
"github.com/sst/opencode/internal/history"
|
||||
"github.com/sst/opencode/internal/lsp"
|
||||
"github.com/sst/opencode/internal/permission"
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
type WriteParams struct {
|
||||
@@ -167,6 +167,7 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error
|
||||
permissionPath = rootDir
|
||||
}
|
||||
p := w.permissions.Request(
|
||||
ctx,
|
||||
permission.CreatePermissionRequest{
|
||||
SessionID: sessionID,
|
||||
Path: permissionPath,
|
||||
@@ -189,7 +190,7 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error
|
||||
}
|
||||
|
||||
// Check if file exists in history
|
||||
file, err := w.files.GetByPathAndSession(ctx, filePath, sessionID)
|
||||
file, err := w.files.GetLatestByPathAndSession(ctx, filePath, sessionID)
|
||||
if err != nil {
|
||||
_, err = w.files.Create(ctx, sessionID, filePath, oldContent)
|
||||
if err != nil {
|
||||
@@ -201,13 +202,13 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error
|
||||
// User Manually changed the content store an intermediate version
|
||||
_, err = w.files.CreateVersion(ctx, sessionID, filePath, oldContent)
|
||||
if err != nil {
|
||||
logging.Debug("Error creating file history version", "error", err)
|
||||
slog.Debug("Error creating file history version", "error", err)
|
||||
}
|
||||
}
|
||||
// Store the new version
|
||||
_, err = w.files.CreateVersion(ctx, sessionID, filePath, params.Content)
|
||||
if err != nil {
|
||||
logging.Debug("Error creating file history version", "error", err)
|
||||
slog.Debug("Error creating file history version", "error", err)
|
||||
}
|
||||
|
||||
recordFileWrite(filePath)
|
||||
|
||||
@@ -1,78 +0,0 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"runtime/debug"
|
||||
"time"
|
||||
)
|
||||
|
||||
func Info(msg string, args ...any) {
|
||||
slog.Info(msg, args...)
|
||||
}
|
||||
|
||||
func Debug(msg string, args ...any) {
|
||||
slog.Debug(msg, args...)
|
||||
}
|
||||
|
||||
func Warn(msg string, args ...any) {
|
||||
slog.Warn(msg, args...)
|
||||
}
|
||||
|
||||
func Error(msg string, args ...any) {
|
||||
slog.Error(msg, args...)
|
||||
}
|
||||
|
||||
func InfoPersist(msg string, args ...any) {
|
||||
args = append(args, persistKeyArg, true)
|
||||
slog.Info(msg, args...)
|
||||
}
|
||||
|
||||
func DebugPersist(msg string, args ...any) {
|
||||
args = append(args, persistKeyArg, true)
|
||||
slog.Debug(msg, args...)
|
||||
}
|
||||
|
||||
func WarnPersist(msg string, args ...any) {
|
||||
args = append(args, persistKeyArg, true)
|
||||
slog.Warn(msg, args...)
|
||||
}
|
||||
|
||||
func ErrorPersist(msg string, args ...any) {
|
||||
args = append(args, persistKeyArg, true)
|
||||
slog.Error(msg, args...)
|
||||
}
|
||||
|
||||
// RecoverPanic is a common function to handle panics gracefully.
|
||||
// It logs the error, creates a panic log file with stack trace,
|
||||
// and executes an optional cleanup function before returning.
|
||||
func RecoverPanic(name string, cleanup func()) {
|
||||
if r := recover(); r != nil {
|
||||
// Log the panic
|
||||
ErrorPersist(fmt.Sprintf("Panic in %s: %v", name, r))
|
||||
|
||||
// Create a timestamped panic log file
|
||||
timestamp := time.Now().Format("20060102-150405")
|
||||
filename := fmt.Sprintf("opencode-panic-%s-%s.log", name, timestamp)
|
||||
|
||||
file, err := os.Create(filename)
|
||||
if err != nil {
|
||||
ErrorPersist(fmt.Sprintf("Failed to create panic log: %v", err))
|
||||
} else {
|
||||
defer file.Close()
|
||||
|
||||
// Write panic information and stack trace
|
||||
fmt.Fprintf(file, "Panic in %s: %v\n\n", name, r)
|
||||
fmt.Fprintf(file, "Time: %s\n\n", time.Now().Format(time.RFC3339))
|
||||
fmt.Fprintf(file, "Stack Trace:\n%s\n", debug.Stack())
|
||||
|
||||
InfoPersist(fmt.Sprintf("Panic details written to %s", filename))
|
||||
}
|
||||
|
||||
// Execute cleanup function if provided
|
||||
if cleanup != nil {
|
||||
cleanup()
|
||||
}
|
||||
}
|
||||
}
|
||||
292
internal/logging/logging.go
Normal file
292
internal/logging/logging.go
Normal file
@@ -0,0 +1,292 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-logfmt/logfmt"
|
||||
"github.com/google/uuid"
|
||||
"github.com/sst/opencode/internal/db"
|
||||
"github.com/sst/opencode/internal/pubsub"
|
||||
)
|
||||
|
||||
type Log struct {
|
||||
ID string
|
||||
SessionID string
|
||||
Timestamp time.Time
|
||||
Level string
|
||||
Message string
|
||||
Attributes map[string]string
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
const (
|
||||
EventLogCreated pubsub.EventType = "log_created"
|
||||
)
|
||||
|
||||
type Service interface {
|
||||
pubsub.Subscriber[Log]
|
||||
|
||||
Create(ctx context.Context, timestamp time.Time, level, message string, attributes map[string]string, sessionID string) error
|
||||
ListBySession(ctx context.Context, sessionID string) ([]Log, error)
|
||||
ListAll(ctx context.Context, limit int) ([]Log, error)
|
||||
}
|
||||
|
||||
type service struct {
|
||||
db *db.Queries
|
||||
broker *pubsub.Broker[Log]
|
||||
}
|
||||
|
||||
var globalLoggingService *service
|
||||
|
||||
func InitService(dbConn *sql.DB) error {
|
||||
if globalLoggingService != nil {
|
||||
return fmt.Errorf("logging service already initialized")
|
||||
}
|
||||
queries := db.New(dbConn)
|
||||
broker := pubsub.NewBroker[Log]()
|
||||
|
||||
globalLoggingService = &service{
|
||||
db: queries,
|
||||
broker: broker,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetService() Service {
|
||||
if globalLoggingService == nil {
|
||||
panic("logging service not initialized. Call logging.InitService() first.")
|
||||
}
|
||||
return globalLoggingService
|
||||
}
|
||||
|
||||
func (s *service) Create(ctx context.Context, timestamp time.Time, level, message string, attributes map[string]string, sessionID string) error {
|
||||
if level == "" {
|
||||
level = "info"
|
||||
}
|
||||
|
||||
var attributesJSON sql.NullString
|
||||
if len(attributes) > 0 {
|
||||
attributesBytes, err := json.Marshal(attributes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal log attributes: %w", err)
|
||||
}
|
||||
attributesJSON = sql.NullString{String: string(attributesBytes), Valid: true}
|
||||
}
|
||||
|
||||
dbLog, err := s.db.CreateLog(ctx, db.CreateLogParams{
|
||||
ID: uuid.New().String(),
|
||||
SessionID: sql.NullString{String: sessionID, Valid: sessionID != ""},
|
||||
Timestamp: timestamp.UTC().Format(time.RFC3339Nano),
|
||||
Level: level,
|
||||
Message: message,
|
||||
Attributes: attributesJSON,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("db.CreateLog: %w", err)
|
||||
}
|
||||
|
||||
log := s.fromDBItem(dbLog)
|
||||
s.broker.Publish(EventLogCreated, log)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *service) ListBySession(ctx context.Context, sessionID string) ([]Log, error) {
|
||||
dbLogs, err := s.db.ListLogsBySession(ctx, sql.NullString{String: sessionID, Valid: true})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("db.ListLogsBySession: %w", err)
|
||||
}
|
||||
|
||||
logs := make([]Log, len(dbLogs))
|
||||
for i, dbSess := range dbLogs {
|
||||
logs[i] = s.fromDBItem(dbSess)
|
||||
}
|
||||
return logs, nil
|
||||
}
|
||||
|
||||
func (s *service) ListAll(ctx context.Context, limit int) ([]Log, error) {
|
||||
dbLogs, err := s.db.ListAllLogs(ctx, int64(limit))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("db.ListAllLogs: %w", err)
|
||||
}
|
||||
logs := make([]Log, len(dbLogs))
|
||||
for i, dbSess := range dbLogs {
|
||||
logs[i] = s.fromDBItem(dbSess)
|
||||
}
|
||||
return logs, nil
|
||||
}
|
||||
|
||||
func (s *service) Subscribe(ctx context.Context) <-chan pubsub.Event[Log] {
|
||||
return s.broker.Subscribe(ctx)
|
||||
}
|
||||
|
||||
func (s *service) fromDBItem(item db.Log) Log {
|
||||
log := Log{
|
||||
ID: item.ID,
|
||||
SessionID: item.SessionID.String,
|
||||
Level: item.Level,
|
||||
Message: item.Message,
|
||||
}
|
||||
|
||||
// Parse timestamp from ISO string
|
||||
timestamp, err := time.Parse(time.RFC3339Nano, item.Timestamp)
|
||||
if err == nil {
|
||||
log.Timestamp = timestamp
|
||||
} else {
|
||||
log.Timestamp = time.Now() // Fallback
|
||||
}
|
||||
|
||||
// Parse created_at from ISO string
|
||||
createdAt, err := time.Parse(time.RFC3339Nano, item.CreatedAt)
|
||||
if err == nil {
|
||||
log.CreatedAt = createdAt
|
||||
} else {
|
||||
log.CreatedAt = time.Now() // Fallback
|
||||
}
|
||||
|
||||
if item.Attributes.Valid && item.Attributes.String != "" {
|
||||
if err := json.Unmarshal([]byte(item.Attributes.String), &log.Attributes); err != nil {
|
||||
slog.Error("Failed to unmarshal log attributes", "log_id", item.ID, "error", err)
|
||||
log.Attributes = make(map[string]string)
|
||||
}
|
||||
} else {
|
||||
log.Attributes = make(map[string]string)
|
||||
}
|
||||
|
||||
return log
|
||||
}
|
||||
|
||||
func Create(ctx context.Context, timestamp time.Time, level, message string, attributes map[string]string, sessionID string) error {
|
||||
return GetService().Create(ctx, timestamp, level, message, attributes, sessionID)
|
||||
}
|
||||
|
||||
func ListBySession(ctx context.Context, sessionID string) ([]Log, error) {
|
||||
return GetService().ListBySession(ctx, sessionID)
|
||||
}
|
||||
|
||||
func ListAll(ctx context.Context, limit int) ([]Log, error) {
|
||||
return GetService().ListAll(ctx, limit)
|
||||
}
|
||||
|
||||
func Subscribe(ctx context.Context) <-chan pubsub.Event[Log] {
|
||||
return GetService().Subscribe(ctx)
|
||||
}
|
||||
|
||||
type slogWriter struct{}
|
||||
|
||||
func (sw *slogWriter) Write(p []byte) (n int, err error) {
|
||||
// Example: time=2024-05-09T12:34:56.789-05:00 level=INFO msg="User request" session=xyz foo=bar
|
||||
d := logfmt.NewDecoder(bytes.NewReader(p))
|
||||
for d.ScanRecord() {
|
||||
var timestamp time.Time
|
||||
var level string
|
||||
var message string
|
||||
var sessionID string
|
||||
var attributes map[string]string
|
||||
|
||||
attributes = make(map[string]string)
|
||||
hasTimestamp := false
|
||||
|
||||
for d.ScanKeyval() {
|
||||
key := string(d.Key())
|
||||
value := string(d.Value())
|
||||
|
||||
switch key {
|
||||
case "time":
|
||||
parsedTime, timeErr := time.Parse(time.RFC3339Nano, value)
|
||||
if timeErr != nil {
|
||||
parsedTime, timeErr = time.Parse(time.RFC3339, value)
|
||||
if timeErr != nil {
|
||||
slog.Error("Failed to parse time in slog writer", "value", value, "error", timeErr)
|
||||
timestamp = time.Now().UTC()
|
||||
hasTimestamp = true
|
||||
continue
|
||||
}
|
||||
}
|
||||
timestamp = parsedTime
|
||||
hasTimestamp = true
|
||||
case "level":
|
||||
level = strings.ToLower(value)
|
||||
case "msg", "message":
|
||||
message = value
|
||||
case "session_id":
|
||||
sessionID = value
|
||||
default:
|
||||
attributes[key] = value
|
||||
}
|
||||
}
|
||||
if d.Err() != nil {
|
||||
return len(p), fmt.Errorf("logfmt.ScanRecord: %w", d.Err())
|
||||
}
|
||||
|
||||
if !hasTimestamp {
|
||||
timestamp = time.Now()
|
||||
}
|
||||
|
||||
// Create log entry via the service (non-blocking or handle error appropriately)
|
||||
// Using context.Background() as this is a low-level logging write.
|
||||
go func(timestamp time.Time, level, message string, attributes map[string]string, sessionID string) { // Run in a goroutine to avoid blocking slog
|
||||
if globalLoggingService == nil {
|
||||
// If the logging service is not initialized, log the message to stderr
|
||||
// fmt.Fprintf(os.Stderr, "ERROR [logging.slogWriter]: logging service not initialized\n")
|
||||
return
|
||||
}
|
||||
if err := Create(context.Background(), timestamp, level, message, attributes, sessionID); err != nil {
|
||||
// Log internal error using a more primitive logger to avoid loops
|
||||
fmt.Fprintf(os.Stderr, "ERROR [logging.slogWriter]: failed to persist log: %v\n", err)
|
||||
}
|
||||
}(timestamp, level, message, attributes, sessionID)
|
||||
}
|
||||
if d.Err() != nil {
|
||||
return len(p), fmt.Errorf("logfmt.ScanRecord final: %w", d.Err())
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func NewSlogWriter() io.Writer {
|
||||
return &slogWriter{}
|
||||
}
|
||||
|
||||
// RecoverPanic is a common function to handle panics gracefully.
|
||||
// It logs the error, creates a panic log file with stack trace,
|
||||
// and executes an optional cleanup function.
|
||||
func RecoverPanic(name string, cleanup func()) {
|
||||
if r := recover(); r != nil {
|
||||
errorMsg := fmt.Sprintf("Panic in %s: %v", name, r)
|
||||
// Use slog directly here, as our service might be the one panicking.
|
||||
slog.Error(errorMsg)
|
||||
// status.Error(errorMsg)
|
||||
|
||||
timestamp := time.Now().Format("20060102-150405")
|
||||
filename := fmt.Sprintf("opencode-panic-%s-%s.log", name, timestamp)
|
||||
|
||||
file, err := os.Create(filename)
|
||||
if err != nil {
|
||||
errMsg := fmt.Sprintf("Failed to create panic log file '%s': %v", filename, err)
|
||||
slog.Error(errMsg)
|
||||
// status.Error(errMsg)
|
||||
} else {
|
||||
defer file.Close()
|
||||
fmt.Fprintf(file, "Panic in %s: %v\n\n", name, r)
|
||||
fmt.Fprintf(file, "Time: %s\n\n", time.Now().Format(time.RFC3339))
|
||||
fmt.Fprintf(file, "Stack Trace:\n%s\n", string(debug.Stack())) // Capture stack trace
|
||||
infoMsg := fmt.Sprintf("Panic details written to %s", filename)
|
||||
slog.Info(infoMsg)
|
||||
// status.Info(infoMsg)
|
||||
}
|
||||
|
||||
if cleanup != nil {
|
||||
cleanup()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,21 +0,0 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// LogMessage is the event payload for a log message
|
||||
type LogMessage struct {
|
||||
ID string
|
||||
Time time.Time
|
||||
Level string
|
||||
Persist bool // used when we want to show the mesage in the status bar
|
||||
PersistTime time.Duration // used when we want to show the mesage in the status bar
|
||||
Message string `json:"msg"`
|
||||
Attributes []Attr
|
||||
}
|
||||
|
||||
type Attr struct {
|
||||
Key string
|
||||
Value string
|
||||
}
|
||||
@@ -1,101 +0,0 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-logfmt/logfmt"
|
||||
"github.com/opencode-ai/opencode/internal/pubsub"
|
||||
)
|
||||
|
||||
const (
|
||||
persistKeyArg = "$_persist"
|
||||
PersistTimeArg = "$_persist_time"
|
||||
)
|
||||
|
||||
type LogData struct {
|
||||
messages []LogMessage
|
||||
*pubsub.Broker[LogMessage]
|
||||
lock sync.Mutex
|
||||
}
|
||||
|
||||
func (l *LogData) Add(msg LogMessage) {
|
||||
l.lock.Lock()
|
||||
defer l.lock.Unlock()
|
||||
l.messages = append(l.messages, msg)
|
||||
l.Publish(pubsub.CreatedEvent, msg)
|
||||
}
|
||||
|
||||
func (l *LogData) List() []LogMessage {
|
||||
l.lock.Lock()
|
||||
defer l.lock.Unlock()
|
||||
return l.messages
|
||||
}
|
||||
|
||||
var defaultLogData = &LogData{
|
||||
messages: make([]LogMessage, 0),
|
||||
Broker: pubsub.NewBroker[LogMessage](),
|
||||
}
|
||||
|
||||
type writer struct{}
|
||||
|
||||
func (w *writer) Write(p []byte) (int, error) {
|
||||
d := logfmt.NewDecoder(bytes.NewReader(p))
|
||||
for d.ScanRecord() {
|
||||
msg := LogMessage{
|
||||
ID: fmt.Sprintf("%d", time.Now().UnixNano()),
|
||||
Time: time.Now(),
|
||||
}
|
||||
for d.ScanKeyval() {
|
||||
switch string(d.Key()) {
|
||||
case "time":
|
||||
parsed, err := time.Parse(time.RFC3339, string(d.Value()))
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing time: %w", err)
|
||||
}
|
||||
msg.Time = parsed
|
||||
case "level":
|
||||
msg.Level = strings.ToLower(string(d.Value()))
|
||||
case "msg":
|
||||
msg.Message = string(d.Value())
|
||||
default:
|
||||
if string(d.Key()) == persistKeyArg {
|
||||
msg.Persist = true
|
||||
} else if string(d.Key()) == PersistTimeArg {
|
||||
parsed, err := time.ParseDuration(string(d.Value()))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
msg.PersistTime = parsed
|
||||
} else {
|
||||
msg.Attributes = append(msg.Attributes, Attr{
|
||||
Key: string(d.Key()),
|
||||
Value: string(d.Value()),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
defaultLogData.Add(msg)
|
||||
}
|
||||
if d.Err() != nil {
|
||||
return 0, d.Err()
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func NewWriter() *writer {
|
||||
w := &writer{}
|
||||
return w
|
||||
}
|
||||
|
||||
func Subscribe(ctx context.Context) <-chan pubsub.Event[LogMessage] {
|
||||
return defaultLogData.Subscribe(ctx)
|
||||
}
|
||||
|
||||
func List() []LogMessage {
|
||||
return defaultLogData.List()
|
||||
}
|
||||
@@ -14,9 +14,12 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/opencode-ai/opencode/internal/config"
|
||||
"github.com/opencode-ai/opencode/internal/logging"
|
||||
"github.com/opencode-ai/opencode/internal/lsp/protocol"
|
||||
"log/slog"
|
||||
|
||||
"github.com/sst/opencode/internal/config"
|
||||
"github.com/sst/opencode/internal/logging"
|
||||
"github.com/sst/opencode/internal/lsp/protocol"
|
||||
"github.com/sst/opencode/internal/status"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
@@ -96,17 +99,17 @@ func NewClient(ctx context.Context, command string, args ...string) (*Client, er
|
||||
go func() {
|
||||
scanner := bufio.NewScanner(stderr)
|
||||
for scanner.Scan() {
|
||||
fmt.Fprintf(os.Stderr, "LSP Server: %s\n", scanner.Text())
|
||||
slog.Info("LSP Server", "message", scanner.Text())
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error reading stderr: %v\n", err)
|
||||
slog.Error("Error reading LSP stderr", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Start message handling loop
|
||||
go func() {
|
||||
defer logging.RecoverPanic("LSP-message-handler", func() {
|
||||
logging.ErrorPersist("LSP message handler crashed, LSP functionality may be impaired")
|
||||
status.Error("LSP message handler crashed, LSP functionality may be impaired")
|
||||
})
|
||||
client.handleMessages()
|
||||
}()
|
||||
@@ -300,7 +303,7 @@ func (c *Client) WaitForServerReady(ctx context.Context) error {
|
||||
defer ticker.Stop()
|
||||
|
||||
if cnf.DebugLSP {
|
||||
logging.Debug("Waiting for LSP server to be ready...")
|
||||
slog.Debug("Waiting for LSP server to be ready...")
|
||||
}
|
||||
|
||||
// Determine server type for specialized initialization
|
||||
@@ -309,7 +312,7 @@ func (c *Client) WaitForServerReady(ctx context.Context) error {
|
||||
// For TypeScript-like servers, we need to open some key files first
|
||||
if serverType == ServerTypeTypeScript {
|
||||
if cnf.DebugLSP {
|
||||
logging.Debug("TypeScript-like server detected, opening key configuration files")
|
||||
slog.Debug("TypeScript-like server detected, opening key configuration files")
|
||||
}
|
||||
c.openKeyConfigFiles(ctx)
|
||||
}
|
||||
@@ -326,15 +329,15 @@ func (c *Client) WaitForServerReady(ctx context.Context) error {
|
||||
// Server responded successfully
|
||||
c.SetServerState(StateReady)
|
||||
if cnf.DebugLSP {
|
||||
logging.Debug("LSP server is ready")
|
||||
slog.Debug("LSP server is ready")
|
||||
}
|
||||
return nil
|
||||
} else {
|
||||
logging.Debug("LSP server not ready yet", "error", err, "serverType", serverType)
|
||||
slog.Debug("LSP server not ready yet", "error", err, "serverType", serverType)
|
||||
}
|
||||
|
||||
if cnf.DebugLSP {
|
||||
logging.Debug("LSP server not ready yet", "error", err, "serverType", serverType)
|
||||
slog.Debug("LSP server not ready yet", "error", err, "serverType", serverType)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -409,9 +412,9 @@ func (c *Client) openKeyConfigFiles(ctx context.Context) {
|
||||
if _, err := os.Stat(file); err == nil {
|
||||
// File exists, try to open it
|
||||
if err := c.OpenFile(ctx, file); err != nil {
|
||||
logging.Debug("Failed to open key config file", "file", file, "error", err)
|
||||
slog.Debug("Failed to open key config file", "file", file, "error", err)
|
||||
} else {
|
||||
logging.Debug("Opened key config file for initialization", "file", file)
|
||||
slog.Debug("Opened key config file for initialization", "file", file)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -487,7 +490,7 @@ func (c *Client) pingTypeScriptServer(ctx context.Context) error {
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
logging.Debug("Error walking directory for TypeScript files", "error", err)
|
||||
slog.Debug("Error walking directory for TypeScript files", "error", err)
|
||||
}
|
||||
|
||||
// Final fallback - just try a generic capability
|
||||
@@ -527,7 +530,7 @@ func (c *Client) openTypeScriptFiles(ctx context.Context, workDir string) {
|
||||
if err := c.OpenFile(ctx, path); err == nil {
|
||||
filesOpened++
|
||||
if cnf.DebugLSP {
|
||||
logging.Debug("Opened TypeScript file for initialization", "file", path)
|
||||
slog.Debug("Opened TypeScript file for initialization", "file", path)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -536,11 +539,11 @@ func (c *Client) openTypeScriptFiles(ctx context.Context, workDir string) {
|
||||
})
|
||||
|
||||
if err != nil && cnf.DebugLSP {
|
||||
logging.Debug("Error walking directory for TypeScript files", "error", err)
|
||||
slog.Debug("Error walking directory for TypeScript files", "error", err)
|
||||
}
|
||||
|
||||
if cnf.DebugLSP {
|
||||
logging.Debug("Opened TypeScript files for initialization", "count", filesOpened)
|
||||
slog.Debug("Opened TypeScript files for initialization", "count", filesOpened)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -627,6 +630,15 @@ func (c *Client) OpenFile(ctx context.Context, filepath string) error {
|
||||
func (c *Client) NotifyChange(ctx context.Context, filepath string) error {
|
||||
uri := fmt.Sprintf("file://%s", filepath)
|
||||
|
||||
// Verify file exists before attempting to read it
|
||||
if _, err := os.Stat(filepath); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
// File was deleted - close it in the LSP client instead of notifying change
|
||||
return c.CloseFile(ctx, filepath)
|
||||
}
|
||||
return fmt.Errorf("error checking file: %w", err)
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(filepath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error reading file: %w", err)
|
||||
@@ -681,7 +693,7 @@ func (c *Client) CloseFile(ctx context.Context, filepath string) error {
|
||||
}
|
||||
|
||||
if cnf.DebugLSP {
|
||||
logging.Debug("Closing file", "file", filepath)
|
||||
slog.Debug("Closing file", "file", filepath)
|
||||
}
|
||||
if err := c.Notify(ctx, "textDocument/didClose", params); err != nil {
|
||||
return err
|
||||
@@ -720,12 +732,12 @@ func (c *Client) CloseAllFiles(ctx context.Context) {
|
||||
for _, filePath := range filesToClose {
|
||||
err := c.CloseFile(ctx, filePath)
|
||||
if err != nil && cnf.DebugLSP {
|
||||
logging.Warn("Error closing file", "file", filePath, "error", err)
|
||||
slog.Warn("Error closing file", "file", filePath, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
if cnf.DebugLSP {
|
||||
logging.Debug("Closed all files", "files", filesToClose)
|
||||
slog.Debug("Closed all files", "files", filesToClose)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
65
internal/lsp/discovery/integration.go
Normal file
65
internal/lsp/discovery/integration.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package discovery
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/sst/opencode/internal/config"
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
// IntegrateLSPServers discovers languages and LSP servers and integrates them into the application configuration
|
||||
func IntegrateLSPServers(workingDir string) error {
|
||||
// Get the current configuration
|
||||
cfg := config.Get()
|
||||
if cfg == nil {
|
||||
return fmt.Errorf("config not loaded")
|
||||
}
|
||||
|
||||
// Check if this is the first run
|
||||
shouldInit, err := config.ShouldShowInitDialog()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check initialization status: %w", err)
|
||||
}
|
||||
|
||||
// Always run language detection, but log differently for first run vs. subsequent runs
|
||||
if shouldInit || len(cfg.LSP) == 0 {
|
||||
slog.Info("Running initial LSP auto-discovery...")
|
||||
} else {
|
||||
slog.Debug("Running LSP auto-discovery to detect new languages...")
|
||||
}
|
||||
|
||||
// Configure LSP servers
|
||||
servers, err := ConfigureLSPServers(workingDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to configure LSP servers: %w", err)
|
||||
}
|
||||
|
||||
// Update the configuration with discovered servers
|
||||
for langID, serverInfo := range servers {
|
||||
// Skip languages that already have a configured server
|
||||
if _, exists := cfg.LSP[langID]; exists {
|
||||
slog.Debug("LSP server already configured for language", "language", langID)
|
||||
continue
|
||||
}
|
||||
|
||||
if serverInfo.Available {
|
||||
// Only add servers that were found
|
||||
cfg.LSP[langID] = config.LSPConfig{
|
||||
Disabled: false,
|
||||
Command: serverInfo.Path,
|
||||
Args: serverInfo.Args,
|
||||
}
|
||||
slog.Info("Added LSP server to configuration",
|
||||
"language", langID,
|
||||
"command", serverInfo.Command,
|
||||
"path", serverInfo.Path)
|
||||
} else {
|
||||
slog.Warn("LSP server not available",
|
||||
"language", langID,
|
||||
"command", serverInfo.Command,
|
||||
"installCmd", serverInfo.InstallCmd)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
298
internal/lsp/discovery/language.go
Normal file
298
internal/lsp/discovery/language.go
Normal file
@@ -0,0 +1,298 @@
|
||||
package discovery
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/sst/opencode/internal/lsp"
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
// LanguageInfo stores information about a detected language
|
||||
type LanguageInfo struct {
|
||||
// Language identifier (e.g., "go", "typescript", "python")
|
||||
ID string
|
||||
|
||||
// Number of files detected for this language
|
||||
FileCount int
|
||||
|
||||
// Project files associated with this language (e.g., go.mod, package.json)
|
||||
ProjectFiles []string
|
||||
|
||||
// Whether this is likely a primary language in the project
|
||||
IsPrimary bool
|
||||
}
|
||||
|
||||
// ProjectFile represents a project configuration file
|
||||
type ProjectFile struct {
|
||||
// File name or pattern to match
|
||||
Name string
|
||||
|
||||
// Associated language ID
|
||||
LanguageID string
|
||||
|
||||
// Whether this file strongly indicates the language is primary
|
||||
IsPrimary bool
|
||||
}
|
||||
|
||||
// Common project files that indicate specific languages
|
||||
var projectFilePatterns = []ProjectFile{
|
||||
{Name: "go.mod", LanguageID: "go", IsPrimary: true},
|
||||
{Name: "go.sum", LanguageID: "go", IsPrimary: false},
|
||||
{Name: "package.json", LanguageID: "javascript", IsPrimary: true}, // Could be TypeScript too
|
||||
{Name: "tsconfig.json", LanguageID: "typescript", IsPrimary: true},
|
||||
{Name: "jsconfig.json", LanguageID: "javascript", IsPrimary: true},
|
||||
{Name: "pyproject.toml", LanguageID: "python", IsPrimary: true},
|
||||
{Name: "setup.py", LanguageID: "python", IsPrimary: true},
|
||||
{Name: "requirements.txt", LanguageID: "python", IsPrimary: true},
|
||||
{Name: "Cargo.toml", LanguageID: "rust", IsPrimary: true},
|
||||
{Name: "Cargo.lock", LanguageID: "rust", IsPrimary: false},
|
||||
{Name: "CMakeLists.txt", LanguageID: "cmake", IsPrimary: true},
|
||||
{Name: "pom.xml", LanguageID: "java", IsPrimary: true},
|
||||
{Name: "build.gradle", LanguageID: "java", IsPrimary: true},
|
||||
{Name: "build.gradle.kts", LanguageID: "kotlin", IsPrimary: true},
|
||||
{Name: "composer.json", LanguageID: "php", IsPrimary: true},
|
||||
{Name: "Gemfile", LanguageID: "ruby", IsPrimary: true},
|
||||
{Name: "Rakefile", LanguageID: "ruby", IsPrimary: true},
|
||||
{Name: "mix.exs", LanguageID: "elixir", IsPrimary: true},
|
||||
{Name: "rebar.config", LanguageID: "erlang", IsPrimary: true},
|
||||
{Name: "dune-project", LanguageID: "ocaml", IsPrimary: true},
|
||||
{Name: "stack.yaml", LanguageID: "haskell", IsPrimary: true},
|
||||
{Name: "cabal.project", LanguageID: "haskell", IsPrimary: true},
|
||||
{Name: "Makefile", LanguageID: "make", IsPrimary: false},
|
||||
{Name: "Dockerfile", LanguageID: "dockerfile", IsPrimary: false},
|
||||
}
|
||||
|
||||
// Map of file extensions to language IDs
|
||||
var extensionToLanguage = map[string]string{
|
||||
".go": "go",
|
||||
".js": "javascript",
|
||||
".jsx": "javascript",
|
||||
".ts": "typescript",
|
||||
".tsx": "typescript",
|
||||
".py": "python",
|
||||
".rs": "rust",
|
||||
".java": "java",
|
||||
".c": "c",
|
||||
".cpp": "cpp",
|
||||
".h": "c",
|
||||
".hpp": "cpp",
|
||||
".rb": "ruby",
|
||||
".php": "php",
|
||||
".cs": "csharp",
|
||||
".fs": "fsharp",
|
||||
".swift": "swift",
|
||||
".kt": "kotlin",
|
||||
".scala": "scala",
|
||||
".hs": "haskell",
|
||||
".ml": "ocaml",
|
||||
".ex": "elixir",
|
||||
".exs": "elixir",
|
||||
".erl": "erlang",
|
||||
".lua": "lua",
|
||||
".r": "r",
|
||||
".sh": "shell",
|
||||
".bash": "shell",
|
||||
".zsh": "shell",
|
||||
".html": "html",
|
||||
".css": "css",
|
||||
".scss": "scss",
|
||||
".sass": "sass",
|
||||
".less": "less",
|
||||
".json": "json",
|
||||
".xml": "xml",
|
||||
".yaml": "yaml",
|
||||
".yml": "yaml",
|
||||
".md": "markdown",
|
||||
".dart": "dart",
|
||||
}
|
||||
|
||||
// Directories to exclude from scanning
|
||||
var excludedDirs = map[string]bool{
|
||||
".git": true,
|
||||
"node_modules": true,
|
||||
"vendor": true,
|
||||
"dist": true,
|
||||
"build": true,
|
||||
"target": true,
|
||||
".idea": true,
|
||||
".vscode": true,
|
||||
".github": true,
|
||||
".gitlab": true,
|
||||
"__pycache__": true,
|
||||
".next": true,
|
||||
".nuxt": true,
|
||||
"venv": true,
|
||||
"env": true,
|
||||
".env": true,
|
||||
}
|
||||
|
||||
// DetectLanguages scans a directory to identify programming languages used in the project
|
||||
func DetectLanguages(rootDir string) (map[string]LanguageInfo, error) {
|
||||
languages := make(map[string]LanguageInfo)
|
||||
var mutex sync.Mutex
|
||||
|
||||
// Walk the directory tree
|
||||
err := filepath.Walk(rootDir, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return nil // Skip files that can't be accessed
|
||||
}
|
||||
|
||||
// Skip excluded directories
|
||||
if info.IsDir() {
|
||||
if excludedDirs[info.Name()] || strings.HasPrefix(info.Name(), ".") {
|
||||
return filepath.SkipDir
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Skip hidden files
|
||||
if strings.HasPrefix(info.Name(), ".") {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check for project files
|
||||
for _, pattern := range projectFilePatterns {
|
||||
if info.Name() == pattern.Name {
|
||||
mutex.Lock()
|
||||
lang, exists := languages[pattern.LanguageID]
|
||||
if !exists {
|
||||
lang = LanguageInfo{
|
||||
ID: pattern.LanguageID,
|
||||
FileCount: 0,
|
||||
ProjectFiles: []string{},
|
||||
IsPrimary: pattern.IsPrimary,
|
||||
}
|
||||
}
|
||||
lang.ProjectFiles = append(lang.ProjectFiles, path)
|
||||
if pattern.IsPrimary {
|
||||
lang.IsPrimary = true
|
||||
}
|
||||
languages[pattern.LanguageID] = lang
|
||||
mutex.Unlock()
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Check file extension
|
||||
ext := strings.ToLower(filepath.Ext(path))
|
||||
if langID, ok := extensionToLanguage[ext]; ok {
|
||||
mutex.Lock()
|
||||
lang, exists := languages[langID]
|
||||
if !exists {
|
||||
lang = LanguageInfo{
|
||||
ID: langID,
|
||||
FileCount: 0,
|
||||
ProjectFiles: []string{},
|
||||
}
|
||||
}
|
||||
lang.FileCount++
|
||||
languages[langID] = lang
|
||||
mutex.Unlock()
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Determine primary languages based on file count if not already marked
|
||||
determinePrimaryLanguages(languages)
|
||||
|
||||
// Log detected languages
|
||||
for id, info := range languages {
|
||||
if info.IsPrimary {
|
||||
slog.Debug("Detected primary language", "language", id, "files", info.FileCount, "projectFiles", len(info.ProjectFiles))
|
||||
} else {
|
||||
slog.Debug("Detected secondary language", "language", id, "files", info.FileCount)
|
||||
}
|
||||
}
|
||||
|
||||
return languages, nil
|
||||
}
|
||||
|
||||
// determinePrimaryLanguages marks languages as primary based on file count
|
||||
func determinePrimaryLanguages(languages map[string]LanguageInfo) {
|
||||
// Find the language with the most files
|
||||
var maxFiles int
|
||||
for _, info := range languages {
|
||||
if info.FileCount > maxFiles {
|
||||
maxFiles = info.FileCount
|
||||
}
|
||||
}
|
||||
|
||||
// Mark languages with at least 20% of the max files as primary
|
||||
threshold := max(maxFiles/5, 5) // At least 5 files to be considered primary
|
||||
|
||||
for id, info := range languages {
|
||||
if !info.IsPrimary && info.FileCount >= threshold {
|
||||
info.IsPrimary = true
|
||||
languages[id] = info
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetLanguageIDFromExtension returns the language ID for a given file extension
|
||||
func GetLanguageIDFromExtension(ext string) string {
|
||||
ext = strings.ToLower(ext)
|
||||
if langID, ok := extensionToLanguage[ext]; ok {
|
||||
return langID
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetLanguageIDFromProtocol converts a protocol.LanguageKind to our language ID string
|
||||
func GetLanguageIDFromProtocol(langKind string) string {
|
||||
// Convert protocol language kind to our language ID
|
||||
switch langKind {
|
||||
case "go":
|
||||
return "go"
|
||||
case "typescript":
|
||||
return "typescript"
|
||||
case "typescriptreact":
|
||||
return "typescript"
|
||||
case "javascript":
|
||||
return "javascript"
|
||||
case "javascriptreact":
|
||||
return "javascript"
|
||||
case "python":
|
||||
return "python"
|
||||
case "rust":
|
||||
return "rust"
|
||||
case "java":
|
||||
return "java"
|
||||
case "c":
|
||||
return "c"
|
||||
case "cpp":
|
||||
return "cpp"
|
||||
default:
|
||||
// Try to normalize the language kind
|
||||
return strings.ToLower(langKind)
|
||||
}
|
||||
}
|
||||
|
||||
// GetLanguageIDFromPath determines the language ID from a file path
|
||||
func GetLanguageIDFromPath(path string) string {
|
||||
// Check file extension first
|
||||
ext := filepath.Ext(path)
|
||||
if langID := GetLanguageIDFromExtension(ext); langID != "" {
|
||||
return langID
|
||||
}
|
||||
|
||||
// Check if it's a known project file
|
||||
filename := filepath.Base(path)
|
||||
for _, pattern := range projectFilePatterns {
|
||||
if filename == pattern.Name {
|
||||
return pattern.LanguageID
|
||||
}
|
||||
}
|
||||
|
||||
// Use LSP's detection as a fallback
|
||||
uri := "file://" + path
|
||||
langKind := lsp.DetectLanguageID(uri)
|
||||
return GetLanguageIDFromProtocol(string(langKind))
|
||||
}
|
||||
306
internal/lsp/discovery/server.go
Normal file
306
internal/lsp/discovery/server.go
Normal file
@@ -0,0 +1,306 @@
|
||||
package discovery
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
// ServerInfo contains information about an LSP server
|
||||
type ServerInfo struct {
|
||||
// Command to run the server
|
||||
Command string
|
||||
|
||||
// Arguments to pass to the command
|
||||
Args []string
|
||||
|
||||
// Command to install the server (for user guidance)
|
||||
InstallCmd string
|
||||
|
||||
// Whether this server is available
|
||||
Available bool
|
||||
|
||||
// Full path to the executable (if found)
|
||||
Path string
|
||||
}
|
||||
|
||||
// LanguageServerMap maps language IDs to their corresponding LSP servers
|
||||
var LanguageServerMap = map[string]ServerInfo{
|
||||
"go": {
|
||||
Command: "gopls",
|
||||
InstallCmd: "go install golang.org/x/tools/gopls@latest",
|
||||
},
|
||||
"typescript": {
|
||||
Command: "typescript-language-server",
|
||||
Args: []string{"--stdio"},
|
||||
InstallCmd: "npm install -g typescript-language-server typescript",
|
||||
},
|
||||
"javascript": {
|
||||
Command: "typescript-language-server",
|
||||
Args: []string{"--stdio"},
|
||||
InstallCmd: "npm install -g typescript-language-server typescript",
|
||||
},
|
||||
"python": {
|
||||
Command: "pylsp",
|
||||
InstallCmd: "pip install python-lsp-server",
|
||||
},
|
||||
"rust": {
|
||||
Command: "rust-analyzer",
|
||||
InstallCmd: "rustup component add rust-analyzer",
|
||||
},
|
||||
"java": {
|
||||
Command: "jdtls",
|
||||
InstallCmd: "Install Eclipse JDT Language Server",
|
||||
},
|
||||
"c": {
|
||||
Command: "clangd",
|
||||
InstallCmd: "Install clangd from your package manager",
|
||||
},
|
||||
"cpp": {
|
||||
Command: "clangd",
|
||||
InstallCmd: "Install clangd from your package manager",
|
||||
},
|
||||
"php": {
|
||||
Command: "intelephense",
|
||||
Args: []string{"--stdio"},
|
||||
InstallCmd: "npm install -g intelephense",
|
||||
},
|
||||
"ruby": {
|
||||
Command: "solargraph",
|
||||
Args: []string{"stdio"},
|
||||
InstallCmd: "gem install solargraph",
|
||||
},
|
||||
"lua": {
|
||||
Command: "lua-language-server",
|
||||
InstallCmd: "Install lua-language-server from your package manager",
|
||||
},
|
||||
"html": {
|
||||
Command: "vscode-html-language-server",
|
||||
Args: []string{"--stdio"},
|
||||
InstallCmd: "npm install -g vscode-langservers-extracted",
|
||||
},
|
||||
"css": {
|
||||
Command: "vscode-css-language-server",
|
||||
Args: []string{"--stdio"},
|
||||
InstallCmd: "npm install -g vscode-langservers-extracted",
|
||||
},
|
||||
"json": {
|
||||
Command: "vscode-json-language-server",
|
||||
Args: []string{"--stdio"},
|
||||
InstallCmd: "npm install -g vscode-langservers-extracted",
|
||||
},
|
||||
"yaml": {
|
||||
Command: "yaml-language-server",
|
||||
Args: []string{"--stdio"},
|
||||
InstallCmd: "npm install -g yaml-language-server",
|
||||
},
|
||||
}
|
||||
|
||||
// FindLSPServer searches for an LSP server for the given language
|
||||
func FindLSPServer(languageID string) (ServerInfo, error) {
|
||||
// Get server info for the language
|
||||
serverInfo, exists := LanguageServerMap[languageID]
|
||||
if !exists {
|
||||
return ServerInfo{}, fmt.Errorf("no LSP server defined for language: %s", languageID)
|
||||
}
|
||||
|
||||
// Check if the command is in PATH
|
||||
path, err := exec.LookPath(serverInfo.Command)
|
||||
if err == nil {
|
||||
serverInfo.Available = true
|
||||
serverInfo.Path = path
|
||||
slog.Debug("Found LSP server in PATH", "language", languageID, "command", serverInfo.Command, "path", path)
|
||||
return serverInfo, nil
|
||||
}
|
||||
|
||||
// If not in PATH, search in common installation locations
|
||||
paths := getCommonLSPPaths(languageID, serverInfo.Command)
|
||||
for _, searchPath := range paths {
|
||||
if _, err := os.Stat(searchPath); err == nil {
|
||||
// Found the server
|
||||
serverInfo.Available = true
|
||||
serverInfo.Path = searchPath
|
||||
slog.Debug("Found LSP server in common location", "language", languageID, "command", serverInfo.Command, "path", searchPath)
|
||||
return serverInfo, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Server not found
|
||||
slog.Debug("LSP server not found", "language", languageID, "command", serverInfo.Command)
|
||||
return serverInfo, fmt.Errorf("LSP server for %s not found. Install with: %s", languageID, serverInfo.InstallCmd)
|
||||
}
|
||||
|
||||
// getCommonLSPPaths returns common installation paths for LSP servers based on language and OS
|
||||
func getCommonLSPPaths(languageID, command string) []string {
|
||||
var paths []string
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
slog.Error("Failed to get user home directory", "error", err)
|
||||
return paths
|
||||
}
|
||||
|
||||
// Add platform-specific paths
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
// macOS paths
|
||||
paths = append(paths,
|
||||
fmt.Sprintf("/usr/local/bin/%s", command),
|
||||
fmt.Sprintf("/opt/homebrew/bin/%s", command),
|
||||
fmt.Sprintf("%s/.local/bin/%s", homeDir, command),
|
||||
)
|
||||
case "linux":
|
||||
// Linux paths
|
||||
paths = append(paths,
|
||||
fmt.Sprintf("/usr/bin/%s", command),
|
||||
fmt.Sprintf("/usr/local/bin/%s", command),
|
||||
fmt.Sprintf("%s/.local/bin/%s", homeDir, command),
|
||||
)
|
||||
case "windows":
|
||||
// Windows paths
|
||||
paths = append(paths,
|
||||
fmt.Sprintf("%s\\AppData\\Local\\Programs\\%s.exe", homeDir, command),
|
||||
fmt.Sprintf("C:\\Program Files\\%s\\bin\\%s.exe", command, command),
|
||||
)
|
||||
}
|
||||
|
||||
// Add language-specific paths
|
||||
switch languageID {
|
||||
case "go":
|
||||
gopath := os.Getenv("GOPATH")
|
||||
if gopath == "" {
|
||||
gopath = filepath.Join(homeDir, "go")
|
||||
}
|
||||
paths = append(paths, filepath.Join(gopath, "bin", command))
|
||||
if runtime.GOOS == "windows" {
|
||||
paths = append(paths, filepath.Join(gopath, "bin", command+".exe"))
|
||||
}
|
||||
case "typescript", "javascript", "html", "css", "json", "yaml", "php":
|
||||
// Node.js global packages
|
||||
if runtime.GOOS == "windows" {
|
||||
paths = append(paths,
|
||||
fmt.Sprintf("%s\\AppData\\Roaming\\npm\\%s.cmd", homeDir, command),
|
||||
fmt.Sprintf("%s\\AppData\\Roaming\\npm\\node_modules\\.bin\\%s.cmd", homeDir, command),
|
||||
)
|
||||
} else {
|
||||
paths = append(paths,
|
||||
fmt.Sprintf("%s/.npm-global/bin/%s", homeDir, command),
|
||||
fmt.Sprintf("%s/.nvm/versions/node/*/bin/%s", homeDir, command),
|
||||
fmt.Sprintf("/usr/local/lib/node_modules/.bin/%s", command),
|
||||
)
|
||||
}
|
||||
case "python":
|
||||
// Python paths
|
||||
if runtime.GOOS == "windows" {
|
||||
paths = append(paths,
|
||||
fmt.Sprintf("%s\\AppData\\Local\\Programs\\Python\\Python*\\Scripts\\%s.exe", homeDir, command),
|
||||
fmt.Sprintf("C:\\Python*\\Scripts\\%s.exe", command),
|
||||
)
|
||||
} else {
|
||||
paths = append(paths,
|
||||
fmt.Sprintf("%s/.local/bin/%s", homeDir, command),
|
||||
fmt.Sprintf("%s/.pyenv/shims/%s", homeDir, command),
|
||||
fmt.Sprintf("/usr/local/bin/%s", command),
|
||||
)
|
||||
}
|
||||
case "rust":
|
||||
// Rust paths
|
||||
if runtime.GOOS == "windows" {
|
||||
paths = append(paths,
|
||||
fmt.Sprintf("%s\\.rustup\\toolchains\\*\\bin\\%s.exe", homeDir, command),
|
||||
fmt.Sprintf("%s\\.cargo\\bin\\%s.exe", homeDir, command),
|
||||
)
|
||||
} else {
|
||||
paths = append(paths,
|
||||
fmt.Sprintf("%s/.rustup/toolchains/*/bin/%s", homeDir, command),
|
||||
fmt.Sprintf("%s/.cargo/bin/%s", homeDir, command),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Add VSCode extensions path
|
||||
vscodePath := getVSCodeExtensionsPath(homeDir)
|
||||
if vscodePath != "" {
|
||||
paths = append(paths, vscodePath)
|
||||
}
|
||||
|
||||
// Expand any glob patterns in paths
|
||||
var expandedPaths []string
|
||||
for _, path := range paths {
|
||||
if strings.Contains(path, "*") {
|
||||
// This is a glob pattern, expand it
|
||||
matches, err := filepath.Glob(path)
|
||||
if err == nil {
|
||||
expandedPaths = append(expandedPaths, matches...)
|
||||
}
|
||||
} else {
|
||||
expandedPaths = append(expandedPaths, path)
|
||||
}
|
||||
}
|
||||
|
||||
return expandedPaths
|
||||
}
|
||||
|
||||
// getVSCodeExtensionsPath returns the path to VSCode extensions directory
|
||||
func getVSCodeExtensionsPath(homeDir string) string {
|
||||
var basePath string
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
basePath = filepath.Join(homeDir, "Library", "Application Support", "Code", "User", "globalStorage")
|
||||
case "linux":
|
||||
basePath = filepath.Join(homeDir, ".config", "Code", "User", "globalStorage")
|
||||
case "windows":
|
||||
basePath = filepath.Join(homeDir, "AppData", "Roaming", "Code", "User", "globalStorage")
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
|
||||
// Check if the directory exists
|
||||
if _, err := os.Stat(basePath); err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return basePath
|
||||
}
|
||||
|
||||
// ConfigureLSPServers detects languages and configures LSP servers
|
||||
func ConfigureLSPServers(rootDir string) (map[string]ServerInfo, error) {
|
||||
// Detect languages in the project
|
||||
languages, err := DetectLanguages(rootDir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to detect languages: %w", err)
|
||||
}
|
||||
|
||||
// Find LSP servers for detected languages
|
||||
servers := make(map[string]ServerInfo)
|
||||
for langID, langInfo := range languages {
|
||||
// Prioritize primary languages but include all languages that have server definitions
|
||||
if !langInfo.IsPrimary && langInfo.FileCount < 3 {
|
||||
// Skip non-primary languages with very few files
|
||||
slog.Debug("Skipping non-primary language with few files", "language", langID, "files", langInfo.FileCount)
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if we have a server for this language
|
||||
serverInfo, err := FindLSPServer(langID)
|
||||
if err != nil {
|
||||
slog.Warn("LSP server not found", "language", langID, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Add to the map of configured servers
|
||||
servers[langID] = serverInfo
|
||||
if langInfo.IsPrimary {
|
||||
slog.Info("Configured LSP server for primary language", "language", langID, "command", serverInfo.Command, "path", serverInfo.Path)
|
||||
} else {
|
||||
slog.Info("Configured LSP server for secondary language", "language", langID, "command", serverInfo.Command, "path", serverInfo.Path)
|
||||
}
|
||||
}
|
||||
|
||||
return servers, nil
|
||||
}
|
||||
@@ -3,10 +3,10 @@ package lsp
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/opencode-ai/opencode/internal/config"
|
||||
"github.com/opencode-ai/opencode/internal/logging"
|
||||
"github.com/opencode-ai/opencode/internal/lsp/protocol"
|
||||
"github.com/opencode-ai/opencode/internal/lsp/util"
|
||||
"github.com/sst/opencode/internal/config"
|
||||
"github.com/sst/opencode/internal/lsp/protocol"
|
||||
"github.com/sst/opencode/internal/lsp/util"
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
// Requests
|
||||
@@ -18,7 +18,7 @@ func HandleWorkspaceConfiguration(params json.RawMessage) (any, error) {
|
||||
func HandleRegisterCapability(params json.RawMessage) (any, error) {
|
||||
var registerParams protocol.RegistrationParams
|
||||
if err := json.Unmarshal(params, ®isterParams); err != nil {
|
||||
logging.Error("Error unmarshaling registration params", "error", err)
|
||||
slog.Error("Error unmarshaling registration params", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -28,13 +28,13 @@ func HandleRegisterCapability(params json.RawMessage) (any, error) {
|
||||
// Parse the registration options
|
||||
optionsJSON, err := json.Marshal(reg.RegisterOptions)
|
||||
if err != nil {
|
||||
logging.Error("Error marshaling registration options", "error", err)
|
||||
slog.Error("Error marshaling registration options", "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
var options protocol.DidChangeWatchedFilesRegistrationOptions
|
||||
if err := json.Unmarshal(optionsJSON, &options); err != nil {
|
||||
logging.Error("Error unmarshaling registration options", "error", err)
|
||||
slog.Error("Error unmarshaling registration options", "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -54,7 +54,7 @@ func HandleApplyEdit(params json.RawMessage) (any, error) {
|
||||
|
||||
err := util.ApplyWorkspaceEdit(edit.Edit)
|
||||
if err != nil {
|
||||
logging.Error("Error applying workspace edit", "error", err)
|
||||
slog.Error("Error applying workspace edit", "error", err)
|
||||
return protocol.ApplyWorkspaceEditResult{Applied: false, FailureReason: err.Error()}, nil
|
||||
}
|
||||
|
||||
@@ -89,7 +89,7 @@ func HandleServerMessage(params json.RawMessage) {
|
||||
}
|
||||
if err := json.Unmarshal(params, &msg); err == nil {
|
||||
if cnf.DebugLSP {
|
||||
logging.Debug("Server message", "type", msg.Type, "message", msg.Message)
|
||||
slog.Debug("Server message", "type", msg.Type, "message", msg.Message)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -97,7 +97,7 @@ func HandleServerMessage(params json.RawMessage) {
|
||||
func HandleDiagnostics(client *Client, params json.RawMessage) {
|
||||
var diagParams protocol.PublishDiagnosticsParams
|
||||
if err := json.Unmarshal(params, &diagParams); err != nil {
|
||||
logging.Error("Error unmarshaling diagnostics params", "error", err)
|
||||
slog.Error("Error unmarshaling diagnostics params", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/opencode-ai/opencode/internal/lsp/protocol"
|
||||
"github.com/sst/opencode/internal/lsp/protocol"
|
||||
)
|
||||
|
||||
func DetectLanguageID(uri string) protocol.LanguageKind {
|
||||
|
||||
@@ -4,7 +4,7 @@ package lsp
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/opencode-ai/opencode/internal/lsp/protocol"
|
||||
"github.com/sst/opencode/internal/lsp/protocol"
|
||||
)
|
||||
|
||||
// Implementation sends a textDocument/implementation request to the LSP server.
|
||||
|
||||
@@ -8,8 +8,8 @@ import (
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"github.com/opencode-ai/opencode/internal/config"
|
||||
"github.com/opencode-ai/opencode/internal/logging"
|
||||
"github.com/sst/opencode/internal/config"
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
// Write writes an LSP message to the given writer
|
||||
@@ -21,7 +21,7 @@ func WriteMessage(w io.Writer, msg *Message) error {
|
||||
cnf := config.Get()
|
||||
|
||||
if cnf.DebugLSP {
|
||||
logging.Debug("Sending message to server", "method", msg.Method, "id", msg.ID)
|
||||
slog.Debug("Sending message to server", "method", msg.Method, "id", msg.ID)
|
||||
}
|
||||
|
||||
_, err = fmt.Fprintf(w, "Content-Length: %d\r\n\r\n", len(data))
|
||||
@@ -50,7 +50,7 @@ func ReadMessage(r *bufio.Reader) (*Message, error) {
|
||||
line = strings.TrimSpace(line)
|
||||
|
||||
if cnf.DebugLSP {
|
||||
logging.Debug("Received header", "line", line)
|
||||
slog.Debug("Received header", "line", line)
|
||||
}
|
||||
|
||||
if line == "" {
|
||||
@@ -66,7 +66,7 @@ func ReadMessage(r *bufio.Reader) (*Message, error) {
|
||||
}
|
||||
|
||||
if cnf.DebugLSP {
|
||||
logging.Debug("Content-Length", "length", contentLength)
|
||||
slog.Debug("Content-Length", "length", contentLength)
|
||||
}
|
||||
|
||||
// Read content
|
||||
@@ -77,7 +77,7 @@ func ReadMessage(r *bufio.Reader) (*Message, error) {
|
||||
}
|
||||
|
||||
if cnf.DebugLSP {
|
||||
logging.Debug("Received content", "content", string(content))
|
||||
slog.Debug("Received content", "content", string(content))
|
||||
}
|
||||
|
||||
// Parse message
|
||||
@@ -96,7 +96,7 @@ func (c *Client) handleMessages() {
|
||||
msg, err := ReadMessage(c.stdout)
|
||||
if err != nil {
|
||||
if cnf.DebugLSP {
|
||||
logging.Error("Error reading message", "error", err)
|
||||
slog.Error("Error reading message", "error", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -104,7 +104,7 @@ func (c *Client) handleMessages() {
|
||||
// Handle server->client request (has both Method and ID)
|
||||
if msg.Method != "" && msg.ID != 0 {
|
||||
if cnf.DebugLSP {
|
||||
logging.Debug("Received request from server", "method", msg.Method, "id", msg.ID)
|
||||
slog.Debug("Received request from server", "method", msg.Method, "id", msg.ID)
|
||||
}
|
||||
|
||||
response := &Message{
|
||||
@@ -144,7 +144,7 @@ func (c *Client) handleMessages() {
|
||||
|
||||
// Send response back to server
|
||||
if err := WriteMessage(c.stdin, response); err != nil {
|
||||
logging.Error("Error sending response to server", "error", err)
|
||||
slog.Error("Error sending response to server", "error", err)
|
||||
}
|
||||
|
||||
continue
|
||||
@@ -158,11 +158,11 @@ func (c *Client) handleMessages() {
|
||||
|
||||
if ok {
|
||||
if cnf.DebugLSP {
|
||||
logging.Debug("Handling notification", "method", msg.Method)
|
||||
slog.Debug("Handling notification", "method", msg.Method)
|
||||
}
|
||||
go handler(msg.Params)
|
||||
} else if cnf.DebugLSP {
|
||||
logging.Debug("No handler for notification", "method", msg.Method)
|
||||
slog.Debug("No handler for notification", "method", msg.Method)
|
||||
}
|
||||
continue
|
||||
}
|
||||
@@ -175,12 +175,12 @@ func (c *Client) handleMessages() {
|
||||
|
||||
if ok {
|
||||
if cnf.DebugLSP {
|
||||
logging.Debug("Received response for request", "id", msg.ID)
|
||||
slog.Debug("Received response for request", "id", msg.ID)
|
||||
}
|
||||
ch <- msg
|
||||
close(ch)
|
||||
} else if cnf.DebugLSP {
|
||||
logging.Debug("No handler for response", "id", msg.ID)
|
||||
slog.Debug("No handler for response", "id", msg.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -192,7 +192,7 @@ func (c *Client) Call(ctx context.Context, method string, params any, result any
|
||||
id := c.nextID.Add(1)
|
||||
|
||||
if cnf.DebugLSP {
|
||||
logging.Debug("Making call", "method", method, "id", id)
|
||||
slog.Debug("Making call", "method", method, "id", id)
|
||||
}
|
||||
|
||||
msg, err := NewRequest(id, method, params)
|
||||
@@ -218,14 +218,14 @@ func (c *Client) Call(ctx context.Context, method string, params any, result any
|
||||
}
|
||||
|
||||
if cnf.DebugLSP {
|
||||
logging.Debug("Request sent", "method", method, "id", id)
|
||||
slog.Debug("Request sent", "method", method, "id", id)
|
||||
}
|
||||
|
||||
// Wait for response
|
||||
resp := <-ch
|
||||
|
||||
if cnf.DebugLSP {
|
||||
logging.Debug("Received response", "id", id)
|
||||
slog.Debug("Received response", "id", id)
|
||||
}
|
||||
|
||||
if resp.Error != nil {
|
||||
@@ -251,7 +251,7 @@ func (c *Client) Call(ctx context.Context, method string, params any, result any
|
||||
func (c *Client) Notify(ctx context.Context, method string, params any) error {
|
||||
cnf := config.Get()
|
||||
if cnf.DebugLSP {
|
||||
logging.Debug("Sending notification", "method", method)
|
||||
slog.Debug("Sending notification", "method", method)
|
||||
}
|
||||
|
||||
msg, err := NewNotification(method, params)
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/opencode-ai/opencode/internal/lsp/protocol"
|
||||
"github.com/sst/opencode/internal/lsp/protocol"
|
||||
)
|
||||
|
||||
func applyTextEdits(uri protocol.DocumentUri, edits []protocol.TextEdit) error {
|
||||
|
||||
@@ -5,16 +5,17 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/bmatcuk/doublestar/v4"
|
||||
"github.com/fsnotify/fsnotify"
|
||||
"github.com/opencode-ai/opencode/internal/config"
|
||||
"github.com/opencode-ai/opencode/internal/logging"
|
||||
"github.com/opencode-ai/opencode/internal/lsp"
|
||||
"github.com/opencode-ai/opencode/internal/lsp/protocol"
|
||||
"github.com/sst/opencode/internal/config"
|
||||
"github.com/sst/opencode/internal/lsp"
|
||||
"github.com/sst/opencode/internal/lsp/protocol"
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
// WorkspaceWatcher manages LSP file watching
|
||||
@@ -45,7 +46,7 @@ func NewWorkspaceWatcher(client *lsp.Client) *WorkspaceWatcher {
|
||||
func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watchers []protocol.FileSystemWatcher) {
|
||||
cnf := config.Get()
|
||||
|
||||
logging.Debug("Adding file watcher registrations")
|
||||
slog.Debug("Adding file watcher registrations")
|
||||
w.registrationMu.Lock()
|
||||
defer w.registrationMu.Unlock()
|
||||
|
||||
@@ -54,33 +55,33 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc
|
||||
|
||||
// Print detailed registration information for debugging
|
||||
if cnf.DebugLSP {
|
||||
logging.Debug("Adding file watcher registrations",
|
||||
slog.Debug("Adding file watcher registrations",
|
||||
"id", id,
|
||||
"watchers", len(watchers),
|
||||
"total", len(w.registrations),
|
||||
)
|
||||
|
||||
for i, watcher := range watchers {
|
||||
logging.Debug("Registration", "index", i+1)
|
||||
slog.Debug("Registration", "index", i+1)
|
||||
|
||||
// Log the GlobPattern
|
||||
switch v := watcher.GlobPattern.Value.(type) {
|
||||
case string:
|
||||
logging.Debug("GlobPattern", "pattern", v)
|
||||
slog.Debug("GlobPattern", "pattern", v)
|
||||
case protocol.RelativePattern:
|
||||
logging.Debug("GlobPattern", "pattern", v.Pattern)
|
||||
slog.Debug("GlobPattern", "pattern", v.Pattern)
|
||||
|
||||
// Log BaseURI details
|
||||
switch u := v.BaseURI.Value.(type) {
|
||||
case string:
|
||||
logging.Debug("BaseURI", "baseURI", u)
|
||||
slog.Debug("BaseURI", "baseURI", u)
|
||||
case protocol.DocumentUri:
|
||||
logging.Debug("BaseURI", "baseURI", u)
|
||||
slog.Debug("BaseURI", "baseURI", u)
|
||||
default:
|
||||
logging.Debug("BaseURI", "baseURI", u)
|
||||
slog.Debug("BaseURI", "baseURI", u)
|
||||
}
|
||||
default:
|
||||
logging.Debug("GlobPattern", "unknown type", fmt.Sprintf("%T", v))
|
||||
slog.Debug("GlobPattern", "unknown type", fmt.Sprintf("%T", v))
|
||||
}
|
||||
|
||||
// Log WatchKind
|
||||
@@ -89,13 +90,13 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc
|
||||
watchKind = *watcher.Kind
|
||||
}
|
||||
|
||||
logging.Debug("WatchKind", "kind", watchKind)
|
||||
slog.Debug("WatchKind", "kind", watchKind)
|
||||
}
|
||||
}
|
||||
|
||||
// Determine server type for specialized handling
|
||||
serverName := getServerNameFromContext(ctx)
|
||||
logging.Debug("Server type detected", "serverName", serverName)
|
||||
slog.Debug("Server type detected", "serverName", serverName)
|
||||
|
||||
// Check if this server has sent file watchers
|
||||
hasFileWatchers := len(watchers) > 0
|
||||
@@ -123,7 +124,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc
|
||||
filesOpened += highPriorityFilesOpened
|
||||
|
||||
if cnf.DebugLSP {
|
||||
logging.Debug("Opened high-priority files",
|
||||
slog.Debug("Opened high-priority files",
|
||||
"count", highPriorityFilesOpened,
|
||||
"serverName", serverName)
|
||||
}
|
||||
@@ -131,7 +132,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc
|
||||
// If we've already opened enough high-priority files, we might not need more
|
||||
if filesOpened >= maxFilesToOpen {
|
||||
if cnf.DebugLSP {
|
||||
logging.Debug("Reached file limit with high-priority files",
|
||||
slog.Debug("Reached file limit with high-priority files",
|
||||
"filesOpened", filesOpened,
|
||||
"maxFiles", maxFilesToOpen)
|
||||
}
|
||||
@@ -149,7 +150,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc
|
||||
if d.IsDir() {
|
||||
if path != w.workspacePath && shouldExcludeDir(path) {
|
||||
if cnf.DebugLSP {
|
||||
logging.Debug("Skipping excluded directory", "path", path)
|
||||
slog.Debug("Skipping excluded directory", "path", path)
|
||||
}
|
||||
return filepath.SkipDir
|
||||
}
|
||||
@@ -177,7 +178,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc
|
||||
|
||||
elapsedTime := time.Since(startTime)
|
||||
if cnf.DebugLSP {
|
||||
logging.Debug("Limited workspace scan complete",
|
||||
slog.Debug("Limited workspace scan complete",
|
||||
"filesOpened", filesOpened,
|
||||
"maxFiles", maxFilesToOpen,
|
||||
"elapsedTime", elapsedTime.Seconds(),
|
||||
@@ -186,11 +187,11 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc
|
||||
}
|
||||
|
||||
if err != nil && cnf.DebugLSP {
|
||||
logging.Debug("Error scanning workspace for files to open", "error", err)
|
||||
slog.Debug("Error scanning workspace for files to open", "error", err)
|
||||
}
|
||||
}()
|
||||
} else if cnf.DebugLSP {
|
||||
logging.Debug("Using on-demand file loading for server", "server", serverName)
|
||||
slog.Debug("Using on-demand file loading for server", "server", serverName)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -263,7 +264,7 @@ func (w *WorkspaceWatcher) openHighPriorityFiles(ctx context.Context, serverName
|
||||
matches, err := doublestar.Glob(os.DirFS(w.workspacePath), pattern)
|
||||
if err != nil {
|
||||
if cnf.DebugLSP {
|
||||
logging.Debug("Error finding high-priority files", "pattern", pattern, "error", err)
|
||||
slog.Debug("Error finding high-priority files", "pattern", pattern, "error", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
@@ -281,12 +282,12 @@ func (w *WorkspaceWatcher) openHighPriorityFiles(ctx context.Context, serverName
|
||||
// Open the file
|
||||
if err := w.client.OpenFile(ctx, fullPath); err != nil {
|
||||
if cnf.DebugLSP {
|
||||
logging.Debug("Error opening high-priority file", "path", fullPath, "error", err)
|
||||
slog.Debug("Error opening high-priority file", "path", fullPath, "error", err)
|
||||
}
|
||||
} else {
|
||||
filesOpened++
|
||||
if cnf.DebugLSP {
|
||||
logging.Debug("Opened high-priority file", "path", fullPath)
|
||||
slog.Debug("Opened high-priority file", "path", fullPath)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -318,7 +319,7 @@ func (w *WorkspaceWatcher) WatchWorkspace(ctx context.Context, workspacePath str
|
||||
}
|
||||
|
||||
serverName := getServerNameFromContext(ctx)
|
||||
logging.Debug("Starting workspace watcher", "workspacePath", workspacePath, "serverName", serverName)
|
||||
slog.Debug("Starting workspace watcher", "workspacePath", workspacePath, "serverName", serverName)
|
||||
|
||||
// Register handler for file watcher registrations from the server
|
||||
lsp.RegisterFileWatchHandler(func(id string, watchers []protocol.FileSystemWatcher) {
|
||||
@@ -327,7 +328,7 @@ func (w *WorkspaceWatcher) WatchWorkspace(ctx context.Context, workspacePath str
|
||||
|
||||
watcher, err := fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
logging.Error("Error creating watcher", "error", err)
|
||||
slog.Error("Error creating watcher", "error", err)
|
||||
}
|
||||
defer watcher.Close()
|
||||
|
||||
@@ -341,7 +342,7 @@ func (w *WorkspaceWatcher) WatchWorkspace(ctx context.Context, workspacePath str
|
||||
if d.IsDir() && path != workspacePath {
|
||||
if shouldExcludeDir(path) {
|
||||
if cnf.DebugLSP {
|
||||
logging.Debug("Skipping excluded directory", "path", path)
|
||||
slog.Debug("Skipping excluded directory", "path", path)
|
||||
}
|
||||
return filepath.SkipDir
|
||||
}
|
||||
@@ -351,14 +352,14 @@ func (w *WorkspaceWatcher) WatchWorkspace(ctx context.Context, workspacePath str
|
||||
if d.IsDir() {
|
||||
err = watcher.Add(path)
|
||||
if err != nil {
|
||||
logging.Error("Error watching path", "path", path, "error", err)
|
||||
slog.Error("Error watching path", "path", path, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
logging.Error("Error walking workspace", "error", err)
|
||||
slog.Error("Error walking workspace", "error", err)
|
||||
}
|
||||
|
||||
// Event loop
|
||||
@@ -375,27 +376,37 @@ func (w *WorkspaceWatcher) WatchWorkspace(ctx context.Context, workspacePath str
|
||||
|
||||
// Add new directories to the watcher
|
||||
if event.Op&fsnotify.Create != 0 {
|
||||
if info, err := os.Stat(event.Name); err == nil {
|
||||
if info.IsDir() {
|
||||
// Skip excluded directories
|
||||
if !shouldExcludeDir(event.Name) {
|
||||
if err := watcher.Add(event.Name); err != nil {
|
||||
logging.Error("Error adding directory to watcher", "path", event.Name, "error", err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// For newly created files
|
||||
if !shouldExcludeFile(event.Name) {
|
||||
w.openMatchingFile(ctx, event.Name)
|
||||
// Check if the file/directory still exists before processing
|
||||
info, err := os.Stat(event.Name)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
// File was deleted between event and processing - ignore
|
||||
slog.Debug("File deleted between create event and stat", "path", event.Name)
|
||||
continue
|
||||
}
|
||||
slog.Error("Error getting file info", "path", event.Name, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if info.IsDir() {
|
||||
// Skip excluded directories
|
||||
if !shouldExcludeDir(event.Name) {
|
||||
if err := watcher.Add(event.Name); err != nil {
|
||||
slog.Error("Error adding directory to watcher", "path", event.Name, "error", err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// For newly created files
|
||||
if !shouldExcludeFile(event.Name) {
|
||||
w.openMatchingFile(ctx, event.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Debug logging
|
||||
if cnf.DebugLSP {
|
||||
matched, kind := w.isPathWatched(event.Name)
|
||||
logging.Debug("File event",
|
||||
slog.Debug("File event",
|
||||
"path", event.Name,
|
||||
"operation", event.Op.String(),
|
||||
"watched", matched,
|
||||
@@ -416,7 +427,7 @@ func (w *WorkspaceWatcher) WatchWorkspace(ctx context.Context, workspacePath str
|
||||
// Just send the notification if needed
|
||||
info, err := os.Stat(event.Name)
|
||||
if err != nil {
|
||||
logging.Error("Error getting file info", "path", event.Name, "error", err)
|
||||
slog.Error("Error getting file info", "path", event.Name, "error", err)
|
||||
return
|
||||
}
|
||||
if !info.IsDir() && watchKind&protocol.WatchCreate != 0 {
|
||||
@@ -444,7 +455,7 @@ func (w *WorkspaceWatcher) WatchWorkspace(ctx context.Context, workspacePath str
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
logging.Error("Error watching file", "error", err)
|
||||
slog.Error("Error watching file", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -569,7 +580,7 @@ func matchesSimpleGlob(pattern, path string) bool {
|
||||
// Fall back to simple matching for simpler patterns
|
||||
matched, err := filepath.Match(pattern, path)
|
||||
if err != nil {
|
||||
logging.Error("Error matching pattern", "pattern", pattern, "path", path, "error", err)
|
||||
slog.Error("Error matching pattern", "pattern", pattern, "path", path, "error", err)
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -580,7 +591,7 @@ func matchesSimpleGlob(pattern, path string) bool {
|
||||
func (w *WorkspaceWatcher) matchesPattern(path string, pattern protocol.GlobPattern) bool {
|
||||
patternInfo, err := pattern.AsPattern()
|
||||
if err != nil {
|
||||
logging.Error("Error parsing pattern", "pattern", pattern, "error", err)
|
||||
slog.Error("Error parsing pattern", "pattern", pattern, "error", err)
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -605,7 +616,7 @@ func (w *WorkspaceWatcher) matchesPattern(path string, pattern protocol.GlobPatt
|
||||
// Make path relative to basePath for matching
|
||||
relPath, err := filepath.Rel(basePath, path)
|
||||
if err != nil {
|
||||
logging.Error("Error getting relative path", "path", path, "basePath", basePath, "error", err)
|
||||
slog.Error("Error getting relative path", "path", path, "basePath", basePath, "error", err)
|
||||
return false
|
||||
}
|
||||
relPath = filepath.ToSlash(relPath)
|
||||
@@ -643,19 +654,55 @@ func (w *WorkspaceWatcher) debounceHandleFileEvent(ctx context.Context, uri stri
|
||||
func (w *WorkspaceWatcher) handleFileEvent(ctx context.Context, uri string, changeType protocol.FileChangeType) {
|
||||
// If the file is open and it's a change event, use didChange notification
|
||||
filePath := uri[7:] // Remove "file://" prefix
|
||||
|
||||
if changeType == protocol.FileChangeType(protocol.Deleted) {
|
||||
// Always clear diagnostics for deleted files
|
||||
w.client.ClearDiagnosticsForURI(protocol.DocumentUri(uri))
|
||||
} else if changeType == protocol.FileChangeType(protocol.Changed) && w.client.IsFileOpen(filePath) {
|
||||
err := w.client.NotifyChange(ctx, filePath)
|
||||
if err != nil {
|
||||
logging.Error("Error notifying change", "error", err)
|
||||
|
||||
// If the file was open, close it in the LSP client
|
||||
if w.client.IsFileOpen(filePath) {
|
||||
if err := w.client.CloseFile(ctx, filePath); err != nil {
|
||||
slog.Debug("Error closing deleted file in LSP client", "file", filePath, "error", err)
|
||||
// Continue anyway - the file is gone
|
||||
}
|
||||
}
|
||||
} else if changeType == protocol.FileChangeType(protocol.Changed) {
|
||||
// For changed files, verify the file still exists before notifying
|
||||
if _, err := os.Stat(filePath); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
// File was deleted between the event and now - treat as delete
|
||||
slog.Debug("File deleted between change event and processing", "file", filePath)
|
||||
w.handleFileEvent(ctx, uri, protocol.FileChangeType(protocol.Deleted))
|
||||
return
|
||||
}
|
||||
slog.Error("Error getting file info", "path", filePath, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// File exists and is open, notify change
|
||||
if w.client.IsFileOpen(filePath) {
|
||||
err := w.client.NotifyChange(ctx, filePath)
|
||||
if err != nil {
|
||||
slog.Error("Error notifying change", "error", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
} else if changeType == protocol.FileChangeType(protocol.Created) {
|
||||
// For created files, verify the file still exists before notifying
|
||||
if _, err := os.Stat(filePath); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
// File was deleted between the event and now - ignore
|
||||
slog.Debug("File deleted between create event and processing", "file", filePath)
|
||||
return
|
||||
}
|
||||
slog.Error("Error getting file info", "path", filePath, "error", err)
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Notify LSP server about the file event using didChangeWatchedFiles
|
||||
if err := w.notifyFileEvent(ctx, uri, changeType); err != nil {
|
||||
logging.Error("Error notifying LSP server about file event", "error", err)
|
||||
slog.Error("Error notifying LSP server about file event", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -663,7 +710,7 @@ func (w *WorkspaceWatcher) handleFileEvent(ctx context.Context, uri string, chan
|
||||
func (w *WorkspaceWatcher) notifyFileEvent(ctx context.Context, uri string, changeType protocol.FileChangeType) error {
|
||||
cnf := config.Get()
|
||||
if cnf.DebugLSP {
|
||||
logging.Debug("Notifying file event",
|
||||
slog.Debug("Notifying file event",
|
||||
"uri", uri,
|
||||
"changeType", changeType,
|
||||
)
|
||||
@@ -828,6 +875,11 @@ func shouldExcludeFile(filePath string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// Skip numeric temporary files (often created by editors)
|
||||
if _, err := strconv.Atoi(fileName); err == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check file size
|
||||
info, err := os.Stat(filePath)
|
||||
if err != nil {
|
||||
@@ -838,7 +890,7 @@ func shouldExcludeFile(filePath string) bool {
|
||||
// Skip large files
|
||||
if info.Size() > maxFileSize {
|
||||
if cnf.DebugLSP {
|
||||
logging.Debug("Skipping large file",
|
||||
slog.Debug("Skipping large file",
|
||||
"path", filePath,
|
||||
"size", info.Size(),
|
||||
"maxSize", maxFileSize,
|
||||
@@ -856,9 +908,19 @@ func shouldExcludeFile(filePath string) bool {
|
||||
// openMatchingFile opens a file if it matches any of the registered patterns
|
||||
func (w *WorkspaceWatcher) openMatchingFile(ctx context.Context, path string) {
|
||||
cnf := config.Get()
|
||||
// Skip directories
|
||||
// Skip directories and verify file exists
|
||||
info, err := os.Stat(path)
|
||||
if err != nil || info.IsDir() {
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
// File was deleted between event and processing - ignore
|
||||
slog.Debug("File deleted between event and openMatchingFile", "path", path)
|
||||
return
|
||||
}
|
||||
slog.Error("Error getting file info", "path", path, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if info.IsDir() {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -876,10 +938,10 @@ func (w *WorkspaceWatcher) openMatchingFile(ctx context.Context, path string) {
|
||||
// This helps with project initialization for certain language servers
|
||||
if isHighPriorityFile(path, serverName) {
|
||||
if cnf.DebugLSP {
|
||||
logging.Debug("Opening high-priority file", "path", path, "serverName", serverName)
|
||||
slog.Debug("Opening high-priority file", "path", path, "serverName", serverName)
|
||||
}
|
||||
if err := w.client.OpenFile(ctx, path); err != nil && cnf.DebugLSP {
|
||||
logging.Error("Error opening high-priority file", "path", path, "error", err)
|
||||
slog.Error("Error opening high-priority file", "path", path, "error", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -891,7 +953,7 @@ func (w *WorkspaceWatcher) openMatchingFile(ctx context.Context, path string) {
|
||||
// Check file size - for preloading we're more conservative
|
||||
if info.Size() > (1 * 1024 * 1024) { // 1MB limit for preloaded files
|
||||
if cnf.DebugLSP {
|
||||
logging.Debug("Skipping large file for preloading", "path", path, "size", info.Size())
|
||||
slog.Debug("Skipping large file for preloading", "path", path, "size", info.Size())
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -923,7 +985,7 @@ func (w *WorkspaceWatcher) openMatchingFile(ctx context.Context, path string) {
|
||||
if shouldOpen {
|
||||
// Don't need to check if it's already open - the client.OpenFile handles that
|
||||
if err := w.client.OpenFile(ctx, path); err != nil && cnf.DebugLSP {
|
||||
logging.Error("Error opening file", "path", path, "error", err)
|
||||
slog.Error("Error opening file", "path", path, "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
8
internal/message/attachment.go
Normal file
8
internal/message/attachment.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package message
|
||||
|
||||
type Attachment struct {
|
||||
FilePath string
|
||||
FileName string
|
||||
MimeType string
|
||||
Content []byte
|
||||
}
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/opencode-ai/opencode/internal/llm/models"
|
||||
"github.com/sst/opencode/internal/llm/models"
|
||||
)
|
||||
|
||||
type MessageRole string
|
||||
@@ -48,7 +48,10 @@ type TextContent struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
func (tc TextContent) String() string {
|
||||
func (tc *TextContent) String() string {
|
||||
if tc == nil {
|
||||
return ""
|
||||
}
|
||||
return tc.Text
|
||||
}
|
||||
|
||||
@@ -66,13 +69,17 @@ func (iuc ImageURLContent) String() string {
|
||||
func (ImageURLContent) isPart() {}
|
||||
|
||||
type BinaryContent struct {
|
||||
Path string
|
||||
MIMEType string
|
||||
Data []byte
|
||||
}
|
||||
|
||||
func (bc BinaryContent) String() string {
|
||||
func (bc BinaryContent) String(provider models.ModelProvider) string {
|
||||
base64Encoded := base64.StdEncoding.EncodeToString(bc.Data)
|
||||
return "data:" + bc.MIMEType + ";base64," + base64Encoded
|
||||
if provider == models.ProviderOpenAI {
|
||||
return "data:" + bc.MIMEType + ";base64," + base64Encoded
|
||||
}
|
||||
return base64Encoded
|
||||
}
|
||||
|
||||
func (BinaryContent) isPart() {}
|
||||
@@ -98,30 +105,24 @@ type ToolResult struct {
|
||||
func (ToolResult) isPart() {}
|
||||
|
||||
type Finish struct {
|
||||
Reason FinishReason `json:"reason"`
|
||||
Time time.Time `json:"time"`
|
||||
}
|
||||
|
||||
type DBFinish struct {
|
||||
Reason FinishReason `json:"reason"`
|
||||
Time int64 `json:"time"`
|
||||
}
|
||||
|
||||
func (Finish) isPart() {}
|
||||
|
||||
type Message struct {
|
||||
ID string
|
||||
Role MessageRole
|
||||
SessionID string
|
||||
Parts []ContentPart
|
||||
Model models.ModelID
|
||||
|
||||
CreatedAt int64
|
||||
UpdatedAt int64
|
||||
}
|
||||
|
||||
func (m *Message) Content() TextContent {
|
||||
func (m *Message) Content() *TextContent {
|
||||
for _, part := range m.Parts {
|
||||
if c, ok := part.(TextContent); ok {
|
||||
return c
|
||||
return &c
|
||||
}
|
||||
}
|
||||
return TextContent{}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Message) ReasoningContent() ReasoningContent {
|
||||
@@ -312,7 +313,7 @@ func (m *Message) AddFinish(reason FinishReason) {
|
||||
break
|
||||
}
|
||||
}
|
||||
m.Parts = append(m.Parts, Finish{Reason: reason, Time: time.Now().Unix()})
|
||||
m.Parts = append(m.Parts, Finish{Reason: reason, Time: time.Now()})
|
||||
}
|
||||
|
||||
func (m *Message) AddImageURL(url, detail string) {
|
||||
|
||||
@@ -5,12 +5,31 @@ import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/opencode-ai/opencode/internal/db"
|
||||
"github.com/opencode-ai/opencode/internal/llm/models"
|
||||
"github.com/opencode-ai/opencode/internal/pubsub"
|
||||
"github.com/sst/opencode/internal/db"
|
||||
"github.com/sst/opencode/internal/llm/models"
|
||||
"github.com/sst/opencode/internal/pubsub"
|
||||
)
|
||||
|
||||
type Message struct {
|
||||
ID string
|
||||
Role MessageRole
|
||||
SessionID string
|
||||
Parts []ContentPart
|
||||
Model models.ModelID
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
const (
|
||||
EventMessageCreated pubsub.EventType = "message_created"
|
||||
EventMessageUpdated pubsub.EventType = "message_updated"
|
||||
EventMessageDeleted pubsub.EventType = "message_deleted"
|
||||
)
|
||||
|
||||
type CreateMessageParams struct {
|
||||
@@ -20,145 +39,317 @@ type CreateMessageParams struct {
|
||||
}
|
||||
|
||||
type Service interface {
|
||||
pubsub.Suscriber[Message]
|
||||
pubsub.Subscriber[Message]
|
||||
|
||||
Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error)
|
||||
Update(ctx context.Context, message Message) error
|
||||
Update(ctx context.Context, message Message) (Message, error)
|
||||
Get(ctx context.Context, id string) (Message, error)
|
||||
List(ctx context.Context, sessionID string) ([]Message, error)
|
||||
ListAfter(ctx context.Context, sessionID string, timestamp time.Time) ([]Message, error)
|
||||
Delete(ctx context.Context, id string) error
|
||||
DeleteSessionMessages(ctx context.Context, sessionID string) error
|
||||
}
|
||||
|
||||
type service struct {
|
||||
*pubsub.Broker[Message]
|
||||
q db.Querier
|
||||
db *db.Queries
|
||||
broker *pubsub.Broker[Message]
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewService(q db.Querier) Service {
|
||||
return &service{
|
||||
Broker: pubsub.NewBroker[Message](),
|
||||
q: q,
|
||||
}
|
||||
}
|
||||
var globalMessageService *service
|
||||
|
||||
func (s *service) Delete(ctx context.Context, id string) error {
|
||||
message, err := s.Get(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
func InitService(dbConn *sql.DB) error {
|
||||
if globalMessageService != nil {
|
||||
return fmt.Errorf("message service already initialized")
|
||||
}
|
||||
err = s.q.DeleteMessage(ctx, message.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
queries := db.New(dbConn)
|
||||
broker := pubsub.NewBroker[Message]()
|
||||
|
||||
globalMessageService = &service{
|
||||
db: queries,
|
||||
broker: broker,
|
||||
}
|
||||
s.Publish(pubsub.DeletedEvent, message)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *service) Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error) {
|
||||
if params.Role != Assistant {
|
||||
params.Parts = append(params.Parts, Finish{
|
||||
Reason: "stop",
|
||||
})
|
||||
func GetService() Service {
|
||||
if globalMessageService == nil {
|
||||
panic("message service not initialized. Call message.InitService() first.")
|
||||
}
|
||||
partsJSON, err := marshallParts(params.Parts)
|
||||
if err != nil {
|
||||
return Message{}, err
|
||||
return globalMessageService
|
||||
}
|
||||
|
||||
func (s *service) Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
isFinished := false
|
||||
for _, p := range params.Parts {
|
||||
if _, ok := p.(Finish); ok {
|
||||
isFinished = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if params.Role == User && !isFinished {
|
||||
params.Parts = append(params.Parts, Finish{Reason: FinishReasonEndTurn, Time: time.Now()})
|
||||
}
|
||||
|
||||
dbMessage, err := s.q.CreateMessage(ctx, db.CreateMessageParams{
|
||||
partsJSON, err := marshallParts(params.Parts)
|
||||
if err != nil {
|
||||
return Message{}, fmt.Errorf("failed to marshal message parts: %w", err)
|
||||
}
|
||||
|
||||
dbMsgParams := db.CreateMessageParams{
|
||||
ID: uuid.New().String(),
|
||||
SessionID: sessionID,
|
||||
Role: string(params.Role),
|
||||
Parts: string(partsJSON),
|
||||
Model: sql.NullString{String: string(params.Model), Valid: true},
|
||||
})
|
||||
if err != nil {
|
||||
return Message{}, err
|
||||
Model: sql.NullString{String: string(params.Model), Valid: params.Model != ""},
|
||||
}
|
||||
|
||||
dbMessage, err := s.db.CreateMessage(ctx, dbMsgParams)
|
||||
if err != nil {
|
||||
return Message{}, fmt.Errorf("db.CreateMessage: %w", err)
|
||||
}
|
||||
|
||||
message, err := s.fromDBItem(dbMessage)
|
||||
if err != nil {
|
||||
return Message{}, err
|
||||
return Message{}, fmt.Errorf("failed to convert DB message: %w", err)
|
||||
}
|
||||
s.Publish(pubsub.CreatedEvent, message)
|
||||
|
||||
s.broker.Publish(EventMessageCreated, message)
|
||||
return message, nil
|
||||
}
|
||||
|
||||
func (s *service) DeleteSessionMessages(ctx context.Context, sessionID string) error {
|
||||
messages, err := s.List(ctx, sessionID)
|
||||
if err != nil {
|
||||
return err
|
||||
func (s *service) Update(ctx context.Context, message Message) (Message, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if message.ID == "" {
|
||||
return Message{}, fmt.Errorf("cannot update message with empty ID")
|
||||
}
|
||||
for _, message := range messages {
|
||||
if message.SessionID == sessionID {
|
||||
err = s.Delete(ctx, message.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
partsJSON, err := marshallParts(message.Parts)
|
||||
if err != nil {
|
||||
return Message{}, fmt.Errorf("failed to marshal message parts for update: %w", err)
|
||||
}
|
||||
|
||||
var dbFinishedAt sql.NullString
|
||||
finishPart := message.FinishPart()
|
||||
if finishPart != nil && !finishPart.Time.IsZero() {
|
||||
dbFinishedAt = sql.NullString{
|
||||
String: finishPart.Time.UTC().Format(time.RFC3339Nano),
|
||||
Valid: true,
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *service) Update(ctx context.Context, message Message) error {
|
||||
parts, err := marshallParts(message.Parts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
finishedAt := sql.NullInt64{}
|
||||
if f := message.FinishPart(); f != nil {
|
||||
finishedAt.Int64 = f.Time
|
||||
finishedAt.Valid = true
|
||||
}
|
||||
err = s.q.UpdateMessage(ctx, db.UpdateMessageParams{
|
||||
// UpdatedAt is handled by the DB trigger (strftime('%s', 'now'))
|
||||
err = s.db.UpdateMessage(ctx, db.UpdateMessageParams{
|
||||
ID: message.ID,
|
||||
Parts: string(parts),
|
||||
FinishedAt: finishedAt,
|
||||
Parts: string(partsJSON),
|
||||
FinishedAt: dbFinishedAt,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
return Message{}, fmt.Errorf("db.UpdateMessage: %w", err)
|
||||
}
|
||||
message.UpdatedAt = time.Now().Unix()
|
||||
s.Publish(pubsub.UpdatedEvent, message)
|
||||
return nil
|
||||
|
||||
dbUpdatedMessage, err := s.db.GetMessage(ctx, message.ID)
|
||||
if err != nil {
|
||||
return Message{}, fmt.Errorf("failed to fetch message after update: %w", err)
|
||||
}
|
||||
updatedMessage, err := s.fromDBItem(dbUpdatedMessage)
|
||||
if err != nil {
|
||||
return Message{}, fmt.Errorf("failed to convert updated DB message: %w", err)
|
||||
}
|
||||
|
||||
s.broker.Publish(EventMessageUpdated, updatedMessage)
|
||||
return updatedMessage, nil
|
||||
}
|
||||
|
||||
func (s *service) Get(ctx context.Context, id string) (Message, error) {
|
||||
dbMessage, err := s.q.GetMessage(ctx, id)
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
dbMessage, err := s.db.GetMessage(ctx, id)
|
||||
if err != nil {
|
||||
return Message{}, err
|
||||
if err == sql.ErrNoRows {
|
||||
return Message{}, fmt.Errorf("message with ID '%s' not found", id)
|
||||
}
|
||||
return Message{}, fmt.Errorf("db.GetMessage: %w", err)
|
||||
}
|
||||
return s.fromDBItem(dbMessage)
|
||||
}
|
||||
|
||||
func (s *service) List(ctx context.Context, sessionID string) ([]Message, error) {
|
||||
dbMessages, err := s.q.ListMessagesBySession(ctx, sessionID)
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
dbMessages, err := s.db.ListMessagesBySession(ctx, sessionID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("db.ListMessagesBySession: %w", err)
|
||||
}
|
||||
messages := make([]Message, len(dbMessages))
|
||||
for i, dbMsg := range dbMessages {
|
||||
msg, convErr := s.fromDBItem(dbMsg)
|
||||
if convErr != nil {
|
||||
return nil, fmt.Errorf("failed to convert DB message at index %d: %w", i, convErr)
|
||||
}
|
||||
messages[i] = msg
|
||||
}
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
func (s *service) ListAfter(ctx context.Context, sessionID string, timestamp time.Time) ([]Message, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
dbMessages, err := s.db.ListMessagesBySessionAfter(ctx, db.ListMessagesBySessionAfterParams{
|
||||
SessionID: sessionID,
|
||||
CreatedAt: timestamp.Format(time.RFC3339Nano),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("db.ListMessagesBySessionAfter: %w", err)
|
||||
}
|
||||
messages := make([]Message, len(dbMessages))
|
||||
for i, dbMsg := range dbMessages {
|
||||
msg, convErr := s.fromDBItem(dbMsg)
|
||||
if convErr != nil {
|
||||
return nil, fmt.Errorf("failed to convert DB message at index %d (ListAfter): %w", i, convErr)
|
||||
}
|
||||
messages[i] = msg
|
||||
}
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
func (s *service) Delete(ctx context.Context, id string) error {
|
||||
s.mu.Lock()
|
||||
messageToPublish, err := s.getServiceForPublish(ctx, id)
|
||||
s.mu.Unlock()
|
||||
|
||||
if err != nil {
|
||||
// If error was due to not found, it's not a critical failure for deletion intent
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
return nil // Or return the error if strictness is required
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
err = s.db.DeleteMessage(ctx, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("db.DeleteMessage: %w", err)
|
||||
}
|
||||
|
||||
if messageToPublish != nil {
|
||||
s.broker.Publish(EventMessageDeleted, *messageToPublish)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *service) getServiceForPublish(ctx context.Context, id string) (*Message, error) {
|
||||
dbMsg, err := s.db.GetMessage(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
messages := make([]Message, len(dbMessages))
|
||||
for i, dbMessage := range dbMessages {
|
||||
messages[i], err = s.fromDBItem(dbMessage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
msg, convErr := s.fromDBItem(dbMsg)
|
||||
if convErr != nil {
|
||||
return nil, fmt.Errorf("failed to convert DB message for publishing: %w", convErr)
|
||||
}
|
||||
return &msg, nil
|
||||
}
|
||||
|
||||
func (s *service) DeleteSessionMessages(ctx context.Context, sessionID string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
messagesToDelete, err := s.db.ListMessagesBySession(ctx, sessionID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to list messages for deletion: %w", err)
|
||||
}
|
||||
|
||||
err = s.db.DeleteSessionMessages(ctx, sessionID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("db.DeleteSessionMessages: %w", err)
|
||||
}
|
||||
|
||||
for _, dbMsg := range messagesToDelete {
|
||||
msg, convErr := s.fromDBItem(dbMsg)
|
||||
if convErr == nil {
|
||||
s.broker.Publish(EventMessageDeleted, msg)
|
||||
} else {
|
||||
slog.Error("Failed to convert DB message for delete event publishing", "id", dbMsg.ID, "error", convErr)
|
||||
}
|
||||
}
|
||||
return messages, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *service) Subscribe(ctx context.Context) <-chan pubsub.Event[Message] {
|
||||
return s.broker.Subscribe(ctx)
|
||||
}
|
||||
|
||||
func (s *service) fromDBItem(item db.Message) (Message, error) {
|
||||
parts, err := unmarshallParts([]byte(item.Parts))
|
||||
if err != nil {
|
||||
return Message{}, err
|
||||
return Message{}, fmt.Errorf("unmarshallParts for message ID %s: %w. Raw parts: %s", item.ID, err, item.Parts)
|
||||
}
|
||||
return Message{
|
||||
|
||||
// Parse timestamps from ISO strings
|
||||
createdAt, err := time.Parse(time.RFC3339Nano, item.CreatedAt)
|
||||
if err != nil {
|
||||
slog.Error("Failed to parse created_at", "value", item.CreatedAt, "error", err)
|
||||
createdAt = time.Now() // Fallback
|
||||
}
|
||||
|
||||
updatedAt, err := time.Parse(time.RFC3339Nano, item.UpdatedAt)
|
||||
if err != nil {
|
||||
slog.Error("Failed to parse created_at", "value", item.CreatedAt, "error", err)
|
||||
updatedAt = time.Now() // Fallback
|
||||
}
|
||||
|
||||
msg := Message{
|
||||
ID: item.ID,
|
||||
SessionID: item.SessionID,
|
||||
Role: MessageRole(item.Role),
|
||||
Parts: parts,
|
||||
Model: models.ModelID(item.Model.String),
|
||||
CreatedAt: item.CreatedAt,
|
||||
UpdatedAt: item.UpdatedAt,
|
||||
}, nil
|
||||
CreatedAt: createdAt,
|
||||
UpdatedAt: updatedAt,
|
||||
}
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error) {
|
||||
return GetService().Create(ctx, sessionID, params)
|
||||
}
|
||||
|
||||
func Update(ctx context.Context, message Message) (Message, error) {
|
||||
return GetService().Update(ctx, message)
|
||||
}
|
||||
|
||||
func Get(ctx context.Context, id string) (Message, error) {
|
||||
return GetService().Get(ctx, id)
|
||||
}
|
||||
|
||||
func List(ctx context.Context, sessionID string) ([]Message, error) {
|
||||
return GetService().List(ctx, sessionID)
|
||||
}
|
||||
|
||||
func ListAfter(ctx context.Context, sessionID string, timestamp time.Time) ([]Message, error) {
|
||||
return GetService().ListAfter(ctx, sessionID, timestamp)
|
||||
}
|
||||
|
||||
func Delete(ctx context.Context, id string) error {
|
||||
return GetService().Delete(ctx, id)
|
||||
}
|
||||
|
||||
func DeleteSessionMessages(ctx context.Context, sessionID string) error {
|
||||
return GetService().DeleteSessionMessages(ctx, sessionID)
|
||||
}
|
||||
|
||||
func Subscribe(ctx context.Context) <-chan pubsub.Event[Message] {
|
||||
return GetService().Subscribe(ctx)
|
||||
}
|
||||
|
||||
type partType string
|
||||
@@ -174,109 +365,139 @@ const (
|
||||
)
|
||||
|
||||
type partWrapper struct {
|
||||
Type partType `json:"type"`
|
||||
Data ContentPart `json:"data"`
|
||||
Type partType `json:"type"`
|
||||
Data json.RawMessage `json:"data"`
|
||||
}
|
||||
|
||||
func marshallParts(parts []ContentPart) ([]byte, error) {
|
||||
wrappedParts := make([]partWrapper, len(parts))
|
||||
|
||||
wrappedParts := make([]json.RawMessage, len(parts))
|
||||
for i, part := range parts {
|
||||
var typ partType
|
||||
var dataBytes []byte
|
||||
var err error
|
||||
|
||||
switch part.(type) {
|
||||
switch p := part.(type) {
|
||||
case ReasoningContent:
|
||||
typ = reasoningType
|
||||
dataBytes, err = json.Marshal(p)
|
||||
case TextContent:
|
||||
typ = textType
|
||||
dataBytes, err = json.Marshal(p)
|
||||
case *TextContent:
|
||||
typ = textType
|
||||
dataBytes, err = json.Marshal(p)
|
||||
case ImageURLContent:
|
||||
typ = imageURLType
|
||||
dataBytes, err = json.Marshal(p)
|
||||
case BinaryContent:
|
||||
typ = binaryType
|
||||
dataBytes, err = json.Marshal(p)
|
||||
case ToolCall:
|
||||
typ = toolCallType
|
||||
dataBytes, err = json.Marshal(p)
|
||||
case ToolResult:
|
||||
typ = toolResultType
|
||||
dataBytes, err = json.Marshal(p)
|
||||
case Finish:
|
||||
typ = finishType
|
||||
var dbFinish DBFinish
|
||||
dbFinish.Reason = p.Reason
|
||||
dbFinish.Time = p.Time.UnixMilli()
|
||||
dataBytes, err = json.Marshal(dbFinish)
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown part type: %T", part)
|
||||
return nil, fmt.Errorf("unknown part type for marshalling: %T", part)
|
||||
}
|
||||
|
||||
wrappedParts[i] = partWrapper{
|
||||
Type: typ,
|
||||
Data: part,
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal part data for type %s: %w", typ, err)
|
||||
}
|
||||
wrapper := struct {
|
||||
Type partType `json:"type"`
|
||||
Data json.RawMessage `json:"data"`
|
||||
}{Type: typ, Data: dataBytes}
|
||||
wrappedBytes, err := json.Marshal(wrapper)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal part wrapper for type %s: %w", typ, err)
|
||||
}
|
||||
wrappedParts[i] = wrappedBytes
|
||||
}
|
||||
return json.Marshal(wrappedParts)
|
||||
}
|
||||
|
||||
func unmarshallParts(data []byte) ([]ContentPart, error) {
|
||||
temp := []json.RawMessage{}
|
||||
|
||||
if err := json.Unmarshal(data, &temp); err != nil {
|
||||
return nil, err
|
||||
var rawMessages []json.RawMessage
|
||||
if err := json.Unmarshal(data, &rawMessages); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal parts data as array: %w. Data: %s", err, string(data))
|
||||
}
|
||||
|
||||
parts := make([]ContentPart, 0)
|
||||
|
||||
for _, rawPart := range temp {
|
||||
var wrapper struct {
|
||||
Type partType `json:"type"`
|
||||
Data json.RawMessage `json:"data"`
|
||||
}
|
||||
|
||||
parts := make([]ContentPart, 0, len(rawMessages))
|
||||
for _, rawPart := range rawMessages {
|
||||
var wrapper partWrapper
|
||||
if err := json.Unmarshal(rawPart, &wrapper); err != nil {
|
||||
return nil, err
|
||||
// Fallback for old format where parts might be just TextContent string
|
||||
var text string
|
||||
if errText := json.Unmarshal(rawPart, &text); errText == nil {
|
||||
parts = append(parts, TextContent{Text: text})
|
||||
continue
|
||||
}
|
||||
return nil, fmt.Errorf("failed to unmarshal part wrapper: %w. Raw part: %s", err, string(rawPart))
|
||||
}
|
||||
|
||||
switch wrapper.Type {
|
||||
case reasoningType:
|
||||
part := ReasoningContent{}
|
||||
if err := json.Unmarshal(wrapper.Data, &part); err != nil {
|
||||
return nil, err
|
||||
var p ReasoningContent
|
||||
if err := json.Unmarshal(wrapper.Data, &p); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal ReasoningContent: %w. Data: %s", err, string(wrapper.Data))
|
||||
}
|
||||
parts = append(parts, part)
|
||||
parts = append(parts, p)
|
||||
case textType:
|
||||
part := TextContent{}
|
||||
if err := json.Unmarshal(wrapper.Data, &part); err != nil {
|
||||
return nil, err
|
||||
var p TextContent
|
||||
if err := json.Unmarshal(wrapper.Data, &p); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal TextContent: %w. Data: %s", err, string(wrapper.Data))
|
||||
}
|
||||
parts = append(parts, part)
|
||||
parts = append(parts, p)
|
||||
case imageURLType:
|
||||
part := ImageURLContent{}
|
||||
if err := json.Unmarshal(wrapper.Data, &part); err != nil {
|
||||
return nil, err
|
||||
var p ImageURLContent
|
||||
if err := json.Unmarshal(wrapper.Data, &p); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal ImageURLContent: %w. Data: %s", err, string(wrapper.Data))
|
||||
}
|
||||
parts = append(parts, p)
|
||||
case binaryType:
|
||||
part := BinaryContent{}
|
||||
if err := json.Unmarshal(wrapper.Data, &part); err != nil {
|
||||
return nil, err
|
||||
var p BinaryContent
|
||||
if err := json.Unmarshal(wrapper.Data, &p); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal BinaryContent: %w. Data: %s", err, string(wrapper.Data))
|
||||
}
|
||||
parts = append(parts, part)
|
||||
parts = append(parts, p)
|
||||
case toolCallType:
|
||||
part := ToolCall{}
|
||||
if err := json.Unmarshal(wrapper.Data, &part); err != nil {
|
||||
return nil, err
|
||||
var p ToolCall
|
||||
if err := json.Unmarshal(wrapper.Data, &p); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal ToolCall: %w. Data: %s", err, string(wrapper.Data))
|
||||
}
|
||||
parts = append(parts, part)
|
||||
parts = append(parts, p)
|
||||
case toolResultType:
|
||||
part := ToolResult{}
|
||||
if err := json.Unmarshal(wrapper.Data, &part); err != nil {
|
||||
return nil, err
|
||||
var p ToolResult
|
||||
if err := json.Unmarshal(wrapper.Data, &p); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal ToolResult: %w. Data: %s", err, string(wrapper.Data))
|
||||
}
|
||||
parts = append(parts, part)
|
||||
parts = append(parts, p)
|
||||
case finishType:
|
||||
part := Finish{}
|
||||
if err := json.Unmarshal(wrapper.Data, &part); err != nil {
|
||||
return nil, err
|
||||
var p DBFinish
|
||||
if err := json.Unmarshal(wrapper.Data, &p); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal Finish: %w. Data: %s", err, string(wrapper.Data))
|
||||
}
|
||||
parts = append(parts, part)
|
||||
parts = append(parts, Finish{Reason: FinishReason(p.Reason), Time: time.UnixMilli(p.Time)})
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown part type: %s", wrapper.Type)
|
||||
slog.Warn("Unknown part type during unmarshalling, attempting to parse as TextContent", "type", wrapper.Type, "data", string(wrapper.Data))
|
||||
// Fallback: if type is unknown or empty, try to parse data as TextContent directly
|
||||
var p TextContent
|
||||
if err := json.Unmarshal(wrapper.Data, &p); err == nil {
|
||||
parts = append(parts, p)
|
||||
} else {
|
||||
// If that also fails, log it but continue if possible, or return error
|
||||
slog.Error("Failed to unmarshal unknown part type and fallback to TextContent failed", "type", wrapper.Type, "data", string(wrapper.Data), "error", err)
|
||||
// Depending on strictness, you might return an error here:
|
||||
// return nil, fmt.Errorf("unknown part type '%s' and failed fallback: %w", wrapper.Type, err)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return parts, nil
|
||||
}
|
||||
|
||||
@@ -1,14 +1,18 @@
|
||||
package permission
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"log/slog"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/opencode-ai/opencode/internal/config"
|
||||
"github.com/opencode-ai/opencode/internal/pubsub"
|
||||
"github.com/sst/opencode/internal/config"
|
||||
"github.com/sst/opencode/internal/pubsub"
|
||||
)
|
||||
|
||||
var ErrorPermissionDenied = errors.New("permission denied")
|
||||
@@ -32,56 +36,141 @@ type PermissionRequest struct {
|
||||
Path string `json:"path"`
|
||||
}
|
||||
|
||||
type PermissionResponse struct {
|
||||
Request PermissionRequest
|
||||
Granted bool
|
||||
}
|
||||
|
||||
const (
|
||||
EventPermissionRequested pubsub.EventType = "permission_requested"
|
||||
EventPermissionGranted pubsub.EventType = "permission_granted"
|
||||
EventPermissionDenied pubsub.EventType = "permission_denied"
|
||||
EventPermissionPersisted pubsub.EventType = "permission_persisted"
|
||||
)
|
||||
|
||||
type Service interface {
|
||||
pubsub.Suscriber[PermissionRequest]
|
||||
GrantPersistant(permission PermissionRequest)
|
||||
Grant(permission PermissionRequest)
|
||||
Deny(permission PermissionRequest)
|
||||
Request(opts CreatePermissionRequest) bool
|
||||
AutoApproveSession(sessionID string)
|
||||
pubsub.Subscriber[PermissionRequest]
|
||||
SubscribeToResponseEvents(ctx context.Context) <-chan pubsub.Event[PermissionResponse]
|
||||
|
||||
GrantPersistant(ctx context.Context, permission PermissionRequest)
|
||||
Grant(ctx context.Context, permission PermissionRequest)
|
||||
Deny(ctx context.Context, permission PermissionRequest)
|
||||
Request(ctx context.Context, opts CreatePermissionRequest) bool
|
||||
AutoApproveSession(ctx context.Context, sessionID string)
|
||||
IsAutoApproved(ctx context.Context, sessionID string) bool
|
||||
}
|
||||
|
||||
type permissionService struct {
|
||||
*pubsub.Broker[PermissionRequest]
|
||||
broker *pubsub.Broker[PermissionRequest]
|
||||
responseBroker *pubsub.Broker[PermissionResponse]
|
||||
|
||||
sessionPermissions []PermissionRequest
|
||||
sessionPermissions map[string][]PermissionRequest
|
||||
pendingRequests sync.Map
|
||||
autoApproveSessions []string
|
||||
autoApproveSessions map[string]bool
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func (s *permissionService) GrantPersistant(permission PermissionRequest) {
|
||||
var globalPermissionService *permissionService
|
||||
|
||||
func InitService() error {
|
||||
if globalPermissionService != nil {
|
||||
return fmt.Errorf("permission service already initialized")
|
||||
}
|
||||
globalPermissionService = &permissionService{
|
||||
broker: pubsub.NewBroker[PermissionRequest](),
|
||||
responseBroker: pubsub.NewBroker[PermissionResponse](),
|
||||
sessionPermissions: make(map[string][]PermissionRequest),
|
||||
autoApproveSessions: make(map[string]bool),
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetService() *permissionService {
|
||||
if globalPermissionService == nil {
|
||||
panic("permission service not initialized. Call permission.InitService() first.")
|
||||
}
|
||||
return globalPermissionService
|
||||
}
|
||||
|
||||
func (s *permissionService) GrantPersistant(ctx context.Context, permission PermissionRequest) {
|
||||
s.mu.Lock()
|
||||
s.sessionPermissions[permission.SessionID] = append(s.sessionPermissions[permission.SessionID], permission)
|
||||
s.mu.Unlock()
|
||||
|
||||
respCh, ok := s.pendingRequests.Load(permission.ID)
|
||||
if ok {
|
||||
respCh.(chan bool) <- true
|
||||
select {
|
||||
case respCh.(chan bool) <- true:
|
||||
case <-ctx.Done():
|
||||
slog.Warn("Context cancelled while sending grant persistent response", "request_id", permission.ID)
|
||||
}
|
||||
}
|
||||
s.sessionPermissions = append(s.sessionPermissions, permission)
|
||||
s.responseBroker.Publish(EventPermissionPersisted, PermissionResponse{Request: permission, Granted: true})
|
||||
}
|
||||
|
||||
func (s *permissionService) Grant(permission PermissionRequest) {
|
||||
func (s *permissionService) Grant(ctx context.Context, permission PermissionRequest) {
|
||||
respCh, ok := s.pendingRequests.Load(permission.ID)
|
||||
if ok {
|
||||
respCh.(chan bool) <- true
|
||||
select {
|
||||
case respCh.(chan bool) <- true:
|
||||
case <-ctx.Done():
|
||||
slog.Warn("Context cancelled while sending grant response", "request_id", permission.ID)
|
||||
}
|
||||
}
|
||||
s.responseBroker.Publish(EventPermissionGranted, PermissionResponse{Request: permission, Granted: true})
|
||||
}
|
||||
|
||||
func (s *permissionService) Deny(permission PermissionRequest) {
|
||||
func (s *permissionService) Deny(ctx context.Context, permission PermissionRequest) {
|
||||
respCh, ok := s.pendingRequests.Load(permission.ID)
|
||||
if ok {
|
||||
respCh.(chan bool) <- false
|
||||
select {
|
||||
case respCh.(chan bool) <- false:
|
||||
case <-ctx.Done():
|
||||
slog.Warn("Context cancelled while sending deny response", "request_id", permission.ID)
|
||||
}
|
||||
}
|
||||
s.responseBroker.Publish(EventPermissionDenied, PermissionResponse{Request: permission, Granted: false})
|
||||
}
|
||||
|
||||
func (s *permissionService) Request(opts CreatePermissionRequest) bool {
|
||||
if slices.Contains(s.autoApproveSessions, opts.SessionID) {
|
||||
func (s *permissionService) Request(ctx context.Context, opts CreatePermissionRequest) bool {
|
||||
s.mu.RLock()
|
||||
if s.autoApproveSessions[opts.SessionID] {
|
||||
s.mu.RUnlock()
|
||||
return true
|
||||
}
|
||||
dir := filepath.Dir(opts.Path)
|
||||
if dir == "." {
|
||||
dir = config.WorkingDirectory()
|
||||
|
||||
requestPath := opts.Path
|
||||
if !filepath.IsAbs(requestPath) {
|
||||
requestPath = filepath.Join(config.WorkingDirectory(), requestPath)
|
||||
}
|
||||
permission := PermissionRequest{
|
||||
requestPath = filepath.Clean(requestPath)
|
||||
|
||||
if permissions, ok := s.sessionPermissions[opts.SessionID]; ok {
|
||||
for _, p := range permissions {
|
||||
storedPath := p.Path
|
||||
if !filepath.IsAbs(storedPath) {
|
||||
storedPath = filepath.Join(config.WorkingDirectory(), storedPath)
|
||||
}
|
||||
storedPath = filepath.Clean(storedPath)
|
||||
|
||||
if p.ToolName == opts.ToolName && p.Action == opts.Action &&
|
||||
(requestPath == storedPath || strings.HasPrefix(requestPath, storedPath+string(filepath.Separator))) {
|
||||
s.mu.RUnlock()
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
|
||||
normalizedPath := opts.Path
|
||||
if !filepath.IsAbs(normalizedPath) {
|
||||
normalizedPath = filepath.Join(config.WorkingDirectory(), normalizedPath)
|
||||
}
|
||||
normalizedPath = filepath.Clean(normalizedPath)
|
||||
|
||||
permissionReq := PermissionRequest{
|
||||
ID: uuid.New().String(),
|
||||
Path: dir,
|
||||
Path: normalizedPath,
|
||||
SessionID: opts.SessionID,
|
||||
ToolName: opts.ToolName,
|
||||
Description: opts.Description,
|
||||
@@ -89,31 +178,69 @@ func (s *permissionService) Request(opts CreatePermissionRequest) bool {
|
||||
Params: opts.Params,
|
||||
}
|
||||
|
||||
for _, p := range s.sessionPermissions {
|
||||
if p.ToolName == permission.ToolName && p.Action == permission.Action && p.SessionID == permission.SessionID && p.Path == permission.Path {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
respCh := make(chan bool, 1)
|
||||
s.pendingRequests.Store(permissionReq.ID, respCh)
|
||||
defer s.pendingRequests.Delete(permissionReq.ID)
|
||||
|
||||
s.pendingRequests.Store(permission.ID, respCh)
|
||||
defer s.pendingRequests.Delete(permission.ID)
|
||||
s.broker.Publish(EventPermissionRequested, permissionReq)
|
||||
|
||||
s.Publish(pubsub.CreatedEvent, permission)
|
||||
|
||||
// Wait for the response with a timeout
|
||||
resp := <-respCh
|
||||
return resp
|
||||
}
|
||||
|
||||
func (s *permissionService) AutoApproveSession(sessionID string) {
|
||||
s.autoApproveSessions = append(s.autoApproveSessions, sessionID)
|
||||
}
|
||||
|
||||
func NewPermissionService() Service {
|
||||
return &permissionService{
|
||||
Broker: pubsub.NewBroker[PermissionRequest](),
|
||||
sessionPermissions: make([]PermissionRequest, 0),
|
||||
select {
|
||||
case resp := <-respCh:
|
||||
return resp
|
||||
case <-ctx.Done():
|
||||
slog.Warn("Permission request timed out or context cancelled", "request_id", permissionReq.ID, "tool", opts.ToolName)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (s *permissionService) AutoApproveSession(ctx context.Context, sessionID string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.autoApproveSessions[sessionID] = true
|
||||
}
|
||||
|
||||
func (s *permissionService) IsAutoApproved(ctx context.Context, sessionID string) bool {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.autoApproveSessions[sessionID]
|
||||
}
|
||||
|
||||
func (s *permissionService) Subscribe(ctx context.Context) <-chan pubsub.Event[PermissionRequest] {
|
||||
return s.broker.Subscribe(ctx)
|
||||
}
|
||||
|
||||
func (s *permissionService) SubscribeToResponseEvents(ctx context.Context) <-chan pubsub.Event[PermissionResponse] {
|
||||
return s.responseBroker.Subscribe(ctx)
|
||||
}
|
||||
|
||||
func GrantPersistant(ctx context.Context, permission PermissionRequest) {
|
||||
GetService().GrantPersistant(ctx, permission)
|
||||
}
|
||||
|
||||
func Grant(ctx context.Context, permission PermissionRequest) {
|
||||
GetService().Grant(ctx, permission)
|
||||
}
|
||||
|
||||
func Deny(ctx context.Context, permission PermissionRequest) {
|
||||
GetService().Deny(ctx, permission)
|
||||
}
|
||||
|
||||
func Request(ctx context.Context, opts CreatePermissionRequest) bool {
|
||||
return GetService().Request(ctx, opts)
|
||||
}
|
||||
|
||||
func AutoApproveSession(ctx context.Context, sessionID string) {
|
||||
GetService().AutoApproveSession(ctx, sessionID)
|
||||
}
|
||||
|
||||
func IsAutoApproved(ctx context.Context, sessionID string) bool {
|
||||
return GetService().IsAutoApproved(ctx, sessionID)
|
||||
}
|
||||
|
||||
func SubscribeToRequests(ctx context.Context) <-chan pubsub.Event[PermissionRequest] {
|
||||
return GetService().Subscribe(ctx)
|
||||
}
|
||||
|
||||
func SubscribeToResponses(ctx context.Context) <-chan pubsub.Event[PermissionResponse] {
|
||||
return GetService().SubscribeToResponseEvents(ctx)
|
||||
}
|
||||
|
||||
@@ -2,115 +2,112 @@ package pubsub
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const bufferSize = 64
|
||||
const defaultChannelBufferSize = 100
|
||||
|
||||
type Broker[T any] struct {
|
||||
subs map[chan Event[T]]struct{}
|
||||
mu sync.RWMutex
|
||||
done chan struct{}
|
||||
subCount int
|
||||
maxEvents int
|
||||
subs map[chan Event[T]]context.CancelFunc
|
||||
mu sync.RWMutex
|
||||
isClosed bool
|
||||
}
|
||||
|
||||
func NewBroker[T any]() *Broker[T] {
|
||||
return NewBrokerWithOptions[T](bufferSize, 1000)
|
||||
}
|
||||
|
||||
func NewBrokerWithOptions[T any](channelBufferSize, maxEvents int) *Broker[T] {
|
||||
b := &Broker[T]{
|
||||
subs: make(map[chan Event[T]]struct{}),
|
||||
done: make(chan struct{}),
|
||||
subCount: 0,
|
||||
maxEvents: maxEvents,
|
||||
return &Broker[T]{
|
||||
subs: make(map[chan Event[T]]context.CancelFunc),
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *Broker[T]) Shutdown() {
|
||||
select {
|
||||
case <-b.done: // Already closed
|
||||
return
|
||||
default:
|
||||
close(b.done)
|
||||
}
|
||||
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
for ch := range b.subs {
|
||||
delete(b.subs, ch)
|
||||
close(ch)
|
||||
if b.isClosed {
|
||||
b.mu.Unlock()
|
||||
return
|
||||
}
|
||||
b.isClosed = true
|
||||
|
||||
b.subCount = 0
|
||||
for ch, cancel := range b.subs {
|
||||
cancel()
|
||||
close(ch)
|
||||
delete(b.subs, ch)
|
||||
}
|
||||
b.mu.Unlock()
|
||||
slog.Debug("PubSub broker shut down", "type", fmt.Sprintf("%T", *new(T)))
|
||||
}
|
||||
|
||||
func (b *Broker[T]) Subscribe(ctx context.Context) <-chan Event[T] {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
select {
|
||||
case <-b.done:
|
||||
ch := make(chan Event[T])
|
||||
close(ch)
|
||||
return ch
|
||||
default:
|
||||
if b.isClosed {
|
||||
closedCh := make(chan Event[T])
|
||||
close(closedCh)
|
||||
return closedCh
|
||||
}
|
||||
|
||||
sub := make(chan Event[T], bufferSize)
|
||||
b.subs[sub] = struct{}{}
|
||||
b.subCount++
|
||||
subCtx, subCancel := context.WithCancel(ctx)
|
||||
subscriberChannel := make(chan Event[T], defaultChannelBufferSize)
|
||||
b.subs[subscriberChannel] = subCancel
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
|
||||
<-subCtx.Done()
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
select {
|
||||
case <-b.done:
|
||||
return
|
||||
default:
|
||||
if _, ok := b.subs[subscriberChannel]; ok {
|
||||
close(subscriberChannel)
|
||||
delete(b.subs, subscriberChannel)
|
||||
}
|
||||
|
||||
delete(b.subs, sub)
|
||||
close(sub)
|
||||
b.subCount--
|
||||
}()
|
||||
|
||||
return sub
|
||||
return subscriberChannel
|
||||
}
|
||||
|
||||
func (b *Broker[T]) Publish(eventType EventType, payload T) {
|
||||
b.mu.RLock()
|
||||
defer b.mu.RUnlock()
|
||||
|
||||
if b.isClosed {
|
||||
slog.Warn("Attempted to publish on a closed pubsub broker", "type", eventType, "payload_type", fmt.Sprintf("%T", payload))
|
||||
return
|
||||
}
|
||||
|
||||
event := Event[T]{Type: eventType, Payload: payload}
|
||||
|
||||
for ch := range b.subs {
|
||||
// Non-blocking send with a fallback to a goroutine to prevent slow subscribers
|
||||
// from blocking the publisher.
|
||||
select {
|
||||
case ch <- event:
|
||||
// Successfully sent
|
||||
default:
|
||||
// Subscriber channel is full or receiver is slow.
|
||||
// Send in a new goroutine to avoid blocking the publisher.
|
||||
// This might lead to out-of-order delivery for this specific slow subscriber.
|
||||
go func(sChan chan Event[T], ev Event[T]) {
|
||||
// Re-check if broker is closed before attempting send in goroutine
|
||||
b.mu.RLock()
|
||||
isBrokerClosed := b.isClosed
|
||||
b.mu.RUnlock()
|
||||
if isBrokerClosed {
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case sChan <- ev:
|
||||
case <-time.After(2 * time.Second): // Timeout for slow subscriber
|
||||
slog.Warn("PubSub: Dropped event for slow subscriber after timeout", "type", ev.Type)
|
||||
}
|
||||
}(ch, event)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Broker[T]) GetSubscriberCount() int {
|
||||
b.mu.RLock()
|
||||
defer b.mu.RUnlock()
|
||||
return b.subCount
|
||||
}
|
||||
|
||||
func (b *Broker[T]) Publish(t EventType, payload T) {
|
||||
b.mu.RLock()
|
||||
select {
|
||||
case <-b.done:
|
||||
b.mu.RUnlock()
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
subscribers := make([]chan Event[T], 0, len(b.subs))
|
||||
for sub := range b.subs {
|
||||
subscribers = append(subscribers, sub)
|
||||
}
|
||||
b.mu.RUnlock()
|
||||
|
||||
event := Event[T]{Type: t, Payload: payload}
|
||||
|
||||
for _, sub := range subscribers {
|
||||
select {
|
||||
case sub <- event:
|
||||
default:
|
||||
}
|
||||
}
|
||||
return len(b.subs)
|
||||
}
|
||||
|
||||
144
internal/pubsub/broker_test.go
Normal file
144
internal/pubsub/broker_test.go
Normal file
@@ -0,0 +1,144 @@
|
||||
package pubsub
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestBrokerSubscribe(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("with cancellable context", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
broker := NewBroker[string]()
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
ch := broker.Subscribe(ctx)
|
||||
assert.NotNil(t, ch)
|
||||
assert.Equal(t, 1, broker.GetSubscriberCount())
|
||||
|
||||
// Cancel the context should remove the subscription
|
||||
cancel()
|
||||
time.Sleep(10 * time.Millisecond) // Give time for goroutine to process
|
||||
assert.Equal(t, 0, broker.GetSubscriberCount())
|
||||
})
|
||||
|
||||
t.Run("with background context", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
broker := NewBroker[string]()
|
||||
|
||||
// Using context.Background() should not leak goroutines
|
||||
ch := broker.Subscribe(context.Background())
|
||||
assert.NotNil(t, ch)
|
||||
assert.Equal(t, 1, broker.GetSubscriberCount())
|
||||
|
||||
// Shutdown should clean up all subscriptions
|
||||
broker.Shutdown()
|
||||
assert.Equal(t, 0, broker.GetSubscriberCount())
|
||||
})
|
||||
}
|
||||
|
||||
func TestBrokerPublish(t *testing.T) {
|
||||
t.Parallel()
|
||||
broker := NewBroker[string]()
|
||||
ctx := t.Context()
|
||||
|
||||
ch := broker.Subscribe(ctx)
|
||||
|
||||
// Publish a message
|
||||
broker.Publish(EventTypeCreated, "test message")
|
||||
|
||||
// Verify message is received
|
||||
select {
|
||||
case event := <-ch:
|
||||
assert.Equal(t, EventTypeCreated, event.Type)
|
||||
assert.Equal(t, "test message", event.Payload)
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("timeout waiting for message")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBrokerShutdown(t *testing.T) {
|
||||
t.Parallel()
|
||||
broker := NewBroker[string]()
|
||||
|
||||
// Create multiple subscribers
|
||||
ch1 := broker.Subscribe(context.Background())
|
||||
ch2 := broker.Subscribe(context.Background())
|
||||
|
||||
assert.Equal(t, 2, broker.GetSubscriberCount())
|
||||
|
||||
// Shutdown should close all channels and clean up
|
||||
broker.Shutdown()
|
||||
|
||||
// Verify channels are closed
|
||||
_, ok1 := <-ch1
|
||||
_, ok2 := <-ch2
|
||||
assert.False(t, ok1, "channel 1 should be closed")
|
||||
assert.False(t, ok2, "channel 2 should be closed")
|
||||
|
||||
// Verify subscriber count is reset
|
||||
assert.Equal(t, 0, broker.GetSubscriberCount())
|
||||
}
|
||||
|
||||
func TestBrokerConcurrency(t *testing.T) {
|
||||
t.Parallel()
|
||||
broker := NewBroker[int]()
|
||||
|
||||
// Create a large number of subscribers
|
||||
const numSubscribers = 100
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numSubscribers)
|
||||
|
||||
// Create a channel to collect received events
|
||||
receivedEvents := make(chan int, numSubscribers)
|
||||
|
||||
for i := range numSubscribers {
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
ch := broker.Subscribe(ctx)
|
||||
|
||||
// Receive one message then cancel
|
||||
select {
|
||||
case event := <-ch:
|
||||
receivedEvents <- event.Payload
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Errorf("timeout waiting for message %d", id)
|
||||
}
|
||||
cancel()
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Give subscribers time to set up
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// Publish messages to all subscribers
|
||||
for i := range numSubscribers {
|
||||
broker.Publish(EventTypeCreated, i)
|
||||
}
|
||||
|
||||
// Wait for all subscribers to finish
|
||||
wg.Wait()
|
||||
close(receivedEvents)
|
||||
|
||||
// Give time for cleanup goroutines to run
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// Verify all subscribers are cleaned up
|
||||
assert.Equal(t, 0, broker.GetSubscriberCount())
|
||||
|
||||
// Verify we received the expected number of events
|
||||
count := 0
|
||||
for range receivedEvents {
|
||||
count++
|
||||
}
|
||||
assert.Equal(t, numSubscribers, count)
|
||||
}
|
||||
@@ -2,27 +2,23 @@ package pubsub
|
||||
|
||||
import "context"
|
||||
|
||||
type EventType string
|
||||
|
||||
const (
|
||||
CreatedEvent EventType = "created"
|
||||
UpdatedEvent EventType = "updated"
|
||||
DeletedEvent EventType = "deleted"
|
||||
EventTypeCreated EventType = "created"
|
||||
EventTypeUpdated EventType = "updated"
|
||||
EventTypeDeleted EventType = "deleted"
|
||||
)
|
||||
|
||||
type Suscriber[T any] interface {
|
||||
Subscribe(context.Context) <-chan Event[T]
|
||||
type Event[T any] struct {
|
||||
Type EventType
|
||||
Payload T
|
||||
}
|
||||
|
||||
type (
|
||||
// EventType identifies the type of event
|
||||
EventType string
|
||||
type Subscriber[T any] interface {
|
||||
Subscribe(ctx context.Context) <-chan Event[T]
|
||||
}
|
||||
|
||||
// Event represents an event in the lifecycle of a resource
|
||||
Event[T any] struct {
|
||||
Type EventType
|
||||
Payload T
|
||||
}
|
||||
|
||||
Publisher[T any] interface {
|
||||
Publish(EventType, T)
|
||||
}
|
||||
)
|
||||
type Publisher[T any] interface {
|
||||
Publish(eventType EventType, payload T)
|
||||
}
|
||||
|
||||
@@ -3,10 +3,13 @@ package session
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/opencode-ai/opencode/internal/db"
|
||||
"github.com/opencode-ai/opencode/internal/pubsub"
|
||||
"github.com/sst/opencode/internal/db"
|
||||
"github.com/sst/opencode/internal/pubsub"
|
||||
)
|
||||
|
||||
type Session struct {
|
||||
@@ -17,117 +20,197 @@ type Session struct {
|
||||
PromptTokens int64
|
||||
CompletionTokens int64
|
||||
Cost float64
|
||||
CreatedAt int64
|
||||
UpdatedAt int64
|
||||
Summary string
|
||||
SummarizedAt time.Time
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
const (
|
||||
EventSessionCreated pubsub.EventType = "session_created"
|
||||
EventSessionUpdated pubsub.EventType = "session_updated"
|
||||
EventSessionDeleted pubsub.EventType = "session_deleted"
|
||||
)
|
||||
|
||||
type Service interface {
|
||||
pubsub.Suscriber[Session]
|
||||
pubsub.Subscriber[Session]
|
||||
|
||||
Create(ctx context.Context, title string) (Session, error)
|
||||
CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error)
|
||||
CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error)
|
||||
Get(ctx context.Context, id string) (Session, error)
|
||||
List(ctx context.Context) ([]Session, error)
|
||||
Save(ctx context.Context, session Session) (Session, error)
|
||||
Update(ctx context.Context, session Session) (Session, error)
|
||||
Delete(ctx context.Context, id string) error
|
||||
}
|
||||
|
||||
type service struct {
|
||||
*pubsub.Broker[Session]
|
||||
q db.Querier
|
||||
db *db.Queries
|
||||
broker *pubsub.Broker[Session]
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
var globalSessionService *service
|
||||
|
||||
func InitService(dbConn *sql.DB) error {
|
||||
if globalSessionService != nil {
|
||||
return fmt.Errorf("session service already initialized")
|
||||
}
|
||||
queries := db.New(dbConn)
|
||||
broker := pubsub.NewBroker[Session]()
|
||||
|
||||
globalSessionService = &service{
|
||||
db: queries,
|
||||
broker: broker,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetService() Service {
|
||||
if globalSessionService == nil {
|
||||
panic("session service not initialized. Call session.InitService() first.")
|
||||
}
|
||||
return globalSessionService
|
||||
}
|
||||
|
||||
func (s *service) Create(ctx context.Context, title string) (Session, error) {
|
||||
dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if title == "" {
|
||||
title = "New Session - " + time.Now().Format("2006-01-02 15:04:05")
|
||||
}
|
||||
|
||||
dbSessParams := db.CreateSessionParams{
|
||||
ID: uuid.New().String(),
|
||||
Title: title,
|
||||
})
|
||||
if err != nil {
|
||||
return Session{}, err
|
||||
}
|
||||
dbSession, err := s.db.CreateSession(ctx, dbSessParams)
|
||||
if err != nil {
|
||||
return Session{}, fmt.Errorf("db.CreateSession: %w", err)
|
||||
}
|
||||
|
||||
session := s.fromDBItem(dbSession)
|
||||
s.Publish(pubsub.CreatedEvent, session)
|
||||
s.broker.Publish(EventSessionCreated, session)
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (s *service) CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error) {
|
||||
dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if title == "" {
|
||||
title = "Task Session - " + time.Now().Format("2006-01-02 15:04:05")
|
||||
}
|
||||
if toolCallID == "" {
|
||||
toolCallID = uuid.New().String()
|
||||
}
|
||||
|
||||
dbSessParams := db.CreateSessionParams{
|
||||
ID: toolCallID,
|
||||
ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
|
||||
ParentSessionID: sql.NullString{String: parentSessionID, Valid: parentSessionID != ""},
|
||||
Title: title,
|
||||
})
|
||||
}
|
||||
dbSession, err := s.db.CreateSession(ctx, dbSessParams)
|
||||
if err != nil {
|
||||
return Session{}, err
|
||||
return Session{}, fmt.Errorf("db.CreateTaskSession: %w", err)
|
||||
}
|
||||
session := s.fromDBItem(dbSession)
|
||||
s.Publish(pubsub.CreatedEvent, session)
|
||||
s.broker.Publish(EventSessionCreated, session)
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (s *service) CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error) {
|
||||
dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
|
||||
ID: "title-" + parentSessionID,
|
||||
ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
|
||||
Title: "Generate a title",
|
||||
})
|
||||
if err != nil {
|
||||
return Session{}, err
|
||||
}
|
||||
session := s.fromDBItem(dbSession)
|
||||
s.Publish(pubsub.CreatedEvent, session)
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (s *service) Delete(ctx context.Context, id string) error {
|
||||
session, err := s.Get(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = s.q.DeleteSession(ctx, session.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.Publish(pubsub.DeletedEvent, session)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *service) Get(ctx context.Context, id string) (Session, error) {
|
||||
dbSession, err := s.q.GetSessionByID(ctx, id)
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
dbSession, err := s.db.GetSessionByID(ctx, id)
|
||||
if err != nil {
|
||||
return Session{}, err
|
||||
if err == sql.ErrNoRows {
|
||||
return Session{}, fmt.Errorf("session ID '%s' not found", id)
|
||||
}
|
||||
return Session{}, fmt.Errorf("db.GetSessionByID: %w", err)
|
||||
}
|
||||
return s.fromDBItem(dbSession), nil
|
||||
}
|
||||
|
||||
func (s *service) Save(ctx context.Context, session Session) (Session, error) {
|
||||
dbSession, err := s.q.UpdateSession(ctx, db.UpdateSessionParams{
|
||||
func (s *service) List(ctx context.Context) ([]Session, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
dbSessions, err := s.db.ListSessions(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("db.ListSessions: %w", err)
|
||||
}
|
||||
sessions := make([]Session, len(dbSessions))
|
||||
for i, dbSess := range dbSessions {
|
||||
sessions[i] = s.fromDBItem(dbSess)
|
||||
}
|
||||
return sessions, nil
|
||||
}
|
||||
|
||||
func (s *service) Update(ctx context.Context, session Session) (Session, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if session.ID == "" {
|
||||
return Session{}, fmt.Errorf("cannot update session with empty ID")
|
||||
}
|
||||
|
||||
params := db.UpdateSessionParams{
|
||||
ID: session.ID,
|
||||
Title: session.Title,
|
||||
PromptTokens: session.PromptTokens,
|
||||
CompletionTokens: session.CompletionTokens,
|
||||
Cost: session.Cost,
|
||||
})
|
||||
if err != nil {
|
||||
return Session{}, err
|
||||
Summary: sql.NullString{String: session.Summary, Valid: session.Summary != ""},
|
||||
SummarizedAt: sql.NullString{String: session.SummarizedAt.UTC().Format(time.RFC3339Nano), Valid: !session.SummarizedAt.IsZero()},
|
||||
}
|
||||
session = s.fromDBItem(dbSession)
|
||||
s.Publish(pubsub.UpdatedEvent, session)
|
||||
return session, nil
|
||||
dbSession, err := s.db.UpdateSession(ctx, params)
|
||||
if err != nil {
|
||||
return Session{}, fmt.Errorf("db.UpdateSession: %w", err)
|
||||
}
|
||||
updatedSession := s.fromDBItem(dbSession)
|
||||
s.broker.Publish(EventSessionUpdated, updatedSession)
|
||||
return updatedSession, nil
|
||||
}
|
||||
|
||||
func (s *service) List(ctx context.Context) ([]Session, error) {
|
||||
dbSessions, err := s.q.ListSessions(ctx)
|
||||
func (s *service) Delete(ctx context.Context, id string) error {
|
||||
s.mu.Lock()
|
||||
dbSess, err := s.db.GetSessionByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
s.mu.Unlock()
|
||||
if err == sql.ErrNoRows {
|
||||
return fmt.Errorf("session ID '%s' not found for deletion", id)
|
||||
}
|
||||
return fmt.Errorf("db.GetSessionByID before delete: %w", err)
|
||||
}
|
||||
sessions := make([]Session, len(dbSessions))
|
||||
for i, dbSession := range dbSessions {
|
||||
sessions[i] = s.fromDBItem(dbSession)
|
||||
sessionToPublish := s.fromDBItem(dbSess)
|
||||
s.mu.Unlock()
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
err = s.db.DeleteSession(ctx, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("db.DeleteSession: %w", err)
|
||||
}
|
||||
return sessions, nil
|
||||
s.broker.Publish(EventSessionDeleted, sessionToPublish)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s service) fromDBItem(item db.Session) Session {
|
||||
func (s *service) Subscribe(ctx context.Context) <-chan pubsub.Event[Session] {
|
||||
return s.broker.Subscribe(ctx)
|
||||
}
|
||||
|
||||
func (s *service) fromDBItem(item db.Session) Session {
|
||||
var summarizedAt time.Time
|
||||
if item.SummarizedAt.Valid {
|
||||
parsedTime, err := time.Parse(time.RFC3339Nano, item.SummarizedAt.String)
|
||||
if err == nil {
|
||||
summarizedAt = parsedTime
|
||||
}
|
||||
}
|
||||
|
||||
createdAt, _ := time.Parse(time.RFC3339Nano, item.CreatedAt)
|
||||
updatedAt, _ := time.Parse(time.RFC3339Nano, item.UpdatedAt)
|
||||
|
||||
return Session{
|
||||
ID: item.ID,
|
||||
ParentSessionID: item.ParentSessionID.String,
|
||||
@@ -136,15 +219,37 @@ func (s service) fromDBItem(item db.Session) Session {
|
||||
PromptTokens: item.PromptTokens,
|
||||
CompletionTokens: item.CompletionTokens,
|
||||
Cost: item.Cost,
|
||||
CreatedAt: item.CreatedAt,
|
||||
UpdatedAt: item.UpdatedAt,
|
||||
Summary: item.Summary.String,
|
||||
SummarizedAt: summarizedAt,
|
||||
CreatedAt: createdAt,
|
||||
UpdatedAt: updatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func NewService(q db.Querier) Service {
|
||||
broker := pubsub.NewBroker[Session]()
|
||||
return &service{
|
||||
broker,
|
||||
q,
|
||||
}
|
||||
func Create(ctx context.Context, title string) (Session, error) {
|
||||
return GetService().Create(ctx, title)
|
||||
}
|
||||
|
||||
func CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error) {
|
||||
return GetService().CreateTaskSession(ctx, toolCallID, parentSessionID, title)
|
||||
}
|
||||
|
||||
func Get(ctx context.Context, id string) (Session, error) {
|
||||
return GetService().Get(ctx, id)
|
||||
}
|
||||
|
||||
func List(ctx context.Context) ([]Session, error) {
|
||||
return GetService().List(ctx)
|
||||
}
|
||||
|
||||
func Update(ctx context.Context, session Session) (Session, error) {
|
||||
return GetService().Update(ctx, session)
|
||||
}
|
||||
|
||||
func Delete(ctx context.Context, id string) error {
|
||||
return GetService().Delete(ctx, id)
|
||||
}
|
||||
|
||||
func Subscribe(ctx context.Context) <-chan pubsub.Event[Session] {
|
||||
return GetService().Subscribe(ctx)
|
||||
}
|
||||
|
||||
117
internal/status/status.go
Normal file
117
internal/status/status.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package status
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sst/opencode/internal/pubsub"
|
||||
)
|
||||
|
||||
type Level string
|
||||
|
||||
const (
|
||||
LevelInfo Level = "info"
|
||||
LevelWarn Level = "warn"
|
||||
LevelError Level = "error"
|
||||
LevelDebug Level = "debug"
|
||||
)
|
||||
|
||||
type StatusMessage struct {
|
||||
Level Level `json:"level"`
|
||||
Message string `json:"message"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
}
|
||||
|
||||
const (
|
||||
EventStatusPublished pubsub.EventType = "status_published"
|
||||
)
|
||||
|
||||
type Service interface {
|
||||
pubsub.Subscriber[StatusMessage]
|
||||
|
||||
Info(message string)
|
||||
Warn(message string)
|
||||
Error(message string)
|
||||
Debug(message string)
|
||||
}
|
||||
|
||||
type service struct {
|
||||
broker *pubsub.Broker[StatusMessage]
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
var globalStatusService *service
|
||||
|
||||
func InitService() error {
|
||||
if globalStatusService != nil {
|
||||
return fmt.Errorf("status service already initialized")
|
||||
}
|
||||
broker := pubsub.NewBroker[StatusMessage]()
|
||||
globalStatusService = &service{
|
||||
broker: broker,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetService() Service {
|
||||
if globalStatusService == nil {
|
||||
panic("status service not initialized. Call status.InitService() at application startup.")
|
||||
}
|
||||
return globalStatusService
|
||||
}
|
||||
|
||||
func (s *service) Info(message string) {
|
||||
s.publish(LevelInfo, message)
|
||||
slog.Info(message)
|
||||
}
|
||||
|
||||
func (s *service) Warn(message string) {
|
||||
s.publish(LevelWarn, message)
|
||||
slog.Warn(message)
|
||||
}
|
||||
|
||||
func (s *service) Error(message string) {
|
||||
s.publish(LevelError, message)
|
||||
slog.Error(message)
|
||||
}
|
||||
|
||||
func (s *service) Debug(message string) {
|
||||
s.publish(LevelDebug, message)
|
||||
slog.Debug(message)
|
||||
}
|
||||
|
||||
func (s *service) publish(level Level, messageText string) {
|
||||
statusMsg := StatusMessage{
|
||||
Level: level,
|
||||
Message: messageText,
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
s.broker.Publish(EventStatusPublished, statusMsg)
|
||||
}
|
||||
|
||||
func (s *service) Subscribe(ctx context.Context) <-chan pubsub.Event[StatusMessage] {
|
||||
return s.broker.Subscribe(ctx)
|
||||
}
|
||||
|
||||
func Info(message string) {
|
||||
GetService().Info(message)
|
||||
}
|
||||
|
||||
func Warn(message string) {
|
||||
GetService().Warn(message)
|
||||
}
|
||||
|
||||
func Error(message string) {
|
||||
GetService().Error(message)
|
||||
}
|
||||
|
||||
func Debug(message string) {
|
||||
GetService().Debug(message)
|
||||
}
|
||||
|
||||
func Subscribe(ctx context.Context) <-chan pubsub.Event[StatusMessage] {
|
||||
return GetService().Subscribe(ctx)
|
||||
}
|
||||
@@ -6,28 +6,41 @@ import (
|
||||
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
"github.com/charmbracelet/x/ansi"
|
||||
"github.com/opencode-ai/opencode/internal/config"
|
||||
"github.com/opencode-ai/opencode/internal/session"
|
||||
"github.com/opencode-ai/opencode/internal/tui/styles"
|
||||
"github.com/opencode-ai/opencode/internal/version"
|
||||
"github.com/sst/opencode/internal/config"
|
||||
"github.com/sst/opencode/internal/message"
|
||||
"github.com/sst/opencode/internal/tui/styles"
|
||||
"github.com/sst/opencode/internal/tui/theme"
|
||||
"github.com/sst/opencode/internal/version"
|
||||
)
|
||||
|
||||
type SendMsg struct {
|
||||
Text string
|
||||
Text string
|
||||
Attachments []message.Attachment
|
||||
}
|
||||
|
||||
type SessionSelectedMsg = session.Session
|
||||
|
||||
type SessionClearedMsg struct{}
|
||||
|
||||
type EditorFocusMsg bool
|
||||
func header(width int) string {
|
||||
return lipgloss.JoinVertical(
|
||||
lipgloss.Top,
|
||||
logo(width),
|
||||
repo(width),
|
||||
"",
|
||||
cwd(width),
|
||||
)
|
||||
}
|
||||
|
||||
func lspsConfigured(width int) string {
|
||||
cfg := config.Get()
|
||||
title := "LSP Configuration"
|
||||
title := "LSP Servers"
|
||||
title = ansi.Truncate(title, width, "…")
|
||||
|
||||
lsps := styles.BaseStyle.Width(width).Foreground(styles.PrimaryColor).Bold(true).Render(title)
|
||||
t := theme.CurrentTheme()
|
||||
baseStyle := styles.BaseStyle()
|
||||
|
||||
lsps := baseStyle.
|
||||
Width(width).
|
||||
Foreground(t.Primary()).
|
||||
Bold(true).
|
||||
Render(title)
|
||||
|
||||
// Get LSP names and sort them for consistent ordering
|
||||
var lspNames []string
|
||||
@@ -39,16 +52,19 @@ func lspsConfigured(width int) string {
|
||||
var lspViews []string
|
||||
for _, name := range lspNames {
|
||||
lsp := cfg.LSP[name]
|
||||
lspName := styles.BaseStyle.Foreground(styles.Forground).Render(
|
||||
fmt.Sprintf("• %s", name),
|
||||
)
|
||||
lspName := baseStyle.
|
||||
Foreground(t.Text()).
|
||||
Render(fmt.Sprintf("• %s", name))
|
||||
|
||||
cmd := lsp.Command
|
||||
cmd = ansi.Truncate(cmd, width-lipgloss.Width(lspName)-3, "…")
|
||||
lspPath := styles.BaseStyle.Foreground(styles.ForgroundDim).Render(
|
||||
fmt.Sprintf(" (%s)", cmd),
|
||||
)
|
||||
|
||||
lspPath := baseStyle.
|
||||
Foreground(t.TextMuted()).
|
||||
Render(fmt.Sprintf(" (%s)", cmd))
|
||||
|
||||
lspViews = append(lspViews,
|
||||
styles.BaseStyle.
|
||||
baseStyle.
|
||||
Width(width).
|
||||
Render(
|
||||
lipgloss.JoinHorizontal(
|
||||
@@ -59,7 +75,8 @@ func lspsConfigured(width int) string {
|
||||
),
|
||||
)
|
||||
}
|
||||
return styles.BaseStyle.
|
||||
|
||||
return baseStyle.
|
||||
Width(width).
|
||||
Render(
|
||||
lipgloss.JoinVertical(
|
||||
@@ -75,10 +92,14 @@ func lspsConfigured(width int) string {
|
||||
|
||||
func logo(width int) string {
|
||||
logo := fmt.Sprintf("%s %s", styles.OpenCodeIcon, "OpenCode")
|
||||
t := theme.CurrentTheme()
|
||||
baseStyle := styles.BaseStyle()
|
||||
|
||||
version := styles.BaseStyle.Foreground(styles.ForgroundDim).Render(version.Version)
|
||||
versionText := baseStyle.
|
||||
Foreground(t.TextMuted()).
|
||||
Render(version.Version)
|
||||
|
||||
return styles.BaseStyle.
|
||||
return baseStyle.
|
||||
Bold(true).
|
||||
Width(width).
|
||||
Render(
|
||||
@@ -86,34 +107,27 @@ func logo(width int) string {
|
||||
lipgloss.Left,
|
||||
logo,
|
||||
" ",
|
||||
version,
|
||||
versionText,
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
func repo(width int) string {
|
||||
repo := "https://github.com/opencode-ai/opencode"
|
||||
return styles.BaseStyle.
|
||||
Foreground(styles.ForgroundDim).
|
||||
repo := "github.com/sst/opencode"
|
||||
t := theme.CurrentTheme()
|
||||
|
||||
return styles.BaseStyle().
|
||||
Foreground(t.TextMuted()).
|
||||
Width(width).
|
||||
Render(repo)
|
||||
}
|
||||
|
||||
func cwd(width int) string {
|
||||
cwd := fmt.Sprintf("cwd: %s", config.WorkingDirectory())
|
||||
return styles.BaseStyle.
|
||||
Foreground(styles.ForgroundDim).
|
||||
t := theme.CurrentTheme()
|
||||
|
||||
return styles.BaseStyle().
|
||||
Foreground(t.TextMuted()).
|
||||
Width(width).
|
||||
Render(cwd)
|
||||
}
|
||||
|
||||
func header(width int) string {
|
||||
header := lipgloss.JoinVertical(
|
||||
lipgloss.Top,
|
||||
logo(width),
|
||||
repo(width),
|
||||
"",
|
||||
cwd(width),
|
||||
)
|
||||
return header
|
||||
}
|
||||
|
||||
@@ -1,24 +1,33 @@
|
||||
package chat
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"slices"
|
||||
"unicode"
|
||||
|
||||
"github.com/charmbracelet/bubbles/key"
|
||||
"github.com/charmbracelet/bubbles/textarea"
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
"github.com/opencode-ai/opencode/internal/app"
|
||||
"github.com/opencode-ai/opencode/internal/session"
|
||||
"github.com/opencode-ai/opencode/internal/tui/layout"
|
||||
"github.com/opencode-ai/opencode/internal/tui/styles"
|
||||
"github.com/opencode-ai/opencode/internal/tui/util"
|
||||
"github.com/sst/opencode/internal/app"
|
||||
"github.com/sst/opencode/internal/message"
|
||||
"github.com/sst/opencode/internal/status"
|
||||
"github.com/sst/opencode/internal/tui/components/dialog"
|
||||
"github.com/sst/opencode/internal/tui/layout"
|
||||
"github.com/sst/opencode/internal/tui/styles"
|
||||
"github.com/sst/opencode/internal/tui/theme"
|
||||
"github.com/sst/opencode/internal/tui/util"
|
||||
)
|
||||
|
||||
type editorCmp struct {
|
||||
app *app.App
|
||||
session session.Session
|
||||
textarea textarea.Model
|
||||
width int
|
||||
height int
|
||||
app *app.App
|
||||
textarea textarea.Model
|
||||
attachments []message.Attachment
|
||||
deleteMode bool
|
||||
}
|
||||
|
||||
type EditorKeyMaps struct {
|
||||
@@ -31,6 +40,11 @@ type bluredEditorKeyMaps struct {
|
||||
Focus key.Binding
|
||||
OpenEditor key.Binding
|
||||
}
|
||||
type DeleteAttachmentKeyMaps struct {
|
||||
AttachmentDeleteMode key.Binding
|
||||
Escape key.Binding
|
||||
DeleteAllAttachments key.Binding
|
||||
}
|
||||
|
||||
var editorMaps = EditorKeyMaps{
|
||||
Send: key.NewBinding(
|
||||
@@ -43,15 +57,36 @@ var editorMaps = EditorKeyMaps{
|
||||
),
|
||||
}
|
||||
|
||||
func openEditor() tea.Cmd {
|
||||
var DeleteKeyMaps = DeleteAttachmentKeyMaps{
|
||||
AttachmentDeleteMode: key.NewBinding(
|
||||
key.WithKeys("ctrl+r"),
|
||||
key.WithHelp("ctrl+r+{i}", "delete attachment at index i"),
|
||||
),
|
||||
Escape: key.NewBinding(
|
||||
key.WithKeys("esc"),
|
||||
key.WithHelp("esc", "cancel delete mode"),
|
||||
),
|
||||
DeleteAllAttachments: key.NewBinding(
|
||||
key.WithKeys("r"),
|
||||
key.WithHelp("ctrl+r+r", "delete all attchments"),
|
||||
),
|
||||
}
|
||||
|
||||
const (
|
||||
maxAttachments = 5
|
||||
)
|
||||
|
||||
func (m *editorCmp) openEditor(value string) tea.Cmd {
|
||||
editor := os.Getenv("EDITOR")
|
||||
if editor == "" {
|
||||
editor = "nvim"
|
||||
}
|
||||
|
||||
tmpfile, err := os.CreateTemp("", "msg_*.md")
|
||||
tmpfile.WriteString(value)
|
||||
if err != nil {
|
||||
return util.ReportError(err)
|
||||
status.Error(err.Error())
|
||||
return nil
|
||||
}
|
||||
tmpfile.Close()
|
||||
c := exec.Command(editor, tmpfile.Name()) //nolint:gosec
|
||||
@@ -60,18 +95,24 @@ func openEditor() tea.Cmd {
|
||||
c.Stderr = os.Stderr
|
||||
return tea.ExecProcess(c, func(err error) tea.Msg {
|
||||
if err != nil {
|
||||
return util.ReportError(err)
|
||||
status.Error(err.Error())
|
||||
return nil
|
||||
}
|
||||
content, err := os.ReadFile(tmpfile.Name())
|
||||
if err != nil {
|
||||
return util.ReportError(err)
|
||||
status.Error(err.Error())
|
||||
return nil
|
||||
}
|
||||
if len(content) == 0 {
|
||||
return util.ReportWarn("Message is empty")
|
||||
status.Warn("Message is empty")
|
||||
return nil
|
||||
}
|
||||
os.Remove(tmpfile.Name())
|
||||
attachments := m.attachments
|
||||
m.attachments = nil
|
||||
return SendMsg{
|
||||
Text: string(content),
|
||||
Text: string(content),
|
||||
Attachments: attachments,
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -81,18 +122,23 @@ func (m *editorCmp) Init() tea.Cmd {
|
||||
}
|
||||
|
||||
func (m *editorCmp) send() tea.Cmd {
|
||||
if m.app.CoderAgent.IsSessionBusy(m.session.ID) {
|
||||
return util.ReportWarn("Agent is working, please wait...")
|
||||
if m.app.PrimaryAgent.IsSessionBusy(m.app.CurrentSession.ID) {
|
||||
status.Warn("Agent is working, please wait...")
|
||||
return nil
|
||||
}
|
||||
|
||||
value := m.textarea.Value()
|
||||
m.textarea.Reset()
|
||||
attachments := m.attachments
|
||||
|
||||
m.attachments = nil
|
||||
if value == "" {
|
||||
return nil
|
||||
}
|
||||
return tea.Batch(
|
||||
util.CmdHandler(SendMsg{
|
||||
Text: value,
|
||||
Text: value,
|
||||
Attachments: attachments,
|
||||
}),
|
||||
)
|
||||
}
|
||||
@@ -100,21 +146,53 @@ func (m *editorCmp) send() tea.Cmd {
|
||||
func (m *editorCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
var cmd tea.Cmd
|
||||
switch msg := msg.(type) {
|
||||
case SessionSelectedMsg:
|
||||
if msg.ID != m.session.ID {
|
||||
m.session = msg
|
||||
}
|
||||
case dialog.ThemeChangedMsg:
|
||||
m.textarea = CreateTextArea(&m.textarea)
|
||||
return m, nil
|
||||
case dialog.AttachmentAddedMsg:
|
||||
if len(m.attachments) >= maxAttachments {
|
||||
status.Error(fmt.Sprintf("cannot add more than %d images", maxAttachments))
|
||||
return m, cmd
|
||||
}
|
||||
m.attachments = append(m.attachments, msg.Attachment)
|
||||
case tea.KeyMsg:
|
||||
if key.Matches(msg, DeleteKeyMaps.AttachmentDeleteMode) {
|
||||
m.deleteMode = true
|
||||
return m, nil
|
||||
}
|
||||
if key.Matches(msg, DeleteKeyMaps.DeleteAllAttachments) && m.deleteMode {
|
||||
m.deleteMode = false
|
||||
m.attachments = nil
|
||||
return m, nil
|
||||
}
|
||||
if m.deleteMode && len(msg.Runes) > 0 && unicode.IsDigit(msg.Runes[0]) {
|
||||
num := int(msg.Runes[0] - '0')
|
||||
m.deleteMode = false
|
||||
if num < 10 && len(m.attachments) > num {
|
||||
if num == 0 {
|
||||
m.attachments = m.attachments[num+1:]
|
||||
} else {
|
||||
m.attachments = slices.Delete(m.attachments, num, num+1)
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
}
|
||||
if key.Matches(msg, messageKeys.PageUp) || key.Matches(msg, messageKeys.PageDown) ||
|
||||
key.Matches(msg, messageKeys.HalfPageUp) || key.Matches(msg, messageKeys.HalfPageDown) {
|
||||
return m, nil
|
||||
}
|
||||
if key.Matches(msg, editorMaps.OpenEditor) {
|
||||
if m.app.CoderAgent.IsSessionBusy(m.session.ID) {
|
||||
return m, util.ReportWarn("Agent is working, please wait...")
|
||||
if m.app.PrimaryAgent.IsSessionBusy(m.app.CurrentSession.ID) {
|
||||
status.Warn("Agent is working, please wait...")
|
||||
return m, nil
|
||||
}
|
||||
return m, openEditor()
|
||||
value := m.textarea.Value()
|
||||
m.textarea.Reset()
|
||||
return m, m.openEditor(value)
|
||||
}
|
||||
if key.Matches(msg, DeleteKeyMaps.Escape) {
|
||||
m.deleteMode = false
|
||||
return m, nil
|
||||
}
|
||||
// Handle Enter key
|
||||
if m.textarea.Focused() && key.Matches(msg, editorMaps.Send) {
|
||||
@@ -128,20 +206,38 @@ func (m *editorCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
return m, m.send()
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
m.textarea, cmd = m.textarea.Update(msg)
|
||||
return m, cmd
|
||||
}
|
||||
|
||||
func (m *editorCmp) View() string {
|
||||
style := lipgloss.NewStyle().Padding(0, 0, 0, 1).Bold(true)
|
||||
t := theme.CurrentTheme()
|
||||
|
||||
return lipgloss.JoinHorizontal(lipgloss.Top, style.Render(">"), m.textarea.View())
|
||||
// Style the prompt with theme colors
|
||||
style := lipgloss.NewStyle().
|
||||
Padding(0, 0, 0, 1).
|
||||
Bold(true).
|
||||
Foreground(t.Primary())
|
||||
|
||||
if len(m.attachments) == 0 {
|
||||
return lipgloss.JoinHorizontal(lipgloss.Top, style.Render(">"), m.textarea.View())
|
||||
}
|
||||
m.textarea.SetHeight(m.height - 1)
|
||||
return lipgloss.JoinVertical(lipgloss.Top,
|
||||
m.attachmentsContent(),
|
||||
lipgloss.JoinHorizontal(lipgloss.Top, style.Render(">"),
|
||||
m.textarea.View()),
|
||||
)
|
||||
}
|
||||
|
||||
func (m *editorCmp) SetSize(width, height int) tea.Cmd {
|
||||
m.width = width
|
||||
m.height = height
|
||||
m.textarea.SetWidth(width - 3) // account for the prompt and padding right
|
||||
m.textarea.SetHeight(height)
|
||||
m.textarea.SetWidth(width)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -149,29 +245,70 @@ func (m *editorCmp) GetSize() (int, int) {
|
||||
return m.textarea.Width(), m.textarea.Height()
|
||||
}
|
||||
|
||||
func (m *editorCmp) attachmentsContent() string {
|
||||
var styledAttachments []string
|
||||
t := theme.CurrentTheme()
|
||||
attachmentStyles := styles.BaseStyle().
|
||||
MarginLeft(1).
|
||||
Background(t.TextMuted()).
|
||||
Foreground(t.Text())
|
||||
for i, attachment := range m.attachments {
|
||||
var filename string
|
||||
if len(attachment.FileName) > 10 {
|
||||
filename = fmt.Sprintf(" %s %s...", styles.DocumentIcon, attachment.FileName[0:7])
|
||||
} else {
|
||||
filename = fmt.Sprintf(" %s %s", styles.DocumentIcon, attachment.FileName)
|
||||
}
|
||||
if m.deleteMode {
|
||||
filename = fmt.Sprintf("%d%s", i, filename)
|
||||
}
|
||||
styledAttachments = append(styledAttachments, attachmentStyles.Render(filename))
|
||||
}
|
||||
content := lipgloss.JoinHorizontal(lipgloss.Left, styledAttachments...)
|
||||
return content
|
||||
}
|
||||
|
||||
func (m *editorCmp) BindingKeys() []key.Binding {
|
||||
bindings := []key.Binding{}
|
||||
bindings = append(bindings, layout.KeyMapToSlice(editorMaps)...)
|
||||
bindings = append(bindings, layout.KeyMapToSlice(DeleteKeyMaps)...)
|
||||
return bindings
|
||||
}
|
||||
|
||||
func NewEditorCmp(app *app.App) tea.Model {
|
||||
ti := textarea.New()
|
||||
ti.Prompt = " "
|
||||
ti.ShowLineNumbers = false
|
||||
ti.BlurredStyle.Base = ti.BlurredStyle.Base.Background(styles.Background)
|
||||
ti.BlurredStyle.CursorLine = ti.BlurredStyle.CursorLine.Background(styles.Background)
|
||||
ti.BlurredStyle.Placeholder = ti.BlurredStyle.Placeholder.Background(styles.Background)
|
||||
ti.BlurredStyle.Text = ti.BlurredStyle.Text.Background(styles.Background)
|
||||
func CreateTextArea(existing *textarea.Model) textarea.Model {
|
||||
t := theme.CurrentTheme()
|
||||
bgColor := t.Background()
|
||||
textColor := t.Text()
|
||||
textMutedColor := t.TextMuted()
|
||||
|
||||
ti.FocusedStyle.Base = ti.FocusedStyle.Base.Background(styles.Background)
|
||||
ti.FocusedStyle.CursorLine = ti.FocusedStyle.CursorLine.Background(styles.Background)
|
||||
ti.FocusedStyle.Placeholder = ti.FocusedStyle.Placeholder.Background(styles.Background)
|
||||
ti.FocusedStyle.Text = ti.BlurredStyle.Text.Background(styles.Background)
|
||||
ti.CharLimit = -1
|
||||
ti.Focus()
|
||||
ta := textarea.New()
|
||||
ta.BlurredStyle.Base = styles.BaseStyle().Background(bgColor).Foreground(textColor)
|
||||
ta.BlurredStyle.CursorLine = styles.BaseStyle().Background(bgColor)
|
||||
ta.BlurredStyle.Placeholder = styles.BaseStyle().Background(bgColor).Foreground(textMutedColor)
|
||||
ta.BlurredStyle.Text = styles.BaseStyle().Background(bgColor).Foreground(textColor)
|
||||
ta.FocusedStyle.Base = styles.BaseStyle().Background(bgColor).Foreground(textColor)
|
||||
ta.FocusedStyle.CursorLine = styles.BaseStyle().Background(bgColor)
|
||||
ta.FocusedStyle.Placeholder = styles.BaseStyle().Background(bgColor).Foreground(textMutedColor)
|
||||
ta.FocusedStyle.Text = styles.BaseStyle().Background(bgColor).Foreground(textColor)
|
||||
|
||||
ta.Prompt = " "
|
||||
ta.ShowLineNumbers = false
|
||||
ta.CharLimit = -1
|
||||
|
||||
if existing != nil {
|
||||
ta.SetValue(existing.Value())
|
||||
ta.SetWidth(existing.Width())
|
||||
ta.SetHeight(existing.Height())
|
||||
}
|
||||
|
||||
ta.Focus()
|
||||
return ta
|
||||
}
|
||||
|
||||
func NewEditorCmp(app *app.App) tea.Model {
|
||||
ta := CreateTextArea(nil)
|
||||
return &editorCmp{
|
||||
app: app,
|
||||
textarea: ti,
|
||||
textarea: ta,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,37 +4,44 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/bubbles/key"
|
||||
"github.com/charmbracelet/bubbles/spinner"
|
||||
"github.com/charmbracelet/bubbles/viewport"
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
"github.com/opencode-ai/opencode/internal/app"
|
||||
"github.com/opencode-ai/opencode/internal/message"
|
||||
"github.com/opencode-ai/opencode/internal/pubsub"
|
||||
"github.com/opencode-ai/opencode/internal/session"
|
||||
"github.com/opencode-ai/opencode/internal/tui/styles"
|
||||
"github.com/opencode-ai/opencode/internal/tui/util"
|
||||
"github.com/sst/opencode/internal/app"
|
||||
"github.com/sst/opencode/internal/message"
|
||||
"github.com/sst/opencode/internal/pubsub"
|
||||
"github.com/sst/opencode/internal/session"
|
||||
"github.com/sst/opencode/internal/status"
|
||||
"github.com/sst/opencode/internal/tui/components/dialog"
|
||||
"github.com/sst/opencode/internal/tui/state"
|
||||
"github.com/sst/opencode/internal/tui/styles"
|
||||
"github.com/sst/opencode/internal/tui/theme"
|
||||
)
|
||||
|
||||
type cacheItem struct {
|
||||
width int
|
||||
content []uiMessage
|
||||
}
|
||||
|
||||
type messagesCmp struct {
|
||||
app *app.App
|
||||
width, height int
|
||||
viewport viewport.Model
|
||||
session session.Session
|
||||
messages []message.Message
|
||||
uiMessages []uiMessage
|
||||
currentMsgID string
|
||||
cachedContent map[string]cacheItem
|
||||
spinner spinner.Model
|
||||
rendering bool
|
||||
app *app.App
|
||||
width, height int
|
||||
viewport viewport.Model
|
||||
messages []message.Message
|
||||
uiMessages []uiMessage
|
||||
currentMsgID string
|
||||
cachedContent map[string]cacheItem
|
||||
spinner spinner.Model
|
||||
rendering bool
|
||||
attachments viewport.Model
|
||||
showToolMessages bool
|
||||
}
|
||||
type renderFinishedMsg struct{}
|
||||
type ToggleToolMessagesMsg struct{}
|
||||
|
||||
type MessageKeys struct {
|
||||
PageDown key.Binding
|
||||
@@ -69,20 +76,23 @@ func (m *messagesCmp) Init() tea.Cmd {
|
||||
func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
var cmds []tea.Cmd
|
||||
switch msg := msg.(type) {
|
||||
|
||||
case SessionSelectedMsg:
|
||||
if msg.ID != m.session.ID {
|
||||
cmd := m.SetSession(msg)
|
||||
return m, cmd
|
||||
}
|
||||
case dialog.ThemeChangedMsg:
|
||||
m.rerender()
|
||||
return m, nil
|
||||
case SessionClearedMsg:
|
||||
m.session = session.Session{}
|
||||
case ToggleToolMessagesMsg:
|
||||
m.showToolMessages = !m.showToolMessages
|
||||
// Clear the cache to force re-rendering of all messages
|
||||
m.cachedContent = make(map[string]cacheItem)
|
||||
m.renderView()
|
||||
return m, nil
|
||||
case state.SessionSelectedMsg:
|
||||
cmd := m.Reload(msg)
|
||||
return m, cmd
|
||||
case state.SessionClearedMsg:
|
||||
m.messages = make([]message.Message, 0)
|
||||
m.currentMsgID = ""
|
||||
m.rendering = false
|
||||
return m, nil
|
||||
|
||||
case tea.KeyMsg:
|
||||
if key.Matches(msg, messageKeys.PageUp) || key.Matches(msg, messageKeys.PageDown) ||
|
||||
key.Matches(msg, messageKeys.HalfPageUp) || key.Matches(msg, messageKeys.HalfPageDown) {
|
||||
@@ -90,15 +100,13 @@ func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
m.viewport = u
|
||||
cmds = append(cmds, cmd)
|
||||
}
|
||||
|
||||
case renderFinishedMsg:
|
||||
m.rendering = false
|
||||
m.viewport.GotoBottom()
|
||||
case pubsub.Event[message.Message]:
|
||||
needsRerender := false
|
||||
if msg.Type == pubsub.CreatedEvent {
|
||||
if msg.Payload.SessionID == m.session.ID {
|
||||
|
||||
if msg.Type == message.EventMessageCreated {
|
||||
if msg.Payload.SessionID == m.app.CurrentSession.ID {
|
||||
messageExists := false
|
||||
for _, v := range m.messages {
|
||||
if v.ID == msg.Payload.ID {
|
||||
@@ -128,7 +136,7 @@ func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if msg.Type == pubsub.UpdatedEvent && msg.Payload.SessionID == m.session.ID {
|
||||
} else if msg.Type == message.EventMessageUpdated && msg.Payload.SessionID == m.app.CurrentSession.ID {
|
||||
for i, v := range m.messages {
|
||||
if v.ID == msg.Payload.ID {
|
||||
m.messages[i] = msg.Payload
|
||||
@@ -141,8 +149,8 @@ func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
if needsRerender {
|
||||
m.renderView()
|
||||
if len(m.messages) > 0 {
|
||||
if (msg.Type == pubsub.CreatedEvent) ||
|
||||
(msg.Type == pubsub.UpdatedEvent && msg.Payload.ID == m.messages[len(m.messages)-1].ID) {
|
||||
if (msg.Type == message.EventMessageCreated) ||
|
||||
(msg.Type == message.EventMessageUpdated && msg.Payload.ID == m.messages[len(m.messages)-1].ID) {
|
||||
m.viewport.GotoBottom()
|
||||
}
|
||||
}
|
||||
@@ -156,7 +164,7 @@ func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
}
|
||||
|
||||
func (m *messagesCmp) IsAgentWorking() bool {
|
||||
return m.app.CoderAgent.IsSessionBusy(m.session.ID)
|
||||
return m.app.PrimaryAgent.IsSessionBusy(m.app.CurrentSession.ID)
|
||||
}
|
||||
|
||||
func formatTimeDifference(unixTime1, unixTime2 int64) string {
|
||||
@@ -174,6 +182,7 @@ func formatTimeDifference(unixTime1, unixTime2 int64) string {
|
||||
func (m *messagesCmp) renderView() {
|
||||
m.uiMessages = make([]uiMessage, 0)
|
||||
pos := 0
|
||||
baseStyle := styles.BaseStyle()
|
||||
|
||||
if m.width == 0 {
|
||||
return
|
||||
@@ -210,6 +219,7 @@ func (m *messagesCmp) renderView() {
|
||||
m.currentMsgID,
|
||||
m.width,
|
||||
pos,
|
||||
m.showToolMessages,
|
||||
)
|
||||
for _, msg := range assistantMessages {
|
||||
m.uiMessages = append(m.uiMessages, msg)
|
||||
@@ -224,16 +234,17 @@ func (m *messagesCmp) renderView() {
|
||||
|
||||
messages := make([]string, 0)
|
||||
for _, v := range m.uiMessages {
|
||||
messages = append(messages, v.content,
|
||||
styles.BaseStyle.
|
||||
messages = append(messages, lipgloss.JoinVertical(lipgloss.Left, v.content),
|
||||
baseStyle.
|
||||
Width(m.width).
|
||||
Render(
|
||||
"",
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
m.viewport.SetContent(
|
||||
styles.BaseStyle.
|
||||
baseStyle.
|
||||
Width(m.width).
|
||||
Render(
|
||||
lipgloss.JoinVertical(
|
||||
@@ -245,8 +256,10 @@ func (m *messagesCmp) renderView() {
|
||||
}
|
||||
|
||||
func (m *messagesCmp) View() string {
|
||||
baseStyle := styles.BaseStyle()
|
||||
|
||||
if m.rendering {
|
||||
return styles.BaseStyle.
|
||||
return baseStyle.
|
||||
Width(m.width).
|
||||
Render(
|
||||
lipgloss.JoinVertical(
|
||||
@@ -258,14 +271,14 @@ func (m *messagesCmp) View() string {
|
||||
)
|
||||
}
|
||||
if len(m.messages) == 0 {
|
||||
content := styles.BaseStyle.
|
||||
content := baseStyle.
|
||||
Width(m.width).
|
||||
Height(m.height - 1).
|
||||
Render(
|
||||
m.initialScreen(),
|
||||
)
|
||||
|
||||
return styles.BaseStyle.
|
||||
return baseStyle.
|
||||
Width(m.width).
|
||||
Render(
|
||||
lipgloss.JoinVertical(
|
||||
@@ -277,7 +290,7 @@ func (m *messagesCmp) View() string {
|
||||
)
|
||||
}
|
||||
|
||||
return styles.BaseStyle.
|
||||
return baseStyle.
|
||||
Width(m.width).
|
||||
Render(
|
||||
lipgloss.JoinVertical(
|
||||
@@ -328,6 +341,9 @@ func hasUnfinishedToolCalls(messages []message.Message) bool {
|
||||
func (m *messagesCmp) working() string {
|
||||
text := ""
|
||||
if m.IsAgentWorking() && len(m.messages) > 0 {
|
||||
t := theme.CurrentTheme()
|
||||
baseStyle := styles.BaseStyle()
|
||||
|
||||
task := "Thinking..."
|
||||
lastMessage := m.messages[len(m.messages)-1]
|
||||
if hasToolsWithoutResponse(m.messages) {
|
||||
@@ -338,42 +354,51 @@ func (m *messagesCmp) working() string {
|
||||
task = "Generating..."
|
||||
}
|
||||
if task != "" {
|
||||
text += styles.BaseStyle.Width(m.width).Foreground(styles.PrimaryColor).Bold(true).Render(
|
||||
fmt.Sprintf("%s %s ", m.spinner.View(), task),
|
||||
)
|
||||
text += baseStyle.
|
||||
Width(m.width).
|
||||
Foreground(t.Primary()).
|
||||
Bold(true).
|
||||
Render(fmt.Sprintf("%s %s ", m.spinner.View(), task))
|
||||
}
|
||||
}
|
||||
return text
|
||||
}
|
||||
|
||||
func (m *messagesCmp) help() string {
|
||||
t := theme.CurrentTheme()
|
||||
baseStyle := styles.BaseStyle()
|
||||
|
||||
text := ""
|
||||
|
||||
if m.app.CoderAgent.IsBusy() {
|
||||
if m.app.PrimaryAgent.IsBusy() {
|
||||
text += lipgloss.JoinHorizontal(
|
||||
lipgloss.Left,
|
||||
styles.BaseStyle.Foreground(styles.ForgroundDim).Bold(true).Render("press "),
|
||||
styles.BaseStyle.Foreground(styles.Forground).Bold(true).Render("esc"),
|
||||
styles.BaseStyle.Foreground(styles.ForgroundDim).Bold(true).Render(" to exit cancel"),
|
||||
baseStyle.Foreground(t.TextMuted()).Bold(true).Render("press "),
|
||||
baseStyle.Foreground(t.Text()).Bold(true).Render("esc"),
|
||||
baseStyle.Foreground(t.TextMuted()).Bold(true).Render(" to interrupt"),
|
||||
)
|
||||
} else {
|
||||
text += lipgloss.JoinHorizontal(
|
||||
lipgloss.Left,
|
||||
styles.BaseStyle.Foreground(styles.ForgroundDim).Bold(true).Render("press "),
|
||||
styles.BaseStyle.Foreground(styles.Forground).Bold(true).Render("enter"),
|
||||
styles.BaseStyle.Foreground(styles.ForgroundDim).Bold(true).Render(" to send the message,"),
|
||||
styles.BaseStyle.Foreground(styles.ForgroundDim).Bold(true).Render(" write"),
|
||||
styles.BaseStyle.Foreground(styles.Forground).Bold(true).Render(" \\"),
|
||||
styles.BaseStyle.Foreground(styles.ForgroundDim).Bold(true).Render(" and enter to add a new line"),
|
||||
baseStyle.Foreground(t.Text()).Bold(true).Render("enter"),
|
||||
baseStyle.Foreground(t.TextMuted()).Bold(true).Render(" to send,"),
|
||||
baseStyle.Foreground(t.Text()).Bold(true).Render(" \\"),
|
||||
baseStyle.Foreground(t.TextMuted()).Bold(true).Render("+"),
|
||||
baseStyle.Foreground(t.Text()).Bold(true).Render("enter"),
|
||||
baseStyle.Foreground(t.TextMuted()).Bold(true).Render(" for newline,"),
|
||||
baseStyle.Foreground(t.Text()).Bold(true).Render(" ctrl+h"),
|
||||
baseStyle.Foreground(t.TextMuted()).Bold(true).Render(" to toggle tool messages"),
|
||||
)
|
||||
}
|
||||
return styles.BaseStyle.
|
||||
return baseStyle.
|
||||
Width(m.width).
|
||||
Render(text)
|
||||
}
|
||||
|
||||
func (m *messagesCmp) initialScreen() string {
|
||||
return styles.BaseStyle.Width(m.width).Render(
|
||||
baseStyle := styles.BaseStyle()
|
||||
|
||||
return baseStyle.Width(m.width).Render(
|
||||
lipgloss.JoinVertical(
|
||||
lipgloss.Top,
|
||||
header(m.width),
|
||||
@@ -383,6 +408,13 @@ func (m *messagesCmp) initialScreen() string {
|
||||
)
|
||||
}
|
||||
|
||||
func (m *messagesCmp) rerender() {
|
||||
for _, msg := range m.messages {
|
||||
delete(m.cachedContent, msg.ID)
|
||||
}
|
||||
m.renderView()
|
||||
}
|
||||
|
||||
func (m *messagesCmp) SetSize(width, height int) tea.Cmd {
|
||||
if m.width == width && m.height == height {
|
||||
return nil
|
||||
@@ -391,11 +423,9 @@ func (m *messagesCmp) SetSize(width, height int) tea.Cmd {
|
||||
m.height = height
|
||||
m.viewport.Width = width
|
||||
m.viewport.Height = height - 2
|
||||
for _, msg := range m.messages {
|
||||
delete(m.cachedContent, msg.ID)
|
||||
}
|
||||
m.uiMessages = make([]uiMessage, 0)
|
||||
m.renderView()
|
||||
m.attachments.Width = width + 40
|
||||
m.attachments.Height = 3
|
||||
m.rerender()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -403,17 +433,16 @@ func (m *messagesCmp) GetSize() (int, int) {
|
||||
return m.width, m.height
|
||||
}
|
||||
|
||||
func (m *messagesCmp) SetSession(session session.Session) tea.Cmd {
|
||||
if m.session.ID == session.ID {
|
||||
return nil
|
||||
}
|
||||
m.session = session
|
||||
func (m *messagesCmp) Reload(session *session.Session) tea.Cmd {
|
||||
messages, err := m.app.Messages.List(context.Background(), session.ID)
|
||||
if err != nil {
|
||||
return util.ReportError(err)
|
||||
status.Error(err.Error())
|
||||
return nil
|
||||
}
|
||||
m.messages = messages
|
||||
m.currentMsgID = m.messages[len(m.messages)-1].ID
|
||||
if len(m.messages) > 0 {
|
||||
m.currentMsgID = m.messages[len(m.messages)-1].ID
|
||||
}
|
||||
delete(m.cachedContent, m.currentMsgID)
|
||||
m.rendering = true
|
||||
return func() tea.Msg {
|
||||
@@ -432,17 +461,23 @@ func (m *messagesCmp) BindingKeys() []key.Binding {
|
||||
}
|
||||
|
||||
func NewMessagesCmp(app *app.App) tea.Model {
|
||||
s := spinner.New()
|
||||
s.Spinner = spinner.Pulse
|
||||
customSpinner := spinner.Spinner{
|
||||
Frames: []string{" ", "┃", "┃"},
|
||||
FPS: time.Second / 3,
|
||||
}
|
||||
s := spinner.New(spinner.WithSpinner(customSpinner))
|
||||
vp := viewport.New(0, 0)
|
||||
attachmets := viewport.New(0, 0)
|
||||
vp.KeyMap.PageUp = messageKeys.PageUp
|
||||
vp.KeyMap.PageDown = messageKeys.PageDown
|
||||
vp.KeyMap.HalfPageUp = messageKeys.HalfPageUp
|
||||
vp.KeyMap.HalfPageDown = messageKeys.HalfPageDown
|
||||
return &messagesCmp{
|
||||
app: app,
|
||||
cachedContent: make(map[string]cacheItem),
|
||||
viewport: vp,
|
||||
spinner: s,
|
||||
app: app,
|
||||
cachedContent: make(map[string]cacheItem),
|
||||
viewport: vp,
|
||||
spinner: s,
|
||||
attachments: attachmets,
|
||||
showToolMessages: true,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,19 +6,18 @@ import (
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/glamour"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
"github.com/charmbracelet/x/ansi"
|
||||
"github.com/opencode-ai/opencode/internal/config"
|
||||
"github.com/opencode-ai/opencode/internal/diff"
|
||||
"github.com/opencode-ai/opencode/internal/llm/agent"
|
||||
"github.com/opencode-ai/opencode/internal/llm/models"
|
||||
"github.com/opencode-ai/opencode/internal/llm/tools"
|
||||
"github.com/opencode-ai/opencode/internal/message"
|
||||
"github.com/opencode-ai/opencode/internal/tui/styles"
|
||||
"github.com/sst/opencode/internal/config"
|
||||
"github.com/sst/opencode/internal/diff"
|
||||
"github.com/sst/opencode/internal/llm/agent"
|
||||
"github.com/sst/opencode/internal/llm/models"
|
||||
"github.com/sst/opencode/internal/llm/tools"
|
||||
"github.com/sst/opencode/internal/message"
|
||||
"github.com/sst/opencode/internal/tui/styles"
|
||||
"github.com/sst/opencode/internal/tui/theme"
|
||||
)
|
||||
|
||||
type uiMessageType int
|
||||
@@ -31,8 +30,6 @@ const (
|
||||
maxResultHeight = 10
|
||||
)
|
||||
|
||||
var diffStyle = diff.NewStyleConfig(diff.WithShowHeader(false), diff.WithShowHunkHeader(false))
|
||||
|
||||
type uiMessage struct {
|
||||
ID string
|
||||
messageType uiMessageType
|
||||
@@ -41,46 +38,37 @@ type uiMessage struct {
|
||||
content string
|
||||
}
|
||||
|
||||
type renderCache struct {
|
||||
mutex sync.Mutex
|
||||
cache map[string][]uiMessage
|
||||
}
|
||||
|
||||
func toMarkdown(content string, focused bool, width int) string {
|
||||
r, _ := glamour.NewTermRenderer(
|
||||
glamour.WithStyles(styles.MarkdownTheme(false)),
|
||||
glamour.WithWordWrap(width),
|
||||
)
|
||||
if focused {
|
||||
r, _ = glamour.NewTermRenderer(
|
||||
glamour.WithStyles(styles.MarkdownTheme(true)),
|
||||
glamour.WithWordWrap(width),
|
||||
)
|
||||
}
|
||||
r := styles.GetMarkdownRenderer(width)
|
||||
rendered, _ := r.Render(content)
|
||||
return rendered
|
||||
}
|
||||
|
||||
func renderMessage(msg string, isUser bool, isFocused bool, width int, info ...string) string {
|
||||
style := styles.BaseStyle.
|
||||
t := theme.CurrentTheme()
|
||||
|
||||
style := styles.BaseStyle().
|
||||
Width(width - 1).
|
||||
BorderLeft(true).
|
||||
Foreground(styles.ForgroundDim).
|
||||
BorderForeground(styles.PrimaryColor).
|
||||
Foreground(t.TextMuted()).
|
||||
BorderForeground(t.Primary()).
|
||||
BorderStyle(lipgloss.ThickBorder())
|
||||
|
||||
if isUser {
|
||||
style = style.
|
||||
BorderForeground(styles.Blue)
|
||||
}
|
||||
parts := []string{
|
||||
styles.ForceReplaceBackgroundWithLipgloss(toMarkdown(msg, isFocused, width), styles.Background),
|
||||
style = style.BorderForeground(t.Secondary())
|
||||
}
|
||||
|
||||
// remove newline at the end
|
||||
// Apply markdown formatting and handle background color
|
||||
parts := []string{
|
||||
styles.ForceReplaceBackgroundWithLipgloss(toMarkdown(msg, isFocused, width), t.Background()),
|
||||
}
|
||||
|
||||
// Remove newline at the end
|
||||
parts[0] = strings.TrimSuffix(parts[0], "\n")
|
||||
if len(info) > 0 {
|
||||
parts = append(parts, info...)
|
||||
}
|
||||
|
||||
rendered := style.Render(
|
||||
lipgloss.JoinVertical(
|
||||
lipgloss.Left,
|
||||
@@ -92,7 +80,41 @@ func renderMessage(msg string, isUser bool, isFocused bool, width int, info ...s
|
||||
}
|
||||
|
||||
func renderUserMessage(msg message.Message, isFocused bool, width int, position int) uiMessage {
|
||||
content := renderMessage(msg.Content().String(), true, isFocused, width)
|
||||
var styledAttachments []string
|
||||
t := theme.CurrentTheme()
|
||||
baseStyle := styles.BaseStyle()
|
||||
attachmentStyles := baseStyle.
|
||||
MarginLeft(1).
|
||||
Background(t.TextMuted()).
|
||||
Foreground(t.Text())
|
||||
for _, attachment := range msg.BinaryContent() {
|
||||
file := filepath.Base(attachment.Path)
|
||||
var filename string
|
||||
if len(file) > 10 {
|
||||
filename = fmt.Sprintf(" %s %s...", styles.DocumentIcon, file[0:7])
|
||||
} else {
|
||||
filename = fmt.Sprintf(" %s %s", styles.DocumentIcon, file)
|
||||
}
|
||||
styledAttachments = append(styledAttachments, attachmentStyles.Render(filename))
|
||||
}
|
||||
|
||||
// Add timestamp info
|
||||
info := []string{}
|
||||
timestamp := msg.CreatedAt.Local().Format("02 Jan 2006 03:04 PM")
|
||||
username, _ := config.GetUsername()
|
||||
info = append(info, baseStyle.
|
||||
Width(width-1).
|
||||
Foreground(t.TextMuted()).
|
||||
Render(fmt.Sprintf(" %s (%s)", username, timestamp)),
|
||||
)
|
||||
|
||||
content := ""
|
||||
if len(styledAttachments) > 0 {
|
||||
attachmentContent := baseStyle.Width(width).Render(lipgloss.JoinHorizontal(lipgloss.Left, styledAttachments...))
|
||||
content = renderMessage(msg.Content().String(), true, isFocused, width, append(info, attachmentContent)...)
|
||||
} else {
|
||||
content = renderMessage(msg.Content().String(), true, isFocused, width, info...)
|
||||
}
|
||||
userMsg := uiMessage{
|
||||
ID: msg.ID,
|
||||
messageType: userMessageType,
|
||||
@@ -112,37 +134,56 @@ func renderAssistantMessage(
|
||||
focusedUIMessageId string,
|
||||
width int,
|
||||
position int,
|
||||
showToolMessages bool,
|
||||
) []uiMessage {
|
||||
messages := []uiMessage{}
|
||||
content := msg.Content().String()
|
||||
content := strings.TrimSpace(msg.Content().String())
|
||||
thinking := msg.IsThinking()
|
||||
thinkingContent := msg.ReasoningContent().Thinking
|
||||
finished := msg.IsFinished()
|
||||
finishData := msg.FinishPart()
|
||||
info := []string{}
|
||||
|
||||
// Add finish info if available
|
||||
t := theme.CurrentTheme()
|
||||
baseStyle := styles.BaseStyle()
|
||||
|
||||
// Always add timestamp info
|
||||
timestamp := msg.CreatedAt.Local().Format("02 Jan 2006 03:04 PM")
|
||||
modelName := "Assistant"
|
||||
if msg.Model != "" {
|
||||
modelName = models.SupportedModels[msg.Model].Name
|
||||
}
|
||||
|
||||
info = append(info, baseStyle.
|
||||
Width(width-1).
|
||||
Foreground(t.TextMuted()).
|
||||
Render(fmt.Sprintf(" %s (%s)", modelName, timestamp)),
|
||||
)
|
||||
|
||||
if finished {
|
||||
// Add finish info if available
|
||||
switch finishData.Reason {
|
||||
case message.FinishReasonEndTurn:
|
||||
took := formatTimeDifference(msg.CreatedAt, finishData.Time)
|
||||
info = append(info, styles.BaseStyle.Width(width-1).Foreground(styles.ForgroundDim).Render(
|
||||
fmt.Sprintf(" %s (%s)", models.SupportedModels[msg.Model].Name, took),
|
||||
))
|
||||
case message.FinishReasonCanceled:
|
||||
info = append(info, styles.BaseStyle.Width(width-1).Foreground(styles.ForgroundDim).Render(
|
||||
fmt.Sprintf(" %s (%s)", models.SupportedModels[msg.Model].Name, "canceled"),
|
||||
))
|
||||
info = append(info, baseStyle.
|
||||
Width(width-1).
|
||||
Foreground(t.Warning()).
|
||||
Render("(canceled)"),
|
||||
)
|
||||
case message.FinishReasonError:
|
||||
info = append(info, styles.BaseStyle.Width(width-1).Foreground(styles.ForgroundDim).Render(
|
||||
fmt.Sprintf(" %s (%s)", models.SupportedModels[msg.Model].Name, "error"),
|
||||
))
|
||||
info = append(info, baseStyle.
|
||||
Width(width-1).
|
||||
Foreground(t.Error()).
|
||||
Render("(error)"),
|
||||
)
|
||||
case message.FinishReasonPermissionDenied:
|
||||
info = append(info, styles.BaseStyle.Width(width-1).Foreground(styles.ForgroundDim).Render(
|
||||
fmt.Sprintf(" %s (%s)", models.SupportedModels[msg.Model].Name, "permission denied"),
|
||||
))
|
||||
info = append(info, baseStyle.
|
||||
Width(width-1).
|
||||
Foreground(t.Info()).
|
||||
Render("(permission denied)"),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if content != "" || (finished && finishData.Reason == message.FinishReasonEndTurn) {
|
||||
if content == "" {
|
||||
content = "*Finished without output*"
|
||||
@@ -159,23 +200,35 @@ func renderAssistantMessage(
|
||||
position += messages[0].height
|
||||
position++ // for the space
|
||||
} else if thinking && thinkingContent != "" {
|
||||
// Render the thinking content
|
||||
content = renderMessage(thinkingContent, false, msg.ID == focusedUIMessageId, width)
|
||||
// Render the thinking content with timestamp
|
||||
content = renderMessage(thinkingContent, false, msg.ID == focusedUIMessageId, width, info...)
|
||||
messages = append(messages, uiMessage{
|
||||
ID: msg.ID,
|
||||
messageType: assistantMessageType,
|
||||
position: position,
|
||||
height: lipgloss.Height(content),
|
||||
content: content,
|
||||
})
|
||||
position += lipgloss.Height(content)
|
||||
position++ // for the space
|
||||
}
|
||||
|
||||
for i, toolCall := range msg.ToolCalls() {
|
||||
toolCallContent := renderToolMessage(
|
||||
toolCall,
|
||||
allMessages,
|
||||
messagesService,
|
||||
focusedUIMessageId,
|
||||
false,
|
||||
width,
|
||||
i+1,
|
||||
)
|
||||
messages = append(messages, toolCallContent)
|
||||
position += toolCallContent.height
|
||||
position++ // for the space
|
||||
// Only render tool messages if they should be shown
|
||||
if showToolMessages {
|
||||
for i, toolCall := range msg.ToolCalls() {
|
||||
toolCallContent := renderToolMessage(
|
||||
toolCall,
|
||||
allMessages,
|
||||
messagesService,
|
||||
focusedUIMessageId,
|
||||
false,
|
||||
width,
|
||||
i+1,
|
||||
)
|
||||
messages = append(messages, toolCallContent)
|
||||
position += toolCallContent.height
|
||||
position++ // for the space
|
||||
}
|
||||
}
|
||||
return messages
|
||||
}
|
||||
@@ -207,8 +260,6 @@ func toolName(name string) string {
|
||||
return "Grep"
|
||||
case tools.LSToolName:
|
||||
return "List"
|
||||
case tools.SourcegraphToolName:
|
||||
return "Sourcegraph"
|
||||
case tools.ViewToolName:
|
||||
return "View"
|
||||
case tools.WriteToolName:
|
||||
@@ -235,8 +286,6 @@ func getToolAction(name string) string {
|
||||
return "Searching content..."
|
||||
case tools.LSToolName:
|
||||
return "Listing directory..."
|
||||
case tools.SourcegraphToolName:
|
||||
return "Searching code..."
|
||||
case tools.ViewToolName:
|
||||
return "Reading file..."
|
||||
case tools.WriteToolName:
|
||||
@@ -375,10 +424,6 @@ func renderToolParams(paramWidth int, toolCall message.ToolCall) string {
|
||||
path = "."
|
||||
}
|
||||
return renderParams(paramWidth, path)
|
||||
case tools.SourcegraphToolName:
|
||||
var params tools.SourcegraphParams
|
||||
json.Unmarshal([]byte(toolCall.Input), ¶ms)
|
||||
return renderParams(paramWidth, params.Query)
|
||||
case tools.ViewToolName:
|
||||
var params tools.ViewParams
|
||||
json.Unmarshal([]byte(toolCall.Input), ¶ms)
|
||||
@@ -414,32 +459,35 @@ func truncateHeight(content string, height int) string {
|
||||
}
|
||||
|
||||
func renderToolResponse(toolCall message.ToolCall, response message.ToolResult, width int) string {
|
||||
t := theme.CurrentTheme()
|
||||
baseStyle := styles.BaseStyle()
|
||||
|
||||
if response.IsError {
|
||||
errContent := fmt.Sprintf("Error: %s", strings.ReplaceAll(response.Content, "\n", " "))
|
||||
errContent = ansi.Truncate(errContent, width-1, "...")
|
||||
return styles.BaseStyle.
|
||||
return baseStyle.
|
||||
Width(width).
|
||||
Foreground(styles.Error).
|
||||
Foreground(t.Error()).
|
||||
Render(errContent)
|
||||
}
|
||||
|
||||
resultContent := truncateHeight(response.Content, maxResultHeight)
|
||||
switch toolCall.Name {
|
||||
case agent.AgentToolName:
|
||||
return styles.ForceReplaceBackgroundWithLipgloss(
|
||||
toMarkdown(resultContent, false, width),
|
||||
styles.Background,
|
||||
t.Background(),
|
||||
)
|
||||
case tools.BashToolName:
|
||||
resultContent = fmt.Sprintf("```bash\n%s\n```", resultContent)
|
||||
return styles.ForceReplaceBackgroundWithLipgloss(
|
||||
toMarkdown(resultContent, true, width),
|
||||
styles.Background,
|
||||
t.Background(),
|
||||
)
|
||||
case tools.EditToolName:
|
||||
metadata := tools.EditResponseMetadata{}
|
||||
json.Unmarshal([]byte(response.Metadata), &metadata)
|
||||
truncDiff := truncateHeight(metadata.Diff, maxResultHeight)
|
||||
formattedDiff, _ := diff.FormatDiff(truncDiff, diff.WithTotalWidth(width), diff.WithStyle(diffStyle))
|
||||
formattedDiff, _ := diff.FormatDiff(metadata.Diff, diff.WithTotalWidth(width))
|
||||
return formattedDiff
|
||||
case tools.FetchToolName:
|
||||
var params tools.FetchParams
|
||||
@@ -454,16 +502,14 @@ func renderToolResponse(toolCall message.ToolCall, response message.ToolResult,
|
||||
resultContent = fmt.Sprintf("```%s\n%s\n```", mdFormat, resultContent)
|
||||
return styles.ForceReplaceBackgroundWithLipgloss(
|
||||
toMarkdown(resultContent, true, width),
|
||||
styles.Background,
|
||||
t.Background(),
|
||||
)
|
||||
case tools.GlobToolName:
|
||||
return styles.BaseStyle.Width(width).Foreground(styles.ForgroundMid).Render(resultContent)
|
||||
return baseStyle.Width(width).Foreground(t.TextMuted()).Render(resultContent)
|
||||
case tools.GrepToolName:
|
||||
return styles.BaseStyle.Width(width).Foreground(styles.ForgroundMid).Render(resultContent)
|
||||
return baseStyle.Width(width).Foreground(t.TextMuted()).Render(resultContent)
|
||||
case tools.LSToolName:
|
||||
return styles.BaseStyle.Width(width).Foreground(styles.ForgroundMid).Render(resultContent)
|
||||
case tools.SourcegraphToolName:
|
||||
return styles.BaseStyle.Width(width).Foreground(styles.ForgroundMid).Render(resultContent)
|
||||
return baseStyle.Width(width).Foreground(t.TextMuted()).Render(resultContent)
|
||||
case tools.ViewToolName:
|
||||
metadata := tools.ViewResponseMetadata{}
|
||||
json.Unmarshal([]byte(response.Metadata), &metadata)
|
||||
@@ -476,7 +522,7 @@ func renderToolResponse(toolCall message.ToolCall, response message.ToolResult,
|
||||
resultContent = fmt.Sprintf("```%s\n%s\n```", ext, truncateHeight(metadata.Content, maxResultHeight))
|
||||
return styles.ForceReplaceBackgroundWithLipgloss(
|
||||
toMarkdown(resultContent, true, width),
|
||||
styles.Background,
|
||||
t.Background(),
|
||||
)
|
||||
case tools.WriteToolName:
|
||||
params := tools.WriteParams{}
|
||||
@@ -492,13 +538,13 @@ func renderToolResponse(toolCall message.ToolCall, response message.ToolResult,
|
||||
resultContent = fmt.Sprintf("```%s\n%s\n```", ext, truncateHeight(params.Content, maxResultHeight))
|
||||
return styles.ForceReplaceBackgroundWithLipgloss(
|
||||
toMarkdown(resultContent, true, width),
|
||||
styles.Background,
|
||||
t.Background(),
|
||||
)
|
||||
default:
|
||||
resultContent = fmt.Sprintf("```text\n%s\n```", resultContent)
|
||||
return styles.ForceReplaceBackgroundWithLipgloss(
|
||||
toMarkdown(resultContent, true, width),
|
||||
styles.Background,
|
||||
t.Background(),
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -515,39 +561,31 @@ func renderToolMessage(
|
||||
if nested {
|
||||
width = width - 3
|
||||
}
|
||||
style := styles.BaseStyle.
|
||||
|
||||
t := theme.CurrentTheme()
|
||||
baseStyle := styles.BaseStyle()
|
||||
|
||||
style := baseStyle.
|
||||
Width(width - 1).
|
||||
BorderLeft(true).
|
||||
BorderStyle(lipgloss.ThickBorder()).
|
||||
PaddingLeft(1).
|
||||
BorderForeground(styles.ForgroundDim)
|
||||
BorderForeground(t.TextMuted())
|
||||
|
||||
response := findToolResponse(toolCall.ID, allMessages)
|
||||
toolName := styles.BaseStyle.Foreground(styles.ForgroundDim).Render(fmt.Sprintf("%s: ", toolName(toolCall.Name)))
|
||||
toolNameText := baseStyle.Foreground(t.TextMuted()).
|
||||
Render(fmt.Sprintf("%s: ", toolName(toolCall.Name)))
|
||||
|
||||
if !toolCall.Finished {
|
||||
// Get a brief description of what the tool is doing
|
||||
toolAction := getToolAction(toolCall.Name)
|
||||
|
||||
// toolInput := strings.ReplaceAll(toolCall.Input, "\n", " ")
|
||||
// truncatedInput := toolInput
|
||||
// if len(truncatedInput) > 10 {
|
||||
// truncatedInput = truncatedInput[len(truncatedInput)-10:]
|
||||
// }
|
||||
//
|
||||
// truncatedInput = styles.BaseStyle.
|
||||
// Italic(true).
|
||||
// Width(width - 2 - lipgloss.Width(toolName)).
|
||||
// Background(styles.BackgroundDim).
|
||||
// Foreground(styles.ForgroundMid).
|
||||
// Render(truncatedInput)
|
||||
|
||||
progressText := styles.BaseStyle.
|
||||
Width(width - 2 - lipgloss.Width(toolName)).
|
||||
Foreground(styles.ForgroundDim).
|
||||
progressText := baseStyle.
|
||||
Width(width - 2 - lipgloss.Width(toolNameText)).
|
||||
Foreground(t.TextMuted()).
|
||||
Render(fmt.Sprintf("%s", toolAction))
|
||||
|
||||
content := style.Render(lipgloss.JoinHorizontal(lipgloss.Left, toolName, progressText))
|
||||
content := style.Render(lipgloss.JoinHorizontal(lipgloss.Left, toolNameText, progressText))
|
||||
toolMsg := uiMessage{
|
||||
messageType: toolMessageType,
|
||||
position: position,
|
||||
@@ -556,37 +594,39 @@ func renderToolMessage(
|
||||
}
|
||||
return toolMsg
|
||||
}
|
||||
params := renderToolParams(width-2-lipgloss.Width(toolName), toolCall)
|
||||
|
||||
params := renderToolParams(width-1-lipgloss.Width(toolNameText), toolCall)
|
||||
responseContent := ""
|
||||
if response != nil {
|
||||
responseContent = renderToolResponse(toolCall, *response, width-2)
|
||||
responseContent = strings.TrimSuffix(responseContent, "\n")
|
||||
} else {
|
||||
responseContent = styles.BaseStyle.
|
||||
responseContent = baseStyle.
|
||||
Italic(true).
|
||||
Width(width - 2).
|
||||
Foreground(styles.ForgroundDim).
|
||||
Foreground(t.TextMuted()).
|
||||
Render("Waiting for response...")
|
||||
}
|
||||
|
||||
parts := []string{}
|
||||
if !nested {
|
||||
params := styles.BaseStyle.
|
||||
Width(width - 2 - lipgloss.Width(toolName)).
|
||||
Foreground(styles.ForgroundDim).
|
||||
formattedParams := baseStyle.
|
||||
Width(width - 2 - lipgloss.Width(toolNameText)).
|
||||
Foreground(t.TextMuted()).
|
||||
Render(params)
|
||||
|
||||
parts = append(parts, lipgloss.JoinHorizontal(lipgloss.Left, toolName, params))
|
||||
parts = append(parts, lipgloss.JoinHorizontal(lipgloss.Left, toolNameText, formattedParams))
|
||||
} else {
|
||||
prefix := styles.BaseStyle.
|
||||
Foreground(styles.ForgroundDim).
|
||||
prefix := baseStyle.
|
||||
Foreground(t.TextMuted()).
|
||||
Render(" └ ")
|
||||
params := styles.BaseStyle.
|
||||
Width(width - 2 - lipgloss.Width(toolName)).
|
||||
Foreground(styles.ForgroundMid).
|
||||
formattedParams := baseStyle.
|
||||
Width(width - 2 - lipgloss.Width(toolNameText)).
|
||||
Foreground(t.TextMuted()).
|
||||
Render(params)
|
||||
parts = append(parts, lipgloss.JoinHorizontal(lipgloss.Left, prefix, toolName, params))
|
||||
parts = append(parts, lipgloss.JoinHorizontal(lipgloss.Left, prefix, toolNameText, formattedParams))
|
||||
}
|
||||
|
||||
if toolCall.Name == agent.AgentToolName {
|
||||
taskMessages, _ := messagesService.List(context.Background(), toolCall.ID)
|
||||
toolCalls := []message.ToolCall{}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user