diff --git a/enterprise/cmd/frontend/internal/context/BUILD.bazel b/enterprise/cmd/frontend/internal/context/BUILD.bazel index 659cc6ff0d9..cc826500e64 100644 --- a/enterprise/cmd/frontend/internal/context/BUILD.bazel +++ b/enterprise/cmd/frontend/internal/context/BUILD.bazel @@ -10,9 +10,12 @@ go_library( "//enterprise/cmd/frontend/internal/context/resolvers", "//internal/codeintel", "//internal/codycontext:context", + "//internal/conf", "//internal/conf/conftypes", "//internal/database", "//internal/embeddings", + "//internal/embeddings/db", + "//internal/grpc/defaults", "//internal/observation", "//internal/search/client", ], diff --git a/enterprise/cmd/frontend/internal/context/init.go b/enterprise/cmd/frontend/internal/context/init.go index a764b162c20..24e22cbcc6e 100644 --- a/enterprise/cmd/frontend/internal/context/init.go +++ b/enterprise/cmd/frontend/internal/context/init.go @@ -7,9 +7,12 @@ import ( "github.com/sourcegraph/sourcegraph/enterprise/cmd/frontend/internal/context/resolvers" "github.com/sourcegraph/sourcegraph/internal/codeintel" codycontext "github.com/sourcegraph/sourcegraph/internal/codycontext" + "github.com/sourcegraph/sourcegraph/internal/conf" "github.com/sourcegraph/sourcegraph/internal/conf/conftypes" "github.com/sourcegraph/sourcegraph/internal/database" "github.com/sourcegraph/sourcegraph/internal/embeddings" + vdb "github.com/sourcegraph/sourcegraph/internal/embeddings/db" + "github.com/sourcegraph/sourcegraph/internal/grpc/defaults" "github.com/sourcegraph/sourcegraph/internal/observation" "github.com/sourcegraph/sourcegraph/internal/search/client" ) @@ -24,11 +27,20 @@ func Init( ) error { embeddingsClient := embeddings.NewDefaultClient() searchClient := client.New(observationCtx.Logger, db) + qdrantSearcher := vdb.NewDisabledDB() + if addr := conf.ServiceConnections().Qdrant; addr != "" { + conn, err := defaults.Dial(addr, observationCtx.Logger) + if err != nil { + return err + } + qdrantSearcher = vdb.NewQdrantDBFromConn(conn) + } contextClient := codycontext.NewCodyContextClient( observationCtx, db, embeddingsClient, searchClient, + qdrantSearcher, ) enterpriseServices.CodyContextResolver = resolvers.NewResolver( db, diff --git a/enterprise/cmd/frontend/internal/context/resolvers/context_test.go b/enterprise/cmd/frontend/internal/context/resolvers/context_test.go index b08aba684a1..fe7ae0218de 100644 --- a/enterprise/cmd/frontend/internal/context/resolvers/context_test.go +++ b/enterprise/cmd/frontend/internal/context/resolvers/context_test.go @@ -130,6 +130,7 @@ func TestContextResolver(t *testing.T) { db, mockEmbeddingsClient, mockSearchClient, + nil, ) resolver := NewResolver( diff --git a/enterprise/cmd/worker/internal/embeddings/repo/worker.go b/enterprise/cmd/worker/internal/embeddings/repo/worker.go index 8d3d9db4daf..f8adb083425 100644 --- a/enterprise/cmd/worker/internal/embeddings/repo/worker.go +++ b/enterprise/cmd/worker/internal/embeddings/repo/worker.go @@ -55,7 +55,7 @@ func (s *repoEmbeddingJob) Routines(_ context.Context, observationCtx *observati return nil, err } - qdrantInserter := vdb.NewNoopInserter() + qdrantInserter := vdb.NewNoopDB() if qdrantAddr := conf.Get().ServiceConnections().Qdrant; qdrantAddr != "" { conn, err := defaults.Dial(qdrantAddr, observationCtx.Logger) if err != nil { diff --git a/internal/codycontext/BUILD.bazel b/internal/codycontext/BUILD.bazel index bf8e2a8f218..a375d1fa365 100644 --- a/internal/codycontext/BUILD.bazel +++ b/internal/codycontext/BUILD.bazel @@ -7,9 +7,12 @@ go_library( visibility = ["//:__subpackages__"], deps = [ "//internal/api", + "//internal/conf", "//internal/database", "//internal/embeddings", + "//internal/embeddings/db", "//internal/embeddings/embed", + "//internal/featureflag", "//internal/metrics", "//internal/observation", "//internal/search", @@ -18,6 +21,7 @@ go_library( "//internal/search/result", "//internal/search/streaming", "//internal/types", + "//lib/errors", "@com_github_sourcegraph_conc//pool", "@com_github_sourcegraph_log//:log", "@io_opentelemetry_go_otel//attribute", diff --git a/internal/codycontext/context.go b/internal/codycontext/context.go index c7a87d0f329..3df0337197b 100644 --- a/internal/codycontext/context.go +++ b/internal/codycontext/context.go @@ -14,9 +14,12 @@ import ( "go.opentelemetry.io/otel/attribute" "github.com/sourcegraph/sourcegraph/internal/api" + "github.com/sourcegraph/sourcegraph/internal/conf" "github.com/sourcegraph/sourcegraph/internal/database" "github.com/sourcegraph/sourcegraph/internal/embeddings" + vdb "github.com/sourcegraph/sourcegraph/internal/embeddings/db" "github.com/sourcegraph/sourcegraph/internal/embeddings/embed" + "github.com/sourcegraph/sourcegraph/internal/featureflag" "github.com/sourcegraph/sourcegraph/internal/metrics" "github.com/sourcegraph/sourcegraph/internal/observation" "github.com/sourcegraph/sourcegraph/internal/search" @@ -25,6 +28,7 @@ import ( "github.com/sourcegraph/sourcegraph/internal/search/result" "github.com/sourcegraph/sourcegraph/internal/search/streaming" "github.com/sourcegraph/sourcegraph/internal/types" + "github.com/sourcegraph/sourcegraph/lib/errors" ) type FileChunkContext struct { @@ -36,7 +40,7 @@ type FileChunkContext struct { EndLine int } -func NewCodyContextClient(obsCtx *observation.Context, db database.DB, embeddingsClient embeddings.Client, searchClient client.SearchClient) *CodyContextClient { +func NewCodyContextClient(obsCtx *observation.Context, db database.DB, embeddingsClient embeddings.Client, searchClient client.SearchClient, qdrantSearcher vdb.VectorSearcher) *CodyContextClient { redMetrics := metrics.NewREDMetrics( obsCtx.Registerer, "codycontext_client", @@ -58,6 +62,7 @@ func NewCodyContextClient(obsCtx *observation.Context, db database.DB, embedding db: db, embeddingsClient: embeddingsClient, searchClient: searchClient, + qdrantSearcher: qdrantSearcher, obsCtx: obsCtx, getCodyContextOp: op("getCodyContext"), @@ -70,6 +75,7 @@ type CodyContextClient struct { db database.DB embeddingsClient embeddings.Client searchClient client.SearchClient + qdrantSearcher vdb.VectorSearcher obsCtx *observation.Context getCodyContextOp *observation.Operation @@ -84,6 +90,14 @@ type GetContextArgs struct { TextResultsCount int32 } +func (a *GetContextArgs) RepoIDs() []api.RepoID { + res := make([]api.RepoID, 0, len(a.Repos)) + for _, repo := range a.Repos { + res = append(res, repo.ID) + } + return res +} + func (a *GetContextArgs) Attrs() []attribute.KeyValue { return []attribute.KeyValue{ attribute.Int("numRepos", len(a.Repos)), @@ -170,6 +184,10 @@ func (c *CodyContextClient) getEmbeddingsContext(ctx context.Context, args GetCo return nil, nil } + if featureflag.FromContext(ctx).GetBoolOr("qdrant", false) { + return c.getEmbeddingsContextFromQdrant(ctx, args) + } + repoNames := make([]api.RepoName, len(args.Repos)) repoIDs := make([]api.RepoID, len(args.Repos)) for i, repo := range args.Repos { @@ -308,6 +326,48 @@ func (c *CodyContextClient) getKeywordContext(ctx context.Context, args GetConte return append(results[0], results[1]...), nil } +func (c *CodyContextClient) getEmbeddingsContextFromQdrant(ctx context.Context, args GetContextArgs) (_ []FileChunkContext, err error) { + embeddingsConf := conf.GetEmbeddingsConfig(conf.Get().SiteConfig()) + if c == nil { + return nil, errors.New("embeddings not configured or disabled") + } + client, err := embed.NewEmbeddingsClient(embeddingsConf) + if err != nil { + return nil, errors.Wrap(err, "getting embeddings client") + } + + resp, err := client.GetQueryEmbedding(ctx, args.Query) + if err != nil || len(resp.Failed) > 0 { + return nil, errors.Wrap(err, "getting query embedding") + } + query := resp.Embeddings + + params := vdb.SearchParams{ + ModelID: client.GetModelIdentifier(), + RepoIDs: args.RepoIDs(), + Query: query, + CodeLimit: int(args.CodeResultsCount), + TextLimit: int(args.TextResultsCount), + } + chunks, err := c.qdrantSearcher.Search(ctx, params) + if err != nil { + return nil, errors.Wrap(err, "searching vector DB") + } + + res := make([]FileChunkContext, 0, len(chunks)) + for _, chunk := range chunks { + res = append(res, FileChunkContext{ + RepoName: chunk.Point.Payload.RepoName, + RepoID: chunk.Point.Payload.RepoID, + CommitID: chunk.Point.Payload.Revision, + Path: chunk.Point.Payload.FilePath, + StartLine: int(chunk.Point.Payload.StartLine), + EndLine: int(chunk.Point.Payload.EndLine), + }) + } + return res, nil +} + func fileMatchToContextMatches(fm *result.FileMatch) []FileChunkContext { if len(fm.ChunkMatches) == 0 { return nil diff --git a/internal/embeddings/db/noop.go b/internal/embeddings/db/noop.go index b2556ead3d2..503acefd11e 100644 --- a/internal/embeddings/db/noop.go +++ b/internal/embeddings/db/noop.go @@ -4,28 +4,55 @@ import ( "context" "github.com/sourcegraph/sourcegraph/internal/api" + "github.com/sourcegraph/sourcegraph/lib/errors" ) -func NewNoopInserter() VectorInserter { - return noopInserter{} +func NewNoopDB() VectorDB { + return noopDB{} } -var _ VectorDB = noopInserter{} +var _ VectorDB = noopDB{} -type noopInserter struct{} +type noopDB struct{} -func (noopInserter) Search(context.Context, SearchParams) ([]ChunkResult, error) { +func (noopDB) Search(context.Context, SearchParams) ([]ChunkResult, error) { return nil, nil } -func (noopInserter) PrepareUpdate(ctx context.Context, modelID string, modelDims uint64) error { +func (noopDB) PrepareUpdate(ctx context.Context, modelID string, modelDims uint64) error { return nil } -func (noopInserter) HasIndex(ctx context.Context, modelID string, repoID api.RepoID, revision api.CommitID) (bool, error) { +func (noopDB) HasIndex(ctx context.Context, modelID string, repoID api.RepoID, revision api.CommitID) (bool, error) { return false, nil } -func (noopInserter) InsertChunks(context.Context, InsertParams) error { +func (noopDB) InsertChunks(context.Context, InsertParams) error { return nil } -func (noopInserter) FinalizeUpdate(context.Context, FinalizeUpdateParams) error { +func (noopDB) FinalizeUpdate(context.Context, FinalizeUpdateParams) error { return nil } + +var ErrDisabled = errors.New("Qdrant is disabled. Enable by setting QDRANT_ENDPOINT") + +func NewDisabledDB() VectorDB { + return disabledDB{} +} + +var _ VectorDB = disabledDB{} + +type disabledDB struct{} + +func (disabledDB) Search(context.Context, SearchParams) ([]ChunkResult, error) { + return nil, ErrDisabled +} +func (disabledDB) PrepareUpdate(ctx context.Context, modelID string, modelDims uint64) error { + return ErrDisabled +} +func (disabledDB) HasIndex(ctx context.Context, modelID string, repoID api.RepoID, revision api.CommitID) (bool, error) { + return false, ErrDisabled +} +func (disabledDB) InsertChunks(context.Context, InsertParams) error { + return ErrDisabled +} +func (disabledDB) FinalizeUpdate(context.Context, FinalizeUpdateParams) error { + return ErrDisabled +} diff --git a/internal/embeddings/db/qdrant.go b/internal/embeddings/db/qdrant.go index e866c4f88b8..a864ff7646f 100644 --- a/internal/embeddings/db/qdrant.go +++ b/internal/embeddings/db/qdrant.go @@ -29,10 +29,22 @@ type qdrantDB struct { var _ VectorDB = (*qdrantDB)(nil) type SearchParams struct { - ModelID string - RepoIDs []api.RepoID - Query []float32 + // RepoIDs is the set of repos to search. + // If empty, all repos are searched. + RepoIDs []api.RepoID + + // The ID of the model that the query was embedded with. + // Embeddings for other models will not be searched. + ModelID string + + // Query is the embedding for the search query. + // Its dimensions must match the model dimensions. + Query []float32 + + // The maximum number of code results to return CodeLimit int + + // The maximum number of text results to return TextLimit int } diff --git a/internal/embeddings/embed/embed_test.go b/internal/embeddings/embed/embed_test.go index 861cdca86ed..4fd56673354 100644 --- a/internal/embeddings/embed/embed_test.go +++ b/internal/embeddings/embed/embed_test.go @@ -38,7 +38,7 @@ func TestEmbedRepo(t *testing.T) { } revision := api.CommitID("deadbeef") embeddingsClient := NewMockEmbeddingsClient() - inserter := db.NewNoopInserter() + inserter := db.NewNoopDB() contextService := NewMockContextService() contextService.SplitIntoEmbeddableChunksFunc.SetDefaultHook(defaultSplitter) splitOptions := codeintelContext.SplitOptions{ChunkTokensThreshold: 8} @@ -385,7 +385,7 @@ func TestEmbedRepo_ExcludeChunkOnError(t *testing.T) { repoIDName := types.RepoIDName{Name: repoName} embeddingsClient := NewMockEmbeddingsClient() contextService := NewMockContextService() - inserter := db.NewNoopInserter() + inserter := db.NewNoopDB() contextService.SplitIntoEmbeddableChunksFunc.SetDefaultHook(defaultSplitter) splitOptions := codeintelContext.SplitOptions{ChunkTokensThreshold: 8} mockFiles := map[string][]byte{