Cody Gateway embeddings: powering with generated metadata (#63000)

Connected to https://github.com/sourcegraph/bfg-private/pull/189 and
https://github.com/sourcegraph/cody/pull/4414.

We're introducing a hacky MVP to enable embeddings being powered by
metadata that's generated from code. This PR is the bare minimum to make
this work on CG. We plan to trigger metadata generation only if we're
using a new (fake) model (this comes in via a feature flag) and if the
request isn't a real-time query, but is a background indexing request.
The implementation is really hacky, but is also really minimal.

## Test plan
Testing locally through a feature flag.
This commit is contained in:
Jan Hartman 2024-06-05 13:33:10 +02:00 committed by GitHub
parent 507668fefd
commit 4327bf8fc1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 110 additions and 12 deletions

View File

@ -27,6 +27,8 @@ go_library(
"//cmd/cody-gateway/shared/config",
"//internal/authbearer",
"//internal/collections",
"//internal/completions/client/fireworks",
"//internal/conf",
"//internal/httpcli",
"//internal/instrumentation",
"//internal/redispool",

View File

@ -5,6 +5,7 @@ go_library(
name = "embeddings",
srcs = [
"handler.go",
"metadata.go",
"models.go",
"openai.go",
"sourcegraph.go",
@ -21,12 +22,15 @@ go_library(
"//cmd/cody-gateway/internal/notify",
"//cmd/cody-gateway/internal/response",
"//internal/codygateway",
"//internal/completions/client/fireworks",
"//internal/completions/types",
"//internal/httpcli",
"//internal/trace",
"//lib/errors",
"@com_github_go_json_experiment_json//:json",
"@com_github_google_uuid//:uuid",
"@com_github_json_iterator_go//:go",
"@com_github_sourcegraph_conc//iter",
"@com_github_sourcegraph_log//:log",
"@io_opentelemetry_go_otel//attribute",
"@io_opentelemetry_go_otel_trace//:trace",

View File

@ -12,15 +12,15 @@ import (
"go.opentelemetry.io/otel/attribute"
oteltrace "go.opentelemetry.io/otel/trace"
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/httpapi/overhead"
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/actor"
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/events"
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/httpapi/featurelimiter"
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/httpapi/overhead"
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/limiter"
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/notify"
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/response"
"github.com/sourcegraph/sourcegraph/internal/codygateway"
"github.com/sourcegraph/sourcegraph/internal/completions/types"
sgtrace "github.com/sourcegraph/sourcegraph/internal/trace"
"github.com/sourcegraph/sourcegraph/lib/errors"
)
@ -34,6 +34,7 @@ func NewHandler(
rateLimitNotifier notify.RateLimitNotifier,
mf ModelFactory,
prefixedAllowedModels []string,
completionsClient types.CompletionsClient,
) http.Handler {
baseLogger = baseLogger.Scoped("embeddingshandler")
@ -118,6 +119,7 @@ func NewHandler(
}
return characters
}(),
"is_query": body.IsQuery,
},
},
)
@ -126,6 +128,25 @@ func NewHandler(
}
}()
// Hacky experiment: Replace embedding model input with generated metadata text when indexing.
if body.Model == string(ModelNameSourcegraphMetadataGen) {
newInput := body.Input
// Generate metadata if we are indexing, not querying.
if !body.IsQuery {
var err error
newInput, err = generateMetadata(r.Context(), body, logger, completionsClient)
if err != nil {
logger.Error("failed to generate metadata", log.Error(err))
return
}
}
body = codygateway.EmbeddingsRequest{
Model: string(ModelNameSourcegraphSTMultiQA),
Input: newInput,
IsQuery: body.IsQuery,
}
}
resp, ut, err := c.GenerateEmbeddings(r.Context(), body)
usedTokens = ut
upstreamFinished = time.Since(upstreamStarted)

View File

@ -0,0 +1,55 @@
package embeddings
import (
"context"
"github.com/sourcegraph/conc/iter"
"github.com/sourcegraph/log"
"github.com/sourcegraph/sourcegraph/internal/codygateway"
"github.com/sourcegraph/sourcegraph/internal/completions/client/fireworks"
"github.com/sourcegraph/sourcegraph/internal/completions/types"
)
type metadataClient struct {
logger log.Logger
completionsClient types.CompletionsClient
ctx context.Context
}
func generateMetadata(ctx context.Context, req codygateway.EmbeddingsRequest, logger log.Logger, completionsClient types.CompletionsClient) ([]string, error) {
c := metadataClient{
logger: logger.Scoped("metadata_gen"),
completionsClient: completionsClient,
ctx: ctx,
}
mapper := iter.Mapper[string, string]{MaxGoroutines: 15}
return mapper.MapErr(req.Input, c.generateMetadataForChunk)
}
func (c *metadataClient) generateMetadataForChunk(input *string) (string, error) {
promptText := `Here is a section of code.
Please write a paragraph of documentation for each high-level class, struct, function or similar.
Be concise, write no more than a few sentences for each entry.
Return your response in text format. Each entry name should be followed by a newline, then its documentation.
Respond with nothing else, only the entry names and the documentation. Code: ` +
"````" + *input + "```"
resp, err := c.completionsClient.Complete(c.ctx, types.CompletionsFeatureChat, types.CompletionsVersionLegacy,
types.CompletionRequestParameters{
Messages: []types.Message{{
Speaker: "user",
Text: promptText,
}},
MaxTokensToSample: 2000,
Temperature: 0,
TopP: 1,
Model: fireworks.Llama370bInstruct,
}, c.logger)
if err != nil {
return "", err
}
return resp.Completion, nil
}

View File

@ -13,8 +13,9 @@ import (
type PrefixedModelName string
const (
ModelNameOpenAIAda PrefixedModelName = "openai/text-embedding-ada-002"
ModelNameSourcegraphSTMultiQA PrefixedModelName = "sourcegraph/st-multi-qa-mpnet-base-dot-v1"
ModelNameOpenAIAda PrefixedModelName = "openai/text-embedding-ada-002"
ModelNameSourcegraphSTMultiQA PrefixedModelName = "sourcegraph/st-multi-qa-mpnet-base-dot-v1"
ModelNameSourcegraphMetadataGen PrefixedModelName = "sourcegraph/st-multi-qa-mpnet-metadata"
)
type EmbeddingsClient interface {
@ -59,6 +60,12 @@ func NewListHandler(prefixedAllowedModels []string) http.Handler {
Dimensions: 768,
Deprecated: false,
},
{
Enabled: modelEnabled(ModelNameSourcegraphMetadataGen),
Name: string(ModelNameSourcegraphMetadataGen),
Dimensions: 768,
Deprecated: false,
},
}
_ = json.NewEncoder(w).Encode(models)
})

