Cody Gateway: Add support for Google non-streaming endpoint (#63166)

Add support for non-stream request for Google Gemini provider

- Added `Stream` field to `googleRequest` struct to enable streaming
completions
- Added `SymtemInstruction` field to `googleRequest` struct to allow
setting system instructions
- Updated `GoogleHandlerMethods.validateRequest` to allow
`FeatureEmbeddings` instead of `FeatureCodeCompletions`
- Updated `GoogleHandlerMethods.getRequestMetadata` to return the
`Stream` field
- Updated `GoogleGatewayFeatureClient.GetRequest` to handle streaming
for both `FeatureCodeCompletions` and `FeatureChatCompletions`
- Removed unsupported feature checks in `googleCompletionStreamClient`
- Added Gemini 1.5 Flash and Gemini 1.0 Pro to autocomplete allowed list
(but not supported by clients atm)

<!-- 💡 To write a useful PR description, make sure that your description
covers:
- WHAT this PR is changing:
    - How was it PREVIOUSLY.
    - How it will be from NOW on.
- WHY this PR is needed.
- CONTEXT, i.e. to which initiative, project or RFC it belongs.

The structure of the description doesn't matter as much as covering
these points, so use
your best judgement based on your context.
Learn how to write good pull request description:
https://www.notion.so/sourcegraph/Write-a-good-pull-request-description-610a7fd3e613496eb76f450db5a49b6e?pvs=4
-->


## Test plan

<!-- All pull requests REQUIRE a test plan:
https://docs-legacy.sourcegraph.com/dev/background-information/testing_principles
-->

Unit tests updated for non-stream request.

To manually test this:

1. In your Soucegraph local instance's Site Config, add the following:

```
  "completions": {
    "accessToken": "REDACTED",
    "chatModel": "gemini-1.5-pro-latest",
    "completionModel": "google/gemini-1.5-flash-latest",
    "provider": "google",
```

Note: You can get the accessToken for Gemini API in 1Password.

2. After saving the site config with the above change, run the following
curl command that hits the code endpoint:

```
curl 'https://sourcegraph.test:3443/.api/completions/stream' -i \
-X POST \
-H 'authorization: token $YOUR_LOCAL_TOKEN' \
--data-raw '{"messages":[{"speaker":"human","text":"Who are you?"}],"maxTokensToSample":30,"temperature":0,"stopSequences":[],"timeoutMs":5000,"stream":false,"model":"gemini-1.5-pro-latest"}'
```

Output:
```
❯ curl 'https://sourcegraph.test:3443/.api/completions/stream' -i \
-X POST \
-H 'authorization: token $YOUR_LOCAL_TOKEN' \
--data-raw '{"messages":[{"speaker":"human","text":"Who are you?"}],"maxTokensToSample":30,"temperature":0,"stopSequences":[],"timeoutMs":5000,"stream":false,"model":"gemini-1.5-pro-latest"}'
HTTP/2 200
access-control-allow-credentials: true
access-control-allow-origin:
alt-svc: h3=":3443"; ma=2592000
cache-control: no-cache, max-age=0
content-type: text/plain; charset=utf-8
date: Tue, 11 Jun 2024 17:02:19 GMT
server: Caddy
server: Caddy
vary: Accept-Encoding, Authorization, Cookie, Authorization, X-Requested-With, Cookie
x-content-type-options: nosniff
x-frame-options: DENY
x-powered-by: Express
x-trace: e11a2ce292639414dd2ccdfcbfa89611
x-trace-span: 9457aa0dd0e09b6c
x-trace-url: https://sourcegraph.test:3443/-/debug/jaeger/trace/e11a2ce292639414dd2ccdfcbfa89611
x-xss-protection: 1; mode=block
content-length: 154

{"completion":"I am a large language model, trained by Google. \n\nHere's what that means:\n\n* **I am a computer program:** I","stopReason":"MAX_TOKENS"}%
```

## Changelog

<!--
1. Ensure your pull request title is formatted as: $type($domain): $what
2. Add bullet list items for each additional detail you want to cover
(see example below)
3. You can edit this after the pull request was merged, as long as
release shipping it hasn't been promoted to the public.
4. For more information, please see this how-to
https://www.notion.so/sourcegraph/Writing-a-changelog-entry-dd997f411d524caabf0d8d38a24a878c?

Audience: TS/CSE > Customers > Teammates (in that order).

Cheat sheet: $type = chore|fix|feat $domain:
source|search|ci|release|plg|cody|local|...
-->

<!--
Example:

Title: fix(search): parse quotes with the appropriate context
Changelog section:

## Changelog

- When a quote is used with regexp pattern type, then ...
- Refactored underlying code.
-->
This commit is contained in:
Beatrix 2024-06-11 10:54:27 -07:00 committed by GitHub
parent 549beac5ec
commit e1551657b1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 92 additions and 35 deletions

View File

@ -41,10 +41,12 @@ func NewGoogleHandler(baseLogger log.Logger, eventLogger events.Logger, rs limit
// The request body for Google completions.
// Ref: https://ai.google.dev/api/rest/v1/models/generateContent#request-body
type googleRequest struct {
Model string `json:"model"`
Contents []googleContentMessage `json:"contents"`
GenerationConfig googleGenerationConfig `json:"generationConfig,omitempty"`
SafetySettings []googleSafetySettings `json:"safetySettings,omitempty"`
Model string `json:"model"`
Stream bool `json:"stream,omitempty"`
Contents []googleContentMessage `json:"contents"`
GenerationConfig googleGenerationConfig `json:"generationConfig,omitempty"`
SafetySettings []googleSafetySettings `json:"safetySettings,omitempty"`
SymtemInstruction string `json:"systemInstruction,omitempty"`
}
type googleContentMessage struct {
@ -73,7 +75,7 @@ type googleSafetySettings struct {
}
func (r googleRequest) ShouldStream() bool {
return true
return r.Stream
}
func (r googleRequest) GetModel() string {
@ -119,7 +121,7 @@ func (g *GoogleHandlerMethods) getAPIURL(_ codygateway.Feature, req googleReques
}
func (*GoogleHandlerMethods) validateRequest(_ context.Context, _ log.Logger, feature codygateway.Feature, _ googleRequest) error {
if feature == codygateway.FeatureCodeCompletions {
if feature == codygateway.FeatureEmbeddings {
return errors.Newf("feature %q is currently not supported for Google", feature)
}
return nil
@ -141,7 +143,7 @@ func (*GoogleHandlerMethods) transformBody(_ *googleRequest, _ string) {
}
func (*GoogleHandlerMethods) getRequestMetadata(body googleRequest) (model string, additionalMetadata map[string]any) {
return body.Model, map[string]any{"stream": body.ShouldStream()}
return body.Model, map[string]any{"stream": body.Stream}
}
func (o *GoogleHandlerMethods) transformRequest(r *http.Request) {
@ -156,7 +158,6 @@ func (*GoogleHandlerMethods) parseResponseAndUsage(logger log.Logger, reqBody go
// it as JSON.
if !reqBody.ShouldStream() {
var res googleResponse
if err := json.NewDecoder(r).Decode(&res); err != nil {
logger.Error("failed to parse Google response as JSON", log.Error(err))
return promptUsage, completionUsage

View File

@ -14,7 +14,7 @@ func TestGoogleRequestGetTokenCount(t *testing.T) {
logger := logtest.Scoped(t)
t.Run("streaming", func(t *testing.T) {
req := googleRequest{}
req := googleRequest{Stream: true}
r := strings.NewReader(googleStreamingResponse)
handler := &GoogleHandlerMethods{}
promptUsage, completionUsage := handler.parseResponseAndUsage(logger, req, r)
@ -22,6 +22,16 @@ func TestGoogleRequestGetTokenCount(t *testing.T) {
assert.Equal(t, 21, promptUsage.tokens)
assert.Equal(t, 87, completionUsage.tokens)
})
t.Run("non-streaming", func(t *testing.T) {
req := googleRequest{Stream: false}
r := strings.NewReader(googleNonStreamingResponse)
handler := &GoogleHandlerMethods{}
promptUsage, completionUsage := handler.parseResponseAndUsage(logger, req, r)
assert.Equal(t, 59, promptUsage.tokens)
assert.Equal(t, 54, completionUsage.tokens)
})
}
var googleStreamingResponse = `data: {"candidates": [{"content": {"parts": [{"text": "def"}],"role": "model"},"finishReason": "STOP","index": 0}],"usageMetadata": {"promptTokenCount": 21,"candidatesTokenCount": 1,"totalTokenCount": 22}}
@ -36,6 +46,47 @@ data: {"candidates": [{"content": {"parts": [{"text": "1[j+1] = list1[j+1], list
`
var googleNonStreamingResponse = `{
"candidates": [
{
"content": {
"parts": [
{
"text": "The cobblestone path, worn smooth by centuries of weary feet, led to a humble cottage nestled within the quiet village of Saint-Martin, where a young boy named Pierre discovered a weathered, leather backpack tucked beneath the gnarled oak tree in his grandmother's garden. \n"
}
],
"role": "model"
},
"finishReason": "STOP",
"index": 0,
"safetyRatings": [
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"probability": "NEGLIGIBLE"
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"probability": "NEGLIGIBLE"
},
{
"category": "HARM_CATEGORY_HARASSMENT",
"probability": "NEGLIGIBLE"
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"probability": "NEGLIGIBLE"
}
]
}
],
"usageMetadata": {
"promptTokenCount": 59,
"candidatesTokenCount": 54,
"totalTokenCount": 113
}
}
`
func TestParseGoogleTokenUsage(t *testing.T) {
tests := []struct {
name string

View File

@ -7,12 +7,13 @@ import (
"github.com/Khan/genqlient/graphql"
"github.com/gorilla/mux"
"github.com/sourcegraph/log"
"github.com/sourcegraph/sourcegraph/internal/collections"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/metric"
"github.com/sourcegraph/sourcegraph/internal/collections"
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/auth"
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/events"
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/httpapi/attribution"

View File

@ -14,10 +14,24 @@ type GoogleGatewayFeatureClient struct{}
func (o GoogleGatewayFeatureClient) GetRequest(f codygateway.Feature, req *http.Request, stream bool) (*http.Request, error) {
if f == codygateway.FeatureCodeCompletions {
return nil, errNotImplemented
body := fmt.Sprintf(`{
"model":"gemini-1.5-flash-latest",
"contents":[{"parts":[{"text":"You are Cody"}],"role":"user"},{"parts":[{"text":"Ok I am Cody"}],"role":"model"},{"parts":[{"text":"Write Bubble sort in Python"}],"role":"user"}],
"generationConfig":{"temperature":0.2,"topP":0.95,"topP":0.95,"maxOutputTokens":1000},
"stream":%t
}`, stream)
req.Method = "POST"
req.URL.Path = "/v1/completions/google"
req.Body = io.NopCloser(strings.NewReader(body))
return req, nil
}
if f == codygateway.FeatureChatCompletions {
body := fmt.Sprintf(`{"model":"gemini-1.5-flash-latest","contents":[{"parts":[{"text":"You are Cody"}],"role":"user"},{"parts":[{"text":"Ok I am Cody"}],"role":"model"},{"parts":[{"text":"Write Bubble sort in Python"}],"role":"user"}],"stream":%t}`, stream)
body := fmt.Sprintf(`{
"model":"gemini-1.5-pro-latest",
"contents":[{"parts":[{"text":"You are Cody"}],"role":"user"},{"parts":[{"text":"Ok I am Cody"}],"role":"model"},{"parts":[{"text":"Write Bubble sort in Python"}],"role":"user"}],
"generationConfig":{"temperature":0.2,"topP":0.95,"topP":0.95,"maxOutputTokens":1000},
"stream":%t
}`, stream)
req.Method = "POST"
req.URL.Path = "/v1/completions/google"
req.Body = io.NopCloser(strings.NewReader(body))

View File

@ -7,6 +7,7 @@ import (
"github.com/sourcegraph/log"
"github.com/sourcegraph/sourcegraph/internal/completions/client/fireworks"
"github.com/sourcegraph/sourcegraph/internal/completions/client/google"
"github.com/sourcegraph/sourcegraph/internal/completions/types"
"github.com/sourcegraph/sourcegraph/internal/conf/conftypes"
"github.com/sourcegraph/sourcegraph/internal/database"
@ -67,6 +68,8 @@ func allowedCustomModel(model string) string {
"anthropic/claude-instant-v1",
"anthropic/claude-instant-1",
"anthropic/claude-instant-1.2-cyan",
"google/" + google.Gemini15Flash,
"google/" + google.GeminiPro,
"fireworks/accounts/sourcegraph/models/starcoder-7b",
"fireworks/accounts/sourcegraph/models/starcoder-16b",
"fireworks/accounts/fireworks/models/starcoder-3b-w8a16",

View File

@ -30,10 +30,6 @@ func (c *googleCompletionStreamClient) Complete(
requestParams types.CompletionRequestParameters,
logger log.Logger,
) (*types.CompletionResponse, error) {
if !isSupportedFeature(feature) {
return nil, errors.Newf("feature %q is currently not supported for Google", feature)
}
var resp *http.Response
var err error
defer (func() {
@ -79,10 +75,6 @@ func (c *googleCompletionStreamClient) Stream(
sendEvent types.SendCompletionEvent,
logger log.Logger,
) error {
if !isSupportedFeature(feature) {
return errors.Newf("feature %q is currently not supported for Google", feature)
}
var resp *http.Response
var err error
@ -151,6 +143,7 @@ func (c *googleCompletionStreamClient) makeRequest(ctx context.Context, requestP
payload := googleRequest{
Model: requestParams.Model,
Stream: stream,
Contents: prompt,
GenerationConfig: googleGenerationConfig{
Temperature: requestParams.Temperature,
@ -228,17 +221,6 @@ func getgRPCMethod(stream bool) string {
return "generateContent"
}
// isSupportedFeature checks if the given CompletionsFeature is supported.
// Currently, only the CompletionsFeatureChat feature is supported.
func isSupportedFeature(feature types.CompletionsFeature) bool {
switch feature {
case types.CompletionsFeatureChat:
return true
default:
return false
}
}
// isDefaultAPIEndpoint checks if the given API endpoint URL is the default API endpoint.
// The default API endpoint is determined by the defaultAPIHost constant.
func isDefaultAPIEndpoint(endpoint *url.URL) bool {

View File

@ -8,11 +8,16 @@ type googleCompletionStreamClient struct {
endpoint string
}
// The request body for the completion stream endpoint.
// Ref: https://ai.google.dev/api/rest/v1beta/models/generateContent
// Ref: https://ai.google.dev/api/rest/v1beta/models/streamGenerateContent
type googleRequest struct {
Model string `json:"model"`
Contents []googleContentMessage `json:"contents"`
GenerationConfig googleGenerationConfig `json:"generationConfig,omitempty"`
SafetySettings []googleSafetySettings `json:"safetySettings,omitempty"`
Model string `json:"model"`
Stream bool `json:"stream,omitempty"`
Contents []googleContentMessage `json:"contents"`
GenerationConfig googleGenerationConfig `json:"generationConfig,omitempty"`
SafetySettings []googleSafetySettings `json:"safetySettings,omitempty"`
SymtemInstruction string `json:"systemInstruction,omitempty"`
}
type googleContentMessage struct {