mirror of
https://github.com/sourcegraph/sourcegraph.git
synced 2026-02-06 18:51:59 +00:00
fix(cody-gateway): streaming google endpoint (#63306)
<!-- 💡 To write a useful PR description, make sure that your description covers: - WHAT this PR is changing: - How was it PREVIOUSLY. - How it will be from NOW on. - WHY this PR is needed. - CONTEXT, i.e. to which initiative, project or RFC it belongs. The structure of the description doesn't matter as much as covering these points, so use your best judgement based on your context. Learn how to write good pull request description: https://www.notion.so/sourcegraph/Write-a-good-pull-request-description-610a7fd3e613496eb76f450db5a49b6e?pvs=4 --> Issue: Currently, the ShouldStream() method will always returns false because the Stream value is removed before it was passed into the Handler. To fix this, we will store the original googleRequest.Stream value if it's true so that ShouldStream() will return the correct Stream value. We will also use the transformBody method to remove the Stream value before we send it to Google API. Here is the expected behaviour after the stream is fixed: https://github.com/sourcegraph/sourcegraph/assets/68532117/8324fb8c-0625-4579-b0e9-0abfc9858961 Also confirmed it works with both Cody Gateway and BYOK:  ## Test plan <!-- All pull requests REQUIRE a test plan: https://docs-legacy.sourcegraph.com/dev/background-information/testing_principles --> Always stream Cody Gateway's requests for Google Gemini models as we haven't implemented Code Completion feature on the client side. ### Non-stream request ``` ❯ curl 'https://sourcegraph.test:3443/.api/completions/code' -i \ -X POST \ -H 'authorization: token LOCALTOKEN' \ --data-raw '{"messages":[{"speaker":"human","text":"Who are you?"}],"maxTokensToSample":30,"temperature":0,"stopSequences":[],"timeoutMs":5000}' HTTP/2 200 access-control-allow-credentials: true access-control-allow-origin: alt-svc: h3=":3443"; ma=2592000 cache-control: no-cache, max-age=0 content-type: text/plain; charset=utf-8 date: Tue, 18 Jun 2024 21:05:38 GMT server: Caddy server: Caddy set-cookie: sourcegraphDeviceId=d4fa7789-2442-472a-b425-a68372d27944; Expires=Wed, 18 Jun 2025 21:05:36 GMT; Secure vary: Cookie, Accept-Encoding, Authorization, Cookie, Authorization, X-Requested-With, Cookie x-content-type-options: nosniff x-frame-options: DENY x-powered-by: Express x-trace: 00f998a2a2e1b6895687ad7cc567b41c x-trace-span: da9c93d16415b94f x-trace-url: https://sourcegraph.test:3443/-/debug/jaeger/trace/00f998a2a2e1b6895687ad7cc567b41c x-xss-protection: 1; mode=block content-length: 147 {"completion":"I am a large language model, trained by Google. \n\nHere's what that means:\n\n* **Large Language Model:** I'm","stopReason":"STOP"}% ``` ### Streaming request: ``` ❯ curl 'https://sourcegraph.test:3443/.api/completions/stream' -i \ -X POST \ -H 'authorization: token $LOCALTOKEN' \ --data-raw '{"stream":true,"messages":[{"speaker":"human","text":"Who are you?"}],"maxTokensToSample":1000,"temperature":0,"stopSequences":[],"timeoutMs":5000}' HTTP/2 200 access-control-allow-credentials: true access-control-allow-origin: alt-svc: h3=":3443"; ma=2592000 cache-control: no-cache content-type: text/event-stream date: Tue, 18 Jun 2024 21:07:02 GMT server: Caddy server: Caddy set-cookie: sourcegraphDeviceId=38b45f36-d237-4f8d-8242-a63fcc801a32; Expires=Wed, 18 Jun 2025 21:06:59 GMT; Secure vary: Cookie, Accept-Encoding, Authorization, Cookie, Authorization, X-Requested-With, Cookie x-accel-buffering: no x-content-type-options: nosniff x-frame-options: DENY x-powered-by: Express x-trace: 984932973626e14f7cb0ce7e8e470717 x-trace-span: d285179cfb744e08 x-trace-url: https://sourcegraph.test:3443/-/debug/jaeger/trace/984932973626e14f7cb0ce7e8e470717 x-xss-protection: 1; mode=block event: completion data: {"completion":"I","stopReason":"STOP"} event: completion data: {"completion":"I am a large language model, trained by Google. \n\nHere's what","stopReason":"STOP"} event: completion data: {"completion":"I am a large language model, trained by Google. \n\nHere's what that means:\n\n* **I am not a person.** I am a computer","stopReason":"STOP"} event: completion data: {"completion":"I am a large language model, trained by Google. \n\nHere's what that means:\n\n* **I am not a person.** I am a computer program designed to process and generate human-like text. \n* **I learn from data.** I was trained on a massive dataset of text and code,","stopReason":"STOP"} event: completion data: {"completion":"I am a large language model, trained by Google. \n\nHere's what that means:\n\n* **I am not a person.** I am a computer program designed to process and generate human-like text. \n* **I learn from data.** I was trained on a massive dataset of text and code, which allows me to generate text, translate languages, write different kinds of creative content, and answer your questions in an informative way.\n* **I am still","stopReason":"STOP"} event: completion data: {"completion":"I am a large language model, trained by Google. \n\nHere's what that means:\n\n* **I am not a person.** I am a computer program designed to process and generate human-like text. \n* **I learn from data.** I was trained on a massive dataset of text and code, which allows me to generate text, translate languages, write different kinds of creative content, and answer your questions in an informative way.\n* **I am still under development.** I am constantly learning and improving, but I am not perfect and can sometimes make mistakes.\n\nHow can I help you today? \n","stopReason":"STOP"} event: done data: {} ``` ## Changelog <!-- 1. Ensure your pull request title is formatted as: $type($domain): $what 2. Add bullet list items for each additional detail you want to cover (see example below) 3. You can edit this after the pull request was merged, as long as release shipping it hasn't been promoted to the public. 4. For more information, please see this how-to https://www.notion.so/sourcegraph/Writing-a-changelog-entry-dd997f411d524caabf0d8d38a24a878c? Audience: TS/CSE > Customers > Teammates (in that order). Cheat sheet: $type = chore|fix|feat $domain: source|search|ci|release|plg|cody|local|... --> <!-- Example: Title: fix(search): parse quotes with the appropriate context Changelog section: ## Changelog - When a quote is used with regexp pattern type, then ... - Refactored underlying code. -->
This commit is contained in:
parent
a333771bd4
commit
0c777bac41
@ -63,9 +63,11 @@ func (r googleRequest) BuildPrompt() string {
|
||||
func (g *GoogleHandlerMethods) getAPIURL(feature codygateway.Feature, req googleRequest) string {
|
||||
rpc := "generateContent"
|
||||
sseSuffix := ""
|
||||
if feature == codygateway.FeatureChatCompletions {
|
||||
// If we're streaming, we need to use the stream endpoint.
|
||||
if feature == codygateway.FeatureChatCompletions || req.ShouldStream() {
|
||||
rpc = "streamGenerateContent"
|
||||
sseSuffix = "&alt=sse"
|
||||
req.Stream = true
|
||||
}
|
||||
return fmt.Sprintf("https://generativelanguage.googleapis.com/v1beta/models/%s:%s?key=%s%s", req.Model, rpc, g.config.AccessToken, sseSuffix)
|
||||
}
|
||||
@ -93,10 +95,13 @@ func (g *GoogleHandlerMethods) shouldFlagRequest(ctx context.Context, logger log
|
||||
}
|
||||
|
||||
// Used to modify the request body before it is sent to upstream.
|
||||
func (*GoogleHandlerMethods) transformBody(*googleRequest, string) {}
|
||||
func (*GoogleHandlerMethods) transformBody(gr *googleRequest, _ string) {
|
||||
// Remove Stream from the request body before sending it to Google.
|
||||
gr.Stream = false
|
||||
}
|
||||
|
||||
func (*GoogleHandlerMethods) getRequestMetadata(body googleRequest) (model string, additionalMetadata map[string]any) {
|
||||
return body.Model, map[string]any{"stream": body.Stream}
|
||||
return body.Model, map[string]any{"stream": body.ShouldStream()}
|
||||
}
|
||||
|
||||
func (o *GoogleHandlerMethods) transformRequest(r *http.Request) {
|
||||
@ -106,10 +111,9 @@ func (o *GoogleHandlerMethods) transformRequest(r *http.Request) {
|
||||
func (*GoogleHandlerMethods) parseResponseAndUsage(logger log.Logger, reqBody googleRequest, r io.Reader) (promptUsage, completionUsage usageStats) {
|
||||
// First, extract prompt usage details from the request.
|
||||
promptUsage.characters = len(reqBody.BuildPrompt())
|
||||
|
||||
// Try to parse the request we saw, if it was non-streaming, we can simply parse
|
||||
// it as JSON.
|
||||
if !reqBody.ShouldStream() {
|
||||
if !reqBody.Stream && !reqBody.ShouldStream() {
|
||||
var res googleResponse
|
||||
if err := json.NewDecoder(r).Decode(&res); err != nil {
|
||||
logger.Error("failed to parse Google response as JSON", log.Error(err))
|
||||
|
||||
@ -11,8 +11,8 @@ type googleRequest struct {
|
||||
SymtemInstruction string `json:"systemInstruction,omitempty"`
|
||||
|
||||
// Stream is used for our internal routing of the Google Request, and is not part
|
||||
// of the Google API shape. So we make sure to not include it when marshaling into JSON.
|
||||
Stream bool `json:"-"` // This field will not be marshaled into JSON
|
||||
// of the Google API shape.
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
}
|
||||
|
||||
type googleContentMessage struct {
|
||||
@ -36,12 +36,13 @@ type googleGenerationConfig struct {
|
||||
}
|
||||
|
||||
type googleResponse struct {
|
||||
Candidates []struct {
|
||||
Content googleContentMessage `json:"content,omitempty"`
|
||||
FinishReason string `json:"finishReason,omitempty"`
|
||||
} `json:"candidates"`
|
||||
Candidates []googleCandidates `json:"candidates,omitempty"`
|
||||
UsageMetadata googleUsage `json:"usageMetadata,omitempty"`
|
||||
}
|
||||
|
||||
UsageMetadata googleUsage `json:"usageMetadata,omitempty"`
|
||||
type googleCandidates struct {
|
||||
Content googleContentMessage `json:"content,omitempty"`
|
||||
FinishReason string `json:"finishReason,omitempty"`
|
||||
SafetyRatings []googleSafetyRatings `json:"safetyRatings,omitempty"`
|
||||
}
|
||||
|
||||
|
||||
@ -42,7 +42,7 @@ func getBasic(endpoint string, provider conftypes.CompletionsProviderName, acces
|
||||
case conftypes.CompletionsProviderNameAzureOpenAI:
|
||||
return azureopenai.NewClient(azureopenai.GetAPIClient, endpoint, accessToken, *tokenManager)
|
||||
case conftypes.CompletionsProviderNameGoogle:
|
||||
return google.NewClient(httpcli.UncachedExternalDoer, endpoint, accessToken), nil
|
||||
return google.NewClient(httpcli.UncachedExternalDoer, endpoint, accessToken, false), nil
|
||||
case conftypes.CompletionsProviderNameSourcegraph:
|
||||
return codygateway.NewClient(httpcli.CodyGatewayDoer, endpoint, accessToken, *tokenManager)
|
||||
case conftypes.CompletionsProviderNameFireworks:
|
||||
|
||||
@ -91,7 +91,7 @@ func (c *codyGatewayClient) clientForParams(feature types.CompletionsFeature, re
|
||||
case string(conftypes.CompletionsProviderNameFireworks):
|
||||
return fireworks.NewClient(gatewayDoer(c.upstream, feature, c.gatewayURL, c.accessToken, "/v1/completions/fireworks"), "", ""), nil
|
||||
case string(conftypes.CompletionsProviderNameGoogle):
|
||||
return google.NewClient(gatewayDoer(c.upstream, feature, c.gatewayURL, c.accessToken, "/v1/completions/google"), "", ""), nil
|
||||
return google.NewClient(gatewayDoer(c.upstream, feature, c.gatewayURL, c.accessToken, "/v1/completions/google"), "", "", true), nil
|
||||
case "":
|
||||
return nil, errors.Newf("no provider provided in model %s - a model in the format '$PROVIDER/$MODEL_NAME' is expected", model)
|
||||
default:
|
||||
|
||||
@ -15,11 +15,12 @@ import (
|
||||
"github.com/sourcegraph/sourcegraph/lib/errors"
|
||||
)
|
||||
|
||||
func NewClient(cli httpcli.Doer, endpoint, accessToken string) types.CompletionsClient {
|
||||
func NewClient(cli httpcli.Doer, endpoint, accessToken string, viaGateway bool) types.CompletionsClient {
|
||||
return &googleCompletionStreamClient{
|
||||
cli: cli,
|
||||
accessToken: accessToken,
|
||||
endpoint: endpoint,
|
||||
viaGateway: viaGateway,
|
||||
}
|
||||
}
|
||||
|
||||
@ -30,15 +31,7 @@ func (c *googleCompletionStreamClient) Complete(
|
||||
requestParams types.CompletionRequestParameters,
|
||||
logger log.Logger,
|
||||
) (*types.CompletionResponse, error) {
|
||||
var resp *http.Response
|
||||
var err error
|
||||
defer (func() {
|
||||
if resp != nil {
|
||||
resp.Body.Close()
|
||||
}
|
||||
})()
|
||||
|
||||
resp, err = c.makeRequest(ctx, requestParams, false)
|
||||
resp, err := c.makeRequest(ctx, requestParams, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -75,19 +68,11 @@ func (c *googleCompletionStreamClient) Stream(
|
||||
sendEvent types.SendCompletionEvent,
|
||||
logger log.Logger,
|
||||
) error {
|
||||
var resp *http.Response
|
||||
var err error
|
||||
|
||||
defer (func() {
|
||||
if resp != nil {
|
||||
resp.Body.Close()
|
||||
}
|
||||
})()
|
||||
|
||||
resp, err = c.makeRequest(ctx, requestParams, true)
|
||||
resp, err := c.makeRequest(ctx, requestParams, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
dec := NewDecoder(resp.Body)
|
||||
var content string
|
||||
@ -131,19 +116,19 @@ func (c *googleCompletionStreamClient) Stream(
|
||||
|
||||
// makeRequest formats the request and calls the chat/completions endpoint for code_completion requests
|
||||
func (c *googleCompletionStreamClient) makeRequest(ctx context.Context, requestParams types.CompletionRequestParameters, stream bool) (*http.Response, error) {
|
||||
apiURL := c.getAPIURL(requestParams, stream)
|
||||
endpointURL := apiURL.String()
|
||||
|
||||
// Ensure TopK and TopP are non-negative
|
||||
requestParams.TopK = max(0, requestParams.TopK)
|
||||
requestParams.TopP = max(0, requestParams.TopP)
|
||||
|
||||
// Generate the prompt
|
||||
prompt, err := getPrompt(requestParams.Messages)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
payload := googleRequest{
|
||||
Model: requestParams.Model,
|
||||
Stream: stream,
|
||||
Contents: prompt,
|
||||
GenerationConfig: googleGenerationConfig{
|
||||
Temperature: requestParams.Temperature,
|
||||
@ -153,14 +138,18 @@ func (c *googleCompletionStreamClient) makeRequest(ctx context.Context, requestP
|
||||
StopSequences: requestParams.StopSequences,
|
||||
},
|
||||
}
|
||||
if c.viaGateway {
|
||||
endpointURL = c.endpoint
|
||||
// Add the Stream value to the payload if this is a Cody Gateway request,
|
||||
// as it is used for internal routing but not part of the Google API shape.
|
||||
payload.Stream = stream
|
||||
}
|
||||
|
||||
reqBody, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
apiURL := c.getAPIURL(requestParams, stream)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL.String(), bytes.NewReader(reqBody))
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpointURL, bytes.NewReader(reqBody))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -169,7 +158,7 @@ func (c *googleCompletionStreamClient) makeRequest(ctx context.Context, requestP
|
||||
|
||||
// Vertex AI API requires an Authorization header with the access token.
|
||||
// Ref: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/gemini#sample-requests
|
||||
if !isDefaultAPIEndpoint(apiURL) {
|
||||
if !c.viaGateway && !isDefaultAPIEndpoint(apiURL) {
|
||||
req.Header.Set("Authorization", "Bearer "+c.accessToken)
|
||||
}
|
||||
|
||||
|
||||
@ -31,7 +31,7 @@ func TestErrStatusNotOK(t *testing.T) {
|
||||
Body: io.NopCloser(bytes.NewReader([]byte("oh no, please slow down!"))),
|
||||
}, nil
|
||||
},
|
||||
}, "", "")
|
||||
}, "", "", false)
|
||||
|
||||
t.Run("Complete", func(t *testing.T) {
|
||||
logger := log.Scoped("completions")
|
||||
|
||||
@ -6,6 +6,7 @@ type googleCompletionStreamClient struct {
|
||||
cli httpcli.Doer
|
||||
accessToken string
|
||||
endpoint string
|
||||
viaGateway bool
|
||||
}
|
||||
|
||||
// The request body for the completion stream endpoint.
|
||||
@ -19,8 +20,8 @@ type googleRequest struct {
|
||||
SymtemInstruction string `json:"systemInstruction,omitempty"`
|
||||
|
||||
// Stream is used for our internal routing of the Google Request, and is not part
|
||||
// of the Google API shape. So we make sure to not include it when marshaling into JSON.
|
||||
Stream bool `json:"-"` // This field will not be marshaled into JSON
|
||||
// of the Google API shape.
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
}
|
||||
|
||||
type googleContentMessage struct {
|
||||
@ -44,12 +45,13 @@ type googleGenerationConfig struct {
|
||||
}
|
||||
|
||||
type googleResponse struct {
|
||||
Candidates []struct {
|
||||
Content googleContentMessage `json:"content,omitempty"`
|
||||
FinishReason string `json:"finishReason,omitempty"`
|
||||
} `json:"candidates"`
|
||||
Candidates []googleCandidates `json:"candidates,omitempty"`
|
||||
UsageMetadata googleUsage `json:"usageMetadata,omitempty"`
|
||||
}
|
||||
|
||||
UsageMetadata googleUsage `json:"usageMetadata,omitempty"`
|
||||
type googleCandidates struct {
|
||||
Content googleContentMessage `json:"content,omitempty"`
|
||||
FinishReason string `json:"finishReason,omitempty"`
|
||||
SafetyRatings []googleSafetyRatings `json:"safetyRatings,omitempty"`
|
||||
}
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user