mirror of
https://github.com/anomalyco/opencode.git
synced 2026-05-03 11:16:46 +00:00
replace github.com/google/generative-ai-go with github.com/googleapis/go-genai (#138)
* replace to github.com/googleapis/go-genai * fix history logic * small fixes --------- Co-authored-by: Kujtim Hoxha <kujtimii.h@gmail.com>
This commit is contained in:
@@ -9,14 +9,12 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/generative-ai-go/genai"
|
||||
"github.com/google/uuid"
|
||||
"github.com/opencode-ai/opencode/internal/config"
|
||||
"github.com/opencode-ai/opencode/internal/llm/tools"
|
||||
"github.com/opencode-ai/opencode/internal/logging"
|
||||
"github.com/opencode-ai/opencode/internal/message"
|
||||
"google.golang.org/api/iterator"
|
||||
"google.golang.org/api/option"
|
||||
"google.golang.org/genai"
|
||||
)
|
||||
|
||||
type geminiOptions struct {
|
||||
@@ -39,7 +37,7 @@ func newGeminiClient(opts providerClientOptions) GeminiClient {
|
||||
o(&geminiOpts)
|
||||
}
|
||||
|
||||
client, err := genai.NewClient(context.Background(), option.WithAPIKey(opts.apiKey))
|
||||
client, err := genai.NewClient(context.Background(), &genai.ClientConfig{APIKey: opts.apiKey, Backend: genai.BackendGeminiAPI})
|
||||
if err != nil {
|
||||
logging.Error("Failed to create Gemini client", "error", err)
|
||||
return nil
|
||||
@@ -57,11 +55,14 @@ func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Cont
|
||||
for _, msg := range messages {
|
||||
switch msg.Role {
|
||||
case message.User:
|
||||
var parts []genai.Part
|
||||
parts = append(parts, genai.Text(msg.Content().String()))
|
||||
var parts []*genai.Part
|
||||
parts = append(parts, &genai.Part{Text: msg.Content().String()})
|
||||
for _, binaryContent := range msg.BinaryContent() {
|
||||
imageFormat := strings.Split(binaryContent.MIMEType, "/")
|
||||
parts = append(parts, genai.ImageData(imageFormat[1], binaryContent.Data))
|
||||
parts = append(parts, &genai.Part{InlineData: &genai.Blob{
|
||||
MIMEType: imageFormat[1],
|
||||
Data: binaryContent.Data,
|
||||
}})
|
||||
}
|
||||
history = append(history, &genai.Content{
|
||||
Parts: parts,
|
||||
@@ -70,19 +71,21 @@ func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Cont
|
||||
case message.Assistant:
|
||||
content := &genai.Content{
|
||||
Role: "model",
|
||||
Parts: []genai.Part{},
|
||||
Parts: []*genai.Part{},
|
||||
}
|
||||
|
||||
if msg.Content().String() != "" {
|
||||
content.Parts = append(content.Parts, genai.Text(msg.Content().String()))
|
||||
content.Parts = append(content.Parts, &genai.Part{Text: msg.Content().String()})
|
||||
}
|
||||
|
||||
if len(msg.ToolCalls()) > 0 {
|
||||
for _, call := range msg.ToolCalls() {
|
||||
args, _ := parseJsonToMap(call.Input)
|
||||
content.Parts = append(content.Parts, genai.FunctionCall{
|
||||
Name: call.Name,
|
||||
Args: args,
|
||||
content.Parts = append(content.Parts, &genai.Part{
|
||||
FunctionCall: &genai.FunctionCall{
|
||||
Name: call.Name,
|
||||
Args: args,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -110,10 +113,14 @@ func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Cont
|
||||
}
|
||||
|
||||
history = append(history, &genai.Content{
|
||||
Parts: []genai.Part{genai.FunctionResponse{
|
||||
Name: toolCall.Name,
|
||||
Response: response,
|
||||
}},
|
||||
Parts: []*genai.Part{
|
||||
{
|
||||
FunctionResponse: &genai.FunctionResponse{
|
||||
Name: toolCall.Name,
|
||||
Response: response,
|
||||
},
|
||||
},
|
||||
},
|
||||
Role: "function",
|
||||
})
|
||||
}
|
||||
@@ -157,18 +164,6 @@ func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishRea
|
||||
}
|
||||
|
||||
func (g *geminiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
|
||||
model := g.client.GenerativeModel(g.providerOptions.model.APIModel)
|
||||
model.SetMaxOutputTokens(int32(g.providerOptions.maxTokens))
|
||||
model.SystemInstruction = &genai.Content{
|
||||
Parts: []genai.Part{
|
||||
genai.Text(g.providerOptions.systemMessage),
|
||||
},
|
||||
}
|
||||
// Convert tools
|
||||
if len(tools) > 0 {
|
||||
model.Tools = g.convertTools(tools)
|
||||
}
|
||||
|
||||
// Convert messages
|
||||
geminiMessages := g.convertMessages(messages)
|
||||
|
||||
@@ -178,16 +173,26 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too
|
||||
logging.Debug("Prepared messages", "messages", string(jsonData))
|
||||
}
|
||||
|
||||
history := geminiMessages[:len(geminiMessages)-1] // All but last message
|
||||
lastMsg := geminiMessages[len(geminiMessages)-1]
|
||||
chat, _ := g.client.Chats.Create(ctx, g.providerOptions.model.APIModel, &genai.GenerateContentConfig{
|
||||
MaxOutputTokens: int32(g.providerOptions.maxTokens),
|
||||
SystemInstruction: &genai.Content{
|
||||
Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}},
|
||||
},
|
||||
Tools: g.convertTools(tools),
|
||||
}, history)
|
||||
|
||||
attempts := 0
|
||||
for {
|
||||
attempts++
|
||||
var toolCalls []message.ToolCall
|
||||
chat := model.StartChat()
|
||||
chat.History = geminiMessages[:len(geminiMessages)-1] // All but last message
|
||||
|
||||
lastMsg := geminiMessages[len(geminiMessages)-1]
|
||||
|
||||
resp, err := chat.SendMessage(ctx, lastMsg.Parts...)
|
||||
var lastMsgParts []genai.Part
|
||||
for _, part := range lastMsg.Parts {
|
||||
lastMsgParts = append(lastMsgParts, *part)
|
||||
}
|
||||
resp, err := chat.SendMessage(ctx, lastMsgParts...)
|
||||
// If there is an error we are going to see if we can retry the call
|
||||
if err != nil {
|
||||
retry, after, retryErr := g.shouldRetry(attempts, err)
|
||||
@@ -210,15 +215,15 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too
|
||||
|
||||
if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
|
||||
for _, part := range resp.Candidates[0].Content.Parts {
|
||||
switch p := part.(type) {
|
||||
case genai.Text:
|
||||
content = string(p)
|
||||
case genai.FunctionCall:
|
||||
switch {
|
||||
case part.Text != "":
|
||||
content = string(part.Text)
|
||||
case part.FunctionCall != nil:
|
||||
id := "call_" + uuid.New().String()
|
||||
args, _ := json.Marshal(p.Args)
|
||||
args, _ := json.Marshal(part.FunctionCall.Args)
|
||||
toolCalls = append(toolCalls, message.ToolCall{
|
||||
ID: id,
|
||||
Name: p.Name,
|
||||
Name: part.FunctionCall.Name,
|
||||
Input: string(args),
|
||||
Type: "function",
|
||||
Finished: true,
|
||||
@@ -244,18 +249,6 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too
|
||||
}
|
||||
|
||||
func (g *geminiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
|
||||
model := g.client.GenerativeModel(g.providerOptions.model.APIModel)
|
||||
model.SetMaxOutputTokens(int32(g.providerOptions.maxTokens))
|
||||
model.SystemInstruction = &genai.Content{
|
||||
Parts: []genai.Part{
|
||||
genai.Text(g.providerOptions.systemMessage),
|
||||
},
|
||||
}
|
||||
// Convert tools
|
||||
if len(tools) > 0 {
|
||||
model.Tools = g.convertTools(tools)
|
||||
}
|
||||
|
||||
// Convert messages
|
||||
geminiMessages := g.convertMessages(messages)
|
||||
|
||||
@@ -265,6 +258,16 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
|
||||
logging.Debug("Prepared messages", "messages", string(jsonData))
|
||||
}
|
||||
|
||||
history := geminiMessages[:len(geminiMessages)-1] // All but last message
|
||||
lastMsg := geminiMessages[len(geminiMessages)-1]
|
||||
chat, _ := g.client.Chats.Create(ctx, g.providerOptions.model.APIModel, &genai.GenerateContentConfig{
|
||||
MaxOutputTokens: int32(g.providerOptions.maxTokens),
|
||||
SystemInstruction: &genai.Content{
|
||||
Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}},
|
||||
},
|
||||
Tools: g.convertTools(tools),
|
||||
}, history)
|
||||
|
||||
attempts := 0
|
||||
eventChan := make(chan ProviderEvent)
|
||||
|
||||
@@ -273,11 +276,6 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
|
||||
|
||||
for {
|
||||
attempts++
|
||||
chat := model.StartChat()
|
||||
chat.History = geminiMessages[:len(geminiMessages)-1]
|
||||
lastMsg := geminiMessages[len(geminiMessages)-1]
|
||||
|
||||
iter := chat.SendMessageStream(ctx, lastMsg.Parts...)
|
||||
|
||||
currentContent := ""
|
||||
toolCalls := []message.ToolCall{}
|
||||
@@ -285,11 +283,12 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
|
||||
|
||||
eventChan <- ProviderEvent{Type: EventContentStart}
|
||||
|
||||
for {
|
||||
resp, err := iter.Next()
|
||||
if err == iterator.Done {
|
||||
break
|
||||
}
|
||||
var lastMsgParts []genai.Part
|
||||
|
||||
for _, part := range lastMsg.Parts {
|
||||
lastMsgParts = append(lastMsgParts, *part)
|
||||
}
|
||||
for resp, err := range chat.SendMessageStream(ctx, lastMsgParts...) {
|
||||
if err != nil {
|
||||
retry, after, retryErr := g.shouldRetry(attempts, err)
|
||||
if retryErr != nil {
|
||||
@@ -318,9 +317,9 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
|
||||
|
||||
if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
|
||||
for _, part := range resp.Candidates[0].Content.Parts {
|
||||
switch p := part.(type) {
|
||||
case genai.Text:
|
||||
delta := string(p)
|
||||
switch {
|
||||
case part.Text != "":
|
||||
delta := string(part.Text)
|
||||
if delta != "" {
|
||||
eventChan <- ProviderEvent{
|
||||
Type: EventContentDelta,
|
||||
@@ -328,12 +327,12 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
|
||||
}
|
||||
currentContent += delta
|
||||
}
|
||||
case genai.FunctionCall:
|
||||
case part.FunctionCall != nil:
|
||||
id := "call_" + uuid.New().String()
|
||||
args, _ := json.Marshal(p.Args)
|
||||
args, _ := json.Marshal(part.FunctionCall.Args)
|
||||
newCall := message.ToolCall{
|
||||
ID: id,
|
||||
Name: p.Name,
|
||||
Name: part.FunctionCall.Name,
|
||||
Input: string(args),
|
||||
Type: "function",
|
||||
Finished: true,
|
||||
@@ -421,12 +420,12 @@ func (g *geminiClient) toolCalls(resp *genai.GenerateContentResponse) []message.
|
||||
|
||||
if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
|
||||
for _, part := range resp.Candidates[0].Content.Parts {
|
||||
if funcCall, ok := part.(genai.FunctionCall); ok {
|
||||
if part.FunctionCall != nil {
|
||||
id := "call_" + uuid.New().String()
|
||||
args, _ := json.Marshal(funcCall.Args)
|
||||
args, _ := json.Marshal(part.FunctionCall.Args)
|
||||
toolCalls = append(toolCalls, message.ToolCall{
|
||||
ID: id,
|
||||
Name: funcCall.Name,
|
||||
Name: part.FunctionCall.Name,
|
||||
Input: string(args),
|
||||
Type: "function",
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user