mirror of
https://github.com/sourcegraph/sourcegraph.git
synced 2026-02-06 17:31:43 +00:00
Cody Gateway embeddings: powering with generated metadata (#63000)
Connected to https://github.com/sourcegraph/bfg-private/pull/189 and https://github.com/sourcegraph/cody/pull/4414. We're introducing a hacky MVP to enable embeddings being powered by metadata that's generated from code. This PR is the bare minimum to make this work on CG. We plan to trigger metadata generation only if we're using a new (fake) model (this comes in via a feature flag) and if the request isn't a real-time query, but is a background indexing request. The implementation is really hacky, but is also really minimal. ## Test plan Testing locally through a feature flag.
This commit is contained in:
parent
507668fefd
commit
4327bf8fc1
@ -27,6 +27,8 @@ go_library(
|
||||
"//cmd/cody-gateway/shared/config",
|
||||
"//internal/authbearer",
|
||||
"//internal/collections",
|
||||
"//internal/completions/client/fireworks",
|
||||
"//internal/conf",
|
||||
"//internal/httpcli",
|
||||
"//internal/instrumentation",
|
||||
"//internal/redispool",
|
||||
|
||||
@ -5,6 +5,7 @@ go_library(
|
||||
name = "embeddings",
|
||||
srcs = [
|
||||
"handler.go",
|
||||
"metadata.go",
|
||||
"models.go",
|
||||
"openai.go",
|
||||
"sourcegraph.go",
|
||||
@ -21,12 +22,15 @@ go_library(
|
||||
"//cmd/cody-gateway/internal/notify",
|
||||
"//cmd/cody-gateway/internal/response",
|
||||
"//internal/codygateway",
|
||||
"//internal/completions/client/fireworks",
|
||||
"//internal/completions/types",
|
||||
"//internal/httpcli",
|
||||
"//internal/trace",
|
||||
"//lib/errors",
|
||||
"@com_github_go_json_experiment_json//:json",
|
||||
"@com_github_google_uuid//:uuid",
|
||||
"@com_github_json_iterator_go//:go",
|
||||
"@com_github_sourcegraph_conc//iter",
|
||||
"@com_github_sourcegraph_log//:log",
|
||||
"@io_opentelemetry_go_otel//attribute",
|
||||
"@io_opentelemetry_go_otel_trace//:trace",
|
||||
|
||||
@ -12,15 +12,15 @@ import (
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
oteltrace "go.opentelemetry.io/otel/trace"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/httpapi/overhead"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/actor"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/events"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/httpapi/featurelimiter"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/httpapi/overhead"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/limiter"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/notify"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/response"
|
||||
"github.com/sourcegraph/sourcegraph/internal/codygateway"
|
||||
"github.com/sourcegraph/sourcegraph/internal/completions/types"
|
||||
sgtrace "github.com/sourcegraph/sourcegraph/internal/trace"
|
||||
"github.com/sourcegraph/sourcegraph/lib/errors"
|
||||
)
|
||||
@ -34,6 +34,7 @@ func NewHandler(
|
||||
rateLimitNotifier notify.RateLimitNotifier,
|
||||
mf ModelFactory,
|
||||
prefixedAllowedModels []string,
|
||||
completionsClient types.CompletionsClient,
|
||||
) http.Handler {
|
||||
baseLogger = baseLogger.Scoped("embeddingshandler")
|
||||
|
||||
@ -118,6 +119,7 @@ func NewHandler(
|
||||
}
|
||||
return characters
|
||||
}(),
|
||||
"is_query": body.IsQuery,
|
||||
},
|
||||
},
|
||||
)
|
||||
@ -126,6 +128,25 @@ func NewHandler(
|
||||
}
|
||||
}()
|
||||
|
||||
// Hacky experiment: Replace embedding model input with generated metadata text when indexing.
|
||||
if body.Model == string(ModelNameSourcegraphMetadataGen) {
|
||||
newInput := body.Input
|
||||
// Generate metadata if we are indexing, not querying.
|
||||
if !body.IsQuery {
|
||||
var err error
|
||||
newInput, err = generateMetadata(r.Context(), body, logger, completionsClient)
|
||||
if err != nil {
|
||||
logger.Error("failed to generate metadata", log.Error(err))
|
||||
return
|
||||
}
|
||||
}
|
||||
body = codygateway.EmbeddingsRequest{
|
||||
Model: string(ModelNameSourcegraphSTMultiQA),
|
||||
Input: newInput,
|
||||
IsQuery: body.IsQuery,
|
||||
}
|
||||
}
|
||||
|
||||
resp, ut, err := c.GenerateEmbeddings(r.Context(), body)
|
||||
usedTokens = ut
|
||||
upstreamFinished = time.Since(upstreamStarted)
|
||||
|
||||
55
cmd/cody-gateway/internal/httpapi/embeddings/metadata.go
Normal file
55
cmd/cody-gateway/internal/httpapi/embeddings/metadata.go
Normal file
@ -0,0 +1,55 @@
|
||||
package embeddings
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/sourcegraph/conc/iter"
|
||||
"github.com/sourcegraph/log"
|
||||
"github.com/sourcegraph/sourcegraph/internal/codygateway"
|
||||
"github.com/sourcegraph/sourcegraph/internal/completions/client/fireworks"
|
||||
"github.com/sourcegraph/sourcegraph/internal/completions/types"
|
||||
)
|
||||
|
||||
type metadataClient struct {
|
||||
logger log.Logger
|
||||
completionsClient types.CompletionsClient
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func generateMetadata(ctx context.Context, req codygateway.EmbeddingsRequest, logger log.Logger, completionsClient types.CompletionsClient) ([]string, error) {
|
||||
c := metadataClient{
|
||||
logger: logger.Scoped("metadata_gen"),
|
||||
completionsClient: completionsClient,
|
||||
ctx: ctx,
|
||||
}
|
||||
|
||||
mapper := iter.Mapper[string, string]{MaxGoroutines: 15}
|
||||
return mapper.MapErr(req.Input, c.generateMetadataForChunk)
|
||||
}
|
||||
|
||||
func (c *metadataClient) generateMetadataForChunk(input *string) (string, error) {
|
||||
promptText := `Here is a section of code.
|
||||
Please write a paragraph of documentation for each high-level class, struct, function or similar.
|
||||
Be concise, write no more than a few sentences for each entry.
|
||||
Return your response in text format. Each entry name should be followed by a newline, then its documentation.
|
||||
Respond with nothing else, only the entry names and the documentation. Code: ` +
|
||||
"````" + *input + "```"
|
||||
|
||||
resp, err := c.completionsClient.Complete(c.ctx, types.CompletionsFeatureChat, types.CompletionsVersionLegacy,
|
||||
types.CompletionRequestParameters{
|
||||
Messages: []types.Message{{
|
||||
Speaker: "user",
|
||||
Text: promptText,
|
||||
}},
|
||||
MaxTokensToSample: 2000,
|
||||
Temperature: 0,
|
||||
TopP: 1,
|
||||
Model: fireworks.Llama370bInstruct,
|
||||
}, c.logger)
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return resp.Completion, nil
|
||||
}
|
||||
@ -13,8 +13,9 @@ import (
|
||||
type PrefixedModelName string
|
||||
|
||||
const (
|
||||
ModelNameOpenAIAda PrefixedModelName = "openai/text-embedding-ada-002"
|
||||
ModelNameSourcegraphSTMultiQA PrefixedModelName = "sourcegraph/st-multi-qa-mpnet-base-dot-v1"
|
||||
ModelNameOpenAIAda PrefixedModelName = "openai/text-embedding-ada-002"
|
||||
ModelNameSourcegraphSTMultiQA PrefixedModelName = "sourcegraph/st-multi-qa-mpnet-base-dot-v1"
|
||||
ModelNameSourcegraphMetadataGen PrefixedModelName = "sourcegraph/st-multi-qa-mpnet-metadata"
|
||||
)
|
||||
|
||||
type EmbeddingsClient interface {
|
||||
@ -59,6 +60,12 @@ func NewListHandler(prefixedAllowedModels []string) http.Handler {
|
||||
Dimensions: 768,
|
||||
Deprecated: false,
|
||||
},
|
||||
{
|
||||
Enabled: modelEnabled(ModelNameSourcegraphMetadataGen),
|
||||
Name: string(ModelNameSourcegraphMetadataGen),
|
||||
Dimensions: 768,
|
||||
Deprecated: false,
|
||||
},
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(models)
|
||||
})
|
||||
|
||||
@ -13,19 +13,19 @@ import (
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/httpapi/overhead"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/shared/config"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/auth"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/events"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/httpapi/attribution"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/httpapi/completions"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/httpapi/embeddings"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/httpapi/featurelimiter"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/httpapi/overhead"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/httpapi/requestlogger"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/limiter"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/notify"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/shared/config"
|
||||
"github.com/sourcegraph/sourcegraph/internal/completions/client/fireworks"
|
||||
"github.com/sourcegraph/sourcegraph/internal/conf"
|
||||
"github.com/sourcegraph/sourcegraph/internal/httpcli"
|
||||
"github.com/sourcegraph/sourcegraph/internal/instrumentation"
|
||||
"github.com/sourcegraph/sourcegraph/lib/errors"
|
||||
@ -162,16 +162,22 @@ func NewHandler(
|
||||
embeddings.NewListHandler(config.EmbeddingsAllowedModels))
|
||||
|
||||
factoryMap := embeddings.ModelFactoryMap{
|
||||
embeddings.ModelNameOpenAIAda: embeddings.NewOpenAIClient(httpClient, config.OpenAI.AccessToken),
|
||||
embeddings.ModelNameSourcegraphSTMultiQA: embeddings.NewSourcegraphClient(httpClient, config.Sourcegraph.EmbeddingsAPIURL, config.Sourcegraph.EmbeddingsAPIToken),
|
||||
embeddings.ModelNameOpenAIAda: embeddings.NewOpenAIClient(httpClient, config.OpenAI.AccessToken),
|
||||
embeddings.ModelNameSourcegraphSTMultiQA: embeddings.NewSourcegraphClient(httpClient, config.Sourcegraph.EmbeddingsAPIURL, config.Sourcegraph.EmbeddingsAPIToken),
|
||||
embeddings.ModelNameSourcegraphMetadataGen: embeddings.NewSourcegraphClient(httpClient, config.Sourcegraph.EmbeddingsAPIURL, config.Sourcegraph.EmbeddingsAPIToken),
|
||||
}
|
||||
|
||||
completionsConfig := conf.GetCompletionsConfig(conf.Get().SiteConfig())
|
||||
fireworksClient := fireworks.NewClient(httpcli.UncachedExternalDoer, completionsConfig.Endpoint, completionsConfig.AccessToken)
|
||||
|
||||
embeddingsHandler := embeddings.NewHandler(
|
||||
logger,
|
||||
eventLogger,
|
||||
rs,
|
||||
config.RateLimitNotifier,
|
||||
factoryMap,
|
||||
config.EmbeddingsAllowedModels)
|
||||
config.EmbeddingsAllowedModels,
|
||||
fireworksClient)
|
||||
// TODO: If embeddings.ModelFactoryMap includes more than just OpenAI, we might want to
|
||||
// revisit how we count concurrent requests into the handler. (Instead of assuming they are
|
||||
// all OpenAI-related requests. (i.e. maybe we should use something other than
|
||||
|
||||
@ -297,7 +297,7 @@ func (c *Config) Load() {
|
||||
c.AddError(errors.New("must provide allowed models for Google"))
|
||||
}
|
||||
|
||||
c.AllowedEmbeddingsModels = splitMaybe(c.Get("CODY_GATEWAY_ALLOWED_EMBEDDINGS_MODELS", strings.Join([]string{string(embeddings.ModelNameOpenAIAda), string(embeddings.ModelNameSourcegraphSTMultiQA)}, ","), "The models allowed for embeddings generation."))
|
||||
c.AllowedEmbeddingsModels = splitMaybe(c.Get("CODY_GATEWAY_ALLOWED_EMBEDDINGS_MODELS", strings.Join([]string{string(embeddings.ModelNameOpenAIAda), string(embeddings.ModelNameSourcegraphSTMultiQA), string(embeddings.ModelNameSourcegraphMetadataGen)}, ","), "The models allowed for embeddings generation."))
|
||||
if len(c.AllowedEmbeddingsModels) == 0 {
|
||||
c.AddError(errors.New("must provide allowed models for embeddings generation"))
|
||||
}
|
||||
|
||||
@ -46,6 +46,8 @@ type EmbeddingsRequest struct {
|
||||
Model string `json:"model"`
|
||||
// Input is the list of strings to generate embeddings for.
|
||||
Input []string `json:"input"`
|
||||
// IsQuery is true if the request is used for querying, false if it used for indexing.
|
||||
IsQuery bool `json:"isQuery"`
|
||||
}
|
||||
|
||||
type Embedding struct {
|
||||
|
||||
@ -28,6 +28,7 @@ const Llama27bCode = "accounts/fireworks/models/llama-v2-7b-code"
|
||||
const Llama213bCode = "accounts/fireworks/models/llama-v2-13b-code"
|
||||
const Llama213bCodeInstruct = "accounts/fireworks/models/llama-v2-13b-code-instruct"
|
||||
const Llama234bCodeInstruct = "accounts/fireworks/models/llama-v2-34b-code-instruct"
|
||||
const Llama370bInstruct = "accounts/fireworks/models/llama-v3-70b-instruct"
|
||||
const Mistral7bInstruct = "accounts/fireworks/models/mistral-7b-instruct-4k"
|
||||
const Mixtral8x7bInstruct = "accounts/fireworks/models/mixtral-8x7b-instruct"
|
||||
const Mixtral8x22Instruct = "accounts/fireworks/models/mixtral-8x22b-instruct"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user