Cody Gateway: Add Gemini models to PLG and Enterprise users (#63053)

CLOSE https://github.com/sourcegraph/cody-issues/issues/211 &
https://github.com/sourcegraph/cody-issues/issues/412 &
https://github.com/sourcegraph/cody-issues/issues/412
UNBLOCK https://github.com/sourcegraph/cody/pull/4360

* Add support for Google Gemini AI models as chat completions provider
* Add new `google` package to handle Google Generative AI client
* Update `client.go` and `codygateway.go` to handle the new Google
provider
* Set default models for chat, fast chat, and completions when Google is
the configured provider
* Add gemini-pro to the allowed list

<!-- 💡 To write a useful PR description, make sure that your description
covers:
- WHAT this PR is changing:
    - How was it PREVIOUSLY.
    - How it will be from NOW on.
- WHY this PR is needed.
- CONTEXT, i.e. to which initiative, project or RFC it belongs.

The structure of the description doesn't matter as much as covering
these points, so use
your best judgement based on your context.
Learn how to write good pull request description:
https://www.notion.so/sourcegraph/Write-a-good-pull-request-description-610a7fd3e613496eb76f450db5a49b6e?pvs=4
-->


## Test plan

<!-- All pull requests REQUIRE a test plan:
https://docs-legacy.sourcegraph.com/dev/background-information/testing_principles
-->

For Enterprise instances using google as provider:

1. In your Soucegraph local instance's Site Config, add the following:

```
    "accessToken": "REDACTED",
    "chatModel": "gemini-1.5-pro-latest",
    "provider": "google",
```

Note: You can get the accessToken for Gemini API in 1Password.

2. After saving the site config with the above change, run the following
curl command:

```
curl 'https://sourcegraph.test:3443/.api/completions/stream' -i \
-X POST \
-H 'authorization: token $LOCAL_INSTANCE_TOKEN' \
--data-raw '{"messages":[{"speaker":"human","text":"Who are you?"}],"maxTokensToSample":30,"temperature":0,"stopSequences":[],"timeoutMs":5000,"stream":true,"model":"gemini-1.5-pro-latest"}'
```

3. Expected Output:

```
❯ curl 'https://sourcegraph.test:3443/.api/completions/stream' -i \
-X POST \
-H 'authorization: token <REDACTED>' \
--data-raw '{"messages":[{"speaker":"human","text":"Who are you?"}],"maxTokensToSample":30,"temperature":0,"stopSequences":[],"timeoutMs":5000,"stream":true,"model":"gemini-1.5-pro-latest"}'

HTTP/2 200
access-control-allow-credentials: true
access-control-allow-origin:
alt-svc: h3=":3443"; ma=2592000
cache-control: no-cache
content-type: text/event-stream
date: Tue, 04 Jun 2024 05:45:33 GMT
server: Caddy
server: Caddy
vary: Accept-Encoding, Authorization, Cookie, Authorization, X-Requested-With, Cookie
x-accel-buffering: no
x-content-type-options: nosniff
x-frame-options: DENY
x-powered-by: Express
x-trace: d4b1f02a3e2882a3d52331335d217b03
x-trace-span: 728ec33860d3b5e6
x-trace-url: https://sourcegraph.test:3443/-/debug/jaeger/trace/d4b1f02a3e2882a3d52331335d217b03
x-xss-protection: 1; mode=block

event: completion
data: {"completion":"I","stopReason":"STOP"}

event: completion
data: {"completion":"I am a large language model, trained by Google. \n\nThink of me as","stopReason":"STOP"}

event: completion
data: {"completion":"I am a large language model, trained by Google. \n\nThink of me as a computer program that can understand and generate human-like text.","stopReason":"MAX_TOKENS"}

event: done
data: {}
```

Verified locally:


![image](https://github.com/sourcegraph/sourcegraph/assets/68532117/2e6c914d-7a77-4484-b693-16bbc394518c)

#### Before

Cody Gateway returns `no client known for upstream provider google`

```sh
curl -X 'POST' -d '{"messages":[{"speaker":"human","text":"Who are you?"}],"maxTokensToSample":30,"temperature":0,"stopSequences":[],"timeoutMs":5000,"stream":true,"model":"google/gemini-1.5-pro-latest"}' -H 'Accept: application/json' -H 'Authorization: token $YOUR_DOTCOM_TOKEN' -H 'Content-Type: application/json' 'https://sourcegraph.com/.api/completions/stream'

event: error
data: {"error":"no client known for upstream provider google"}

event: done
data: {
```

## Changelog

<!--
1. Ensure your pull request title is formatted as: $type($domain): $what
2. Add bullet list items for each additional detail you want to cover
(see example below)
5. You can edit this after the pull request was merged, as long as
release shipping it hasn't been promoted to the public.
6. For more information, please see this how-to
https://www.notion.so/sourcegraph/Writing-a-changelog-entry-dd997f411d524caabf0d8d38a24a878c?

Audience: TS/CSE > Customers > Teammates (in that order).

Cheat sheet: $type = chore|fix|feat $domain:
source|search|ci|release|plg|cody|local|...
-->

<!--
Example:

Title: fix(search): parse quotes with the appropriate context
Changelog section:

## Changelog

- When a quote is used with regexp pattern type, then ...
- Refactored underlying code.
-->

Added support for Google as an LLM provider for Cody, with the following
models available through Cody Gateway: Gemini Pro (`gemini-pro-latest`),
Gemini 1.5 Flash (`gemini-1.5-flash-latest`), and Gemini 1.5 Pro
(`gemini-1.5-pro-latest`).
This commit is contained in:
Beatrix 2024-06-04 16:46:36 -07:00 committed by GitHub
parent f952ceb8da
commit f2590cbb36
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
29 changed files with 975 additions and 56 deletions

View File

@ -22,6 +22,7 @@ All notable changes to Sourcegraph are documented in this file.
- A feature flag for Cody, `completions.smartContextWindow` is added and set to "enabled" by default. It allows clients to adjust the context window based on the name of the chat model. When smartContextWindow is enabled, the `completions.chatModelMaxTokens` value is ignored. ([#62802](https://github.com/sourcegraph/sourcegraph/pull/62802))
- Code Insights: When facing the "incomplete datapoints" warning, you can now use GraphQL to discover which repositories had problems. The schemas for `TimeoutDatapointAlert` and `GenericIncompleteDatapointAlert` now contain an additional `repositories` field. ([#62756](https://github.com/sourcegraph/sourcegraph/pull/62756)).
- Users will now be presented with a modal that reminds them to connect any external code host accounts that's required for permissions. Without these accounts connected, users may be unable to view repositories that they otherwise have access to. [#62983](https://github.com/sourcegraph/sourcegraph/pull/62983)
- Added support for Google as an LLM provider for Cody, with the following models available through Cody Gateway: Gemini Pro (`gemini-pro-latest`), Gemini 1.5 Flash (`gemini-1.5-flash-latest`), and Gemini 1.5 Pro (`gemini-1.5-pro-latest`). [#63053](https://github.com/sourcegraph/sourcegraph/pull/63053)
### Changed

View File

@ -1,4 +1,4 @@
import React from 'react'
import type React from 'react'
import { Badge } from '@sourcegraph/wildcard'
@ -51,8 +51,11 @@ function modelBadgeVariant(model: string, mode: 'completions' | 'embeddings'): '
case 'openai/gpt-4o':
case 'openai/gpt-4-turbo':
case 'openai/gpt-4-turbo-preview':
// For currently available Google Gemini models,
// see: https://ai.google.dev/gemini-api/docs/models/gemini
case 'google/gemini-1.5-flash-latest':
case 'google/gemini-1.5-pro-latest':
case 'google/gemini-pro-latest':
// Virtual models that are translated by Cody Gateway and allow access to all StarCoder
// models hosted for us by Fireworks.
case 'fireworks/starcoder':

View File

@ -421,7 +421,7 @@ export const SiteAdminPingsPage: React.FunctionComponent<React.PropsWithChildren
<ul>
<li>
Provider (e.g., "sourcegraph", "anthropic", "openai", "azure-openai",
"fireworks", "aws-bedrock", etc.)
"fireworks", "aws-bedrock", "google", etc.)
</li>
<li>Chat model (included only for "sourcegraph" provider)</li>
<li>Fast chat model (included only for "sourcegraph" provider)</li>

View File

@ -38,6 +38,15 @@ func NewGoogleHandler(baseLogger log.Logger, eventLogger events.Logger, rs limit
)
}
// The request body for Google completions.
// Ref: https://ai.google.dev/api/rest/v1/models/generateContent#request-body
type googleRequest struct {
Model string `json:"model"`
Contents []googleContentMessage `json:"contents"`
GenerationConfig googleGenerationConfig `json:"generationConfig,omitempty"`
SafetySettings []googleSafetySettings `json:"safetySettings,omitempty"`
}
type googleContentMessage struct {
Role string `json:"role"`
Parts []struct {
@ -45,12 +54,22 @@ type googleContentMessage struct {
} `json:"parts"`
}
type googleRequest struct {
Model string `json:"model"`
Contents []googleContentMessage `json:"contents"`
GenerationConfig struct {
MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
} `json:"generationConfig,omitempty"`
// Configuration options for model generation and outputs.
// Ref: https://ai.google.dev/api/rest/v1/GenerationConfig
type googleGenerationConfig struct {
Temperature float32 `json:"temperature,omitempty"` // request.Temperature
TopP float32 `json:"topP,omitempty"` // request.TopP
TopK int `json:"topK,omitempty"` // request.TopK
StopSequences []string `json:"stopSequences,omitempty"` // request.StopSequences
MaxOutputTokens int `json:"maxOutputTokens,omitempty"` // request.MaxTokensToSample
CandidateCount int `json:"candidateCount,omitempty"` // request.CandidateCount
}
// Safety setting, affecting the safety-blocking behavior.
// Ref: https://ai.google.dev/gemini-api/docs/safety-settings
type googleSafetySettings struct {
Category string `json:"category"`
Threshold string `json:"threshold"`
}
func (r googleRequest) ShouldStream() bool {
@ -152,11 +171,8 @@ func (*GoogleHandlerMethods) parseResponseAndUsage(logger log.Logger, reqBody go
}
// Otherwise, we have to parse the event stream.
promptUsage.tokens = -1
promptUsage.tokenizerTokens = -1
completionUsage.tokens = -1
completionUsage.tokenizerTokens = -1
promptUsage.tokens, completionUsage.tokens = -1, -1
promptUsage.tokenizerTokens, completionUsage.tokenizerTokens = -1, -1
promptTokens, completionTokens, err := parseGoogleTokenUsage(r, logger)
if err != nil {
logger.Error("failed to decode Google streaming response", log.Error(err))
@ -173,35 +189,22 @@ const maxPayloadSize = 10 * 1024 * 1024 // 10mb
func parseGoogleTokenUsage(r io.Reader, logger log.Logger) (promptTokens int, completionTokens int, err error) {
scanner := bufio.NewScanner(r)
scanner.Buffer(make([]byte, 0, 4096), maxPayloadSize)
// bufio.ScanLines, except we look for \r\n\r\n which separate events.
split := func(data []byte, atEOF bool) (int, []byte, error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := bytes.Index(data, []byte("\r\n\r\n")); i >= 0 {
return i + 4, data[:i], nil
}
if i := bytes.Index(data, []byte("\n\n")); i >= 0 {
return i + 2, data[:i], nil
}
// If we're at EOF, we have a final, non-terminated event. This should
// be empty.
if atEOF {
return len(data), data, nil
}
// Request more data.
return 0, nil, nil
}
scanner.Split(split)
// skip to the last event
var lastEvent []byte
scanner.Split(bufio.ScanLines)
var lastLine []byte
for scanner.Scan() {
lastEvent = scanner.Bytes()
lastLine = scanner.Bytes()
}
var res googleResponse
if err := json.NewDecoder(bytes.NewReader(lastEvent[5:])).Decode(&res); err != nil {
logger.Error("failed to parse Google response as JSON", log.Error(err))
return -1, -1, err
if bytes.HasPrefix(bytes.TrimSpace(lastLine), []byte("data: ")) {
event := lastLine[5:]
var res googleResponse
if err := json.NewDecoder(bytes.NewReader(event)).Decode(&res); err != nil {
logger.Error("failed to parse Google response as JSON", log.Error(err))
return -1, -1, err
}
return res.UsageMetadata.PromptTokenCount, res.UsageMetadata.CompletionTokenCount, nil
}
return res.UsageMetadata.PromptTokenCount, res.UsageMetadata.CompletionTokenCount, nil
return -1, -1, errors.New("no Google response found")
}

View File

@ -4,6 +4,8 @@ import (
"strings"
"testing"
"bytes"
"github.com/sourcegraph/log/logtest"
"github.com/stretchr/testify/assert"
)
@ -31,3 +33,104 @@ data: {"candidates": [{"content": {"parts": [{"text": "\n for i in range(n-1):\
data: {"candidates": [{"content": {"parts": [{"text": " range(n-i-1):\n if list1[j] \u003e list1[j+1]:\n list1[j], list"}],"role": "model"},"finishReason": "STOP","index": 0,"safetyRatings": [{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HATE_SPEECH","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HARASSMENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_DANGEROUS_CONTENT","probability": "NEGLIGIBLE"}]}],"usageMetadata": {"promptTokenCount": 21,"candidatesTokenCount": 63,"totalTokenCount": 84}}
data: {"candidates": [{"content": {"parts": [{"text": "1[j+1] = list1[j+1], list1[j]\n return list1\n"}],"role": "model"},"finishReason": "STOP","index": 0,"safetyRatings": [{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HATE_SPEECH","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HARASSMENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_DANGEROUS_CONTENT","probability": "NEGLIGIBLE"}],"citationMetadata": {"citationSources": [{"startIndex": 1,"endIndex": 185,"uri": "https://github.com/Feng080412/Searches-and-sorts","license": ""}]}}],"usageMetadata": {"promptTokenCount": 21,"candidatesTokenCount": 87,"totalTokenCount": 108}}`
func TestParseGoogleTokenUsage(t *testing.T) {
tests := []struct {
name string
input string
want *googleResponse
wantErr bool
}{
{
name: "valid response",
input: `data: {"candidates": [{"content": {"parts": [{"text": "def"}],"role": "model"},"finishReason": "STOP","index": 0}],"usageMetadata": {"promptTokenCount": 21,"candidatesTokenCount": 1,"totalTokenCount": 22}}`,
want: &googleResponse{
UsageMetadata: googleUsage{
PromptTokenCount: 21,
CompletionTokenCount: 1,
TotalTokenCount: 0,
},
},
wantErr: false,
},
{
name: "valid response - with candidates",
input: `data: {"usageMetadata": {"promptTokenCount": 10, "candidatesTokenCount": 20}}`,
want: &googleResponse{
UsageMetadata: googleUsage{
PromptTokenCount: 10,
CompletionTokenCount: 20,
TotalTokenCount: 0,
},
},
wantErr: false,
},
{
name: "invalid JSON",
input: `data: {"usageMetadata": {"promptTokenCount": 10, "candidatesTokenCount": 20}`,
want: nil,
wantErr: true,
},
{
name: "no prefix",
input: `{"usageMetadata": {"promptTokenCount": 10, "candidatesTokenCount": 20}}`,
want: nil,
wantErr: true,
},
{
name: "empty input",
input: ``,
want: nil,
wantErr: true,
},
{
name: "multiple lines with one valid",
input: `data: {"usageMetadata": {"promptTokenCount": 5, "candidatesTokenCount": 15}}
data: {"usageMetadata": {"promptTokenCount": 10, "candidatesTokenCount": 20}}`,
want: &googleResponse{
UsageMetadata: googleUsage{
PromptTokenCount: 10,
CompletionTokenCount: 20,
TotalTokenCount: 0,
},
},
wantErr: false,
},
{
name: "non-JSON data",
input: `data: not-a-json`,
want: nil,
wantErr: true,
},
{
name: "partial data",
input: `data: {"usageMetadata": {"promptTokenCount": 10`,
want: nil,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := bytes.NewReader([]byte(tt.input))
logger := logtest.Scoped(t)
promptTokens, completionTokens, err := parseGoogleTokenUsage(r, logger)
if (err != nil) != tt.wantErr {
t.Errorf("parseGoogleTokenUsage() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.want != nil {
got := &googleResponse{
UsageMetadata: googleUsage{
PromptTokenCount: promptTokens,
CompletionTokenCount: completionTokens,
},
}
if !assert.ObjectsAreEqual(got, tt.want) {
t.Errorf("parseGoogleTokenUsage() mismatch (-want +got):\n%v", assert.ObjectsAreEqual(got, tt.want))
}
}
})
}
}

View File

@ -12,6 +12,7 @@ go_library(
"//internal/collections",
"//internal/completions/client/anthropic",
"//internal/completions/client/fireworks",
"//internal/completions/client/google",
"//internal/env",
"//internal/trace/policy",
"//lib/errors",

View File

@ -12,6 +12,7 @@ import (
"github.com/sourcegraph/sourcegraph/internal/collections"
"github.com/sourcegraph/sourcegraph/internal/completions/client/anthropic"
"github.com/sourcegraph/sourcegraph/internal/completions/client/fireworks"
"github.com/sourcegraph/sourcegraph/internal/completions/client/google"
"github.com/sourcegraph/sourcegraph/internal/env"
"github.com/sourcegraph/sourcegraph/internal/trace/policy"
"github.com/sourcegraph/sourcegraph/lib/errors"
@ -282,6 +283,20 @@ func (c *Config) Load() {
c.Fireworks.StarcoderCommunitySingleTenantPercent = c.GetPercent("CODY_GATEWAY_FIREWORKS_STARCODER_COMMUNITY_SINGLE_TENANT_PERCENT", "0", "The percentage of community traffic for Starcoder to be redirected to the single-tenant deployment.")
c.Fireworks.StarcoderEnterpriseSingleTenantPercent = c.GetPercent("CODY_GATEWAY_FIREWORKS_STARCODER_ENTERPRISE_SINGLE_TENANT_PERCENT", "100", "The percentage of Enterprise traffic for Starcoder to be redirected to the single-tenant deployment.")
// Configurations for Google Gemini models.
c.Google.AccessToken = c.GetOptional("CODY_GATEWAY_GOOGLE_ACCESS_TOKEN", "The Google AI Studio access token to be used.")
c.Google.AllowedModels = splitMaybe(c.Get("CODY_GATEWAY_GOOGLE_ALLOWED_MODELS",
strings.Join([]string{
google.Gemini15FlashLatest,
google.Gemini15ProLatest,
google.GeminiProLatest,
}, ","),
"Google models that can to be used."),
)
if c.Google.AccessToken != "" && len(c.Google.AllowedModels) == 0 {
c.AddError(errors.New("must provide allowed models for Google"))
}
c.AllowedEmbeddingsModels = splitMaybe(c.Get("CODY_GATEWAY_ALLOWED_EMBEDDINGS_MODELS", strings.Join([]string{string(embeddings.ModelNameOpenAIAda), string(embeddings.ModelNameSourcegraphSTMultiQA)}, ","), "The models allowed for embeddings generation."))
if len(c.AllowedEmbeddingsModels) == 0 {
c.AddError(errors.New("must provide allowed models for embeddings generation"))
@ -325,16 +340,6 @@ func (c *Config) Load() {
c.SAMSClientConfig.ClientSecret = c.GetOptional("SAMS_CLIENT_SECRET", "SAMS OAuth client secret")
c.Environment = c.Get("CODY_GATEWAY_ENVIRONMENT", "dev", "Environment name.")
c.Google.AccessToken = c.GetOptional("CODY_GATEWAY_GOOGLE_ACCESS_TOKEN", "The Google AI Studio access token to be used.")
c.Google.AllowedModels = splitMaybe(c.Get("CODY_GATEWAY_GOOGLE_ALLOWED_MODELS",
strings.Join([]string{
"gemini-1.5-pro-latest",
"gemini-1.5-flash-latest",
}, ","),
"Google models that can to be used."),
)
}
// loadFlaggingConfig loads the common set of flagging-related environment variables for

View File

@ -38,6 +38,7 @@ go_library(
"//internal/codygateway",
"//internal/completions/client/anthropic",
"//internal/completions/client/fireworks",
"//internal/completions/client/google",
"//internal/completions/types",
"//internal/conf",
"//internal/conf/conftypes",

View File

@ -21,6 +21,7 @@ import (
"github.com/sourcegraph/sourcegraph/internal/codygateway"
"github.com/sourcegraph/sourcegraph/internal/completions/client/anthropic"
"github.com/sourcegraph/sourcegraph/internal/completions/client/fireworks"
"github.com/sourcegraph/sourcegraph/internal/completions/client/google"
"github.com/sourcegraph/sourcegraph/internal/completions/types"
"github.com/sourcegraph/sourcegraph/internal/conf"
"github.com/sourcegraph/sourcegraph/internal/database"
@ -382,8 +383,9 @@ func allowedModels(scope types.CompletionsFeature, isProUser bool) []string {
"openai/gpt-4o",
"openai/gpt-4-turbo",
"openai/gpt-4-turbo-preview",
"google/gemini-1.5-pro-latest",
"google/gemini-1.5-flash-latest",
"google/" + google.Gemini15FlashLatest,
"google/" + google.Gemini15ProLatest,
"google/" + google.GeminiProLatest,
// Remove after the Claude 3 rollout is complete
"anthropic/claude-2",

View File

@ -22,6 +22,7 @@ go_library(
"//internal/completions/client",
"//internal/completions/client/anthropic",
"//internal/completions/client/fireworks",
"//internal/completions/client/google",
"//internal/completions/types",
"//internal/conf",
"//internal/conf/conftypes",

View File

@ -13,6 +13,7 @@ import (
"github.com/sourcegraph/sourcegraph/internal/completions/client/anthropic"
"github.com/sourcegraph/sourcegraph/internal/completions/client/fireworks"
"github.com/sourcegraph/sourcegraph/internal/completions/client/google"
"github.com/sourcegraph/sourcegraph/internal/completions/types"
"github.com/sourcegraph/sourcegraph/internal/conf/conftypes"
"github.com/sourcegraph/sourcegraph/internal/database"
@ -85,8 +86,9 @@ func isAllowedCustomChatModel(model string, isProUser bool) bool {
"openai/gpt-4o",
"openai/gpt-4-turbo",
"openai/gpt-4-turbo-preview",
"google/gemini-1.5-flash-latest",
"google/gemini-1.5-pro-latest",
"google" + google.Gemini15FlashLatest,
"google" + google.Gemini15ProLatest,
"google" + google.GeminiProLatest,
// Remove after the Claude 3 rollout is complete
"anthropic/claude-2",

View File

@ -15,6 +15,7 @@ go_library(
"//internal/completions/client/azureopenai",
"//internal/completions/client/codygateway",
"//internal/completions/client/fireworks",
"//internal/completions/client/google",
"//internal/completions/client/openai",
"//internal/completions/tokenusage",
"//internal/completions/types",

View File

@ -8,6 +8,7 @@ import (
"github.com/sourcegraph/sourcegraph/internal/completions/client/azureopenai"
"github.com/sourcegraph/sourcegraph/internal/completions/client/codygateway"
"github.com/sourcegraph/sourcegraph/internal/completions/client/fireworks"
"github.com/sourcegraph/sourcegraph/internal/completions/client/google"
"github.com/sourcegraph/sourcegraph/internal/completions/client/openai"
"github.com/sourcegraph/sourcegraph/internal/completions/tokenusage"
"github.com/sourcegraph/sourcegraph/internal/completions/types"
@ -40,6 +41,8 @@ func getBasic(endpoint string, provider conftypes.CompletionsProviderName, acces
return openai.NewClient(httpcli.UncachedExternalDoer, endpoint, accessToken, *tokenManager), nil
case conftypes.CompletionsProviderNameAzureOpenAI:
return azureopenai.NewClient(azureopenai.GetAPIClient, endpoint, accessToken, *tokenManager)
case conftypes.CompletionsProviderNameGoogle:
return google.NewClient(httpcli.UncachedExternalDoer, endpoint, accessToken), nil
case conftypes.CompletionsProviderNameSourcegraph:
return codygateway.NewClient(httpcli.CodyGatewayDoer, endpoint, accessToken, *tokenManager)
case conftypes.CompletionsProviderNameFireworks:

View File

@ -12,6 +12,7 @@ go_library(
"//internal/codygateway",
"//internal/completions/client/anthropic",
"//internal/completions/client/fireworks",
"//internal/completions/client/google",
"//internal/completions/client/openai",
"//internal/completions/tokenusage",
"//internal/completions/types",

View File

@ -15,6 +15,7 @@ import (
"github.com/sourcegraph/sourcegraph/internal/codygateway"
"github.com/sourcegraph/sourcegraph/internal/completions/client/anthropic"
"github.com/sourcegraph/sourcegraph/internal/completions/client/fireworks"
"github.com/sourcegraph/sourcegraph/internal/completions/client/google"
"github.com/sourcegraph/sourcegraph/internal/completions/client/openai"
"github.com/sourcegraph/sourcegraph/internal/completions/tokenusage"
"github.com/sourcegraph/sourcegraph/internal/completions/types"
@ -89,6 +90,8 @@ func (c *codyGatewayClient) clientForParams(feature types.CompletionsFeature, re
return openai.NewClient(gatewayDoer(c.upstream, feature, c.gatewayURL, c.accessToken, "/v1/completions/openai"), "", "", c.tokenizer), nil
case string(conftypes.CompletionsProviderNameFireworks):
return fireworks.NewClient(gatewayDoer(c.upstream, feature, c.gatewayURL, c.accessToken, "/v1/completions/fireworks"), "", ""), nil
case string(conftypes.CompletionsProviderNameGoogle):
return google.NewClient(gatewayDoer(c.upstream, feature, c.gatewayURL, c.accessToken, "/v1/completions/google"), "", ""), nil
case "":
return nil, errors.Newf("no provider provided in model %s - a model in the format '$PROVIDER/$MODEL_NAME' is expected", model)
default:

View File

@ -0,0 +1,38 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library")
load("//dev:go_defs.bzl", "go_test")
go_library(
name = "google",
srcs = [
"decoder.go",
"google.go",
"models.go",
"prompt.go",
"types.go",
],
importpath = "github.com/sourcegraph/sourcegraph/internal/completions/client/google",
visibility = ["//:__subpackages__"],
deps = [
"//internal/completions/types",
"//internal/httpcli",
"//lib/errors",
"@com_github_sourcegraph_log//:log",
],
)
go_test(
name = "google_test",
srcs = [
"decoder_test.go",
"google_test.go",
"prompt_test.go",
],
embed = [":google"],
deps = [
"//internal/completions/types",
"@com_github_hexops_autogold_v2//:autogold",
"@com_github_sourcegraph_log//:log",
"@com_github_stretchr_testify//assert",
"@com_github_stretchr_testify//require",
],
)

View File

@ -0,0 +1,96 @@
package google
import (
"bufio"
"bytes"
"io"
"github.com/sourcegraph/sourcegraph/lib/errors"
)
const maxPayloadSize = 10 * 1024 * 1024 // 10mb
var doneBytes = []byte("[DONE]")
// decoder decodes streaming events from a Server Sent Event stream.
type decoder struct {
scanner *bufio.Scanner
done bool
data []byte
err error
}
func NewDecoder(r io.Reader) *decoder {
scanner := bufio.NewScanner(r)
scanner.Buffer(make([]byte, 0, 4096), maxPayloadSize)
// bufio.ScanLines, except we look for \r\n\r\n which separate events.
split := func(data []byte, atEOF bool) (int, []byte, error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := bytes.Index(data, []byte("\r\n\r\n")); i >= 0 {
return i + 4, data[:i], nil
}
if i := bytes.Index(data, []byte("\n\n")); i >= 0 {
return i + 2, data[:i], nil
}
// If we're at EOF, we have a final, non-terminated event. This should
// be empty.
if atEOF {
return len(data), data, nil
}
// Request more data.
return 0, nil, nil
}
scanner.Split(split)
return &decoder{
scanner: scanner,
}
}
// Scan advances the decoder to the next event in the stream. It returns
// false when it either hits the end of the stream or an error.
func (d *decoder) Scan() bool {
if d.done {
return false
}
for d.scanner.Scan() {
line := d.scanner.Bytes()
typ, data := splitColon(line)
switch {
case bytes.Equal(typ, []byte("data")):
d.data = data
// Check for special sentinel value used by the Google API to
// indicate that the stream is done.
if bytes.Equal(data, doneBytes) {
d.done = true
return false
}
return true
default:
d.err = errors.Errorf("malformed data, expected data: %s", typ)
return false
}
}
d.err = d.scanner.Err()
return false
}
// Event returns the event data of the last decoded event
func (d *decoder) Data() []byte {
return d.data
}
// Err returns the last encountered error
func (d *decoder) Err() error {
return d.err
}
func splitColon(data []byte) ([]byte, []byte) {
i := bytes.Index(data, []byte(":"))
if i < 0 {
return bytes.TrimSpace(data), nil
}
return bytes.TrimSpace(data[:i]), bytes.TrimSpace(data[i+1:])
}

View File

@ -0,0 +1,56 @@
package google
import (
"strings"
"testing"
"github.com/stretchr/testify/require"
)
func TestDecoder(t *testing.T) {
t.Parallel()
type event struct {
data string
}
decodeAll := func(input string) ([]event, error) {
dec := NewDecoder(strings.NewReader(input))
var events []event
for dec.Scan() {
events = append(events, event{
data: string(dec.Data()),
})
}
return events, dec.Err()
}
t.Run("Single", func(t *testing.T) {
events, err := decodeAll("data:b\n\n")
require.NoError(t, err)
require.Equal(t, events, []event{{data: "b"}})
})
t.Run("Multiple", func(t *testing.T) {
events, err := decodeAll("data:b\n\ndata:c\n\ndata: [DONE]\n\n")
require.NoError(t, err)
require.Equal(t, events, []event{{data: "b"}, {data: "c"}})
})
t.Run("Multiple with new line within data", func(t *testing.T) {
events, err := decodeAll("data:line1\nline2\nline3\n\ndata:second-data\n\ndata: [DONE]\n\n")
require.NoError(t, err)
require.Equal(t, events, []event{{data: "line1\nline2\nline3"}, {data: "second-data"}})
})
t.Run("ErrExpectedData", func(t *testing.T) {
_, err := decodeAll("datas:b\n\n")
require.Contains(t, err.Error(), "malformed data, expected data")
})
t.Run("Ends after done", func(t *testing.T) {
events, err := decodeAll("data:b\n\ndata:c\n\ndata: [DONE]\n\ndata:d\n\n")
require.NoError(t, err)
require.Equal(t, events, []event{{data: "b"}, {data: "c"}})
})
}

View File

@ -0,0 +1,245 @@
package google
import (
"bytes"
"context"
"encoding/json"
"net/http"
"net/url"
"path"
"github.com/sourcegraph/log"
"github.com/sourcegraph/sourcegraph/internal/completions/types"
"github.com/sourcegraph/sourcegraph/internal/httpcli"
"github.com/sourcegraph/sourcegraph/lib/errors"
)
func NewClient(cli httpcli.Doer, endpoint, accessToken string) types.CompletionsClient {
return &googleCompletionStreamClient{
cli: cli,
accessToken: accessToken,
endpoint: endpoint,
}
}
func (c *googleCompletionStreamClient) Complete(
ctx context.Context,
feature types.CompletionsFeature,
_ types.CompletionsVersion,
requestParams types.CompletionRequestParameters,
logger log.Logger,
) (*types.CompletionResponse, error) {
if !isSupportedFeature(feature) {
return nil, errors.Newf("feature %q is currently not supported for Google", feature)
}
var resp *http.Response
var err error
defer (func() {
if resp != nil {
resp.Body.Close()
}
})()
resp, err = c.makeRequest(ctx, requestParams, false)
if err != nil {
return nil, err
}
defer resp.Body.Close()
var response googleResponse
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
return nil, err
}
if len(response.Candidates) == 0 {
// Empty response.
return &types.CompletionResponse{}, nil
}
if len(response.Candidates[0].Content.Parts) == 0 {
// Empty response.
return &types.CompletionResponse{}, nil
}
// NOTE: Candidates can be used to get multiple completions when CandidateCount is set,
// which is not currently supported by Cody. For now, we only return the first completion.
return &types.CompletionResponse{
Completion: response.Candidates[0].Content.Parts[0].Text,
StopReason: response.Candidates[0].FinishReason,
}, nil
}
func (c *googleCompletionStreamClient) Stream(
ctx context.Context,
feature types.CompletionsFeature,
_ types.CompletionsVersion,
requestParams types.CompletionRequestParameters,
sendEvent types.SendCompletionEvent,
logger log.Logger,
) error {
if !isSupportedFeature(feature) {
return errors.Newf("feature %q is currently not supported for Google", feature)
}
var resp *http.Response
var err error
defer (func() {
if resp != nil {
resp.Body.Close()
}
})()
resp, err = c.makeRequest(ctx, requestParams, true)
if err != nil {
return err
}
dec := NewDecoder(resp.Body)
var content string
var ev types.CompletionResponse
for dec.Scan() {
if ctx.Err() != nil && ctx.Err() == context.Canceled {
return nil
}
data := dec.Data()
// Gracefully skip over any data that isn't JSON-like.
if !bytes.HasPrefix(data, []byte("{")) {
continue
}
var event googleResponse
if err := json.Unmarshal(data, &event); err != nil {
return errors.Errorf("failed to decode event payload: %w - body: %s", err, string(data))
}
if len(event.Candidates) > 0 && len(event.Candidates[0].Content.Parts) > 0 {
content += event.Candidates[0].Content.Parts[0].Text
ev = types.CompletionResponse{
Completion: content,
StopReason: event.Candidates[0].FinishReason,
}
err = sendEvent(ev)
if err != nil {
return err
}
}
}
if dec.Err() != nil {
return dec.Err()
}
return nil
}
// makeRequest formats the request and calls the chat/completions endpoint for code_completion requests
func (c *googleCompletionStreamClient) makeRequest(ctx context.Context, requestParams types.CompletionRequestParameters, stream bool) (*http.Response, error) {
// Ensure TopK and TopP are non-negative
requestParams.TopK = max(0, requestParams.TopK)
requestParams.TopP = max(0, requestParams.TopP)
// Generate the prompt
prompt, err := getPrompt(requestParams.Messages)
if err != nil {
return nil, err
}
payload := googleRequest{
Contents: prompt,
GenerationConfig: googleGenerationConfig{
Temperature: requestParams.Temperature,
TopP: requestParams.TopP,
TopK: requestParams.TopK,
MaxOutputTokens: requestParams.MaxTokensToSample,
StopSequences: requestParams.StopSequences,
},
}
reqBody, err := json.Marshal(payload)
if err != nil {
return nil, err
}
apiURL := c.getAPIURL(requestParams, stream)
req, err := http.NewRequestWithContext(ctx, "POST", apiURL.String(), bytes.NewReader(reqBody))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
// Vertex AI API requires an Authorization header with the access token.
// Ref: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/gemini#sample-requests
if !isDefaultAPIEndpoint(apiURL) {
req.Header.Set("Authorization", "Bearer "+c.accessToken)
}
resp, err := c.cli.Do(req)
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusOK {
return nil, types.NewErrStatusNotOK("Google", resp)
}
return resp, nil
}
// In the latest API Docs, the model name and API key must be used with the default API endpoint URL.
// Ref: https://ai.google.dev/gemini-api/docs/get-started/tutorial?lang=rest#gemini_and_content_based_apis
func (c *googleCompletionStreamClient) getAPIURL(requestParams types.CompletionRequestParameters, stream bool) *url.URL {
apiURL, err := url.Parse(c.endpoint)
if err != nil {
apiURL = &url.URL{
Scheme: "https",
Host: defaultAPIHost,
Path: defaultAPIPath,
}
}
apiURL.Path = path.Join(apiURL.Path, requestParams.Model) + ":" + getgRPCMethod(stream)
// We need to append the API key to the default API endpoint URL.
if isDefaultAPIEndpoint(apiURL) {
query := apiURL.Query()
query.Set("key", c.accessToken)
if stream {
query.Set("alt", "sse")
}
apiURL.RawQuery = query.Encode()
}
return apiURL
}
// getgRPCMethod returns the gRPC method name based on the stream flag.
func getgRPCMethod(stream bool) string {
if stream {
return "streamGenerateContent"
}
return "generateContent"
}
// isSupportedFeature checks if the given CompletionsFeature is supported.
// Currently, only the CompletionsFeatureChat feature is supported.
func isSupportedFeature(feature types.CompletionsFeature) bool {
switch feature {
case types.CompletionsFeatureChat:
return true
default:
return false
}
}
// isDefaultAPIEndpoint checks if the given API endpoint URL is the default API endpoint.
// The default API endpoint is determined by the defaultAPIHost constant.
func isDefaultAPIEndpoint(endpoint *url.URL) bool {
return endpoint.Host == defaultAPIHost
}

View File

@ -0,0 +1,129 @@
package google
import (
"bytes"
"context"
"io"
"net/http"
"testing"
"github.com/hexops/autogold/v2"
"github.com/sourcegraph/log"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/sourcegraph/sourcegraph/internal/completions/types"
)
type mockDoer struct {
do func(*http.Request) (*http.Response, error)
}
func (c *mockDoer) Do(r *http.Request) (*http.Response, error) {
return c.do(r)
}
func TestErrStatusNotOK(t *testing.T) {
mockClient := NewClient(&mockDoer{
func(r *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusTooManyRequests,
Body: io.NopCloser(bytes.NewReader([]byte("oh no, please slow down!"))),
}, nil
},
}, "", "")
t.Run("Complete", func(t *testing.T) {
logger := log.Scoped("completions")
resp, err := mockClient.Complete(context.Background(), types.CompletionsFeatureChat, types.CompletionsVersionLegacy, types.CompletionRequestParameters{}, logger)
require.Error(t, err)
assert.Nil(t, resp)
autogold.Expect("Google: unexpected status code 429: oh no, please slow down!").Equal(t, err.Error())
_, ok := types.IsErrStatusNotOK(err)
assert.True(t, ok)
})
t.Run("Stream", func(t *testing.T) {
logger := log.Scoped("completions")
err := mockClient.Stream(context.Background(), types.CompletionsFeatureChat, types.CompletionsVersionLegacy, types.CompletionRequestParameters{}, func(event types.CompletionResponse) error { return nil }, logger)
require.Error(t, err)
autogold.Expect("Google: unexpected status code 429: oh no, please slow down!").Equal(t, err.Error())
_, ok := types.IsErrStatusNotOK(err)
assert.True(t, ok)
})
}
func TestGetAPIURL(t *testing.T) {
t.Parallel()
client := &googleCompletionStreamClient{
endpoint: "https://generativelanguage.googleapis.com/v1/models",
accessToken: "test-token",
}
t.Run("valid v1 endpoint", func(t *testing.T) {
params := types.CompletionRequestParameters{
Model: "test-model",
}
url := client.getAPIURL(params, false).String()
expected := "https://generativelanguage.googleapis.com/v1/models/test-model:generateContent?key=test-token"
require.Equal(t, expected, url)
})
//
t.Run("valid endpoint for Vertex AI", func(t *testing.T) {
params := types.CompletionRequestParameters{
Model: "gemini-1.5-pro",
}
c := &googleCompletionStreamClient{
endpoint: "https://vertex-ai.example.com/v1/projects/PROJECT_ID/locations/LOCATION/publishers/google/models",
accessToken: "test-token",
}
url := c.getAPIURL(params, true).String()
expected := "https://vertex-ai.example.com/v1/projects/PROJECT_ID/locations/LOCATION/publishers/google/models/gemini-1.5-pro:streamGenerateContent"
require.Equal(t, expected, url)
})
t.Run("valid custom endpoint", func(t *testing.T) {
params := types.CompletionRequestParameters{
Model: "test-model",
}
c := &googleCompletionStreamClient{
endpoint: "https://example.com/api/models",
accessToken: "test-token",
}
url := c.getAPIURL(params, true).String()
expected := "https://example.com/api/models/test-model:streamGenerateContent"
require.Equal(t, expected, url)
})
t.Run("invalid endpoint", func(t *testing.T) {
client.endpoint = "://invalid"
params := types.CompletionRequestParameters{
Model: "test-model",
}
url := client.getAPIURL(params, false).String()
expected := "https://generativelanguage.googleapis.com/v1beta/models/test-model:generateContent?key=test-token"
require.Equal(t, expected, url)
})
t.Run("streaming", func(t *testing.T) {
params := types.CompletionRequestParameters{
Model: "test-model",
}
url := client.getAPIURL(params, true).String()
expected := "https://generativelanguage.googleapis.com/v1beta/models/test-model:streamGenerateContent?alt=sse&key=test-token"
require.Equal(t, expected, url)
})
t.Run("empty model", func(t *testing.T) {
params := types.CompletionRequestParameters{
Model: "",
}
url := client.getAPIURL(params, false).String()
expected := "https://generativelanguage.googleapis.com/v1beta/models:generateContent?key=test-token"
require.Equal(t, expected, url)
})
}

View File

@ -0,0 +1,21 @@
package google
// For latest available Google Gemini models,
// See: https://ai.google.dev/gemini-api/docs/models/gemini
const providerName = "google"
// Default API endpoint URL
const (
defaultAPIHost = "generativelanguage.googleapis.com"
defaultAPIPath = "/v1beta/models"
)
// Latest stable versions
const Gemini15Flash = "gemini-1.5-flash"
const Gemini15Pro = "gemini-1.5-pro"
const GeminiPro = "gemini-pro"
// Latest versions
const Gemini15FlashLatest = "gemini-1.5-flash-latest"
const Gemini15ProLatest = "gemini-1.5-pro-latest"
const GeminiProLatest = "gemini-pro-latest"

View File

@ -0,0 +1,42 @@
package google
import (
"github.com/sourcegraph/sourcegraph/internal/completions/types"
"github.com/sourcegraph/sourcegraph/lib/errors"
)
func getPrompt(messages []types.Message) ([]googleContentMessage, error) {
googleMessages := make([]googleContentMessage, 0, len(messages))
for i, message := range messages {
var googleRole string
switch message.Speaker {
case types.SYSTEM_MESSAGE_SPEAKER:
if i != 0 {
return nil, errors.New("system role can only be used in the first message")
}
googleRole = message.Speaker
case types.ASSISTANT_MESSAGE_SPEAKER:
if i == 0 {
return nil, errors.New("assistant role cannot be used in the first message")
}
googleRole = "model"
case types.HUMAN_MESSAGE_SPEAKER:
googleRole = "user"
default:
return nil, errors.Errorf("unexpected role: %s", message.Text)
}
if message.Text == "" {
return nil, errors.New("message content cannot be empty")
}
googleMessages = append(googleMessages, googleContentMessage{
Role: googleRole,
Parts: []googleContentMessagePart{{Text: message.Text}},
})
}
return googleMessages, nil
}

View File

@ -0,0 +1,70 @@
package google
import (
"testing"
"github.com/sourcegraph/sourcegraph/internal/completions/types"
)
func TestGetPrompt(t *testing.T) {
t.Run("invalid speaker", func(t *testing.T) {
_, err := getPrompt([]types.Message{{Speaker: "invalid", Text: "hello"}})
if err == nil {
t.Errorf("expected error for invalid speaker, got nil")
}
})
t.Run("empty text", func(t *testing.T) {
_, err := getPrompt([]types.Message{{Speaker: types.HUMAN_MESSAGE_SPEAKER, Text: ""}})
if err == nil {
t.Errorf("expected error for empty text, got nil")
}
})
t.Run("multiple system messages", func(t *testing.T) {
_, err := getPrompt([]types.Message{
{Speaker: types.SYSTEM_MESSAGE_SPEAKER, Text: "system"},
{Speaker: types.HUMAN_MESSAGE_SPEAKER, Text: "hello"},
{Speaker: types.SYSTEM_MESSAGE_SPEAKER, Text: "system"},
})
if err == nil {
t.Errorf("expected error for multiple system messages, got nil")
}
})
t.Run("invalid prompt starts with assistant", func(t *testing.T) {
_, err := getPrompt([]types.Message{
{Speaker: types.ASSISTANT_MESSAGE_SPEAKER, Text: "assistant"},
{Speaker: types.HUMAN_MESSAGE_SPEAKER, Text: "hello"},
{Speaker: types.ASSISTANT_MESSAGE_SPEAKER, Text: "assistant"},
})
if err == nil {
t.Errorf("expected error for messages starts with assistant, got nil")
}
})
t.Run("valid prompt", func(t *testing.T) {
messages := []types.Message{
{Speaker: types.SYSTEM_MESSAGE_SPEAKER, Text: "system"},
{Speaker: types.HUMAN_MESSAGE_SPEAKER, Text: "hello"},
{Speaker: types.ASSISTANT_MESSAGE_SPEAKER, Text: "hi"},
}
prompt, err := getPrompt(messages)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
expected := []googleContentMessage{
{Role: "system", Parts: []googleContentMessagePart{{Text: "system"}}},
{Role: "user", Parts: []googleContentMessagePart{{Text: "hello"}}},
{Role: "model", Parts: []googleContentMessagePart{{Text: "hi"}}},
}
if len(prompt) != len(expected) {
t.Errorf("unexpected prompt length, got %d, want %d", len(prompt), len(expected))
}
for i := range prompt {
if prompt[i].Parts[0].Text != expected[i].Parts[0].Text {
t.Errorf("unexpected prompt message at index %d, got %v, want %v", i, prompt[i], expected[i])
}
}
})
}

View File

@ -0,0 +1,60 @@
package google
import "github.com/sourcegraph/sourcegraph/internal/httpcli"
type googleCompletionStreamClient struct {
cli httpcli.Doer
accessToken string
endpoint string
}
type googleRequest struct {
Contents []googleContentMessage `json:"contents"`
GenerationConfig googleGenerationConfig `json:"generationConfig,omitempty"`
SafetySettings []googleSafetySettings `json:"safetySettings,omitempty"`
}
type googleContentMessage struct {
Role string `json:"role"`
Parts []googleContentMessagePart `json:"parts"`
}
type googleContentMessagePart struct {
Text string `json:"text"`
}
// Configuration options for model generation and outputs.
// Ref: https://ai.google.dev/api/rest/v1/GenerationConfig
type googleGenerationConfig struct {
Temperature float32 `json:"temperature,omitempty"` // request.Temperature
TopP float32 `json:"topP,omitempty"` // request.TopP
TopK int `json:"topK,omitempty"` // request.TopK
StopSequences []string `json:"stopSequences,omitempty"` // request.StopSequences
MaxOutputTokens int `json:"maxOutputTokens,omitempty"` // request.MaxTokensToSample
CandidateCount int `json:"candidateCount,omitempty"` // request.CandidateCount
}
type googleResponse struct {
Model string `json:"model"`
Candidates []struct {
Content googleContentMessage
FinishReason string `json:"finishReason"`
} `json:"candidates"`
UsageMetadata googleUsage `json:"usageMetadata"`
SafetySettings []googleSafetySettings `json:"safetySettings,omitempty"`
}
// Safety setting, affecting the safety-blocking behavior.
// Ref: https://ai.google.dev/gemini-api/docs/safety-settings
type googleSafetySettings struct {
Category string `json:"category"`
Threshold string `json:"threshold"`
}
type googleUsage struct {
PromptTokenCount int `json:"promptTokenCount"`
// Use the same name we use elsewhere (completion instead of candidates)
CompletionTokenCount int `json:"candidatesTokenCount"`
TotalTokenCount int `json:"totalTokenCount"`
}

View File

@ -28,6 +28,7 @@ go_library(
deps = [
"//internal/api/internalapi",
"//internal/completions/client/anthropic",
"//internal/completions/client/google",
"//internal/conf/confdefaults",
"//internal/conf/conftypes",
"//internal/conf/deploy",

View File

@ -10,6 +10,7 @@ import (
"github.com/hashicorp/cronexpr"
"github.com/sourcegraph/sourcegraph/internal/completions/client/anthropic"
"github.com/sourcegraph/sourcegraph/internal/completions/client/google"
"github.com/sourcegraph/sourcegraph/internal/conf/confdefaults"
"github.com/sourcegraph/sourcegraph/internal/conf/conftypes"
"github.com/sourcegraph/sourcegraph/internal/conf/deploy"
@ -867,6 +868,32 @@ func GetCompletionsConfig(siteConfig schema.SiteConfiguration) (c *conftypes.Com
completionsModelRef := conftypes.NewBedrockModelRefFromModelID(completionsConfig.CompletionModel)
completionsConfig.CompletionModel = completionsModelRef.CanonicalizedModelID()
canonicalized = true
} else if completionsConfig.Provider == string(conftypes.CompletionsProviderNameGoogle) {
// If no endpoint is configured, use a default value.
if completionsConfig.Endpoint == "" {
completionsConfig.Endpoint = "https://generativelanguage.googleapis.com/v1beta/models"
}
// If not access token is set, we cannot talk to Google. Bail.
if completionsConfig.AccessToken == "" {
return nil
}
// Set a default chat model.
if completionsConfig.ChatModel == "" {
completionsConfig.ChatModel = google.Gemini15ProLatest
}
// Set a default fast chat model.
if completionsConfig.FastChatModel == "" {
completionsConfig.FastChatModel = google.Gemini15FlashLatest
}
// Set a default completions model.
if completionsConfig.CompletionModel == "" {
// Code completion is not supported by Google
completionsConfig.CompletionModel = google.Gemini15FlashLatest
}
}
// only apply canonicalization if not already applied. Not all model IDs can simply be lowercased

View File

@ -50,6 +50,7 @@ func NewCodyGatewayChatRateLimit(plan Plan, userCount *int) CodyGatewayRateLimit
"google/gemini-1.5-pro-latest",
"google/gemini-1.5-flash-latest",
"google/gemini-pro-latest",
}
switch plan {
// TODO: This is just an example for now.

View File

@ -38,6 +38,7 @@ func TestNewCodyGatewayChatRateLimit(t *testing.T) {
"openai/gpt-4-turbo-preview",
"google/gemini-1.5-pro-latest",
"google/gemini-1.5-flash-latest",
"google/gemini-pro-latest",
},
Limit: 2500,
IntervalSeconds: 60 * 60 * 24,
@ -65,6 +66,7 @@ func TestNewCodyGatewayChatRateLimit(t *testing.T) {
"openai/gpt-4-turbo-preview",
"google/gemini-1.5-pro-latest",
"google/gemini-1.5-flash-latest",
"google/gemini-pro-latest",
},
Limit: 50,
IntervalSeconds: 60 * 60 * 24,
@ -92,6 +94,7 @@ func TestNewCodyGatewayChatRateLimit(t *testing.T) {
"openai/gpt-4-turbo-preview",
"google/gemini-1.5-pro-latest",
"google/gemini-1.5-flash-latest",
"google/gemini-pro-latest",
},
Limit: 10,
IntervalSeconds: 60 * 60 * 24,

View File

@ -2890,7 +2890,7 @@
"type": "string",
"description": "The external completions provider. Defaults to 'sourcegraph'.",
"default": "sourcegraph",
"enum": ["anthropic", "openai", "sourcegraph", "azure-openai", "aws-bedrock", "fireworks"]
"enum": ["anthropic", "openai", "sourcegraph", "azure-openai", "aws-bedrock", "fireworks", "google"]
},
"endpoint": {
"type": "string",