From 0c777bac4192b2cb85f29aae3bcfea1dcd9be377 Mon Sep 17 00:00:00 2001 From: Beatrix <68532117+abeatrix@users.noreply.github.com> Date: Tue, 18 Jun 2024 15:25:26 -0700 Subject: [PATCH] fix(cody-gateway): streaming google endpoint (#63306) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Issue: Currently, the ShouldStream() method will always returns false because the Stream value is removed before it was passed into the Handler. To fix this, we will store the original googleRequest.Stream value if it's true so that ShouldStream() will return the correct Stream value. We will also use the transformBody method to remove the Stream value before we send it to Google API. Here is the expected behaviour after the stream is fixed: https://github.com/sourcegraph/sourcegraph/assets/68532117/8324fb8c-0625-4579-b0e9-0abfc9858961 Also confirmed it works with both Cody Gateway and BYOK: ![image](https://github.com/sourcegraph/sourcegraph/assets/68532117/9fe60423-a05b-412d-812a-f34cd812d9dc) ## Test plan Always stream Cody Gateway's requests for Google Gemini models as we haven't implemented Code Completion feature on the client side. ### Non-stream request ``` ❯ curl 'https://sourcegraph.test:3443/.api/completions/code' -i \ -X POST \ -H 'authorization: token LOCALTOKEN' \ --data-raw '{"messages":[{"speaker":"human","text":"Who are you?"}],"maxTokensToSample":30,"temperature":0,"stopSequences":[],"timeoutMs":5000}' HTTP/2 200 access-control-allow-credentials: true access-control-allow-origin: alt-svc: h3=":3443"; ma=2592000 cache-control: no-cache, max-age=0 content-type: text/plain; charset=utf-8 date: Tue, 18 Jun 2024 21:05:38 GMT server: Caddy server: Caddy set-cookie: sourcegraphDeviceId=d4fa7789-2442-472a-b425-a68372d27944; Expires=Wed, 18 Jun 2025 21:05:36 GMT; Secure vary: Cookie, Accept-Encoding, Authorization, Cookie, Authorization, X-Requested-With, Cookie x-content-type-options: nosniff x-frame-options: DENY x-powered-by: Express x-trace: 00f998a2a2e1b6895687ad7cc567b41c x-trace-span: da9c93d16415b94f x-trace-url: https://sourcegraph.test:3443/-/debug/jaeger/trace/00f998a2a2e1b6895687ad7cc567b41c x-xss-protection: 1; mode=block content-length: 147 {"completion":"I am a large language model, trained by Google. \n\nHere's what that means:\n\n* **Large Language Model:** I'm","stopReason":"STOP"}% ``` ### Streaming request: ``` ❯ curl 'https://sourcegraph.test:3443/.api/completions/stream' -i \ -X POST \ -H 'authorization: token $LOCALTOKEN' \ --data-raw '{"stream":true,"messages":[{"speaker":"human","text":"Who are you?"}],"maxTokensToSample":1000,"temperature":0,"stopSequences":[],"timeoutMs":5000}' HTTP/2 200 access-control-allow-credentials: true access-control-allow-origin: alt-svc: h3=":3443"; ma=2592000 cache-control: no-cache content-type: text/event-stream date: Tue, 18 Jun 2024 21:07:02 GMT server: Caddy server: Caddy set-cookie: sourcegraphDeviceId=38b45f36-d237-4f8d-8242-a63fcc801a32; Expires=Wed, 18 Jun 2025 21:06:59 GMT; Secure vary: Cookie, Accept-Encoding, Authorization, Cookie, Authorization, X-Requested-With, Cookie x-accel-buffering: no x-content-type-options: nosniff x-frame-options: DENY x-powered-by: Express x-trace: 984932973626e14f7cb0ce7e8e470717 x-trace-span: d285179cfb744e08 x-trace-url: https://sourcegraph.test:3443/-/debug/jaeger/trace/984932973626e14f7cb0ce7e8e470717 x-xss-protection: 1; mode=block event: completion data: {"completion":"I","stopReason":"STOP"} event: completion data: {"completion":"I am a large language model, trained by Google. \n\nHere's what","stopReason":"STOP"} event: completion data: {"completion":"I am a large language model, trained by Google. \n\nHere's what that means:\n\n* **I am not a person.** I am a computer","stopReason":"STOP"} event: completion data: {"completion":"I am a large language model, trained by Google. \n\nHere's what that means:\n\n* **I am not a person.** I am a computer program designed to process and generate human-like text. \n* **I learn from data.** I was trained on a massive dataset of text and code,","stopReason":"STOP"} event: completion data: {"completion":"I am a large language model, trained by Google. \n\nHere's what that means:\n\n* **I am not a person.** I am a computer program designed to process and generate human-like text. \n* **I learn from data.** I was trained on a massive dataset of text and code, which allows me to generate text, translate languages, write different kinds of creative content, and answer your questions in an informative way.\n* **I am still","stopReason":"STOP"} event: completion data: {"completion":"I am a large language model, trained by Google. \n\nHere's what that means:\n\n* **I am not a person.** I am a computer program designed to process and generate human-like text. \n* **I learn from data.** I was trained on a massive dataset of text and code, which allows me to generate text, translate languages, write different kinds of creative content, and answer your questions in an informative way.\n* **I am still under development.** I am constantly learning and improving, but I am not perfect and can sometimes make mistakes.\n\nHow can I help you today? \n","stopReason":"STOP"} event: done data: {} ``` ## Changelog --- .../internal/httpapi/completions/google.go | 14 +++--- .../httpapi/completions/google_types.go | 15 ++++--- internal/completions/client/client.go | 2 +- .../client/codygateway/codygateway.go | 2 +- internal/completions/client/google/google.go | 43 +++++++------------ .../completions/client/google/google_test.go | 2 +- internal/completions/client/google/types.go | 16 ++++--- 7 files changed, 45 insertions(+), 49 deletions(-) diff --git a/cmd/cody-gateway/internal/httpapi/completions/google.go b/cmd/cody-gateway/internal/httpapi/completions/google.go index 68ca4a55a95..9a700930a80 100644 --- a/cmd/cody-gateway/internal/httpapi/completions/google.go +++ b/cmd/cody-gateway/internal/httpapi/completions/google.go @@ -63,9 +63,11 @@ func (r googleRequest) BuildPrompt() string { func (g *GoogleHandlerMethods) getAPIURL(feature codygateway.Feature, req googleRequest) string { rpc := "generateContent" sseSuffix := "" - if feature == codygateway.FeatureChatCompletions { + // If we're streaming, we need to use the stream endpoint. + if feature == codygateway.FeatureChatCompletions || req.ShouldStream() { rpc = "streamGenerateContent" sseSuffix = "&alt=sse" + req.Stream = true } return fmt.Sprintf("https://generativelanguage.googleapis.com/v1beta/models/%s:%s?key=%s%s", req.Model, rpc, g.config.AccessToken, sseSuffix) } @@ -93,10 +95,13 @@ func (g *GoogleHandlerMethods) shouldFlagRequest(ctx context.Context, logger log } // Used to modify the request body before it is sent to upstream. -func (*GoogleHandlerMethods) transformBody(*googleRequest, string) {} +func (*GoogleHandlerMethods) transformBody(gr *googleRequest, _ string) { + // Remove Stream from the request body before sending it to Google. + gr.Stream = false +} func (*GoogleHandlerMethods) getRequestMetadata(body googleRequest) (model string, additionalMetadata map[string]any) { - return body.Model, map[string]any{"stream": body.Stream} + return body.Model, map[string]any{"stream": body.ShouldStream()} } func (o *GoogleHandlerMethods) transformRequest(r *http.Request) { @@ -106,10 +111,9 @@ func (o *GoogleHandlerMethods) transformRequest(r *http.Request) { func (*GoogleHandlerMethods) parseResponseAndUsage(logger log.Logger, reqBody googleRequest, r io.Reader) (promptUsage, completionUsage usageStats) { // First, extract prompt usage details from the request. promptUsage.characters = len(reqBody.BuildPrompt()) - // Try to parse the request we saw, if it was non-streaming, we can simply parse // it as JSON. - if !reqBody.ShouldStream() { + if !reqBody.Stream && !reqBody.ShouldStream() { var res googleResponse if err := json.NewDecoder(r).Decode(&res); err != nil { logger.Error("failed to parse Google response as JSON", log.Error(err)) diff --git a/cmd/cody-gateway/internal/httpapi/completions/google_types.go b/cmd/cody-gateway/internal/httpapi/completions/google_types.go index e164deac926..356df1dd8ac 100644 --- a/cmd/cody-gateway/internal/httpapi/completions/google_types.go +++ b/cmd/cody-gateway/internal/httpapi/completions/google_types.go @@ -11,8 +11,8 @@ type googleRequest struct { SymtemInstruction string `json:"systemInstruction,omitempty"` // Stream is used for our internal routing of the Google Request, and is not part - // of the Google API shape. So we make sure to not include it when marshaling into JSON. - Stream bool `json:"-"` // This field will not be marshaled into JSON + // of the Google API shape. + Stream bool `json:"stream,omitempty"` } type googleContentMessage struct { @@ -36,12 +36,13 @@ type googleGenerationConfig struct { } type googleResponse struct { - Candidates []struct { - Content googleContentMessage `json:"content,omitempty"` - FinishReason string `json:"finishReason,omitempty"` - } `json:"candidates"` + Candidates []googleCandidates `json:"candidates,omitempty"` + UsageMetadata googleUsage `json:"usageMetadata,omitempty"` +} - UsageMetadata googleUsage `json:"usageMetadata,omitempty"` +type googleCandidates struct { + Content googleContentMessage `json:"content,omitempty"` + FinishReason string `json:"finishReason,omitempty"` SafetyRatings []googleSafetyRatings `json:"safetyRatings,omitempty"` } diff --git a/internal/completions/client/client.go b/internal/completions/client/client.go index 9871951a439..a97e501cc6a 100644 --- a/internal/completions/client/client.go +++ b/internal/completions/client/client.go @@ -42,7 +42,7 @@ func getBasic(endpoint string, provider conftypes.CompletionsProviderName, acces case conftypes.CompletionsProviderNameAzureOpenAI: return azureopenai.NewClient(azureopenai.GetAPIClient, endpoint, accessToken, *tokenManager) case conftypes.CompletionsProviderNameGoogle: - return google.NewClient(httpcli.UncachedExternalDoer, endpoint, accessToken), nil + return google.NewClient(httpcli.UncachedExternalDoer, endpoint, accessToken, false), nil case conftypes.CompletionsProviderNameSourcegraph: return codygateway.NewClient(httpcli.CodyGatewayDoer, endpoint, accessToken, *tokenManager) case conftypes.CompletionsProviderNameFireworks: diff --git a/internal/completions/client/codygateway/codygateway.go b/internal/completions/client/codygateway/codygateway.go index 5152dd39d2d..5354488989d 100644 --- a/internal/completions/client/codygateway/codygateway.go +++ b/internal/completions/client/codygateway/codygateway.go @@ -91,7 +91,7 @@ func (c *codyGatewayClient) clientForParams(feature types.CompletionsFeature, re case string(conftypes.CompletionsProviderNameFireworks): return fireworks.NewClient(gatewayDoer(c.upstream, feature, c.gatewayURL, c.accessToken, "/v1/completions/fireworks"), "", ""), nil case string(conftypes.CompletionsProviderNameGoogle): - return google.NewClient(gatewayDoer(c.upstream, feature, c.gatewayURL, c.accessToken, "/v1/completions/google"), "", ""), nil + return google.NewClient(gatewayDoer(c.upstream, feature, c.gatewayURL, c.accessToken, "/v1/completions/google"), "", "", true), nil case "": return nil, errors.Newf("no provider provided in model %s - a model in the format '$PROVIDER/$MODEL_NAME' is expected", model) default: diff --git a/internal/completions/client/google/google.go b/internal/completions/client/google/google.go index 05d02cb2879..c321fe20924 100644 --- a/internal/completions/client/google/google.go +++ b/internal/completions/client/google/google.go @@ -15,11 +15,12 @@ import ( "github.com/sourcegraph/sourcegraph/lib/errors" ) -func NewClient(cli httpcli.Doer, endpoint, accessToken string) types.CompletionsClient { +func NewClient(cli httpcli.Doer, endpoint, accessToken string, viaGateway bool) types.CompletionsClient { return &googleCompletionStreamClient{ cli: cli, accessToken: accessToken, endpoint: endpoint, + viaGateway: viaGateway, } } @@ -30,15 +31,7 @@ func (c *googleCompletionStreamClient) Complete( requestParams types.CompletionRequestParameters, logger log.Logger, ) (*types.CompletionResponse, error) { - var resp *http.Response - var err error - defer (func() { - if resp != nil { - resp.Body.Close() - } - })() - - resp, err = c.makeRequest(ctx, requestParams, false) + resp, err := c.makeRequest(ctx, requestParams, false) if err != nil { return nil, err } @@ -75,19 +68,11 @@ func (c *googleCompletionStreamClient) Stream( sendEvent types.SendCompletionEvent, logger log.Logger, ) error { - var resp *http.Response - var err error - - defer (func() { - if resp != nil { - resp.Body.Close() - } - })() - - resp, err = c.makeRequest(ctx, requestParams, true) + resp, err := c.makeRequest(ctx, requestParams, true) if err != nil { return err } + defer resp.Body.Close() dec := NewDecoder(resp.Body) var content string @@ -131,19 +116,19 @@ func (c *googleCompletionStreamClient) Stream( // makeRequest formats the request and calls the chat/completions endpoint for code_completion requests func (c *googleCompletionStreamClient) makeRequest(ctx context.Context, requestParams types.CompletionRequestParameters, stream bool) (*http.Response, error) { + apiURL := c.getAPIURL(requestParams, stream) + endpointURL := apiURL.String() + // Ensure TopK and TopP are non-negative requestParams.TopK = max(0, requestParams.TopK) requestParams.TopP = max(0, requestParams.TopP) - // Generate the prompt prompt, err := getPrompt(requestParams.Messages) if err != nil { return nil, err } - payload := googleRequest{ Model: requestParams.Model, - Stream: stream, Contents: prompt, GenerationConfig: googleGenerationConfig{ Temperature: requestParams.Temperature, @@ -153,14 +138,18 @@ func (c *googleCompletionStreamClient) makeRequest(ctx context.Context, requestP StopSequences: requestParams.StopSequences, }, } + if c.viaGateway { + endpointURL = c.endpoint + // Add the Stream value to the payload if this is a Cody Gateway request, + // as it is used for internal routing but not part of the Google API shape. + payload.Stream = stream + } reqBody, err := json.Marshal(payload) if err != nil { return nil, err } - - apiURL := c.getAPIURL(requestParams, stream) - req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL.String(), bytes.NewReader(reqBody)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpointURL, bytes.NewReader(reqBody)) if err != nil { return nil, err } @@ -169,7 +158,7 @@ func (c *googleCompletionStreamClient) makeRequest(ctx context.Context, requestP // Vertex AI API requires an Authorization header with the access token. // Ref: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/gemini#sample-requests - if !isDefaultAPIEndpoint(apiURL) { + if !c.viaGateway && !isDefaultAPIEndpoint(apiURL) { req.Header.Set("Authorization", "Bearer "+c.accessToken) } diff --git a/internal/completions/client/google/google_test.go b/internal/completions/client/google/google_test.go index 3e27682c635..0a8264e45a5 100644 --- a/internal/completions/client/google/google_test.go +++ b/internal/completions/client/google/google_test.go @@ -31,7 +31,7 @@ func TestErrStatusNotOK(t *testing.T) { Body: io.NopCloser(bytes.NewReader([]byte("oh no, please slow down!"))), }, nil }, - }, "", "") + }, "", "", false) t.Run("Complete", func(t *testing.T) { logger := log.Scoped("completions") diff --git a/internal/completions/client/google/types.go b/internal/completions/client/google/types.go index 347da99ba3d..91af7917a55 100644 --- a/internal/completions/client/google/types.go +++ b/internal/completions/client/google/types.go @@ -6,6 +6,7 @@ type googleCompletionStreamClient struct { cli httpcli.Doer accessToken string endpoint string + viaGateway bool } // The request body for the completion stream endpoint. @@ -19,8 +20,8 @@ type googleRequest struct { SymtemInstruction string `json:"systemInstruction,omitempty"` // Stream is used for our internal routing of the Google Request, and is not part - // of the Google API shape. So we make sure to not include it when marshaling into JSON. - Stream bool `json:"-"` // This field will not be marshaled into JSON + // of the Google API shape. + Stream bool `json:"stream,omitempty"` } type googleContentMessage struct { @@ -44,12 +45,13 @@ type googleGenerationConfig struct { } type googleResponse struct { - Candidates []struct { - Content googleContentMessage `json:"content,omitempty"` - FinishReason string `json:"finishReason,omitempty"` - } `json:"candidates"` + Candidates []googleCandidates `json:"candidates,omitempty"` + UsageMetadata googleUsage `json:"usageMetadata,omitempty"` +} - UsageMetadata googleUsage `json:"usageMetadata,omitempty"` +type googleCandidates struct { + Content googleContentMessage `json:"content,omitempty"` + FinishReason string `json:"finishReason,omitempty"` SafetyRatings []googleSafetyRatings `json:"safetyRatings,omitempty"` }