View File

@ -13,19 +13,19 @@ import (
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/metric"
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/httpapi/overhead"
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/shared/config"
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/auth"
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/events"
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/httpapi/attribution"
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/httpapi/completions"
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/httpapi/embeddings"
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/httpapi/featurelimiter"
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/httpapi/overhead"
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/httpapi/requestlogger"
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/limiter"
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/notify"
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/shared/config"
"github.com/sourcegraph/sourcegraph/internal/completions/client/fireworks"
"github.com/sourcegraph/sourcegraph/internal/conf"
"github.com/sourcegraph/sourcegraph/internal/httpcli"
"github.com/sourcegraph/sourcegraph/internal/instrumentation"
"github.com/sourcegraph/sourcegraph/lib/errors"
@ -162,16 +162,22 @@ func NewHandler(
embeddings.NewListHandler(config.EmbeddingsAllowedModels))
factoryMap := embeddings.ModelFactoryMap{
embeddings.ModelNameOpenAIAda: embeddings.NewOpenAIClient(httpClient, config.OpenAI.AccessToken),
embeddings.ModelNameSourcegraphSTMultiQA: embeddings.NewSourcegraphClient(httpClient, config.Sourcegraph.EmbeddingsAPIURL, config.Sourcegraph.EmbeddingsAPIToken),
embeddings.ModelNameOpenAIAda: embeddings.NewOpenAIClient(httpClient, config.OpenAI.AccessToken),
embeddings.ModelNameSourcegraphSTMultiQA: embeddings.NewSourcegraphClient(httpClient, config.Sourcegraph.EmbeddingsAPIURL, config.Sourcegraph.EmbeddingsAPIToken),
embeddings.ModelNameSourcegraphMetadataGen: embeddings.NewSourcegraphClient(httpClient, config.Sourcegraph.EmbeddingsAPIURL, config.Sourcegraph.EmbeddingsAPIToken),
}
completionsConfig := conf.GetCompletionsConfig(conf.Get().SiteConfig())
fireworksClient := fireworks.NewClient(httpcli.UncachedExternalDoer, completionsConfig.Endpoint, completionsConfig.AccessToken)
embeddingsHandler := embeddings.NewHandler(
logger,
eventLogger,
rs,
config.RateLimitNotifier,
factoryMap,
config.EmbeddingsAllowedModels)
config.EmbeddingsAllowedModels,
fireworksClient)
// TODO: If embeddings.ModelFactoryMap includes more than just OpenAI, we might want to
// revisit how we count concurrent requests into the handler. (Instead of assuming they are
// all OpenAI-related requests. (i.e. maybe we should use something other than

