mirror of
https://github.com/sourcegraph/sourcegraph.git
synced 2026-02-06 12:51:55 +00:00
Cody Gateway: Add Gemini models to PLG and Enterprise users (#63053)
CLOSE https://github.com/sourcegraph/cody-issues/issues/211 & https://github.com/sourcegraph/cody-issues/issues/412 & https://github.com/sourcegraph/cody-issues/issues/412 UNBLOCK https://github.com/sourcegraph/cody/pull/4360 * Add support for Google Gemini AI models as chat completions provider * Add new `google` package to handle Google Generative AI client * Update `client.go` and `codygateway.go` to handle the new Google provider * Set default models for chat, fast chat, and completions when Google is the configured provider * Add gemini-pro to the allowed list <!-- 💡 To write a useful PR description, make sure that your description covers: - WHAT this PR is changing: - How was it PREVIOUSLY. - How it will be from NOW on. - WHY this PR is needed. - CONTEXT, i.e. to which initiative, project or RFC it belongs. The structure of the description doesn't matter as much as covering these points, so use your best judgement based on your context. Learn how to write good pull request description: https://www.notion.so/sourcegraph/Write-a-good-pull-request-description-610a7fd3e613496eb76f450db5a49b6e?pvs=4 --> ## Test plan <!-- All pull requests REQUIRE a test plan: https://docs-legacy.sourcegraph.com/dev/background-information/testing_principles --> For Enterprise instances using google as provider: 1. In your Soucegraph local instance's Site Config, add the following: ``` "accessToken": "REDACTED", "chatModel": "gemini-1.5-pro-latest", "provider": "google", ``` Note: You can get the accessToken for Gemini API in 1Password. 2. After saving the site config with the above change, run the following curl command: ``` curl 'https://sourcegraph.test:3443/.api/completions/stream' -i \ -X POST \ -H 'authorization: token $LOCAL_INSTANCE_TOKEN' \ --data-raw '{"messages":[{"speaker":"human","text":"Who are you?"}],"maxTokensToSample":30,"temperature":0,"stopSequences":[],"timeoutMs":5000,"stream":true,"model":"gemini-1.5-pro-latest"}' ``` 3. Expected Output: ``` ❯ curl 'https://sourcegraph.test:3443/.api/completions/stream' -i \ -X POST \ -H 'authorization: token <REDACTED>' \ --data-raw '{"messages":[{"speaker":"human","text":"Who are you?"}],"maxTokensToSample":30,"temperature":0,"stopSequences":[],"timeoutMs":5000,"stream":true,"model":"gemini-1.5-pro-latest"}' HTTP/2 200 access-control-allow-credentials: true access-control-allow-origin: alt-svc: h3=":3443"; ma=2592000 cache-control: no-cache content-type: text/event-stream date: Tue, 04 Jun 2024 05:45:33 GMT server: Caddy server: Caddy vary: Accept-Encoding, Authorization, Cookie, Authorization, X-Requested-With, Cookie x-accel-buffering: no x-content-type-options: nosniff x-frame-options: DENY x-powered-by: Express x-trace: d4b1f02a3e2882a3d52331335d217b03 x-trace-span: 728ec33860d3b5e6 x-trace-url: https://sourcegraph.test:3443/-/debug/jaeger/trace/d4b1f02a3e2882a3d52331335d217b03 x-xss-protection: 1; mode=block event: completion data: {"completion":"I","stopReason":"STOP"} event: completion data: {"completion":"I am a large language model, trained by Google. \n\nThink of me as","stopReason":"STOP"} event: completion data: {"completion":"I am a large language model, trained by Google. \n\nThink of me as a computer program that can understand and generate human-like text.","stopReason":"MAX_TOKENS"} event: done data: {} ``` Verified locally:  #### Before Cody Gateway returns `no client known for upstream provider google` ```sh curl -X 'POST' -d '{"messages":[{"speaker":"human","text":"Who are you?"}],"maxTokensToSample":30,"temperature":0,"stopSequences":[],"timeoutMs":5000,"stream":true,"model":"google/gemini-1.5-pro-latest"}' -H 'Accept: application/json' -H 'Authorization: token $YOUR_DOTCOM_TOKEN' -H 'Content-Type: application/json' 'https://sourcegraph.com/.api/completions/stream' event: error data: {"error":"no client known for upstream provider google"} event: done data: { ``` ## Changelog <!-- 1. Ensure your pull request title is formatted as: $type($domain): $what 2. Add bullet list items for each additional detail you want to cover (see example below) 5. You can edit this after the pull request was merged, as long as release shipping it hasn't been promoted to the public. 6. For more information, please see this how-to https://www.notion.so/sourcegraph/Writing-a-changelog-entry-dd997f411d524caabf0d8d38a24a878c? Audience: TS/CSE > Customers > Teammates (in that order). Cheat sheet: $type = chore|fix|feat $domain: source|search|ci|release|plg|cody|local|... --> <!-- Example: Title: fix(search): parse quotes with the appropriate context Changelog section: ## Changelog - When a quote is used with regexp pattern type, then ... - Refactored underlying code. --> Added support for Google as an LLM provider for Cody, with the following models available through Cody Gateway: Gemini Pro (`gemini-pro-latest`), Gemini 1.5 Flash (`gemini-1.5-flash-latest`), and Gemini 1.5 Pro (`gemini-1.5-pro-latest`).
This commit is contained in:
parent
f952ceb8da
commit
f2590cbb36
@ -22,6 +22,7 @@ All notable changes to Sourcegraph are documented in this file.
|
||||
- A feature flag for Cody, `completions.smartContextWindow` is added and set to "enabled" by default. It allows clients to adjust the context window based on the name of the chat model. When smartContextWindow is enabled, the `completions.chatModelMaxTokens` value is ignored. ([#62802](https://github.com/sourcegraph/sourcegraph/pull/62802))
|
||||
- Code Insights: When facing the "incomplete datapoints" warning, you can now use GraphQL to discover which repositories had problems. The schemas for `TimeoutDatapointAlert` and `GenericIncompleteDatapointAlert` now contain an additional `repositories` field. ([#62756](https://github.com/sourcegraph/sourcegraph/pull/62756)).
|
||||
- Users will now be presented with a modal that reminds them to connect any external code host accounts that's required for permissions. Without these accounts connected, users may be unable to view repositories that they otherwise have access to. [#62983](https://github.com/sourcegraph/sourcegraph/pull/62983)
|
||||
- Added support for Google as an LLM provider for Cody, with the following models available through Cody Gateway: Gemini Pro (`gemini-pro-latest`), Gemini 1.5 Flash (`gemini-1.5-flash-latest`), and Gemini 1.5 Pro (`gemini-1.5-pro-latest`). [#63053](https://github.com/sourcegraph/sourcegraph/pull/63053)
|
||||
|
||||
### Changed
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import React from 'react'
|
||||
import type React from 'react'
|
||||
|
||||
import { Badge } from '@sourcegraph/wildcard'
|
||||
|
||||
@ -51,8 +51,11 @@ function modelBadgeVariant(model: string, mode: 'completions' | 'embeddings'): '
|
||||
case 'openai/gpt-4o':
|
||||
case 'openai/gpt-4-turbo':
|
||||
case 'openai/gpt-4-turbo-preview':
|
||||
// For currently available Google Gemini models,
|
||||
// see: https://ai.google.dev/gemini-api/docs/models/gemini
|
||||
case 'google/gemini-1.5-flash-latest':
|
||||
case 'google/gemini-1.5-pro-latest':
|
||||
case 'google/gemini-pro-latest':
|
||||
// Virtual models that are translated by Cody Gateway and allow access to all StarCoder
|
||||
// models hosted for us by Fireworks.
|
||||
case 'fireworks/starcoder':
|
||||
|
||||
@ -421,7 +421,7 @@ export const SiteAdminPingsPage: React.FunctionComponent<React.PropsWithChildren
|
||||
<ul>
|
||||
<li>
|
||||
Provider (e.g., "sourcegraph", "anthropic", "openai", "azure-openai",
|
||||
"fireworks", "aws-bedrock", etc.)
|
||||
"fireworks", "aws-bedrock", "google", etc.)
|
||||
</li>
|
||||
<li>Chat model (included only for "sourcegraph" provider)</li>
|
||||
<li>Fast chat model (included only for "sourcegraph" provider)</li>
|
||||
|
||||
@ -38,6 +38,15 @@ func NewGoogleHandler(baseLogger log.Logger, eventLogger events.Logger, rs limit
|
||||
)
|
||||
}
|
||||
|
||||
// The request body for Google completions.
|
||||
// Ref: https://ai.google.dev/api/rest/v1/models/generateContent#request-body
|
||||
type googleRequest struct {
|
||||
Model string `json:"model"`
|
||||
Contents []googleContentMessage `json:"contents"`
|
||||
GenerationConfig googleGenerationConfig `json:"generationConfig,omitempty"`
|
||||
SafetySettings []googleSafetySettings `json:"safetySettings,omitempty"`
|
||||
}
|
||||
|
||||
type googleContentMessage struct {
|
||||
Role string `json:"role"`
|
||||
Parts []struct {
|
||||
@ -45,12 +54,22 @@ type googleContentMessage struct {
|
||||
} `json:"parts"`
|
||||
}
|
||||
|
||||
type googleRequest struct {
|
||||
Model string `json:"model"`
|
||||
Contents []googleContentMessage `json:"contents"`
|
||||
GenerationConfig struct {
|
||||
MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
|
||||
} `json:"generationConfig,omitempty"`
|
||||
// Configuration options for model generation and outputs.
|
||||
// Ref: https://ai.google.dev/api/rest/v1/GenerationConfig
|
||||
type googleGenerationConfig struct {
|
||||
Temperature float32 `json:"temperature,omitempty"` // request.Temperature
|
||||
TopP float32 `json:"topP,omitempty"` // request.TopP
|
||||
TopK int `json:"topK,omitempty"` // request.TopK
|
||||
StopSequences []string `json:"stopSequences,omitempty"` // request.StopSequences
|
||||
MaxOutputTokens int `json:"maxOutputTokens,omitempty"` // request.MaxTokensToSample
|
||||
CandidateCount int `json:"candidateCount,omitempty"` // request.CandidateCount
|
||||
}
|
||||
|
||||
// Safety setting, affecting the safety-blocking behavior.
|
||||
// Ref: https://ai.google.dev/gemini-api/docs/safety-settings
|
||||
type googleSafetySettings struct {
|
||||
Category string `json:"category"`
|
||||
Threshold string `json:"threshold"`
|
||||
}
|
||||
|
||||
func (r googleRequest) ShouldStream() bool {
|
||||
@ -152,11 +171,8 @@ func (*GoogleHandlerMethods) parseResponseAndUsage(logger log.Logger, reqBody go
|
||||
}
|
||||
|
||||
// Otherwise, we have to parse the event stream.
|
||||
promptUsage.tokens = -1
|
||||
promptUsage.tokenizerTokens = -1
|
||||
completionUsage.tokens = -1
|
||||
completionUsage.tokenizerTokens = -1
|
||||
|
||||
promptUsage.tokens, completionUsage.tokens = -1, -1
|
||||
promptUsage.tokenizerTokens, completionUsage.tokenizerTokens = -1, -1
|
||||
promptTokens, completionTokens, err := parseGoogleTokenUsage(r, logger)
|
||||
if err != nil {
|
||||
logger.Error("failed to decode Google streaming response", log.Error(err))
|
||||
@ -173,35 +189,22 @@ const maxPayloadSize = 10 * 1024 * 1024 // 10mb
|
||||
func parseGoogleTokenUsage(r io.Reader, logger log.Logger) (promptTokens int, completionTokens int, err error) {
|
||||
scanner := bufio.NewScanner(r)
|
||||
scanner.Buffer(make([]byte, 0, 4096), maxPayloadSize)
|
||||
// bufio.ScanLines, except we look for \r\n\r\n which separate events.
|
||||
split := func(data []byte, atEOF bool) (int, []byte, error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := bytes.Index(data, []byte("\r\n\r\n")); i >= 0 {
|
||||
return i + 4, data[:i], nil
|
||||
}
|
||||
if i := bytes.Index(data, []byte("\n\n")); i >= 0 {
|
||||
return i + 2, data[:i], nil
|
||||
}
|
||||
// If we're at EOF, we have a final, non-terminated event. This should
|
||||
// be empty.
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
// Request more data.
|
||||
return 0, nil, nil
|
||||
}
|
||||
scanner.Split(split)
|
||||
// skip to the last event
|
||||
var lastEvent []byte
|
||||
scanner.Split(bufio.ScanLines)
|
||||
|
||||
var lastLine []byte
|
||||
for scanner.Scan() {
|
||||
lastEvent = scanner.Bytes()
|
||||
lastLine = scanner.Bytes()
|
||||
}
|
||||
var res googleResponse
|
||||
if err := json.NewDecoder(bytes.NewReader(lastEvent[5:])).Decode(&res); err != nil {
|
||||
logger.Error("failed to parse Google response as JSON", log.Error(err))
|
||||
return -1, -1, err
|
||||
|
||||
if bytes.HasPrefix(bytes.TrimSpace(lastLine), []byte("data: ")) {
|
||||
event := lastLine[5:]
|
||||
var res googleResponse
|
||||
if err := json.NewDecoder(bytes.NewReader(event)).Decode(&res); err != nil {
|
||||
logger.Error("failed to parse Google response as JSON", log.Error(err))
|
||||
return -1, -1, err
|
||||
}
|
||||
return res.UsageMetadata.PromptTokenCount, res.UsageMetadata.CompletionTokenCount, nil
|
||||
}
|
||||
return res.UsageMetadata.PromptTokenCount, res.UsageMetadata.CompletionTokenCount, nil
|
||||
|
||||
return -1, -1, errors.New("no Google response found")
|
||||
}
|
||||
|
||||
@ -4,6 +4,8 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"bytes"
|
||||
|
||||
"github.com/sourcegraph/log/logtest"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
@ -31,3 +33,104 @@ data: {"candidates": [{"content": {"parts": [{"text": "\n for i in range(n-1):\
|
||||
data: {"candidates": [{"content": {"parts": [{"text": " range(n-i-1):\n if list1[j] \u003e list1[j+1]:\n list1[j], list"}],"role": "model"},"finishReason": "STOP","index": 0,"safetyRatings": [{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HATE_SPEECH","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HARASSMENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_DANGEROUS_CONTENT","probability": "NEGLIGIBLE"}]}],"usageMetadata": {"promptTokenCount": 21,"candidatesTokenCount": 63,"totalTokenCount": 84}}
|
||||
|
||||
data: {"candidates": [{"content": {"parts": [{"text": "1[j+1] = list1[j+1], list1[j]\n return list1\n"}],"role": "model"},"finishReason": "STOP","index": 0,"safetyRatings": [{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HATE_SPEECH","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HARASSMENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_DANGEROUS_CONTENT","probability": "NEGLIGIBLE"}],"citationMetadata": {"citationSources": [{"startIndex": 1,"endIndex": 185,"uri": "https://github.com/Feng080412/Searches-and-sorts","license": ""}]}}],"usageMetadata": {"promptTokenCount": 21,"candidatesTokenCount": 87,"totalTokenCount": 108}}`
|
||||
|
||||
func TestParseGoogleTokenUsage(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want *googleResponse
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid response",
|
||||
input: `data: {"candidates": [{"content": {"parts": [{"text": "def"}],"role": "model"},"finishReason": "STOP","index": 0}],"usageMetadata": {"promptTokenCount": 21,"candidatesTokenCount": 1,"totalTokenCount": 22}}`,
|
||||
want: &googleResponse{
|
||||
UsageMetadata: googleUsage{
|
||||
PromptTokenCount: 21,
|
||||
CompletionTokenCount: 1,
|
||||
TotalTokenCount: 0,
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid response - with candidates",
|
||||
input: `data: {"usageMetadata": {"promptTokenCount": 10, "candidatesTokenCount": 20}}`,
|
||||
want: &googleResponse{
|
||||
UsageMetadata: googleUsage{
|
||||
PromptTokenCount: 10,
|
||||
CompletionTokenCount: 20,
|
||||
TotalTokenCount: 0,
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid JSON",
|
||||
input: `data: {"usageMetadata": {"promptTokenCount": 10, "candidatesTokenCount": 20}`,
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "no prefix",
|
||||
input: `{"usageMetadata": {"promptTokenCount": 10, "candidatesTokenCount": 20}}`,
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty input",
|
||||
input: ``,
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "multiple lines with one valid",
|
||||
input: `data: {"usageMetadata": {"promptTokenCount": 5, "candidatesTokenCount": 15}}
|
||||
|
||||
data: {"usageMetadata": {"promptTokenCount": 10, "candidatesTokenCount": 20}}`,
|
||||
want: &googleResponse{
|
||||
UsageMetadata: googleUsage{
|
||||
PromptTokenCount: 10,
|
||||
CompletionTokenCount: 20,
|
||||
TotalTokenCount: 0,
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "non-JSON data",
|
||||
input: `data: not-a-json`,
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "partial data",
|
||||
input: `data: {"usageMetadata": {"promptTokenCount": 10`,
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := bytes.NewReader([]byte(tt.input))
|
||||
logger := logtest.Scoped(t)
|
||||
promptTokens, completionTokens, err := parseGoogleTokenUsage(r, logger)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("parseGoogleTokenUsage() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if tt.want != nil {
|
||||
got := &googleResponse{
|
||||
UsageMetadata: googleUsage{
|
||||
PromptTokenCount: promptTokens,
|
||||
CompletionTokenCount: completionTokens,
|
||||
},
|
||||
}
|
||||
if !assert.ObjectsAreEqual(got, tt.want) {
|
||||
t.Errorf("parseGoogleTokenUsage() mismatch (-want +got):\n%v", assert.ObjectsAreEqual(got, tt.want))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -12,6 +12,7 @@ go_library(
|
||||
"//internal/collections",
|
||||
"//internal/completions/client/anthropic",
|
||||
"//internal/completions/client/fireworks",
|
||||
"//internal/completions/client/google",
|
||||
"//internal/env",
|
||||
"//internal/trace/policy",
|
||||
"//lib/errors",
|
||||
|
||||
@ -12,6 +12,7 @@ import (
|
||||
"github.com/sourcegraph/sourcegraph/internal/collections"
|
||||
"github.com/sourcegraph/sourcegraph/internal/completions/client/anthropic"
|
||||
"github.com/sourcegraph/sourcegraph/internal/completions/client/fireworks"
|
||||
"github.com/sourcegraph/sourcegraph/internal/completions/client/google"
|
||||
"github.com/sourcegraph/sourcegraph/internal/env"
|
||||
"github.com/sourcegraph/sourcegraph/internal/trace/policy"
|
||||
"github.com/sourcegraph/sourcegraph/lib/errors"
|
||||
@ -282,6 +283,20 @@ func (c *Config) Load() {
|
||||
c.Fireworks.StarcoderCommunitySingleTenantPercent = c.GetPercent("CODY_GATEWAY_FIREWORKS_STARCODER_COMMUNITY_SINGLE_TENANT_PERCENT", "0", "The percentage of community traffic for Starcoder to be redirected to the single-tenant deployment.")
|
||||
c.Fireworks.StarcoderEnterpriseSingleTenantPercent = c.GetPercent("CODY_GATEWAY_FIREWORKS_STARCODER_ENTERPRISE_SINGLE_TENANT_PERCENT", "100", "The percentage of Enterprise traffic for Starcoder to be redirected to the single-tenant deployment.")
|
||||
|
||||
// Configurations for Google Gemini models.
|
||||
c.Google.AccessToken = c.GetOptional("CODY_GATEWAY_GOOGLE_ACCESS_TOKEN", "The Google AI Studio access token to be used.")
|
||||
c.Google.AllowedModels = splitMaybe(c.Get("CODY_GATEWAY_GOOGLE_ALLOWED_MODELS",
|
||||
strings.Join([]string{
|
||||
google.Gemini15FlashLatest,
|
||||
google.Gemini15ProLatest,
|
||||
google.GeminiProLatest,
|
||||
}, ","),
|
||||
"Google models that can to be used."),
|
||||
)
|
||||
if c.Google.AccessToken != "" && len(c.Google.AllowedModels) == 0 {
|
||||
c.AddError(errors.New("must provide allowed models for Google"))
|
||||
}
|
||||
|
||||
c.AllowedEmbeddingsModels = splitMaybe(c.Get("CODY_GATEWAY_ALLOWED_EMBEDDINGS_MODELS", strings.Join([]string{string(embeddings.ModelNameOpenAIAda), string(embeddings.ModelNameSourcegraphSTMultiQA)}, ","), "The models allowed for embeddings generation."))
|
||||
if len(c.AllowedEmbeddingsModels) == 0 {
|
||||
c.AddError(errors.New("must provide allowed models for embeddings generation"))
|
||||
@ -325,16 +340,6 @@ func (c *Config) Load() {
|
||||
c.SAMSClientConfig.ClientSecret = c.GetOptional("SAMS_CLIENT_SECRET", "SAMS OAuth client secret")
|
||||
|
||||
c.Environment = c.Get("CODY_GATEWAY_ENVIRONMENT", "dev", "Environment name.")
|
||||
|
||||
c.Google.AccessToken = c.GetOptional("CODY_GATEWAY_GOOGLE_ACCESS_TOKEN", "The Google AI Studio access token to be used.")
|
||||
c.Google.AllowedModels = splitMaybe(c.Get("CODY_GATEWAY_GOOGLE_ALLOWED_MODELS",
|
||||
strings.Join([]string{
|
||||
"gemini-1.5-pro-latest",
|
||||
"gemini-1.5-flash-latest",
|
||||
}, ","),
|
||||
"Google models that can to be used."),
|
||||
)
|
||||
|
||||
}
|
||||
|
||||
// loadFlaggingConfig loads the common set of flagging-related environment variables for
|
||||
|
||||
@ -38,6 +38,7 @@ go_library(
|
||||
"//internal/codygateway",
|
||||
"//internal/completions/client/anthropic",
|
||||
"//internal/completions/client/fireworks",
|
||||
"//internal/completions/client/google",
|
||||
"//internal/completions/types",
|
||||
"//internal/conf",
|
||||
"//internal/conf/conftypes",
|
||||
|
||||
@ -21,6 +21,7 @@ import (
|
||||
"github.com/sourcegraph/sourcegraph/internal/codygateway"
|
||||
"github.com/sourcegraph/sourcegraph/internal/completions/client/anthropic"
|
||||
"github.com/sourcegraph/sourcegraph/internal/completions/client/fireworks"
|
||||
"github.com/sourcegraph/sourcegraph/internal/completions/client/google"
|
||||
"github.com/sourcegraph/sourcegraph/internal/completions/types"
|
||||
"github.com/sourcegraph/sourcegraph/internal/conf"
|
||||
"github.com/sourcegraph/sourcegraph/internal/database"
|
||||
@ -382,8 +383,9 @@ func allowedModels(scope types.CompletionsFeature, isProUser bool) []string {
|
||||
"openai/gpt-4o",
|
||||
"openai/gpt-4-turbo",
|
||||
"openai/gpt-4-turbo-preview",
|
||||
"google/gemini-1.5-pro-latest",
|
||||
"google/gemini-1.5-flash-latest",
|
||||
"google/" + google.Gemini15FlashLatest,
|
||||
"google/" + google.Gemini15ProLatest,
|
||||
"google/" + google.GeminiProLatest,
|
||||
|
||||
// Remove after the Claude 3 rollout is complete
|
||||
"anthropic/claude-2",
|
||||
|
||||
@ -22,6 +22,7 @@ go_library(
|
||||
"//internal/completions/client",
|
||||
"//internal/completions/client/anthropic",
|
||||
"//internal/completions/client/fireworks",
|
||||
"//internal/completions/client/google",
|
||||
"//internal/completions/types",
|
||||
"//internal/conf",
|
||||
"//internal/conf/conftypes",
|
||||
|
||||
@ -13,6 +13,7 @@ import (
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/internal/completions/client/anthropic"
|
||||
"github.com/sourcegraph/sourcegraph/internal/completions/client/fireworks"
|
||||
"github.com/sourcegraph/sourcegraph/internal/completions/client/google"
|
||||
"github.com/sourcegraph/sourcegraph/internal/completions/types"
|
||||
"github.com/sourcegraph/sourcegraph/internal/conf/conftypes"
|
||||
"github.com/sourcegraph/sourcegraph/internal/database"
|
||||
@ -85,8 +86,9 @@ func isAllowedCustomChatModel(model string, isProUser bool) bool {
|
||||
"openai/gpt-4o",
|
||||
"openai/gpt-4-turbo",
|
||||
"openai/gpt-4-turbo-preview",
|
||||
"google/gemini-1.5-flash-latest",
|
||||
"google/gemini-1.5-pro-latest",
|
||||
"google" + google.Gemini15FlashLatest,
|
||||
"google" + google.Gemini15ProLatest,
|
||||
"google" + google.GeminiProLatest,
|
||||
|
||||
// Remove after the Claude 3 rollout is complete
|
||||
"anthropic/claude-2",
|
||||
|
||||
@ -15,6 +15,7 @@ go_library(
|
||||
"//internal/completions/client/azureopenai",
|
||||
"//internal/completions/client/codygateway",
|
||||
"//internal/completions/client/fireworks",
|
||||
"//internal/completions/client/google",
|
||||
"//internal/completions/client/openai",
|
||||
"//internal/completions/tokenusage",
|
||||
"//internal/completions/types",
|
||||
|
||||
@ -8,6 +8,7 @@ import (
|
||||
"github.com/sourcegraph/sourcegraph/internal/completions/client/azureopenai"
|
||||
"github.com/sourcegraph/sourcegraph/internal/completions/client/codygateway"
|
||||
"github.com/sourcegraph/sourcegraph/internal/completions/client/fireworks"
|
||||
"github.com/sourcegraph/sourcegraph/internal/completions/client/google"
|
||||
"github.com/sourcegraph/sourcegraph/internal/completions/client/openai"
|
||||
"github.com/sourcegraph/sourcegraph/internal/completions/tokenusage"
|
||||
"github.com/sourcegraph/sourcegraph/internal/completions/types"
|
||||
@ -40,6 +41,8 @@ func getBasic(endpoint string, provider conftypes.CompletionsProviderName, acces
|
||||
return openai.NewClient(httpcli.UncachedExternalDoer, endpoint, accessToken, *tokenManager), nil
|
||||
case conftypes.CompletionsProviderNameAzureOpenAI:
|
||||
return azureopenai.NewClient(azureopenai.GetAPIClient, endpoint, accessToken, *tokenManager)
|
||||
case conftypes.CompletionsProviderNameGoogle:
|
||||
return google.NewClient(httpcli.UncachedExternalDoer, endpoint, accessToken), nil
|
||||
case conftypes.CompletionsProviderNameSourcegraph:
|
||||
return codygateway.NewClient(httpcli.CodyGatewayDoer, endpoint, accessToken, *tokenManager)
|
||||
case conftypes.CompletionsProviderNameFireworks:
|
||||
|
||||
@ -12,6 +12,7 @@ go_library(
|
||||
"//internal/codygateway",
|
||||
"//internal/completions/client/anthropic",
|
||||
"//internal/completions/client/fireworks",
|
||||
"//internal/completions/client/google",
|
||||
"//internal/completions/client/openai",
|
||||
"//internal/completions/tokenusage",
|
||||
"//internal/completions/types",
|
||||
|
||||
@ -15,6 +15,7 @@ import (
|
||||
"github.com/sourcegraph/sourcegraph/internal/codygateway"
|
||||
"github.com/sourcegraph/sourcegraph/internal/completions/client/anthropic"
|
||||
"github.com/sourcegraph/sourcegraph/internal/completions/client/fireworks"
|
||||
"github.com/sourcegraph/sourcegraph/internal/completions/client/google"
|
||||
"github.com/sourcegraph/sourcegraph/internal/completions/client/openai"
|
||||
"github.com/sourcegraph/sourcegraph/internal/completions/tokenusage"
|
||||
"github.com/sourcegraph/sourcegraph/internal/completions/types"
|
||||
@ -89,6 +90,8 @@ func (c *codyGatewayClient) clientForParams(feature types.CompletionsFeature, re
|
||||
return openai.NewClient(gatewayDoer(c.upstream, feature, c.gatewayURL, c.accessToken, "/v1/completions/openai"), "", "", c.tokenizer), nil
|
||||
case string(conftypes.CompletionsProviderNameFireworks):
|
||||
return fireworks.NewClient(gatewayDoer(c.upstream, feature, c.gatewayURL, c.accessToken, "/v1/completions/fireworks"), "", ""), nil
|
||||
case string(conftypes.CompletionsProviderNameGoogle):
|
||||
return google.NewClient(gatewayDoer(c.upstream, feature, c.gatewayURL, c.accessToken, "/v1/completions/google"), "", ""), nil
|
||||
case "":
|
||||
return nil, errors.Newf("no provider provided in model %s - a model in the format '$PROVIDER/$MODEL_NAME' is expected", model)
|
||||
default:
|
||||
|
||||
38
internal/completions/client/google/BUILD.bazel
Normal file
38
internal/completions/client/google/BUILD.bazel
Normal file
@ -0,0 +1,38 @@
|
||||
load("@io_bazel_rules_go//go:def.bzl", "go_library")
|
||||
load("//dev:go_defs.bzl", "go_test")
|
||||
|
||||
go_library(
|
||||
name = "google",
|
||||
srcs = [
|
||||
"decoder.go",
|
||||
"google.go",
|
||||
"models.go",
|
||||
"prompt.go",
|
||||
"types.go",
|
||||
],
|
||||
importpath = "github.com/sourcegraph/sourcegraph/internal/completions/client/google",
|
||||
visibility = ["//:__subpackages__"],
|
||||
deps = [
|
||||
"//internal/completions/types",
|
||||
"//internal/httpcli",
|
||||
"//lib/errors",
|
||||
"@com_github_sourcegraph_log//:log",
|
||||
],
|
||||
)
|
||||
|
||||
go_test(
|
||||
name = "google_test",
|
||||
srcs = [
|
||||
"decoder_test.go",
|
||||
"google_test.go",
|
||||
"prompt_test.go",
|
||||
],
|
||||
embed = [":google"],
|
||||
deps = [
|
||||
"//internal/completions/types",
|
||||
"@com_github_hexops_autogold_v2//:autogold",
|
||||
"@com_github_sourcegraph_log//:log",
|
||||
"@com_github_stretchr_testify//assert",
|
||||
"@com_github_stretchr_testify//require",
|
||||
],
|
||||
)
|
||||
96
internal/completions/client/google/decoder.go
Normal file
96
internal/completions/client/google/decoder.go
Normal file
@ -0,0 +1,96 @@
|
||||
package google
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"io"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/lib/errors"
|
||||
)
|
||||
|
||||
const maxPayloadSize = 10 * 1024 * 1024 // 10mb
|
||||
|
||||
var doneBytes = []byte("[DONE]")
|
||||
|
||||
// decoder decodes streaming events from a Server Sent Event stream.
|
||||
type decoder struct {
|
||||
scanner *bufio.Scanner
|
||||
done bool
|
||||
data []byte
|
||||
err error
|
||||
}
|
||||
|
||||
func NewDecoder(r io.Reader) *decoder {
|
||||
scanner := bufio.NewScanner(r)
|
||||
scanner.Buffer(make([]byte, 0, 4096), maxPayloadSize)
|
||||
// bufio.ScanLines, except we look for \r\n\r\n which separate events.
|
||||
split := func(data []byte, atEOF bool) (int, []byte, error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := bytes.Index(data, []byte("\r\n\r\n")); i >= 0 {
|
||||
return i + 4, data[:i], nil
|
||||
}
|
||||
if i := bytes.Index(data, []byte("\n\n")); i >= 0 {
|
||||
return i + 2, data[:i], nil
|
||||
}
|
||||
// If we're at EOF, we have a final, non-terminated event. This should
|
||||
// be empty.
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
// Request more data.
|
||||
return 0, nil, nil
|
||||
}
|
||||
scanner.Split(split)
|
||||
return &decoder{
|
||||
scanner: scanner,
|
||||
}
|
||||
}
|
||||
|
||||
// Scan advances the decoder to the next event in the stream. It returns
|
||||
// false when it either hits the end of the stream or an error.
|
||||
func (d *decoder) Scan() bool {
|
||||
if d.done {
|
||||
return false
|
||||
}
|
||||
for d.scanner.Scan() {
|
||||
line := d.scanner.Bytes()
|
||||
typ, data := splitColon(line)
|
||||
switch {
|
||||
case bytes.Equal(typ, []byte("data")):
|
||||
d.data = data
|
||||
// Check for special sentinel value used by the Google API to
|
||||
// indicate that the stream is done.
|
||||
if bytes.Equal(data, doneBytes) {
|
||||
d.done = true
|
||||
return false
|
||||
}
|
||||
return true
|
||||
default:
|
||||
d.err = errors.Errorf("malformed data, expected data: %s", typ)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
d.err = d.scanner.Err()
|
||||
return false
|
||||
}
|
||||
|
||||
// Event returns the event data of the last decoded event
|
||||
func (d *decoder) Data() []byte {
|
||||
return d.data
|
||||
}
|
||||
|
||||
// Err returns the last encountered error
|
||||
func (d *decoder) Err() error {
|
||||
return d.err
|
||||
}
|
||||
|
||||
func splitColon(data []byte) ([]byte, []byte) {
|
||||
i := bytes.Index(data, []byte(":"))
|
||||
if i < 0 {
|
||||
return bytes.TrimSpace(data), nil
|
||||
}
|
||||
return bytes.TrimSpace(data[:i]), bytes.TrimSpace(data[i+1:])
|
||||
}
|
||||
56
internal/completions/client/google/decoder_test.go
Normal file
56
internal/completions/client/google/decoder_test.go
Normal file
@ -0,0 +1,56 @@
|
||||
package google
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDecoder(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type event struct {
|
||||
data string
|
||||
}
|
||||
|
||||
decodeAll := func(input string) ([]event, error) {
|
||||
dec := NewDecoder(strings.NewReader(input))
|
||||
var events []event
|
||||
for dec.Scan() {
|
||||
events = append(events, event{
|
||||
data: string(dec.Data()),
|
||||
})
|
||||
}
|
||||
return events, dec.Err()
|
||||
}
|
||||
|
||||
t.Run("Single", func(t *testing.T) {
|
||||
events, err := decodeAll("data:b\n\n")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, events, []event{{data: "b"}})
|
||||
})
|
||||
|
||||
t.Run("Multiple", func(t *testing.T) {
|
||||
events, err := decodeAll("data:b\n\ndata:c\n\ndata: [DONE]\n\n")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, events, []event{{data: "b"}, {data: "c"}})
|
||||
})
|
||||
|
||||
t.Run("Multiple with new line within data", func(t *testing.T) {
|
||||
events, err := decodeAll("data:line1\nline2\nline3\n\ndata:second-data\n\ndata: [DONE]\n\n")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, events, []event{{data: "line1\nline2\nline3"}, {data: "second-data"}})
|
||||
})
|
||||
|
||||
t.Run("ErrExpectedData", func(t *testing.T) {
|
||||
_, err := decodeAll("datas:b\n\n")
|
||||
require.Contains(t, err.Error(), "malformed data, expected data")
|
||||
})
|
||||
|
||||
t.Run("Ends after done", func(t *testing.T) {
|
||||
events, err := decodeAll("data:b\n\ndata:c\n\ndata: [DONE]\n\ndata:d\n\n")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, events, []event{{data: "b"}, {data: "c"}})
|
||||
})
|
||||
}
|
||||
245
internal/completions/client/google/google.go
Normal file
245
internal/completions/client/google/google.go
Normal file
@ -0,0 +1,245 @@
|
||||
package google
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
|
||||
"github.com/sourcegraph/log"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/internal/completions/types"
|
||||
"github.com/sourcegraph/sourcegraph/internal/httpcli"
|
||||
"github.com/sourcegraph/sourcegraph/lib/errors"
|
||||
)
|
||||
|
||||
func NewClient(cli httpcli.Doer, endpoint, accessToken string) types.CompletionsClient {
|
||||
return &googleCompletionStreamClient{
|
||||
cli: cli,
|
||||
accessToken: accessToken,
|
||||
endpoint: endpoint,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *googleCompletionStreamClient) Complete(
|
||||
ctx context.Context,
|
||||
feature types.CompletionsFeature,
|
||||
_ types.CompletionsVersion,
|
||||
requestParams types.CompletionRequestParameters,
|
||||
logger log.Logger,
|
||||
) (*types.CompletionResponse, error) {
|
||||
if !isSupportedFeature(feature) {
|
||||
return nil, errors.Newf("feature %q is currently not supported for Google", feature)
|
||||
}
|
||||
|
||||
var resp *http.Response
|
||||
var err error
|
||||
defer (func() {
|
||||
if resp != nil {
|
||||
resp.Body.Close()
|
||||
}
|
||||
})()
|
||||
|
||||
resp, err = c.makeRequest(ctx, requestParams, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var response googleResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(response.Candidates) == 0 {
|
||||
// Empty response.
|
||||
return &types.CompletionResponse{}, nil
|
||||
}
|
||||
|
||||
if len(response.Candidates[0].Content.Parts) == 0 {
|
||||
// Empty response.
|
||||
return &types.CompletionResponse{}, nil
|
||||
}
|
||||
|
||||
// NOTE: Candidates can be used to get multiple completions when CandidateCount is set,
|
||||
// which is not currently supported by Cody. For now, we only return the first completion.
|
||||
return &types.CompletionResponse{
|
||||
Completion: response.Candidates[0].Content.Parts[0].Text,
|
||||
StopReason: response.Candidates[0].FinishReason,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *googleCompletionStreamClient) Stream(
|
||||
ctx context.Context,
|
||||
feature types.CompletionsFeature,
|
||||
_ types.CompletionsVersion,
|
||||
requestParams types.CompletionRequestParameters,
|
||||
sendEvent types.SendCompletionEvent,
|
||||
logger log.Logger,
|
||||
) error {
|
||||
if !isSupportedFeature(feature) {
|
||||
return errors.Newf("feature %q is currently not supported for Google", feature)
|
||||
}
|
||||
|
||||
var resp *http.Response
|
||||
var err error
|
||||
|
||||
defer (func() {
|
||||
if resp != nil {
|
||||
resp.Body.Close()
|
||||
}
|
||||
})()
|
||||
|
||||
resp, err = c.makeRequest(ctx, requestParams, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dec := NewDecoder(resp.Body)
|
||||
var content string
|
||||
var ev types.CompletionResponse
|
||||
|
||||
for dec.Scan() {
|
||||
if ctx.Err() != nil && ctx.Err() == context.Canceled {
|
||||
return nil
|
||||
}
|
||||
|
||||
data := dec.Data()
|
||||
// Gracefully skip over any data that isn't JSON-like.
|
||||
if !bytes.HasPrefix(data, []byte("{")) {
|
||||
continue
|
||||
}
|
||||
|
||||
var event googleResponse
|
||||
if err := json.Unmarshal(data, &event); err != nil {
|
||||
return errors.Errorf("failed to decode event payload: %w - body: %s", err, string(data))
|
||||
}
|
||||
|
||||
if len(event.Candidates) > 0 && len(event.Candidates[0].Content.Parts) > 0 {
|
||||
content += event.Candidates[0].Content.Parts[0].Text
|
||||
|
||||
ev = types.CompletionResponse{
|
||||
Completion: content,
|
||||
StopReason: event.Candidates[0].FinishReason,
|
||||
}
|
||||
err = sendEvent(ev)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
if dec.Err() != nil {
|
||||
return dec.Err()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// makeRequest formats the request and calls the chat/completions endpoint for code_completion requests
|
||||
func (c *googleCompletionStreamClient) makeRequest(ctx context.Context, requestParams types.CompletionRequestParameters, stream bool) (*http.Response, error) {
|
||||
// Ensure TopK and TopP are non-negative
|
||||
requestParams.TopK = max(0, requestParams.TopK)
|
||||
requestParams.TopP = max(0, requestParams.TopP)
|
||||
|
||||
// Generate the prompt
|
||||
prompt, err := getPrompt(requestParams.Messages)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
payload := googleRequest{
|
||||
Contents: prompt,
|
||||
GenerationConfig: googleGenerationConfig{
|
||||
Temperature: requestParams.Temperature,
|
||||
TopP: requestParams.TopP,
|
||||
TopK: requestParams.TopK,
|
||||
MaxOutputTokens: requestParams.MaxTokensToSample,
|
||||
StopSequences: requestParams.StopSequences,
|
||||
},
|
||||
}
|
||||
|
||||
reqBody, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
apiURL := c.getAPIURL(requestParams, stream)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", apiURL.String(), bytes.NewReader(reqBody))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
// Vertex AI API requires an Authorization header with the access token.
|
||||
// Ref: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/gemini#sample-requests
|
||||
if !isDefaultAPIEndpoint(apiURL) {
|
||||
req.Header.Set("Authorization", "Bearer "+c.accessToken)
|
||||
}
|
||||
|
||||
resp, err := c.cli.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, types.NewErrStatusNotOK("Google", resp)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// In the latest API Docs, the model name and API key must be used with the default API endpoint URL.
|
||||
// Ref: https://ai.google.dev/gemini-api/docs/get-started/tutorial?lang=rest#gemini_and_content_based_apis
|
||||
func (c *googleCompletionStreamClient) getAPIURL(requestParams types.CompletionRequestParameters, stream bool) *url.URL {
|
||||
apiURL, err := url.Parse(c.endpoint)
|
||||
if err != nil {
|
||||
apiURL = &url.URL{
|
||||
Scheme: "https",
|
||||
Host: defaultAPIHost,
|
||||
Path: defaultAPIPath,
|
||||
}
|
||||
}
|
||||
|
||||
apiURL.Path = path.Join(apiURL.Path, requestParams.Model) + ":" + getgRPCMethod(stream)
|
||||
|
||||
// We need to append the API key to the default API endpoint URL.
|
||||
if isDefaultAPIEndpoint(apiURL) {
|
||||
query := apiURL.Query()
|
||||
query.Set("key", c.accessToken)
|
||||
if stream {
|
||||
query.Set("alt", "sse")
|
||||
}
|
||||
apiURL.RawQuery = query.Encode()
|
||||
}
|
||||
|
||||
return apiURL
|
||||
}
|
||||
|
||||
// getgRPCMethod returns the gRPC method name based on the stream flag.
|
||||
func getgRPCMethod(stream bool) string {
|
||||
if stream {
|
||||
return "streamGenerateContent"
|
||||
}
|
||||
return "generateContent"
|
||||
}
|
||||
|
||||
// isSupportedFeature checks if the given CompletionsFeature is supported.
|
||||
// Currently, only the CompletionsFeatureChat feature is supported.
|
||||
func isSupportedFeature(feature types.CompletionsFeature) bool {
|
||||
switch feature {
|
||||
case types.CompletionsFeatureChat:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// isDefaultAPIEndpoint checks if the given API endpoint URL is the default API endpoint.
|
||||
// The default API endpoint is determined by the defaultAPIHost constant.
|
||||
func isDefaultAPIEndpoint(endpoint *url.URL) bool {
|
||||
return endpoint.Host == defaultAPIHost
|
||||
}
|
||||
129
internal/completions/client/google/google_test.go
Normal file
129
internal/completions/client/google/google_test.go
Normal file
@ -0,0 +1,129 @@
|
||||
package google
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/hexops/autogold/v2"
|
||||
"github.com/sourcegraph/log"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/internal/completions/types"
|
||||
)
|
||||
|
||||
type mockDoer struct {
|
||||
do func(*http.Request) (*http.Response, error)
|
||||
}
|
||||
|
||||
func (c *mockDoer) Do(r *http.Request) (*http.Response, error) {
|
||||
return c.do(r)
|
||||
}
|
||||
|
||||
func TestErrStatusNotOK(t *testing.T) {
|
||||
mockClient := NewClient(&mockDoer{
|
||||
func(r *http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusTooManyRequests,
|
||||
Body: io.NopCloser(bytes.NewReader([]byte("oh no, please slow down!"))),
|
||||
}, nil
|
||||
},
|
||||
}, "", "")
|
||||
|
||||
t.Run("Complete", func(t *testing.T) {
|
||||
logger := log.Scoped("completions")
|
||||
resp, err := mockClient.Complete(context.Background(), types.CompletionsFeatureChat, types.CompletionsVersionLegacy, types.CompletionRequestParameters{}, logger)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, resp)
|
||||
|
||||
autogold.Expect("Google: unexpected status code 429: oh no, please slow down!").Equal(t, err.Error())
|
||||
_, ok := types.IsErrStatusNotOK(err)
|
||||
assert.True(t, ok)
|
||||
})
|
||||
|
||||
t.Run("Stream", func(t *testing.T) {
|
||||
logger := log.Scoped("completions")
|
||||
err := mockClient.Stream(context.Background(), types.CompletionsFeatureChat, types.CompletionsVersionLegacy, types.CompletionRequestParameters{}, func(event types.CompletionResponse) error { return nil }, logger)
|
||||
require.Error(t, err)
|
||||
|
||||
autogold.Expect("Google: unexpected status code 429: oh no, please slow down!").Equal(t, err.Error())
|
||||
_, ok := types.IsErrStatusNotOK(err)
|
||||
assert.True(t, ok)
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetAPIURL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := &googleCompletionStreamClient{
|
||||
endpoint: "https://generativelanguage.googleapis.com/v1/models",
|
||||
accessToken: "test-token",
|
||||
}
|
||||
|
||||
t.Run("valid v1 endpoint", func(t *testing.T) {
|
||||
params := types.CompletionRequestParameters{
|
||||
Model: "test-model",
|
||||
}
|
||||
url := client.getAPIURL(params, false).String()
|
||||
expected := "https://generativelanguage.googleapis.com/v1/models/test-model:generateContent?key=test-token"
|
||||
require.Equal(t, expected, url)
|
||||
})
|
||||
|
||||
//
|
||||
t.Run("valid endpoint for Vertex AI", func(t *testing.T) {
|
||||
params := types.CompletionRequestParameters{
|
||||
Model: "gemini-1.5-pro",
|
||||
}
|
||||
c := &googleCompletionStreamClient{
|
||||
endpoint: "https://vertex-ai.example.com/v1/projects/PROJECT_ID/locations/LOCATION/publishers/google/models",
|
||||
accessToken: "test-token",
|
||||
}
|
||||
url := c.getAPIURL(params, true).String()
|
||||
expected := "https://vertex-ai.example.com/v1/projects/PROJECT_ID/locations/LOCATION/publishers/google/models/gemini-1.5-pro:streamGenerateContent"
|
||||
require.Equal(t, expected, url)
|
||||
})
|
||||
|
||||
t.Run("valid custom endpoint", func(t *testing.T) {
|
||||
params := types.CompletionRequestParameters{
|
||||
Model: "test-model",
|
||||
}
|
||||
c := &googleCompletionStreamClient{
|
||||
endpoint: "https://example.com/api/models",
|
||||
accessToken: "test-token",
|
||||
}
|
||||
url := c.getAPIURL(params, true).String()
|
||||
expected := "https://example.com/api/models/test-model:streamGenerateContent"
|
||||
require.Equal(t, expected, url)
|
||||
})
|
||||
|
||||
t.Run("invalid endpoint", func(t *testing.T) {
|
||||
client.endpoint = "://invalid"
|
||||
params := types.CompletionRequestParameters{
|
||||
Model: "test-model",
|
||||
}
|
||||
url := client.getAPIURL(params, false).String()
|
||||
expected := "https://generativelanguage.googleapis.com/v1beta/models/test-model:generateContent?key=test-token"
|
||||
require.Equal(t, expected, url)
|
||||
})
|
||||
|
||||
t.Run("streaming", func(t *testing.T) {
|
||||
params := types.CompletionRequestParameters{
|
||||
Model: "test-model",
|
||||
}
|
||||
url := client.getAPIURL(params, true).String()
|
||||
expected := "https://generativelanguage.googleapis.com/v1beta/models/test-model:streamGenerateContent?alt=sse&key=test-token"
|
||||
require.Equal(t, expected, url)
|
||||
})
|
||||
|
||||
t.Run("empty model", func(t *testing.T) {
|
||||
params := types.CompletionRequestParameters{
|
||||
Model: "",
|
||||
}
|
||||
url := client.getAPIURL(params, false).String()
|
||||
expected := "https://generativelanguage.googleapis.com/v1beta/models:generateContent?key=test-token"
|
||||
require.Equal(t, expected, url)
|
||||
})
|
||||
}
|
||||
21
internal/completions/client/google/models.go
Normal file
21
internal/completions/client/google/models.go
Normal file
@ -0,0 +1,21 @@
|
||||
package google
|
||||
|
||||
// For latest available Google Gemini models,
|
||||
// See: https://ai.google.dev/gemini-api/docs/models/gemini
|
||||
const providerName = "google"
|
||||
|
||||
// Default API endpoint URL
|
||||
const (
|
||||
defaultAPIHost = "generativelanguage.googleapis.com"
|
||||
defaultAPIPath = "/v1beta/models"
|
||||
)
|
||||
|
||||
// Latest stable versions
|
||||
const Gemini15Flash = "gemini-1.5-flash"
|
||||
const Gemini15Pro = "gemini-1.5-pro"
|
||||
const GeminiPro = "gemini-pro"
|
||||
|
||||
// Latest versions
|
||||
const Gemini15FlashLatest = "gemini-1.5-flash-latest"
|
||||
const Gemini15ProLatest = "gemini-1.5-pro-latest"
|
||||
const GeminiProLatest = "gemini-pro-latest"
|
||||
42
internal/completions/client/google/prompt.go
Normal file
42
internal/completions/client/google/prompt.go
Normal file
@ -0,0 +1,42 @@
|
||||
package google
|
||||
|
||||
import (
|
||||
"github.com/sourcegraph/sourcegraph/internal/completions/types"
|
||||
"github.com/sourcegraph/sourcegraph/lib/errors"
|
||||
)
|
||||
|
||||
func getPrompt(messages []types.Message) ([]googleContentMessage, error) {
|
||||
googleMessages := make([]googleContentMessage, 0, len(messages))
|
||||
|
||||
for i, message := range messages {
|
||||
var googleRole string
|
||||
|
||||
switch message.Speaker {
|
||||
case types.SYSTEM_MESSAGE_SPEAKER:
|
||||
if i != 0 {
|
||||
return nil, errors.New("system role can only be used in the first message")
|
||||
}
|
||||
googleRole = message.Speaker
|
||||
case types.ASSISTANT_MESSAGE_SPEAKER:
|
||||
if i == 0 {
|
||||
return nil, errors.New("assistant role cannot be used in the first message")
|
||||
}
|
||||
googleRole = "model"
|
||||
case types.HUMAN_MESSAGE_SPEAKER:
|
||||
googleRole = "user"
|
||||
default:
|
||||
return nil, errors.Errorf("unexpected role: %s", message.Text)
|
||||
}
|
||||
|
||||
if message.Text == "" {
|
||||
return nil, errors.New("message content cannot be empty")
|
||||
}
|
||||
|
||||
googleMessages = append(googleMessages, googleContentMessage{
|
||||
Role: googleRole,
|
||||
Parts: []googleContentMessagePart{{Text: message.Text}},
|
||||
})
|
||||
}
|
||||
|
||||
return googleMessages, nil
|
||||
}
|
||||
70
internal/completions/client/google/prompt_test.go
Normal file
70
internal/completions/client/google/prompt_test.go
Normal file
@ -0,0 +1,70 @@
|
||||
package google
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/internal/completions/types"
|
||||
)
|
||||
|
||||
func TestGetPrompt(t *testing.T) {
|
||||
t.Run("invalid speaker", func(t *testing.T) {
|
||||
_, err := getPrompt([]types.Message{{Speaker: "invalid", Text: "hello"}})
|
||||
if err == nil {
|
||||
t.Errorf("expected error for invalid speaker, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty text", func(t *testing.T) {
|
||||
_, err := getPrompt([]types.Message{{Speaker: types.HUMAN_MESSAGE_SPEAKER, Text: ""}})
|
||||
if err == nil {
|
||||
t.Errorf("expected error for empty text, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("multiple system messages", func(t *testing.T) {
|
||||
_, err := getPrompt([]types.Message{
|
||||
{Speaker: types.SYSTEM_MESSAGE_SPEAKER, Text: "system"},
|
||||
{Speaker: types.HUMAN_MESSAGE_SPEAKER, Text: "hello"},
|
||||
{Speaker: types.SYSTEM_MESSAGE_SPEAKER, Text: "system"},
|
||||
})
|
||||
if err == nil {
|
||||
t.Errorf("expected error for multiple system messages, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid prompt starts with assistant", func(t *testing.T) {
|
||||
_, err := getPrompt([]types.Message{
|
||||
{Speaker: types.ASSISTANT_MESSAGE_SPEAKER, Text: "assistant"},
|
||||
{Speaker: types.HUMAN_MESSAGE_SPEAKER, Text: "hello"},
|
||||
{Speaker: types.ASSISTANT_MESSAGE_SPEAKER, Text: "assistant"},
|
||||
})
|
||||
if err == nil {
|
||||
t.Errorf("expected error for messages starts with assistant, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("valid prompt", func(t *testing.T) {
|
||||
messages := []types.Message{
|
||||
{Speaker: types.SYSTEM_MESSAGE_SPEAKER, Text: "system"},
|
||||
{Speaker: types.HUMAN_MESSAGE_SPEAKER, Text: "hello"},
|
||||
{Speaker: types.ASSISTANT_MESSAGE_SPEAKER, Text: "hi"},
|
||||
}
|
||||
prompt, err := getPrompt(messages)
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
expected := []googleContentMessage{
|
||||
{Role: "system", Parts: []googleContentMessagePart{{Text: "system"}}},
|
||||
{Role: "user", Parts: []googleContentMessagePart{{Text: "hello"}}},
|
||||
{Role: "model", Parts: []googleContentMessagePart{{Text: "hi"}}},
|
||||
}
|
||||
if len(prompt) != len(expected) {
|
||||
t.Errorf("unexpected prompt length, got %d, want %d", len(prompt), len(expected))
|
||||
}
|
||||
for i := range prompt {
|
||||
if prompt[i].Parts[0].Text != expected[i].Parts[0].Text {
|
||||
t.Errorf("unexpected prompt message at index %d, got %v, want %v", i, prompt[i], expected[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
60
internal/completions/client/google/types.go
Normal file
60
internal/completions/client/google/types.go
Normal file
@ -0,0 +1,60 @@
|
||||
package google
|
||||
|
||||
import "github.com/sourcegraph/sourcegraph/internal/httpcli"
|
||||
|
||||
type googleCompletionStreamClient struct {
|
||||
cli httpcli.Doer
|
||||
accessToken string
|
||||
endpoint string
|
||||
}
|
||||
|
||||
type googleRequest struct {
|
||||
Contents []googleContentMessage `json:"contents"`
|
||||
GenerationConfig googleGenerationConfig `json:"generationConfig,omitempty"`
|
||||
SafetySettings []googleSafetySettings `json:"safetySettings,omitempty"`
|
||||
}
|
||||
|
||||
type googleContentMessage struct {
|
||||
Role string `json:"role"`
|
||||
Parts []googleContentMessagePart `json:"parts"`
|
||||
}
|
||||
|
||||
type googleContentMessagePart struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
// Configuration options for model generation and outputs.
|
||||
// Ref: https://ai.google.dev/api/rest/v1/GenerationConfig
|
||||
type googleGenerationConfig struct {
|
||||
Temperature float32 `json:"temperature,omitempty"` // request.Temperature
|
||||
TopP float32 `json:"topP,omitempty"` // request.TopP
|
||||
TopK int `json:"topK,omitempty"` // request.TopK
|
||||
StopSequences []string `json:"stopSequences,omitempty"` // request.StopSequences
|
||||
MaxOutputTokens int `json:"maxOutputTokens,omitempty"` // request.MaxTokensToSample
|
||||
CandidateCount int `json:"candidateCount,omitempty"` // request.CandidateCount
|
||||
}
|
||||
|
||||
type googleResponse struct {
|
||||
Model string `json:"model"`
|
||||
Candidates []struct {
|
||||
Content googleContentMessage
|
||||
FinishReason string `json:"finishReason"`
|
||||
} `json:"candidates"`
|
||||
|
||||
UsageMetadata googleUsage `json:"usageMetadata"`
|
||||
SafetySettings []googleSafetySettings `json:"safetySettings,omitempty"`
|
||||
}
|
||||
|
||||
// Safety setting, affecting the safety-blocking behavior.
|
||||
// Ref: https://ai.google.dev/gemini-api/docs/safety-settings
|
||||
type googleSafetySettings struct {
|
||||
Category string `json:"category"`
|
||||
Threshold string `json:"threshold"`
|
||||
}
|
||||
|
||||
type googleUsage struct {
|
||||
PromptTokenCount int `json:"promptTokenCount"`
|
||||
// Use the same name we use elsewhere (completion instead of candidates)
|
||||
CompletionTokenCount int `json:"candidatesTokenCount"`
|
||||
TotalTokenCount int `json:"totalTokenCount"`
|
||||
}
|
||||
@ -28,6 +28,7 @@ go_library(
|
||||
deps = [
|
||||
"//internal/api/internalapi",
|
||||
"//internal/completions/client/anthropic",
|
||||
"//internal/completions/client/google",
|
||||
"//internal/conf/confdefaults",
|
||||
"//internal/conf/conftypes",
|
||||
"//internal/conf/deploy",
|
||||
|
||||
@ -10,6 +10,7 @@ import (
|
||||
"github.com/hashicorp/cronexpr"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/internal/completions/client/anthropic"
|
||||
"github.com/sourcegraph/sourcegraph/internal/completions/client/google"
|
||||
"github.com/sourcegraph/sourcegraph/internal/conf/confdefaults"
|
||||
"github.com/sourcegraph/sourcegraph/internal/conf/conftypes"
|
||||
"github.com/sourcegraph/sourcegraph/internal/conf/deploy"
|
||||
@ -867,6 +868,32 @@ func GetCompletionsConfig(siteConfig schema.SiteConfiguration) (c *conftypes.Com
|
||||
completionsModelRef := conftypes.NewBedrockModelRefFromModelID(completionsConfig.CompletionModel)
|
||||
completionsConfig.CompletionModel = completionsModelRef.CanonicalizedModelID()
|
||||
canonicalized = true
|
||||
} else if completionsConfig.Provider == string(conftypes.CompletionsProviderNameGoogle) {
|
||||
// If no endpoint is configured, use a default value.
|
||||
if completionsConfig.Endpoint == "" {
|
||||
completionsConfig.Endpoint = "https://generativelanguage.googleapis.com/v1beta/models"
|
||||
}
|
||||
|
||||
// If not access token is set, we cannot talk to Google. Bail.
|
||||
if completionsConfig.AccessToken == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Set a default chat model.
|
||||
if completionsConfig.ChatModel == "" {
|
||||
completionsConfig.ChatModel = google.Gemini15ProLatest
|
||||
}
|
||||
|
||||
// Set a default fast chat model.
|
||||
if completionsConfig.FastChatModel == "" {
|
||||
completionsConfig.FastChatModel = google.Gemini15FlashLatest
|
||||
}
|
||||
|
||||
// Set a default completions model.
|
||||
if completionsConfig.CompletionModel == "" {
|
||||
// Code completion is not supported by Google
|
||||
completionsConfig.CompletionModel = google.Gemini15FlashLatest
|
||||
}
|
||||
}
|
||||
|
||||
// only apply canonicalization if not already applied. Not all model IDs can simply be lowercased
|
||||
|
||||
@ -50,6 +50,7 @@ func NewCodyGatewayChatRateLimit(plan Plan, userCount *int) CodyGatewayRateLimit
|
||||
|
||||
"google/gemini-1.5-pro-latest",
|
||||
"google/gemini-1.5-flash-latest",
|
||||
"google/gemini-pro-latest",
|
||||
}
|
||||
switch plan {
|
||||
// TODO: This is just an example for now.
|
||||
|
||||
@ -38,6 +38,7 @@ func TestNewCodyGatewayChatRateLimit(t *testing.T) {
|
||||
"openai/gpt-4-turbo-preview",
|
||||
"google/gemini-1.5-pro-latest",
|
||||
"google/gemini-1.5-flash-latest",
|
||||
"google/gemini-pro-latest",
|
||||
},
|
||||
Limit: 2500,
|
||||
IntervalSeconds: 60 * 60 * 24,
|
||||
@ -65,6 +66,7 @@ func TestNewCodyGatewayChatRateLimit(t *testing.T) {
|
||||
"openai/gpt-4-turbo-preview",
|
||||
"google/gemini-1.5-pro-latest",
|
||||
"google/gemini-1.5-flash-latest",
|
||||
"google/gemini-pro-latest",
|
||||
},
|
||||
Limit: 50,
|
||||
IntervalSeconds: 60 * 60 * 24,
|
||||
@ -92,6 +94,7 @@ func TestNewCodyGatewayChatRateLimit(t *testing.T) {
|
||||
"openai/gpt-4-turbo-preview",
|
||||
"google/gemini-1.5-pro-latest",
|
||||
"google/gemini-1.5-flash-latest",
|
||||
"google/gemini-pro-latest",
|
||||
},
|
||||
Limit: 10,
|
||||
IntervalSeconds: 60 * 60 * 24,
|
||||
|
||||
@ -2890,7 +2890,7 @@
|
||||
"type": "string",
|
||||
"description": "The external completions provider. Defaults to 'sourcegraph'.",
|
||||
"default": "sourcegraph",
|
||||
"enum": ["anthropic", "openai", "sourcegraph", "azure-openai", "aws-bedrock", "fireworks"]
|
||||
"enum": ["anthropic", "openai", "sourcegraph", "azure-openai", "aws-bedrock", "fireworks", "google"]
|
||||
},
|
||||
"endpoint": {
|
||||
"type": "string",
|
||||
|
||||
Loading…
Reference in New Issue
Block a user