Revert "Embeddings: multi-repo search" (#51969)

Reverts sourcegraph/sourcegraph#51662
This commit is contained in:
Camden Cheek 2023-05-15 22:44:26 -06:00 committed by GitHub
parent c30b33ff0a
commit 03da2be83d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 340 additions and 1382 deletions

View File

@ -10,6 +10,6 @@ export class SourcegraphEmbeddingsSearchClient implements EmbeddingsSearch {
codeResultsCount: number,
textResultsCount: number
): Promise<EmbeddingsSearchResults | Error> {
return this.client.searchEmbeddings([this.repoId], query, codeResultsCount, textResultsCount)
return this.client.searchEmbeddings(this.repoId, query, codeResultsCount, textResultsCount)
}
}

View File

@ -141,13 +141,13 @@ export class SourcegraphGraphQLAPIClient {
}
public async searchEmbeddings(
repos: string[],
repo: string,
query: string,
codeResultsCount: number,
textResultsCount: number
): Promise<EmbeddingsSearchResults | Error> {
return this.fetchSourcegraphAPI<APIResponse<EmbeddingsSearchResponse>>(SEARCH_EMBEDDINGS_QUERY, {
repos,
repo,
query,
codeResultsCount,
textResultsCount,

View File

@ -21,8 +21,8 @@ query Repository($name: String!) {
}`
export const SEARCH_EMBEDDINGS_QUERY = `
query EmbeddingsSearch($repos: [ID!]!, $query: String!, $codeResultsCount: Int!, $textResultsCount: Int!) {
embeddingsMultiSearch(repos: $repos, query: $query, codeResultsCount: $codeResultsCount, textResultsCount: $textResultsCount) {
query EmbeddingsSearch($repo: ID!, $query: String!, $codeResultsCount: Int!, $textResultsCount: Int!) {
embeddingsSearch(repo: $repo, query: $query, codeResultsCount: $codeResultsCount, textResultsCount: $textResultsCount) {
codeResults {
fileName
startLine

View File

@ -11,7 +11,6 @@ import (
type EmbeddingsResolver interface {
EmbeddingsSearch(ctx context.Context, args EmbeddingsSearchInputArgs) (EmbeddingsSearchResultsResolver, error)
EmbeddingsMultiSearch(ctx context.Context, args EmbeddingsMultiSearchInputArgs) (EmbeddingsSearchResultsResolver, error)
IsContextRequiredForChatQuery(ctx context.Context, args IsContextRequiredForChatQueryInputArgs) (bool, error)
RepoEmbeddingJobs(ctx context.Context, args ListRepoEmbeddingJobsArgs) (*graphqlutil.ConnectionResolver[RepoEmbeddingJobResolver], error)
@ -34,16 +33,9 @@ type EmbeddingsSearchInputArgs struct {
TextResultsCount int32
}
type EmbeddingsMultiSearchInputArgs struct {
Repos []graphql.ID
Query string
CodeResultsCount int32
TextResultsCount int32
}
type EmbeddingsSearchResultsResolver interface {
CodeResults(ctx context.Context) ([]EmbeddingsSearchResultResolver, error)
TextResults(ctx context.Context) ([]EmbeddingsSearchResultResolver, error)
CodeResults(ctx context.Context) []EmbeddingsSearchResultResolver
TextResults(ctx context.Context) []EmbeddingsSearchResultResolver
}
type EmbeddingsSearchResultResolver interface {

View File

@ -22,31 +22,6 @@ extend type Query {
"""
textResultsCount: Int!
): EmbeddingsSearchResults!
"""
Experimental: Searches a set of repositories for similar code and text results using embeddings.
We separated code and text results because text results tended to always feature at the top of the combined results,
and didn't leave room for the code.
"""
embeddingsMultiSearch(
"""
The repository to search.
"""
repos: [ID!]!
"""
The query used for embeddings search.
"""
query: String!
"""
The number of code results to return.
"""
codeResultsCount: Int!
"""
The number of text results to return. Text results contain Markdown files and similar file types primarily used for writing documentation.
"""
textResultsCount: Int!
): EmbeddingsSearchResults!
"""
Experimental: Determines whether the given query requires further context before it can be answered.
For example:

View File

@ -8,7 +8,6 @@ go_library(
visibility = ["//visibility:public"],
deps = [
"//enterprise/internal/embeddings",
"//internal/api",
"//lib/errors",
],
)

View File

@ -11,7 +11,6 @@ import (
"strings"
"github.com/sourcegraph/sourcegraph/enterprise/internal/embeddings"
"github.com/sourcegraph/sourcegraph/internal/api"
"github.com/sourcegraph/sourcegraph/lib/errors"
)
@ -22,7 +21,7 @@ import (
var fs embed.FS
type embeddingsSearcher interface {
Search(args embeddings.EmbeddingsSearchParameters) (*embeddings.EmbeddingCombinedSearchResults, error)
Search(args embeddings.EmbeddingsSearchParameters) (*embeddings.EmbeddingSearchResults, error)
}
// Run runs the evaluation and returns recall for the test data.
@ -46,10 +45,11 @@ func Run(searcher embeddingsSearcher) (float64, error) {
relevantFile := fields[1]
args := embeddings.EmbeddingsSearchParameters{
RepoNames: []api.RepoName{"github.com/sourcegraph/sourcegraph"},
RepoName: "github.com/sourcegraph/sourcegraph",
Query: query,
CodeResultsCount: 20,
TextResultsCount: 2,
Debug: true,
}
results, err := searcher.Search(args)
@ -70,7 +70,11 @@ func Run(searcher embeddingsSearcher) (float64, error) {
fmt.Printf(" ")
}
fmt.Printf("%d. %s", i+1, result.FileName)
fmt.Printf(" (%s)\n", result.ScoreDetails.String())
if result.Debug != "" {
fmt.Printf(" (%s)\n", result.Debug)
} else {
fmt.Print("\n")
}
}
fmt.Println()
if fileFound {
@ -99,7 +103,7 @@ func NewClient(url string) *client {
}
}
func (c *client) Search(args embeddings.EmbeddingsSearchParameters) (*embeddings.EmbeddingCombinedSearchResults, error) {
func (c *client) Search(args embeddings.EmbeddingsSearchParameters) (*embeddings.EmbeddingSearchResults, error) {
b, err := json.Marshal(args)
if err != nil {
return nil, err
@ -126,7 +130,7 @@ func (c *client) Search(args embeddings.EmbeddingsSearchParameters) (*embeddings
return nil, err
}
res := embeddings.EmbeddingCombinedSearchResults{}
res := embeddings.EmbeddingSearchResults{}
err = json.Unmarshal(body, &res)
if err != nil {
return nil, errors.Wrap(err, "failed to unmarshal response")

View File

@ -35,6 +35,7 @@ go_library(
"//internal/env",
"//internal/errcode",
"//internal/featureflag",
"//internal/gitserver",
"//internal/goroutine",
"//internal/honey",
"//internal/httpserver",
@ -78,8 +79,8 @@ go_test(
srcs = [
"context_detection_test.go",
"context_qa_test.go",
"main_test.go",
"repo_embedding_index_cache_test.go",
"search_test.go",
],
embed = [":shared"],
embedsrcs = [
@ -92,11 +93,10 @@ go_test(
"//enterprise/internal/embeddings/background/repo",
"//internal/api",
"//internal/database",
"//internal/endpoint",
"//internal/types",
"//internal/uploadstore/mocks",
"//lib/errors",
"@com_github_sourcegraph_log//logtest",
"@com_github_sourcegraph_log//:log",
"@com_github_stretchr_testify//require",
],
)

View File

@ -13,6 +13,8 @@ import (
"path/filepath"
"testing"
"github.com/sourcegraph/log"
"github.com/sourcegraph/sourcegraph/enterprise/cmd/embeddings/qa"
"github.com/sourcegraph/sourcegraph/enterprise/internal/embeddings"
"github.com/sourcegraph/sourcegraph/internal/api"
@ -32,6 +34,7 @@ func TestRecall(t *testing.T) {
}
ctx := context.Background()
logger := log.NoOp()
// Set up mock functions
queryEmbeddings, err := loadQueryEmbeddings(t)
@ -56,13 +59,20 @@ func TestRecall(t *testing.T) {
return embeddings.DownloadRepoEmbeddingIndex(context.Background(), mockStore, string(key))
}
// We only care about the file names in this test.
mockReadFile := func(ctx context.Context, repoName api.RepoName, revision api.CommitID, fileName string) ([]byte, error) {
return []byte{}, nil
}
// Weaviate is disabled per default. We don't need it for this test.
weaviate := &weaviateClient{}
searcher := func(args embeddings.EmbeddingsSearchParameters) (*embeddings.EmbeddingCombinedSearchResults, error) {
return searchRepoEmbeddingIndexes(
searcher := func(args embeddings.EmbeddingsSearchParameters) (*embeddings.EmbeddingSearchResults, error) {
return searchRepoEmbeddingIndex(
ctx,
logger,
args,
mockReadFile,
getRepoEmbeddingIndex,
getQueryEmbedding,
weaviate,
@ -110,8 +120,8 @@ func loadQueryEmbeddings(t *testing.T) (map[string][]float32, error) {
return m, nil
}
type embeddingsSearcherFunc func(args embeddings.EmbeddingsSearchParameters) (*embeddings.EmbeddingCombinedSearchResults, error)
type embeddingsSearcherFunc func(args embeddings.EmbeddingsSearchParameters) (*embeddings.EmbeddingSearchResults, error)
func (f embeddingsSearcherFunc) Search(args embeddings.EmbeddingsSearchParameters) (*embeddings.EmbeddingCombinedSearchResults, error) {
func (f embeddingsSearcherFunc) Search(args embeddings.EmbeddingsSearchParameters) (*embeddings.EmbeddingSearchResults, error) {
return f(args)
}

View File

@ -19,6 +19,7 @@ import (
"github.com/sourcegraph/sourcegraph/enterprise/internal/embeddings"
"github.com/sourcegraph/sourcegraph/enterprise/internal/embeddings/background/repo"
"github.com/sourcegraph/sourcegraph/internal/actor"
"github.com/sourcegraph/sourcegraph/internal/api"
"github.com/sourcegraph/sourcegraph/internal/authz"
"github.com/sourcegraph/sourcegraph/internal/conf"
"github.com/sourcegraph/sourcegraph/internal/conf/conftypes"
@ -26,6 +27,7 @@ import (
connections "github.com/sourcegraph/sourcegraph/internal/database/connections/live"
"github.com/sourcegraph/sourcegraph/internal/errcode"
"github.com/sourcegraph/sourcegraph/internal/featureflag"
"github.com/sourcegraph/sourcegraph/internal/gitserver"
"github.com/sourcegraph/sourcegraph/internal/goroutine"
"github.com/sourcegraph/sourcegraph/internal/honey"
"github.com/sourcegraph/sourcegraph/internal/httpserver"
@ -56,6 +58,7 @@ func Main(ctx context.Context, observationCtx *observation.Context, ready servic
repoEmbeddingJobsStore := repo.NewRepoEmbeddingJobsStore(db)
// Run setup
gitserverClient := gitserver.NewClient()
uploadStore, err := embeddings.NewEmbeddingsUploadStore(ctx, observationCtx, config.EmbeddingsUploadStoreConfig)
if err != nil {
return err
@ -66,6 +69,10 @@ func Main(ctx context.Context, observationCtx *observation.Context, ready servic
return errors.Wrap(err, "creating sub-repo client")
}
readFile := func(ctx context.Context, repoName api.RepoName, revision api.CommitID, fileName string) ([]byte, error) {
return gitserverClient.ReadFile(ctx, authz.DefaultSubRepoPermsChecker, repoName, revision, fileName)
}
indexGetter, err := NewCachedEmbeddingIndexGetter(
repoStore,
repoEmbeddingJobsStore,
@ -85,6 +92,7 @@ func Main(ctx context.Context, observationCtx *observation.Context, ready servic
weaviate := newWeaviateClient(
logger,
readFile,
getQueryEmbedding,
config.WeaviateURL,
)
@ -92,7 +100,7 @@ func Main(ctx context.Context, observationCtx *observation.Context, ready servic
getContextDetectionEmbeddingIndex := getCachedContextDetectionEmbeddingIndex(uploadStore)
// Create HTTP server
handler := NewHandler(logger, indexGetter.Get, getQueryEmbedding, weaviate, getContextDetectionEmbeddingIndex)
handler := NewHandler(logger, readFile, indexGetter.Get, getQueryEmbedding, weaviate, getContextDetectionEmbeddingIndex)
handler = handlePanic(logger, handler)
handler = featureflag.Middleware(db.FeatureFlags(), handler)
handler = trace.HTTPMiddleware(logger, handler, conf.DefaultClient())
@ -114,6 +122,7 @@ func Main(ctx context.Context, observationCtx *observation.Context, ready servic
func NewHandler(
logger log.Logger,
readFile readFileFn,
getRepoEmbeddingIndex getRepoEmbeddingIndexFn,
getQueryEmbedding getQueryEmbeddingFn,
weaviate *weaviateClient,
@ -134,7 +143,7 @@ func NewHandler(
return
}
res, err := searchRepoEmbeddingIndexes(r.Context(), args, getRepoEmbeddingIndex, getQueryEmbedding, weaviate)
res, err := searchRepoEmbeddingIndex(r.Context(), logger, args, readFile, getRepoEmbeddingIndex, getQueryEmbedding, weaviate)
if errcode.IsNotFound(err) {
http.Error(w, err.Error(), http.StatusBadRequest)
return

View File

@ -1,235 +0,0 @@
package shared
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/sourcegraph/log/logtest"
"github.com/sourcegraph/sourcegraph/enterprise/internal/embeddings"
"github.com/sourcegraph/sourcegraph/internal/api"
"github.com/sourcegraph/sourcegraph/internal/endpoint"
"github.com/stretchr/testify/require"
)
func TestEmbeddingsSearch(t *testing.T) {
logger := logtest.Scoped(t)
makeIndex := func(name api.RepoName, w int8) *embeddings.RepoEmbeddingIndex {
return &embeddings.RepoEmbeddingIndex{
RepoName: name,
Revision: "",
CodeIndex: embeddings.EmbeddingIndex{
Embeddings: []int8{
w, 0, 0, 0,
0, w, 0, 0,
0, 0, w, 0,
0, 0, 0, w,
},
ColumnDimension: 4,
RowMetadata: []embeddings.RepoEmbeddingRowMetadata{
{FileName: "codefile1", StartLine: 0, EndLine: 1},
{FileName: "codefile2", StartLine: 0, EndLine: 1},
{FileName: "codefile3", StartLine: 0, EndLine: 1},
{FileName: "codefile4", StartLine: 0, EndLine: 1},
},
},
TextIndex: embeddings.EmbeddingIndex{
Embeddings: []int8{
w, 0, 0, 0,
0, w, 0, 0,
0, 0, w, 0,
0, 0, 0, w,
},
ColumnDimension: 4,
RowMetadata: []embeddings.RepoEmbeddingRowMetadata{
{FileName: "textfile1", StartLine: 0, EndLine: 1},
{FileName: "textfile2", StartLine: 0, EndLine: 1},
{FileName: "textfile3", StartLine: 0, EndLine: 1},
{FileName: "textfile4", StartLine: 0, EndLine: 1},
},
},
}
}
indexes := map[api.RepoName]*embeddings.RepoEmbeddingIndex{
"repo1": makeIndex("repo1", 1),
"repo2": makeIndex("repo2", 2),
"repo3": makeIndex("repo3", 3),
"repo4": makeIndex("repo4", 4),
}
getRepoEmbeddingIndex := func(_ context.Context, repoName api.RepoName) (*embeddings.RepoEmbeddingIndex, error) {
return indexes[repoName], nil
}
getQueryEmbedding := func(_ context.Context, query string) ([]float32, error) {
switch query {
case "one":
return []float32{1, 0, 0, 0}, nil
case "two":
return []float32{0, 1, 0, 0}, nil
case "three":
return []float32{0, 0, 1, 0}, nil
case "four":
return []float32{0, 0, 1, 1}, nil
default:
panic("unknown")
}
}
getContextDetectionEmbeddingIndex := func(context.Context) (*embeddings.ContextDetectionEmbeddingIndex, error) {
return nil, nil
}
server1 := httptest.NewServer(NewHandler(
logger,
getRepoEmbeddingIndex,
getQueryEmbedding,
nil,
getContextDetectionEmbeddingIndex,
))
server2 := httptest.NewServer(NewHandler(
logger,
getRepoEmbeddingIndex,
getQueryEmbedding,
nil,
getContextDetectionEmbeddingIndex,
))
client := embeddings.NewClient(endpoint.Static(server1.URL, server2.URL), http.DefaultClient)
{
// First test: we should return results for file1 based on the query.
// The rankings should have repo4 highest because it has the largest weighted
// embeddings.
params := embeddings.EmbeddingsSearchParameters{
RepoNames: []api.RepoName{"repo1", "repo2", "repo3", "repo4"},
RepoIDs: []api.RepoID{1, 2, 3, 4},
Query: "one",
CodeResultsCount: 2,
TextResultsCount: 2,
UseDocumentRanks: false,
}
results, err := client.Search(context.Background(), params)
require.NoError(t, err)
require.Equal(t, &embeddings.EmbeddingCombinedSearchResults{
CodeResults: embeddings.EmbeddingSearchResults{{
RepoName: "repo4",
FileName: "codefile1",
StartLine: 0,
EndLine: 1,
ScoreDetails: embeddings.SearchScoreDetails{Score: 1016, SimilarityScore: 1016},
}, {
RepoName: "repo3",
FileName: "codefile1",
StartLine: 0,
EndLine: 1,
ScoreDetails: embeddings.SearchScoreDetails{Score: 762, SimilarityScore: 762},
}},
TextResults: embeddings.EmbeddingSearchResults{{
RepoName: "repo4",
FileName: "textfile1",
StartLine: 0,
EndLine: 1,
ScoreDetails: embeddings.SearchScoreDetails{Score: 1016, SimilarityScore: 1016},
}, {
RepoName: "repo3",
FileName: "textfile1",
StartLine: 0,
EndLine: 1,
ScoreDetails: embeddings.SearchScoreDetails{Score: 762, SimilarityScore: 762},
}},
}, results)
}
{
// Second test: providing a subset of repos should only search those repos
params := embeddings.EmbeddingsSearchParameters{
RepoNames: []api.RepoName{"repo1", "repo3"},
RepoIDs: []api.RepoID{1, 2, 3, 4},
Query: "one",
CodeResultsCount: 2,
TextResultsCount: 2,
UseDocumentRanks: false,
}
results, err := client.Search(context.Background(), params)
require.NoError(t, err)
require.Equal(t, &embeddings.EmbeddingCombinedSearchResults{
CodeResults: embeddings.EmbeddingSearchResults{{
RepoName: "repo3",
FileName: "codefile1",
StartLine: 0,
EndLine: 1,
ScoreDetails: embeddings.SearchScoreDetails{Score: 762, SimilarityScore: 762},
}, {
RepoName: "repo1",
FileName: "codefile1",
StartLine: 0,
EndLine: 1,
ScoreDetails: embeddings.SearchScoreDetails{Score: 254, SimilarityScore: 254},
}},
TextResults: embeddings.EmbeddingSearchResults{{
RepoName: "repo3",
FileName: "textfile1",
StartLine: 0,
EndLine: 1,
ScoreDetails: embeddings.SearchScoreDetails{Score: 762, SimilarityScore: 762},
}, {
RepoName: "repo1",
FileName: "textfile1",
StartLine: 0,
EndLine: 1,
ScoreDetails: embeddings.SearchScoreDetails{Score: 254, SimilarityScore: 254},
}},
}, results)
}
{
// Third test: try a different file just to be safe
params := embeddings.EmbeddingsSearchParameters{
RepoNames: []api.RepoName{"repo1", "repo2", "repo3", "repo4"},
RepoIDs: []api.RepoID{1, 2, 3, 4},
Query: "three",
CodeResultsCount: 2,
TextResultsCount: 2,
UseDocumentRanks: false,
}
results, err := client.Search(context.Background(), params)
require.NoError(t, err)
require.Equal(t, &embeddings.EmbeddingCombinedSearchResults{
CodeResults: embeddings.EmbeddingSearchResults{{
RepoName: "repo4",
FileName: "codefile3",
StartLine: 0,
EndLine: 1,
ScoreDetails: embeddings.SearchScoreDetails{Score: 1016, SimilarityScore: 1016},
}, {
RepoName: "repo3",
FileName: "codefile3",
StartLine: 0,
EndLine: 1,
ScoreDetails: embeddings.SearchScoreDetails{Score: 762, SimilarityScore: 762},
}},
TextResults: embeddings.EmbeddingSearchResults{{
RepoName: "repo4",
FileName: "textfile3",
StartLine: 0,
EndLine: 1,
ScoreDetails: embeddings.SearchScoreDetails{Score: 1016, SimilarityScore: 1016},
}, {
RepoName: "repo3",
FileName: "textfile3",
StartLine: 0,
EndLine: 1,
ScoreDetails: embeddings.SearchScoreDetails{Score: 762, SimilarityScore: 762},
}},
}, results)
}
}

View File

@ -2,66 +2,145 @@ package shared
import (
"context"
"fmt"
"os"
"runtime"
"strings"
"github.com/sourcegraph/log"
"github.com/sourcegraph/sourcegraph/enterprise/internal/embeddings"
"github.com/sourcegraph/sourcegraph/internal/api"
"github.com/sourcegraph/sourcegraph/lib/errors"
)
const SIMILARITY_SEARCH_MIN_ROWS_TO_SPLIT = 1000
type readFileFn func(ctx context.Context, repoName api.RepoName, revision api.CommitID, fileName string) ([]byte, error)
type getRepoEmbeddingIndexFn func(ctx context.Context, repoName api.RepoName) (*embeddings.RepoEmbeddingIndex, error)
type getQueryEmbeddingFn func(ctx context.Context, query string) ([]float32, error)
func searchRepoEmbeddingIndexes(
func searchRepoEmbeddingIndex(
ctx context.Context,
logger log.Logger,
params embeddings.EmbeddingsSearchParameters,
readFile readFileFn,
getRepoEmbeddingIndex getRepoEmbeddingIndexFn,
getQueryEmbedding getQueryEmbeddingFn,
weaviate *weaviateClient,
) (*embeddings.EmbeddingCombinedSearchResults, error) {
) (*embeddings.EmbeddingSearchResults, error) {
if weaviate.Use(ctx) {
return weaviate.Search(ctx, params)
}
embeddingIndex, err := getRepoEmbeddingIndex(ctx, params.RepoName)
if err != nil {
return nil, errors.Wrapf(err, "getting repo embedding index for repo %q", params.RepoName)
}
floatQuery, err := getQueryEmbedding(ctx, params.Query)
if err != nil {
return nil, errors.Wrap(err, "getting query embedding")
}
embeddedQuery := embeddings.Quantize(floatQuery)
workerOpts := embeddings.WorkerOptions{
NumWorkers: runtime.GOMAXPROCS(0),
MinRowsToSplit: SIMILARITY_SEARCH_MIN_ROWS_TO_SPLIT,
}
searchOpts := embeddings.SearchOptions{
opts := embeddings.SearchOptions{
Debug: params.Debug,
UseDocumentRanks: params.UseDocumentRanks,
}
var result embeddings.EmbeddingCombinedSearchResults
codeResults := searchEmbeddingIndex(ctx, logger, embeddingIndex.RepoName, embeddingIndex.Revision, &embeddingIndex.CodeIndex, readFile, embeddedQuery, params.CodeResultsCount, opts)
textResults := searchEmbeddingIndex(ctx, logger, embeddingIndex.RepoName, embeddingIndex.Revision, &embeddingIndex.TextIndex, readFile, embeddedQuery, params.TextResultsCount, opts)
for i, repoName := range params.RepoNames {
if weaviate.Use(ctx) {
codeResults, textResults, err := weaviate.Search(ctx, repoName, params.RepoIDs[i], params.Query, params.CodeResultsCount, params.TextResultsCount)
if err != nil {
return nil, err
return &embeddings.EmbeddingSearchResults{CodeResults: codeResults, TextResults: textResults}, nil
}
const SIMILARITY_SEARCH_MIN_ROWS_TO_SPLIT = 1000
func searchEmbeddingIndex(
ctx context.Context,
logger log.Logger,
repoName api.RepoName,
revision api.CommitID,
index *embeddings.EmbeddingIndex,
readFile readFileFn,
query []int8,
nResults int,
opts embeddings.SearchOptions,
) []embeddings.EmbeddingSearchResult {
numWorkers := runtime.GOMAXPROCS(0)
results := index.SimilaritySearch(query, nResults, embeddings.WorkerOptions{NumWorkers: numWorkers, MinRowsToSplit: SIMILARITY_SEARCH_MIN_ROWS_TO_SPLIT}, opts)
return filterAndHydrateContent(
ctx,
logger,
repoName,
revision,
readFile,
opts.Debug,
results,
)
}
// filterAndHydrateContent will mutate unfiltered to populate the Content
// field. If we fail to read a file (eg permission issues) we will remove the
// result. As such the returned slice should be used.
func filterAndHydrateContent(
ctx context.Context,
logger log.Logger,
repoName api.RepoName,
revision api.CommitID,
readFile readFileFn,
debug bool,
unfiltered []embeddings.SimilaritySearchResult,
) []embeddings.EmbeddingSearchResult {
filtered := make([]embeddings.EmbeddingSearchResult, 0, len(unfiltered))
for _, result := range unfiltered {
fileContent, err := readFile(ctx, repoName, revision, result.FileName)
if err != nil {
if !os.IsNotExist(err) {
logger.Error("error reading file", log.String("repoName", string(repoName)), log.String("revision", string(revision)), log.String("fileName", result.FileName), log.Error(err))
}
result.CodeResults.MergeTruncate(codeResults, params.CodeResultsCount)
result.TextResults.MergeTruncate(textResults, params.TextResultsCount)
continue
}
lines := strings.Split(string(fileContent), "\n")
embeddingIndex, err := getRepoEmbeddingIndex(ctx, repoName)
if err != nil {
return nil, errors.Wrapf(err, "getting repo embedding index for repo %q", repoName)
// Sanity check: check that startLine and endLine are within 0 and len(lines).
startLine := max(0, min(len(lines), result.StartLine))
endLine := max(0, min(len(lines), result.EndLine))
content := strings.Join(lines[startLine:endLine], "\n")
var debugString string
if debug {
debugString = fmt.Sprintf("score:%d, similarity:%d, rank:%d", result.Score(), result.SimilarityScore, result.RankScore)
}
codeResults := embeddingIndex.CodeIndex.SimilaritySearch(embeddedQuery, params.CodeResultsCount, workerOpts, searchOpts, embeddingIndex.RepoName, embeddingIndex.Revision)
textResults := embeddingIndex.TextIndex.SimilaritySearch(embeddedQuery, params.TextResultsCount, workerOpts, searchOpts, embeddingIndex.RepoName, embeddingIndex.Revision)
result.CodeResults.MergeTruncate(codeResults, params.CodeResultsCount)
result.TextResults.MergeTruncate(textResults, params.TextResultsCount)
filtered = append(filtered, embeddings.EmbeddingSearchResult{
RepoName: repoName,
Revision: revision,
RepoEmbeddingRowMetadata: embeddings.RepoEmbeddingRowMetadata{
FileName: result.FileName,
StartLine: startLine,
EndLine: endLine,
},
Debug: debugString,
Content: content,
})
}
return &result, nil
return filtered
}
func min(a, b int) int {
if a < b {
return a
}
return b
}
func max(a, b int) int {
if a > b {
return a
}
return b
}

View File

@ -0,0 +1,38 @@
package shared
import (
"context"
"testing"
"github.com/sourcegraph/log"
"github.com/sourcegraph/sourcegraph/enterprise/internal/embeddings"
"github.com/sourcegraph/sourcegraph/internal/api"
)
func TestFilterAndHydrateContent_emptyFile(t *testing.T) {
// Set up test data
repoName := api.RepoName("example/repo")
revision := api.CommitID("abc123")
debug := false
unfiltered := []embeddings.SimilaritySearchResult{
{
RepoEmbeddingRowMetadata: embeddings.RepoEmbeddingRowMetadata{
FileName: "file.txt",
StartLine: 5,
EndLine: 20,
},
},
}
// Define a mock readFile function that returns an empty string
readFile := func(ctx context.Context, repoName api.RepoName, revision api.CommitID, fileName string) ([]byte, error) {
return []byte{}, nil
}
filtered := filterAndHydrateContent(context.Background(), log.NoOp(), repoName, revision, readFile, debug, unfiltered)
if len(filtered) != 1 {
t.Errorf("Expected 1 filtered result, but got %d elements", len(filtered))
}
}

View File

@ -19,6 +19,7 @@ import (
type weaviateClient struct {
logger log.Logger
readFile readFileFn
getQueryEmbedding getQueryEmbeddingFn
client *weaviate.Client
@ -27,6 +28,7 @@ type weaviateClient struct {
func newWeaviateClient(
logger log.Logger,
readFile readFileFn,
getQueryEmbedding getQueryEmbeddingFn,
url *url.URL,
) *weaviateClient {
@ -43,6 +45,7 @@ func newWeaviateClient(
return &weaviateClient{
logger: logger.Scoped("weaviate", "client for weaviate embedding index"),
readFile: readFile,
getQueryEmbedding: getQueryEmbedding,
client: client,
clientErr: err,
@ -53,14 +56,14 @@ func (w *weaviateClient) Use(ctx context.Context) bool {
return featureflag.FromContext(ctx).GetBoolOr("search-weaviate", false)
}
func (w *weaviateClient) Search(ctx context.Context, repoName api.RepoName, repoID api.RepoID, query string, codeResultsCount, textResultsCount int) (codeResults, textResults []embeddings.EmbeddingSearchResult, _ error) {
func (w *weaviateClient) Search(ctx context.Context, params embeddings.EmbeddingsSearchParameters) (*embeddings.EmbeddingSearchResults, error) {
if w.clientErr != nil {
return nil, nil, w.clientErr
return nil, w.clientErr
}
embeddedQuery, err := w.getQueryEmbedding(ctx, query)
embeddedQuery, err := w.getQueryEmbedding(ctx, params.Query)
if err != nil {
return nil, nil, errors.Wrap(err, "getting query embedding")
return nil, errors.Wrap(err, "getting query embedding")
}
queryBuilder := func(klass string, limit int) *graphql.GetBuilder {
@ -72,9 +75,6 @@ func (w *weaviateClient) Search(ctx context.Context, repoName api.RepoName, repo
{Name: "start_line"},
{Name: "end_line"},
{Name: "revision"},
{Name: "_additional", Fields: []graphql.Field{
{Name: "distance"},
}},
}...).
WithLimit(limit)
}
@ -86,7 +86,7 @@ func (w *weaviateClient) Search(ctx context.Context, repoName api.RepoName, repo
return nil
}
srs := make([]embeddings.EmbeddingSearchResult, 0, len(code))
srs := make([]embeddings.SimilaritySearchResult, 0, len(code))
revision := ""
for _, c := range code {
cMap := c.(map[string]any)
@ -96,48 +96,50 @@ func (w *weaviateClient) Search(ctx context.Context, repoName api.RepoName, repo
if revision == "" {
revision = rev
} else {
w.logger.Warn("inconsistent revisions returned for an embedded repository", log.Int("repoid", int(repoID)), log.String("filename", fileName), log.String("revision1", revision), log.String("revision2", rev))
w.logger.Warn("inconsistent revisions returned for an embedded repository", log.Int("repoid", int(params.RepoID)), log.String("filename", fileName), log.String("revision1", revision), log.String("revision2", rev))
}
}
// multiply by half max int32 since distance will always be between 0 and 2
similarity := int32(cMap["_additional"].(map[string]any)["distance"].(float64) * (1073741823))
srs = append(srs, embeddings.EmbeddingSearchResult{
RepoName: repoName,
Revision: api.CommitID(revision),
FileName: fileName,
StartLine: int(cMap["start_line"].(float64)),
EndLine: int(cMap["end_line"].(float64)),
ScoreDetails: embeddings.SearchScoreDetails{
Score: similarity,
SimilarityScore: similarity,
srs = append(srs, embeddings.SimilaritySearchResult{
RepoEmbeddingRowMetadata: embeddings.RepoEmbeddingRowMetadata{
FileName: fileName,
StartLine: int(cMap["start_line"].(float64)),
EndLine: int(cMap["end_line"].(float64)),
},
})
}
return srs
commit := api.CommitID(revision)
if commit == "" {
w.logger.Warn("no revision set for an embedded repository", log.Int("repoid", int(params.RepoID)))
commit = api.CommitID("HEAD")
}
return filterAndHydrateContent(ctx, w.logger, params.RepoName, commit, w.readFile, false, srs)
}
// We partition the indexes by type and repository. Each class in
// weaviate is its own index, so we achieve partitioning by a class
// per repo and type.
codeClass := fmt.Sprintf("Code_%d", repoID)
textClass := fmt.Sprintf("Text_%d", repoID)
codeClass := fmt.Sprintf("Code_%d", params.RepoID)
textClass := fmt.Sprintf("Text_%d", params.RepoID)
res, err := w.client.GraphQL().MultiClassGet().
AddQueryClass(queryBuilder(codeClass, codeResultsCount)).
AddQueryClass(queryBuilder(textClass, textResultsCount)).
AddQueryClass(queryBuilder(codeClass, params.CodeResultsCount)).
AddQueryClass(queryBuilder(textClass, params.TextResultsCount)).
Do(ctx)
if err != nil {
return nil, nil, errors.Wrap(err, "doing weaviate request")
return nil, errors.Wrap(err, "doing weaviate request")
}
if len(res.Errors) > 0 {
return nil, nil, weaviateGraphQLError(res.Errors)
return nil, weaviateGraphQLError(res.Errors)
}
return extractResults(res, codeClass), extractResults(res, textClass), nil
return &embeddings.EmbeddingSearchResults{
CodeResults: extractResults(res, codeClass),
TextResults: extractResults(res, textClass),
}, nil
}
type weaviateGraphQLError []*models.GraphQLError

View File

@ -26,7 +26,8 @@ func Init(
repoEmbeddingsStore := repo.NewRepoEmbeddingJobsStore(db)
contextDetectionEmbeddingsStore := contextdetection.NewContextDetectionEmbeddingJobsStore(db)
gitserverClient := gitserver.NewClient()
embeddingsClient := embeddings.NewDefaultClient()
embeddingsClient := embeddings.NewClient()
enterpriseServices.EmbeddingsResolver = resolvers.NewResolver(
db,
observationCtx.Logger,

View File

@ -1,4 +1,4 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
load("@io_bazel_rules_go//go:def.bzl", "go_library")
go_library(
name = "resolvers",
@ -17,7 +17,6 @@ go_library(
"//enterprise/internal/embeddings/background/repo",
"//internal/api",
"//internal/auth",
"//internal/authz",
"//internal/cody",
"//internal/conf",
"//internal/database",
@ -27,31 +26,6 @@ go_library(
"//lib/errors",
"@com_github_graph_gophers_graphql_go//:graphql-go",
"@com_github_graph_gophers_graphql_go//relay",
"@com_github_sourcegraph_conc//pool",
"@com_github_sourcegraph_log//:log",
],
)
go_test(
name = "resolvers_test",
srcs = ["resolvers_test.go"],
embed = [":resolvers"],
deps = [
"//cmd/frontend/graphqlbackend",
"//enterprise/internal/embeddings",
"//enterprise/internal/embeddings/background/contextdetection",
"//enterprise/internal/embeddings/background/repo",
"//internal/actor",
"//internal/api",
"//internal/authz",
"//internal/conf",
"//internal/database",
"//internal/featureflag",
"//internal/gitserver",
"//internal/types",
"//schema",
"@com_github_graph_gophers_graphql_go//:graphql-go",
"@com_github_sourcegraph_log//logtest",
"@com_github_stretchr_testify//require",
],
)

View File

@ -1,12 +1,8 @@
package resolvers
import (
"bytes"
"context"
"os"
"github.com/graph-gophers/graphql-go"
"github.com/sourcegraph/conc/pool"
"github.com/sourcegraph/log"
"github.com/sourcegraph/sourcegraph/lib/errors"
@ -19,7 +15,6 @@ import (
repobg "github.com/sourcegraph/sourcegraph/enterprise/internal/embeddings/background/repo"
"github.com/sourcegraph/sourcegraph/internal/api"
"github.com/sourcegraph/sourcegraph/internal/auth"
"github.com/sourcegraph/sourcegraph/internal/authz"
"github.com/sourcegraph/sourcegraph/internal/cody"
"github.com/sourcegraph/sourcegraph/internal/conf"
"github.com/sourcegraph/sourcegraph/internal/database"
@ -30,7 +25,7 @@ func NewResolver(
db database.DB,
logger log.Logger,
gitserverClient gitserver.Client,
embeddingsClient embeddings.Client,
embeddingsClient *embeddings.Client,
repoStore repobg.RepoEmbeddingJobsStore,
contextDetectionStore contextdetectionbg.ContextDetectionEmbeddingJobsStore,
) graphqlbackend.EmbeddingsResolver {
@ -48,22 +43,13 @@ type Resolver struct {
db database.DB
logger log.Logger
gitserverClient gitserver.Client
embeddingsClient embeddings.Client
embeddingsClient *embeddings.Client
repoEmbeddingJobsStore repobg.RepoEmbeddingJobsStore
contextDetectionJobsStore contextdetectionbg.ContextDetectionEmbeddingJobsStore
emails backend.UserEmailsService
}
func (r *Resolver) EmbeddingsSearch(ctx context.Context, args graphqlbackend.EmbeddingsSearchInputArgs) (graphqlbackend.EmbeddingsSearchResultsResolver, error) {
return r.EmbeddingsMultiSearch(ctx, graphqlbackend.EmbeddingsMultiSearchInputArgs{
Repos: []graphql.ID{args.Repo},
Query: args.Query,
CodeResultsCount: args.CodeResultsCount,
TextResultsCount: args.TextResultsCount,
})
}
func (r *Resolver) EmbeddingsMultiSearch(ctx context.Context, args graphqlbackend.EmbeddingsMultiSearchInputArgs) (graphqlbackend.EmbeddingsSearchResultsResolver, error) {
if !conf.EmbeddingsEnabled() {
return nil, errors.New("embeddings are not configured or disabled")
}
@ -76,28 +62,19 @@ func (r *Resolver) EmbeddingsMultiSearch(ctx context.Context, args graphqlbacken
return nil, err
}
repoIDs := make([]api.RepoID, len(args.Repos))
for i, repo := range args.Repos {
repoID, err := graphqlbackend.UnmarshalRepositoryID(repo)
if err != nil {
return nil, err
}
repoIDs[i] = repoID
}
repos, err := r.db.Repos().GetByIDs(ctx, repoIDs...)
repoID, err := graphqlbackend.UnmarshalRepositoryID(args.Repo)
if err != nil {
return nil, err
}
repoNames := make([]api.RepoName, len(repos))
for i, repo := range repos {
repoNames[i] = repo.Name
repo, err := r.db.Repos().Get(ctx, repoID)
if err != nil {
return nil, err
}
results, err := r.embeddingsClient.Search(ctx, embeddings.EmbeddingsSearchParameters{
RepoNames: repoNames,
RepoIDs: repoIDs,
RepoName: repo.Name,
RepoID: repoID,
Query: args.Query,
CodeResultsCount: int(args.CodeResultsCount),
TextResultsCount: int(args.TextResultsCount),
@ -106,11 +83,7 @@ func (r *Resolver) EmbeddingsMultiSearch(ctx context.Context, args graphqlbacken
return nil, err
}
return &embeddingsSearchResultsResolver{
results: results,
gitserver: r.gitserverClient,
logger: r.logger,
}, nil
return &embeddingsSearchResultsResolver{results}, nil
}
func (r *Resolver) IsContextRequiredForChatQuery(ctx context.Context, args graphqlbackend.IsContextRequiredForChatQueryInputArgs) (bool, error) {
@ -210,91 +183,27 @@ func (r *Resolver) ScheduleContextDetectionForEmbedding(ctx context.Context) (*g
}
type embeddingsSearchResultsResolver struct {
results *embeddings.EmbeddingCombinedSearchResults
gitserver gitserver.Client
logger log.Logger
results *embeddings.EmbeddingSearchResults
}
func (r *embeddingsSearchResultsResolver) CodeResults(ctx context.Context) ([]graphqlbackend.EmbeddingsSearchResultResolver, error) {
return embeddingsSearchResultsToResolvers(ctx, r.logger, r.gitserver, r.results.CodeResults)
}
func (r *embeddingsSearchResultsResolver) TextResults(ctx context.Context) ([]graphqlbackend.EmbeddingsSearchResultResolver, error) {
return embeddingsSearchResultsToResolvers(ctx, r.logger, r.gitserver, r.results.TextResults)
}
func embeddingsSearchResultsToResolvers(
ctx context.Context,
logger log.Logger,
gs gitserver.Client,
results []embeddings.EmbeddingSearchResult,
) ([]graphqlbackend.EmbeddingsSearchResultResolver, error) {
allContents := make([][]byte, len(results))
allErrors := make([]error, len(results))
{ // Fetch contents in parallel because fetching them serially can be slow.
p := pool.New().WithMaxGoroutines(8)
for i, result := range results {
i, result := i, result
p.Go(func() {
content, err := gs.ReadFile(ctx, authz.DefaultSubRepoPermsChecker, result.RepoName, result.Revision, result.FileName)
allContents[i] = content
allErrors[i] = err
})
}
p.Wait()
func (r *embeddingsSearchResultsResolver) CodeResults(ctx context.Context) []graphqlbackend.EmbeddingsSearchResultResolver {
codeResults := make([]graphqlbackend.EmbeddingsSearchResultResolver, len(r.results.CodeResults))
for idx, result := range r.results.CodeResults {
codeResults[idx] = &embeddingsSearchResultResolver{result}
}
resolvers := make([]graphqlbackend.EmbeddingsSearchResultResolver, 0, len(results))
{ // Merge the results with their contents, skipping any that errored when fetching the context.
for i, result := range results {
contents := allContents[i]
err := allErrors[i]
if err != nil {
if !os.IsNotExist(err) {
logger.Error(
"error reading file",
log.String("repoName", string(result.RepoName)),
log.String("revision", string(result.Revision)),
log.String("fileName", result.FileName),
log.Error(err),
)
}
continue
}
resolvers = append(resolvers, &embeddingsSearchResultResolver{
result: result,
content: string(extractLineRange(contents, result.StartLine, result.EndLine)),
})
}
}
return resolvers, nil
return codeResults
}
func extractLineRange(content []byte, startLine, endLine int) []byte {
lines := bytes.Split(content, []byte("\n"))
// Sanity check: check that startLine and endLine are within 0 and len(lines).
startLine = clamp(startLine, 0, len(lines))
endLine = clamp(endLine, 0, len(lines))
return bytes.Join(lines[startLine:endLine], []byte("\n"))
}
func clamp(input, min, max int) int {
if input > max {
return max
} else if input < min {
return min
func (r *embeddingsSearchResultsResolver) TextResults(ctx context.Context) []graphqlbackend.EmbeddingsSearchResultResolver {
textResults := make([]graphqlbackend.EmbeddingsSearchResultResolver, len(r.results.TextResults))
for idx, result := range r.results.TextResults {
textResults[idx] = &embeddingsSearchResultResolver{result}
}
return input
return textResults
}
type embeddingsSearchResultResolver struct {
result embeddings.EmbeddingSearchResult
content string
result embeddings.EmbeddingSearchResult
}
func (r *embeddingsSearchResultResolver) FileName(ctx context.Context) string {
@ -310,5 +219,5 @@ func (r *embeddingsSearchResultResolver) EndLine(ctx context.Context) int32 {
}
func (r *embeddingsSearchResultResolver) Content(ctx context.Context) string {
return r.content
return r.result.Content
}

View File

@ -1,122 +0,0 @@
package resolvers
import (
"context"
"os"
"testing"
"github.com/graph-gophers/graphql-go"
"github.com/sourcegraph/log/logtest"
"github.com/sourcegraph/sourcegraph/cmd/frontend/graphqlbackend"
"github.com/sourcegraph/sourcegraph/enterprise/internal/embeddings"
"github.com/sourcegraph/sourcegraph/enterprise/internal/embeddings/background/contextdetection"
"github.com/sourcegraph/sourcegraph/enterprise/internal/embeddings/background/repo"
"github.com/sourcegraph/sourcegraph/internal/actor"
"github.com/sourcegraph/sourcegraph/internal/api"
"github.com/sourcegraph/sourcegraph/internal/authz"
"github.com/sourcegraph/sourcegraph/internal/conf"
"github.com/sourcegraph/sourcegraph/internal/database"
"github.com/sourcegraph/sourcegraph/internal/featureflag"
"github.com/sourcegraph/sourcegraph/internal/gitserver"
"github.com/sourcegraph/sourcegraph/internal/types"
"github.com/sourcegraph/sourcegraph/schema"
"github.com/stretchr/testify/require"
)
func TestEmbeddingSearchResolver(t *testing.T) {
logger := logtest.Scoped(t)
mockDB := database.NewMockDB()
mockRepos := database.NewMockRepoStore()
mockRepos.GetByIDsFunc.SetDefaultReturn([]*types.Repo{{ID: 1, Name: "repo1"}}, nil)
mockDB.ReposFunc.SetDefaultReturn(mockRepos)
mockGitserver := gitserver.NewMockClient()
mockGitserver.ReadFileFunc.SetDefaultHook(func(_ context.Context, _ authz.SubRepoPermissionChecker, _ api.RepoName, _ api.CommitID, fileName string) ([]byte, error) {
if fileName == "testfile" {
return []byte("test\nfirst\nfour\nlines\nplus\nsome\nmore"), nil
}
return nil, os.ErrNotExist
})
mockEmbeddingsClient := embeddings.NewMockClient()
mockEmbeddingsClient.SearchFunc.SetDefaultReturn(&embeddings.EmbeddingCombinedSearchResults{
CodeResults: embeddings.EmbeddingSearchResults{{
FileName: "testfile",
StartLine: 0,
EndLine: 4,
}, {
FileName: "censored",
StartLine: 0,
EndLine: 4,
}},
}, nil)
repoEmbeddingJobsStore := repo.NewMockRepoEmbeddingJobsStore()
contextDetectionJobsStore := contextdetection.NewMockContextDetectionEmbeddingJobsStore()
resolver := NewResolver(
mockDB,
logger,
mockGitserver,
mockEmbeddingsClient,
repoEmbeddingJobsStore,
contextDetectionJobsStore,
)
conf.Mock(&conf.Unified{
SiteConfiguration: schema.SiteConfiguration{
Embeddings: &schema.Embeddings{Enabled: true},
Completions: &schema.Completions{Enabled: true},
},
})
ctx := actor.WithActor(context.Background(), actor.FromMockUser(1))
ffs := featureflag.NewMemoryStore(map[string]bool{"cody-experimental": true}, nil, nil)
ctx = featureflag.WithFlags(ctx, ffs)
results, err := resolver.EmbeddingsMultiSearch(ctx, graphqlbackend.EmbeddingsMultiSearchInputArgs{
Repos: []graphql.ID{graphqlbackend.MarshalRepositoryID(3)},
Query: "test",
CodeResultsCount: 1,
TextResultsCount: 1,
})
require.NoError(t, err)
codeResults, err := results.CodeResults(ctx)
require.NoError(t, err)
require.Len(t, codeResults, 1)
require.Equal(t, "test\nfirst\nfour\nlines", codeResults[0].Content(ctx))
}
func Test_extractLineRange(t *testing.T) {
cases := []struct {
input []byte
start, end int
output []byte
}{{
[]byte("zero\none\ntwo\nthree\n"),
0, 2,
[]byte("zero\none"),
}, {
[]byte("zero\none\ntwo\nthree\n"),
1, 2,
[]byte("one"),
}, {
[]byte("zero\none\ntwo\nthree\n"),
1, 2,
[]byte("one"),
}, {
[]byte(""),
1, 2,
[]byte(""),
}}
for _, tc := range cases {
t.Run("", func(t *testing.T) {
got := extractLineRange(tc.input, tc.start, tc.end)
require.Equal(t, tc.output, got)
})
}
}

View File

@ -12,7 +12,6 @@ go_library(
"dot_portable.go",
"index_name.go",
"index_storage.go",
"mocks_temp.go",
"quantize.go",
"similarity_search.go",
"tokens.go",
@ -32,7 +31,6 @@ go_library(
"//internal/uploadstore",
"//lib/errors",
"@com_github_sourcegraph_conc//:conc",
"@com_github_sourcegraph_conc//pool",
"@com_github_sourcegraph_log//:log",
"@org_golang_x_sync//errgroup",
] + select({

View File

@ -3,7 +3,6 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library")
go_library(
name = "contextdetection",
srcs = [
"mocks_temp.go",
"store.go",
"types.go",
],

View File

@ -1,292 +0,0 @@
// Code generated by go-mockgen 1.3.7; DO NOT EDIT.
//
// This file was generated by running `sg generate` (or `go-mockgen`) at the root of
// this repository. To add additional mocks to this or another package, add a new entry
// to the mockgen.yaml file in the root of this repository.
package contextdetection
import (
"context"
"sync"
basestore "github.com/sourcegraph/sourcegraph/internal/database/basestore"
)
// MockContextDetectionEmbeddingJobsStore is a mock implementation of the
// ContextDetectionEmbeddingJobsStore interface (from the package
// github.com/sourcegraph/sourcegraph/enterprise/internal/embeddings/background/contextdetection)
// used for unit testing.
type MockContextDetectionEmbeddingJobsStore struct {
// CreateContextDetectionEmbeddingJobFunc is an instance of a mock
// function object controlling the behavior of the method
// CreateContextDetectionEmbeddingJob.
CreateContextDetectionEmbeddingJobFunc *ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFunc
// HandleFunc is an instance of a mock function object controlling the
// behavior of the method Handle.
HandleFunc *ContextDetectionEmbeddingJobsStoreHandleFunc
}
// NewMockContextDetectionEmbeddingJobsStore creates a new mock of the
// ContextDetectionEmbeddingJobsStore interface. All methods return zero
// values for all results, unless overwritten.
func NewMockContextDetectionEmbeddingJobsStore() *MockContextDetectionEmbeddingJobsStore {
return &MockContextDetectionEmbeddingJobsStore{
CreateContextDetectionEmbeddingJobFunc: &ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFunc{
defaultHook: func(context.Context) (r0 int, r1 error) {
return
},
},
HandleFunc: &ContextDetectionEmbeddingJobsStoreHandleFunc{
defaultHook: func() (r0 basestore.TransactableHandle) {
return
},
},
}
}
// NewStrictMockContextDetectionEmbeddingJobsStore creates a new mock of the
// ContextDetectionEmbeddingJobsStore interface. All methods panic on
// invocation, unless overwritten.
func NewStrictMockContextDetectionEmbeddingJobsStore() *MockContextDetectionEmbeddingJobsStore {
return &MockContextDetectionEmbeddingJobsStore{
CreateContextDetectionEmbeddingJobFunc: &ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFunc{
defaultHook: func(context.Context) (int, error) {
panic("unexpected invocation of MockContextDetectionEmbeddingJobsStore.CreateContextDetectionEmbeddingJob")
},
},
HandleFunc: &ContextDetectionEmbeddingJobsStoreHandleFunc{
defaultHook: func() basestore.TransactableHandle {
panic("unexpected invocation of MockContextDetectionEmbeddingJobsStore.Handle")
},
},
}
}
// NewMockContextDetectionEmbeddingJobsStoreFrom creates a new mock of the
// MockContextDetectionEmbeddingJobsStore interface. All methods delegate to
// the given implementation, unless overwritten.
func NewMockContextDetectionEmbeddingJobsStoreFrom(i ContextDetectionEmbeddingJobsStore) *MockContextDetectionEmbeddingJobsStore {
return &MockContextDetectionEmbeddingJobsStore{
CreateContextDetectionEmbeddingJobFunc: &ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFunc{
defaultHook: i.CreateContextDetectionEmbeddingJob,
},
HandleFunc: &ContextDetectionEmbeddingJobsStoreHandleFunc{
defaultHook: i.Handle,
},
}
}
// ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFunc
// describes the behavior when the CreateContextDetectionEmbeddingJob method
// of the parent MockContextDetectionEmbeddingJobsStore instance is invoked.
type ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFunc struct {
defaultHook func(context.Context) (int, error)
hooks []func(context.Context) (int, error)
history []ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFuncCall
mutex sync.Mutex
}
// CreateContextDetectionEmbeddingJob delegates to the next hook function in
// the queue and stores the parameter and result values of this invocation.
func (m *MockContextDetectionEmbeddingJobsStore) CreateContextDetectionEmbeddingJob(v0 context.Context) (int, error) {
r0, r1 := m.CreateContextDetectionEmbeddingJobFunc.nextHook()(v0)
m.CreateContextDetectionEmbeddingJobFunc.appendCall(ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFuncCall{v0, r0, r1})
return r0, r1
}
// SetDefaultHook sets function that is called when the
// CreateContextDetectionEmbeddingJob method of the parent
// MockContextDetectionEmbeddingJobsStore instance is invoked and the hook
// queue is empty.
func (f *ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFunc) SetDefaultHook(hook func(context.Context) (int, error)) {
f.defaultHook = hook
}
// PushHook adds a function to the end of hook queue. Each invocation of the
// CreateContextDetectionEmbeddingJob method of the parent
// MockContextDetectionEmbeddingJobsStore instance invokes the hook at the
// front of the queue and discards it. After the queue is empty, the default
// hook function is invoked for any future action.
func (f *ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFunc) PushHook(hook func(context.Context) (int, error)) {
f.mutex.Lock()
f.hooks = append(f.hooks, hook)
f.mutex.Unlock()
}
// SetDefaultReturn calls SetDefaultHook with a function that returns the
// given values.
func (f *ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFunc) SetDefaultReturn(r0 int, r1 error) {
f.SetDefaultHook(func(context.Context) (int, error) {
return r0, r1
})
}
// PushReturn calls PushHook with a function that returns the given values.
func (f *ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFunc) PushReturn(r0 int, r1 error) {
f.PushHook(func(context.Context) (int, error) {
return r0, r1
})
}
func (f *ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFunc) nextHook() func(context.Context) (int, error) {
f.mutex.Lock()
defer f.mutex.Unlock()
if len(f.hooks) == 0 {
return f.defaultHook
}
hook := f.hooks[0]
f.hooks = f.hooks[1:]
return hook
}
func (f *ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFunc) appendCall(r0 ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFuncCall) {
f.mutex.Lock()
f.history = append(f.history, r0)
f.mutex.Unlock()
}
// History returns a sequence of
// ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFuncCall
// objects describing the invocations of this function.
func (f *ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFunc) History() []ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFuncCall {
f.mutex.Lock()
history := make([]ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFuncCall, len(f.history))
copy(history, f.history)
f.mutex.Unlock()
return history
}
// ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFuncCall
// is an object that describes an invocation of method
// CreateContextDetectionEmbeddingJob on an instance of
// MockContextDetectionEmbeddingJobsStore.
type ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFuncCall struct {
// Arg0 is the value of the 1st argument passed to this method
// invocation.
Arg0 context.Context
// Result0 is the value of the 1st result returned from this method
// invocation.
Result0 int
// Result1 is the value of the 2nd result returned from this method
// invocation.
Result1 error
}
// Args returns an interface slice containing the arguments of this
// invocation.
func (c ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFuncCall) Args() []interface{} {
return []interface{}{c.Arg0}
}
// Results returns an interface slice containing the results of this
// invocation.
func (c ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFuncCall) Results() []interface{} {
return []interface{}{c.Result0, c.Result1}
}
// ContextDetectionEmbeddingJobsStoreHandleFunc describes the behavior when
// the Handle method of the parent MockContextDetectionEmbeddingJobsStore
// instance is invoked.
type ContextDetectionEmbeddingJobsStoreHandleFunc struct {
defaultHook func() basestore.TransactableHandle
hooks []func() basestore.TransactableHandle
history []ContextDetectionEmbeddingJobsStoreHandleFuncCall
mutex sync.Mutex
}
// Handle delegates to the next hook function in the queue and stores the
// parameter and result values of this invocation.
func (m *MockContextDetectionEmbeddingJobsStore) Handle() basestore.TransactableHandle {
r0 := m.HandleFunc.nextHook()()
m.HandleFunc.appendCall(ContextDetectionEmbeddingJobsStoreHandleFuncCall{r0})
return r0
}
// SetDefaultHook sets function that is called when the Handle method of the
// parent MockContextDetectionEmbeddingJobsStore instance is invoked and the
// hook queue is empty.
func (f *ContextDetectionEmbeddingJobsStoreHandleFunc) SetDefaultHook(hook func() basestore.TransactableHandle) {
f.defaultHook = hook
}
// PushHook adds a function to the end of hook queue. Each invocation of the
// Handle method of the parent MockContextDetectionEmbeddingJobsStore
// instance invokes the hook at the front of the queue and discards it.
// After the queue is empty, the default hook function is invoked for any
// future action.
func (f *ContextDetectionEmbeddingJobsStoreHandleFunc) PushHook(hook func() basestore.TransactableHandle) {
f.mutex.Lock()
f.hooks = append(f.hooks, hook)
f.mutex.Unlock()
}
// SetDefaultReturn calls SetDefaultHook with a function that returns the
// given values.
func (f *ContextDetectionEmbeddingJobsStoreHandleFunc) SetDefaultReturn(r0 basestore.TransactableHandle) {
f.SetDefaultHook(func() basestore.TransactableHandle {
return r0
})
}
// PushReturn calls PushHook with a function that returns the given values.
func (f *ContextDetectionEmbeddingJobsStoreHandleFunc) PushReturn(r0 basestore.TransactableHandle) {
f.PushHook(func() basestore.TransactableHandle {
return r0
})
}
func (f *ContextDetectionEmbeddingJobsStoreHandleFunc) nextHook() func() basestore.TransactableHandle {
f.mutex.Lock()
defer f.mutex.Unlock()
if len(f.hooks) == 0 {
return f.defaultHook
}
hook := f.hooks[0]
f.hooks = f.hooks[1:]
return hook
}
func (f *ContextDetectionEmbeddingJobsStoreHandleFunc) appendCall(r0 ContextDetectionEmbeddingJobsStoreHandleFuncCall) {
f.mutex.Lock()
f.history = append(f.history, r0)
f.mutex.Unlock()
}
// History returns a sequence of
// ContextDetectionEmbeddingJobsStoreHandleFuncCall objects describing the
// invocations of this function.
func (f *ContextDetectionEmbeddingJobsStoreHandleFunc) History() []ContextDetectionEmbeddingJobsStoreHandleFuncCall {
f.mutex.Lock()
history := make([]ContextDetectionEmbeddingJobsStoreHandleFuncCall, len(f.history))
copy(history, f.history)
f.mutex.Unlock()
return history
}
// ContextDetectionEmbeddingJobsStoreHandleFuncCall is an object that
// describes an invocation of method Handle on an instance of
// MockContextDetectionEmbeddingJobsStore.
type ContextDetectionEmbeddingJobsStoreHandleFuncCall struct {
// Result0 is the value of the 1st result returned from this method
// invocation.
Result0 basestore.TransactableHandle
}
// Args returns an interface slice containing the arguments of this
// invocation.
func (c ContextDetectionEmbeddingJobsStoreHandleFuncCall) Args() []interface{} {
return []interface{}{}
}
// Results returns an interface slice containing the results of this
// invocation.
func (c ContextDetectionEmbeddingJobsStoreHandleFuncCall) Results() []interface{} {
return []interface{}{c.Result0}
}

View File

@ -4,12 +4,10 @@ import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"github.com/sourcegraph/conc/pool"
"github.com/sourcegraph/sourcegraph/lib/errors"
"github.com/sourcegraph/sourcegraph/internal/api"
@ -32,23 +30,14 @@ var defaultDoer = func() httpcli.Doer {
return d
}()
func NewDefaultClient() Client {
return NewClient(defaultEndpoints(), defaultDoer)
}
func NewClient(endpoints *endpoint.Map, doer httpcli.Doer) Client {
return &client{
Endpoints: endpoints,
HTTPClient: doer,
func NewClient() *Client {
return &Client{
Endpoints: defaultEndpoints(),
HTTPClient: defaultDoer,
}
}
type Client interface {
Search(context.Context, EmbeddingsSearchParameters) (*EmbeddingCombinedSearchResults, error)
IsContextRequiredForChatQuery(context.Context, IsContextRequiredForChatQueryParameters) (bool, error)
}
type client struct {
type Client struct {
// Endpoints to embeddings service.
Endpoints *endpoint.Map
@ -57,13 +46,15 @@ type client struct {
}
type EmbeddingsSearchParameters struct {
RepoNames []api.RepoName `json:"repoNames"`
RepoIDs []api.RepoID `json:"repoIDs"`
Query string `json:"query"`
CodeResultsCount int `json:"codeResultsCount"`
TextResultsCount int `json:"textResultsCount"`
RepoName api.RepoName `json:"repoName"`
RepoID api.RepoID `json:"repoID"`
Query string `json:"query"`
CodeResultsCount int `json:"codeResultsCount"`
TextResultsCount int `json:"textResultsCount"`
UseDocumentRanks bool `json:"useDocumentRanks"`
// If set to "True", EmbeddingSearchResult.Debug will contain useful information about scoring.
Debug bool `json:"debug"`
}
type IsContextRequiredForChatQueryParameters struct {
@ -74,43 +65,8 @@ type IsContextRequiredForChatQueryResult struct {
IsRequired bool `json:"isRequired"`
}
func (c *client) Search(ctx context.Context, args EmbeddingsSearchParameters) (*EmbeddingCombinedSearchResults, error) {
partitions, err := c.partition(args.RepoNames, args.RepoIDs)
if err != nil {
return nil, err
}
p := pool.NewWithResults[*EmbeddingCombinedSearchResults]().WithContext(ctx)
for endpoint, partition := range partitions {
endpoint := endpoint
// make a copy for this request
args := args
args.RepoNames = partition.repoNames
args.RepoIDs = partition.repoIDs
p.Go(func(ctx context.Context) (*EmbeddingCombinedSearchResults, error) {
return c.searchPartition(ctx, endpoint, args)
})
}
allResults, err := p.Wait()
if err != nil {
return nil, err
}
var combinedResult EmbeddingCombinedSearchResults
for _, result := range allResults {
combinedResult.CodeResults.MergeTruncate(result.CodeResults, args.CodeResultsCount)
combinedResult.TextResults.MergeTruncate(result.TextResults, args.TextResultsCount)
}
return &combinedResult, nil
}
func (c *client) searchPartition(ctx context.Context, endpoint string, args EmbeddingsSearchParameters) (*EmbeddingCombinedSearchResults, error) {
resp, err := c.httpPost(ctx, "search", endpoint, args)
func (c *Client) Search(ctx context.Context, args EmbeddingsSearchParameters) (*EmbeddingSearchResults, error) {
resp, err := c.httpPost(ctx, "search", args.RepoName, args)
if err != nil {
return nil, err
}
@ -126,16 +82,15 @@ func (c *client) searchPartition(ctx context.Context, endpoint string, args Embe
)
}
var response EmbeddingCombinedSearchResults
var response EmbeddingSearchResults
err = json.NewDecoder(resp.Body).Decode(&response)
if err != nil {
return nil, err
}
fmt.Printf("%s, %#v\n", endpoint, response)
return &response, nil
}
func (c *client) IsContextRequiredForChatQuery(ctx context.Context, args IsContextRequiredForChatQueryParameters) (bool, error) {
func (c *Client) IsContextRequiredForChatQuery(ctx context.Context, args IsContextRequiredForChatQueryParameters) (bool, error) {
resp, err := c.httpPost(ctx, "isContextRequiredForChatQuery", "", args)
if err != nil {
return false, err
@ -160,50 +115,24 @@ func (c *client) IsContextRequiredForChatQuery(ctx context.Context, args IsConte
return response.IsRequired, nil
}
func (c *client) url(repo api.RepoName) (string, error) {
func (c *Client) url(repo api.RepoName) (string, error) {
if c.Endpoints == nil {
return "", errors.New("an embeddings service has not been configured")
}
return c.Endpoints.Get(string(repo))
}
type repoPartition struct {
repoNames []api.RepoName
repoIDs []api.RepoID
}
// returns a partition of the input repos by the endpoint their requests should be routed to
func (c *client) partition(repos []api.RepoName, repoIDs []api.RepoID) (map[string]repoPartition, error) {
if c.Endpoints == nil {
return nil, errors.New("an embeddings service has not been configured")
}
repoStrings := make([]string, len(repos))
for i, repo := range repos {
repoStrings[i] = string(repo)
}
endpoints, err := c.Endpoints.GetMany(repoStrings...)
func (c *Client) httpPost(
ctx context.Context,
method string,
repo api.RepoName,
payload any,
) (resp *http.Response, err error) {
url, err := c.url(repo)
if err != nil {
return nil, err
}
res := make(map[string]repoPartition)
for i, endpoint := range endpoints {
res[endpoint] = repoPartition{
repoNames: append(res[endpoint].repoNames, repos[i]),
repoIDs: append(res[endpoint].repoIDs, repoIDs[i]),
}
}
return res, nil
}
func (c *client) httpPost(
ctx context.Context,
method string,
url string,
payload any,
) (resp *http.Response, err error) {
reqBody, err := json.Marshal(payload)
if err != nil {
return nil, err

View File

@ -1,291 +0,0 @@
// Code generated by go-mockgen 1.3.7; DO NOT EDIT.
//
// This file was generated by running `sg generate` (or `go-mockgen`) at the root of
// this repository. To add additional mocks to this or another package, add a new entry
// to the mockgen.yaml file in the root of this repository.
package embeddings
import (
"context"
"sync"
)
// MockClient is a mock implementation of the Client interface (from the
// package
// github.com/sourcegraph/sourcegraph/enterprise/internal/embeddings) used
// for unit testing.
type MockClient struct {
// IsContextRequiredForChatQueryFunc is an instance of a mock function
// object controlling the behavior of the method
// IsContextRequiredForChatQuery.
IsContextRequiredForChatQueryFunc *ClientIsContextRequiredForChatQueryFunc
// SearchFunc is an instance of a mock function object controlling the
// behavior of the method Search.
SearchFunc *ClientSearchFunc
}
// NewMockClient creates a new mock of the Client interface. All methods
// return zero values for all results, unless overwritten.
func NewMockClient() *MockClient {
return &MockClient{
IsContextRequiredForChatQueryFunc: &ClientIsContextRequiredForChatQueryFunc{
defaultHook: func(context.Context, IsContextRequiredForChatQueryParameters) (r0 bool, r1 error) {
return
},
},
SearchFunc: &ClientSearchFunc{
defaultHook: func(context.Context, EmbeddingsSearchParameters) (r0 *EmbeddingCombinedSearchResults, r1 error) {
return
},
},
}
}
// NewStrictMockClient creates a new mock of the Client interface. All
// methods panic on invocation, unless overwritten.
func NewStrictMockClient() *MockClient {
return &MockClient{
IsContextRequiredForChatQueryFunc: &ClientIsContextRequiredForChatQueryFunc{
defaultHook: func(context.Context, IsContextRequiredForChatQueryParameters) (bool, error) {
panic("unexpected invocation of MockClient.IsContextRequiredForChatQuery")
},
},
SearchFunc: &ClientSearchFunc{
defaultHook: func(context.Context, EmbeddingsSearchParameters) (*EmbeddingCombinedSearchResults, error) {
panic("unexpected invocation of MockClient.Search")
},
},
}
}
// NewMockClientFrom creates a new mock of the MockClient interface. All
// methods delegate to the given implementation, unless overwritten.
func NewMockClientFrom(i Client) *MockClient {
return &MockClient{
IsContextRequiredForChatQueryFunc: &ClientIsContextRequiredForChatQueryFunc{
defaultHook: i.IsContextRequiredForChatQuery,
},
SearchFunc: &ClientSearchFunc{
defaultHook: i.Search,
},
}
}
// ClientIsContextRequiredForChatQueryFunc describes the behavior when the
// IsContextRequiredForChatQuery method of the parent MockClient instance is
// invoked.
type ClientIsContextRequiredForChatQueryFunc struct {
defaultHook func(context.Context, IsContextRequiredForChatQueryParameters) (bool, error)
hooks []func(context.Context, IsContextRequiredForChatQueryParameters) (bool, error)
history []ClientIsContextRequiredForChatQueryFuncCall
mutex sync.Mutex
}
// IsContextRequiredForChatQuery delegates to the next hook function in the
// queue and stores the parameter and result values of this invocation.
func (m *MockClient) IsContextRequiredForChatQuery(v0 context.Context, v1 IsContextRequiredForChatQueryParameters) (bool, error) {
r0, r1 := m.IsContextRequiredForChatQueryFunc.nextHook()(v0, v1)
m.IsContextRequiredForChatQueryFunc.appendCall(ClientIsContextRequiredForChatQueryFuncCall{v0, v1, r0, r1})
return r0, r1
}
// SetDefaultHook sets function that is called when the
// IsContextRequiredForChatQuery method of the parent MockClient instance is
// invoked and the hook queue is empty.
func (f *ClientIsContextRequiredForChatQueryFunc) SetDefaultHook(hook func(context.Context, IsContextRequiredForChatQueryParameters) (bool, error)) {
f.defaultHook = hook
}
// PushHook adds a function to the end of hook queue. Each invocation of the
// IsContextRequiredForChatQuery method of the parent MockClient instance
// invokes the hook at the front of the queue and discards it. After the
// queue is empty, the default hook function is invoked for any future
// action.
func (f *ClientIsContextRequiredForChatQueryFunc) PushHook(hook func(context.Context, IsContextRequiredForChatQueryParameters) (bool, error)) {
f.mutex.Lock()
f.hooks = append(f.hooks, hook)
f.mutex.Unlock()
}
// SetDefaultReturn calls SetDefaultHook with a function that returns the
// given values.
func (f *ClientIsContextRequiredForChatQueryFunc) SetDefaultReturn(r0 bool, r1 error) {
f.SetDefaultHook(func(context.Context, IsContextRequiredForChatQueryParameters) (bool, error) {
return r0, r1
})
}
// PushReturn calls PushHook with a function that returns the given values.
func (f *ClientIsContextRequiredForChatQueryFunc) PushReturn(r0 bool, r1 error) {
f.PushHook(func(context.Context, IsContextRequiredForChatQueryParameters) (bool, error) {
return r0, r1
})
}
func (f *ClientIsContextRequiredForChatQueryFunc) nextHook() func(context.Context, IsContextRequiredForChatQueryParameters) (bool, error) {
f.mutex.Lock()
defer f.mutex.Unlock()
if len(f.hooks) == 0 {
return f.defaultHook
}
hook := f.hooks[0]
f.hooks = f.hooks[1:]
return hook
}
func (f *ClientIsContextRequiredForChatQueryFunc) appendCall(r0 ClientIsContextRequiredForChatQueryFuncCall) {
f.mutex.Lock()
f.history = append(f.history, r0)
f.mutex.Unlock()
}
// History returns a sequence of ClientIsContextRequiredForChatQueryFuncCall
// objects describing the invocations of this function.
func (f *ClientIsContextRequiredForChatQueryFunc) History() []ClientIsContextRequiredForChatQueryFuncCall {
f.mutex.Lock()
history := make([]ClientIsContextRequiredForChatQueryFuncCall, len(f.history))
copy(history, f.history)
f.mutex.Unlock()
return history
}
// ClientIsContextRequiredForChatQueryFuncCall is an object that describes
// an invocation of method IsContextRequiredForChatQuery on an instance of
// MockClient.
type ClientIsContextRequiredForChatQueryFuncCall struct {
// Arg0 is the value of the 1st argument passed to this method
// invocation.
Arg0 context.Context
// Arg1 is the value of the 2nd argument passed to this method
// invocation.
Arg1 IsContextRequiredForChatQueryParameters
// Result0 is the value of the 1st result returned from this method
// invocation.
Result0 bool
// Result1 is the value of the 2nd result returned from this method
// invocation.
Result1 error
}
// Args returns an interface slice containing the arguments of this
// invocation.
func (c ClientIsContextRequiredForChatQueryFuncCall) Args() []interface{} {
return []interface{}{c.Arg0, c.Arg1}
}
// Results returns an interface slice containing the results of this
// invocation.
func (c ClientIsContextRequiredForChatQueryFuncCall) Results() []interface{} {
return []interface{}{c.Result0, c.Result1}
}
// ClientSearchFunc describes the behavior when the Search method of the
// parent MockClient instance is invoked.
type ClientSearchFunc struct {
defaultHook func(context.Context, EmbeddingsSearchParameters) (*EmbeddingCombinedSearchResults, error)
hooks []func(context.Context, EmbeddingsSearchParameters) (*EmbeddingCombinedSearchResults, error)
history []ClientSearchFuncCall
mutex sync.Mutex
}
// Search delegates to the next hook function in the queue and stores the
// parameter and result values of this invocation.
func (m *MockClient) Search(v0 context.Context, v1 EmbeddingsSearchParameters) (*EmbeddingCombinedSearchResults, error) {
r0, r1 := m.SearchFunc.nextHook()(v0, v1)
m.SearchFunc.appendCall(ClientSearchFuncCall{v0, v1, r0, r1})
return r0, r1
}
// SetDefaultHook sets function that is called when the Search method of the
// parent MockClient instance is invoked and the hook queue is empty.
func (f *ClientSearchFunc) SetDefaultHook(hook func(context.Context, EmbeddingsSearchParameters) (*EmbeddingCombinedSearchResults, error)) {
f.defaultHook = hook
}
// PushHook adds a function to the end of hook queue. Each invocation of the
// Search method of the parent MockClient instance invokes the hook at the
// front of the queue and discards it. After the queue is empty, the default
// hook function is invoked for any future action.
func (f *ClientSearchFunc) PushHook(hook func(context.Context, EmbeddingsSearchParameters) (*EmbeddingCombinedSearchResults, error)) {
f.mutex.Lock()
f.hooks = append(f.hooks, hook)
f.mutex.Unlock()
}
// SetDefaultReturn calls SetDefaultHook with a function that returns the
// given values.
func (f *ClientSearchFunc) SetDefaultReturn(r0 *EmbeddingCombinedSearchResults, r1 error) {
f.SetDefaultHook(func(context.Context, EmbeddingsSearchParameters) (*EmbeddingCombinedSearchResults, error) {
return r0, r1
})
}
// PushReturn calls PushHook with a function that returns the given values.
func (f *ClientSearchFunc) PushReturn(r0 *EmbeddingCombinedSearchResults, r1 error) {
f.PushHook(func(context.Context, EmbeddingsSearchParameters) (*EmbeddingCombinedSearchResults, error) {
return r0, r1
})
}
func (f *ClientSearchFunc) nextHook() func(context.Context, EmbeddingsSearchParameters) (*EmbeddingCombinedSearchResults, error) {
f.mutex.Lock()
defer f.mutex.Unlock()
if len(f.hooks) == 0 {
return f.defaultHook
}
hook := f.hooks[0]
f.hooks = f.hooks[1:]
return hook
}
func (f *ClientSearchFunc) appendCall(r0 ClientSearchFuncCall) {
f.mutex.Lock()
f.history = append(f.history, r0)
f.mutex.Unlock()
}
// History returns a sequence of ClientSearchFuncCall objects describing the
// invocations of this function.
func (f *ClientSearchFunc) History() []ClientSearchFuncCall {
f.mutex.Lock()
history := make([]ClientSearchFuncCall, len(f.history))
copy(history, f.history)
f.mutex.Unlock()
return history
}
// ClientSearchFuncCall is an object that describes an invocation of method
// Search on an instance of MockClient.
type ClientSearchFuncCall struct {
// Arg0 is the value of the 1st argument passed to this method
// invocation.
Arg0 context.Context
// Arg1 is the value of the 2nd argument passed to this method
// invocation.
Arg1 EmbeddingsSearchParameters
// Result0 is the value of the 1st result returned from this method
// invocation.
Result0 *EmbeddingCombinedSearchResults
// Result1 is the value of the 2nd result returned from this method
// invocation.
Result1 error
}
// Args returns an interface slice containing the arguments of this
// invocation.
func (c ClientSearchFuncCall) Args() []interface{} {
return []interface{}{c.Arg0, c.Arg1}
}
// Results returns an interface slice containing the results of this
// invocation.
func (c ClientSearchFuncCall) Results() []interface{} {
return []interface{}{c.Result0, c.Result1}
}

View File

@ -2,16 +2,17 @@ package embeddings
import (
"container/heap"
"fmt"
"math"
"sort"
"github.com/sourcegraph/conc"
"github.com/sourcegraph/sourcegraph/internal/api"
)
type nearestNeighbor struct {
index int
scoreDetails SearchScoreDetails
index int
score int32
debug searchDebugInfo
}
type nearestNeighborsHeap struct {
@ -21,7 +22,7 @@ type nearestNeighborsHeap struct {
func (nn *nearestNeighborsHeap) Len() int { return len(nn.neighbors) }
func (nn *nearestNeighborsHeap) Less(i, j int) bool {
return nn.neighbors[i].scoreDetails.Score < nn.neighbors[j].scoreDetails.Score
return nn.neighbors[i].score < nn.neighbors[j].score
}
func (nn *nearestNeighborsHeap) Swap(i, j int) {
@ -80,16 +81,19 @@ type WorkerOptions struct {
MinRowsToSplit int
}
type SimilaritySearchResult struct {
RepoEmbeddingRowMetadata
SimilarityScore int32
RankScore int32
}
func (r *SimilaritySearchResult) Score() int32 {
return r.SimilarityScore + r.RankScore
}
// SimilaritySearch finds the `nResults` most similar rows to a query vector. It uses the cosine similarity metric.
// IMPORTANT: The vectors in the embedding index have to be normalized for similarity search to work correctly.
func (index *EmbeddingIndex) SimilaritySearch(
query []int8,
numResults int,
workerOptions WorkerOptions,
opts SearchOptions,
repoName api.RepoName,
revision api.CommitID,
) []EmbeddingSearchResult {
func (index *EmbeddingIndex) SimilaritySearch(query []int8, numResults int, workerOptions WorkerOptions, opts SearchOptions) []SimilaritySearchResult {
if numResults == 0 || len(index.Embeddings) == 0 {
return nil
}
@ -127,20 +131,16 @@ func (index *EmbeddingIndex) SimilaritySearch(
}
}
// And re-sort it according to the score (descending).
sort.Slice(neighbors, func(i, j int) bool { return neighbors[i].scoreDetails.Score > neighbors[j].scoreDetails.Score })
sort.Slice(neighbors, func(i, j int) bool { return neighbors[i].score > neighbors[j].score })
// Take top neighbors and return them as results.
results := make([]EmbeddingSearchResult, numResults)
results := make([]SimilaritySearchResult, numResults)
for idx := 0; idx < min(numResults, len(neighbors)); idx++ {
metadata := index.RowMetadata[neighbors[idx].index]
results[idx] = EmbeddingSearchResult{
RepoName: repoName,
Revision: revision,
FileName: metadata.FileName,
StartLine: metadata.StartLine,
EndLine: metadata.EndLine,
ScoreDetails: neighbors[idx].scoreDetails,
results[idx] = SimilaritySearchResult{
RepoEmbeddingRowMetadata: index.RowMetadata[neighbors[idx].index],
SimilarityScore: neighbors[idx].debug.similarity,
RankScore: neighbors[idx].debug.rank,
}
}
@ -156,17 +156,17 @@ func (index *EmbeddingIndex) partialSimilaritySearch(query []int8, numResults in
nnHeap := newNearestNeighborsHeap()
for i := partialRows.start; i < partialRows.start+numResults; i++ {
scoreDetails := index.score(query, i, opts)
heap.Push(nnHeap, nearestNeighbor{index: i, scoreDetails: scoreDetails})
score, debugInfo := index.score(query, i, opts)
heap.Push(nnHeap, nearestNeighbor{index: i, score: score, debug: debugInfo})
}
for i := partialRows.start + numResults; i < partialRows.end; i++ {
scoreDetails := index.score(query, i, opts)
score, debugInfo := index.score(query, i, opts)
// Add row if it has greater similarity than the smallest similarity in the heap.
// This way we ensure keep a set of the highest similarities in the heap.
if scoreDetails.Score > nnHeap.Peek().scoreDetails.Score {
if score > nnHeap.Peek().score {
heap.Pop(nnHeap)
heap.Push(nnHeap, nearestNeighbor{index: i, scoreDetails: scoreDetails})
heap.Push(nnHeap, nearestNeighbor{index: i, score: score, debug: debugInfo})
}
}
@ -178,7 +178,7 @@ const (
scoreSimilarityWeight int32 = 2
)
func (index *EmbeddingIndex) score(query []int8, i int, opts SearchOptions) SearchScoreDetails {
func (index *EmbeddingIndex) score(query []int8, i int, opts SearchOptions) (score int32, debugInfo searchDebugInfo) {
similarityScore := scoreSimilarityWeight * Dot(index.Row(i), query)
// handle missing ranks
@ -195,11 +195,20 @@ func (index *EmbeddingIndex) score(query []int8, i int, opts SearchOptions) Sear
rankScore = int32(float32(scoreFileRankWeight) * normalizedRank)
}
return SearchScoreDetails{
Score: similarityScore + rankScore,
SimilarityScore: similarityScore,
RankScore: rankScore,
return similarityScore + rankScore, searchDebugInfo{similarity: similarityScore, rank: rankScore, enabled: opts.Debug}
}
type searchDebugInfo struct {
similarity int32
rank int32
enabled bool
}
func (i *searchDebugInfo) String() string {
if !i.enabled {
return ""
}
return fmt.Sprintf("score:%d, similarity:%d, rank:%d", i.similarity+i.rank, i.similarity, i.rank)
}
func min(a, b int) int {
@ -217,5 +226,6 @@ func max(a, b int) int {
}
type SearchOptions struct {
Debug bool
UseDocumentRanks bool
}

View File

@ -64,10 +64,10 @@ func TestSimilaritySearch(t *testing.T) {
for q := 0; q < numQueries; q++ {
t.Run(fmt.Sprintf("find nearest neighbors query=%d numResults=%d numWorkers=%d", q, numResults, numWorkers), func(t *testing.T) {
query := queries[q*columnDimension : (q+1)*columnDimension]
results := index.SimilaritySearch(query, numResults, WorkerOptions{NumWorkers: numWorkers, MinRowsToSplit: 0}, SearchOptions{}, "", "")
results := index.SimilaritySearch(query, numResults, WorkerOptions{NumWorkers: numWorkers, MinRowsToSplit: 0}, SearchOptions{})
resultRowNums := make([]int, len(results))
for i, r := range results {
resultRowNums[i], _ = strconv.Atoi(r.FileName)
resultRowNums[i], _ = strconv.Atoi(r.RepoEmbeddingRowMetadata.FileName)
}
expectedResults := ranks[q]
require.Equal(t, expectedResults[:min(numResults, len(expectedResults))], resultRowNums)
@ -152,7 +152,7 @@ func BenchmarkSimilaritySearch(b *testing.B) {
b.Run(fmt.Sprintf("numWorkers=%d", numWorkers), func(b *testing.B) {
start := time.Now()
for n := 0; n < b.N; n++ {
_ = index.SimilaritySearch(query, numResults, WorkerOptions{NumWorkers: numWorkers}, SearchOptions{}, "", "")
_ = index.SimilaritySearch(query, numResults, WorkerOptions{NumWorkers: numWorkers}, SearchOptions{})
}
m := float64(numRows) * float64(b.N) / time.Since(start).Seconds()
b.ReportMetric(m, "embeddings/s")
@ -175,11 +175,23 @@ func TestScore(t *testing.T) {
}
// embeddings[0] = 64, 83, 70,
// queries[0:3] = 53, 61, 97,
scoreDetails := index.score(queries[0:columnDimension], 0, SearchOptions{UseDocumentRanks: true})
score, debugInfo := index.score(queries[0:columnDimension], 0, SearchOptions{Debug: true, UseDocumentRanks: true})
// Check that the score is correct
expectedScore := scoreSimilarityWeight * ((64 * 53) + (83 * 61) + (70 * 97))
if math.Abs(float64(scoreDetails.Score-expectedScore)) > 0.0001 {
t.Fatalf("Expected score %d, but got %d", expectedScore, scoreDetails.Score)
if math.Abs(float64(score-expectedScore)) > 0.0001 {
t.Fatalf("Expected score %d, but got %d", expectedScore, score)
}
if debugInfo.String() == "" {
t.Fatal("Expected a non-empty debug")
}
}
func simpleCosineSimilarity(a, b []int8) int32 {
similarity := int32(0)
for i := 0; i < len(a); i++ {
similarity += int32(a[i]) * int32(b[i])
}
return similarity
}

View File

@ -1,8 +1,6 @@
package embeddings
import (
"fmt"
"sort"
"time"
"github.com/sourcegraph/log"
@ -47,46 +45,18 @@ type ContextDetectionEmbeddingIndex struct {
MessagesWithoutAdditionalContextMeanEmbedding []float32
}
type EmbeddingCombinedSearchResults struct {
CodeResults EmbeddingSearchResults `json:"codeResults"`
TextResults EmbeddingSearchResults `json:"textResults"`
}
type EmbeddingSearchResults []EmbeddingSearchResult
// MergeTruncate merges other into the search results, keeping only max results with the highest scores
func (esrs *EmbeddingSearchResults) MergeTruncate(other EmbeddingSearchResults, max int) {
self := *esrs
self = append(self, other...)
sort.Slice(self, func(i, j int) bool { return self[i].Score() > self[j].Score() })
*esrs = self[:max]
type EmbeddingSearchResults struct {
CodeResults []EmbeddingSearchResult `json:"codeResults"`
TextResults []EmbeddingSearchResult `json:"textResults"`
}
type EmbeddingSearchResult struct {
RepoName api.RepoName `json:"repoName"`
Revision api.CommitID `json:"revision"`
FileName string `json:"fileName"`
StartLine int `json:"startLine"`
EndLine int `json:"endLine"`
ScoreDetails SearchScoreDetails `json:"scoreDetails"`
}
func (esr *EmbeddingSearchResult) Score() int32 {
return esr.ScoreDetails.RankScore + esr.ScoreDetails.SimilarityScore
}
type SearchScoreDetails struct {
Score int32 `json:"score"`
// Breakdown
SimilarityScore int32 `json:"similarityScore"`
RankScore int32 `json:"rankScore"`
}
func (s *SearchScoreDetails) String() string {
return fmt.Sprintf("score:%d, similarity:%d, rank:%d", s.Score, s.SimilarityScore, s.RankScore)
RepoName api.RepoName
Revision api.CommitID
RepoEmbeddingRowMetadata
Content string `json:"content"`
// Experimental: Clients should not rely on any particular format of debug
Debug string `json:"debug,omitempty"`
}
// DEPRECATED: to support decoding old indexes, we need a struct

View File

@ -123,15 +123,3 @@
path: github.com/sourcegraph/sourcegraph/enterprise/internal/github_apps/store
interfaces:
- GitHubAppsStore
- filename: enterprise/internal/embeddings/mocks_temp.go
path: github.com/sourcegraph/sourcegraph/enterprise/internal/embeddings
interfaces:
- Client
- filename: enterprise/internal/embeddings/background/repo/mocks_temp.go
path: github.com/sourcegraph/sourcegraph/enterprise/internal/embeddings/background/repo
interfaces:
- RepoEmbeddingJobsStore
- filename: enterprise/internal/embeddings/background/contextdetection/mocks_temp.go
path: github.com/sourcegraph/sourcegraph/enterprise/internal/embeddings/background/contextdetection
interfaces:
- ContextDetectionEmbeddingJobsStore