From f2590cbb36d39b0cbda32b809f10459b45046326 Mon Sep 17 00:00:00 2001 From: Beatrix <68532117+abeatrix@users.noreply.github.com> Date: Tue, 4 Jun 2024 16:46:36 -0700 Subject: [PATCH] Cody Gateway: Add Gemini models to PLG and Enterprise users (#63053) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit CLOSE https://github.com/sourcegraph/cody-issues/issues/211 & https://github.com/sourcegraph/cody-issues/issues/412 & https://github.com/sourcegraph/cody-issues/issues/412 UNBLOCK https://github.com/sourcegraph/cody/pull/4360 * Add support for Google Gemini AI models as chat completions provider * Add new `google` package to handle Google Generative AI client * Update `client.go` and `codygateway.go` to handle the new Google provider * Set default models for chat, fast chat, and completions when Google is the configured provider * Add gemini-pro to the allowed list ## Test plan For Enterprise instances using google as provider: 1. In your Soucegraph local instance's Site Config, add the following: ``` "accessToken": "REDACTED", "chatModel": "gemini-1.5-pro-latest", "provider": "google", ``` Note: You can get the accessToken for Gemini API in 1Password. 2. After saving the site config with the above change, run the following curl command: ``` curl 'https://sourcegraph.test:3443/.api/completions/stream' -i \ -X POST \ -H 'authorization: token $LOCAL_INSTANCE_TOKEN' \ --data-raw '{"messages":[{"speaker":"human","text":"Who are you?"}],"maxTokensToSample":30,"temperature":0,"stopSequences":[],"timeoutMs":5000,"stream":true,"model":"gemini-1.5-pro-latest"}' ``` 3. Expected Output: ``` ❯ curl 'https://sourcegraph.test:3443/.api/completions/stream' -i \ -X POST \ -H 'authorization: token ' \ --data-raw '{"messages":[{"speaker":"human","text":"Who are you?"}],"maxTokensToSample":30,"temperature":0,"stopSequences":[],"timeoutMs":5000,"stream":true,"model":"gemini-1.5-pro-latest"}' HTTP/2 200 access-control-allow-credentials: true access-control-allow-origin: alt-svc: h3=":3443"; ma=2592000 cache-control: no-cache content-type: text/event-stream date: Tue, 04 Jun 2024 05:45:33 GMT server: Caddy server: Caddy vary: Accept-Encoding, Authorization, Cookie, Authorization, X-Requested-With, Cookie x-accel-buffering: no x-content-type-options: nosniff x-frame-options: DENY x-powered-by: Express x-trace: d4b1f02a3e2882a3d52331335d217b03 x-trace-span: 728ec33860d3b5e6 x-trace-url: https://sourcegraph.test:3443/-/debug/jaeger/trace/d4b1f02a3e2882a3d52331335d217b03 x-xss-protection: 1; mode=block event: completion data: {"completion":"I","stopReason":"STOP"} event: completion data: {"completion":"I am a large language model, trained by Google. \n\nThink of me as","stopReason":"STOP"} event: completion data: {"completion":"I am a large language model, trained by Google. \n\nThink of me as a computer program that can understand and generate human-like text.","stopReason":"MAX_TOKENS"} event: done data: {} ``` Verified locally: ![image](https://github.com/sourcegraph/sourcegraph/assets/68532117/2e6c914d-7a77-4484-b693-16bbc394518c) #### Before Cody Gateway returns `no client known for upstream provider google` ```sh curl -X 'POST' -d '{"messages":[{"speaker":"human","text":"Who are you?"}],"maxTokensToSample":30,"temperature":0,"stopSequences":[],"timeoutMs":5000,"stream":true,"model":"google/gemini-1.5-pro-latest"}' -H 'Accept: application/json' -H 'Authorization: token $YOUR_DOTCOM_TOKEN' -H 'Content-Type: application/json' 'https://sourcegraph.com/.api/completions/stream' event: error data: {"error":"no client known for upstream provider google"} event: done data: { ``` ## Changelog Added support for Google as an LLM provider for Cody, with the following models available through Cody Gateway: Gemini Pro (`gemini-pro-latest`), Gemini 1.5 Flash (`gemini-1.5-flash-latest`), and Gemini 1.5 Pro (`gemini-1.5-pro-latest`). --- CHANGELOG.md | 1 + .../productSubscriptions/ModelBadges.tsx | 5 +- .../web/src/site-admin/SiteAdminPingsPage.tsx | 2 +- .../internal/httpapi/completions/google.go | 81 +++--- .../httpapi/completions/google_test.go | 103 ++++++++ cmd/cody-gateway/shared/config/BUILD.bazel | 1 + cmd/cody-gateway/shared/config/config.go | 25 +- .../dotcom/productsubscription/BUILD.bazel | 1 + .../codygateway_dotcom_user.go | 6 +- .../internal/httpapi/completions/BUILD.bazel | 1 + .../internal/httpapi/completions/chat.go | 6 +- internal/completions/client/BUILD.bazel | 1 + internal/completions/client/client.go | 3 + .../client/codygateway/BUILD.bazel | 1 + .../client/codygateway/codygateway.go | 3 + .../completions/client/google/BUILD.bazel | 38 +++ internal/completions/client/google/decoder.go | 96 +++++++ .../completions/client/google/decoder_test.go | 56 ++++ internal/completions/client/google/google.go | 245 ++++++++++++++++++ .../completions/client/google/google_test.go | 129 +++++++++ internal/completions/client/google/models.go | 21 ++ internal/completions/client/google/prompt.go | 42 +++ .../completions/client/google/prompt_test.go | 70 +++++ internal/completions/client/google/types.go | 60 +++++ internal/conf/BUILD.bazel | 1 + internal/conf/computed.go | 27 ++ internal/licensing/codygateway.go | 1 + internal/licensing/codygateway_test.go | 3 + schema/site.schema.json | 2 +- 29 files changed, 975 insertions(+), 56 deletions(-) create mode 100644 internal/completions/client/google/BUILD.bazel create mode 100644 internal/completions/client/google/decoder.go create mode 100644 internal/completions/client/google/decoder_test.go create mode 100644 internal/completions/client/google/google.go create mode 100644 internal/completions/client/google/google_test.go create mode 100644 internal/completions/client/google/models.go create mode 100644 internal/completions/client/google/prompt.go create mode 100644 internal/completions/client/google/prompt_test.go create mode 100644 internal/completions/client/google/types.go diff --git a/CHANGELOG.md b/CHANGELOG.md index 57f9407ee0f..e0b2c73fc1e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,7 @@ All notable changes to Sourcegraph are documented in this file. - A feature flag for Cody, `completions.smartContextWindow` is added and set to "enabled" by default. It allows clients to adjust the context window based on the name of the chat model. When smartContextWindow is enabled, the `completions.chatModelMaxTokens` value is ignored. ([#62802](https://github.com/sourcegraph/sourcegraph/pull/62802)) - Code Insights: When facing the "incomplete datapoints" warning, you can now use GraphQL to discover which repositories had problems. The schemas for `TimeoutDatapointAlert` and `GenericIncompleteDatapointAlert` now contain an additional `repositories` field. ([#62756](https://github.com/sourcegraph/sourcegraph/pull/62756)). - Users will now be presented with a modal that reminds them to connect any external code host accounts that's required for permissions. Without these accounts connected, users may be unable to view repositories that they otherwise have access to. [#62983](https://github.com/sourcegraph/sourcegraph/pull/62983) +- Added support for Google as an LLM provider for Cody, with the following models available through Cody Gateway: Gemini Pro (`gemini-pro-latest`), Gemini 1.5 Flash (`gemini-1.5-flash-latest`), and Gemini 1.5 Pro (`gemini-1.5-pro-latest`). [#63053](https://github.com/sourcegraph/sourcegraph/pull/63053) ### Changed diff --git a/client/web/src/enterprise/site-admin/dotcom/productSubscriptions/ModelBadges.tsx b/client/web/src/enterprise/site-admin/dotcom/productSubscriptions/ModelBadges.tsx index b8c50e41d35..3d672643ddc 100644 --- a/client/web/src/enterprise/site-admin/dotcom/productSubscriptions/ModelBadges.tsx +++ b/client/web/src/enterprise/site-admin/dotcom/productSubscriptions/ModelBadges.tsx @@ -1,4 +1,4 @@ -import React from 'react' +import type React from 'react' import { Badge } from '@sourcegraph/wildcard' @@ -51,8 +51,11 @@ function modelBadgeVariant(model: string, mode: 'completions' | 'embeddings'): ' case 'openai/gpt-4o': case 'openai/gpt-4-turbo': case 'openai/gpt-4-turbo-preview': + // For currently available Google Gemini models, + // see: https://ai.google.dev/gemini-api/docs/models/gemini case 'google/gemini-1.5-flash-latest': case 'google/gemini-1.5-pro-latest': + case 'google/gemini-pro-latest': // Virtual models that are translated by Cody Gateway and allow access to all StarCoder // models hosted for us by Fireworks. case 'fireworks/starcoder': diff --git a/client/web/src/site-admin/SiteAdminPingsPage.tsx b/client/web/src/site-admin/SiteAdminPingsPage.tsx index 8e947994246..01af9523263 100644 --- a/client/web/src/site-admin/SiteAdminPingsPage.tsx +++ b/client/web/src/site-admin/SiteAdminPingsPage.tsx @@ -421,7 +421,7 @@ export const SiteAdminPingsPage: React.FunctionComponent
  • Provider (e.g., "sourcegraph", "anthropic", "openai", "azure-openai", - "fireworks", "aws-bedrock", etc.) + "fireworks", "aws-bedrock", "google", etc.)
  • Chat model (included only for "sourcegraph" provider)
  • Fast chat model (included only for "sourcegraph" provider)
  • diff --git a/cmd/cody-gateway/internal/httpapi/completions/google.go b/cmd/cody-gateway/internal/httpapi/completions/google.go index 55c9e13a8e9..17a493dca91 100644 --- a/cmd/cody-gateway/internal/httpapi/completions/google.go +++ b/cmd/cody-gateway/internal/httpapi/completions/google.go @@ -38,6 +38,15 @@ func NewGoogleHandler(baseLogger log.Logger, eventLogger events.Logger, rs limit ) } +// The request body for Google completions. +// Ref: https://ai.google.dev/api/rest/v1/models/generateContent#request-body +type googleRequest struct { + Model string `json:"model"` + Contents []googleContentMessage `json:"contents"` + GenerationConfig googleGenerationConfig `json:"generationConfig,omitempty"` + SafetySettings []googleSafetySettings `json:"safetySettings,omitempty"` +} + type googleContentMessage struct { Role string `json:"role"` Parts []struct { @@ -45,12 +54,22 @@ type googleContentMessage struct { } `json:"parts"` } -type googleRequest struct { - Model string `json:"model"` - Contents []googleContentMessage `json:"contents"` - GenerationConfig struct { - MaxOutputTokens int `json:"maxOutputTokens,omitempty"` - } `json:"generationConfig,omitempty"` +// Configuration options for model generation and outputs. +// Ref: https://ai.google.dev/api/rest/v1/GenerationConfig +type googleGenerationConfig struct { + Temperature float32 `json:"temperature,omitempty"` // request.Temperature + TopP float32 `json:"topP,omitempty"` // request.TopP + TopK int `json:"topK,omitempty"` // request.TopK + StopSequences []string `json:"stopSequences,omitempty"` // request.StopSequences + MaxOutputTokens int `json:"maxOutputTokens,omitempty"` // request.MaxTokensToSample + CandidateCount int `json:"candidateCount,omitempty"` // request.CandidateCount +} + +// Safety setting, affecting the safety-blocking behavior. +// Ref: https://ai.google.dev/gemini-api/docs/safety-settings +type googleSafetySettings struct { + Category string `json:"category"` + Threshold string `json:"threshold"` } func (r googleRequest) ShouldStream() bool { @@ -152,11 +171,8 @@ func (*GoogleHandlerMethods) parseResponseAndUsage(logger log.Logger, reqBody go } // Otherwise, we have to parse the event stream. - promptUsage.tokens = -1 - promptUsage.tokenizerTokens = -1 - completionUsage.tokens = -1 - completionUsage.tokenizerTokens = -1 - + promptUsage.tokens, completionUsage.tokens = -1, -1 + promptUsage.tokenizerTokens, completionUsage.tokenizerTokens = -1, -1 promptTokens, completionTokens, err := parseGoogleTokenUsage(r, logger) if err != nil { logger.Error("failed to decode Google streaming response", log.Error(err)) @@ -173,35 +189,22 @@ const maxPayloadSize = 10 * 1024 * 1024 // 10mb func parseGoogleTokenUsage(r io.Reader, logger log.Logger) (promptTokens int, completionTokens int, err error) { scanner := bufio.NewScanner(r) scanner.Buffer(make([]byte, 0, 4096), maxPayloadSize) - // bufio.ScanLines, except we look for \r\n\r\n which separate events. - split := func(data []byte, atEOF bool) (int, []byte, error) { - if atEOF && len(data) == 0 { - return 0, nil, nil - } - if i := bytes.Index(data, []byte("\r\n\r\n")); i >= 0 { - return i + 4, data[:i], nil - } - if i := bytes.Index(data, []byte("\n\n")); i >= 0 { - return i + 2, data[:i], nil - } - // If we're at EOF, we have a final, non-terminated event. This should - // be empty. - if atEOF { - return len(data), data, nil - } - // Request more data. - return 0, nil, nil - } - scanner.Split(split) - // skip to the last event - var lastEvent []byte + scanner.Split(bufio.ScanLines) + + var lastLine []byte for scanner.Scan() { - lastEvent = scanner.Bytes() + lastLine = scanner.Bytes() } - var res googleResponse - if err := json.NewDecoder(bytes.NewReader(lastEvent[5:])).Decode(&res); err != nil { - logger.Error("failed to parse Google response as JSON", log.Error(err)) - return -1, -1, err + + if bytes.HasPrefix(bytes.TrimSpace(lastLine), []byte("data: ")) { + event := lastLine[5:] + var res googleResponse + if err := json.NewDecoder(bytes.NewReader(event)).Decode(&res); err != nil { + logger.Error("failed to parse Google response as JSON", log.Error(err)) + return -1, -1, err + } + return res.UsageMetadata.PromptTokenCount, res.UsageMetadata.CompletionTokenCount, nil } - return res.UsageMetadata.PromptTokenCount, res.UsageMetadata.CompletionTokenCount, nil + + return -1, -1, errors.New("no Google response found") } diff --git a/cmd/cody-gateway/internal/httpapi/completions/google_test.go b/cmd/cody-gateway/internal/httpapi/completions/google_test.go index 9d40cae413f..80512e531d3 100644 --- a/cmd/cody-gateway/internal/httpapi/completions/google_test.go +++ b/cmd/cody-gateway/internal/httpapi/completions/google_test.go @@ -4,6 +4,8 @@ import ( "strings" "testing" + "bytes" + "github.com/sourcegraph/log/logtest" "github.com/stretchr/testify/assert" ) @@ -31,3 +33,104 @@ data: {"candidates": [{"content": {"parts": [{"text": "\n for i in range(n-1):\ data: {"candidates": [{"content": {"parts": [{"text": " range(n-i-1):\n if list1[j] \u003e list1[j+1]:\n list1[j], list"}],"role": "model"},"finishReason": "STOP","index": 0,"safetyRatings": [{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HATE_SPEECH","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HARASSMENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_DANGEROUS_CONTENT","probability": "NEGLIGIBLE"}]}],"usageMetadata": {"promptTokenCount": 21,"candidatesTokenCount": 63,"totalTokenCount": 84}} data: {"candidates": [{"content": {"parts": [{"text": "1[j+1] = list1[j+1], list1[j]\n return list1\n"}],"role": "model"},"finishReason": "STOP","index": 0,"safetyRatings": [{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HATE_SPEECH","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HARASSMENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_DANGEROUS_CONTENT","probability": "NEGLIGIBLE"}],"citationMetadata": {"citationSources": [{"startIndex": 1,"endIndex": 185,"uri": "https://github.com/Feng080412/Searches-and-sorts","license": ""}]}}],"usageMetadata": {"promptTokenCount": 21,"candidatesTokenCount": 87,"totalTokenCount": 108}}` + +func TestParseGoogleTokenUsage(t *testing.T) { + tests := []struct { + name string + input string + want *googleResponse + wantErr bool + }{ + { + name: "valid response", + input: `data: {"candidates": [{"content": {"parts": [{"text": "def"}],"role": "model"},"finishReason": "STOP","index": 0}],"usageMetadata": {"promptTokenCount": 21,"candidatesTokenCount": 1,"totalTokenCount": 22}}`, + want: &googleResponse{ + UsageMetadata: googleUsage{ + PromptTokenCount: 21, + CompletionTokenCount: 1, + TotalTokenCount: 0, + }, + }, + wantErr: false, + }, + { + name: "valid response - with candidates", + input: `data: {"usageMetadata": {"promptTokenCount": 10, "candidatesTokenCount": 20}}`, + want: &googleResponse{ + UsageMetadata: googleUsage{ + PromptTokenCount: 10, + CompletionTokenCount: 20, + TotalTokenCount: 0, + }, + }, + wantErr: false, + }, + { + name: "invalid JSON", + input: `data: {"usageMetadata": {"promptTokenCount": 10, "candidatesTokenCount": 20}`, + want: nil, + wantErr: true, + }, + { + name: "no prefix", + input: `{"usageMetadata": {"promptTokenCount": 10, "candidatesTokenCount": 20}}`, + want: nil, + wantErr: true, + }, + { + name: "empty input", + input: ``, + want: nil, + wantErr: true, + }, + { + name: "multiple lines with one valid", + input: `data: {"usageMetadata": {"promptTokenCount": 5, "candidatesTokenCount": 15}} + +data: {"usageMetadata": {"promptTokenCount": 10, "candidatesTokenCount": 20}}`, + want: &googleResponse{ + UsageMetadata: googleUsage{ + PromptTokenCount: 10, + CompletionTokenCount: 20, + TotalTokenCount: 0, + }, + }, + wantErr: false, + }, + { + name: "non-JSON data", + input: `data: not-a-json`, + want: nil, + wantErr: true, + }, + { + name: "partial data", + input: `data: {"usageMetadata": {"promptTokenCount": 10`, + want: nil, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := bytes.NewReader([]byte(tt.input)) + logger := logtest.Scoped(t) + promptTokens, completionTokens, err := parseGoogleTokenUsage(r, logger) + if (err != nil) != tt.wantErr { + t.Errorf("parseGoogleTokenUsage() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.want != nil { + got := &googleResponse{ + UsageMetadata: googleUsage{ + PromptTokenCount: promptTokens, + CompletionTokenCount: completionTokens, + }, + } + if !assert.ObjectsAreEqual(got, tt.want) { + t.Errorf("parseGoogleTokenUsage() mismatch (-want +got):\n%v", assert.ObjectsAreEqual(got, tt.want)) + } + } + }) + } +} diff --git a/cmd/cody-gateway/shared/config/BUILD.bazel b/cmd/cody-gateway/shared/config/BUILD.bazel index 4321101e97c..850c4ce05af 100644 --- a/cmd/cody-gateway/shared/config/BUILD.bazel +++ b/cmd/cody-gateway/shared/config/BUILD.bazel @@ -12,6 +12,7 @@ go_library( "//internal/collections", "//internal/completions/client/anthropic", "//internal/completions/client/fireworks", + "//internal/completions/client/google", "//internal/env", "//internal/trace/policy", "//lib/errors", diff --git a/cmd/cody-gateway/shared/config/config.go b/cmd/cody-gateway/shared/config/config.go index 89b9ac3b391..0ab12098034 100644 --- a/cmd/cody-gateway/shared/config/config.go +++ b/cmd/cody-gateway/shared/config/config.go @@ -12,6 +12,7 @@ import ( "github.com/sourcegraph/sourcegraph/internal/collections" "github.com/sourcegraph/sourcegraph/internal/completions/client/anthropic" "github.com/sourcegraph/sourcegraph/internal/completions/client/fireworks" + "github.com/sourcegraph/sourcegraph/internal/completions/client/google" "github.com/sourcegraph/sourcegraph/internal/env" "github.com/sourcegraph/sourcegraph/internal/trace/policy" "github.com/sourcegraph/sourcegraph/lib/errors" @@ -282,6 +283,20 @@ func (c *Config) Load() { c.Fireworks.StarcoderCommunitySingleTenantPercent = c.GetPercent("CODY_GATEWAY_FIREWORKS_STARCODER_COMMUNITY_SINGLE_TENANT_PERCENT", "0", "The percentage of community traffic for Starcoder to be redirected to the single-tenant deployment.") c.Fireworks.StarcoderEnterpriseSingleTenantPercent = c.GetPercent("CODY_GATEWAY_FIREWORKS_STARCODER_ENTERPRISE_SINGLE_TENANT_PERCENT", "100", "The percentage of Enterprise traffic for Starcoder to be redirected to the single-tenant deployment.") + // Configurations for Google Gemini models. + c.Google.AccessToken = c.GetOptional("CODY_GATEWAY_GOOGLE_ACCESS_TOKEN", "The Google AI Studio access token to be used.") + c.Google.AllowedModels = splitMaybe(c.Get("CODY_GATEWAY_GOOGLE_ALLOWED_MODELS", + strings.Join([]string{ + google.Gemini15FlashLatest, + google.Gemini15ProLatest, + google.GeminiProLatest, + }, ","), + "Google models that can to be used."), + ) + if c.Google.AccessToken != "" && len(c.Google.AllowedModels) == 0 { + c.AddError(errors.New("must provide allowed models for Google")) + } + c.AllowedEmbeddingsModels = splitMaybe(c.Get("CODY_GATEWAY_ALLOWED_EMBEDDINGS_MODELS", strings.Join([]string{string(embeddings.ModelNameOpenAIAda), string(embeddings.ModelNameSourcegraphSTMultiQA)}, ","), "The models allowed for embeddings generation.")) if len(c.AllowedEmbeddingsModels) == 0 { c.AddError(errors.New("must provide allowed models for embeddings generation")) @@ -325,16 +340,6 @@ func (c *Config) Load() { c.SAMSClientConfig.ClientSecret = c.GetOptional("SAMS_CLIENT_SECRET", "SAMS OAuth client secret") c.Environment = c.Get("CODY_GATEWAY_ENVIRONMENT", "dev", "Environment name.") - - c.Google.AccessToken = c.GetOptional("CODY_GATEWAY_GOOGLE_ACCESS_TOKEN", "The Google AI Studio access token to be used.") - c.Google.AllowedModels = splitMaybe(c.Get("CODY_GATEWAY_GOOGLE_ALLOWED_MODELS", - strings.Join([]string{ - "gemini-1.5-pro-latest", - "gemini-1.5-flash-latest", - }, ","), - "Google models that can to be used."), - ) - } // loadFlaggingConfig loads the common set of flagging-related environment variables for diff --git a/cmd/frontend/internal/dotcom/productsubscription/BUILD.bazel b/cmd/frontend/internal/dotcom/productsubscription/BUILD.bazel index 8115a940cc4..bd85643e76f 100644 --- a/cmd/frontend/internal/dotcom/productsubscription/BUILD.bazel +++ b/cmd/frontend/internal/dotcom/productsubscription/BUILD.bazel @@ -38,6 +38,7 @@ go_library( "//internal/codygateway", "//internal/completions/client/anthropic", "//internal/completions/client/fireworks", + "//internal/completions/client/google", "//internal/completions/types", "//internal/conf", "//internal/conf/conftypes", diff --git a/cmd/frontend/internal/dotcom/productsubscription/codygateway_dotcom_user.go b/cmd/frontend/internal/dotcom/productsubscription/codygateway_dotcom_user.go index 2a3aa7466d4..3715338b8f9 100644 --- a/cmd/frontend/internal/dotcom/productsubscription/codygateway_dotcom_user.go +++ b/cmd/frontend/internal/dotcom/productsubscription/codygateway_dotcom_user.go @@ -21,6 +21,7 @@ import ( "github.com/sourcegraph/sourcegraph/internal/codygateway" "github.com/sourcegraph/sourcegraph/internal/completions/client/anthropic" "github.com/sourcegraph/sourcegraph/internal/completions/client/fireworks" + "github.com/sourcegraph/sourcegraph/internal/completions/client/google" "github.com/sourcegraph/sourcegraph/internal/completions/types" "github.com/sourcegraph/sourcegraph/internal/conf" "github.com/sourcegraph/sourcegraph/internal/database" @@ -382,8 +383,9 @@ func allowedModels(scope types.CompletionsFeature, isProUser bool) []string { "openai/gpt-4o", "openai/gpt-4-turbo", "openai/gpt-4-turbo-preview", - "google/gemini-1.5-pro-latest", - "google/gemini-1.5-flash-latest", + "google/" + google.Gemini15FlashLatest, + "google/" + google.Gemini15ProLatest, + "google/" + google.GeminiProLatest, // Remove after the Claude 3 rollout is complete "anthropic/claude-2", diff --git a/cmd/frontend/internal/httpapi/completions/BUILD.bazel b/cmd/frontend/internal/httpapi/completions/BUILD.bazel index 16aed82c7d6..64b2d73628a 100644 --- a/cmd/frontend/internal/httpapi/completions/BUILD.bazel +++ b/cmd/frontend/internal/httpapi/completions/BUILD.bazel @@ -22,6 +22,7 @@ go_library( "//internal/completions/client", "//internal/completions/client/anthropic", "//internal/completions/client/fireworks", + "//internal/completions/client/google", "//internal/completions/types", "//internal/conf", "//internal/conf/conftypes", diff --git a/cmd/frontend/internal/httpapi/completions/chat.go b/cmd/frontend/internal/httpapi/completions/chat.go index 397f062cc94..6f07738d5e0 100644 --- a/cmd/frontend/internal/httpapi/completions/chat.go +++ b/cmd/frontend/internal/httpapi/completions/chat.go @@ -13,6 +13,7 @@ import ( "github.com/sourcegraph/sourcegraph/internal/completions/client/anthropic" "github.com/sourcegraph/sourcegraph/internal/completions/client/fireworks" + "github.com/sourcegraph/sourcegraph/internal/completions/client/google" "github.com/sourcegraph/sourcegraph/internal/completions/types" "github.com/sourcegraph/sourcegraph/internal/conf/conftypes" "github.com/sourcegraph/sourcegraph/internal/database" @@ -85,8 +86,9 @@ func isAllowedCustomChatModel(model string, isProUser bool) bool { "openai/gpt-4o", "openai/gpt-4-turbo", "openai/gpt-4-turbo-preview", - "google/gemini-1.5-flash-latest", - "google/gemini-1.5-pro-latest", + "google" + google.Gemini15FlashLatest, + "google" + google.Gemini15ProLatest, + "google" + google.GeminiProLatest, // Remove after the Claude 3 rollout is complete "anthropic/claude-2", diff --git a/internal/completions/client/BUILD.bazel b/internal/completions/client/BUILD.bazel index 0a7351f26ec..14964739228 100644 --- a/internal/completions/client/BUILD.bazel +++ b/internal/completions/client/BUILD.bazel @@ -15,6 +15,7 @@ go_library( "//internal/completions/client/azureopenai", "//internal/completions/client/codygateway", "//internal/completions/client/fireworks", + "//internal/completions/client/google", "//internal/completions/client/openai", "//internal/completions/tokenusage", "//internal/completions/types", diff --git a/internal/completions/client/client.go b/internal/completions/client/client.go index fdb684cce08..9871951a439 100644 --- a/internal/completions/client/client.go +++ b/internal/completions/client/client.go @@ -8,6 +8,7 @@ import ( "github.com/sourcegraph/sourcegraph/internal/completions/client/azureopenai" "github.com/sourcegraph/sourcegraph/internal/completions/client/codygateway" "github.com/sourcegraph/sourcegraph/internal/completions/client/fireworks" + "github.com/sourcegraph/sourcegraph/internal/completions/client/google" "github.com/sourcegraph/sourcegraph/internal/completions/client/openai" "github.com/sourcegraph/sourcegraph/internal/completions/tokenusage" "github.com/sourcegraph/sourcegraph/internal/completions/types" @@ -40,6 +41,8 @@ func getBasic(endpoint string, provider conftypes.CompletionsProviderName, acces return openai.NewClient(httpcli.UncachedExternalDoer, endpoint, accessToken, *tokenManager), nil case conftypes.CompletionsProviderNameAzureOpenAI: return azureopenai.NewClient(azureopenai.GetAPIClient, endpoint, accessToken, *tokenManager) + case conftypes.CompletionsProviderNameGoogle: + return google.NewClient(httpcli.UncachedExternalDoer, endpoint, accessToken), nil case conftypes.CompletionsProviderNameSourcegraph: return codygateway.NewClient(httpcli.CodyGatewayDoer, endpoint, accessToken, *tokenManager) case conftypes.CompletionsProviderNameFireworks: diff --git a/internal/completions/client/codygateway/BUILD.bazel b/internal/completions/client/codygateway/BUILD.bazel index 826e547cb64..9593f03aa22 100644 --- a/internal/completions/client/codygateway/BUILD.bazel +++ b/internal/completions/client/codygateway/BUILD.bazel @@ -12,6 +12,7 @@ go_library( "//internal/codygateway", "//internal/completions/client/anthropic", "//internal/completions/client/fireworks", + "//internal/completions/client/google", "//internal/completions/client/openai", "//internal/completions/tokenusage", "//internal/completions/types", diff --git a/internal/completions/client/codygateway/codygateway.go b/internal/completions/client/codygateway/codygateway.go index d4769d1a0ee..5152dd39d2d 100644 --- a/internal/completions/client/codygateway/codygateway.go +++ b/internal/completions/client/codygateway/codygateway.go @@ -15,6 +15,7 @@ import ( "github.com/sourcegraph/sourcegraph/internal/codygateway" "github.com/sourcegraph/sourcegraph/internal/completions/client/anthropic" "github.com/sourcegraph/sourcegraph/internal/completions/client/fireworks" + "github.com/sourcegraph/sourcegraph/internal/completions/client/google" "github.com/sourcegraph/sourcegraph/internal/completions/client/openai" "github.com/sourcegraph/sourcegraph/internal/completions/tokenusage" "github.com/sourcegraph/sourcegraph/internal/completions/types" @@ -89,6 +90,8 @@ func (c *codyGatewayClient) clientForParams(feature types.CompletionsFeature, re return openai.NewClient(gatewayDoer(c.upstream, feature, c.gatewayURL, c.accessToken, "/v1/completions/openai"), "", "", c.tokenizer), nil case string(conftypes.CompletionsProviderNameFireworks): return fireworks.NewClient(gatewayDoer(c.upstream, feature, c.gatewayURL, c.accessToken, "/v1/completions/fireworks"), "", ""), nil + case string(conftypes.CompletionsProviderNameGoogle): + return google.NewClient(gatewayDoer(c.upstream, feature, c.gatewayURL, c.accessToken, "/v1/completions/google"), "", ""), nil case "": return nil, errors.Newf("no provider provided in model %s - a model in the format '$PROVIDER/$MODEL_NAME' is expected", model) default: diff --git a/internal/completions/client/google/BUILD.bazel b/internal/completions/client/google/BUILD.bazel new file mode 100644 index 00000000000..37a8d604908 --- /dev/null +++ b/internal/completions/client/google/BUILD.bazel @@ -0,0 +1,38 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("//dev:go_defs.bzl", "go_test") + +go_library( + name = "google", + srcs = [ + "decoder.go", + "google.go", + "models.go", + "prompt.go", + "types.go", + ], + importpath = "github.com/sourcegraph/sourcegraph/internal/completions/client/google", + visibility = ["//:__subpackages__"], + deps = [ + "//internal/completions/types", + "//internal/httpcli", + "//lib/errors", + "@com_github_sourcegraph_log//:log", + ], +) + +go_test( + name = "google_test", + srcs = [ + "decoder_test.go", + "google_test.go", + "prompt_test.go", + ], + embed = [":google"], + deps = [ + "//internal/completions/types", + "@com_github_hexops_autogold_v2//:autogold", + "@com_github_sourcegraph_log//:log", + "@com_github_stretchr_testify//assert", + "@com_github_stretchr_testify//require", + ], +) diff --git a/internal/completions/client/google/decoder.go b/internal/completions/client/google/decoder.go new file mode 100644 index 00000000000..65b327ede5d --- /dev/null +++ b/internal/completions/client/google/decoder.go @@ -0,0 +1,96 @@ +package google + +import ( + "bufio" + "bytes" + "io" + + "github.com/sourcegraph/sourcegraph/lib/errors" +) + +const maxPayloadSize = 10 * 1024 * 1024 // 10mb + +var doneBytes = []byte("[DONE]") + +// decoder decodes streaming events from a Server Sent Event stream. +type decoder struct { + scanner *bufio.Scanner + done bool + data []byte + err error +} + +func NewDecoder(r io.Reader) *decoder { + scanner := bufio.NewScanner(r) + scanner.Buffer(make([]byte, 0, 4096), maxPayloadSize) + // bufio.ScanLines, except we look for \r\n\r\n which separate events. + split := func(data []byte, atEOF bool) (int, []byte, error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + if i := bytes.Index(data, []byte("\r\n\r\n")); i >= 0 { + return i + 4, data[:i], nil + } + if i := bytes.Index(data, []byte("\n\n")); i >= 0 { + return i + 2, data[:i], nil + } + // If we're at EOF, we have a final, non-terminated event. This should + // be empty. + if atEOF { + return len(data), data, nil + } + // Request more data. + return 0, nil, nil + } + scanner.Split(split) + return &decoder{ + scanner: scanner, + } +} + +// Scan advances the decoder to the next event in the stream. It returns +// false when it either hits the end of the stream or an error. +func (d *decoder) Scan() bool { + if d.done { + return false + } + for d.scanner.Scan() { + line := d.scanner.Bytes() + typ, data := splitColon(line) + switch { + case bytes.Equal(typ, []byte("data")): + d.data = data + // Check for special sentinel value used by the Google API to + // indicate that the stream is done. + if bytes.Equal(data, doneBytes) { + d.done = true + return false + } + return true + default: + d.err = errors.Errorf("malformed data, expected data: %s", typ) + return false + } + } + + d.err = d.scanner.Err() + return false +} + +// Event returns the event data of the last decoded event +func (d *decoder) Data() []byte { + return d.data +} + +// Err returns the last encountered error +func (d *decoder) Err() error { + return d.err +} + +func splitColon(data []byte) ([]byte, []byte) { + i := bytes.Index(data, []byte(":")) + if i < 0 { + return bytes.TrimSpace(data), nil + } + return bytes.TrimSpace(data[:i]), bytes.TrimSpace(data[i+1:]) +} diff --git a/internal/completions/client/google/decoder_test.go b/internal/completions/client/google/decoder_test.go new file mode 100644 index 00000000000..2267f7010a0 --- /dev/null +++ b/internal/completions/client/google/decoder_test.go @@ -0,0 +1,56 @@ +package google + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestDecoder(t *testing.T) { + t.Parallel() + + type event struct { + data string + } + + decodeAll := func(input string) ([]event, error) { + dec := NewDecoder(strings.NewReader(input)) + var events []event + for dec.Scan() { + events = append(events, event{ + data: string(dec.Data()), + }) + } + return events, dec.Err() + } + + t.Run("Single", func(t *testing.T) { + events, err := decodeAll("data:b\n\n") + require.NoError(t, err) + require.Equal(t, events, []event{{data: "b"}}) + }) + + t.Run("Multiple", func(t *testing.T) { + events, err := decodeAll("data:b\n\ndata:c\n\ndata: [DONE]\n\n") + require.NoError(t, err) + require.Equal(t, events, []event{{data: "b"}, {data: "c"}}) + }) + + t.Run("Multiple with new line within data", func(t *testing.T) { + events, err := decodeAll("data:line1\nline2\nline3\n\ndata:second-data\n\ndata: [DONE]\n\n") + require.NoError(t, err) + require.Equal(t, events, []event{{data: "line1\nline2\nline3"}, {data: "second-data"}}) + }) + + t.Run("ErrExpectedData", func(t *testing.T) { + _, err := decodeAll("datas:b\n\n") + require.Contains(t, err.Error(), "malformed data, expected data") + }) + + t.Run("Ends after done", func(t *testing.T) { + events, err := decodeAll("data:b\n\ndata:c\n\ndata: [DONE]\n\ndata:d\n\n") + require.NoError(t, err) + require.Equal(t, events, []event{{data: "b"}, {data: "c"}}) + }) +} diff --git a/internal/completions/client/google/google.go b/internal/completions/client/google/google.go new file mode 100644 index 00000000000..6e2b2b530e7 --- /dev/null +++ b/internal/completions/client/google/google.go @@ -0,0 +1,245 @@ +package google + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/url" + "path" + + "github.com/sourcegraph/log" + + "github.com/sourcegraph/sourcegraph/internal/completions/types" + "github.com/sourcegraph/sourcegraph/internal/httpcli" + "github.com/sourcegraph/sourcegraph/lib/errors" +) + +func NewClient(cli httpcli.Doer, endpoint, accessToken string) types.CompletionsClient { + return &googleCompletionStreamClient{ + cli: cli, + accessToken: accessToken, + endpoint: endpoint, + } +} + +func (c *googleCompletionStreamClient) Complete( + ctx context.Context, + feature types.CompletionsFeature, + _ types.CompletionsVersion, + requestParams types.CompletionRequestParameters, + logger log.Logger, +) (*types.CompletionResponse, error) { + if !isSupportedFeature(feature) { + return nil, errors.Newf("feature %q is currently not supported for Google", feature) + } + + var resp *http.Response + var err error + defer (func() { + if resp != nil { + resp.Body.Close() + } + })() + + resp, err = c.makeRequest(ctx, requestParams, false) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + var response googleResponse + if err := json.NewDecoder(resp.Body).Decode(&response); err != nil { + return nil, err + } + + if len(response.Candidates) == 0 { + // Empty response. + return &types.CompletionResponse{}, nil + } + + if len(response.Candidates[0].Content.Parts) == 0 { + // Empty response. + return &types.CompletionResponse{}, nil + } + + // NOTE: Candidates can be used to get multiple completions when CandidateCount is set, + // which is not currently supported by Cody. For now, we only return the first completion. + return &types.CompletionResponse{ + Completion: response.Candidates[0].Content.Parts[0].Text, + StopReason: response.Candidates[0].FinishReason, + }, nil +} + +func (c *googleCompletionStreamClient) Stream( + ctx context.Context, + feature types.CompletionsFeature, + _ types.CompletionsVersion, + requestParams types.CompletionRequestParameters, + sendEvent types.SendCompletionEvent, + logger log.Logger, +) error { + if !isSupportedFeature(feature) { + return errors.Newf("feature %q is currently not supported for Google", feature) + } + + var resp *http.Response + var err error + + defer (func() { + if resp != nil { + resp.Body.Close() + } + })() + + resp, err = c.makeRequest(ctx, requestParams, true) + if err != nil { + return err + } + + dec := NewDecoder(resp.Body) + var content string + var ev types.CompletionResponse + + for dec.Scan() { + if ctx.Err() != nil && ctx.Err() == context.Canceled { + return nil + } + + data := dec.Data() + // Gracefully skip over any data that isn't JSON-like. + if !bytes.HasPrefix(data, []byte("{")) { + continue + } + + var event googleResponse + if err := json.Unmarshal(data, &event); err != nil { + return errors.Errorf("failed to decode event payload: %w - body: %s", err, string(data)) + } + + if len(event.Candidates) > 0 && len(event.Candidates[0].Content.Parts) > 0 { + content += event.Candidates[0].Content.Parts[0].Text + + ev = types.CompletionResponse{ + Completion: content, + StopReason: event.Candidates[0].FinishReason, + } + err = sendEvent(ev) + if err != nil { + return err + } + } + } + if dec.Err() != nil { + return dec.Err() + } + + return nil +} + +// makeRequest formats the request and calls the chat/completions endpoint for code_completion requests +func (c *googleCompletionStreamClient) makeRequest(ctx context.Context, requestParams types.CompletionRequestParameters, stream bool) (*http.Response, error) { + // Ensure TopK and TopP are non-negative + requestParams.TopK = max(0, requestParams.TopK) + requestParams.TopP = max(0, requestParams.TopP) + + // Generate the prompt + prompt, err := getPrompt(requestParams.Messages) + if err != nil { + return nil, err + } + + payload := googleRequest{ + Contents: prompt, + GenerationConfig: googleGenerationConfig{ + Temperature: requestParams.Temperature, + TopP: requestParams.TopP, + TopK: requestParams.TopK, + MaxOutputTokens: requestParams.MaxTokensToSample, + StopSequences: requestParams.StopSequences, + }, + } + + reqBody, err := json.Marshal(payload) + if err != nil { + return nil, err + } + + apiURL := c.getAPIURL(requestParams, stream) + + req, err := http.NewRequestWithContext(ctx, "POST", apiURL.String(), bytes.NewReader(reqBody)) + if err != nil { + return nil, err + } + + req.Header.Set("Content-Type", "application/json") + + // Vertex AI API requires an Authorization header with the access token. + // Ref: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/gemini#sample-requests + if !isDefaultAPIEndpoint(apiURL) { + req.Header.Set("Authorization", "Bearer "+c.accessToken) + } + + resp, err := c.cli.Do(req) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusOK { + return nil, types.NewErrStatusNotOK("Google", resp) + } + + return resp, nil +} + +// In the latest API Docs, the model name and API key must be used with the default API endpoint URL. +// Ref: https://ai.google.dev/gemini-api/docs/get-started/tutorial?lang=rest#gemini_and_content_based_apis +func (c *googleCompletionStreamClient) getAPIURL(requestParams types.CompletionRequestParameters, stream bool) *url.URL { + apiURL, err := url.Parse(c.endpoint) + if err != nil { + apiURL = &url.URL{ + Scheme: "https", + Host: defaultAPIHost, + Path: defaultAPIPath, + } + } + + apiURL.Path = path.Join(apiURL.Path, requestParams.Model) + ":" + getgRPCMethod(stream) + + // We need to append the API key to the default API endpoint URL. + if isDefaultAPIEndpoint(apiURL) { + query := apiURL.Query() + query.Set("key", c.accessToken) + if stream { + query.Set("alt", "sse") + } + apiURL.RawQuery = query.Encode() + } + + return apiURL +} + +// getgRPCMethod returns the gRPC method name based on the stream flag. +func getgRPCMethod(stream bool) string { + if stream { + return "streamGenerateContent" + } + return "generateContent" +} + +// isSupportedFeature checks if the given CompletionsFeature is supported. +// Currently, only the CompletionsFeatureChat feature is supported. +func isSupportedFeature(feature types.CompletionsFeature) bool { + switch feature { + case types.CompletionsFeatureChat: + return true + default: + return false + } +} + +// isDefaultAPIEndpoint checks if the given API endpoint URL is the default API endpoint. +// The default API endpoint is determined by the defaultAPIHost constant. +func isDefaultAPIEndpoint(endpoint *url.URL) bool { + return endpoint.Host == defaultAPIHost +} diff --git a/internal/completions/client/google/google_test.go b/internal/completions/client/google/google_test.go new file mode 100644 index 00000000000..3e27682c635 --- /dev/null +++ b/internal/completions/client/google/google_test.go @@ -0,0 +1,129 @@ +package google + +import ( + "bytes" + "context" + "io" + "net/http" + "testing" + + "github.com/hexops/autogold/v2" + "github.com/sourcegraph/log" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/sourcegraph/sourcegraph/internal/completions/types" +) + +type mockDoer struct { + do func(*http.Request) (*http.Response, error) +} + +func (c *mockDoer) Do(r *http.Request) (*http.Response, error) { + return c.do(r) +} + +func TestErrStatusNotOK(t *testing.T) { + mockClient := NewClient(&mockDoer{ + func(r *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusTooManyRequests, + Body: io.NopCloser(bytes.NewReader([]byte("oh no, please slow down!"))), + }, nil + }, + }, "", "") + + t.Run("Complete", func(t *testing.T) { + logger := log.Scoped("completions") + resp, err := mockClient.Complete(context.Background(), types.CompletionsFeatureChat, types.CompletionsVersionLegacy, types.CompletionRequestParameters{}, logger) + require.Error(t, err) + assert.Nil(t, resp) + + autogold.Expect("Google: unexpected status code 429: oh no, please slow down!").Equal(t, err.Error()) + _, ok := types.IsErrStatusNotOK(err) + assert.True(t, ok) + }) + + t.Run("Stream", func(t *testing.T) { + logger := log.Scoped("completions") + err := mockClient.Stream(context.Background(), types.CompletionsFeatureChat, types.CompletionsVersionLegacy, types.CompletionRequestParameters{}, func(event types.CompletionResponse) error { return nil }, logger) + require.Error(t, err) + + autogold.Expect("Google: unexpected status code 429: oh no, please slow down!").Equal(t, err.Error()) + _, ok := types.IsErrStatusNotOK(err) + assert.True(t, ok) + }) +} + +func TestGetAPIURL(t *testing.T) { + t.Parallel() + + client := &googleCompletionStreamClient{ + endpoint: "https://generativelanguage.googleapis.com/v1/models", + accessToken: "test-token", + } + + t.Run("valid v1 endpoint", func(t *testing.T) { + params := types.CompletionRequestParameters{ + Model: "test-model", + } + url := client.getAPIURL(params, false).String() + expected := "https://generativelanguage.googleapis.com/v1/models/test-model:generateContent?key=test-token" + require.Equal(t, expected, url) + }) + + // + t.Run("valid endpoint for Vertex AI", func(t *testing.T) { + params := types.CompletionRequestParameters{ + Model: "gemini-1.5-pro", + } + c := &googleCompletionStreamClient{ + endpoint: "https://vertex-ai.example.com/v1/projects/PROJECT_ID/locations/LOCATION/publishers/google/models", + accessToken: "test-token", + } + url := c.getAPIURL(params, true).String() + expected := "https://vertex-ai.example.com/v1/projects/PROJECT_ID/locations/LOCATION/publishers/google/models/gemini-1.5-pro:streamGenerateContent" + require.Equal(t, expected, url) + }) + + t.Run("valid custom endpoint", func(t *testing.T) { + params := types.CompletionRequestParameters{ + Model: "test-model", + } + c := &googleCompletionStreamClient{ + endpoint: "https://example.com/api/models", + accessToken: "test-token", + } + url := c.getAPIURL(params, true).String() + expected := "https://example.com/api/models/test-model:streamGenerateContent" + require.Equal(t, expected, url) + }) + + t.Run("invalid endpoint", func(t *testing.T) { + client.endpoint = "://invalid" + params := types.CompletionRequestParameters{ + Model: "test-model", + } + url := client.getAPIURL(params, false).String() + expected := "https://generativelanguage.googleapis.com/v1beta/models/test-model:generateContent?key=test-token" + require.Equal(t, expected, url) + }) + + t.Run("streaming", func(t *testing.T) { + params := types.CompletionRequestParameters{ + Model: "test-model", + } + url := client.getAPIURL(params, true).String() + expected := "https://generativelanguage.googleapis.com/v1beta/models/test-model:streamGenerateContent?alt=sse&key=test-token" + require.Equal(t, expected, url) + }) + + t.Run("empty model", func(t *testing.T) { + params := types.CompletionRequestParameters{ + Model: "", + } + url := client.getAPIURL(params, false).String() + expected := "https://generativelanguage.googleapis.com/v1beta/models:generateContent?key=test-token" + require.Equal(t, expected, url) + }) +} diff --git a/internal/completions/client/google/models.go b/internal/completions/client/google/models.go new file mode 100644 index 00000000000..e1ab40c9a51 --- /dev/null +++ b/internal/completions/client/google/models.go @@ -0,0 +1,21 @@ +package google + +// For latest available Google Gemini models, +// See: https://ai.google.dev/gemini-api/docs/models/gemini +const providerName = "google" + +// Default API endpoint URL +const ( + defaultAPIHost = "generativelanguage.googleapis.com" + defaultAPIPath = "/v1beta/models" +) + +// Latest stable versions +const Gemini15Flash = "gemini-1.5-flash" +const Gemini15Pro = "gemini-1.5-pro" +const GeminiPro = "gemini-pro" + +// Latest versions +const Gemini15FlashLatest = "gemini-1.5-flash-latest" +const Gemini15ProLatest = "gemini-1.5-pro-latest" +const GeminiProLatest = "gemini-pro-latest" diff --git a/internal/completions/client/google/prompt.go b/internal/completions/client/google/prompt.go new file mode 100644 index 00000000000..7b6a63904f4 --- /dev/null +++ b/internal/completions/client/google/prompt.go @@ -0,0 +1,42 @@ +package google + +import ( + "github.com/sourcegraph/sourcegraph/internal/completions/types" + "github.com/sourcegraph/sourcegraph/lib/errors" +) + +func getPrompt(messages []types.Message) ([]googleContentMessage, error) { + googleMessages := make([]googleContentMessage, 0, len(messages)) + + for i, message := range messages { + var googleRole string + + switch message.Speaker { + case types.SYSTEM_MESSAGE_SPEAKER: + if i != 0 { + return nil, errors.New("system role can only be used in the first message") + } + googleRole = message.Speaker + case types.ASSISTANT_MESSAGE_SPEAKER: + if i == 0 { + return nil, errors.New("assistant role cannot be used in the first message") + } + googleRole = "model" + case types.HUMAN_MESSAGE_SPEAKER: + googleRole = "user" + default: + return nil, errors.Errorf("unexpected role: %s", message.Text) + } + + if message.Text == "" { + return nil, errors.New("message content cannot be empty") + } + + googleMessages = append(googleMessages, googleContentMessage{ + Role: googleRole, + Parts: []googleContentMessagePart{{Text: message.Text}}, + }) + } + + return googleMessages, nil +} diff --git a/internal/completions/client/google/prompt_test.go b/internal/completions/client/google/prompt_test.go new file mode 100644 index 00000000000..16b797299de --- /dev/null +++ b/internal/completions/client/google/prompt_test.go @@ -0,0 +1,70 @@ +package google + +import ( + "testing" + + "github.com/sourcegraph/sourcegraph/internal/completions/types" +) + +func TestGetPrompt(t *testing.T) { + t.Run("invalid speaker", func(t *testing.T) { + _, err := getPrompt([]types.Message{{Speaker: "invalid", Text: "hello"}}) + if err == nil { + t.Errorf("expected error for invalid speaker, got nil") + } + }) + + t.Run("empty text", func(t *testing.T) { + _, err := getPrompt([]types.Message{{Speaker: types.HUMAN_MESSAGE_SPEAKER, Text: ""}}) + if err == nil { + t.Errorf("expected error for empty text, got nil") + } + }) + + t.Run("multiple system messages", func(t *testing.T) { + _, err := getPrompt([]types.Message{ + {Speaker: types.SYSTEM_MESSAGE_SPEAKER, Text: "system"}, + {Speaker: types.HUMAN_MESSAGE_SPEAKER, Text: "hello"}, + {Speaker: types.SYSTEM_MESSAGE_SPEAKER, Text: "system"}, + }) + if err == nil { + t.Errorf("expected error for multiple system messages, got nil") + } + }) + + t.Run("invalid prompt starts with assistant", func(t *testing.T) { + _, err := getPrompt([]types.Message{ + {Speaker: types.ASSISTANT_MESSAGE_SPEAKER, Text: "assistant"}, + {Speaker: types.HUMAN_MESSAGE_SPEAKER, Text: "hello"}, + {Speaker: types.ASSISTANT_MESSAGE_SPEAKER, Text: "assistant"}, + }) + if err == nil { + t.Errorf("expected error for messages starts with assistant, got nil") + } + }) + + t.Run("valid prompt", func(t *testing.T) { + messages := []types.Message{ + {Speaker: types.SYSTEM_MESSAGE_SPEAKER, Text: "system"}, + {Speaker: types.HUMAN_MESSAGE_SPEAKER, Text: "hello"}, + {Speaker: types.ASSISTANT_MESSAGE_SPEAKER, Text: "hi"}, + } + prompt, err := getPrompt(messages) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + expected := []googleContentMessage{ + {Role: "system", Parts: []googleContentMessagePart{{Text: "system"}}}, + {Role: "user", Parts: []googleContentMessagePart{{Text: "hello"}}}, + {Role: "model", Parts: []googleContentMessagePart{{Text: "hi"}}}, + } + if len(prompt) != len(expected) { + t.Errorf("unexpected prompt length, got %d, want %d", len(prompt), len(expected)) + } + for i := range prompt { + if prompt[i].Parts[0].Text != expected[i].Parts[0].Text { + t.Errorf("unexpected prompt message at index %d, got %v, want %v", i, prompt[i], expected[i]) + } + } + }) +} diff --git a/internal/completions/client/google/types.go b/internal/completions/client/google/types.go new file mode 100644 index 00000000000..4119ecf5099 --- /dev/null +++ b/internal/completions/client/google/types.go @@ -0,0 +1,60 @@ +package google + +import "github.com/sourcegraph/sourcegraph/internal/httpcli" + +type googleCompletionStreamClient struct { + cli httpcli.Doer + accessToken string + endpoint string +} + +type googleRequest struct { + Contents []googleContentMessage `json:"contents"` + GenerationConfig googleGenerationConfig `json:"generationConfig,omitempty"` + SafetySettings []googleSafetySettings `json:"safetySettings,omitempty"` +} + +type googleContentMessage struct { + Role string `json:"role"` + Parts []googleContentMessagePart `json:"parts"` +} + +type googleContentMessagePart struct { + Text string `json:"text"` +} + +// Configuration options for model generation and outputs. +// Ref: https://ai.google.dev/api/rest/v1/GenerationConfig +type googleGenerationConfig struct { + Temperature float32 `json:"temperature,omitempty"` // request.Temperature + TopP float32 `json:"topP,omitempty"` // request.TopP + TopK int `json:"topK,omitempty"` // request.TopK + StopSequences []string `json:"stopSequences,omitempty"` // request.StopSequences + MaxOutputTokens int `json:"maxOutputTokens,omitempty"` // request.MaxTokensToSample + CandidateCount int `json:"candidateCount,omitempty"` // request.CandidateCount +} + +type googleResponse struct { + Model string `json:"model"` + Candidates []struct { + Content googleContentMessage + FinishReason string `json:"finishReason"` + } `json:"candidates"` + + UsageMetadata googleUsage `json:"usageMetadata"` + SafetySettings []googleSafetySettings `json:"safetySettings,omitempty"` +} + +// Safety setting, affecting the safety-blocking behavior. +// Ref: https://ai.google.dev/gemini-api/docs/safety-settings +type googleSafetySettings struct { + Category string `json:"category"` + Threshold string `json:"threshold"` +} + +type googleUsage struct { + PromptTokenCount int `json:"promptTokenCount"` + // Use the same name we use elsewhere (completion instead of candidates) + CompletionTokenCount int `json:"candidatesTokenCount"` + TotalTokenCount int `json:"totalTokenCount"` +} diff --git a/internal/conf/BUILD.bazel b/internal/conf/BUILD.bazel index 738ea68196f..3dc3da493aa 100644 --- a/internal/conf/BUILD.bazel +++ b/internal/conf/BUILD.bazel @@ -28,6 +28,7 @@ go_library( deps = [ "//internal/api/internalapi", "//internal/completions/client/anthropic", + "//internal/completions/client/google", "//internal/conf/confdefaults", "//internal/conf/conftypes", "//internal/conf/deploy", diff --git a/internal/conf/computed.go b/internal/conf/computed.go index a66b95bb3e0..572988f4bf2 100644 --- a/internal/conf/computed.go +++ b/internal/conf/computed.go @@ -10,6 +10,7 @@ import ( "github.com/hashicorp/cronexpr" "github.com/sourcegraph/sourcegraph/internal/completions/client/anthropic" + "github.com/sourcegraph/sourcegraph/internal/completions/client/google" "github.com/sourcegraph/sourcegraph/internal/conf/confdefaults" "github.com/sourcegraph/sourcegraph/internal/conf/conftypes" "github.com/sourcegraph/sourcegraph/internal/conf/deploy" @@ -867,6 +868,32 @@ func GetCompletionsConfig(siteConfig schema.SiteConfiguration) (c *conftypes.Com completionsModelRef := conftypes.NewBedrockModelRefFromModelID(completionsConfig.CompletionModel) completionsConfig.CompletionModel = completionsModelRef.CanonicalizedModelID() canonicalized = true + } else if completionsConfig.Provider == string(conftypes.CompletionsProviderNameGoogle) { + // If no endpoint is configured, use a default value. + if completionsConfig.Endpoint == "" { + completionsConfig.Endpoint = "https://generativelanguage.googleapis.com/v1beta/models" + } + + // If not access token is set, we cannot talk to Google. Bail. + if completionsConfig.AccessToken == "" { + return nil + } + + // Set a default chat model. + if completionsConfig.ChatModel == "" { + completionsConfig.ChatModel = google.Gemini15ProLatest + } + + // Set a default fast chat model. + if completionsConfig.FastChatModel == "" { + completionsConfig.FastChatModel = google.Gemini15FlashLatest + } + + // Set a default completions model. + if completionsConfig.CompletionModel == "" { + // Code completion is not supported by Google + completionsConfig.CompletionModel = google.Gemini15FlashLatest + } } // only apply canonicalization if not already applied. Not all model IDs can simply be lowercased diff --git a/internal/licensing/codygateway.go b/internal/licensing/codygateway.go index 10728570d2b..19bd545a05b 100644 --- a/internal/licensing/codygateway.go +++ b/internal/licensing/codygateway.go @@ -50,6 +50,7 @@ func NewCodyGatewayChatRateLimit(plan Plan, userCount *int) CodyGatewayRateLimit "google/gemini-1.5-pro-latest", "google/gemini-1.5-flash-latest", + "google/gemini-pro-latest", } switch plan { // TODO: This is just an example for now. diff --git a/internal/licensing/codygateway_test.go b/internal/licensing/codygateway_test.go index 8e74d7d413a..2e3984ce87c 100644 --- a/internal/licensing/codygateway_test.go +++ b/internal/licensing/codygateway_test.go @@ -38,6 +38,7 @@ func TestNewCodyGatewayChatRateLimit(t *testing.T) { "openai/gpt-4-turbo-preview", "google/gemini-1.5-pro-latest", "google/gemini-1.5-flash-latest", + "google/gemini-pro-latest", }, Limit: 2500, IntervalSeconds: 60 * 60 * 24, @@ -65,6 +66,7 @@ func TestNewCodyGatewayChatRateLimit(t *testing.T) { "openai/gpt-4-turbo-preview", "google/gemini-1.5-pro-latest", "google/gemini-1.5-flash-latest", + "google/gemini-pro-latest", }, Limit: 50, IntervalSeconds: 60 * 60 * 24, @@ -92,6 +94,7 @@ func TestNewCodyGatewayChatRateLimit(t *testing.T) { "openai/gpt-4-turbo-preview", "google/gemini-1.5-pro-latest", "google/gemini-1.5-flash-latest", + "google/gemini-pro-latest", }, Limit: 10, IntervalSeconds: 60 * 60 * 24, diff --git a/schema/site.schema.json b/schema/site.schema.json index a6e7e34d4de..a4dbdc15fe0 100644 --- a/schema/site.schema.json +++ b/schema/site.schema.json @@ -2890,7 +2890,7 @@ "type": "string", "description": "The external completions provider. Defaults to 'sourcegraph'.", "default": "sourcegraph", - "enum": ["anthropic", "openai", "sourcegraph", "azure-openai", "aws-bedrock", "fireworks"] + "enum": ["anthropic", "openai", "sourcegraph", "azure-openai", "aws-bedrock", "fireworks", "google"] }, "endpoint": { "type": "string",