View File

@ -297,7 +297,7 @@ func (c *Config) Load() {
c.AddError(errors.New("must provide allowed models for Google"))
}
c.AllowedEmbeddingsModels = splitMaybe(c.Get("CODY_GATEWAY_ALLOWED_EMBEDDINGS_MODELS", strings.Join([]string{string(embeddings.ModelNameOpenAIAda), string(embeddings.ModelNameSourcegraphSTMultiQA)}, ","), "The models allowed for embeddings generation."))
c.AllowedEmbeddingsModels = splitMaybe(c.Get("CODY_GATEWAY_ALLOWED_EMBEDDINGS_MODELS", strings.Join([]string{string(embeddings.ModelNameOpenAIAda), string(embeddings.ModelNameSourcegraphSTMultiQA), string(embeddings.ModelNameSourcegraphMetadataGen)}, ","), "The models allowed for embeddings generation."))
if len(c.AllowedEmbeddingsModels) == 0 {
c.AddError(errors.New("must provide allowed models for embeddings generation"))
}

View File

@ -46,6 +46,8 @@ type EmbeddingsRequest struct {
Model string `json:"model"`
// Input is the list of strings to generate embeddings for.
Input []string `json:"input"`
// IsQuery is true if the request is used for querying, false if it used for indexing.
IsQuery bool `json:"isQuery"`
}
type Embedding struct {

View File

@ -28,6 +28,7 @@ const Llama27bCode = "accounts/fireworks/models/llama-v2-7b-code"
const Llama213bCode = "accounts/fireworks/models/llama-v2-13b-code"
const Llama213bCodeInstruct = "accounts/fireworks/models/llama-v2-13b-code-instruct"
const Llama234bCodeInstruct = "accounts/fireworks/models/llama-v2-34b-code-instruct"
const Llama370bInstruct = "accounts/fireworks/models/llama-v3-70b-instruct"
const Mistral7bInstruct = "accounts/fireworks/models/mistral-7b-instruct-4k"
const Mixtral8x7bInstruct = "accounts/fireworks/models/mixtral-8x7b-instruct"
const Mixtral8x22Instruct = "accounts/fireworks/models/mixtral-8x22b-instruct"