mirror of
https://github.com/sourcegraph/sourcegraph.git
synced 2026-02-06 14:51:44 +00:00
Embeddings: minimal end-to-end searching with qdrant (#55772)
This commit is contained in:
parent
2036e42251
commit
c98cac6035
@ -10,9 +10,12 @@ go_library(
|
||||
"//enterprise/cmd/frontend/internal/context/resolvers",
|
||||
"//internal/codeintel",
|
||||
"//internal/codycontext:context",
|
||||
"//internal/conf",
|
||||
"//internal/conf/conftypes",
|
||||
"//internal/database",
|
||||
"//internal/embeddings",
|
||||
"//internal/embeddings/db",
|
||||
"//internal/grpc/defaults",
|
||||
"//internal/observation",
|
||||
"//internal/search/client",
|
||||
],
|
||||
|
||||
@ -7,9 +7,12 @@ import (
|
||||
"github.com/sourcegraph/sourcegraph/enterprise/cmd/frontend/internal/context/resolvers"
|
||||
"github.com/sourcegraph/sourcegraph/internal/codeintel"
|
||||
codycontext "github.com/sourcegraph/sourcegraph/internal/codycontext"
|
||||
"github.com/sourcegraph/sourcegraph/internal/conf"
|
||||
"github.com/sourcegraph/sourcegraph/internal/conf/conftypes"
|
||||
"github.com/sourcegraph/sourcegraph/internal/database"
|
||||
"github.com/sourcegraph/sourcegraph/internal/embeddings"
|
||||
vdb "github.com/sourcegraph/sourcegraph/internal/embeddings/db"
|
||||
"github.com/sourcegraph/sourcegraph/internal/grpc/defaults"
|
||||
"github.com/sourcegraph/sourcegraph/internal/observation"
|
||||
"github.com/sourcegraph/sourcegraph/internal/search/client"
|
||||
)
|
||||
@ -24,11 +27,20 @@ func Init(
|
||||
) error {
|
||||
embeddingsClient := embeddings.NewDefaultClient()
|
||||
searchClient := client.New(observationCtx.Logger, db)
|
||||
qdrantSearcher := vdb.NewDisabledDB()
|
||||
if addr := conf.ServiceConnections().Qdrant; addr != "" {
|
||||
conn, err := defaults.Dial(addr, observationCtx.Logger)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
qdrantSearcher = vdb.NewQdrantDBFromConn(conn)
|
||||
}
|
||||
contextClient := codycontext.NewCodyContextClient(
|
||||
observationCtx,
|
||||
db,
|
||||
embeddingsClient,
|
||||
searchClient,
|
||||
qdrantSearcher,
|
||||
)
|
||||
enterpriseServices.CodyContextResolver = resolvers.NewResolver(
|
||||
db,
|
||||
|
||||
@ -130,6 +130,7 @@ func TestContextResolver(t *testing.T) {
|
||||
db,
|
||||
mockEmbeddingsClient,
|
||||
mockSearchClient,
|
||||
nil,
|
||||
)
|
||||
|
||||
resolver := NewResolver(
|
||||
|
||||
@ -55,7 +55,7 @@ func (s *repoEmbeddingJob) Routines(_ context.Context, observationCtx *observati
|
||||
return nil, err
|
||||
}
|
||||
|
||||
qdrantInserter := vdb.NewNoopInserter()
|
||||
qdrantInserter := vdb.NewNoopDB()
|
||||
if qdrantAddr := conf.Get().ServiceConnections().Qdrant; qdrantAddr != "" {
|
||||
conn, err := defaults.Dial(qdrantAddr, observationCtx.Logger)
|
||||
if err != nil {
|
||||
|
||||
@ -7,9 +7,12 @@ go_library(
|
||||
visibility = ["//:__subpackages__"],
|
||||
deps = [
|
||||
"//internal/api",
|
||||
"//internal/conf",
|
||||
"//internal/database",
|
||||
"//internal/embeddings",
|
||||
"//internal/embeddings/db",
|
||||
"//internal/embeddings/embed",
|
||||
"//internal/featureflag",
|
||||
"//internal/metrics",
|
||||
"//internal/observation",
|
||||
"//internal/search",
|
||||
@ -18,6 +21,7 @@ go_library(
|
||||
"//internal/search/result",
|
||||
"//internal/search/streaming",
|
||||
"//internal/types",
|
||||
"//lib/errors",
|
||||
"@com_github_sourcegraph_conc//pool",
|
||||
"@com_github_sourcegraph_log//:log",
|
||||
"@io_opentelemetry_go_otel//attribute",
|
||||
|
||||
@ -14,9 +14,12 @@ import (
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/internal/api"
|
||||
"github.com/sourcegraph/sourcegraph/internal/conf"
|
||||
"github.com/sourcegraph/sourcegraph/internal/database"
|
||||
"github.com/sourcegraph/sourcegraph/internal/embeddings"
|
||||
vdb "github.com/sourcegraph/sourcegraph/internal/embeddings/db"
|
||||
"github.com/sourcegraph/sourcegraph/internal/embeddings/embed"
|
||||
"github.com/sourcegraph/sourcegraph/internal/featureflag"
|
||||
"github.com/sourcegraph/sourcegraph/internal/metrics"
|
||||
"github.com/sourcegraph/sourcegraph/internal/observation"
|
||||
"github.com/sourcegraph/sourcegraph/internal/search"
|
||||
@ -25,6 +28,7 @@ import (
|
||||
"github.com/sourcegraph/sourcegraph/internal/search/result"
|
||||
"github.com/sourcegraph/sourcegraph/internal/search/streaming"
|
||||
"github.com/sourcegraph/sourcegraph/internal/types"
|
||||
"github.com/sourcegraph/sourcegraph/lib/errors"
|
||||
)
|
||||
|
||||
type FileChunkContext struct {
|
||||
@ -36,7 +40,7 @@ type FileChunkContext struct {
|
||||
EndLine int
|
||||
}
|
||||
|
||||
func NewCodyContextClient(obsCtx *observation.Context, db database.DB, embeddingsClient embeddings.Client, searchClient client.SearchClient) *CodyContextClient {
|
||||
func NewCodyContextClient(obsCtx *observation.Context, db database.DB, embeddingsClient embeddings.Client, searchClient client.SearchClient, qdrantSearcher vdb.VectorSearcher) *CodyContextClient {
|
||||
redMetrics := metrics.NewREDMetrics(
|
||||
obsCtx.Registerer,
|
||||
"codycontext_client",
|
||||
@ -58,6 +62,7 @@ func NewCodyContextClient(obsCtx *observation.Context, db database.DB, embedding
|
||||
db: db,
|
||||
embeddingsClient: embeddingsClient,
|
||||
searchClient: searchClient,
|
||||
qdrantSearcher: qdrantSearcher,
|
||||
|
||||
obsCtx: obsCtx,
|
||||
getCodyContextOp: op("getCodyContext"),
|
||||
@ -70,6 +75,7 @@ type CodyContextClient struct {
|
||||
db database.DB
|
||||
embeddingsClient embeddings.Client
|
||||
searchClient client.SearchClient
|
||||
qdrantSearcher vdb.VectorSearcher
|
||||
|
||||
obsCtx *observation.Context
|
||||
getCodyContextOp *observation.Operation
|
||||
@ -84,6 +90,14 @@ type GetContextArgs struct {
|
||||
TextResultsCount int32
|
||||
}
|
||||
|
||||
func (a *GetContextArgs) RepoIDs() []api.RepoID {
|
||||
res := make([]api.RepoID, 0, len(a.Repos))
|
||||
for _, repo := range a.Repos {
|
||||
res = append(res, repo.ID)
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func (a *GetContextArgs) Attrs() []attribute.KeyValue {
|
||||
return []attribute.KeyValue{
|
||||
attribute.Int("numRepos", len(a.Repos)),
|
||||
@ -170,6 +184,10 @@ func (c *CodyContextClient) getEmbeddingsContext(ctx context.Context, args GetCo
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if featureflag.FromContext(ctx).GetBoolOr("qdrant", false) {
|
||||
return c.getEmbeddingsContextFromQdrant(ctx, args)
|
||||
}
|
||||
|
||||
repoNames := make([]api.RepoName, len(args.Repos))
|
||||
repoIDs := make([]api.RepoID, len(args.Repos))
|
||||
for i, repo := range args.Repos {
|
||||
@ -308,6 +326,48 @@ func (c *CodyContextClient) getKeywordContext(ctx context.Context, args GetConte
|
||||
return append(results[0], results[1]...), nil
|
||||
}
|
||||
|
||||
func (c *CodyContextClient) getEmbeddingsContextFromQdrant(ctx context.Context, args GetContextArgs) (_ []FileChunkContext, err error) {
|
||||
embeddingsConf := conf.GetEmbeddingsConfig(conf.Get().SiteConfig())
|
||||
if c == nil {
|
||||
return nil, errors.New("embeddings not configured or disabled")
|
||||
}
|
||||
client, err := embed.NewEmbeddingsClient(embeddingsConf)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "getting embeddings client")
|
||||
}
|
||||
|
||||
resp, err := client.GetQueryEmbedding(ctx, args.Query)
|
||||
if err != nil || len(resp.Failed) > 0 {
|
||||
return nil, errors.Wrap(err, "getting query embedding")
|
||||
}
|
||||
query := resp.Embeddings
|
||||
|
||||
params := vdb.SearchParams{
|
||||
ModelID: client.GetModelIdentifier(),
|
||||
RepoIDs: args.RepoIDs(),
|
||||
Query: query,
|
||||
CodeLimit: int(args.CodeResultsCount),
|
||||
TextLimit: int(args.TextResultsCount),
|
||||
}
|
||||
chunks, err := c.qdrantSearcher.Search(ctx, params)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "searching vector DB")
|
||||
}
|
||||
|
||||
res := make([]FileChunkContext, 0, len(chunks))
|
||||
for _, chunk := range chunks {
|
||||
res = append(res, FileChunkContext{
|
||||
RepoName: chunk.Point.Payload.RepoName,
|
||||
RepoID: chunk.Point.Payload.RepoID,
|
||||
CommitID: chunk.Point.Payload.Revision,
|
||||
Path: chunk.Point.Payload.FilePath,
|
||||
StartLine: int(chunk.Point.Payload.StartLine),
|
||||
EndLine: int(chunk.Point.Payload.EndLine),
|
||||
})
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func fileMatchToContextMatches(fm *result.FileMatch) []FileChunkContext {
|
||||
if len(fm.ChunkMatches) == 0 {
|
||||
return nil
|
||||
|
||||
@ -4,28 +4,55 @@ import (
|
||||
"context"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/internal/api"
|
||||
"github.com/sourcegraph/sourcegraph/lib/errors"
|
||||
)
|
||||
|
||||
func NewNoopInserter() VectorInserter {
|
||||
return noopInserter{}
|
||||
func NewNoopDB() VectorDB {
|
||||
return noopDB{}
|
||||
}
|
||||
|
||||
var _ VectorDB = noopInserter{}
|
||||
var _ VectorDB = noopDB{}
|
||||
|
||||
type noopInserter struct{}
|
||||
type noopDB struct{}
|
||||
|
||||
func (noopInserter) Search(context.Context, SearchParams) ([]ChunkResult, error) {
|
||||
func (noopDB) Search(context.Context, SearchParams) ([]ChunkResult, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (noopInserter) PrepareUpdate(ctx context.Context, modelID string, modelDims uint64) error {
|
||||
func (noopDB) PrepareUpdate(ctx context.Context, modelID string, modelDims uint64) error {
|
||||
return nil
|
||||
}
|
||||
func (noopInserter) HasIndex(ctx context.Context, modelID string, repoID api.RepoID, revision api.CommitID) (bool, error) {
|
||||
func (noopDB) HasIndex(ctx context.Context, modelID string, repoID api.RepoID, revision api.CommitID) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
func (noopInserter) InsertChunks(context.Context, InsertParams) error {
|
||||
func (noopDB) InsertChunks(context.Context, InsertParams) error {
|
||||
return nil
|
||||
}
|
||||
func (noopInserter) FinalizeUpdate(context.Context, FinalizeUpdateParams) error {
|
||||
func (noopDB) FinalizeUpdate(context.Context, FinalizeUpdateParams) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var ErrDisabled = errors.New("Qdrant is disabled. Enable by setting QDRANT_ENDPOINT")
|
||||
|
||||
func NewDisabledDB() VectorDB {
|
||||
return disabledDB{}
|
||||
}
|
||||
|
||||
var _ VectorDB = disabledDB{}
|
||||
|
||||
type disabledDB struct{}
|
||||
|
||||
func (disabledDB) Search(context.Context, SearchParams) ([]ChunkResult, error) {
|
||||
return nil, ErrDisabled
|
||||
}
|
||||
func (disabledDB) PrepareUpdate(ctx context.Context, modelID string, modelDims uint64) error {
|
||||
return ErrDisabled
|
||||
}
|
||||
func (disabledDB) HasIndex(ctx context.Context, modelID string, repoID api.RepoID, revision api.CommitID) (bool, error) {
|
||||
return false, ErrDisabled
|
||||
}
|
||||
func (disabledDB) InsertChunks(context.Context, InsertParams) error {
|
||||
return ErrDisabled
|
||||
}
|
||||
func (disabledDB) FinalizeUpdate(context.Context, FinalizeUpdateParams) error {
|
||||
return ErrDisabled
|
||||
}
|
||||
|
||||
@ -29,10 +29,22 @@ type qdrantDB struct {
|
||||
var _ VectorDB = (*qdrantDB)(nil)
|
||||
|
||||
type SearchParams struct {
|
||||
ModelID string
|
||||
RepoIDs []api.RepoID
|
||||
Query []float32
|
||||
// RepoIDs is the set of repos to search.
|
||||
// If empty, all repos are searched.
|
||||
RepoIDs []api.RepoID
|
||||
|
||||
// The ID of the model that the query was embedded with.
|
||||
// Embeddings for other models will not be searched.
|
||||
ModelID string
|
||||
|
||||
// Query is the embedding for the search query.
|
||||
// Its dimensions must match the model dimensions.
|
||||
Query []float32
|
||||
|
||||
// The maximum number of code results to return
|
||||
CodeLimit int
|
||||
|
||||
// The maximum number of text results to return
|
||||
TextLimit int
|
||||
}
|
||||
|
||||
|
||||
@ -38,7 +38,7 @@ func TestEmbedRepo(t *testing.T) {
|
||||
}
|
||||
revision := api.CommitID("deadbeef")
|
||||
embeddingsClient := NewMockEmbeddingsClient()
|
||||
inserter := db.NewNoopInserter()
|
||||
inserter := db.NewNoopDB()
|
||||
contextService := NewMockContextService()
|
||||
contextService.SplitIntoEmbeddableChunksFunc.SetDefaultHook(defaultSplitter)
|
||||
splitOptions := codeintelContext.SplitOptions{ChunkTokensThreshold: 8}
|
||||
@ -385,7 +385,7 @@ func TestEmbedRepo_ExcludeChunkOnError(t *testing.T) {
|
||||
repoIDName := types.RepoIDName{Name: repoName}
|
||||
embeddingsClient := NewMockEmbeddingsClient()
|
||||
contextService := NewMockContextService()
|
||||
inserter := db.NewNoopInserter()
|
||||
inserter := db.NewNoopDB()
|
||||
contextService.SplitIntoEmbeddableChunksFunc.SetDefaultHook(defaultSplitter)
|
||||
splitOptions := codeintelContext.SplitOptions{ChunkTokensThreshold: 8}
|
||||
mockFiles := map[string][]byte{
|
||||
|
||||
Loading…
Reference in New Issue
Block a user