Embeddings: multi-repo search take 2 (#52019)

This re-opens the reverted #51662.

There were three independent issues with the deployed version:
1) There was a stray printf from debugging. It wasn't breaking anything,
but it was pretty spammy.
2) I made a change to `httpPost` that made it take an endpoint rather
than a repo name, but forgot to update the callsite in
`IsContextRequiredForChatQuery`, so it was using an empty string as the
URL. This was not caught by tests because the API was not covered by
tests before this PR, and I only added tests for the `EmbeddingsSearch`
endpoint, not the `IsContextRequiredForChatQuery` endpoint. I've added a
new test to this endpoint as part of this un-revert
3) Followup changes in the client to use the new endpoints were using
types that did not match the actual response shape from GraphQL and this
was causing undefined field accesses. This was not caught by the type
system because the problem was at the JSON -> typescript type layer,
which is an unchecked type assertion.

I squashed the original PR into the first commit, so the second and
third commits are the only changes

## Test plan

Added a test for `IsContextRequiredForChatQuery` API
This commit is contained in:
Camden Cheek 2023-05-16 16:42:15 -06:00 committed by GitHub
parent 9578c413db
commit 78d95e61d0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 1410 additions and 348 deletions

View File

@ -11,7 +11,7 @@ export class SourcegraphEmbeddingsSearchClient implements EmbeddingsSearch {
textResultsCount: number
): Promise<EmbeddingsSearchResults | Error> {
if (this.web) {
return this.client.searchEmbeddings(this.repoId, query, codeResultsCount, textResultsCount)
return this.client.searchEmbeddings([this.repoId], query, codeResultsCount, textResultsCount)
}
return this.client.legacySearchEmbeddings(this.repoId, query, codeResultsCount, textResultsCount)

View File

@ -36,6 +36,10 @@ interface EmbeddingsSearchResponse {
embeddingsSearch: EmbeddingsSearchResults
}
interface EmbeddingsMultiSearchResponse {
embeddingsMultiSearch: EmbeddingsSearchResults
}
interface LogEventResponse {}
export interface EmbeddingsSearchResult {
@ -144,17 +148,17 @@ export class SourcegraphGraphQLAPIClient {
}
public async searchEmbeddings(
repo: string,
repos: string[],
query: string,
codeResultsCount: number,
textResultsCount: number
): Promise<EmbeddingsSearchResults | Error> {
return this.fetchSourcegraphAPI<APIResponse<EmbeddingsSearchResponse>>(SEARCH_EMBEDDINGS_QUERY, {
repo,
return this.fetchSourcegraphAPI<APIResponse<EmbeddingsMultiSearchResponse>>(SEARCH_EMBEDDINGS_QUERY, {
repos,
query,
codeResultsCount,
textResultsCount,
}).then(response => extractDataOrError(response, data => data.embeddingsSearch))
}).then(response => extractDataOrError(response, data => data.embeddingsMultiSearch))
}
// (Naman): This is a temporary workaround for supporting vscode cody integrated with older version of sourcegraph which do not support the latest searchEmbeddings query.

View File

@ -21,19 +21,19 @@ query Repository($name: String!) {
}`
export const SEARCH_EMBEDDINGS_QUERY = `
query EmbeddingsSearch($repo: ID!, $query: String!, $codeResultsCount: Int!, $textResultsCount: Int!) {
embeddingsSearch(repo: $repo, query: $query, codeResultsCount: $codeResultsCount, textResultsCount: $textResultsCount) {
query EmbeddingsSearch($repos: [ID!]!, $query: String!, $codeResultsCount: Int!, $textResultsCount: Int!) {
embeddingsMultiSearch(repos: $repos, query: $query, codeResultsCount: $codeResultsCount, textResultsCount: $textResultsCount) {
codeResults {
repoName
revision
repoName
revision
fileName
startLine
endLine
content
}
textResults {
repoName
revision
repoName
revision
fileName
startLine
endLine

View File

@ -11,6 +11,7 @@ import (
type EmbeddingsResolver interface {
EmbeddingsSearch(ctx context.Context, args EmbeddingsSearchInputArgs) (EmbeddingsSearchResultsResolver, error)
EmbeddingsMultiSearch(ctx context.Context, args EmbeddingsMultiSearchInputArgs) (EmbeddingsSearchResultsResolver, error)
IsContextRequiredForChatQuery(ctx context.Context, args IsContextRequiredForChatQueryInputArgs) (bool, error)
RepoEmbeddingJobs(ctx context.Context, args ListRepoEmbeddingJobsArgs) (*graphqlutil.ConnectionResolver[RepoEmbeddingJobResolver], error)
@ -33,9 +34,16 @@ type EmbeddingsSearchInputArgs struct {
TextResultsCount int32
}
type EmbeddingsMultiSearchInputArgs struct {
Repos []graphql.ID
Query string
CodeResultsCount int32
TextResultsCount int32
}
type EmbeddingsSearchResultsResolver interface {
CodeResults(ctx context.Context) []EmbeddingsSearchResultResolver
TextResults(ctx context.Context) []EmbeddingsSearchResultResolver
CodeResults(ctx context.Context) ([]EmbeddingsSearchResultResolver, error)
TextResults(ctx context.Context) ([]EmbeddingsSearchResultResolver, error)
}
type EmbeddingsSearchResultResolver interface {

View File

@ -22,6 +22,31 @@ extend type Query {
"""
textResultsCount: Int!
): EmbeddingsSearchResults!
"""
Experimental: Searches a set of repositories for similar code and text results using embeddings.
We separated code and text results because text results tended to always feature at the top of the combined results,
and didn't leave room for the code.
"""
embeddingsMultiSearch(
"""
The repository to search.
"""
repos: [ID!]!
"""
The query used for embeddings search.
"""
query: String!
"""
The number of code results to return.
"""
codeResultsCount: Int!
"""
The number of text results to return. Text results contain Markdown files and similar file types primarily used for writing documentation.
"""
textResultsCount: Int!
): EmbeddingsSearchResults!
"""
Experimental: Determines whether the given query requires further context before it can be answered.
For example:

View File

@ -8,6 +8,7 @@ go_library(
visibility = ["//visibility:public"],
deps = [
"//enterprise/internal/embeddings",
"//internal/api",
"//lib/errors",
],
)

View File

@ -11,6 +11,7 @@ import (
"strings"
"github.com/sourcegraph/sourcegraph/enterprise/internal/embeddings"
"github.com/sourcegraph/sourcegraph/internal/api"
"github.com/sourcegraph/sourcegraph/lib/errors"
)
@ -21,7 +22,7 @@ import (
var fs embed.FS
type embeddingsSearcher interface {
Search(args embeddings.EmbeddingsSearchParameters) (*embeddings.EmbeddingSearchResults, error)
Search(args embeddings.EmbeddingsSearchParameters) (*embeddings.EmbeddingCombinedSearchResults, error)
}
// Run runs the evaluation and returns recall for the test data.
@ -45,11 +46,10 @@ func Run(searcher embeddingsSearcher) (float64, error) {
relevantFile := fields[1]
args := embeddings.EmbeddingsSearchParameters{
RepoName: "github.com/sourcegraph/sourcegraph",
RepoNames: []api.RepoName{"github.com/sourcegraph/sourcegraph"},
Query: query,
CodeResultsCount: 20,
TextResultsCount: 2,
Debug: true,
}
results, err := searcher.Search(args)
@ -70,11 +70,7 @@ func Run(searcher embeddingsSearcher) (float64, error) {
fmt.Printf(" ")
}
fmt.Printf("%d. %s", i+1, result.FileName)
if result.Debug != "" {
fmt.Printf(" (%s)\n", result.Debug)
} else {
fmt.Print("\n")
}
fmt.Printf(" (%s)\n", result.ScoreDetails.String())
}
fmt.Println()
if fileFound {
@ -103,7 +99,7 @@ func NewClient(url string) *client {
}
}
func (c *client) Search(args embeddings.EmbeddingsSearchParameters) (*embeddings.EmbeddingSearchResults, error) {
func (c *client) Search(args embeddings.EmbeddingsSearchParameters) (*embeddings.EmbeddingCombinedSearchResults, error) {
b, err := json.Marshal(args)
if err != nil {
return nil, err
@ -130,7 +126,7 @@ func (c *client) Search(args embeddings.EmbeddingsSearchParameters) (*embeddings
return nil, err
}
res := embeddings.EmbeddingSearchResults{}
res := embeddings.EmbeddingCombinedSearchResults{}
err = json.Unmarshal(body, &res)
if err != nil {
return nil, errors.Wrap(err, "failed to unmarshal response")

View File

@ -35,7 +35,6 @@ go_library(
"//internal/env",
"//internal/errcode",
"//internal/featureflag",
"//internal/gitserver",
"//internal/goroutine",
"//internal/honey",
"//internal/httpserver",
@ -79,8 +78,8 @@ go_test(
srcs = [
"context_detection_test.go",
"context_qa_test.go",
"main_test.go",
"repo_embedding_index_cache_test.go",
"search_test.go",
],
embed = [":shared"],
embedsrcs = [
@ -93,10 +92,11 @@ go_test(
"//enterprise/internal/embeddings/background/repo",
"//internal/api",
"//internal/database",
"//internal/endpoint",
"//internal/types",
"//internal/uploadstore/mocks",
"//lib/errors",
"@com_github_sourcegraph_log//:log",
"@com_github_sourcegraph_log//logtest",
"@com_github_stretchr_testify//require",
],
)

View File

@ -13,8 +13,6 @@ import (
"path/filepath"
"testing"
"github.com/sourcegraph/log"
"github.com/sourcegraph/sourcegraph/enterprise/cmd/embeddings/qa"
"github.com/sourcegraph/sourcegraph/enterprise/internal/embeddings"
"github.com/sourcegraph/sourcegraph/internal/api"
@ -34,7 +32,6 @@ func TestRecall(t *testing.T) {
}
ctx := context.Background()
logger := log.NoOp()
// Set up mock functions
queryEmbeddings, err := loadQueryEmbeddings(t)
@ -59,20 +56,13 @@ func TestRecall(t *testing.T) {
return embeddings.DownloadRepoEmbeddingIndex(context.Background(), mockStore, string(key))
}
// We only care about the file names in this test.
mockReadFile := func(ctx context.Context, repoName api.RepoName, revision api.CommitID, fileName string) ([]byte, error) {
return []byte{}, nil
}
// Weaviate is disabled per default. We don't need it for this test.
weaviate := &weaviateClient{}
searcher := func(args embeddings.EmbeddingsSearchParameters) (*embeddings.EmbeddingSearchResults, error) {
return searchRepoEmbeddingIndex(
searcher := func(args embeddings.EmbeddingsSearchParameters) (*embeddings.EmbeddingCombinedSearchResults, error) {
return searchRepoEmbeddingIndexes(
ctx,
logger,
args,
mockReadFile,
getRepoEmbeddingIndex,
getQueryEmbedding,
weaviate,
@ -120,8 +110,8 @@ func loadQueryEmbeddings(t *testing.T) (map[string][]float32, error) {
return m, nil
}
type embeddingsSearcherFunc func(args embeddings.EmbeddingsSearchParameters) (*embeddings.EmbeddingSearchResults, error)
type embeddingsSearcherFunc func(args embeddings.EmbeddingsSearchParameters) (*embeddings.EmbeddingCombinedSearchResults, error)
func (f embeddingsSearcherFunc) Search(args embeddings.EmbeddingsSearchParameters) (*embeddings.EmbeddingSearchResults, error) {
func (f embeddingsSearcherFunc) Search(args embeddings.EmbeddingsSearchParameters) (*embeddings.EmbeddingCombinedSearchResults, error) {
return f(args)
}

View File

@ -19,7 +19,6 @@ import (
"github.com/sourcegraph/sourcegraph/enterprise/internal/embeddings"
"github.com/sourcegraph/sourcegraph/enterprise/internal/embeddings/background/repo"
"github.com/sourcegraph/sourcegraph/internal/actor"
"github.com/sourcegraph/sourcegraph/internal/api"
"github.com/sourcegraph/sourcegraph/internal/authz"
"github.com/sourcegraph/sourcegraph/internal/conf"
"github.com/sourcegraph/sourcegraph/internal/conf/conftypes"
@ -27,7 +26,6 @@ import (
connections "github.com/sourcegraph/sourcegraph/internal/database/connections/live"
"github.com/sourcegraph/sourcegraph/internal/errcode"
"github.com/sourcegraph/sourcegraph/internal/featureflag"
"github.com/sourcegraph/sourcegraph/internal/gitserver"
"github.com/sourcegraph/sourcegraph/internal/goroutine"
"github.com/sourcegraph/sourcegraph/internal/honey"
"github.com/sourcegraph/sourcegraph/internal/httpserver"
@ -58,7 +56,6 @@ func Main(ctx context.Context, observationCtx *observation.Context, ready servic
repoEmbeddingJobsStore := repo.NewRepoEmbeddingJobsStore(db)
// Run setup
gitserverClient := gitserver.NewClient()
uploadStore, err := embeddings.NewEmbeddingsUploadStore(ctx, observationCtx, config.EmbeddingsUploadStoreConfig)
if err != nil {
return err
@ -69,10 +66,6 @@ func Main(ctx context.Context, observationCtx *observation.Context, ready servic
return errors.Wrap(err, "creating sub-repo client")
}
readFile := func(ctx context.Context, repoName api.RepoName, revision api.CommitID, fileName string) ([]byte, error) {
return gitserverClient.ReadFile(ctx, authz.DefaultSubRepoPermsChecker, repoName, revision, fileName)
}
indexGetter, err := NewCachedEmbeddingIndexGetter(
repoStore,
repoEmbeddingJobsStore,
@ -92,7 +85,6 @@ func Main(ctx context.Context, observationCtx *observation.Context, ready servic
weaviate := newWeaviateClient(
logger,
readFile,
getQueryEmbedding,
config.WeaviateURL,
)
@ -100,7 +92,7 @@ func Main(ctx context.Context, observationCtx *observation.Context, ready servic
getContextDetectionEmbeddingIndex := getCachedContextDetectionEmbeddingIndex(uploadStore)
// Create HTTP server
handler := NewHandler(logger, readFile, indexGetter.Get, getQueryEmbedding, weaviate, getContextDetectionEmbeddingIndex)
handler := NewHandler(logger, indexGetter.Get, getQueryEmbedding, weaviate, getContextDetectionEmbeddingIndex)
handler = handlePanic(logger, handler)
handler = featureflag.Middleware(db.FeatureFlags(), handler)
handler = trace.HTTPMiddleware(logger, handler, conf.DefaultClient())
@ -122,7 +114,6 @@ func Main(ctx context.Context, observationCtx *observation.Context, ready servic
func NewHandler(
logger log.Logger,
readFile readFileFn,
getRepoEmbeddingIndex getRepoEmbeddingIndexFn,
getQueryEmbedding getQueryEmbeddingFn,
weaviate *weaviateClient,
@ -143,7 +134,7 @@ func NewHandler(
return
}
res, err := searchRepoEmbeddingIndex(r.Context(), logger, args, readFile, getRepoEmbeddingIndex, getQueryEmbedding, weaviate)
res, err := searchRepoEmbeddingIndexes(r.Context(), args, getRepoEmbeddingIndex, getQueryEmbedding, weaviate)
if errcode.IsNotFound(err) {
http.Error(w, err.Error(), http.StatusBadRequest)
return

View File

@ -0,0 +1,248 @@
package shared
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/sourcegraph/log/logtest"
"github.com/sourcegraph/sourcegraph/enterprise/internal/embeddings"
"github.com/sourcegraph/sourcegraph/internal/api"
"github.com/sourcegraph/sourcegraph/internal/endpoint"
"github.com/stretchr/testify/require"
)
func TestEmbeddingsSearch(t *testing.T) {
logger := logtest.Scoped(t)
makeIndex := func(name api.RepoName, w int8) *embeddings.RepoEmbeddingIndex {
return &embeddings.RepoEmbeddingIndex{
RepoName: name,
Revision: "",
CodeIndex: embeddings.EmbeddingIndex{
Embeddings: []int8{
w, 0, 0, 0,
0, w, 0, 0,
0, 0, w, 0,
0, 0, 0, w,
},
ColumnDimension: 4,
RowMetadata: []embeddings.RepoEmbeddingRowMetadata{
{FileName: "codefile1", StartLine: 0, EndLine: 1},
{FileName: "codefile2", StartLine: 0, EndLine: 1},
{FileName: "codefile3", StartLine: 0, EndLine: 1},
{FileName: "codefile4", StartLine: 0, EndLine: 1},
},
},
TextIndex: embeddings.EmbeddingIndex{
Embeddings: []int8{
w, 0, 0, 0,
0, w, 0, 0,
0, 0, w, 0,
0, 0, 0, w,
},
ColumnDimension: 4,
RowMetadata: []embeddings.RepoEmbeddingRowMetadata{
{FileName: "textfile1", StartLine: 0, EndLine: 1},
{FileName: "textfile2", StartLine: 0, EndLine: 1},
{FileName: "textfile3", StartLine: 0, EndLine: 1},
{FileName: "textfile4", StartLine: 0, EndLine: 1},
},
},
}
}
indexes := map[api.RepoName]*embeddings.RepoEmbeddingIndex{
"repo1": makeIndex("repo1", 1),
"repo2": makeIndex("repo2", 2),
"repo3": makeIndex("repo3", 3),
"repo4": makeIndex("repo4", 4),
}
getRepoEmbeddingIndex := func(_ context.Context, repoName api.RepoName) (*embeddings.RepoEmbeddingIndex, error) {
return indexes[repoName], nil
}
getQueryEmbedding := func(_ context.Context, query string) ([]float32, error) {
switch query {
case "one":
return []float32{1, 0, 0, 0}, nil
case "two":
return []float32{0, 1, 0, 0}, nil
case "three":
return []float32{0, 0, 1, 0}, nil
case "four":
return []float32{0, 0, 1, 1}, nil
case "context detection":
return []float32{2, 4, 6, 8}, nil
default:
panic("unknown")
}
}
getContextDetectionEmbeddingIndex := func(context.Context) (*embeddings.ContextDetectionEmbeddingIndex, error) {
return &embeddings.ContextDetectionEmbeddingIndex{
MessagesWithAdditionalContextMeanEmbedding: []float32{1, 2, 3, 4},
MessagesWithoutAdditionalContextMeanEmbedding: []float32{4, 3, 2, 1},
}, nil
}
server1 := httptest.NewServer(NewHandler(
logger,
getRepoEmbeddingIndex,
getQueryEmbedding,
nil,
getContextDetectionEmbeddingIndex,
))
server2 := httptest.NewServer(NewHandler(
logger,
getRepoEmbeddingIndex,
getQueryEmbedding,
nil,
getContextDetectionEmbeddingIndex,
))
client := embeddings.NewClient(endpoint.Static(server1.URL, server2.URL), http.DefaultClient)
{
// First test: we should return results for file1 based on the query.
// The rankings should have repo4 highest because it has the largest weighted
// embeddings.
params := embeddings.EmbeddingsSearchParameters{
RepoNames: []api.RepoName{"repo1", "repo2", "repo3", "repo4"},
RepoIDs: []api.RepoID{1, 2, 3, 4},
Query: "one",
CodeResultsCount: 2,
TextResultsCount: 2,
UseDocumentRanks: false,
}
results, err := client.Search(context.Background(), params)
require.NoError(t, err)
require.Equal(t, &embeddings.EmbeddingCombinedSearchResults{
CodeResults: embeddings.EmbeddingSearchResults{{
RepoName: "repo4",
FileName: "codefile1",
StartLine: 0,
EndLine: 1,
ScoreDetails: embeddings.SearchScoreDetails{Score: 1016, SimilarityScore: 1016},
}, {
RepoName: "repo3",
FileName: "codefile1",
StartLine: 0,
EndLine: 1,
ScoreDetails: embeddings.SearchScoreDetails{Score: 762, SimilarityScore: 762},
}},
TextResults: embeddings.EmbeddingSearchResults{{
RepoName: "repo4",
FileName: "textfile1",
StartLine: 0,
EndLine: 1,
ScoreDetails: embeddings.SearchScoreDetails{Score: 1016, SimilarityScore: 1016},
}, {
RepoName: "repo3",
FileName: "textfile1",
StartLine: 0,
EndLine: 1,
ScoreDetails: embeddings.SearchScoreDetails{Score: 762, SimilarityScore: 762},
}},
}, results)
}
{
// Second test: providing a subset of repos should only search those repos
params := embeddings.EmbeddingsSearchParameters{
RepoNames: []api.RepoName{"repo1", "repo3"},
RepoIDs: []api.RepoID{1, 2, 3, 4},
Query: "one",
CodeResultsCount: 2,
TextResultsCount: 2,
UseDocumentRanks: false,
}
results, err := client.Search(context.Background(), params)
require.NoError(t, err)
require.Equal(t, &embeddings.EmbeddingCombinedSearchResults{
CodeResults: embeddings.EmbeddingSearchResults{{
RepoName: "repo3",
FileName: "codefile1",
StartLine: 0,
EndLine: 1,
ScoreDetails: embeddings.SearchScoreDetails{Score: 762, SimilarityScore: 762},
}, {
RepoName: "repo1",
FileName: "codefile1",
StartLine: 0,
EndLine: 1,
ScoreDetails: embeddings.SearchScoreDetails{Score: 254, SimilarityScore: 254},
}},
TextResults: embeddings.EmbeddingSearchResults{{
RepoName: "repo3",
FileName: "textfile1",
StartLine: 0,
EndLine: 1,
ScoreDetails: embeddings.SearchScoreDetails{Score: 762, SimilarityScore: 762},
}, {
RepoName: "repo1",
FileName: "textfile1",
StartLine: 0,
EndLine: 1,
ScoreDetails: embeddings.SearchScoreDetails{Score: 254, SimilarityScore: 254},
}},
}, results)
}
{
// Third test: try a different file just to be safe
params := embeddings.EmbeddingsSearchParameters{
RepoNames: []api.RepoName{"repo1", "repo2", "repo3", "repo4"},
RepoIDs: []api.RepoID{1, 2, 3, 4},
Query: "three",
CodeResultsCount: 2,
TextResultsCount: 2,
UseDocumentRanks: false,
}
results, err := client.Search(context.Background(), params)
require.NoError(t, err)
require.Equal(t, &embeddings.EmbeddingCombinedSearchResults{
CodeResults: embeddings.EmbeddingSearchResults{{
RepoName: "repo4",
FileName: "codefile3",
StartLine: 0,
EndLine: 1,
ScoreDetails: embeddings.SearchScoreDetails{Score: 1016, SimilarityScore: 1016},
}, {
RepoName: "repo3",
FileName: "codefile3",
StartLine: 0,
EndLine: 1,
ScoreDetails: embeddings.SearchScoreDetails{Score: 762, SimilarityScore: 762},
}},
TextResults: embeddings.EmbeddingSearchResults{{
RepoName: "repo4",
FileName: "textfile3",
StartLine: 0,
EndLine: 1,
ScoreDetails: embeddings.SearchScoreDetails{Score: 1016, SimilarityScore: 1016},
}, {
RepoName: "repo3",
FileName: "textfile3",
StartLine: 0,
EndLine: 1,
ScoreDetails: embeddings.SearchScoreDetails{Score: 762, SimilarityScore: 762},
}},
}, results)
}
t.Run("IsContextRequiredForChatQuery", func(t *testing.T) {
res, err := client.IsContextRequiredForChatQuery(context.Background(), embeddings.IsContextRequiredForChatQueryParameters{
Query: "context detection",
})
require.NoError(t, err)
require.True(t, res)
})
}

View File

@ -2,145 +2,66 @@ package shared
import (
"context"
"fmt"
"os"
"runtime"
"strings"
"github.com/sourcegraph/log"
"github.com/sourcegraph/sourcegraph/enterprise/internal/embeddings"
"github.com/sourcegraph/sourcegraph/internal/api"
"github.com/sourcegraph/sourcegraph/lib/errors"
)
type readFileFn func(ctx context.Context, repoName api.RepoName, revision api.CommitID, fileName string) ([]byte, error)
const SIMILARITY_SEARCH_MIN_ROWS_TO_SPLIT = 1000
type getRepoEmbeddingIndexFn func(ctx context.Context, repoName api.RepoName) (*embeddings.RepoEmbeddingIndex, error)
type getQueryEmbeddingFn func(ctx context.Context, query string) ([]float32, error)
func searchRepoEmbeddingIndex(
func searchRepoEmbeddingIndexes(
ctx context.Context,
logger log.Logger,
params embeddings.EmbeddingsSearchParameters,
readFile readFileFn,
getRepoEmbeddingIndex getRepoEmbeddingIndexFn,
getQueryEmbedding getQueryEmbeddingFn,
weaviate *weaviateClient,
) (*embeddings.EmbeddingSearchResults, error) {
if weaviate.Use(ctx) {
return weaviate.Search(ctx, params)
}
embeddingIndex, err := getRepoEmbeddingIndex(ctx, params.RepoName)
if err != nil {
return nil, errors.Wrapf(err, "getting repo embedding index for repo %q", params.RepoName)
}
) (*embeddings.EmbeddingCombinedSearchResults, error) {
floatQuery, err := getQueryEmbedding(ctx, params.Query)
if err != nil {
return nil, errors.Wrap(err, "getting query embedding")
}
embeddedQuery := embeddings.Quantize(floatQuery)
opts := embeddings.SearchOptions{
Debug: params.Debug,
workerOpts := embeddings.WorkerOptions{
NumWorkers: runtime.GOMAXPROCS(0),
MinRowsToSplit: SIMILARITY_SEARCH_MIN_ROWS_TO_SPLIT,
}
searchOpts := embeddings.SearchOptions{
UseDocumentRanks: params.UseDocumentRanks,
}
codeResults := searchEmbeddingIndex(ctx, logger, embeddingIndex.RepoName, embeddingIndex.Revision, &embeddingIndex.CodeIndex, readFile, embeddedQuery, params.CodeResultsCount, opts)
textResults := searchEmbeddingIndex(ctx, logger, embeddingIndex.RepoName, embeddingIndex.Revision, &embeddingIndex.TextIndex, readFile, embeddedQuery, params.TextResultsCount, opts)
var result embeddings.EmbeddingCombinedSearchResults
return &embeddings.EmbeddingSearchResults{CodeResults: codeResults, TextResults: textResults}, nil
}
const SIMILARITY_SEARCH_MIN_ROWS_TO_SPLIT = 1000
func searchEmbeddingIndex(
ctx context.Context,
logger log.Logger,
repoName api.RepoName,
revision api.CommitID,
index *embeddings.EmbeddingIndex,
readFile readFileFn,
query []int8,
nResults int,
opts embeddings.SearchOptions,
) []embeddings.EmbeddingSearchResult {
numWorkers := runtime.GOMAXPROCS(0)
results := index.SimilaritySearch(query, nResults, embeddings.WorkerOptions{NumWorkers: numWorkers, MinRowsToSplit: SIMILARITY_SEARCH_MIN_ROWS_TO_SPLIT}, opts)
return filterAndHydrateContent(
ctx,
logger,
repoName,
revision,
readFile,
opts.Debug,
results,
)
}
// filterAndHydrateContent will mutate unfiltered to populate the Content
// field. If we fail to read a file (eg permission issues) we will remove the
// result. As such the returned slice should be used.
func filterAndHydrateContent(
ctx context.Context,
logger log.Logger,
repoName api.RepoName,
revision api.CommitID,
readFile readFileFn,
debug bool,
unfiltered []embeddings.SimilaritySearchResult,
) []embeddings.EmbeddingSearchResult {
filtered := make([]embeddings.EmbeddingSearchResult, 0, len(unfiltered))
for _, result := range unfiltered {
fileContent, err := readFile(ctx, repoName, revision, result.FileName)
if err != nil {
if !os.IsNotExist(err) {
logger.Error("error reading file", log.String("repoName", string(repoName)), log.String("revision", string(revision)), log.String("fileName", result.FileName), log.Error(err))
for i, repoName := range params.RepoNames {
if weaviate.Use(ctx) {
codeResults, textResults, err := weaviate.Search(ctx, repoName, params.RepoIDs[i], params.Query, params.CodeResultsCount, params.TextResultsCount)
if err != nil {
return nil, err
}
result.CodeResults.MergeTruncate(codeResults, params.CodeResultsCount)
result.TextResults.MergeTruncate(textResults, params.TextResultsCount)
continue
}
lines := strings.Split(string(fileContent), "\n")
// Sanity check: check that startLine and endLine are within 0 and len(lines).
startLine := max(0, min(len(lines), result.StartLine))
endLine := max(0, min(len(lines), result.EndLine))
content := strings.Join(lines[startLine:endLine], "\n")
var debugString string
if debug {
debugString = fmt.Sprintf("score:%d, similarity:%d, rank:%d", result.Score(), result.SimilarityScore, result.RankScore)
embeddingIndex, err := getRepoEmbeddingIndex(ctx, repoName)
if err != nil {
return nil, errors.Wrapf(err, "getting repo embedding index for repo %q", repoName)
}
filtered = append(filtered, embeddings.EmbeddingSearchResult{
RepoName: repoName,
Revision: revision,
RepoEmbeddingRowMetadata: embeddings.RepoEmbeddingRowMetadata{
FileName: result.FileName,
StartLine: startLine,
EndLine: endLine,
},
Debug: debugString,
Content: content,
})
codeResults := embeddingIndex.CodeIndex.SimilaritySearch(embeddedQuery, params.CodeResultsCount, workerOpts, searchOpts, embeddingIndex.RepoName, embeddingIndex.Revision)
textResults := embeddingIndex.TextIndex.SimilaritySearch(embeddedQuery, params.TextResultsCount, workerOpts, searchOpts, embeddingIndex.RepoName, embeddingIndex.Revision)
result.CodeResults.MergeTruncate(codeResults, params.CodeResultsCount)
result.TextResults.MergeTruncate(textResults, params.TextResultsCount)
}
return filtered
}
func min(a, b int) int {
if a < b {
return a
}
return b
}
func max(a, b int) int {
if a > b {
return a
}
return b
return &result, nil
}

View File

@ -1,38 +0,0 @@
package shared
import (
"context"
"testing"
"github.com/sourcegraph/log"
"github.com/sourcegraph/sourcegraph/enterprise/internal/embeddings"
"github.com/sourcegraph/sourcegraph/internal/api"
)
func TestFilterAndHydrateContent_emptyFile(t *testing.T) {
// Set up test data
repoName := api.RepoName("example/repo")
revision := api.CommitID("abc123")
debug := false
unfiltered := []embeddings.SimilaritySearchResult{
{
RepoEmbeddingRowMetadata: embeddings.RepoEmbeddingRowMetadata{
FileName: "file.txt",
StartLine: 5,
EndLine: 20,
},
},
}
// Define a mock readFile function that returns an empty string
readFile := func(ctx context.Context, repoName api.RepoName, revision api.CommitID, fileName string) ([]byte, error) {
return []byte{}, nil
}
filtered := filterAndHydrateContent(context.Background(), log.NoOp(), repoName, revision, readFile, debug, unfiltered)
if len(filtered) != 1 {
t.Errorf("Expected 1 filtered result, but got %d elements", len(filtered))
}
}

View File

@ -19,7 +19,6 @@ import (
type weaviateClient struct {
logger log.Logger
readFile readFileFn
getQueryEmbedding getQueryEmbeddingFn
client *weaviate.Client
@ -28,7 +27,6 @@ type weaviateClient struct {
func newWeaviateClient(
logger log.Logger,
readFile readFileFn,
getQueryEmbedding getQueryEmbeddingFn,
url *url.URL,
) *weaviateClient {
@ -45,7 +43,6 @@ func newWeaviateClient(
return &weaviateClient{
logger: logger.Scoped("weaviate", "client for weaviate embedding index"),
readFile: readFile,
getQueryEmbedding: getQueryEmbedding,
client: client,
clientErr: err,
@ -56,14 +53,14 @@ func (w *weaviateClient) Use(ctx context.Context) bool {
return featureflag.FromContext(ctx).GetBoolOr("search-weaviate", false)
}
func (w *weaviateClient) Search(ctx context.Context, params embeddings.EmbeddingsSearchParameters) (*embeddings.EmbeddingSearchResults, error) {
func (w *weaviateClient) Search(ctx context.Context, repoName api.RepoName, repoID api.RepoID, query string, codeResultsCount, textResultsCount int) (codeResults, textResults []embeddings.EmbeddingSearchResult, _ error) {
if w.clientErr != nil {
return nil, w.clientErr
return nil, nil, w.clientErr
}
embeddedQuery, err := w.getQueryEmbedding(ctx, params.Query)
embeddedQuery, err := w.getQueryEmbedding(ctx, query)
if err != nil {
return nil, errors.Wrap(err, "getting query embedding")
return nil, nil, errors.Wrap(err, "getting query embedding")
}
queryBuilder := func(klass string, limit int) *graphql.GetBuilder {
@ -75,6 +72,9 @@ func (w *weaviateClient) Search(ctx context.Context, params embeddings.Embedding
{Name: "start_line"},
{Name: "end_line"},
{Name: "revision"},
{Name: "_additional", Fields: []graphql.Field{
{Name: "distance"},
}},
}...).
WithLimit(limit)
}
@ -86,7 +86,7 @@ func (w *weaviateClient) Search(ctx context.Context, params embeddings.Embedding
return nil
}
srs := make([]embeddings.SimilaritySearchResult, 0, len(code))
srs := make([]embeddings.EmbeddingSearchResult, 0, len(code))
revision := ""
for _, c := range code {
cMap := c.(map[string]any)
@ -96,50 +96,48 @@ func (w *weaviateClient) Search(ctx context.Context, params embeddings.Embedding
if revision == "" {
revision = rev
} else {
w.logger.Warn("inconsistent revisions returned for an embedded repository", log.Int("repoid", int(params.RepoID)), log.String("filename", fileName), log.String("revision1", revision), log.String("revision2", rev))
w.logger.Warn("inconsistent revisions returned for an embedded repository", log.Int("repoid", int(repoID)), log.String("filename", fileName), log.String("revision1", revision), log.String("revision2", rev))
}
}
srs = append(srs, embeddings.SimilaritySearchResult{
RepoEmbeddingRowMetadata: embeddings.RepoEmbeddingRowMetadata{
FileName: fileName,
StartLine: int(cMap["start_line"].(float64)),
EndLine: int(cMap["end_line"].(float64)),
// multiply by half max int32 since distance will always be between 0 and 2
similarity := int32(cMap["_additional"].(map[string]any)["distance"].(float64) * (1073741823))
srs = append(srs, embeddings.EmbeddingSearchResult{
RepoName: repoName,
Revision: api.CommitID(revision),
FileName: fileName,
StartLine: int(cMap["start_line"].(float64)),
EndLine: int(cMap["end_line"].(float64)),
ScoreDetails: embeddings.SearchScoreDetails{
Score: similarity,
SimilarityScore: similarity,
},
})
}
commit := api.CommitID(revision)
if commit == "" {
w.logger.Warn("no revision set for an embedded repository", log.Int("repoid", int(params.RepoID)))
commit = api.CommitID("HEAD")
}
return filterAndHydrateContent(ctx, w.logger, params.RepoName, commit, w.readFile, false, srs)
return srs
}
// We partition the indexes by type and repository. Each class in
// weaviate is its own index, so we achieve partitioning by a class
// per repo and type.
codeClass := fmt.Sprintf("Code_%d", params.RepoID)
textClass := fmt.Sprintf("Text_%d", params.RepoID)
codeClass := fmt.Sprintf("Code_%d", repoID)
textClass := fmt.Sprintf("Text_%d", repoID)
res, err := w.client.GraphQL().MultiClassGet().
AddQueryClass(queryBuilder(codeClass, params.CodeResultsCount)).
AddQueryClass(queryBuilder(textClass, params.TextResultsCount)).
AddQueryClass(queryBuilder(codeClass, codeResultsCount)).
AddQueryClass(queryBuilder(textClass, textResultsCount)).
Do(ctx)
if err != nil {
return nil, errors.Wrap(err, "doing weaviate request")
return nil, nil, errors.Wrap(err, "doing weaviate request")
}
if len(res.Errors) > 0 {
return nil, weaviateGraphQLError(res.Errors)
return nil, nil, weaviateGraphQLError(res.Errors)
}
return &embeddings.EmbeddingSearchResults{
CodeResults: extractResults(res, codeClass),
TextResults: extractResults(res, textClass),
}, nil
return extractResults(res, codeClass), extractResults(res, textClass), nil
}
type weaviateGraphQLError []*models.GraphQLError

View File

@ -26,8 +26,7 @@ func Init(
repoEmbeddingsStore := repo.NewRepoEmbeddingJobsStore(db)
contextDetectionEmbeddingsStore := contextdetection.NewContextDetectionEmbeddingJobsStore(db)
gitserverClient := gitserver.NewClient()
embeddingsClient := embeddings.NewClient()
embeddingsClient := embeddings.NewDefaultClient()
enterpriseServices.EmbeddingsResolver = resolvers.NewResolver(
db,
observationCtx.Logger,

View File

@ -1,4 +1,4 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library")
load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
go_library(
name = "resolvers",
@ -17,6 +17,7 @@ go_library(
"//enterprise/internal/embeddings/background/repo",
"//internal/api",
"//internal/auth",
"//internal/authz",
"//internal/cody",
"//internal/conf",
"//internal/database",
@ -26,6 +27,31 @@ go_library(
"//lib/errors",
"@com_github_graph_gophers_graphql_go//:graphql-go",
"@com_github_graph_gophers_graphql_go//relay",
"@com_github_sourcegraph_conc//pool",
"@com_github_sourcegraph_log//:log",
],
)
go_test(
name = "resolvers_test",
srcs = ["resolvers_test.go"],
embed = [":resolvers"],
deps = [
"//cmd/frontend/graphqlbackend",
"//enterprise/internal/embeddings",
"//enterprise/internal/embeddings/background/contextdetection",
"//enterprise/internal/embeddings/background/repo",
"//internal/actor",
"//internal/api",
"//internal/authz",
"//internal/conf",
"//internal/database",
"//internal/featureflag",
"//internal/gitserver",
"//internal/types",
"//schema",
"@com_github_graph_gophers_graphql_go//:graphql-go",
"@com_github_sourcegraph_log//logtest",
"@com_github_stretchr_testify//require",
],
)

View File

@ -1,8 +1,12 @@
package resolvers
import (
"bytes"
"context"
"os"
"github.com/graph-gophers/graphql-go"
"github.com/sourcegraph/conc/pool"
"github.com/sourcegraph/log"
"github.com/sourcegraph/sourcegraph/lib/errors"
@ -15,6 +19,7 @@ import (
repobg "github.com/sourcegraph/sourcegraph/enterprise/internal/embeddings/background/repo"
"github.com/sourcegraph/sourcegraph/internal/api"
"github.com/sourcegraph/sourcegraph/internal/auth"
"github.com/sourcegraph/sourcegraph/internal/authz"
"github.com/sourcegraph/sourcegraph/internal/cody"
"github.com/sourcegraph/sourcegraph/internal/conf"
"github.com/sourcegraph/sourcegraph/internal/database"
@ -25,7 +30,7 @@ func NewResolver(
db database.DB,
logger log.Logger,
gitserverClient gitserver.Client,
embeddingsClient *embeddings.Client,
embeddingsClient embeddings.Client,
repoStore repobg.RepoEmbeddingJobsStore,
contextDetectionStore contextdetectionbg.ContextDetectionEmbeddingJobsStore,
) graphqlbackend.EmbeddingsResolver {
@ -43,13 +48,22 @@ type Resolver struct {
db database.DB
logger log.Logger
gitserverClient gitserver.Client
embeddingsClient *embeddings.Client
embeddingsClient embeddings.Client
repoEmbeddingJobsStore repobg.RepoEmbeddingJobsStore
contextDetectionJobsStore contextdetectionbg.ContextDetectionEmbeddingJobsStore
emails backend.UserEmailsService
}
func (r *Resolver) EmbeddingsSearch(ctx context.Context, args graphqlbackend.EmbeddingsSearchInputArgs) (graphqlbackend.EmbeddingsSearchResultsResolver, error) {
return r.EmbeddingsMultiSearch(ctx, graphqlbackend.EmbeddingsMultiSearchInputArgs{
Repos: []graphql.ID{args.Repo},
Query: args.Query,
CodeResultsCount: args.CodeResultsCount,
TextResultsCount: args.TextResultsCount,
})
}
func (r *Resolver) EmbeddingsMultiSearch(ctx context.Context, args graphqlbackend.EmbeddingsMultiSearchInputArgs) (graphqlbackend.EmbeddingsSearchResultsResolver, error) {
if !conf.EmbeddingsEnabled() {
return nil, errors.New("embeddings are not configured or disabled")
}
@ -62,19 +76,28 @@ func (r *Resolver) EmbeddingsSearch(ctx context.Context, args graphqlbackend.Emb
return nil, err
}
repoID, err := graphqlbackend.UnmarshalRepositoryID(args.Repo)
repoIDs := make([]api.RepoID, len(args.Repos))
for i, repo := range args.Repos {
repoID, err := graphqlbackend.UnmarshalRepositoryID(repo)
if err != nil {
return nil, err
}
repoIDs[i] = repoID
}
repos, err := r.db.Repos().GetByIDs(ctx, repoIDs...)
if err != nil {
return nil, err
}
repo, err := r.db.Repos().Get(ctx, repoID)
if err != nil {
return nil, err
repoNames := make([]api.RepoName, len(repos))
for i, repo := range repos {
repoNames[i] = repo.Name
}
results, err := r.embeddingsClient.Search(ctx, embeddings.EmbeddingsSearchParameters{
RepoName: repo.Name,
RepoID: repoID,
RepoNames: repoNames,
RepoIDs: repoIDs,
Query: args.Query,
CodeResultsCount: int(args.CodeResultsCount),
TextResultsCount: int(args.TextResultsCount),
@ -83,7 +106,11 @@ func (r *Resolver) EmbeddingsSearch(ctx context.Context, args graphqlbackend.Emb
return nil, err
}
return &embeddingsSearchResultsResolver{results}, nil
return &embeddingsSearchResultsResolver{
results: results,
gitserver: r.gitserverClient,
logger: r.logger,
}, nil
}
func (r *Resolver) IsContextRequiredForChatQuery(ctx context.Context, args graphqlbackend.IsContextRequiredForChatQueryInputArgs) (bool, error) {
@ -183,27 +210,91 @@ func (r *Resolver) ScheduleContextDetectionForEmbedding(ctx context.Context) (*g
}
type embeddingsSearchResultsResolver struct {
results *embeddings.EmbeddingSearchResults
results *embeddings.EmbeddingCombinedSearchResults
gitserver gitserver.Client
logger log.Logger
}
func (r *embeddingsSearchResultsResolver) CodeResults(ctx context.Context) []graphqlbackend.EmbeddingsSearchResultResolver {
codeResults := make([]graphqlbackend.EmbeddingsSearchResultResolver, len(r.results.CodeResults))
for idx, result := range r.results.CodeResults {
codeResults[idx] = &embeddingsSearchResultResolver{result}
}
return codeResults
func (r *embeddingsSearchResultsResolver) CodeResults(ctx context.Context) ([]graphqlbackend.EmbeddingsSearchResultResolver, error) {
return embeddingsSearchResultsToResolvers(ctx, r.logger, r.gitserver, r.results.CodeResults)
}
func (r *embeddingsSearchResultsResolver) TextResults(ctx context.Context) []graphqlbackend.EmbeddingsSearchResultResolver {
textResults := make([]graphqlbackend.EmbeddingsSearchResultResolver, len(r.results.TextResults))
for idx, result := range r.results.TextResults {
textResults[idx] = &embeddingsSearchResultResolver{result}
func (r *embeddingsSearchResultsResolver) TextResults(ctx context.Context) ([]graphqlbackend.EmbeddingsSearchResultResolver, error) {
return embeddingsSearchResultsToResolvers(ctx, r.logger, r.gitserver, r.results.TextResults)
}
func embeddingsSearchResultsToResolvers(
ctx context.Context,
logger log.Logger,
gs gitserver.Client,
results []embeddings.EmbeddingSearchResult,
) ([]graphqlbackend.EmbeddingsSearchResultResolver, error) {
allContents := make([][]byte, len(results))
allErrors := make([]error, len(results))
{ // Fetch contents in parallel because fetching them serially can be slow.
p := pool.New().WithMaxGoroutines(8)
for i, result := range results {
i, result := i, result
p.Go(func() {
content, err := gs.ReadFile(ctx, authz.DefaultSubRepoPermsChecker, result.RepoName, result.Revision, result.FileName)
allContents[i] = content
allErrors[i] = err
})
}
p.Wait()
}
return textResults
resolvers := make([]graphqlbackend.EmbeddingsSearchResultResolver, 0, len(results))
{ // Merge the results with their contents, skipping any that errored when fetching the context.
for i, result := range results {
contents := allContents[i]
err := allErrors[i]
if err != nil {
if !os.IsNotExist(err) {
logger.Error(
"error reading file",
log.String("repoName", string(result.RepoName)),
log.String("revision", string(result.Revision)),
log.String("fileName", result.FileName),
log.Error(err),
)
}
continue
}
resolvers = append(resolvers, &embeddingsSearchResultResolver{
result: result,
content: string(extractLineRange(contents, result.StartLine, result.EndLine)),
})
}
}
return resolvers, nil
}
func extractLineRange(content []byte, startLine, endLine int) []byte {
lines := bytes.Split(content, []byte("\n"))
// Sanity check: check that startLine and endLine are within 0 and len(lines).
startLine = clamp(startLine, 0, len(lines))
endLine = clamp(endLine, 0, len(lines))
return bytes.Join(lines[startLine:endLine], []byte("\n"))
}
func clamp(input, min, max int) int {
if input > max {
return max
} else if input < min {
return min
}
return input
}
type embeddingsSearchResultResolver struct {
result embeddings.EmbeddingSearchResult
result embeddings.EmbeddingSearchResult
content string
}
func (r *embeddingsSearchResultResolver) RepoName(ctx context.Context) string {
@ -227,5 +318,5 @@ func (r *embeddingsSearchResultResolver) EndLine(ctx context.Context) int32 {
}
func (r *embeddingsSearchResultResolver) Content(ctx context.Context) string {
return r.result.Content
return r.content
}

View File

@ -0,0 +1,122 @@
package resolvers
import (
"context"
"os"
"testing"
"github.com/graph-gophers/graphql-go"
"github.com/sourcegraph/log/logtest"
"github.com/sourcegraph/sourcegraph/cmd/frontend/graphqlbackend"
"github.com/sourcegraph/sourcegraph/enterprise/internal/embeddings"
"github.com/sourcegraph/sourcegraph/enterprise/internal/embeddings/background/contextdetection"
"github.com/sourcegraph/sourcegraph/enterprise/internal/embeddings/background/repo"
"github.com/sourcegraph/sourcegraph/internal/actor"
"github.com/sourcegraph/sourcegraph/internal/api"
"github.com/sourcegraph/sourcegraph/internal/authz"
"github.com/sourcegraph/sourcegraph/internal/conf"
"github.com/sourcegraph/sourcegraph/internal/database"
"github.com/sourcegraph/sourcegraph/internal/featureflag"
"github.com/sourcegraph/sourcegraph/internal/gitserver"
"github.com/sourcegraph/sourcegraph/internal/types"
"github.com/sourcegraph/sourcegraph/schema"
"github.com/stretchr/testify/require"
)
func TestEmbeddingSearchResolver(t *testing.T) {
logger := logtest.Scoped(t)
mockDB := database.NewMockDB()
mockRepos := database.NewMockRepoStore()
mockRepos.GetByIDsFunc.SetDefaultReturn([]*types.Repo{{ID: 1, Name: "repo1"}}, nil)
mockDB.ReposFunc.SetDefaultReturn(mockRepos)
mockGitserver := gitserver.NewMockClient()
mockGitserver.ReadFileFunc.SetDefaultHook(func(_ context.Context, _ authz.SubRepoPermissionChecker, _ api.RepoName, _ api.CommitID, fileName string) ([]byte, error) {
if fileName == "testfile" {
return []byte("test\nfirst\nfour\nlines\nplus\nsome\nmore"), nil
}
return nil, os.ErrNotExist
})
mockEmbeddingsClient := embeddings.NewMockClient()
mockEmbeddingsClient.SearchFunc.SetDefaultReturn(&embeddings.EmbeddingCombinedSearchResults{
CodeResults: embeddings.EmbeddingSearchResults{{
FileName: "testfile",
StartLine: 0,
EndLine: 4,
}, {
FileName: "censored",
StartLine: 0,
EndLine: 4,
}},
}, nil)
repoEmbeddingJobsStore := repo.NewMockRepoEmbeddingJobsStore()
contextDetectionJobsStore := contextdetection.NewMockContextDetectionEmbeddingJobsStore()
resolver := NewResolver(
mockDB,
logger,
mockGitserver,
mockEmbeddingsClient,
repoEmbeddingJobsStore,
contextDetectionJobsStore,
)
conf.Mock(&conf.Unified{
SiteConfiguration: schema.SiteConfiguration{
Embeddings: &schema.Embeddings{Enabled: true},
Completions: &schema.Completions{Enabled: true},
},
})
ctx := actor.WithActor(context.Background(), actor.FromMockUser(1))
ffs := featureflag.NewMemoryStore(map[string]bool{"cody-experimental": true}, nil, nil)
ctx = featureflag.WithFlags(ctx, ffs)
results, err := resolver.EmbeddingsMultiSearch(ctx, graphqlbackend.EmbeddingsMultiSearchInputArgs{
Repos: []graphql.ID{graphqlbackend.MarshalRepositoryID(3)},
Query: "test",
CodeResultsCount: 1,
TextResultsCount: 1,
})
require.NoError(t, err)
codeResults, err := results.CodeResults(ctx)
require.NoError(t, err)
require.Len(t, codeResults, 1)
require.Equal(t, "test\nfirst\nfour\nlines", codeResults[0].Content(ctx))
}
func Test_extractLineRange(t *testing.T) {
cases := []struct {
input []byte
start, end int
output []byte
}{{
[]byte("zero\none\ntwo\nthree\n"),
0, 2,
[]byte("zero\none"),
}, {
[]byte("zero\none\ntwo\nthree\n"),
1, 2,
[]byte("one"),
}, {
[]byte("zero\none\ntwo\nthree\n"),
1, 2,
[]byte("one"),
}, {
[]byte(""),
1, 2,
[]byte(""),
}}
for _, tc := range cases {
t.Run("", func(t *testing.T) {
got := extractLineRange(tc.input, tc.start, tc.end)
require.Equal(t, tc.output, got)
})
}
}

View File

@ -12,6 +12,7 @@ go_library(
"dot_portable.go",
"index_name.go",
"index_storage.go",
"mocks_temp.go",
"quantize.go",
"similarity_search.go",
"tokens.go",
@ -31,6 +32,7 @@ go_library(
"//internal/uploadstore",
"//lib/errors",
"@com_github_sourcegraph_conc//:conc",
"@com_github_sourcegraph_conc//pool",
"@com_github_sourcegraph_log//:log",
"@org_golang_x_sync//errgroup",
] + select({

View File

@ -3,6 +3,7 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library")
go_library(
name = "contextdetection",
srcs = [
"mocks_temp.go",
"store.go",
"types.go",
],

View File

@ -0,0 +1,292 @@
// Code generated by go-mockgen 1.3.7; DO NOT EDIT.
//
// This file was generated by running `sg generate` (or `go-mockgen`) at the root of
// this repository. To add additional mocks to this or another package, add a new entry
// to the mockgen.yaml file in the root of this repository.
package contextdetection
import (
"context"
"sync"
basestore "github.com/sourcegraph/sourcegraph/internal/database/basestore"
)
// MockContextDetectionEmbeddingJobsStore is a mock implementation of the
// ContextDetectionEmbeddingJobsStore interface (from the package
// github.com/sourcegraph/sourcegraph/enterprise/internal/embeddings/background/contextdetection)
// used for unit testing.
type MockContextDetectionEmbeddingJobsStore struct {
// CreateContextDetectionEmbeddingJobFunc is an instance of a mock
// function object controlling the behavior of the method
// CreateContextDetectionEmbeddingJob.
CreateContextDetectionEmbeddingJobFunc *ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFunc
// HandleFunc is an instance of a mock function object controlling the
// behavior of the method Handle.
HandleFunc *ContextDetectionEmbeddingJobsStoreHandleFunc
}
// NewMockContextDetectionEmbeddingJobsStore creates a new mock of the
// ContextDetectionEmbeddingJobsStore interface. All methods return zero
// values for all results, unless overwritten.
func NewMockContextDetectionEmbeddingJobsStore() *MockContextDetectionEmbeddingJobsStore {
return &MockContextDetectionEmbeddingJobsStore{
CreateContextDetectionEmbeddingJobFunc: &ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFunc{
defaultHook: func(context.Context) (r0 int, r1 error) {
return
},
},
HandleFunc: &ContextDetectionEmbeddingJobsStoreHandleFunc{
defaultHook: func() (r0 basestore.TransactableHandle) {
return
},
},
}
}
// NewStrictMockContextDetectionEmbeddingJobsStore creates a new mock of the
// ContextDetectionEmbeddingJobsStore interface. All methods panic on
// invocation, unless overwritten.
func NewStrictMockContextDetectionEmbeddingJobsStore() *MockContextDetectionEmbeddingJobsStore {
return &MockContextDetectionEmbeddingJobsStore{
CreateContextDetectionEmbeddingJobFunc: &ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFunc{
defaultHook: func(context.Context) (int, error) {
panic("unexpected invocation of MockContextDetectionEmbeddingJobsStore.CreateContextDetectionEmbeddingJob")
},
},
HandleFunc: &ContextDetectionEmbeddingJobsStoreHandleFunc{
defaultHook: func() basestore.TransactableHandle {
panic("unexpected invocation of MockContextDetectionEmbeddingJobsStore.Handle")
},
},
}
}
// NewMockContextDetectionEmbeddingJobsStoreFrom creates a new mock of the
// MockContextDetectionEmbeddingJobsStore interface. All methods delegate to
// the given implementation, unless overwritten.
func NewMockContextDetectionEmbeddingJobsStoreFrom(i ContextDetectionEmbeddingJobsStore) *MockContextDetectionEmbeddingJobsStore {
return &MockContextDetectionEmbeddingJobsStore{
CreateContextDetectionEmbeddingJobFunc: &ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFunc{
defaultHook: i.CreateContextDetectionEmbeddingJob,
},
HandleFunc: &ContextDetectionEmbeddingJobsStoreHandleFunc{
defaultHook: i.Handle,
},
}
}
// ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFunc
// describes the behavior when the CreateContextDetectionEmbeddingJob method
// of the parent MockContextDetectionEmbeddingJobsStore instance is invoked.
type ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFunc struct {
defaultHook func(context.Context) (int, error)
hooks []func(context.Context) (int, error)
history []ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFuncCall
mutex sync.Mutex
}
// CreateContextDetectionEmbeddingJob delegates to the next hook function in
// the queue and stores the parameter and result values of this invocation.
func (m *MockContextDetectionEmbeddingJobsStore) CreateContextDetectionEmbeddingJob(v0 context.Context) (int, error) {
r0, r1 := m.CreateContextDetectionEmbeddingJobFunc.nextHook()(v0)
m.CreateContextDetectionEmbeddingJobFunc.appendCall(ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFuncCall{v0, r0, r1})
return r0, r1
}
// SetDefaultHook sets function that is called when the
// CreateContextDetectionEmbeddingJob method of the parent
// MockContextDetectionEmbeddingJobsStore instance is invoked and the hook
// queue is empty.
func (f *ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFunc) SetDefaultHook(hook func(context.Context) (int, error)) {
f.defaultHook = hook
}
// PushHook adds a function to the end of hook queue. Each invocation of the
// CreateContextDetectionEmbeddingJob method of the parent
// MockContextDetectionEmbeddingJobsStore instance invokes the hook at the
// front of the queue and discards it. After the queue is empty, the default
// hook function is invoked for any future action.
func (f *ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFunc) PushHook(hook func(context.Context) (int, error)) {
f.mutex.Lock()
f.hooks = append(f.hooks, hook)
f.mutex.Unlock()
}
// SetDefaultReturn calls SetDefaultHook with a function that returns the
// given values.
func (f *ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFunc) SetDefaultReturn(r0 int, r1 error) {
f.SetDefaultHook(func(context.Context) (int, error) {
return r0, r1
})
}
// PushReturn calls PushHook with a function that returns the given values.
func (f *ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFunc) PushReturn(r0 int, r1 error) {
f.PushHook(func(context.Context) (int, error) {
return r0, r1
})
}
func (f *ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFunc) nextHook() func(context.Context) (int, error) {
f.mutex.Lock()
defer f.mutex.Unlock()
if len(f.hooks) == 0 {
return f.defaultHook
}
hook := f.hooks[0]
f.hooks = f.hooks[1:]
return hook
}
func (f *ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFunc) appendCall(r0 ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFuncCall) {
f.mutex.Lock()
f.history = append(f.history, r0)
f.mutex.Unlock()
}
// History returns a sequence of
// ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFuncCall
// objects describing the invocations of this function.
func (f *ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFunc) History() []ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFuncCall {
f.mutex.Lock()
history := make([]ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFuncCall, len(f.history))
copy(history, f.history)
f.mutex.Unlock()
return history
}
// ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFuncCall
// is an object that describes an invocation of method
// CreateContextDetectionEmbeddingJob on an instance of
// MockContextDetectionEmbeddingJobsStore.
type ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFuncCall struct {
// Arg0 is the value of the 1st argument passed to this method
// invocation.
Arg0 context.Context
// Result0 is the value of the 1st result returned from this method
// invocation.
Result0 int
// Result1 is the value of the 2nd result returned from this method
// invocation.
Result1 error
}
// Args returns an interface slice containing the arguments of this
// invocation.
func (c ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFuncCall) Args() []interface{} {
return []interface{}{c.Arg0}
}
// Results returns an interface slice containing the results of this
// invocation.
func (c ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFuncCall) Results() []interface{} {
return []interface{}{c.Result0, c.Result1}
}
// ContextDetectionEmbeddingJobsStoreHandleFunc describes the behavior when
// the Handle method of the parent MockContextDetectionEmbeddingJobsStore
// instance is invoked.
type ContextDetectionEmbeddingJobsStoreHandleFunc struct {
defaultHook func() basestore.TransactableHandle
hooks []func() basestore.TransactableHandle
history []ContextDetectionEmbeddingJobsStoreHandleFuncCall
mutex sync.Mutex
}
// Handle delegates to the next hook function in the queue and stores the
// parameter and result values of this invocation.
func (m *MockContextDetectionEmbeddingJobsStore) Handle() basestore.TransactableHandle {
r0 := m.HandleFunc.nextHook()()
m.HandleFunc.appendCall(ContextDetectionEmbeddingJobsStoreHandleFuncCall{r0})
return r0
}
// SetDefaultHook sets function that is called when the Handle method of the
// parent MockContextDetectionEmbeddingJobsStore instance is invoked and the
// hook queue is empty.
func (f *ContextDetectionEmbeddingJobsStoreHandleFunc) SetDefaultHook(hook func() basestore.TransactableHandle) {
f.defaultHook = hook
}
// PushHook adds a function to the end of hook queue. Each invocation of the
// Handle method of the parent MockContextDetectionEmbeddingJobsStore
// instance invokes the hook at the front of the queue and discards it.
// After the queue is empty, the default hook function is invoked for any
// future action.
func (f *ContextDetectionEmbeddingJobsStoreHandleFunc) PushHook(hook func() basestore.TransactableHandle) {
f.mutex.Lock()
f.hooks = append(f.hooks, hook)
f.mutex.Unlock()
}
// SetDefaultReturn calls SetDefaultHook with a function that returns the
// given values.
func (f *ContextDetectionEmbeddingJobsStoreHandleFunc) SetDefaultReturn(r0 basestore.TransactableHandle) {
f.SetDefaultHook(func() basestore.TransactableHandle {
return r0
})
}
// PushReturn calls PushHook with a function that returns the given values.
func (f *ContextDetectionEmbeddingJobsStoreHandleFunc) PushReturn(r0 basestore.TransactableHandle) {
f.PushHook(func() basestore.TransactableHandle {
return r0
})
}
func (f *ContextDetectionEmbeddingJobsStoreHandleFunc) nextHook() func() basestore.TransactableHandle {
f.mutex.Lock()
defer f.mutex.Unlock()
if len(f.hooks) == 0 {
return f.defaultHook
}
hook := f.hooks[0]
f.hooks = f.hooks[1:]
return hook
}
func (f *ContextDetectionEmbeddingJobsStoreHandleFunc) appendCall(r0 ContextDetectionEmbeddingJobsStoreHandleFuncCall) {
f.mutex.Lock()
f.history = append(f.history, r0)
f.mutex.Unlock()
}
// History returns a sequence of
// ContextDetectionEmbeddingJobsStoreHandleFuncCall objects describing the
// invocations of this function.
func (f *ContextDetectionEmbeddingJobsStoreHandleFunc) History() []ContextDetectionEmbeddingJobsStoreHandleFuncCall {
f.mutex.Lock()
history := make([]ContextDetectionEmbeddingJobsStoreHandleFuncCall, len(f.history))
copy(history, f.history)
f.mutex.Unlock()
return history
}
// ContextDetectionEmbeddingJobsStoreHandleFuncCall is an object that
// describes an invocation of method Handle on an instance of
// MockContextDetectionEmbeddingJobsStore.
type ContextDetectionEmbeddingJobsStoreHandleFuncCall struct {
// Result0 is the value of the 1st result returned from this method
// invocation.
Result0 basestore.TransactableHandle
}
// Args returns an interface slice containing the arguments of this
// invocation.
func (c ContextDetectionEmbeddingJobsStoreHandleFuncCall) Args() []interface{} {
return []interface{}{}
}
// Results returns an interface slice containing the results of this
// invocation.
func (c ContextDetectionEmbeddingJobsStoreHandleFuncCall) Results() []interface{} {
return []interface{}{c.Result0}
}

View File

@ -8,6 +8,7 @@ import (
"net/http"
"strings"
"github.com/sourcegraph/conc/pool"
"github.com/sourcegraph/sourcegraph/lib/errors"
"github.com/sourcegraph/sourcegraph/internal/api"
@ -30,14 +31,23 @@ var defaultDoer = func() httpcli.Doer {
return d
}()
func NewClient() *Client {
return &Client{
Endpoints: defaultEndpoints(),
HTTPClient: defaultDoer,
func NewDefaultClient() Client {
return NewClient(defaultEndpoints(), defaultDoer)
}
func NewClient(endpoints *endpoint.Map, doer httpcli.Doer) Client {
return &client{
Endpoints: endpoints,
HTTPClient: doer,
}
}
type Client struct {
type Client interface {
Search(context.Context, EmbeddingsSearchParameters) (*EmbeddingCombinedSearchResults, error)
IsContextRequiredForChatQuery(context.Context, IsContextRequiredForChatQueryParameters) (bool, error)
}
type client struct {
// Endpoints to embeddings service.
Endpoints *endpoint.Map
@ -46,15 +56,13 @@ type Client struct {
}
type EmbeddingsSearchParameters struct {
RepoName api.RepoName `json:"repoName"`
RepoID api.RepoID `json:"repoID"`
Query string `json:"query"`
CodeResultsCount int `json:"codeResultsCount"`
TextResultsCount int `json:"textResultsCount"`
RepoNames []api.RepoName `json:"repoNames"`
RepoIDs []api.RepoID `json:"repoIDs"`
Query string `json:"query"`
CodeResultsCount int `json:"codeResultsCount"`
TextResultsCount int `json:"textResultsCount"`
UseDocumentRanks bool `json:"useDocumentRanks"`
// If set to "True", EmbeddingSearchResult.Debug will contain useful information about scoring.
Debug bool `json:"debug"`
}
type IsContextRequiredForChatQueryParameters struct {
@ -65,8 +73,43 @@ type IsContextRequiredForChatQueryResult struct {
IsRequired bool `json:"isRequired"`
}
func (c *Client) Search(ctx context.Context, args EmbeddingsSearchParameters) (*EmbeddingSearchResults, error) {
resp, err := c.httpPost(ctx, "search", args.RepoName, args)
func (c *client) Search(ctx context.Context, args EmbeddingsSearchParameters) (*EmbeddingCombinedSearchResults, error) {
partitions, err := c.partition(args.RepoNames, args.RepoIDs)
if err != nil {
return nil, err
}
p := pool.NewWithResults[*EmbeddingCombinedSearchResults]().WithContext(ctx)
for endpoint, partition := range partitions {
endpoint := endpoint
// make a copy for this request
args := args
args.RepoNames = partition.repoNames
args.RepoIDs = partition.repoIDs
p.Go(func(ctx context.Context) (*EmbeddingCombinedSearchResults, error) {
return c.searchPartition(ctx, endpoint, args)
})
}
allResults, err := p.Wait()
if err != nil {
return nil, err
}
var combinedResult EmbeddingCombinedSearchResults
for _, result := range allResults {
combinedResult.CodeResults.MergeTruncate(result.CodeResults, args.CodeResultsCount)
combinedResult.TextResults.MergeTruncate(result.TextResults, args.TextResultsCount)
}
return &combinedResult, nil
}
func (c *client) searchPartition(ctx context.Context, endpoint string, args EmbeddingsSearchParameters) (*EmbeddingCombinedSearchResults, error) {
resp, err := c.httpPost(ctx, "search", endpoint, args)
if err != nil {
return nil, err
}
@ -82,7 +125,7 @@ func (c *Client) Search(ctx context.Context, args EmbeddingsSearchParameters) (*
)
}
var response EmbeddingSearchResults
var response EmbeddingCombinedSearchResults
err = json.NewDecoder(resp.Body).Decode(&response)
if err != nil {
return nil, err
@ -90,8 +133,13 @@ func (c *Client) Search(ctx context.Context, args EmbeddingsSearchParameters) (*
return &response, nil
}
func (c *Client) IsContextRequiredForChatQuery(ctx context.Context, args IsContextRequiredForChatQueryParameters) (bool, error) {
resp, err := c.httpPost(ctx, "isContextRequiredForChatQuery", "", args)
func (c *client) IsContextRequiredForChatQuery(ctx context.Context, args IsContextRequiredForChatQueryParameters) (bool, error) {
endpoint, err := c.url("")
if err != nil {
return false, err
}
resp, err := c.httpPost(ctx, "isContextRequiredForChatQuery", endpoint, args)
if err != nil {
return false, err
}
@ -115,24 +163,50 @@ func (c *Client) IsContextRequiredForChatQuery(ctx context.Context, args IsConte
return response.IsRequired, nil
}
func (c *Client) url(repo api.RepoName) (string, error) {
func (c *client) url(repo api.RepoName) (string, error) {
if c.Endpoints == nil {
return "", errors.New("an embeddings service has not been configured")
}
return c.Endpoints.Get(string(repo))
}
func (c *Client) httpPost(
ctx context.Context,
method string,
repo api.RepoName,
payload any,
) (resp *http.Response, err error) {
url, err := c.url(repo)
type repoPartition struct {
repoNames []api.RepoName
repoIDs []api.RepoID
}
// returns a partition of the input repos by the endpoint their requests should be routed to
func (c *client) partition(repos []api.RepoName, repoIDs []api.RepoID) (map[string]repoPartition, error) {
if c.Endpoints == nil {
return nil, errors.New("an embeddings service has not been configured")
}
repoStrings := make([]string, len(repos))
for i, repo := range repos {
repoStrings[i] = string(repo)
}
endpoints, err := c.Endpoints.GetMany(repoStrings...)
if err != nil {
return nil, err
}
res := make(map[string]repoPartition)
for i, endpoint := range endpoints {
res[endpoint] = repoPartition{
repoNames: append(res[endpoint].repoNames, repos[i]),
repoIDs: append(res[endpoint].repoIDs, repoIDs[i]),
}
}
return res, nil
}
func (c *client) httpPost(
ctx context.Context,
method string,
url string,
payload any,
) (resp *http.Response, err error) {
reqBody, err := json.Marshal(payload)
if err != nil {
return nil, err

View File

@ -0,0 +1,291 @@
// Code generated by go-mockgen 1.3.7; DO NOT EDIT.
//
// This file was generated by running `sg generate` (or `go-mockgen`) at the root of
// this repository. To add additional mocks to this or another package, add a new entry
// to the mockgen.yaml file in the root of this repository.
package embeddings
import (
"context"
"sync"
)
// MockClient is a mock implementation of the Client interface (from the
// package
// github.com/sourcegraph/sourcegraph/enterprise/internal/embeddings) used
// for unit testing.
type MockClient struct {
// IsContextRequiredForChatQueryFunc is an instance of a mock function
// object controlling the behavior of the method
// IsContextRequiredForChatQuery.
IsContextRequiredForChatQueryFunc *ClientIsContextRequiredForChatQueryFunc
// SearchFunc is an instance of a mock function object controlling the
// behavior of the method Search.
SearchFunc *ClientSearchFunc
}
// NewMockClient creates a new mock of the Client interface. All methods
// return zero values for all results, unless overwritten.
func NewMockClient() *MockClient {
return &MockClient{
IsContextRequiredForChatQueryFunc: &ClientIsContextRequiredForChatQueryFunc{
defaultHook: func(context.Context, IsContextRequiredForChatQueryParameters) (r0 bool, r1 error) {
return
},
},
SearchFunc: &ClientSearchFunc{
defaultHook: func(context.Context, EmbeddingsSearchParameters) (r0 *EmbeddingCombinedSearchResults, r1 error) {
return
},
},
}
}
// NewStrictMockClient creates a new mock of the Client interface. All
// methods panic on invocation, unless overwritten.
func NewStrictMockClient() *MockClient {
return &MockClient{
IsContextRequiredForChatQueryFunc: &ClientIsContextRequiredForChatQueryFunc{
defaultHook: func(context.Context, IsContextRequiredForChatQueryParameters) (bool, error) {
panic("unexpected invocation of MockClient.IsContextRequiredForChatQuery")
},
},
SearchFunc: &ClientSearchFunc{
defaultHook: func(context.Context, EmbeddingsSearchParameters) (*EmbeddingCombinedSearchResults, error) {
panic("unexpected invocation of MockClient.Search")
},
},
}
}
// NewMockClientFrom creates a new mock of the MockClient interface. All
// methods delegate to the given implementation, unless overwritten.
func NewMockClientFrom(i Client) *MockClient {
return &MockClient{
IsContextRequiredForChatQueryFunc: &ClientIsContextRequiredForChatQueryFunc{
defaultHook: i.IsContextRequiredForChatQuery,
},
SearchFunc: &ClientSearchFunc{
defaultHook: i.Search,
},
}
}
// ClientIsContextRequiredForChatQueryFunc describes the behavior when the
// IsContextRequiredForChatQuery method of the parent MockClient instance is
// invoked.
type ClientIsContextRequiredForChatQueryFunc struct {
defaultHook func(context.Context, IsContextRequiredForChatQueryParameters) (bool, error)
hooks []func(context.Context, IsContextRequiredForChatQueryParameters) (bool, error)
history []ClientIsContextRequiredForChatQueryFuncCall
mutex sync.Mutex
}
// IsContextRequiredForChatQuery delegates to the next hook function in the
// queue and stores the parameter and result values of this invocation.
func (m *MockClient) IsContextRequiredForChatQuery(v0 context.Context, v1 IsContextRequiredForChatQueryParameters) (bool, error) {
r0, r1 := m.IsContextRequiredForChatQueryFunc.nextHook()(v0, v1)
m.IsContextRequiredForChatQueryFunc.appendCall(ClientIsContextRequiredForChatQueryFuncCall{v0, v1, r0, r1})
return r0, r1
}
// SetDefaultHook sets function that is called when the
// IsContextRequiredForChatQuery method of the parent MockClient instance is
// invoked and the hook queue is empty.
func (f *ClientIsContextRequiredForChatQueryFunc) SetDefaultHook(hook func(context.Context, IsContextRequiredForChatQueryParameters) (bool, error)) {
f.defaultHook = hook
}
// PushHook adds a function to the end of hook queue. Each invocation of the
// IsContextRequiredForChatQuery method of the parent MockClient instance
// invokes the hook at the front of the queue and discards it. After the
// queue is empty, the default hook function is invoked for any future
// action.
func (f *ClientIsContextRequiredForChatQueryFunc) PushHook(hook func(context.Context, IsContextRequiredForChatQueryParameters) (bool, error)) {
f.mutex.Lock()
f.hooks = append(f.hooks, hook)
f.mutex.Unlock()
}
// SetDefaultReturn calls SetDefaultHook with a function that returns the
// given values.
func (f *ClientIsContextRequiredForChatQueryFunc) SetDefaultReturn(r0 bool, r1 error) {
f.SetDefaultHook(func(context.Context, IsContextRequiredForChatQueryParameters) (bool, error) {
return r0, r1
})
}
// PushReturn calls PushHook with a function that returns the given values.
func (f *ClientIsContextRequiredForChatQueryFunc) PushReturn(r0 bool, r1 error) {
f.PushHook(func(context.Context, IsContextRequiredForChatQueryParameters) (bool, error) {
return r0, r1
})
}
func (f *ClientIsContextRequiredForChatQueryFunc) nextHook() func(context.Context, IsContextRequiredForChatQueryParameters) (bool, error) {
f.mutex.Lock()
defer f.mutex.Unlock()
if len(f.hooks) == 0 {
return f.defaultHook
}
hook := f.hooks[0]
f.hooks = f.hooks[1:]
return hook
}
func (f *ClientIsContextRequiredForChatQueryFunc) appendCall(r0 ClientIsContextRequiredForChatQueryFuncCall) {
f.mutex.Lock()
f.history = append(f.history, r0)
f.mutex.Unlock()
}
// History returns a sequence of ClientIsContextRequiredForChatQueryFuncCall
// objects describing the invocations of this function.
func (f *ClientIsContextRequiredForChatQueryFunc) History() []ClientIsContextRequiredForChatQueryFuncCall {
f.mutex.Lock()
history := make([]ClientIsContextRequiredForChatQueryFuncCall, len(f.history))
copy(history, f.history)
f.mutex.Unlock()
return history
}
// ClientIsContextRequiredForChatQueryFuncCall is an object that describes
// an invocation of method IsContextRequiredForChatQuery on an instance of
// MockClient.
type ClientIsContextRequiredForChatQueryFuncCall struct {
// Arg0 is the value of the 1st argument passed to this method
// invocation.
Arg0 context.Context
// Arg1 is the value of the 2nd argument passed to this method
// invocation.
Arg1 IsContextRequiredForChatQueryParameters
// Result0 is the value of the 1st result returned from this method
// invocation.
Result0 bool
// Result1 is the value of the 2nd result returned from this method
// invocation.
Result1 error
}
// Args returns an interface slice containing the arguments of this
// invocation.
func (c ClientIsContextRequiredForChatQueryFuncCall) Args() []interface{} {
return []interface{}{c.Arg0, c.Arg1}
}
// Results returns an interface slice containing the results of this
// invocation.
func (c ClientIsContextRequiredForChatQueryFuncCall) Results() []interface{} {
return []interface{}{c.Result0, c.Result1}
}
// ClientSearchFunc describes the behavior when the Search method of the
// parent MockClient instance is invoked.
type ClientSearchFunc struct {
defaultHook func(context.Context, EmbeddingsSearchParameters) (*EmbeddingCombinedSearchResults, error)
hooks []func(context.Context, EmbeddingsSearchParameters) (*EmbeddingCombinedSearchResults, error)
history []ClientSearchFuncCall
mutex sync.Mutex
}
// Search delegates to the next hook function in the queue and stores the
// parameter and result values of this invocation.
func (m *MockClient) Search(v0 context.Context, v1 EmbeddingsSearchParameters) (*EmbeddingCombinedSearchResults, error) {
r0, r1 := m.SearchFunc.nextHook()(v0, v1)
m.SearchFunc.appendCall(ClientSearchFuncCall{v0, v1, r0, r1})
return r0, r1
}
// SetDefaultHook sets function that is called when the Search method of the
// parent MockClient instance is invoked and the hook queue is empty.
func (f *ClientSearchFunc) SetDefaultHook(hook func(context.Context, EmbeddingsSearchParameters) (*EmbeddingCombinedSearchResults, error)) {
f.defaultHook = hook
}
// PushHook adds a function to the end of hook queue. Each invocation of the
// Search method of the parent MockClient instance invokes the hook at the
// front of the queue and discards it. After the queue is empty, the default
// hook function is invoked for any future action.
func (f *ClientSearchFunc) PushHook(hook func(context.Context, EmbeddingsSearchParameters) (*EmbeddingCombinedSearchResults, error)) {
f.mutex.Lock()
f.hooks = append(f.hooks, hook)
f.mutex.Unlock()
}
// SetDefaultReturn calls SetDefaultHook with a function that returns the
// given values.
func (f *ClientSearchFunc) SetDefaultReturn(r0 *EmbeddingCombinedSearchResults, r1 error) {
f.SetDefaultHook(func(context.Context, EmbeddingsSearchParameters) (*EmbeddingCombinedSearchResults, error) {
return r0, r1
})
}
// PushReturn calls PushHook with a function that returns the given values.
func (f *ClientSearchFunc) PushReturn(r0 *EmbeddingCombinedSearchResults, r1 error) {
f.PushHook(func(context.Context, EmbeddingsSearchParameters) (*EmbeddingCombinedSearchResults, error) {
return r0, r1
})
}
func (f *ClientSearchFunc) nextHook() func(context.Context, EmbeddingsSearchParameters) (*EmbeddingCombinedSearchResults, error) {
f.mutex.Lock()
defer f.mutex.Unlock()
if len(f.hooks) == 0 {
return f.defaultHook
}
hook := f.hooks[0]
f.hooks = f.hooks[1:]
return hook
}
func (f *ClientSearchFunc) appendCall(r0 ClientSearchFuncCall) {
f.mutex.Lock()
f.history = append(f.history, r0)
f.mutex.Unlock()
}
// History returns a sequence of ClientSearchFuncCall objects describing the
// invocations of this function.
func (f *ClientSearchFunc) History() []ClientSearchFuncCall {
f.mutex.Lock()
history := make([]ClientSearchFuncCall, len(f.history))
copy(history, f.history)
f.mutex.Unlock()
return history
}
// ClientSearchFuncCall is an object that describes an invocation of method
// Search on an instance of MockClient.
type ClientSearchFuncCall struct {
// Arg0 is the value of the 1st argument passed to this method
// invocation.
Arg0 context.Context
// Arg1 is the value of the 2nd argument passed to this method
// invocation.
Arg1 EmbeddingsSearchParameters
// Result0 is the value of the 1st result returned from this method
// invocation.
Result0 *EmbeddingCombinedSearchResults
// Result1 is the value of the 2nd result returned from this method
// invocation.
Result1 error
}
// Args returns an interface slice containing the arguments of this
// invocation.
func (c ClientSearchFuncCall) Args() []interface{} {
return []interface{}{c.Arg0, c.Arg1}
}
// Results returns an interface slice containing the results of this
// invocation.
func (c ClientSearchFuncCall) Results() []interface{} {
return []interface{}{c.Result0, c.Result1}
}

View File

@ -2,17 +2,16 @@ package embeddings
import (
"container/heap"
"fmt"
"math"
"sort"
"github.com/sourcegraph/conc"
"github.com/sourcegraph/sourcegraph/internal/api"
)
type nearestNeighbor struct {
index int
score int32
debug searchDebugInfo
index int
scoreDetails SearchScoreDetails
}
type nearestNeighborsHeap struct {
@ -22,7 +21,7 @@ type nearestNeighborsHeap struct {
func (nn *nearestNeighborsHeap) Len() int { return len(nn.neighbors) }
func (nn *nearestNeighborsHeap) Less(i, j int) bool {
return nn.neighbors[i].score < nn.neighbors[j].score
return nn.neighbors[i].scoreDetails.Score < nn.neighbors[j].scoreDetails.Score
}
func (nn *nearestNeighborsHeap) Swap(i, j int) {
@ -81,19 +80,16 @@ type WorkerOptions struct {
MinRowsToSplit int
}
type SimilaritySearchResult struct {
RepoEmbeddingRowMetadata
SimilarityScore int32
RankScore int32
}
func (r *SimilaritySearchResult) Score() int32 {
return r.SimilarityScore + r.RankScore
}
// SimilaritySearch finds the `nResults` most similar rows to a query vector. It uses the cosine similarity metric.
// IMPORTANT: The vectors in the embedding index have to be normalized for similarity search to work correctly.
func (index *EmbeddingIndex) SimilaritySearch(query []int8, numResults int, workerOptions WorkerOptions, opts SearchOptions) []SimilaritySearchResult {
func (index *EmbeddingIndex) SimilaritySearch(
query []int8,
numResults int,
workerOptions WorkerOptions,
opts SearchOptions,
repoName api.RepoName,
revision api.CommitID,
) []EmbeddingSearchResult {
if numResults == 0 || len(index.Embeddings) == 0 {
return nil
}
@ -131,16 +127,20 @@ func (index *EmbeddingIndex) SimilaritySearch(query []int8, numResults int, work
}
}
// And re-sort it according to the score (descending).
sort.Slice(neighbors, func(i, j int) bool { return neighbors[i].score > neighbors[j].score })
sort.Slice(neighbors, func(i, j int) bool { return neighbors[i].scoreDetails.Score > neighbors[j].scoreDetails.Score })
// Take top neighbors and return them as results.
results := make([]SimilaritySearchResult, numResults)
results := make([]EmbeddingSearchResult, numResults)
for idx := 0; idx < min(numResults, len(neighbors)); idx++ {
results[idx] = SimilaritySearchResult{
RepoEmbeddingRowMetadata: index.RowMetadata[neighbors[idx].index],
SimilarityScore: neighbors[idx].debug.similarity,
RankScore: neighbors[idx].debug.rank,
metadata := index.RowMetadata[neighbors[idx].index]
results[idx] = EmbeddingSearchResult{
RepoName: repoName,
Revision: revision,
FileName: metadata.FileName,
StartLine: metadata.StartLine,
EndLine: metadata.EndLine,
ScoreDetails: neighbors[idx].scoreDetails,
}
}
@ -156,17 +156,17 @@ func (index *EmbeddingIndex) partialSimilaritySearch(query []int8, numResults in
nnHeap := newNearestNeighborsHeap()
for i := partialRows.start; i < partialRows.start+numResults; i++ {
score, debugInfo := index.score(query, i, opts)
heap.Push(nnHeap, nearestNeighbor{index: i, score: score, debug: debugInfo})
scoreDetails := index.score(query, i, opts)
heap.Push(nnHeap, nearestNeighbor{index: i, scoreDetails: scoreDetails})
}
for i := partialRows.start + numResults; i < partialRows.end; i++ {
score, debugInfo := index.score(query, i, opts)
scoreDetails := index.score(query, i, opts)
// Add row if it has greater similarity than the smallest similarity in the heap.
// This way we ensure keep a set of the highest similarities in the heap.
if score > nnHeap.Peek().score {
if scoreDetails.Score > nnHeap.Peek().scoreDetails.Score {
heap.Pop(nnHeap)
heap.Push(nnHeap, nearestNeighbor{index: i, score: score, debug: debugInfo})
heap.Push(nnHeap, nearestNeighbor{index: i, scoreDetails: scoreDetails})
}
}
@ -178,7 +178,7 @@ const (
scoreSimilarityWeight int32 = 2
)
func (index *EmbeddingIndex) score(query []int8, i int, opts SearchOptions) (score int32, debugInfo searchDebugInfo) {
func (index *EmbeddingIndex) score(query []int8, i int, opts SearchOptions) SearchScoreDetails {
similarityScore := scoreSimilarityWeight * Dot(index.Row(i), query)
// handle missing ranks
@ -195,20 +195,11 @@ func (index *EmbeddingIndex) score(query []int8, i int, opts SearchOptions) (sco
rankScore = int32(float32(scoreFileRankWeight) * normalizedRank)
}
return similarityScore + rankScore, searchDebugInfo{similarity: similarityScore, rank: rankScore, enabled: opts.Debug}
}
type searchDebugInfo struct {
similarity int32
rank int32
enabled bool
}
func (i *searchDebugInfo) String() string {
if !i.enabled {
return ""
return SearchScoreDetails{
Score: similarityScore + rankScore,
SimilarityScore: similarityScore,
RankScore: rankScore,
}
return fmt.Sprintf("score:%d, similarity:%d, rank:%d", i.similarity+i.rank, i.similarity, i.rank)
}
func min(a, b int) int {
@ -226,6 +217,5 @@ func max(a, b int) int {
}
type SearchOptions struct {
Debug bool
UseDocumentRanks bool
}

View File

@ -64,10 +64,10 @@ func TestSimilaritySearch(t *testing.T) {
for q := 0; q < numQueries; q++ {
t.Run(fmt.Sprintf("find nearest neighbors query=%d numResults=%d numWorkers=%d", q, numResults, numWorkers), func(t *testing.T) {
query := queries[q*columnDimension : (q+1)*columnDimension]
results := index.SimilaritySearch(query, numResults, WorkerOptions{NumWorkers: numWorkers, MinRowsToSplit: 0}, SearchOptions{})
results := index.SimilaritySearch(query, numResults, WorkerOptions{NumWorkers: numWorkers, MinRowsToSplit: 0}, SearchOptions{}, "", "")
resultRowNums := make([]int, len(results))
for i, r := range results {
resultRowNums[i], _ = strconv.Atoi(r.RepoEmbeddingRowMetadata.FileName)
resultRowNums[i], _ = strconv.Atoi(r.FileName)
}
expectedResults := ranks[q]
require.Equal(t, expectedResults[:min(numResults, len(expectedResults))], resultRowNums)
@ -152,7 +152,7 @@ func BenchmarkSimilaritySearch(b *testing.B) {
b.Run(fmt.Sprintf("numWorkers=%d", numWorkers), func(b *testing.B) {
start := time.Now()
for n := 0; n < b.N; n++ {
_ = index.SimilaritySearch(query, numResults, WorkerOptions{NumWorkers: numWorkers}, SearchOptions{})
_ = index.SimilaritySearch(query, numResults, WorkerOptions{NumWorkers: numWorkers}, SearchOptions{}, "", "")
}
m := float64(numRows) * float64(b.N) / time.Since(start).Seconds()
b.ReportMetric(m, "embeddings/s")
@ -175,23 +175,11 @@ func TestScore(t *testing.T) {
}
// embeddings[0] = 64, 83, 70,
// queries[0:3] = 53, 61, 97,
score, debugInfo := index.score(queries[0:columnDimension], 0, SearchOptions{Debug: true, UseDocumentRanks: true})
scoreDetails := index.score(queries[0:columnDimension], 0, SearchOptions{UseDocumentRanks: true})
// Check that the score is correct
expectedScore := scoreSimilarityWeight * ((64 * 53) + (83 * 61) + (70 * 97))
if math.Abs(float64(score-expectedScore)) > 0.0001 {
t.Fatalf("Expected score %d, but got %d", expectedScore, score)
}
if debugInfo.String() == "" {
t.Fatal("Expected a non-empty debug")
if math.Abs(float64(scoreDetails.Score-expectedScore)) > 0.0001 {
t.Fatalf("Expected score %d, but got %d", expectedScore, scoreDetails.Score)
}
}
func simpleCosineSimilarity(a, b []int8) int32 {
similarity := int32(0)
for i := 0; i < len(a); i++ {
similarity += int32(a[i]) * int32(b[i])
}
return similarity
}

View File

@ -1,6 +1,8 @@
package embeddings
import (
"fmt"
"sort"
"time"
"github.com/sourcegraph/log"
@ -45,18 +47,46 @@ type ContextDetectionEmbeddingIndex struct {
MessagesWithoutAdditionalContextMeanEmbedding []float32
}
type EmbeddingSearchResults struct {
CodeResults []EmbeddingSearchResult `json:"codeResults"`
TextResults []EmbeddingSearchResult `json:"textResults"`
type EmbeddingCombinedSearchResults struct {
CodeResults EmbeddingSearchResults `json:"codeResults"`
TextResults EmbeddingSearchResults `json:"textResults"`
}
type EmbeddingSearchResults []EmbeddingSearchResult
// MergeTruncate merges other into the search results, keeping only max results with the highest scores
func (esrs *EmbeddingSearchResults) MergeTruncate(other EmbeddingSearchResults, max int) {
self := *esrs
self = append(self, other...)
sort.Slice(self, func(i, j int) bool { return self[i].Score() > self[j].Score() })
*esrs = self[:max]
}
type EmbeddingSearchResult struct {
RepoName api.RepoName
Revision api.CommitID
RepoEmbeddingRowMetadata
Content string `json:"content"`
// Experimental: Clients should not rely on any particular format of debug
Debug string `json:"debug,omitempty"`
RepoName api.RepoName `json:"repoName"`
Revision api.CommitID `json:"revision"`
FileName string `json:"fileName"`
StartLine int `json:"startLine"`
EndLine int `json:"endLine"`
ScoreDetails SearchScoreDetails `json:"scoreDetails"`
}
func (esr *EmbeddingSearchResult) Score() int32 {
return esr.ScoreDetails.RankScore + esr.ScoreDetails.SimilarityScore
}
type SearchScoreDetails struct {
Score int32 `json:"score"`
// Breakdown
SimilarityScore int32 `json:"similarityScore"`
RankScore int32 `json:"rankScore"`
}
func (s *SearchScoreDetails) String() string {
return fmt.Sprintf("score:%d, similarity:%d, rank:%d", s.Score, s.SimilarityScore, s.RankScore)
}
// DEPRECATED: to support decoding old indexes, we need a struct

View File

@ -123,3 +123,15 @@
path: github.com/sourcegraph/sourcegraph/enterprise/internal/github_apps/store
interfaces:
- GitHubAppsStore
- filename: enterprise/internal/embeddings/mocks_temp.go
path: github.com/sourcegraph/sourcegraph/enterprise/internal/embeddings
interfaces:
- Client
- filename: enterprise/internal/embeddings/background/repo/mocks_temp.go
path: github.com/sourcegraph/sourcegraph/enterprise/internal/embeddings/background/repo
interfaces:
- RepoEmbeddingJobsStore
- filename: enterprise/internal/embeddings/background/contextdetection/mocks_temp.go
path: github.com/sourcegraph/sourcegraph/enterprise/internal/embeddings/background/contextdetection
interfaces:
- ContextDetectionEmbeddingJobsStore