Embeddings: minimal end-to-end searching with qdrant (#55772)

This commit is contained in:
Camden Cheek 2023-08-15 13:29:40 -06:00 committed by GitHub
parent 2036e42251
commit c98cac6035
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 135 additions and 16 deletions

View File

@ -10,9 +10,12 @@ go_library(
"//enterprise/cmd/frontend/internal/context/resolvers",
"//internal/codeintel",
"//internal/codycontext:context",
"//internal/conf",
"//internal/conf/conftypes",
"//internal/database",
"//internal/embeddings",
"//internal/embeddings/db",
"//internal/grpc/defaults",
"//internal/observation",
"//internal/search/client",
],

View File

@ -7,9 +7,12 @@ import (
"github.com/sourcegraph/sourcegraph/enterprise/cmd/frontend/internal/context/resolvers"
"github.com/sourcegraph/sourcegraph/internal/codeintel"
codycontext "github.com/sourcegraph/sourcegraph/internal/codycontext"
"github.com/sourcegraph/sourcegraph/internal/conf"
"github.com/sourcegraph/sourcegraph/internal/conf/conftypes"
"github.com/sourcegraph/sourcegraph/internal/database"
"github.com/sourcegraph/sourcegraph/internal/embeddings"
vdb "github.com/sourcegraph/sourcegraph/internal/embeddings/db"
"github.com/sourcegraph/sourcegraph/internal/grpc/defaults"
"github.com/sourcegraph/sourcegraph/internal/observation"
"github.com/sourcegraph/sourcegraph/internal/search/client"
)
@ -24,11 +27,20 @@ func Init(
) error {
embeddingsClient := embeddings.NewDefaultClient()
searchClient := client.New(observationCtx.Logger, db)
qdrantSearcher := vdb.NewDisabledDB()
if addr := conf.ServiceConnections().Qdrant; addr != "" {
conn, err := defaults.Dial(addr, observationCtx.Logger)
if err != nil {
return err
}
qdrantSearcher = vdb.NewQdrantDBFromConn(conn)
}
contextClient := codycontext.NewCodyContextClient(
observationCtx,
db,
embeddingsClient,
searchClient,
qdrantSearcher,
)
enterpriseServices.CodyContextResolver = resolvers.NewResolver(
db,

View File

@ -130,6 +130,7 @@ func TestContextResolver(t *testing.T) {
db,
mockEmbeddingsClient,
mockSearchClient,
nil,
)
resolver := NewResolver(

View File

@ -55,7 +55,7 @@ func (s *repoEmbeddingJob) Routines(_ context.Context, observationCtx *observati
return nil, err
}
qdrantInserter := vdb.NewNoopInserter()
qdrantInserter := vdb.NewNoopDB()
if qdrantAddr := conf.Get().ServiceConnections().Qdrant; qdrantAddr != "" {
conn, err := defaults.Dial(qdrantAddr, observationCtx.Logger)
if err != nil {

View File

@ -7,9 +7,12 @@ go_library(
visibility = ["//:__subpackages__"],
deps = [
"//internal/api",
"//internal/conf",
"//internal/database",
"//internal/embeddings",
"//internal/embeddings/db",
"//internal/embeddings/embed",
"//internal/featureflag",
"//internal/metrics",
"//internal/observation",
"//internal/search",
@ -18,6 +21,7 @@ go_library(
"//internal/search/result",
"//internal/search/streaming",
"//internal/types",
"//lib/errors",
"@com_github_sourcegraph_conc//pool",
"@com_github_sourcegraph_log//:log",
"@io_opentelemetry_go_otel//attribute",

View File

@ -14,9 +14,12 @@ import (
"go.opentelemetry.io/otel/attribute"
"github.com/sourcegraph/sourcegraph/internal/api"
"github.com/sourcegraph/sourcegraph/internal/conf"
"github.com/sourcegraph/sourcegraph/internal/database"
"github.com/sourcegraph/sourcegraph/internal/embeddings"
vdb "github.com/sourcegraph/sourcegraph/internal/embeddings/db"
"github.com/sourcegraph/sourcegraph/internal/embeddings/embed"
"github.com/sourcegraph/sourcegraph/internal/featureflag"
"github.com/sourcegraph/sourcegraph/internal/metrics"
"github.com/sourcegraph/sourcegraph/internal/observation"
"github.com/sourcegraph/sourcegraph/internal/search"
@ -25,6 +28,7 @@ import (
"github.com/sourcegraph/sourcegraph/internal/search/result"
"github.com/sourcegraph/sourcegraph/internal/search/streaming"
"github.com/sourcegraph/sourcegraph/internal/types"
"github.com/sourcegraph/sourcegraph/lib/errors"
)
type FileChunkContext struct {
@ -36,7 +40,7 @@ type FileChunkContext struct {
EndLine int
}
func NewCodyContextClient(obsCtx *observation.Context, db database.DB, embeddingsClient embeddings.Client, searchClient client.SearchClient) *CodyContextClient {
func NewCodyContextClient(obsCtx *observation.Context, db database.DB, embeddingsClient embeddings.Client, searchClient client.SearchClient, qdrantSearcher vdb.VectorSearcher) *CodyContextClient {
redMetrics := metrics.NewREDMetrics(
obsCtx.Registerer,
"codycontext_client",
@ -58,6 +62,7 @@ func NewCodyContextClient(obsCtx *observation.Context, db database.DB, embedding
db: db,
embeddingsClient: embeddingsClient,
searchClient: searchClient,
qdrantSearcher: qdrantSearcher,
obsCtx: obsCtx,
getCodyContextOp: op("getCodyContext"),
@ -70,6 +75,7 @@ type CodyContextClient struct {
db database.DB
embeddingsClient embeddings.Client
searchClient client.SearchClient
qdrantSearcher vdb.VectorSearcher
obsCtx *observation.Context
getCodyContextOp *observation.Operation
@ -84,6 +90,14 @@ type GetContextArgs struct {
TextResultsCount int32
}
func (a *GetContextArgs) RepoIDs() []api.RepoID {
res := make([]api.RepoID, 0, len(a.Repos))
for _, repo := range a.Repos {
res = append(res, repo.ID)
}
return res
}
func (a *GetContextArgs) Attrs() []attribute.KeyValue {
return []attribute.KeyValue{
attribute.Int("numRepos", len(a.Repos)),
@ -170,6 +184,10 @@ func (c *CodyContextClient) getEmbeddingsContext(ctx context.Context, args GetCo
return nil, nil
}
if featureflag.FromContext(ctx).GetBoolOr("qdrant", false) {
return c.getEmbeddingsContextFromQdrant(ctx, args)
}
repoNames := make([]api.RepoName, len(args.Repos))
repoIDs := make([]api.RepoID, len(args.Repos))
for i, repo := range args.Repos {
@ -308,6 +326,48 @@ func (c *CodyContextClient) getKeywordContext(ctx context.Context, args GetConte
return append(results[0], results[1]...), nil
}
func (c *CodyContextClient) getEmbeddingsContextFromQdrant(ctx context.Context, args GetContextArgs) (_ []FileChunkContext, err error) {
embeddingsConf := conf.GetEmbeddingsConfig(conf.Get().SiteConfig())
if c == nil {
return nil, errors.New("embeddings not configured or disabled")
}
client, err := embed.NewEmbeddingsClient(embeddingsConf)
if err != nil {
return nil, errors.Wrap(err, "getting embeddings client")
}
resp, err := client.GetQueryEmbedding(ctx, args.Query)
if err != nil || len(resp.Failed) > 0 {
return nil, errors.Wrap(err, "getting query embedding")
}
query := resp.Embeddings
params := vdb.SearchParams{
ModelID: client.GetModelIdentifier(),
RepoIDs: args.RepoIDs(),
Query: query,
CodeLimit: int(args.CodeResultsCount),
TextLimit: int(args.TextResultsCount),
}
chunks, err := c.qdrantSearcher.Search(ctx, params)
if err != nil {
return nil, errors.Wrap(err, "searching vector DB")
}
res := make([]FileChunkContext, 0, len(chunks))
for _, chunk := range chunks {
res = append(res, FileChunkContext{
RepoName: chunk.Point.Payload.RepoName,
RepoID: chunk.Point.Payload.RepoID,
CommitID: chunk.Point.Payload.Revision,
Path: chunk.Point.Payload.FilePath,
StartLine: int(chunk.Point.Payload.StartLine),
EndLine: int(chunk.Point.Payload.EndLine),
})
}
return res, nil
}
func fileMatchToContextMatches(fm *result.FileMatch) []FileChunkContext {
if len(fm.ChunkMatches) == 0 {
return nil

View File

@ -4,28 +4,55 @@ import (
"context"
"github.com/sourcegraph/sourcegraph/internal/api"
"github.com/sourcegraph/sourcegraph/lib/errors"
)
func NewNoopInserter() VectorInserter {
return noopInserter{}
func NewNoopDB() VectorDB {
return noopDB{}
}
var _ VectorDB = noopInserter{}
var _ VectorDB = noopDB{}
type noopInserter struct{}
type noopDB struct{}
func (noopInserter) Search(context.Context, SearchParams) ([]ChunkResult, error) {
func (noopDB) Search(context.Context, SearchParams) ([]ChunkResult, error) {
return nil, nil
}
func (noopInserter) PrepareUpdate(ctx context.Context, modelID string, modelDims uint64) error {
func (noopDB) PrepareUpdate(ctx context.Context, modelID string, modelDims uint64) error {
return nil
}
func (noopInserter) HasIndex(ctx context.Context, modelID string, repoID api.RepoID, revision api.CommitID) (bool, error) {
func (noopDB) HasIndex(ctx context.Context, modelID string, repoID api.RepoID, revision api.CommitID) (bool, error) {
return false, nil
}
func (noopInserter) InsertChunks(context.Context, InsertParams) error {
func (noopDB) InsertChunks(context.Context, InsertParams) error {
return nil
}
func (noopInserter) FinalizeUpdate(context.Context, FinalizeUpdateParams) error {
func (noopDB) FinalizeUpdate(context.Context, FinalizeUpdateParams) error {
return nil
}
var ErrDisabled = errors.New("Qdrant is disabled. Enable by setting QDRANT_ENDPOINT")
func NewDisabledDB() VectorDB {
return disabledDB{}
}
var _ VectorDB = disabledDB{}
type disabledDB struct{}
func (disabledDB) Search(context.Context, SearchParams) ([]ChunkResult, error) {
return nil, ErrDisabled
}
func (disabledDB) PrepareUpdate(ctx context.Context, modelID string, modelDims uint64) error {
return ErrDisabled
}
func (disabledDB) HasIndex(ctx context.Context, modelID string, repoID api.RepoID, revision api.CommitID) (bool, error) {
return false, ErrDisabled
}
func (disabledDB) InsertChunks(context.Context, InsertParams) error {
return ErrDisabled
}
func (disabledDB) FinalizeUpdate(context.Context, FinalizeUpdateParams) error {
return ErrDisabled
}

View File

@ -29,10 +29,22 @@ type qdrantDB struct {
var _ VectorDB = (*qdrantDB)(nil)
type SearchParams struct {
ModelID string
RepoIDs []api.RepoID
Query []float32
// RepoIDs is the set of repos to search.
// If empty, all repos are searched.
RepoIDs []api.RepoID
// The ID of the model that the query was embedded with.
// Embeddings for other models will not be searched.
ModelID string
// Query is the embedding for the search query.
// Its dimensions must match the model dimensions.
Query []float32
// The maximum number of code results to return
CodeLimit int
// The maximum number of text results to return
TextLimit int
}

View File

@ -38,7 +38,7 @@ func TestEmbedRepo(t *testing.T) {
}
revision := api.CommitID("deadbeef")
embeddingsClient := NewMockEmbeddingsClient()
inserter := db.NewNoopInserter()
inserter := db.NewNoopDB()
contextService := NewMockContextService()
contextService.SplitIntoEmbeddableChunksFunc.SetDefaultHook(defaultSplitter)
splitOptions := codeintelContext.SplitOptions{ChunkTokensThreshold: 8}
@ -385,7 +385,7 @@ func TestEmbedRepo_ExcludeChunkOnError(t *testing.T) {
repoIDName := types.RepoIDName{Name: repoName}
embeddingsClient := NewMockEmbeddingsClient()
contextService := NewMockContextService()
inserter := db.NewNoopInserter()
inserter := db.NewNoopDB()
contextService.SplitIntoEmbeddableChunksFunc.SetDefaultHook(defaultSplitter)
splitOptions := codeintelContext.SplitOptions{ChunkTokensThreshold: 8}
mockFiles := map[string][]byte{