fix(cody-gateway): streaming google endpoint (#63306)

<!-- 💡 To write a useful PR description, make sure that your description
covers:
- WHAT this PR is changing:
    - How was it PREVIOUSLY.
    - How it will be from NOW on.
- WHY this PR is needed.
- CONTEXT, i.e. to which initiative, project or RFC it belongs.

The structure of the description doesn't matter as much as covering
these points, so use
your best judgement based on your context.
Learn how to write good pull request description:
https://www.notion.so/sourcegraph/Write-a-good-pull-request-description-610a7fd3e613496eb76f450db5a49b6e?pvs=4
-->

Issue: Currently, the ShouldStream() method will always returns false
because the Stream value is removed before it was passed into the
Handler.

To fix this, we will store the original googleRequest.Stream value if
it's true so that ShouldStream() will return the correct Stream value.
We will also use the transformBody method to remove the Stream value
before we send it to Google API.

Here is the expected behaviour after the stream is fixed:


https://github.com/sourcegraph/sourcegraph/assets/68532117/8324fb8c-0625-4579-b0e9-0abfc9858961


Also confirmed it works with both Cody Gateway and BYOK:


![image](https://github.com/sourcegraph/sourcegraph/assets/68532117/9fe60423-a05b-412d-812a-f34cd812d9dc)


## Test plan

<!-- All pull requests REQUIRE a test plan:
https://docs-legacy.sourcegraph.com/dev/background-information/testing_principles
-->

Always stream Cody Gateway's requests for Google Gemini models as we
haven't implemented Code Completion feature on the client side.

### Non-stream request

```
❯ curl 'https://sourcegraph.test:3443/.api/completions/code' -i \
-X POST \
-H 'authorization: token LOCALTOKEN' \
--data-raw '{"messages":[{"speaker":"human","text":"Who are you?"}],"maxTokensToSample":30,"temperature":0,"stopSequences":[],"timeoutMs":5000}'
HTTP/2 200
access-control-allow-credentials: true
access-control-allow-origin:
alt-svc: h3=":3443"; ma=2592000
cache-control: no-cache, max-age=0
content-type: text/plain; charset=utf-8
date: Tue, 18 Jun 2024 21:05:38 GMT
server: Caddy
server: Caddy
set-cookie: sourcegraphDeviceId=d4fa7789-2442-472a-b425-a68372d27944; Expires=Wed, 18 Jun 2025 21:05:36 GMT; Secure
vary: Cookie, Accept-Encoding, Authorization, Cookie, Authorization, X-Requested-With, Cookie
x-content-type-options: nosniff
x-frame-options: DENY
x-powered-by: Express
x-trace: 00f998a2a2e1b6895687ad7cc567b41c
x-trace-span: da9c93d16415b94f
x-trace-url: https://sourcegraph.test:3443/-/debug/jaeger/trace/00f998a2a2e1b6895687ad7cc567b41c
x-xss-protection: 1; mode=block
content-length: 147

{"completion":"I am a large language model, trained by Google. \n\nHere's what that means:\n\n* **Large Language Model:** I'm","stopReason":"STOP"}%
```

### Streaming request: 

```
❯ curl 'https://sourcegraph.test:3443/.api/completions/stream' -i \
-X POST \
-H 'authorization: token $LOCALTOKEN' \
--data-raw '{"stream":true,"messages":[{"speaker":"human","text":"Who are you?"}],"maxTokensToSample":1000,"temperature":0,"stopSequences":[],"timeoutMs":5000}'

HTTP/2 200
access-control-allow-credentials: true
access-control-allow-origin:
alt-svc: h3=":3443"; ma=2592000
cache-control: no-cache
content-type: text/event-stream
date: Tue, 18 Jun 2024 21:07:02 GMT
server: Caddy
server: Caddy
set-cookie: sourcegraphDeviceId=38b45f36-d237-4f8d-8242-a63fcc801a32; Expires=Wed, 18 Jun 2025 21:06:59 GMT; Secure
vary: Cookie, Accept-Encoding, Authorization, Cookie, Authorization, X-Requested-With, Cookie
x-accel-buffering: no
x-content-type-options: nosniff
x-frame-options: DENY
x-powered-by: Express
x-trace: 984932973626e14f7cb0ce7e8e470717
x-trace-span: d285179cfb744e08
x-trace-url: https://sourcegraph.test:3443/-/debug/jaeger/trace/984932973626e14f7cb0ce7e8e470717
x-xss-protection: 1; mode=block

event: completion
data: {"completion":"I","stopReason":"STOP"}

event: completion
data: {"completion":"I am a large language model, trained by Google. \n\nHere's what","stopReason":"STOP"}

event: completion
data: {"completion":"I am a large language model, trained by Google. \n\nHere's what that means:\n\n* **I am not a person.** I am a computer","stopReason":"STOP"}

event: completion
data: {"completion":"I am a large language model, trained by Google. \n\nHere's what that means:\n\n* **I am not a person.** I am a computer program designed to process and generate human-like text. \n* **I learn from data.** I was trained on a massive dataset of text and code,","stopReason":"STOP"}

event: completion
data: {"completion":"I am a large language model, trained by Google. \n\nHere's what that means:\n\n* **I am not a person.** I am a computer program designed to process and generate human-like text. \n* **I learn from data.** I was trained on a massive dataset of text and code, which allows me to generate text, translate languages, write different kinds of creative content, and answer your questions in an informative way.\n* **I am still","stopReason":"STOP"}

event: completion
data: {"completion":"I am a large language model, trained by Google. \n\nHere's what that means:\n\n* **I am not a person.** I am a computer program designed to process and generate human-like text. \n* **I learn from data.** I was trained on a massive dataset of text and code, which allows me to generate text, translate languages, write different kinds of creative content, and answer your questions in an informative way.\n* **I am still under development.** I am constantly learning and improving, but I am not perfect and can sometimes make mistakes.\n\nHow can I help you today? \n","stopReason":"STOP"}

event: done
data: {}
```

## Changelog

<!--
1. Ensure your pull request title is formatted as: $type($domain): $what
2. Add bullet list items for each additional detail you want to cover
(see example below)
3. You can edit this after the pull request was merged, as long as
release shipping it hasn't been promoted to the public.
4. For more information, please see this how-to
https://www.notion.so/sourcegraph/Writing-a-changelog-entry-dd997f411d524caabf0d8d38a24a878c?

Audience: TS/CSE > Customers > Teammates (in that order).

Cheat sheet: $type = chore|fix|feat $domain:
source|search|ci|release|plg|cody|local|...
-->

<!--
Example:

Title: fix(search): parse quotes with the appropriate context
Changelog section:

## Changelog

- When a quote is used with regexp pattern type, then ...
- Refactored underlying code.
-->
This commit is contained in:
Beatrix 2024-06-18 15:25:26 -07:00 committed by GitHub
parent a333771bd4
commit 0c777bac41
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 45 additions and 49 deletions

View File

@ -63,9 +63,11 @@ func (r googleRequest) BuildPrompt() string {
func (g *GoogleHandlerMethods) getAPIURL(feature codygateway.Feature, req googleRequest) string {
rpc := "generateContent"
sseSuffix := ""
if feature == codygateway.FeatureChatCompletions {
// If we're streaming, we need to use the stream endpoint.
if feature == codygateway.FeatureChatCompletions || req.ShouldStream() {
rpc = "streamGenerateContent"
sseSuffix = "&alt=sse"
req.Stream = true
}
return fmt.Sprintf("https://generativelanguage.googleapis.com/v1beta/models/%s:%s?key=%s%s", req.Model, rpc, g.config.AccessToken, sseSuffix)
}
@ -93,10 +95,13 @@ func (g *GoogleHandlerMethods) shouldFlagRequest(ctx context.Context, logger log
}
// Used to modify the request body before it is sent to upstream.
func (*GoogleHandlerMethods) transformBody(*googleRequest, string) {}
func (*GoogleHandlerMethods) transformBody(gr *googleRequest, _ string) {
// Remove Stream from the request body before sending it to Google.
gr.Stream = false
}
func (*GoogleHandlerMethods) getRequestMetadata(body googleRequest) (model string, additionalMetadata map[string]any) {
return body.Model, map[string]any{"stream": body.Stream}
return body.Model, map[string]any{"stream": body.ShouldStream()}
}
func (o *GoogleHandlerMethods) transformRequest(r *http.Request) {
@ -106,10 +111,9 @@ func (o *GoogleHandlerMethods) transformRequest(r *http.Request) {
func (*GoogleHandlerMethods) parseResponseAndUsage(logger log.Logger, reqBody googleRequest, r io.Reader) (promptUsage, completionUsage usageStats) {
// First, extract prompt usage details from the request.
promptUsage.characters = len(reqBody.BuildPrompt())
// Try to parse the request we saw, if it was non-streaming, we can simply parse
// it as JSON.
if !reqBody.ShouldStream() {
if !reqBody.Stream && !reqBody.ShouldStream() {
var res googleResponse
if err := json.NewDecoder(r).Decode(&res); err != nil {
logger.Error("failed to parse Google response as JSON", log.Error(err))

View File

@ -11,8 +11,8 @@ type googleRequest struct {
SymtemInstruction string `json:"systemInstruction,omitempty"`
// Stream is used for our internal routing of the Google Request, and is not part
// of the Google API shape. So we make sure to not include it when marshaling into JSON.
Stream bool `json:"-"` // This field will not be marshaled into JSON
// of the Google API shape.
Stream bool `json:"stream,omitempty"`
}
type googleContentMessage struct {
@ -36,12 +36,13 @@ type googleGenerationConfig struct {
}
type googleResponse struct {
Candidates []struct {
Content googleContentMessage `json:"content,omitempty"`
FinishReason string `json:"finishReason,omitempty"`
} `json:"candidates"`
Candidates []googleCandidates `json:"candidates,omitempty"`
UsageMetadata googleUsage `json:"usageMetadata,omitempty"`
}
UsageMetadata googleUsage `json:"usageMetadata,omitempty"`
type googleCandidates struct {
Content googleContentMessage `json:"content,omitempty"`
FinishReason string `json:"finishReason,omitempty"`
SafetyRatings []googleSafetyRatings `json:"safetyRatings,omitempty"`
}

View File

@ -42,7 +42,7 @@ func getBasic(endpoint string, provider conftypes.CompletionsProviderName, acces
case conftypes.CompletionsProviderNameAzureOpenAI:
return azureopenai.NewClient(azureopenai.GetAPIClient, endpoint, accessToken, *tokenManager)
case conftypes.CompletionsProviderNameGoogle:
return google.NewClient(httpcli.UncachedExternalDoer, endpoint, accessToken), nil
return google.NewClient(httpcli.UncachedExternalDoer, endpoint, accessToken, false), nil
case conftypes.CompletionsProviderNameSourcegraph:
return codygateway.NewClient(httpcli.CodyGatewayDoer, endpoint, accessToken, *tokenManager)
case conftypes.CompletionsProviderNameFireworks:

View File

@ -91,7 +91,7 @@ func (c *codyGatewayClient) clientForParams(feature types.CompletionsFeature, re
case string(conftypes.CompletionsProviderNameFireworks):
return fireworks.NewClient(gatewayDoer(c.upstream, feature, c.gatewayURL, c.accessToken, "/v1/completions/fireworks"), "", ""), nil
case string(conftypes.CompletionsProviderNameGoogle):
return google.NewClient(gatewayDoer(c.upstream, feature, c.gatewayURL, c.accessToken, "/v1/completions/google"), "", ""), nil
return google.NewClient(gatewayDoer(c.upstream, feature, c.gatewayURL, c.accessToken, "/v1/completions/google"), "", "", true), nil
case "":
return nil, errors.Newf("no provider provided in model %s - a model in the format '$PROVIDER/$MODEL_NAME' is expected", model)
default:

View File

@ -15,11 +15,12 @@ import (
"github.com/sourcegraph/sourcegraph/lib/errors"
)
func NewClient(cli httpcli.Doer, endpoint, accessToken string) types.CompletionsClient {
func NewClient(cli httpcli.Doer, endpoint, accessToken string, viaGateway bool) types.CompletionsClient {
return &googleCompletionStreamClient{
cli: cli,
accessToken: accessToken,
endpoint: endpoint,
viaGateway: viaGateway,
}
}
@ -30,15 +31,7 @@ func (c *googleCompletionStreamClient) Complete(
requestParams types.CompletionRequestParameters,
logger log.Logger,
) (*types.CompletionResponse, error) {
var resp *http.Response
var err error
defer (func() {
if resp != nil {
resp.Body.Close()
}
})()
resp, err = c.makeRequest(ctx, requestParams, false)
resp, err := c.makeRequest(ctx, requestParams, false)
if err != nil {
return nil, err
}
@ -75,19 +68,11 @@ func (c *googleCompletionStreamClient) Stream(
sendEvent types.SendCompletionEvent,
logger log.Logger,
) error {
var resp *http.Response
var err error
defer (func() {
if resp != nil {
resp.Body.Close()
}
})()
resp, err = c.makeRequest(ctx, requestParams, true)
resp, err := c.makeRequest(ctx, requestParams, true)
if err != nil {
return err
}
defer resp.Body.Close()
dec := NewDecoder(resp.Body)
var content string
@ -131,19 +116,19 @@ func (c *googleCompletionStreamClient) Stream(
// makeRequest formats the request and calls the chat/completions endpoint for code_completion requests
func (c *googleCompletionStreamClient) makeRequest(ctx context.Context, requestParams types.CompletionRequestParameters, stream bool) (*http.Response, error) {
apiURL := c.getAPIURL(requestParams, stream)
endpointURL := apiURL.String()
// Ensure TopK and TopP are non-negative
requestParams.TopK = max(0, requestParams.TopK)
requestParams.TopP = max(0, requestParams.TopP)
// Generate the prompt
prompt, err := getPrompt(requestParams.Messages)
if err != nil {
return nil, err
}
payload := googleRequest{
Model: requestParams.Model,
Stream: stream,
Contents: prompt,
GenerationConfig: googleGenerationConfig{
Temperature: requestParams.Temperature,
@ -153,14 +138,18 @@ func (c *googleCompletionStreamClient) makeRequest(ctx context.Context, requestP
StopSequences: requestParams.StopSequences,
},
}
if c.viaGateway {
endpointURL = c.endpoint
// Add the Stream value to the payload if this is a Cody Gateway request,
// as it is used for internal routing but not part of the Google API shape.
payload.Stream = stream
}
reqBody, err := json.Marshal(payload)
if err != nil {
return nil, err
}
apiURL := c.getAPIURL(requestParams, stream)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL.String(), bytes.NewReader(reqBody))
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpointURL, bytes.NewReader(reqBody))
if err != nil {
return nil, err
}
@ -169,7 +158,7 @@ func (c *googleCompletionStreamClient) makeRequest(ctx context.Context, requestP
// Vertex AI API requires an Authorization header with the access token.
// Ref: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/gemini#sample-requests
if !isDefaultAPIEndpoint(apiURL) {
if !c.viaGateway && !isDefaultAPIEndpoint(apiURL) {
req.Header.Set("Authorization", "Bearer "+c.accessToken)
}

View File

@ -31,7 +31,7 @@ func TestErrStatusNotOK(t *testing.T) {
Body: io.NopCloser(bytes.NewReader([]byte("oh no, please slow down!"))),
}, nil
},
}, "", "")
}, "", "", false)
t.Run("Complete", func(t *testing.T) {
logger := log.Scoped("completions")

View File

@ -6,6 +6,7 @@ type googleCompletionStreamClient struct {
cli httpcli.Doer
accessToken string
endpoint string
viaGateway bool
}
// The request body for the completion stream endpoint.
@ -19,8 +20,8 @@ type googleRequest struct {
SymtemInstruction string `json:"systemInstruction,omitempty"`
// Stream is used for our internal routing of the Google Request, and is not part
// of the Google API shape. So we make sure to not include it when marshaling into JSON.
Stream bool `json:"-"` // This field will not be marshaled into JSON
// of the Google API shape.
Stream bool `json:"stream,omitempty"`
}
type googleContentMessage struct {
@ -44,12 +45,13 @@ type googleGenerationConfig struct {
}
type googleResponse struct {
Candidates []struct {
Content googleContentMessage `json:"content,omitempty"`
FinishReason string `json:"finishReason,omitempty"`
} `json:"candidates"`
Candidates []googleCandidates `json:"candidates,omitempty"`
UsageMetadata googleUsage `json:"usageMetadata,omitempty"`
}
UsageMetadata googleUsage `json:"usageMetadata,omitempty"`
type googleCandidates struct {
Content googleContentMessage `json:"content,omitempty"`
FinishReason string `json:"finishReason,omitempty"`
SafetyRatings []googleSafetyRatings `json:"safetyRatings,omitempty"`
}