mirror of
https://github.com/sourcegraph/sourcegraph.git
synced 2026-02-06 18:51:59 +00:00
Cody Gateway: Add support for Google non-streaming endpoint (#63166)
Add support for non-stream request for Google Gemini provider - Added `Stream` field to `googleRequest` struct to enable streaming completions - Added `SymtemInstruction` field to `googleRequest` struct to allow setting system instructions - Updated `GoogleHandlerMethods.validateRequest` to allow `FeatureEmbeddings` instead of `FeatureCodeCompletions` - Updated `GoogleHandlerMethods.getRequestMetadata` to return the `Stream` field - Updated `GoogleGatewayFeatureClient.GetRequest` to handle streaming for both `FeatureCodeCompletions` and `FeatureChatCompletions` - Removed unsupported feature checks in `googleCompletionStreamClient` - Added Gemini 1.5 Flash and Gemini 1.0 Pro to autocomplete allowed list (but not supported by clients atm) <!-- 💡 To write a useful PR description, make sure that your description covers: - WHAT this PR is changing: - How was it PREVIOUSLY. - How it will be from NOW on. - WHY this PR is needed. - CONTEXT, i.e. to which initiative, project or RFC it belongs. The structure of the description doesn't matter as much as covering these points, so use your best judgement based on your context. Learn how to write good pull request description: https://www.notion.so/sourcegraph/Write-a-good-pull-request-description-610a7fd3e613496eb76f450db5a49b6e?pvs=4 --> ## Test plan <!-- All pull requests REQUIRE a test plan: https://docs-legacy.sourcegraph.com/dev/background-information/testing_principles --> Unit tests updated for non-stream request. To manually test this: 1. In your Soucegraph local instance's Site Config, add the following: ``` "completions": { "accessToken": "REDACTED", "chatModel": "gemini-1.5-pro-latest", "completionModel": "google/gemini-1.5-flash-latest", "provider": "google", ``` Note: You can get the accessToken for Gemini API in 1Password. 2. After saving the site config with the above change, run the following curl command that hits the code endpoint: ``` curl 'https://sourcegraph.test:3443/.api/completions/stream' -i \ -X POST \ -H 'authorization: token $YOUR_LOCAL_TOKEN' \ --data-raw '{"messages":[{"speaker":"human","text":"Who are you?"}],"maxTokensToSample":30,"temperature":0,"stopSequences":[],"timeoutMs":5000,"stream":false,"model":"gemini-1.5-pro-latest"}' ``` Output: ``` ❯ curl 'https://sourcegraph.test:3443/.api/completions/stream' -i \ -X POST \ -H 'authorization: token $YOUR_LOCAL_TOKEN' \ --data-raw '{"messages":[{"speaker":"human","text":"Who are you?"}],"maxTokensToSample":30,"temperature":0,"stopSequences":[],"timeoutMs":5000,"stream":false,"model":"gemini-1.5-pro-latest"}' HTTP/2 200 access-control-allow-credentials: true access-control-allow-origin: alt-svc: h3=":3443"; ma=2592000 cache-control: no-cache, max-age=0 content-type: text/plain; charset=utf-8 date: Tue, 11 Jun 2024 17:02:19 GMT server: Caddy server: Caddy vary: Accept-Encoding, Authorization, Cookie, Authorization, X-Requested-With, Cookie x-content-type-options: nosniff x-frame-options: DENY x-powered-by: Express x-trace: e11a2ce292639414dd2ccdfcbfa89611 x-trace-span: 9457aa0dd0e09b6c x-trace-url: https://sourcegraph.test:3443/-/debug/jaeger/trace/e11a2ce292639414dd2ccdfcbfa89611 x-xss-protection: 1; mode=block content-length: 154 {"completion":"I am a large language model, trained by Google. \n\nHere's what that means:\n\n* **I am a computer program:** I","stopReason":"MAX_TOKENS"}% ``` ## Changelog <!-- 1. Ensure your pull request title is formatted as: $type($domain): $what 2. Add bullet list items for each additional detail you want to cover (see example below) 3. You can edit this after the pull request was merged, as long as release shipping it hasn't been promoted to the public. 4. For more information, please see this how-to https://www.notion.so/sourcegraph/Writing-a-changelog-entry-dd997f411d524caabf0d8d38a24a878c? Audience: TS/CSE > Customers > Teammates (in that order). Cheat sheet: $type = chore|fix|feat $domain: source|search|ci|release|plg|cody|local|... --> <!-- Example: Title: fix(search): parse quotes with the appropriate context Changelog section: ## Changelog - When a quote is used with regexp pattern type, then ... - Refactored underlying code. -->
This commit is contained in:
parent
549beac5ec
commit
e1551657b1
@ -41,10 +41,12 @@ func NewGoogleHandler(baseLogger log.Logger, eventLogger events.Logger, rs limit
|
||||
// The request body for Google completions.
|
||||
// Ref: https://ai.google.dev/api/rest/v1/models/generateContent#request-body
|
||||
type googleRequest struct {
|
||||
Model string `json:"model"`
|
||||
Contents []googleContentMessage `json:"contents"`
|
||||
GenerationConfig googleGenerationConfig `json:"generationConfig,omitempty"`
|
||||
SafetySettings []googleSafetySettings `json:"safetySettings,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Contents []googleContentMessage `json:"contents"`
|
||||
GenerationConfig googleGenerationConfig `json:"generationConfig,omitempty"`
|
||||
SafetySettings []googleSafetySettings `json:"safetySettings,omitempty"`
|
||||
SymtemInstruction string `json:"systemInstruction,omitempty"`
|
||||
}
|
||||
|
||||
type googleContentMessage struct {
|
||||
@ -73,7 +75,7 @@ type googleSafetySettings struct {
|
||||
}
|
||||
|
||||
func (r googleRequest) ShouldStream() bool {
|
||||
return true
|
||||
return r.Stream
|
||||
}
|
||||
|
||||
func (r googleRequest) GetModel() string {
|
||||
@ -119,7 +121,7 @@ func (g *GoogleHandlerMethods) getAPIURL(_ codygateway.Feature, req googleReques
|
||||
}
|
||||
|
||||
func (*GoogleHandlerMethods) validateRequest(_ context.Context, _ log.Logger, feature codygateway.Feature, _ googleRequest) error {
|
||||
if feature == codygateway.FeatureCodeCompletions {
|
||||
if feature == codygateway.FeatureEmbeddings {
|
||||
return errors.Newf("feature %q is currently not supported for Google", feature)
|
||||
}
|
||||
return nil
|
||||
@ -141,7 +143,7 @@ func (*GoogleHandlerMethods) transformBody(_ *googleRequest, _ string) {
|
||||
}
|
||||
|
||||
func (*GoogleHandlerMethods) getRequestMetadata(body googleRequest) (model string, additionalMetadata map[string]any) {
|
||||
return body.Model, map[string]any{"stream": body.ShouldStream()}
|
||||
return body.Model, map[string]any{"stream": body.Stream}
|
||||
}
|
||||
|
||||
func (o *GoogleHandlerMethods) transformRequest(r *http.Request) {
|
||||
@ -156,7 +158,6 @@ func (*GoogleHandlerMethods) parseResponseAndUsage(logger log.Logger, reqBody go
|
||||
// it as JSON.
|
||||
if !reqBody.ShouldStream() {
|
||||
var res googleResponse
|
||||
|
||||
if err := json.NewDecoder(r).Decode(&res); err != nil {
|
||||
logger.Error("failed to parse Google response as JSON", log.Error(err))
|
||||
return promptUsage, completionUsage
|
||||
|
||||
@ -14,7 +14,7 @@ func TestGoogleRequestGetTokenCount(t *testing.T) {
|
||||
logger := logtest.Scoped(t)
|
||||
|
||||
t.Run("streaming", func(t *testing.T) {
|
||||
req := googleRequest{}
|
||||
req := googleRequest{Stream: true}
|
||||
r := strings.NewReader(googleStreamingResponse)
|
||||
handler := &GoogleHandlerMethods{}
|
||||
promptUsage, completionUsage := handler.parseResponseAndUsage(logger, req, r)
|
||||
@ -22,6 +22,16 @@ func TestGoogleRequestGetTokenCount(t *testing.T) {
|
||||
assert.Equal(t, 21, promptUsage.tokens)
|
||||
assert.Equal(t, 87, completionUsage.tokens)
|
||||
})
|
||||
|
||||
t.Run("non-streaming", func(t *testing.T) {
|
||||
req := googleRequest{Stream: false}
|
||||
r := strings.NewReader(googleNonStreamingResponse)
|
||||
handler := &GoogleHandlerMethods{}
|
||||
promptUsage, completionUsage := handler.parseResponseAndUsage(logger, req, r)
|
||||
|
||||
assert.Equal(t, 59, promptUsage.tokens)
|
||||
assert.Equal(t, 54, completionUsage.tokens)
|
||||
})
|
||||
}
|
||||
|
||||
var googleStreamingResponse = `data: {"candidates": [{"content": {"parts": [{"text": "def"}],"role": "model"},"finishReason": "STOP","index": 0}],"usageMetadata": {"promptTokenCount": 21,"candidatesTokenCount": 1,"totalTokenCount": 22}}
|
||||
@ -36,6 +46,47 @@ data: {"candidates": [{"content": {"parts": [{"text": "1[j+1] = list1[j+1], list
|
||||
|
||||
`
|
||||
|
||||
var googleNonStreamingResponse = `{
|
||||
"candidates": [
|
||||
{
|
||||
"content": {
|
||||
"parts": [
|
||||
{
|
||||
"text": "The cobblestone path, worn smooth by centuries of weary feet, led to a humble cottage nestled within the quiet village of Saint-Martin, where a young boy named Pierre discovered a weathered, leather backpack tucked beneath the gnarled oak tree in his grandmother's garden. \n"
|
||||
}
|
||||
],
|
||||
"role": "model"
|
||||
},
|
||||
"finishReason": "STOP",
|
||||
"index": 0,
|
||||
"safetyRatings": [
|
||||
{
|
||||
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||
"probability": "NEGLIGIBLE"
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_HATE_SPEECH",
|
||||
"probability": "NEGLIGIBLE"
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_HARASSMENT",
|
||||
"probability": "NEGLIGIBLE"
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||
"probability": "NEGLIGIBLE"
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"usageMetadata": {
|
||||
"promptTokenCount": 59,
|
||||
"candidatesTokenCount": 54,
|
||||
"totalTokenCount": 113
|
||||
}
|
||||
}
|
||||
`
|
||||
|
||||
func TestParseGoogleTokenUsage(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@ -7,12 +7,13 @@ import (
|
||||
"github.com/Khan/genqlient/graphql"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/sourcegraph/log"
|
||||
"github.com/sourcegraph/sourcegraph/internal/collections"
|
||||
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
|
||||
"go.opentelemetry.io/otel"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/internal/collections"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/auth"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/events"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/httpapi/attribution"
|
||||
|
||||
@ -14,10 +14,24 @@ type GoogleGatewayFeatureClient struct{}
|
||||
|
||||
func (o GoogleGatewayFeatureClient) GetRequest(f codygateway.Feature, req *http.Request, stream bool) (*http.Request, error) {
|
||||
if f == codygateway.FeatureCodeCompletions {
|
||||
return nil, errNotImplemented
|
||||
body := fmt.Sprintf(`{
|
||||
"model":"gemini-1.5-flash-latest",
|
||||
"contents":[{"parts":[{"text":"You are Cody"}],"role":"user"},{"parts":[{"text":"Ok I am Cody"}],"role":"model"},{"parts":[{"text":"Write Bubble sort in Python"}],"role":"user"}],
|
||||
"generationConfig":{"temperature":0.2,"topP":0.95,"topP":0.95,"maxOutputTokens":1000},
|
||||
"stream":%t
|
||||
}`, stream)
|
||||
req.Method = "POST"
|
||||
req.URL.Path = "/v1/completions/google"
|
||||
req.Body = io.NopCloser(strings.NewReader(body))
|
||||
return req, nil
|
||||
}
|
||||
if f == codygateway.FeatureChatCompletions {
|
||||
body := fmt.Sprintf(`{"model":"gemini-1.5-flash-latest","contents":[{"parts":[{"text":"You are Cody"}],"role":"user"},{"parts":[{"text":"Ok I am Cody"}],"role":"model"},{"parts":[{"text":"Write Bubble sort in Python"}],"role":"user"}],"stream":%t}`, stream)
|
||||
body := fmt.Sprintf(`{
|
||||
"model":"gemini-1.5-pro-latest",
|
||||
"contents":[{"parts":[{"text":"You are Cody"}],"role":"user"},{"parts":[{"text":"Ok I am Cody"}],"role":"model"},{"parts":[{"text":"Write Bubble sort in Python"}],"role":"user"}],
|
||||
"generationConfig":{"temperature":0.2,"topP":0.95,"topP":0.95,"maxOutputTokens":1000},
|
||||
"stream":%t
|
||||
}`, stream)
|
||||
req.Method = "POST"
|
||||
req.URL.Path = "/v1/completions/google"
|
||||
req.Body = io.NopCloser(strings.NewReader(body))
|
||||
|
||||
@ -7,6 +7,7 @@ import (
|
||||
"github.com/sourcegraph/log"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/internal/completions/client/fireworks"
|
||||
"github.com/sourcegraph/sourcegraph/internal/completions/client/google"
|
||||
"github.com/sourcegraph/sourcegraph/internal/completions/types"
|
||||
"github.com/sourcegraph/sourcegraph/internal/conf/conftypes"
|
||||
"github.com/sourcegraph/sourcegraph/internal/database"
|
||||
@ -67,6 +68,8 @@ func allowedCustomModel(model string) string {
|
||||
"anthropic/claude-instant-v1",
|
||||
"anthropic/claude-instant-1",
|
||||
"anthropic/claude-instant-1.2-cyan",
|
||||
"google/" + google.Gemini15Flash,
|
||||
"google/" + google.GeminiPro,
|
||||
"fireworks/accounts/sourcegraph/models/starcoder-7b",
|
||||
"fireworks/accounts/sourcegraph/models/starcoder-16b",
|
||||
"fireworks/accounts/fireworks/models/starcoder-3b-w8a16",
|
||||
|
||||
@ -30,10 +30,6 @@ func (c *googleCompletionStreamClient) Complete(
|
||||
requestParams types.CompletionRequestParameters,
|
||||
logger log.Logger,
|
||||
) (*types.CompletionResponse, error) {
|
||||
if !isSupportedFeature(feature) {
|
||||
return nil, errors.Newf("feature %q is currently not supported for Google", feature)
|
||||
}
|
||||
|
||||
var resp *http.Response
|
||||
var err error
|
||||
defer (func() {
|
||||
@ -79,10 +75,6 @@ func (c *googleCompletionStreamClient) Stream(
|
||||
sendEvent types.SendCompletionEvent,
|
||||
logger log.Logger,
|
||||
) error {
|
||||
if !isSupportedFeature(feature) {
|
||||
return errors.Newf("feature %q is currently not supported for Google", feature)
|
||||
}
|
||||
|
||||
var resp *http.Response
|
||||
var err error
|
||||
|
||||
@ -151,6 +143,7 @@ func (c *googleCompletionStreamClient) makeRequest(ctx context.Context, requestP
|
||||
|
||||
payload := googleRequest{
|
||||
Model: requestParams.Model,
|
||||
Stream: stream,
|
||||
Contents: prompt,
|
||||
GenerationConfig: googleGenerationConfig{
|
||||
Temperature: requestParams.Temperature,
|
||||
@ -228,17 +221,6 @@ func getgRPCMethod(stream bool) string {
|
||||
return "generateContent"
|
||||
}
|
||||
|
||||
// isSupportedFeature checks if the given CompletionsFeature is supported.
|
||||
// Currently, only the CompletionsFeatureChat feature is supported.
|
||||
func isSupportedFeature(feature types.CompletionsFeature) bool {
|
||||
switch feature {
|
||||
case types.CompletionsFeatureChat:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// isDefaultAPIEndpoint checks if the given API endpoint URL is the default API endpoint.
|
||||
// The default API endpoint is determined by the defaultAPIHost constant.
|
||||
func isDefaultAPIEndpoint(endpoint *url.URL) bool {
|
||||
|
||||
@ -8,11 +8,16 @@ type googleCompletionStreamClient struct {
|
||||
endpoint string
|
||||
}
|
||||
|
||||
// The request body for the completion stream endpoint.
|
||||
// Ref: https://ai.google.dev/api/rest/v1beta/models/generateContent
|
||||
// Ref: https://ai.google.dev/api/rest/v1beta/models/streamGenerateContent
|
||||
type googleRequest struct {
|
||||
Model string `json:"model"`
|
||||
Contents []googleContentMessage `json:"contents"`
|
||||
GenerationConfig googleGenerationConfig `json:"generationConfig,omitempty"`
|
||||
SafetySettings []googleSafetySettings `json:"safetySettings,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Contents []googleContentMessage `json:"contents"`
|
||||
GenerationConfig googleGenerationConfig `json:"generationConfig,omitempty"`
|
||||
SafetySettings []googleSafetySettings `json:"safetySettings,omitempty"`
|
||||
SymtemInstruction string `json:"systemInstruction,omitempty"`
|
||||
}
|
||||
|
||||
type googleContentMessage struct {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user