mirror of
https://github.com/sourcegraph/sourcegraph.git
synced 2026-02-06 19:21:50 +00:00
Revert "Embeddings: multi-repo search" (#51969)
Reverts sourcegraph/sourcegraph#51662
This commit is contained in:
parent
c30b33ff0a
commit
03da2be83d
@ -10,6 +10,6 @@ export class SourcegraphEmbeddingsSearchClient implements EmbeddingsSearch {
|
||||
codeResultsCount: number,
|
||||
textResultsCount: number
|
||||
): Promise<EmbeddingsSearchResults | Error> {
|
||||
return this.client.searchEmbeddings([this.repoId], query, codeResultsCount, textResultsCount)
|
||||
return this.client.searchEmbeddings(this.repoId, query, codeResultsCount, textResultsCount)
|
||||
}
|
||||
}
|
||||
|
||||
@ -141,13 +141,13 @@ export class SourcegraphGraphQLAPIClient {
|
||||
}
|
||||
|
||||
public async searchEmbeddings(
|
||||
repos: string[],
|
||||
repo: string,
|
||||
query: string,
|
||||
codeResultsCount: number,
|
||||
textResultsCount: number
|
||||
): Promise<EmbeddingsSearchResults | Error> {
|
||||
return this.fetchSourcegraphAPI<APIResponse<EmbeddingsSearchResponse>>(SEARCH_EMBEDDINGS_QUERY, {
|
||||
repos,
|
||||
repo,
|
||||
query,
|
||||
codeResultsCount,
|
||||
textResultsCount,
|
||||
|
||||
@ -21,8 +21,8 @@ query Repository($name: String!) {
|
||||
}`
|
||||
|
||||
export const SEARCH_EMBEDDINGS_QUERY = `
|
||||
query EmbeddingsSearch($repos: [ID!]!, $query: String!, $codeResultsCount: Int!, $textResultsCount: Int!) {
|
||||
embeddingsMultiSearch(repos: $repos, query: $query, codeResultsCount: $codeResultsCount, textResultsCount: $textResultsCount) {
|
||||
query EmbeddingsSearch($repo: ID!, $query: String!, $codeResultsCount: Int!, $textResultsCount: Int!) {
|
||||
embeddingsSearch(repo: $repo, query: $query, codeResultsCount: $codeResultsCount, textResultsCount: $textResultsCount) {
|
||||
codeResults {
|
||||
fileName
|
||||
startLine
|
||||
|
||||
@ -11,7 +11,6 @@ import (
|
||||
|
||||
type EmbeddingsResolver interface {
|
||||
EmbeddingsSearch(ctx context.Context, args EmbeddingsSearchInputArgs) (EmbeddingsSearchResultsResolver, error)
|
||||
EmbeddingsMultiSearch(ctx context.Context, args EmbeddingsMultiSearchInputArgs) (EmbeddingsSearchResultsResolver, error)
|
||||
IsContextRequiredForChatQuery(ctx context.Context, args IsContextRequiredForChatQueryInputArgs) (bool, error)
|
||||
RepoEmbeddingJobs(ctx context.Context, args ListRepoEmbeddingJobsArgs) (*graphqlutil.ConnectionResolver[RepoEmbeddingJobResolver], error)
|
||||
|
||||
@ -34,16 +33,9 @@ type EmbeddingsSearchInputArgs struct {
|
||||
TextResultsCount int32
|
||||
}
|
||||
|
||||
type EmbeddingsMultiSearchInputArgs struct {
|
||||
Repos []graphql.ID
|
||||
Query string
|
||||
CodeResultsCount int32
|
||||
TextResultsCount int32
|
||||
}
|
||||
|
||||
type EmbeddingsSearchResultsResolver interface {
|
||||
CodeResults(ctx context.Context) ([]EmbeddingsSearchResultResolver, error)
|
||||
TextResults(ctx context.Context) ([]EmbeddingsSearchResultResolver, error)
|
||||
CodeResults(ctx context.Context) []EmbeddingsSearchResultResolver
|
||||
TextResults(ctx context.Context) []EmbeddingsSearchResultResolver
|
||||
}
|
||||
|
||||
type EmbeddingsSearchResultResolver interface {
|
||||
|
||||
@ -22,31 +22,6 @@ extend type Query {
|
||||
"""
|
||||
textResultsCount: Int!
|
||||
): EmbeddingsSearchResults!
|
||||
|
||||
"""
|
||||
Experimental: Searches a set of repositories for similar code and text results using embeddings.
|
||||
We separated code and text results because text results tended to always feature at the top of the combined results,
|
||||
and didn't leave room for the code.
|
||||
"""
|
||||
embeddingsMultiSearch(
|
||||
"""
|
||||
The repository to search.
|
||||
"""
|
||||
repos: [ID!]!
|
||||
"""
|
||||
The query used for embeddings search.
|
||||
"""
|
||||
query: String!
|
||||
"""
|
||||
The number of code results to return.
|
||||
"""
|
||||
codeResultsCount: Int!
|
||||
"""
|
||||
The number of text results to return. Text results contain Markdown files and similar file types primarily used for writing documentation.
|
||||
"""
|
||||
textResultsCount: Int!
|
||||
): EmbeddingsSearchResults!
|
||||
|
||||
"""
|
||||
Experimental: Determines whether the given query requires further context before it can be answered.
|
||||
For example:
|
||||
|
||||
1
enterprise/cmd/embeddings/qa/BUILD.bazel
generated
1
enterprise/cmd/embeddings/qa/BUILD.bazel
generated
@ -8,7 +8,6 @@ go_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//enterprise/internal/embeddings",
|
||||
"//internal/api",
|
||||
"//lib/errors",
|
||||
],
|
||||
)
|
||||
|
||||
@ -11,7 +11,6 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/enterprise/internal/embeddings"
|
||||
"github.com/sourcegraph/sourcegraph/internal/api"
|
||||
"github.com/sourcegraph/sourcegraph/lib/errors"
|
||||
)
|
||||
|
||||
@ -22,7 +21,7 @@ import (
|
||||
var fs embed.FS
|
||||
|
||||
type embeddingsSearcher interface {
|
||||
Search(args embeddings.EmbeddingsSearchParameters) (*embeddings.EmbeddingCombinedSearchResults, error)
|
||||
Search(args embeddings.EmbeddingsSearchParameters) (*embeddings.EmbeddingSearchResults, error)
|
||||
}
|
||||
|
||||
// Run runs the evaluation and returns recall for the test data.
|
||||
@ -46,10 +45,11 @@ func Run(searcher embeddingsSearcher) (float64, error) {
|
||||
relevantFile := fields[1]
|
||||
|
||||
args := embeddings.EmbeddingsSearchParameters{
|
||||
RepoNames: []api.RepoName{"github.com/sourcegraph/sourcegraph"},
|
||||
RepoName: "github.com/sourcegraph/sourcegraph",
|
||||
Query: query,
|
||||
CodeResultsCount: 20,
|
||||
TextResultsCount: 2,
|
||||
Debug: true,
|
||||
}
|
||||
|
||||
results, err := searcher.Search(args)
|
||||
@ -70,7 +70,11 @@ func Run(searcher embeddingsSearcher) (float64, error) {
|
||||
fmt.Printf(" ")
|
||||
}
|
||||
fmt.Printf("%d. %s", i+1, result.FileName)
|
||||
fmt.Printf(" (%s)\n", result.ScoreDetails.String())
|
||||
if result.Debug != "" {
|
||||
fmt.Printf(" (%s)\n", result.Debug)
|
||||
} else {
|
||||
fmt.Print("\n")
|
||||
}
|
||||
}
|
||||
fmt.Println()
|
||||
if fileFound {
|
||||
@ -99,7 +103,7 @@ func NewClient(url string) *client {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *client) Search(args embeddings.EmbeddingsSearchParameters) (*embeddings.EmbeddingCombinedSearchResults, error) {
|
||||
func (c *client) Search(args embeddings.EmbeddingsSearchParameters) (*embeddings.EmbeddingSearchResults, error) {
|
||||
b, err := json.Marshal(args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -126,7 +130,7 @@ func (c *client) Search(args embeddings.EmbeddingsSearchParameters) (*embeddings
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res := embeddings.EmbeddingCombinedSearchResults{}
|
||||
res := embeddings.EmbeddingSearchResults{}
|
||||
err = json.Unmarshal(body, &res)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to unmarshal response")
|
||||
|
||||
6
enterprise/cmd/embeddings/shared/BUILD.bazel
generated
6
enterprise/cmd/embeddings/shared/BUILD.bazel
generated
@ -35,6 +35,7 @@ go_library(
|
||||
"//internal/env",
|
||||
"//internal/errcode",
|
||||
"//internal/featureflag",
|
||||
"//internal/gitserver",
|
||||
"//internal/goroutine",
|
||||
"//internal/honey",
|
||||
"//internal/httpserver",
|
||||
@ -78,8 +79,8 @@ go_test(
|
||||
srcs = [
|
||||
"context_detection_test.go",
|
||||
"context_qa_test.go",
|
||||
"main_test.go",
|
||||
"repo_embedding_index_cache_test.go",
|
||||
"search_test.go",
|
||||
],
|
||||
embed = [":shared"],
|
||||
embedsrcs = [
|
||||
@ -92,11 +93,10 @@ go_test(
|
||||
"//enterprise/internal/embeddings/background/repo",
|
||||
"//internal/api",
|
||||
"//internal/database",
|
||||
"//internal/endpoint",
|
||||
"//internal/types",
|
||||
"//internal/uploadstore/mocks",
|
||||
"//lib/errors",
|
||||
"@com_github_sourcegraph_log//logtest",
|
||||
"@com_github_sourcegraph_log//:log",
|
||||
"@com_github_stretchr_testify//require",
|
||||
],
|
||||
)
|
||||
|
||||
@ -13,6 +13,8 @@ import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/sourcegraph/log"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/enterprise/cmd/embeddings/qa"
|
||||
"github.com/sourcegraph/sourcegraph/enterprise/internal/embeddings"
|
||||
"github.com/sourcegraph/sourcegraph/internal/api"
|
||||
@ -32,6 +34,7 @@ func TestRecall(t *testing.T) {
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
logger := log.NoOp()
|
||||
|
||||
// Set up mock functions
|
||||
queryEmbeddings, err := loadQueryEmbeddings(t)
|
||||
@ -56,13 +59,20 @@ func TestRecall(t *testing.T) {
|
||||
return embeddings.DownloadRepoEmbeddingIndex(context.Background(), mockStore, string(key))
|
||||
}
|
||||
|
||||
// We only care about the file names in this test.
|
||||
mockReadFile := func(ctx context.Context, repoName api.RepoName, revision api.CommitID, fileName string) ([]byte, error) {
|
||||
return []byte{}, nil
|
||||
}
|
||||
|
||||
// Weaviate is disabled per default. We don't need it for this test.
|
||||
weaviate := &weaviateClient{}
|
||||
|
||||
searcher := func(args embeddings.EmbeddingsSearchParameters) (*embeddings.EmbeddingCombinedSearchResults, error) {
|
||||
return searchRepoEmbeddingIndexes(
|
||||
searcher := func(args embeddings.EmbeddingsSearchParameters) (*embeddings.EmbeddingSearchResults, error) {
|
||||
return searchRepoEmbeddingIndex(
|
||||
ctx,
|
||||
logger,
|
||||
args,
|
||||
mockReadFile,
|
||||
getRepoEmbeddingIndex,
|
||||
getQueryEmbedding,
|
||||
weaviate,
|
||||
@ -110,8 +120,8 @@ func loadQueryEmbeddings(t *testing.T) (map[string][]float32, error) {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
type embeddingsSearcherFunc func(args embeddings.EmbeddingsSearchParameters) (*embeddings.EmbeddingCombinedSearchResults, error)
|
||||
type embeddingsSearcherFunc func(args embeddings.EmbeddingsSearchParameters) (*embeddings.EmbeddingSearchResults, error)
|
||||
|
||||
func (f embeddingsSearcherFunc) Search(args embeddings.EmbeddingsSearchParameters) (*embeddings.EmbeddingCombinedSearchResults, error) {
|
||||
func (f embeddingsSearcherFunc) Search(args embeddings.EmbeddingsSearchParameters) (*embeddings.EmbeddingSearchResults, error) {
|
||||
return f(args)
|
||||
}
|
||||
|
||||
@ -19,6 +19,7 @@ import (
|
||||
"github.com/sourcegraph/sourcegraph/enterprise/internal/embeddings"
|
||||
"github.com/sourcegraph/sourcegraph/enterprise/internal/embeddings/background/repo"
|
||||
"github.com/sourcegraph/sourcegraph/internal/actor"
|
||||
"github.com/sourcegraph/sourcegraph/internal/api"
|
||||
"github.com/sourcegraph/sourcegraph/internal/authz"
|
||||
"github.com/sourcegraph/sourcegraph/internal/conf"
|
||||
"github.com/sourcegraph/sourcegraph/internal/conf/conftypes"
|
||||
@ -26,6 +27,7 @@ import (
|
||||
connections "github.com/sourcegraph/sourcegraph/internal/database/connections/live"
|
||||
"github.com/sourcegraph/sourcegraph/internal/errcode"
|
||||
"github.com/sourcegraph/sourcegraph/internal/featureflag"
|
||||
"github.com/sourcegraph/sourcegraph/internal/gitserver"
|
||||
"github.com/sourcegraph/sourcegraph/internal/goroutine"
|
||||
"github.com/sourcegraph/sourcegraph/internal/honey"
|
||||
"github.com/sourcegraph/sourcegraph/internal/httpserver"
|
||||
@ -56,6 +58,7 @@ func Main(ctx context.Context, observationCtx *observation.Context, ready servic
|
||||
repoEmbeddingJobsStore := repo.NewRepoEmbeddingJobsStore(db)
|
||||
|
||||
// Run setup
|
||||
gitserverClient := gitserver.NewClient()
|
||||
uploadStore, err := embeddings.NewEmbeddingsUploadStore(ctx, observationCtx, config.EmbeddingsUploadStoreConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -66,6 +69,10 @@ func Main(ctx context.Context, observationCtx *observation.Context, ready servic
|
||||
return errors.Wrap(err, "creating sub-repo client")
|
||||
}
|
||||
|
||||
readFile := func(ctx context.Context, repoName api.RepoName, revision api.CommitID, fileName string) ([]byte, error) {
|
||||
return gitserverClient.ReadFile(ctx, authz.DefaultSubRepoPermsChecker, repoName, revision, fileName)
|
||||
}
|
||||
|
||||
indexGetter, err := NewCachedEmbeddingIndexGetter(
|
||||
repoStore,
|
||||
repoEmbeddingJobsStore,
|
||||
@ -85,6 +92,7 @@ func Main(ctx context.Context, observationCtx *observation.Context, ready servic
|
||||
|
||||
weaviate := newWeaviateClient(
|
||||
logger,
|
||||
readFile,
|
||||
getQueryEmbedding,
|
||||
config.WeaviateURL,
|
||||
)
|
||||
@ -92,7 +100,7 @@ func Main(ctx context.Context, observationCtx *observation.Context, ready servic
|
||||
getContextDetectionEmbeddingIndex := getCachedContextDetectionEmbeddingIndex(uploadStore)
|
||||
|
||||
// Create HTTP server
|
||||
handler := NewHandler(logger, indexGetter.Get, getQueryEmbedding, weaviate, getContextDetectionEmbeddingIndex)
|
||||
handler := NewHandler(logger, readFile, indexGetter.Get, getQueryEmbedding, weaviate, getContextDetectionEmbeddingIndex)
|
||||
handler = handlePanic(logger, handler)
|
||||
handler = featureflag.Middleware(db.FeatureFlags(), handler)
|
||||
handler = trace.HTTPMiddleware(logger, handler, conf.DefaultClient())
|
||||
@ -114,6 +122,7 @@ func Main(ctx context.Context, observationCtx *observation.Context, ready servic
|
||||
|
||||
func NewHandler(
|
||||
logger log.Logger,
|
||||
readFile readFileFn,
|
||||
getRepoEmbeddingIndex getRepoEmbeddingIndexFn,
|
||||
getQueryEmbedding getQueryEmbeddingFn,
|
||||
weaviate *weaviateClient,
|
||||
@ -134,7 +143,7 @@ func NewHandler(
|
||||
return
|
||||
}
|
||||
|
||||
res, err := searchRepoEmbeddingIndexes(r.Context(), args, getRepoEmbeddingIndex, getQueryEmbedding, weaviate)
|
||||
res, err := searchRepoEmbeddingIndex(r.Context(), logger, args, readFile, getRepoEmbeddingIndex, getQueryEmbedding, weaviate)
|
||||
if errcode.IsNotFound(err) {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
|
||||
@ -1,235 +0,0 @@
|
||||
package shared
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/sourcegraph/log/logtest"
|
||||
"github.com/sourcegraph/sourcegraph/enterprise/internal/embeddings"
|
||||
"github.com/sourcegraph/sourcegraph/internal/api"
|
||||
"github.com/sourcegraph/sourcegraph/internal/endpoint"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestEmbeddingsSearch(t *testing.T) {
|
||||
logger := logtest.Scoped(t)
|
||||
|
||||
makeIndex := func(name api.RepoName, w int8) *embeddings.RepoEmbeddingIndex {
|
||||
return &embeddings.RepoEmbeddingIndex{
|
||||
RepoName: name,
|
||||
Revision: "",
|
||||
CodeIndex: embeddings.EmbeddingIndex{
|
||||
Embeddings: []int8{
|
||||
w, 0, 0, 0,
|
||||
0, w, 0, 0,
|
||||
0, 0, w, 0,
|
||||
0, 0, 0, w,
|
||||
},
|
||||
ColumnDimension: 4,
|
||||
RowMetadata: []embeddings.RepoEmbeddingRowMetadata{
|
||||
{FileName: "codefile1", StartLine: 0, EndLine: 1},
|
||||
{FileName: "codefile2", StartLine: 0, EndLine: 1},
|
||||
{FileName: "codefile3", StartLine: 0, EndLine: 1},
|
||||
{FileName: "codefile4", StartLine: 0, EndLine: 1},
|
||||
},
|
||||
},
|
||||
TextIndex: embeddings.EmbeddingIndex{
|
||||
Embeddings: []int8{
|
||||
w, 0, 0, 0,
|
||||
0, w, 0, 0,
|
||||
0, 0, w, 0,
|
||||
0, 0, 0, w,
|
||||
},
|
||||
ColumnDimension: 4,
|
||||
RowMetadata: []embeddings.RepoEmbeddingRowMetadata{
|
||||
{FileName: "textfile1", StartLine: 0, EndLine: 1},
|
||||
{FileName: "textfile2", StartLine: 0, EndLine: 1},
|
||||
{FileName: "textfile3", StartLine: 0, EndLine: 1},
|
||||
{FileName: "textfile4", StartLine: 0, EndLine: 1},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
indexes := map[api.RepoName]*embeddings.RepoEmbeddingIndex{
|
||||
"repo1": makeIndex("repo1", 1),
|
||||
"repo2": makeIndex("repo2", 2),
|
||||
"repo3": makeIndex("repo3", 3),
|
||||
"repo4": makeIndex("repo4", 4),
|
||||
}
|
||||
|
||||
getRepoEmbeddingIndex := func(_ context.Context, repoName api.RepoName) (*embeddings.RepoEmbeddingIndex, error) {
|
||||
return indexes[repoName], nil
|
||||
}
|
||||
getQueryEmbedding := func(_ context.Context, query string) ([]float32, error) {
|
||||
switch query {
|
||||
case "one":
|
||||
return []float32{1, 0, 0, 0}, nil
|
||||
case "two":
|
||||
return []float32{0, 1, 0, 0}, nil
|
||||
case "three":
|
||||
return []float32{0, 0, 1, 0}, nil
|
||||
case "four":
|
||||
return []float32{0, 0, 1, 1}, nil
|
||||
default:
|
||||
panic("unknown")
|
||||
}
|
||||
}
|
||||
getContextDetectionEmbeddingIndex := func(context.Context) (*embeddings.ContextDetectionEmbeddingIndex, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
server1 := httptest.NewServer(NewHandler(
|
||||
logger,
|
||||
getRepoEmbeddingIndex,
|
||||
getQueryEmbedding,
|
||||
nil,
|
||||
getContextDetectionEmbeddingIndex,
|
||||
))
|
||||
|
||||
server2 := httptest.NewServer(NewHandler(
|
||||
logger,
|
||||
getRepoEmbeddingIndex,
|
||||
getQueryEmbedding,
|
||||
nil,
|
||||
getContextDetectionEmbeddingIndex,
|
||||
))
|
||||
|
||||
client := embeddings.NewClient(endpoint.Static(server1.URL, server2.URL), http.DefaultClient)
|
||||
|
||||
{
|
||||
// First test: we should return results for file1 based on the query.
|
||||
// The rankings should have repo4 highest because it has the largest weighted
|
||||
// embeddings.
|
||||
params := embeddings.EmbeddingsSearchParameters{
|
||||
RepoNames: []api.RepoName{"repo1", "repo2", "repo3", "repo4"},
|
||||
RepoIDs: []api.RepoID{1, 2, 3, 4},
|
||||
Query: "one",
|
||||
CodeResultsCount: 2,
|
||||
TextResultsCount: 2,
|
||||
UseDocumentRanks: false,
|
||||
}
|
||||
|
||||
results, err := client.Search(context.Background(), params)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, &embeddings.EmbeddingCombinedSearchResults{
|
||||
CodeResults: embeddings.EmbeddingSearchResults{{
|
||||
RepoName: "repo4",
|
||||
FileName: "codefile1",
|
||||
StartLine: 0,
|
||||
EndLine: 1,
|
||||
ScoreDetails: embeddings.SearchScoreDetails{Score: 1016, SimilarityScore: 1016},
|
||||
}, {
|
||||
RepoName: "repo3",
|
||||
FileName: "codefile1",
|
||||
StartLine: 0,
|
||||
EndLine: 1,
|
||||
ScoreDetails: embeddings.SearchScoreDetails{Score: 762, SimilarityScore: 762},
|
||||
}},
|
||||
TextResults: embeddings.EmbeddingSearchResults{{
|
||||
RepoName: "repo4",
|
||||
FileName: "textfile1",
|
||||
StartLine: 0,
|
||||
EndLine: 1,
|
||||
ScoreDetails: embeddings.SearchScoreDetails{Score: 1016, SimilarityScore: 1016},
|
||||
}, {
|
||||
RepoName: "repo3",
|
||||
FileName: "textfile1",
|
||||
StartLine: 0,
|
||||
EndLine: 1,
|
||||
ScoreDetails: embeddings.SearchScoreDetails{Score: 762, SimilarityScore: 762},
|
||||
}},
|
||||
}, results)
|
||||
}
|
||||
|
||||
{
|
||||
// Second test: providing a subset of repos should only search those repos
|
||||
params := embeddings.EmbeddingsSearchParameters{
|
||||
RepoNames: []api.RepoName{"repo1", "repo3"},
|
||||
RepoIDs: []api.RepoID{1, 2, 3, 4},
|
||||
Query: "one",
|
||||
CodeResultsCount: 2,
|
||||
TextResultsCount: 2,
|
||||
UseDocumentRanks: false,
|
||||
}
|
||||
|
||||
results, err := client.Search(context.Background(), params)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, &embeddings.EmbeddingCombinedSearchResults{
|
||||
CodeResults: embeddings.EmbeddingSearchResults{{
|
||||
RepoName: "repo3",
|
||||
FileName: "codefile1",
|
||||
StartLine: 0,
|
||||
EndLine: 1,
|
||||
ScoreDetails: embeddings.SearchScoreDetails{Score: 762, SimilarityScore: 762},
|
||||
}, {
|
||||
RepoName: "repo1",
|
||||
FileName: "codefile1",
|
||||
StartLine: 0,
|
||||
EndLine: 1,
|
||||
ScoreDetails: embeddings.SearchScoreDetails{Score: 254, SimilarityScore: 254},
|
||||
}},
|
||||
TextResults: embeddings.EmbeddingSearchResults{{
|
||||
RepoName: "repo3",
|
||||
FileName: "textfile1",
|
||||
StartLine: 0,
|
||||
EndLine: 1,
|
||||
ScoreDetails: embeddings.SearchScoreDetails{Score: 762, SimilarityScore: 762},
|
||||
}, {
|
||||
RepoName: "repo1",
|
||||
FileName: "textfile1",
|
||||
StartLine: 0,
|
||||
EndLine: 1,
|
||||
ScoreDetails: embeddings.SearchScoreDetails{Score: 254, SimilarityScore: 254},
|
||||
}},
|
||||
}, results)
|
||||
}
|
||||
|
||||
{
|
||||
// Third test: try a different file just to be safe
|
||||
params := embeddings.EmbeddingsSearchParameters{
|
||||
RepoNames: []api.RepoName{"repo1", "repo2", "repo3", "repo4"},
|
||||
RepoIDs: []api.RepoID{1, 2, 3, 4},
|
||||
Query: "three",
|
||||
CodeResultsCount: 2,
|
||||
TextResultsCount: 2,
|
||||
UseDocumentRanks: false,
|
||||
}
|
||||
|
||||
results, err := client.Search(context.Background(), params)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, &embeddings.EmbeddingCombinedSearchResults{
|
||||
CodeResults: embeddings.EmbeddingSearchResults{{
|
||||
RepoName: "repo4",
|
||||
FileName: "codefile3",
|
||||
StartLine: 0,
|
||||
EndLine: 1,
|
||||
ScoreDetails: embeddings.SearchScoreDetails{Score: 1016, SimilarityScore: 1016},
|
||||
}, {
|
||||
RepoName: "repo3",
|
||||
FileName: "codefile3",
|
||||
StartLine: 0,
|
||||
EndLine: 1,
|
||||
ScoreDetails: embeddings.SearchScoreDetails{Score: 762, SimilarityScore: 762},
|
||||
}},
|
||||
TextResults: embeddings.EmbeddingSearchResults{{
|
||||
RepoName: "repo4",
|
||||
FileName: "textfile3",
|
||||
StartLine: 0,
|
||||
EndLine: 1,
|
||||
ScoreDetails: embeddings.SearchScoreDetails{Score: 1016, SimilarityScore: 1016},
|
||||
}, {
|
||||
RepoName: "repo3",
|
||||
FileName: "textfile3",
|
||||
StartLine: 0,
|
||||
EndLine: 1,
|
||||
ScoreDetails: embeddings.SearchScoreDetails{Score: 762, SimilarityScore: 762},
|
||||
}},
|
||||
}, results)
|
||||
}
|
||||
}
|
||||
@ -2,66 +2,145 @@ package shared
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
"github.com/sourcegraph/log"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/enterprise/internal/embeddings"
|
||||
"github.com/sourcegraph/sourcegraph/internal/api"
|
||||
"github.com/sourcegraph/sourcegraph/lib/errors"
|
||||
)
|
||||
|
||||
const SIMILARITY_SEARCH_MIN_ROWS_TO_SPLIT = 1000
|
||||
|
||||
type readFileFn func(ctx context.Context, repoName api.RepoName, revision api.CommitID, fileName string) ([]byte, error)
|
||||
type getRepoEmbeddingIndexFn func(ctx context.Context, repoName api.RepoName) (*embeddings.RepoEmbeddingIndex, error)
|
||||
type getQueryEmbeddingFn func(ctx context.Context, query string) ([]float32, error)
|
||||
|
||||
func searchRepoEmbeddingIndexes(
|
||||
func searchRepoEmbeddingIndex(
|
||||
ctx context.Context,
|
||||
logger log.Logger,
|
||||
params embeddings.EmbeddingsSearchParameters,
|
||||
readFile readFileFn,
|
||||
getRepoEmbeddingIndex getRepoEmbeddingIndexFn,
|
||||
getQueryEmbedding getQueryEmbeddingFn,
|
||||
weaviate *weaviateClient,
|
||||
) (*embeddings.EmbeddingCombinedSearchResults, error) {
|
||||
) (*embeddings.EmbeddingSearchResults, error) {
|
||||
if weaviate.Use(ctx) {
|
||||
return weaviate.Search(ctx, params)
|
||||
}
|
||||
|
||||
embeddingIndex, err := getRepoEmbeddingIndex(ctx, params.RepoName)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "getting repo embedding index for repo %q", params.RepoName)
|
||||
}
|
||||
|
||||
floatQuery, err := getQueryEmbedding(ctx, params.Query)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "getting query embedding")
|
||||
}
|
||||
embeddedQuery := embeddings.Quantize(floatQuery)
|
||||
|
||||
workerOpts := embeddings.WorkerOptions{
|
||||
NumWorkers: runtime.GOMAXPROCS(0),
|
||||
MinRowsToSplit: SIMILARITY_SEARCH_MIN_ROWS_TO_SPLIT,
|
||||
}
|
||||
|
||||
searchOpts := embeddings.SearchOptions{
|
||||
opts := embeddings.SearchOptions{
|
||||
Debug: params.Debug,
|
||||
UseDocumentRanks: params.UseDocumentRanks,
|
||||
}
|
||||
|
||||
var result embeddings.EmbeddingCombinedSearchResults
|
||||
codeResults := searchEmbeddingIndex(ctx, logger, embeddingIndex.RepoName, embeddingIndex.Revision, &embeddingIndex.CodeIndex, readFile, embeddedQuery, params.CodeResultsCount, opts)
|
||||
textResults := searchEmbeddingIndex(ctx, logger, embeddingIndex.RepoName, embeddingIndex.Revision, &embeddingIndex.TextIndex, readFile, embeddedQuery, params.TextResultsCount, opts)
|
||||
|
||||
for i, repoName := range params.RepoNames {
|
||||
if weaviate.Use(ctx) {
|
||||
codeResults, textResults, err := weaviate.Search(ctx, repoName, params.RepoIDs[i], params.Query, params.CodeResultsCount, params.TextResultsCount)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return &embeddings.EmbeddingSearchResults{CodeResults: codeResults, TextResults: textResults}, nil
|
||||
}
|
||||
|
||||
const SIMILARITY_SEARCH_MIN_ROWS_TO_SPLIT = 1000
|
||||
|
||||
func searchEmbeddingIndex(
|
||||
ctx context.Context,
|
||||
logger log.Logger,
|
||||
repoName api.RepoName,
|
||||
revision api.CommitID,
|
||||
index *embeddings.EmbeddingIndex,
|
||||
readFile readFileFn,
|
||||
query []int8,
|
||||
nResults int,
|
||||
opts embeddings.SearchOptions,
|
||||
) []embeddings.EmbeddingSearchResult {
|
||||
numWorkers := runtime.GOMAXPROCS(0)
|
||||
results := index.SimilaritySearch(query, nResults, embeddings.WorkerOptions{NumWorkers: numWorkers, MinRowsToSplit: SIMILARITY_SEARCH_MIN_ROWS_TO_SPLIT}, opts)
|
||||
|
||||
return filterAndHydrateContent(
|
||||
ctx,
|
||||
logger,
|
||||
repoName,
|
||||
revision,
|
||||
readFile,
|
||||
opts.Debug,
|
||||
results,
|
||||
)
|
||||
}
|
||||
|
||||
// filterAndHydrateContent will mutate unfiltered to populate the Content
|
||||
// field. If we fail to read a file (eg permission issues) we will remove the
|
||||
// result. As such the returned slice should be used.
|
||||
func filterAndHydrateContent(
|
||||
ctx context.Context,
|
||||
logger log.Logger,
|
||||
repoName api.RepoName,
|
||||
revision api.CommitID,
|
||||
readFile readFileFn,
|
||||
debug bool,
|
||||
unfiltered []embeddings.SimilaritySearchResult,
|
||||
) []embeddings.EmbeddingSearchResult {
|
||||
filtered := make([]embeddings.EmbeddingSearchResult, 0, len(unfiltered))
|
||||
|
||||
for _, result := range unfiltered {
|
||||
fileContent, err := readFile(ctx, repoName, revision, result.FileName)
|
||||
if err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
logger.Error("error reading file", log.String("repoName", string(repoName)), log.String("revision", string(revision)), log.String("fileName", result.FileName), log.Error(err))
|
||||
}
|
||||
|
||||
result.CodeResults.MergeTruncate(codeResults, params.CodeResultsCount)
|
||||
result.TextResults.MergeTruncate(textResults, params.TextResultsCount)
|
||||
continue
|
||||
}
|
||||
lines := strings.Split(string(fileContent), "\n")
|
||||
|
||||
embeddingIndex, err := getRepoEmbeddingIndex(ctx, repoName)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "getting repo embedding index for repo %q", repoName)
|
||||
// Sanity check: check that startLine and endLine are within 0 and len(lines).
|
||||
startLine := max(0, min(len(lines), result.StartLine))
|
||||
endLine := max(0, min(len(lines), result.EndLine))
|
||||
|
||||
content := strings.Join(lines[startLine:endLine], "\n")
|
||||
|
||||
var debugString string
|
||||
if debug {
|
||||
debugString = fmt.Sprintf("score:%d, similarity:%d, rank:%d", result.Score(), result.SimilarityScore, result.RankScore)
|
||||
}
|
||||
|
||||
codeResults := embeddingIndex.CodeIndex.SimilaritySearch(embeddedQuery, params.CodeResultsCount, workerOpts, searchOpts, embeddingIndex.RepoName, embeddingIndex.Revision)
|
||||
textResults := embeddingIndex.TextIndex.SimilaritySearch(embeddedQuery, params.TextResultsCount, workerOpts, searchOpts, embeddingIndex.RepoName, embeddingIndex.Revision)
|
||||
|
||||
result.CodeResults.MergeTruncate(codeResults, params.CodeResultsCount)
|
||||
result.TextResults.MergeTruncate(textResults, params.TextResultsCount)
|
||||
|
||||
filtered = append(filtered, embeddings.EmbeddingSearchResult{
|
||||
RepoName: repoName,
|
||||
Revision: revision,
|
||||
RepoEmbeddingRowMetadata: embeddings.RepoEmbeddingRowMetadata{
|
||||
FileName: result.FileName,
|
||||
StartLine: startLine,
|
||||
EndLine: endLine,
|
||||
},
|
||||
Debug: debugString,
|
||||
Content: content,
|
||||
})
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
return filtered
|
||||
}
|
||||
|
||||
func min(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func max(a, b int) int {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
38
enterprise/cmd/embeddings/shared/search_test.go
Normal file
38
enterprise/cmd/embeddings/shared/search_test.go
Normal file
@ -0,0 +1,38 @@
|
||||
package shared
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/sourcegraph/log"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/enterprise/internal/embeddings"
|
||||
"github.com/sourcegraph/sourcegraph/internal/api"
|
||||
)
|
||||
|
||||
func TestFilterAndHydrateContent_emptyFile(t *testing.T) {
|
||||
// Set up test data
|
||||
repoName := api.RepoName("example/repo")
|
||||
revision := api.CommitID("abc123")
|
||||
debug := false
|
||||
unfiltered := []embeddings.SimilaritySearchResult{
|
||||
{
|
||||
RepoEmbeddingRowMetadata: embeddings.RepoEmbeddingRowMetadata{
|
||||
FileName: "file.txt",
|
||||
StartLine: 5,
|
||||
EndLine: 20,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Define a mock readFile function that returns an empty string
|
||||
readFile := func(ctx context.Context, repoName api.RepoName, revision api.CommitID, fileName string) ([]byte, error) {
|
||||
return []byte{}, nil
|
||||
}
|
||||
|
||||
filtered := filterAndHydrateContent(context.Background(), log.NoOp(), repoName, revision, readFile, debug, unfiltered)
|
||||
|
||||
if len(filtered) != 1 {
|
||||
t.Errorf("Expected 1 filtered result, but got %d elements", len(filtered))
|
||||
}
|
||||
}
|
||||
@ -19,6 +19,7 @@ import (
|
||||
|
||||
type weaviateClient struct {
|
||||
logger log.Logger
|
||||
readFile readFileFn
|
||||
getQueryEmbedding getQueryEmbeddingFn
|
||||
|
||||
client *weaviate.Client
|
||||
@ -27,6 +28,7 @@ type weaviateClient struct {
|
||||
|
||||
func newWeaviateClient(
|
||||
logger log.Logger,
|
||||
readFile readFileFn,
|
||||
getQueryEmbedding getQueryEmbeddingFn,
|
||||
url *url.URL,
|
||||
) *weaviateClient {
|
||||
@ -43,6 +45,7 @@ func newWeaviateClient(
|
||||
|
||||
return &weaviateClient{
|
||||
logger: logger.Scoped("weaviate", "client for weaviate embedding index"),
|
||||
readFile: readFile,
|
||||
getQueryEmbedding: getQueryEmbedding,
|
||||
client: client,
|
||||
clientErr: err,
|
||||
@ -53,14 +56,14 @@ func (w *weaviateClient) Use(ctx context.Context) bool {
|
||||
return featureflag.FromContext(ctx).GetBoolOr("search-weaviate", false)
|
||||
}
|
||||
|
||||
func (w *weaviateClient) Search(ctx context.Context, repoName api.RepoName, repoID api.RepoID, query string, codeResultsCount, textResultsCount int) (codeResults, textResults []embeddings.EmbeddingSearchResult, _ error) {
|
||||
func (w *weaviateClient) Search(ctx context.Context, params embeddings.EmbeddingsSearchParameters) (*embeddings.EmbeddingSearchResults, error) {
|
||||
if w.clientErr != nil {
|
||||
return nil, nil, w.clientErr
|
||||
return nil, w.clientErr
|
||||
}
|
||||
|
||||
embeddedQuery, err := w.getQueryEmbedding(ctx, query)
|
||||
embeddedQuery, err := w.getQueryEmbedding(ctx, params.Query)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Wrap(err, "getting query embedding")
|
||||
return nil, errors.Wrap(err, "getting query embedding")
|
||||
}
|
||||
|
||||
queryBuilder := func(klass string, limit int) *graphql.GetBuilder {
|
||||
@ -72,9 +75,6 @@ func (w *weaviateClient) Search(ctx context.Context, repoName api.RepoName, repo
|
||||
{Name: "start_line"},
|
||||
{Name: "end_line"},
|
||||
{Name: "revision"},
|
||||
{Name: "_additional", Fields: []graphql.Field{
|
||||
{Name: "distance"},
|
||||
}},
|
||||
}...).
|
||||
WithLimit(limit)
|
||||
}
|
||||
@ -86,7 +86,7 @@ func (w *weaviateClient) Search(ctx context.Context, repoName api.RepoName, repo
|
||||
return nil
|
||||
}
|
||||
|
||||
srs := make([]embeddings.EmbeddingSearchResult, 0, len(code))
|
||||
srs := make([]embeddings.SimilaritySearchResult, 0, len(code))
|
||||
revision := ""
|
||||
for _, c := range code {
|
||||
cMap := c.(map[string]any)
|
||||
@ -96,48 +96,50 @@ func (w *weaviateClient) Search(ctx context.Context, repoName api.RepoName, repo
|
||||
if revision == "" {
|
||||
revision = rev
|
||||
} else {
|
||||
w.logger.Warn("inconsistent revisions returned for an embedded repository", log.Int("repoid", int(repoID)), log.String("filename", fileName), log.String("revision1", revision), log.String("revision2", rev))
|
||||
w.logger.Warn("inconsistent revisions returned for an embedded repository", log.Int("repoid", int(params.RepoID)), log.String("filename", fileName), log.String("revision1", revision), log.String("revision2", rev))
|
||||
}
|
||||
}
|
||||
|
||||
// multiply by half max int32 since distance will always be between 0 and 2
|
||||
similarity := int32(cMap["_additional"].(map[string]any)["distance"].(float64) * (1073741823))
|
||||
|
||||
srs = append(srs, embeddings.EmbeddingSearchResult{
|
||||
RepoName: repoName,
|
||||
Revision: api.CommitID(revision),
|
||||
FileName: fileName,
|
||||
StartLine: int(cMap["start_line"].(float64)),
|
||||
EndLine: int(cMap["end_line"].(float64)),
|
||||
ScoreDetails: embeddings.SearchScoreDetails{
|
||||
Score: similarity,
|
||||
SimilarityScore: similarity,
|
||||
srs = append(srs, embeddings.SimilaritySearchResult{
|
||||
RepoEmbeddingRowMetadata: embeddings.RepoEmbeddingRowMetadata{
|
||||
FileName: fileName,
|
||||
StartLine: int(cMap["start_line"].(float64)),
|
||||
EndLine: int(cMap["end_line"].(float64)),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
return srs
|
||||
commit := api.CommitID(revision)
|
||||
if commit == "" {
|
||||
w.logger.Warn("no revision set for an embedded repository", log.Int("repoid", int(params.RepoID)))
|
||||
commit = api.CommitID("HEAD")
|
||||
}
|
||||
|
||||
return filterAndHydrateContent(ctx, w.logger, params.RepoName, commit, w.readFile, false, srs)
|
||||
}
|
||||
|
||||
// We partition the indexes by type and repository. Each class in
|
||||
// weaviate is its own index, so we achieve partitioning by a class
|
||||
// per repo and type.
|
||||
codeClass := fmt.Sprintf("Code_%d", repoID)
|
||||
textClass := fmt.Sprintf("Text_%d", repoID)
|
||||
codeClass := fmt.Sprintf("Code_%d", params.RepoID)
|
||||
textClass := fmt.Sprintf("Text_%d", params.RepoID)
|
||||
|
||||
res, err := w.client.GraphQL().MultiClassGet().
|
||||
AddQueryClass(queryBuilder(codeClass, codeResultsCount)).
|
||||
AddQueryClass(queryBuilder(textClass, textResultsCount)).
|
||||
AddQueryClass(queryBuilder(codeClass, params.CodeResultsCount)).
|
||||
AddQueryClass(queryBuilder(textClass, params.TextResultsCount)).
|
||||
Do(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Wrap(err, "doing weaviate request")
|
||||
return nil, errors.Wrap(err, "doing weaviate request")
|
||||
}
|
||||
|
||||
if len(res.Errors) > 0 {
|
||||
return nil, nil, weaviateGraphQLError(res.Errors)
|
||||
return nil, weaviateGraphQLError(res.Errors)
|
||||
}
|
||||
|
||||
return extractResults(res, codeClass), extractResults(res, textClass), nil
|
||||
return &embeddings.EmbeddingSearchResults{
|
||||
CodeResults: extractResults(res, codeClass),
|
||||
TextResults: extractResults(res, textClass),
|
||||
}, nil
|
||||
}
|
||||
|
||||
type weaviateGraphQLError []*models.GraphQLError
|
||||
|
||||
@ -26,7 +26,8 @@ func Init(
|
||||
repoEmbeddingsStore := repo.NewRepoEmbeddingJobsStore(db)
|
||||
contextDetectionEmbeddingsStore := contextdetection.NewContextDetectionEmbeddingJobsStore(db)
|
||||
gitserverClient := gitserver.NewClient()
|
||||
embeddingsClient := embeddings.NewDefaultClient()
|
||||
embeddingsClient := embeddings.NewClient()
|
||||
|
||||
enterpriseServices.EmbeddingsResolver = resolvers.NewResolver(
|
||||
db,
|
||||
observationCtx.Logger,
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
|
||||
load("@io_bazel_rules_go//go:def.bzl", "go_library")
|
||||
|
||||
go_library(
|
||||
name = "resolvers",
|
||||
@ -17,7 +17,6 @@ go_library(
|
||||
"//enterprise/internal/embeddings/background/repo",
|
||||
"//internal/api",
|
||||
"//internal/auth",
|
||||
"//internal/authz",
|
||||
"//internal/cody",
|
||||
"//internal/conf",
|
||||
"//internal/database",
|
||||
@ -27,31 +26,6 @@ go_library(
|
||||
"//lib/errors",
|
||||
"@com_github_graph_gophers_graphql_go//:graphql-go",
|
||||
"@com_github_graph_gophers_graphql_go//relay",
|
||||
"@com_github_sourcegraph_conc//pool",
|
||||
"@com_github_sourcegraph_log//:log",
|
||||
],
|
||||
)
|
||||
|
||||
go_test(
|
||||
name = "resolvers_test",
|
||||
srcs = ["resolvers_test.go"],
|
||||
embed = [":resolvers"],
|
||||
deps = [
|
||||
"//cmd/frontend/graphqlbackend",
|
||||
"//enterprise/internal/embeddings",
|
||||
"//enterprise/internal/embeddings/background/contextdetection",
|
||||
"//enterprise/internal/embeddings/background/repo",
|
||||
"//internal/actor",
|
||||
"//internal/api",
|
||||
"//internal/authz",
|
||||
"//internal/conf",
|
||||
"//internal/database",
|
||||
"//internal/featureflag",
|
||||
"//internal/gitserver",
|
||||
"//internal/types",
|
||||
"//schema",
|
||||
"@com_github_graph_gophers_graphql_go//:graphql-go",
|
||||
"@com_github_sourcegraph_log//logtest",
|
||||
"@com_github_stretchr_testify//require",
|
||||
],
|
||||
)
|
||||
|
||||
@ -1,12 +1,8 @@
|
||||
package resolvers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"os"
|
||||
|
||||
"github.com/graph-gophers/graphql-go"
|
||||
"github.com/sourcegraph/conc/pool"
|
||||
"github.com/sourcegraph/log"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/lib/errors"
|
||||
@ -19,7 +15,6 @@ import (
|
||||
repobg "github.com/sourcegraph/sourcegraph/enterprise/internal/embeddings/background/repo"
|
||||
"github.com/sourcegraph/sourcegraph/internal/api"
|
||||
"github.com/sourcegraph/sourcegraph/internal/auth"
|
||||
"github.com/sourcegraph/sourcegraph/internal/authz"
|
||||
"github.com/sourcegraph/sourcegraph/internal/cody"
|
||||
"github.com/sourcegraph/sourcegraph/internal/conf"
|
||||
"github.com/sourcegraph/sourcegraph/internal/database"
|
||||
@ -30,7 +25,7 @@ func NewResolver(
|
||||
db database.DB,
|
||||
logger log.Logger,
|
||||
gitserverClient gitserver.Client,
|
||||
embeddingsClient embeddings.Client,
|
||||
embeddingsClient *embeddings.Client,
|
||||
repoStore repobg.RepoEmbeddingJobsStore,
|
||||
contextDetectionStore contextdetectionbg.ContextDetectionEmbeddingJobsStore,
|
||||
) graphqlbackend.EmbeddingsResolver {
|
||||
@ -48,22 +43,13 @@ type Resolver struct {
|
||||
db database.DB
|
||||
logger log.Logger
|
||||
gitserverClient gitserver.Client
|
||||
embeddingsClient embeddings.Client
|
||||
embeddingsClient *embeddings.Client
|
||||
repoEmbeddingJobsStore repobg.RepoEmbeddingJobsStore
|
||||
contextDetectionJobsStore contextdetectionbg.ContextDetectionEmbeddingJobsStore
|
||||
emails backend.UserEmailsService
|
||||
}
|
||||
|
||||
func (r *Resolver) EmbeddingsSearch(ctx context.Context, args graphqlbackend.EmbeddingsSearchInputArgs) (graphqlbackend.EmbeddingsSearchResultsResolver, error) {
|
||||
return r.EmbeddingsMultiSearch(ctx, graphqlbackend.EmbeddingsMultiSearchInputArgs{
|
||||
Repos: []graphql.ID{args.Repo},
|
||||
Query: args.Query,
|
||||
CodeResultsCount: args.CodeResultsCount,
|
||||
TextResultsCount: args.TextResultsCount,
|
||||
})
|
||||
}
|
||||
|
||||
func (r *Resolver) EmbeddingsMultiSearch(ctx context.Context, args graphqlbackend.EmbeddingsMultiSearchInputArgs) (graphqlbackend.EmbeddingsSearchResultsResolver, error) {
|
||||
if !conf.EmbeddingsEnabled() {
|
||||
return nil, errors.New("embeddings are not configured or disabled")
|
||||
}
|
||||
@ -76,28 +62,19 @@ func (r *Resolver) EmbeddingsMultiSearch(ctx context.Context, args graphqlbacken
|
||||
return nil, err
|
||||
}
|
||||
|
||||
repoIDs := make([]api.RepoID, len(args.Repos))
|
||||
for i, repo := range args.Repos {
|
||||
repoID, err := graphqlbackend.UnmarshalRepositoryID(repo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
repoIDs[i] = repoID
|
||||
}
|
||||
|
||||
repos, err := r.db.Repos().GetByIDs(ctx, repoIDs...)
|
||||
repoID, err := graphqlbackend.UnmarshalRepositoryID(args.Repo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
repoNames := make([]api.RepoName, len(repos))
|
||||
for i, repo := range repos {
|
||||
repoNames[i] = repo.Name
|
||||
repo, err := r.db.Repos().Get(ctx, repoID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
results, err := r.embeddingsClient.Search(ctx, embeddings.EmbeddingsSearchParameters{
|
||||
RepoNames: repoNames,
|
||||
RepoIDs: repoIDs,
|
||||
RepoName: repo.Name,
|
||||
RepoID: repoID,
|
||||
Query: args.Query,
|
||||
CodeResultsCount: int(args.CodeResultsCount),
|
||||
TextResultsCount: int(args.TextResultsCount),
|
||||
@ -106,11 +83,7 @@ func (r *Resolver) EmbeddingsMultiSearch(ctx context.Context, args graphqlbacken
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &embeddingsSearchResultsResolver{
|
||||
results: results,
|
||||
gitserver: r.gitserverClient,
|
||||
logger: r.logger,
|
||||
}, nil
|
||||
return &embeddingsSearchResultsResolver{results}, nil
|
||||
}
|
||||
|
||||
func (r *Resolver) IsContextRequiredForChatQuery(ctx context.Context, args graphqlbackend.IsContextRequiredForChatQueryInputArgs) (bool, error) {
|
||||
@ -210,91 +183,27 @@ func (r *Resolver) ScheduleContextDetectionForEmbedding(ctx context.Context) (*g
|
||||
}
|
||||
|
||||
type embeddingsSearchResultsResolver struct {
|
||||
results *embeddings.EmbeddingCombinedSearchResults
|
||||
gitserver gitserver.Client
|
||||
logger log.Logger
|
||||
results *embeddings.EmbeddingSearchResults
|
||||
}
|
||||
|
||||
func (r *embeddingsSearchResultsResolver) CodeResults(ctx context.Context) ([]graphqlbackend.EmbeddingsSearchResultResolver, error) {
|
||||
return embeddingsSearchResultsToResolvers(ctx, r.logger, r.gitserver, r.results.CodeResults)
|
||||
}
|
||||
|
||||
func (r *embeddingsSearchResultsResolver) TextResults(ctx context.Context) ([]graphqlbackend.EmbeddingsSearchResultResolver, error) {
|
||||
return embeddingsSearchResultsToResolvers(ctx, r.logger, r.gitserver, r.results.TextResults)
|
||||
}
|
||||
|
||||
func embeddingsSearchResultsToResolvers(
|
||||
ctx context.Context,
|
||||
logger log.Logger,
|
||||
gs gitserver.Client,
|
||||
results []embeddings.EmbeddingSearchResult,
|
||||
) ([]graphqlbackend.EmbeddingsSearchResultResolver, error) {
|
||||
|
||||
allContents := make([][]byte, len(results))
|
||||
allErrors := make([]error, len(results))
|
||||
{ // Fetch contents in parallel because fetching them serially can be slow.
|
||||
p := pool.New().WithMaxGoroutines(8)
|
||||
for i, result := range results {
|
||||
i, result := i, result
|
||||
p.Go(func() {
|
||||
content, err := gs.ReadFile(ctx, authz.DefaultSubRepoPermsChecker, result.RepoName, result.Revision, result.FileName)
|
||||
allContents[i] = content
|
||||
allErrors[i] = err
|
||||
})
|
||||
}
|
||||
p.Wait()
|
||||
func (r *embeddingsSearchResultsResolver) CodeResults(ctx context.Context) []graphqlbackend.EmbeddingsSearchResultResolver {
|
||||
codeResults := make([]graphqlbackend.EmbeddingsSearchResultResolver, len(r.results.CodeResults))
|
||||
for idx, result := range r.results.CodeResults {
|
||||
codeResults[idx] = &embeddingsSearchResultResolver{result}
|
||||
}
|
||||
|
||||
resolvers := make([]graphqlbackend.EmbeddingsSearchResultResolver, 0, len(results))
|
||||
{ // Merge the results with their contents, skipping any that errored when fetching the context.
|
||||
for i, result := range results {
|
||||
contents := allContents[i]
|
||||
err := allErrors[i]
|
||||
if err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
logger.Error(
|
||||
"error reading file",
|
||||
log.String("repoName", string(result.RepoName)),
|
||||
log.String("revision", string(result.Revision)),
|
||||
log.String("fileName", result.FileName),
|
||||
log.Error(err),
|
||||
)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
resolvers = append(resolvers, &embeddingsSearchResultResolver{
|
||||
result: result,
|
||||
content: string(extractLineRange(contents, result.StartLine, result.EndLine)),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return resolvers, nil
|
||||
return codeResults
|
||||
}
|
||||
|
||||
func extractLineRange(content []byte, startLine, endLine int) []byte {
|
||||
lines := bytes.Split(content, []byte("\n"))
|
||||
|
||||
// Sanity check: check that startLine and endLine are within 0 and len(lines).
|
||||
startLine = clamp(startLine, 0, len(lines))
|
||||
endLine = clamp(endLine, 0, len(lines))
|
||||
|
||||
return bytes.Join(lines[startLine:endLine], []byte("\n"))
|
||||
}
|
||||
|
||||
func clamp(input, min, max int) int {
|
||||
if input > max {
|
||||
return max
|
||||
} else if input < min {
|
||||
return min
|
||||
func (r *embeddingsSearchResultsResolver) TextResults(ctx context.Context) []graphqlbackend.EmbeddingsSearchResultResolver {
|
||||
textResults := make([]graphqlbackend.EmbeddingsSearchResultResolver, len(r.results.TextResults))
|
||||
for idx, result := range r.results.TextResults {
|
||||
textResults[idx] = &embeddingsSearchResultResolver{result}
|
||||
}
|
||||
return input
|
||||
return textResults
|
||||
}
|
||||
|
||||
type embeddingsSearchResultResolver struct {
|
||||
result embeddings.EmbeddingSearchResult
|
||||
content string
|
||||
result embeddings.EmbeddingSearchResult
|
||||
}
|
||||
|
||||
func (r *embeddingsSearchResultResolver) FileName(ctx context.Context) string {
|
||||
@ -310,5 +219,5 @@ func (r *embeddingsSearchResultResolver) EndLine(ctx context.Context) int32 {
|
||||
}
|
||||
|
||||
func (r *embeddingsSearchResultResolver) Content(ctx context.Context) string {
|
||||
return r.content
|
||||
return r.result.Content
|
||||
}
|
||||
|
||||
@ -1,122 +0,0 @@
|
||||
package resolvers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/graph-gophers/graphql-go"
|
||||
"github.com/sourcegraph/log/logtest"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/graphqlbackend"
|
||||
"github.com/sourcegraph/sourcegraph/enterprise/internal/embeddings"
|
||||
"github.com/sourcegraph/sourcegraph/enterprise/internal/embeddings/background/contextdetection"
|
||||
"github.com/sourcegraph/sourcegraph/enterprise/internal/embeddings/background/repo"
|
||||
"github.com/sourcegraph/sourcegraph/internal/actor"
|
||||
"github.com/sourcegraph/sourcegraph/internal/api"
|
||||
"github.com/sourcegraph/sourcegraph/internal/authz"
|
||||
"github.com/sourcegraph/sourcegraph/internal/conf"
|
||||
"github.com/sourcegraph/sourcegraph/internal/database"
|
||||
"github.com/sourcegraph/sourcegraph/internal/featureflag"
|
||||
"github.com/sourcegraph/sourcegraph/internal/gitserver"
|
||||
"github.com/sourcegraph/sourcegraph/internal/types"
|
||||
"github.com/sourcegraph/sourcegraph/schema"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestEmbeddingSearchResolver(t *testing.T) {
|
||||
logger := logtest.Scoped(t)
|
||||
|
||||
mockDB := database.NewMockDB()
|
||||
mockRepos := database.NewMockRepoStore()
|
||||
mockRepos.GetByIDsFunc.SetDefaultReturn([]*types.Repo{{ID: 1, Name: "repo1"}}, nil)
|
||||
mockDB.ReposFunc.SetDefaultReturn(mockRepos)
|
||||
|
||||
mockGitserver := gitserver.NewMockClient()
|
||||
mockGitserver.ReadFileFunc.SetDefaultHook(func(_ context.Context, _ authz.SubRepoPermissionChecker, _ api.RepoName, _ api.CommitID, fileName string) ([]byte, error) {
|
||||
if fileName == "testfile" {
|
||||
return []byte("test\nfirst\nfour\nlines\nplus\nsome\nmore"), nil
|
||||
}
|
||||
return nil, os.ErrNotExist
|
||||
})
|
||||
|
||||
mockEmbeddingsClient := embeddings.NewMockClient()
|
||||
mockEmbeddingsClient.SearchFunc.SetDefaultReturn(&embeddings.EmbeddingCombinedSearchResults{
|
||||
CodeResults: embeddings.EmbeddingSearchResults{{
|
||||
FileName: "testfile",
|
||||
StartLine: 0,
|
||||
EndLine: 4,
|
||||
}, {
|
||||
FileName: "censored",
|
||||
StartLine: 0,
|
||||
EndLine: 4,
|
||||
}},
|
||||
}, nil)
|
||||
|
||||
repoEmbeddingJobsStore := repo.NewMockRepoEmbeddingJobsStore()
|
||||
contextDetectionJobsStore := contextdetection.NewMockContextDetectionEmbeddingJobsStore()
|
||||
|
||||
resolver := NewResolver(
|
||||
mockDB,
|
||||
logger,
|
||||
mockGitserver,
|
||||
mockEmbeddingsClient,
|
||||
repoEmbeddingJobsStore,
|
||||
contextDetectionJobsStore,
|
||||
)
|
||||
|
||||
conf.Mock(&conf.Unified{
|
||||
SiteConfiguration: schema.SiteConfiguration{
|
||||
Embeddings: &schema.Embeddings{Enabled: true},
|
||||
Completions: &schema.Completions{Enabled: true},
|
||||
},
|
||||
})
|
||||
|
||||
ctx := actor.WithActor(context.Background(), actor.FromMockUser(1))
|
||||
ffs := featureflag.NewMemoryStore(map[string]bool{"cody-experimental": true}, nil, nil)
|
||||
ctx = featureflag.WithFlags(ctx, ffs)
|
||||
|
||||
results, err := resolver.EmbeddingsMultiSearch(ctx, graphqlbackend.EmbeddingsMultiSearchInputArgs{
|
||||
Repos: []graphql.ID{graphqlbackend.MarshalRepositoryID(3)},
|
||||
Query: "test",
|
||||
CodeResultsCount: 1,
|
||||
TextResultsCount: 1,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
codeResults, err := results.CodeResults(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, codeResults, 1)
|
||||
require.Equal(t, "test\nfirst\nfour\nlines", codeResults[0].Content(ctx))
|
||||
|
||||
}
|
||||
|
||||
func Test_extractLineRange(t *testing.T) {
|
||||
cases := []struct {
|
||||
input []byte
|
||||
start, end int
|
||||
output []byte
|
||||
}{{
|
||||
[]byte("zero\none\ntwo\nthree\n"),
|
||||
0, 2,
|
||||
[]byte("zero\none"),
|
||||
}, {
|
||||
[]byte("zero\none\ntwo\nthree\n"),
|
||||
1, 2,
|
||||
[]byte("one"),
|
||||
}, {
|
||||
[]byte("zero\none\ntwo\nthree\n"),
|
||||
1, 2,
|
||||
[]byte("one"),
|
||||
}, {
|
||||
[]byte(""),
|
||||
1, 2,
|
||||
[]byte(""),
|
||||
}}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run("", func(t *testing.T) {
|
||||
got := extractLineRange(tc.input, tc.start, tc.end)
|
||||
require.Equal(t, tc.output, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
2
enterprise/internal/embeddings/BUILD.bazel
generated
2
enterprise/internal/embeddings/BUILD.bazel
generated
@ -12,7 +12,6 @@ go_library(
|
||||
"dot_portable.go",
|
||||
"index_name.go",
|
||||
"index_storage.go",
|
||||
"mocks_temp.go",
|
||||
"quantize.go",
|
||||
"similarity_search.go",
|
||||
"tokens.go",
|
||||
@ -32,7 +31,6 @@ go_library(
|
||||
"//internal/uploadstore",
|
||||
"//lib/errors",
|
||||
"@com_github_sourcegraph_conc//:conc",
|
||||
"@com_github_sourcegraph_conc//pool",
|
||||
"@com_github_sourcegraph_log//:log",
|
||||
"@org_golang_x_sync//errgroup",
|
||||
] + select({
|
||||
|
||||
@ -3,7 +3,6 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library")
|
||||
go_library(
|
||||
name = "contextdetection",
|
||||
srcs = [
|
||||
"mocks_temp.go",
|
||||
"store.go",
|
||||
"types.go",
|
||||
],
|
||||
|
||||
@ -1,292 +0,0 @@
|
||||
// Code generated by go-mockgen 1.3.7; DO NOT EDIT.
|
||||
//
|
||||
// This file was generated by running `sg generate` (or `go-mockgen`) at the root of
|
||||
// this repository. To add additional mocks to this or another package, add a new entry
|
||||
// to the mockgen.yaml file in the root of this repository.
|
||||
|
||||
package contextdetection
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
basestore "github.com/sourcegraph/sourcegraph/internal/database/basestore"
|
||||
)
|
||||
|
||||
// MockContextDetectionEmbeddingJobsStore is a mock implementation of the
|
||||
// ContextDetectionEmbeddingJobsStore interface (from the package
|
||||
// github.com/sourcegraph/sourcegraph/enterprise/internal/embeddings/background/contextdetection)
|
||||
// used for unit testing.
|
||||
type MockContextDetectionEmbeddingJobsStore struct {
|
||||
// CreateContextDetectionEmbeddingJobFunc is an instance of a mock
|
||||
// function object controlling the behavior of the method
|
||||
// CreateContextDetectionEmbeddingJob.
|
||||
CreateContextDetectionEmbeddingJobFunc *ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFunc
|
||||
// HandleFunc is an instance of a mock function object controlling the
|
||||
// behavior of the method Handle.
|
||||
HandleFunc *ContextDetectionEmbeddingJobsStoreHandleFunc
|
||||
}
|
||||
|
||||
// NewMockContextDetectionEmbeddingJobsStore creates a new mock of the
|
||||
// ContextDetectionEmbeddingJobsStore interface. All methods return zero
|
||||
// values for all results, unless overwritten.
|
||||
func NewMockContextDetectionEmbeddingJobsStore() *MockContextDetectionEmbeddingJobsStore {
|
||||
return &MockContextDetectionEmbeddingJobsStore{
|
||||
CreateContextDetectionEmbeddingJobFunc: &ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFunc{
|
||||
defaultHook: func(context.Context) (r0 int, r1 error) {
|
||||
return
|
||||
},
|
||||
},
|
||||
HandleFunc: &ContextDetectionEmbeddingJobsStoreHandleFunc{
|
||||
defaultHook: func() (r0 basestore.TransactableHandle) {
|
||||
return
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// NewStrictMockContextDetectionEmbeddingJobsStore creates a new mock of the
|
||||
// ContextDetectionEmbeddingJobsStore interface. All methods panic on
|
||||
// invocation, unless overwritten.
|
||||
func NewStrictMockContextDetectionEmbeddingJobsStore() *MockContextDetectionEmbeddingJobsStore {
|
||||
return &MockContextDetectionEmbeddingJobsStore{
|
||||
CreateContextDetectionEmbeddingJobFunc: &ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFunc{
|
||||
defaultHook: func(context.Context) (int, error) {
|
||||
panic("unexpected invocation of MockContextDetectionEmbeddingJobsStore.CreateContextDetectionEmbeddingJob")
|
||||
},
|
||||
},
|
||||
HandleFunc: &ContextDetectionEmbeddingJobsStoreHandleFunc{
|
||||
defaultHook: func() basestore.TransactableHandle {
|
||||
panic("unexpected invocation of MockContextDetectionEmbeddingJobsStore.Handle")
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// NewMockContextDetectionEmbeddingJobsStoreFrom creates a new mock of the
|
||||
// MockContextDetectionEmbeddingJobsStore interface. All methods delegate to
|
||||
// the given implementation, unless overwritten.
|
||||
func NewMockContextDetectionEmbeddingJobsStoreFrom(i ContextDetectionEmbeddingJobsStore) *MockContextDetectionEmbeddingJobsStore {
|
||||
return &MockContextDetectionEmbeddingJobsStore{
|
||||
CreateContextDetectionEmbeddingJobFunc: &ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFunc{
|
||||
defaultHook: i.CreateContextDetectionEmbeddingJob,
|
||||
},
|
||||
HandleFunc: &ContextDetectionEmbeddingJobsStoreHandleFunc{
|
||||
defaultHook: i.Handle,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFunc
|
||||
// describes the behavior when the CreateContextDetectionEmbeddingJob method
|
||||
// of the parent MockContextDetectionEmbeddingJobsStore instance is invoked.
|
||||
type ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFunc struct {
|
||||
defaultHook func(context.Context) (int, error)
|
||||
hooks []func(context.Context) (int, error)
|
||||
history []ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFuncCall
|
||||
mutex sync.Mutex
|
||||
}
|
||||
|
||||
// CreateContextDetectionEmbeddingJob delegates to the next hook function in
|
||||
// the queue and stores the parameter and result values of this invocation.
|
||||
func (m *MockContextDetectionEmbeddingJobsStore) CreateContextDetectionEmbeddingJob(v0 context.Context) (int, error) {
|
||||
r0, r1 := m.CreateContextDetectionEmbeddingJobFunc.nextHook()(v0)
|
||||
m.CreateContextDetectionEmbeddingJobFunc.appendCall(ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFuncCall{v0, r0, r1})
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// SetDefaultHook sets function that is called when the
|
||||
// CreateContextDetectionEmbeddingJob method of the parent
|
||||
// MockContextDetectionEmbeddingJobsStore instance is invoked and the hook
|
||||
// queue is empty.
|
||||
func (f *ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFunc) SetDefaultHook(hook func(context.Context) (int, error)) {
|
||||
f.defaultHook = hook
|
||||
}
|
||||
|
||||
// PushHook adds a function to the end of hook queue. Each invocation of the
|
||||
// CreateContextDetectionEmbeddingJob method of the parent
|
||||
// MockContextDetectionEmbeddingJobsStore instance invokes the hook at the
|
||||
// front of the queue and discards it. After the queue is empty, the default
|
||||
// hook function is invoked for any future action.
|
||||
func (f *ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFunc) PushHook(hook func(context.Context) (int, error)) {
|
||||
f.mutex.Lock()
|
||||
f.hooks = append(f.hooks, hook)
|
||||
f.mutex.Unlock()
|
||||
}
|
||||
|
||||
// SetDefaultReturn calls SetDefaultHook with a function that returns the
|
||||
// given values.
|
||||
func (f *ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFunc) SetDefaultReturn(r0 int, r1 error) {
|
||||
f.SetDefaultHook(func(context.Context) (int, error) {
|
||||
return r0, r1
|
||||
})
|
||||
}
|
||||
|
||||
// PushReturn calls PushHook with a function that returns the given values.
|
||||
func (f *ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFunc) PushReturn(r0 int, r1 error) {
|
||||
f.PushHook(func(context.Context) (int, error) {
|
||||
return r0, r1
|
||||
})
|
||||
}
|
||||
|
||||
func (f *ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFunc) nextHook() func(context.Context) (int, error) {
|
||||
f.mutex.Lock()
|
||||
defer f.mutex.Unlock()
|
||||
|
||||
if len(f.hooks) == 0 {
|
||||
return f.defaultHook
|
||||
}
|
||||
|
||||
hook := f.hooks[0]
|
||||
f.hooks = f.hooks[1:]
|
||||
return hook
|
||||
}
|
||||
|
||||
func (f *ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFunc) appendCall(r0 ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFuncCall) {
|
||||
f.mutex.Lock()
|
||||
f.history = append(f.history, r0)
|
||||
f.mutex.Unlock()
|
||||
}
|
||||
|
||||
// History returns a sequence of
|
||||
// ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFuncCall
|
||||
// objects describing the invocations of this function.
|
||||
func (f *ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFunc) History() []ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFuncCall {
|
||||
f.mutex.Lock()
|
||||
history := make([]ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFuncCall, len(f.history))
|
||||
copy(history, f.history)
|
||||
f.mutex.Unlock()
|
||||
|
||||
return history
|
||||
}
|
||||
|
||||
// ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFuncCall
|
||||
// is an object that describes an invocation of method
|
||||
// CreateContextDetectionEmbeddingJob on an instance of
|
||||
// MockContextDetectionEmbeddingJobsStore.
|
||||
type ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFuncCall struct {
|
||||
// Arg0 is the value of the 1st argument passed to this method
|
||||
// invocation.
|
||||
Arg0 context.Context
|
||||
// Result0 is the value of the 1st result returned from this method
|
||||
// invocation.
|
||||
Result0 int
|
||||
// Result1 is the value of the 2nd result returned from this method
|
||||
// invocation.
|
||||
Result1 error
|
||||
}
|
||||
|
||||
// Args returns an interface slice containing the arguments of this
|
||||
// invocation.
|
||||
func (c ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFuncCall) Args() []interface{} {
|
||||
return []interface{}{c.Arg0}
|
||||
}
|
||||
|
||||
// Results returns an interface slice containing the results of this
|
||||
// invocation.
|
||||
func (c ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFuncCall) Results() []interface{} {
|
||||
return []interface{}{c.Result0, c.Result1}
|
||||
}
|
||||
|
||||
// ContextDetectionEmbeddingJobsStoreHandleFunc describes the behavior when
|
||||
// the Handle method of the parent MockContextDetectionEmbeddingJobsStore
|
||||
// instance is invoked.
|
||||
type ContextDetectionEmbeddingJobsStoreHandleFunc struct {
|
||||
defaultHook func() basestore.TransactableHandle
|
||||
hooks []func() basestore.TransactableHandle
|
||||
history []ContextDetectionEmbeddingJobsStoreHandleFuncCall
|
||||
mutex sync.Mutex
|
||||
}
|
||||
|
||||
// Handle delegates to the next hook function in the queue and stores the
|
||||
// parameter and result values of this invocation.
|
||||
func (m *MockContextDetectionEmbeddingJobsStore) Handle() basestore.TransactableHandle {
|
||||
r0 := m.HandleFunc.nextHook()()
|
||||
m.HandleFunc.appendCall(ContextDetectionEmbeddingJobsStoreHandleFuncCall{r0})
|
||||
return r0
|
||||
}
|
||||
|
||||
// SetDefaultHook sets function that is called when the Handle method of the
|
||||
// parent MockContextDetectionEmbeddingJobsStore instance is invoked and the
|
||||
// hook queue is empty.
|
||||
func (f *ContextDetectionEmbeddingJobsStoreHandleFunc) SetDefaultHook(hook func() basestore.TransactableHandle) {
|
||||
f.defaultHook = hook
|
||||
}
|
||||
|
||||
// PushHook adds a function to the end of hook queue. Each invocation of the
|
||||
// Handle method of the parent MockContextDetectionEmbeddingJobsStore
|
||||
// instance invokes the hook at the front of the queue and discards it.
|
||||
// After the queue is empty, the default hook function is invoked for any
|
||||
// future action.
|
||||
func (f *ContextDetectionEmbeddingJobsStoreHandleFunc) PushHook(hook func() basestore.TransactableHandle) {
|
||||
f.mutex.Lock()
|
||||
f.hooks = append(f.hooks, hook)
|
||||
f.mutex.Unlock()
|
||||
}
|
||||
|
||||
// SetDefaultReturn calls SetDefaultHook with a function that returns the
|
||||
// given values.
|
||||
func (f *ContextDetectionEmbeddingJobsStoreHandleFunc) SetDefaultReturn(r0 basestore.TransactableHandle) {
|
||||
f.SetDefaultHook(func() basestore.TransactableHandle {
|
||||
return r0
|
||||
})
|
||||
}
|
||||
|
||||
// PushReturn calls PushHook with a function that returns the given values.
|
||||
func (f *ContextDetectionEmbeddingJobsStoreHandleFunc) PushReturn(r0 basestore.TransactableHandle) {
|
||||
f.PushHook(func() basestore.TransactableHandle {
|
||||
return r0
|
||||
})
|
||||
}
|
||||
|
||||
func (f *ContextDetectionEmbeddingJobsStoreHandleFunc) nextHook() func() basestore.TransactableHandle {
|
||||
f.mutex.Lock()
|
||||
defer f.mutex.Unlock()
|
||||
|
||||
if len(f.hooks) == 0 {
|
||||
return f.defaultHook
|
||||
}
|
||||
|
||||
hook := f.hooks[0]
|
||||
f.hooks = f.hooks[1:]
|
||||
return hook
|
||||
}
|
||||
|
||||
func (f *ContextDetectionEmbeddingJobsStoreHandleFunc) appendCall(r0 ContextDetectionEmbeddingJobsStoreHandleFuncCall) {
|
||||
f.mutex.Lock()
|
||||
f.history = append(f.history, r0)
|
||||
f.mutex.Unlock()
|
||||
}
|
||||
|
||||
// History returns a sequence of
|
||||
// ContextDetectionEmbeddingJobsStoreHandleFuncCall objects describing the
|
||||
// invocations of this function.
|
||||
func (f *ContextDetectionEmbeddingJobsStoreHandleFunc) History() []ContextDetectionEmbeddingJobsStoreHandleFuncCall {
|
||||
f.mutex.Lock()
|
||||
history := make([]ContextDetectionEmbeddingJobsStoreHandleFuncCall, len(f.history))
|
||||
copy(history, f.history)
|
||||
f.mutex.Unlock()
|
||||
|
||||
return history
|
||||
}
|
||||
|
||||
// ContextDetectionEmbeddingJobsStoreHandleFuncCall is an object that
|
||||
// describes an invocation of method Handle on an instance of
|
||||
// MockContextDetectionEmbeddingJobsStore.
|
||||
type ContextDetectionEmbeddingJobsStoreHandleFuncCall struct {
|
||||
// Result0 is the value of the 1st result returned from this method
|
||||
// invocation.
|
||||
Result0 basestore.TransactableHandle
|
||||
}
|
||||
|
||||
// Args returns an interface slice containing the arguments of this
|
||||
// invocation.
|
||||
func (c ContextDetectionEmbeddingJobsStoreHandleFuncCall) Args() []interface{} {
|
||||
return []interface{}{}
|
||||
}
|
||||
|
||||
// Results returns an interface slice containing the results of this
|
||||
// invocation.
|
||||
func (c ContextDetectionEmbeddingJobsStoreHandleFuncCall) Results() []interface{} {
|
||||
return []interface{}{c.Result0}
|
||||
}
|
||||
@ -4,12 +4,10 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/sourcegraph/conc/pool"
|
||||
"github.com/sourcegraph/sourcegraph/lib/errors"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/internal/api"
|
||||
@ -32,23 +30,14 @@ var defaultDoer = func() httpcli.Doer {
|
||||
return d
|
||||
}()
|
||||
|
||||
func NewDefaultClient() Client {
|
||||
return NewClient(defaultEndpoints(), defaultDoer)
|
||||
}
|
||||
|
||||
func NewClient(endpoints *endpoint.Map, doer httpcli.Doer) Client {
|
||||
return &client{
|
||||
Endpoints: endpoints,
|
||||
HTTPClient: doer,
|
||||
func NewClient() *Client {
|
||||
return &Client{
|
||||
Endpoints: defaultEndpoints(),
|
||||
HTTPClient: defaultDoer,
|
||||
}
|
||||
}
|
||||
|
||||
type Client interface {
|
||||
Search(context.Context, EmbeddingsSearchParameters) (*EmbeddingCombinedSearchResults, error)
|
||||
IsContextRequiredForChatQuery(context.Context, IsContextRequiredForChatQueryParameters) (bool, error)
|
||||
}
|
||||
|
||||
type client struct {
|
||||
type Client struct {
|
||||
// Endpoints to embeddings service.
|
||||
Endpoints *endpoint.Map
|
||||
|
||||
@ -57,13 +46,15 @@ type client struct {
|
||||
}
|
||||
|
||||
type EmbeddingsSearchParameters struct {
|
||||
RepoNames []api.RepoName `json:"repoNames"`
|
||||
RepoIDs []api.RepoID `json:"repoIDs"`
|
||||
Query string `json:"query"`
|
||||
CodeResultsCount int `json:"codeResultsCount"`
|
||||
TextResultsCount int `json:"textResultsCount"`
|
||||
RepoName api.RepoName `json:"repoName"`
|
||||
RepoID api.RepoID `json:"repoID"`
|
||||
Query string `json:"query"`
|
||||
CodeResultsCount int `json:"codeResultsCount"`
|
||||
TextResultsCount int `json:"textResultsCount"`
|
||||
|
||||
UseDocumentRanks bool `json:"useDocumentRanks"`
|
||||
// If set to "True", EmbeddingSearchResult.Debug will contain useful information about scoring.
|
||||
Debug bool `json:"debug"`
|
||||
}
|
||||
|
||||
type IsContextRequiredForChatQueryParameters struct {
|
||||
@ -74,43 +65,8 @@ type IsContextRequiredForChatQueryResult struct {
|
||||
IsRequired bool `json:"isRequired"`
|
||||
}
|
||||
|
||||
func (c *client) Search(ctx context.Context, args EmbeddingsSearchParameters) (*EmbeddingCombinedSearchResults, error) {
|
||||
partitions, err := c.partition(args.RepoNames, args.RepoIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
p := pool.NewWithResults[*EmbeddingCombinedSearchResults]().WithContext(ctx)
|
||||
|
||||
for endpoint, partition := range partitions {
|
||||
endpoint := endpoint
|
||||
|
||||
// make a copy for this request
|
||||
args := args
|
||||
args.RepoNames = partition.repoNames
|
||||
args.RepoIDs = partition.repoIDs
|
||||
|
||||
p.Go(func(ctx context.Context) (*EmbeddingCombinedSearchResults, error) {
|
||||
return c.searchPartition(ctx, endpoint, args)
|
||||
})
|
||||
}
|
||||
|
||||
allResults, err := p.Wait()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var combinedResult EmbeddingCombinedSearchResults
|
||||
for _, result := range allResults {
|
||||
combinedResult.CodeResults.MergeTruncate(result.CodeResults, args.CodeResultsCount)
|
||||
combinedResult.TextResults.MergeTruncate(result.TextResults, args.TextResultsCount)
|
||||
}
|
||||
|
||||
return &combinedResult, nil
|
||||
}
|
||||
|
||||
func (c *client) searchPartition(ctx context.Context, endpoint string, args EmbeddingsSearchParameters) (*EmbeddingCombinedSearchResults, error) {
|
||||
resp, err := c.httpPost(ctx, "search", endpoint, args)
|
||||
func (c *Client) Search(ctx context.Context, args EmbeddingsSearchParameters) (*EmbeddingSearchResults, error) {
|
||||
resp, err := c.httpPost(ctx, "search", args.RepoName, args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -126,16 +82,15 @@ func (c *client) searchPartition(ctx context.Context, endpoint string, args Embe
|
||||
)
|
||||
}
|
||||
|
||||
var response EmbeddingCombinedSearchResults
|
||||
var response EmbeddingSearchResults
|
||||
err = json.NewDecoder(resp.Body).Decode(&response)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fmt.Printf("%s, %#v\n", endpoint, response)
|
||||
return &response, nil
|
||||
}
|
||||
|
||||
func (c *client) IsContextRequiredForChatQuery(ctx context.Context, args IsContextRequiredForChatQueryParameters) (bool, error) {
|
||||
func (c *Client) IsContextRequiredForChatQuery(ctx context.Context, args IsContextRequiredForChatQueryParameters) (bool, error) {
|
||||
resp, err := c.httpPost(ctx, "isContextRequiredForChatQuery", "", args)
|
||||
if err != nil {
|
||||
return false, err
|
||||
@ -160,50 +115,24 @@ func (c *client) IsContextRequiredForChatQuery(ctx context.Context, args IsConte
|
||||
return response.IsRequired, nil
|
||||
}
|
||||
|
||||
func (c *client) url(repo api.RepoName) (string, error) {
|
||||
func (c *Client) url(repo api.RepoName) (string, error) {
|
||||
if c.Endpoints == nil {
|
||||
return "", errors.New("an embeddings service has not been configured")
|
||||
}
|
||||
return c.Endpoints.Get(string(repo))
|
||||
}
|
||||
|
||||
type repoPartition struct {
|
||||
repoNames []api.RepoName
|
||||
repoIDs []api.RepoID
|
||||
}
|
||||
|
||||
// returns a partition of the input repos by the endpoint their requests should be routed to
|
||||
func (c *client) partition(repos []api.RepoName, repoIDs []api.RepoID) (map[string]repoPartition, error) {
|
||||
if c.Endpoints == nil {
|
||||
return nil, errors.New("an embeddings service has not been configured")
|
||||
}
|
||||
|
||||
repoStrings := make([]string, len(repos))
|
||||
for i, repo := range repos {
|
||||
repoStrings[i] = string(repo)
|
||||
}
|
||||
|
||||
endpoints, err := c.Endpoints.GetMany(repoStrings...)
|
||||
func (c *Client) httpPost(
|
||||
ctx context.Context,
|
||||
method string,
|
||||
repo api.RepoName,
|
||||
payload any,
|
||||
) (resp *http.Response, err error) {
|
||||
url, err := c.url(repo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res := make(map[string]repoPartition)
|
||||
for i, endpoint := range endpoints {
|
||||
res[endpoint] = repoPartition{
|
||||
repoNames: append(res[endpoint].repoNames, repos[i]),
|
||||
repoIDs: append(res[endpoint].repoIDs, repoIDs[i]),
|
||||
}
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (c *client) httpPost(
|
||||
ctx context.Context,
|
||||
method string,
|
||||
url string,
|
||||
payload any,
|
||||
) (resp *http.Response, err error) {
|
||||
reqBody, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@ -1,291 +0,0 @@
|
||||
// Code generated by go-mockgen 1.3.7; DO NOT EDIT.
|
||||
//
|
||||
// This file was generated by running `sg generate` (or `go-mockgen`) at the root of
|
||||
// this repository. To add additional mocks to this or another package, add a new entry
|
||||
// to the mockgen.yaml file in the root of this repository.
|
||||
|
||||
package embeddings
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// MockClient is a mock implementation of the Client interface (from the
|
||||
// package
|
||||
// github.com/sourcegraph/sourcegraph/enterprise/internal/embeddings) used
|
||||
// for unit testing.
|
||||
type MockClient struct {
|
||||
// IsContextRequiredForChatQueryFunc is an instance of a mock function
|
||||
// object controlling the behavior of the method
|
||||
// IsContextRequiredForChatQuery.
|
||||
IsContextRequiredForChatQueryFunc *ClientIsContextRequiredForChatQueryFunc
|
||||
// SearchFunc is an instance of a mock function object controlling the
|
||||
// behavior of the method Search.
|
||||
SearchFunc *ClientSearchFunc
|
||||
}
|
||||
|
||||
// NewMockClient creates a new mock of the Client interface. All methods
|
||||
// return zero values for all results, unless overwritten.
|
||||
func NewMockClient() *MockClient {
|
||||
return &MockClient{
|
||||
IsContextRequiredForChatQueryFunc: &ClientIsContextRequiredForChatQueryFunc{
|
||||
defaultHook: func(context.Context, IsContextRequiredForChatQueryParameters) (r0 bool, r1 error) {
|
||||
return
|
||||
},
|
||||
},
|
||||
SearchFunc: &ClientSearchFunc{
|
||||
defaultHook: func(context.Context, EmbeddingsSearchParameters) (r0 *EmbeddingCombinedSearchResults, r1 error) {
|
||||
return
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// NewStrictMockClient creates a new mock of the Client interface. All
|
||||
// methods panic on invocation, unless overwritten.
|
||||
func NewStrictMockClient() *MockClient {
|
||||
return &MockClient{
|
||||
IsContextRequiredForChatQueryFunc: &ClientIsContextRequiredForChatQueryFunc{
|
||||
defaultHook: func(context.Context, IsContextRequiredForChatQueryParameters) (bool, error) {
|
||||
panic("unexpected invocation of MockClient.IsContextRequiredForChatQuery")
|
||||
},
|
||||
},
|
||||
SearchFunc: &ClientSearchFunc{
|
||||
defaultHook: func(context.Context, EmbeddingsSearchParameters) (*EmbeddingCombinedSearchResults, error) {
|
||||
panic("unexpected invocation of MockClient.Search")
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// NewMockClientFrom creates a new mock of the MockClient interface. All
|
||||
// methods delegate to the given implementation, unless overwritten.
|
||||
func NewMockClientFrom(i Client) *MockClient {
|
||||
return &MockClient{
|
||||
IsContextRequiredForChatQueryFunc: &ClientIsContextRequiredForChatQueryFunc{
|
||||
defaultHook: i.IsContextRequiredForChatQuery,
|
||||
},
|
||||
SearchFunc: &ClientSearchFunc{
|
||||
defaultHook: i.Search,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ClientIsContextRequiredForChatQueryFunc describes the behavior when the
|
||||
// IsContextRequiredForChatQuery method of the parent MockClient instance is
|
||||
// invoked.
|
||||
type ClientIsContextRequiredForChatQueryFunc struct {
|
||||
defaultHook func(context.Context, IsContextRequiredForChatQueryParameters) (bool, error)
|
||||
hooks []func(context.Context, IsContextRequiredForChatQueryParameters) (bool, error)
|
||||
history []ClientIsContextRequiredForChatQueryFuncCall
|
||||
mutex sync.Mutex
|
||||
}
|
||||
|
||||
// IsContextRequiredForChatQuery delegates to the next hook function in the
|
||||
// queue and stores the parameter and result values of this invocation.
|
||||
func (m *MockClient) IsContextRequiredForChatQuery(v0 context.Context, v1 IsContextRequiredForChatQueryParameters) (bool, error) {
|
||||
r0, r1 := m.IsContextRequiredForChatQueryFunc.nextHook()(v0, v1)
|
||||
m.IsContextRequiredForChatQueryFunc.appendCall(ClientIsContextRequiredForChatQueryFuncCall{v0, v1, r0, r1})
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// SetDefaultHook sets function that is called when the
|
||||
// IsContextRequiredForChatQuery method of the parent MockClient instance is
|
||||
// invoked and the hook queue is empty.
|
||||
func (f *ClientIsContextRequiredForChatQueryFunc) SetDefaultHook(hook func(context.Context, IsContextRequiredForChatQueryParameters) (bool, error)) {
|
||||
f.defaultHook = hook
|
||||
}
|
||||
|
||||
// PushHook adds a function to the end of hook queue. Each invocation of the
|
||||
// IsContextRequiredForChatQuery method of the parent MockClient instance
|
||||
// invokes the hook at the front of the queue and discards it. After the
|
||||
// queue is empty, the default hook function is invoked for any future
|
||||
// action.
|
||||
func (f *ClientIsContextRequiredForChatQueryFunc) PushHook(hook func(context.Context, IsContextRequiredForChatQueryParameters) (bool, error)) {
|
||||
f.mutex.Lock()
|
||||
f.hooks = append(f.hooks, hook)
|
||||
f.mutex.Unlock()
|
||||
}
|
||||
|
||||
// SetDefaultReturn calls SetDefaultHook with a function that returns the
|
||||
// given values.
|
||||
func (f *ClientIsContextRequiredForChatQueryFunc) SetDefaultReturn(r0 bool, r1 error) {
|
||||
f.SetDefaultHook(func(context.Context, IsContextRequiredForChatQueryParameters) (bool, error) {
|
||||
return r0, r1
|
||||
})
|
||||
}
|
||||
|
||||
// PushReturn calls PushHook with a function that returns the given values.
|
||||
func (f *ClientIsContextRequiredForChatQueryFunc) PushReturn(r0 bool, r1 error) {
|
||||
f.PushHook(func(context.Context, IsContextRequiredForChatQueryParameters) (bool, error) {
|
||||
return r0, r1
|
||||
})
|
||||
}
|
||||
|
||||
func (f *ClientIsContextRequiredForChatQueryFunc) nextHook() func(context.Context, IsContextRequiredForChatQueryParameters) (bool, error) {
|
||||
f.mutex.Lock()
|
||||
defer f.mutex.Unlock()
|
||||
|
||||
if len(f.hooks) == 0 {
|
||||
return f.defaultHook
|
||||
}
|
||||
|
||||
hook := f.hooks[0]
|
||||
f.hooks = f.hooks[1:]
|
||||
return hook
|
||||
}
|
||||
|
||||
func (f *ClientIsContextRequiredForChatQueryFunc) appendCall(r0 ClientIsContextRequiredForChatQueryFuncCall) {
|
||||
f.mutex.Lock()
|
||||
f.history = append(f.history, r0)
|
||||
f.mutex.Unlock()
|
||||
}
|
||||
|
||||
// History returns a sequence of ClientIsContextRequiredForChatQueryFuncCall
|
||||
// objects describing the invocations of this function.
|
||||
func (f *ClientIsContextRequiredForChatQueryFunc) History() []ClientIsContextRequiredForChatQueryFuncCall {
|
||||
f.mutex.Lock()
|
||||
history := make([]ClientIsContextRequiredForChatQueryFuncCall, len(f.history))
|
||||
copy(history, f.history)
|
||||
f.mutex.Unlock()
|
||||
|
||||
return history
|
||||
}
|
||||
|
||||
// ClientIsContextRequiredForChatQueryFuncCall is an object that describes
|
||||
// an invocation of method IsContextRequiredForChatQuery on an instance of
|
||||
// MockClient.
|
||||
type ClientIsContextRequiredForChatQueryFuncCall struct {
|
||||
// Arg0 is the value of the 1st argument passed to this method
|
||||
// invocation.
|
||||
Arg0 context.Context
|
||||
// Arg1 is the value of the 2nd argument passed to this method
|
||||
// invocation.
|
||||
Arg1 IsContextRequiredForChatQueryParameters
|
||||
// Result0 is the value of the 1st result returned from this method
|
||||
// invocation.
|
||||
Result0 bool
|
||||
// Result1 is the value of the 2nd result returned from this method
|
||||
// invocation.
|
||||
Result1 error
|
||||
}
|
||||
|
||||
// Args returns an interface slice containing the arguments of this
|
||||
// invocation.
|
||||
func (c ClientIsContextRequiredForChatQueryFuncCall) Args() []interface{} {
|
||||
return []interface{}{c.Arg0, c.Arg1}
|
||||
}
|
||||
|
||||
// Results returns an interface slice containing the results of this
|
||||
// invocation.
|
||||
func (c ClientIsContextRequiredForChatQueryFuncCall) Results() []interface{} {
|
||||
return []interface{}{c.Result0, c.Result1}
|
||||
}
|
||||
|
||||
// ClientSearchFunc describes the behavior when the Search method of the
|
||||
// parent MockClient instance is invoked.
|
||||
type ClientSearchFunc struct {
|
||||
defaultHook func(context.Context, EmbeddingsSearchParameters) (*EmbeddingCombinedSearchResults, error)
|
||||
hooks []func(context.Context, EmbeddingsSearchParameters) (*EmbeddingCombinedSearchResults, error)
|
||||
history []ClientSearchFuncCall
|
||||
mutex sync.Mutex
|
||||
}
|
||||
|
||||
// Search delegates to the next hook function in the queue and stores the
|
||||
// parameter and result values of this invocation.
|
||||
func (m *MockClient) Search(v0 context.Context, v1 EmbeddingsSearchParameters) (*EmbeddingCombinedSearchResults, error) {
|
||||
r0, r1 := m.SearchFunc.nextHook()(v0, v1)
|
||||
m.SearchFunc.appendCall(ClientSearchFuncCall{v0, v1, r0, r1})
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// SetDefaultHook sets function that is called when the Search method of the
|
||||
// parent MockClient instance is invoked and the hook queue is empty.
|
||||
func (f *ClientSearchFunc) SetDefaultHook(hook func(context.Context, EmbeddingsSearchParameters) (*EmbeddingCombinedSearchResults, error)) {
|
||||
f.defaultHook = hook
|
||||
}
|
||||
|
||||
// PushHook adds a function to the end of hook queue. Each invocation of the
|
||||
// Search method of the parent MockClient instance invokes the hook at the
|
||||
// front of the queue and discards it. After the queue is empty, the default
|
||||
// hook function is invoked for any future action.
|
||||
func (f *ClientSearchFunc) PushHook(hook func(context.Context, EmbeddingsSearchParameters) (*EmbeddingCombinedSearchResults, error)) {
|
||||
f.mutex.Lock()
|
||||
f.hooks = append(f.hooks, hook)
|
||||
f.mutex.Unlock()
|
||||
}
|
||||
|
||||
// SetDefaultReturn calls SetDefaultHook with a function that returns the
|
||||
// given values.
|
||||
func (f *ClientSearchFunc) SetDefaultReturn(r0 *EmbeddingCombinedSearchResults, r1 error) {
|
||||
f.SetDefaultHook(func(context.Context, EmbeddingsSearchParameters) (*EmbeddingCombinedSearchResults, error) {
|
||||
return r0, r1
|
||||
})
|
||||
}
|
||||
|
||||
// PushReturn calls PushHook with a function that returns the given values.
|
||||
func (f *ClientSearchFunc) PushReturn(r0 *EmbeddingCombinedSearchResults, r1 error) {
|
||||
f.PushHook(func(context.Context, EmbeddingsSearchParameters) (*EmbeddingCombinedSearchResults, error) {
|
||||
return r0, r1
|
||||
})
|
||||
}
|
||||
|
||||
func (f *ClientSearchFunc) nextHook() func(context.Context, EmbeddingsSearchParameters) (*EmbeddingCombinedSearchResults, error) {
|
||||
f.mutex.Lock()
|
||||
defer f.mutex.Unlock()
|
||||
|
||||
if len(f.hooks) == 0 {
|
||||
return f.defaultHook
|
||||
}
|
||||
|
||||
hook := f.hooks[0]
|
||||
f.hooks = f.hooks[1:]
|
||||
return hook
|
||||
}
|
||||
|
||||
func (f *ClientSearchFunc) appendCall(r0 ClientSearchFuncCall) {
|
||||
f.mutex.Lock()
|
||||
f.history = append(f.history, r0)
|
||||
f.mutex.Unlock()
|
||||
}
|
||||
|
||||
// History returns a sequence of ClientSearchFuncCall objects describing the
|
||||
// invocations of this function.
|
||||
func (f *ClientSearchFunc) History() []ClientSearchFuncCall {
|
||||
f.mutex.Lock()
|
||||
history := make([]ClientSearchFuncCall, len(f.history))
|
||||
copy(history, f.history)
|
||||
f.mutex.Unlock()
|
||||
|
||||
return history
|
||||
}
|
||||
|
||||
// ClientSearchFuncCall is an object that describes an invocation of method
|
||||
// Search on an instance of MockClient.
|
||||
type ClientSearchFuncCall struct {
|
||||
// Arg0 is the value of the 1st argument passed to this method
|
||||
// invocation.
|
||||
Arg0 context.Context
|
||||
// Arg1 is the value of the 2nd argument passed to this method
|
||||
// invocation.
|
||||
Arg1 EmbeddingsSearchParameters
|
||||
// Result0 is the value of the 1st result returned from this method
|
||||
// invocation.
|
||||
Result0 *EmbeddingCombinedSearchResults
|
||||
// Result1 is the value of the 2nd result returned from this method
|
||||
// invocation.
|
||||
Result1 error
|
||||
}
|
||||
|
||||
// Args returns an interface slice containing the arguments of this
|
||||
// invocation.
|
||||
func (c ClientSearchFuncCall) Args() []interface{} {
|
||||
return []interface{}{c.Arg0, c.Arg1}
|
||||
}
|
||||
|
||||
// Results returns an interface slice containing the results of this
|
||||
// invocation.
|
||||
func (c ClientSearchFuncCall) Results() []interface{} {
|
||||
return []interface{}{c.Result0, c.Result1}
|
||||
}
|
||||
@ -2,16 +2,17 @@ package embeddings
|
||||
|
||||
import (
|
||||
"container/heap"
|
||||
"fmt"
|
||||
"math"
|
||||
"sort"
|
||||
|
||||
"github.com/sourcegraph/conc"
|
||||
"github.com/sourcegraph/sourcegraph/internal/api"
|
||||
)
|
||||
|
||||
type nearestNeighbor struct {
|
||||
index int
|
||||
scoreDetails SearchScoreDetails
|
||||
index int
|
||||
score int32
|
||||
debug searchDebugInfo
|
||||
}
|
||||
|
||||
type nearestNeighborsHeap struct {
|
||||
@ -21,7 +22,7 @@ type nearestNeighborsHeap struct {
|
||||
func (nn *nearestNeighborsHeap) Len() int { return len(nn.neighbors) }
|
||||
|
||||
func (nn *nearestNeighborsHeap) Less(i, j int) bool {
|
||||
return nn.neighbors[i].scoreDetails.Score < nn.neighbors[j].scoreDetails.Score
|
||||
return nn.neighbors[i].score < nn.neighbors[j].score
|
||||
}
|
||||
|
||||
func (nn *nearestNeighborsHeap) Swap(i, j int) {
|
||||
@ -80,16 +81,19 @@ type WorkerOptions struct {
|
||||
MinRowsToSplit int
|
||||
}
|
||||
|
||||
type SimilaritySearchResult struct {
|
||||
RepoEmbeddingRowMetadata
|
||||
SimilarityScore int32
|
||||
RankScore int32
|
||||
}
|
||||
|
||||
func (r *SimilaritySearchResult) Score() int32 {
|
||||
return r.SimilarityScore + r.RankScore
|
||||
}
|
||||
|
||||
// SimilaritySearch finds the `nResults` most similar rows to a query vector. It uses the cosine similarity metric.
|
||||
// IMPORTANT: The vectors in the embedding index have to be normalized for similarity search to work correctly.
|
||||
func (index *EmbeddingIndex) SimilaritySearch(
|
||||
query []int8,
|
||||
numResults int,
|
||||
workerOptions WorkerOptions,
|
||||
opts SearchOptions,
|
||||
repoName api.RepoName,
|
||||
revision api.CommitID,
|
||||
) []EmbeddingSearchResult {
|
||||
func (index *EmbeddingIndex) SimilaritySearch(query []int8, numResults int, workerOptions WorkerOptions, opts SearchOptions) []SimilaritySearchResult {
|
||||
if numResults == 0 || len(index.Embeddings) == 0 {
|
||||
return nil
|
||||
}
|
||||
@ -127,20 +131,16 @@ func (index *EmbeddingIndex) SimilaritySearch(
|
||||
}
|
||||
}
|
||||
// And re-sort it according to the score (descending).
|
||||
sort.Slice(neighbors, func(i, j int) bool { return neighbors[i].scoreDetails.Score > neighbors[j].scoreDetails.Score })
|
||||
sort.Slice(neighbors, func(i, j int) bool { return neighbors[i].score > neighbors[j].score })
|
||||
|
||||
// Take top neighbors and return them as results.
|
||||
results := make([]EmbeddingSearchResult, numResults)
|
||||
results := make([]SimilaritySearchResult, numResults)
|
||||
|
||||
for idx := 0; idx < min(numResults, len(neighbors)); idx++ {
|
||||
metadata := index.RowMetadata[neighbors[idx].index]
|
||||
results[idx] = EmbeddingSearchResult{
|
||||
RepoName: repoName,
|
||||
Revision: revision,
|
||||
FileName: metadata.FileName,
|
||||
StartLine: metadata.StartLine,
|
||||
EndLine: metadata.EndLine,
|
||||
ScoreDetails: neighbors[idx].scoreDetails,
|
||||
results[idx] = SimilaritySearchResult{
|
||||
RepoEmbeddingRowMetadata: index.RowMetadata[neighbors[idx].index],
|
||||
SimilarityScore: neighbors[idx].debug.similarity,
|
||||
RankScore: neighbors[idx].debug.rank,
|
||||
}
|
||||
}
|
||||
|
||||
@ -156,17 +156,17 @@ func (index *EmbeddingIndex) partialSimilaritySearch(query []int8, numResults in
|
||||
|
||||
nnHeap := newNearestNeighborsHeap()
|
||||
for i := partialRows.start; i < partialRows.start+numResults; i++ {
|
||||
scoreDetails := index.score(query, i, opts)
|
||||
heap.Push(nnHeap, nearestNeighbor{index: i, scoreDetails: scoreDetails})
|
||||
score, debugInfo := index.score(query, i, opts)
|
||||
heap.Push(nnHeap, nearestNeighbor{index: i, score: score, debug: debugInfo})
|
||||
}
|
||||
|
||||
for i := partialRows.start + numResults; i < partialRows.end; i++ {
|
||||
scoreDetails := index.score(query, i, opts)
|
||||
score, debugInfo := index.score(query, i, opts)
|
||||
// Add row if it has greater similarity than the smallest similarity in the heap.
|
||||
// This way we ensure keep a set of the highest similarities in the heap.
|
||||
if scoreDetails.Score > nnHeap.Peek().scoreDetails.Score {
|
||||
if score > nnHeap.Peek().score {
|
||||
heap.Pop(nnHeap)
|
||||
heap.Push(nnHeap, nearestNeighbor{index: i, scoreDetails: scoreDetails})
|
||||
heap.Push(nnHeap, nearestNeighbor{index: i, score: score, debug: debugInfo})
|
||||
}
|
||||
}
|
||||
|
||||
@ -178,7 +178,7 @@ const (
|
||||
scoreSimilarityWeight int32 = 2
|
||||
)
|
||||
|
||||
func (index *EmbeddingIndex) score(query []int8, i int, opts SearchOptions) SearchScoreDetails {
|
||||
func (index *EmbeddingIndex) score(query []int8, i int, opts SearchOptions) (score int32, debugInfo searchDebugInfo) {
|
||||
similarityScore := scoreSimilarityWeight * Dot(index.Row(i), query)
|
||||
|
||||
// handle missing ranks
|
||||
@ -195,11 +195,20 @@ func (index *EmbeddingIndex) score(query []int8, i int, opts SearchOptions) Sear
|
||||
rankScore = int32(float32(scoreFileRankWeight) * normalizedRank)
|
||||
}
|
||||
|
||||
return SearchScoreDetails{
|
||||
Score: similarityScore + rankScore,
|
||||
SimilarityScore: similarityScore,
|
||||
RankScore: rankScore,
|
||||
return similarityScore + rankScore, searchDebugInfo{similarity: similarityScore, rank: rankScore, enabled: opts.Debug}
|
||||
}
|
||||
|
||||
type searchDebugInfo struct {
|
||||
similarity int32
|
||||
rank int32
|
||||
enabled bool
|
||||
}
|
||||
|
||||
func (i *searchDebugInfo) String() string {
|
||||
if !i.enabled {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("score:%d, similarity:%d, rank:%d", i.similarity+i.rank, i.similarity, i.rank)
|
||||
}
|
||||
|
||||
func min(a, b int) int {
|
||||
@ -217,5 +226,6 @@ func max(a, b int) int {
|
||||
}
|
||||
|
||||
type SearchOptions struct {
|
||||
Debug bool
|
||||
UseDocumentRanks bool
|
||||
}
|
||||
|
||||
@ -64,10 +64,10 @@ func TestSimilaritySearch(t *testing.T) {
|
||||
for q := 0; q < numQueries; q++ {
|
||||
t.Run(fmt.Sprintf("find nearest neighbors query=%d numResults=%d numWorkers=%d", q, numResults, numWorkers), func(t *testing.T) {
|
||||
query := queries[q*columnDimension : (q+1)*columnDimension]
|
||||
results := index.SimilaritySearch(query, numResults, WorkerOptions{NumWorkers: numWorkers, MinRowsToSplit: 0}, SearchOptions{}, "", "")
|
||||
results := index.SimilaritySearch(query, numResults, WorkerOptions{NumWorkers: numWorkers, MinRowsToSplit: 0}, SearchOptions{})
|
||||
resultRowNums := make([]int, len(results))
|
||||
for i, r := range results {
|
||||
resultRowNums[i], _ = strconv.Atoi(r.FileName)
|
||||
resultRowNums[i], _ = strconv.Atoi(r.RepoEmbeddingRowMetadata.FileName)
|
||||
}
|
||||
expectedResults := ranks[q]
|
||||
require.Equal(t, expectedResults[:min(numResults, len(expectedResults))], resultRowNums)
|
||||
@ -152,7 +152,7 @@ func BenchmarkSimilaritySearch(b *testing.B) {
|
||||
b.Run(fmt.Sprintf("numWorkers=%d", numWorkers), func(b *testing.B) {
|
||||
start := time.Now()
|
||||
for n := 0; n < b.N; n++ {
|
||||
_ = index.SimilaritySearch(query, numResults, WorkerOptions{NumWorkers: numWorkers}, SearchOptions{}, "", "")
|
||||
_ = index.SimilaritySearch(query, numResults, WorkerOptions{NumWorkers: numWorkers}, SearchOptions{})
|
||||
}
|
||||
m := float64(numRows) * float64(b.N) / time.Since(start).Seconds()
|
||||
b.ReportMetric(m, "embeddings/s")
|
||||
@ -175,11 +175,23 @@ func TestScore(t *testing.T) {
|
||||
}
|
||||
// embeddings[0] = 64, 83, 70,
|
||||
// queries[0:3] = 53, 61, 97,
|
||||
scoreDetails := index.score(queries[0:columnDimension], 0, SearchOptions{UseDocumentRanks: true})
|
||||
score, debugInfo := index.score(queries[0:columnDimension], 0, SearchOptions{Debug: true, UseDocumentRanks: true})
|
||||
|
||||
// Check that the score is correct
|
||||
expectedScore := scoreSimilarityWeight * ((64 * 53) + (83 * 61) + (70 * 97))
|
||||
if math.Abs(float64(scoreDetails.Score-expectedScore)) > 0.0001 {
|
||||
t.Fatalf("Expected score %d, but got %d", expectedScore, scoreDetails.Score)
|
||||
if math.Abs(float64(score-expectedScore)) > 0.0001 {
|
||||
t.Fatalf("Expected score %d, but got %d", expectedScore, score)
|
||||
}
|
||||
|
||||
if debugInfo.String() == "" {
|
||||
t.Fatal("Expected a non-empty debug")
|
||||
}
|
||||
}
|
||||
|
||||
func simpleCosineSimilarity(a, b []int8) int32 {
|
||||
similarity := int32(0)
|
||||
for i := 0; i < len(a); i++ {
|
||||
similarity += int32(a[i]) * int32(b[i])
|
||||
}
|
||||
return similarity
|
||||
}
|
||||
|
||||
@ -1,8 +1,6 @@
|
||||
package embeddings
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/sourcegraph/log"
|
||||
@ -47,46 +45,18 @@ type ContextDetectionEmbeddingIndex struct {
|
||||
MessagesWithoutAdditionalContextMeanEmbedding []float32
|
||||
}
|
||||
|
||||
type EmbeddingCombinedSearchResults struct {
|
||||
CodeResults EmbeddingSearchResults `json:"codeResults"`
|
||||
TextResults EmbeddingSearchResults `json:"textResults"`
|
||||
}
|
||||
|
||||
type EmbeddingSearchResults []EmbeddingSearchResult
|
||||
|
||||
// MergeTruncate merges other into the search results, keeping only max results with the highest scores
|
||||
func (esrs *EmbeddingSearchResults) MergeTruncate(other EmbeddingSearchResults, max int) {
|
||||
self := *esrs
|
||||
self = append(self, other...)
|
||||
sort.Slice(self, func(i, j int) bool { return self[i].Score() > self[j].Score() })
|
||||
*esrs = self[:max]
|
||||
type EmbeddingSearchResults struct {
|
||||
CodeResults []EmbeddingSearchResult `json:"codeResults"`
|
||||
TextResults []EmbeddingSearchResult `json:"textResults"`
|
||||
}
|
||||
|
||||
type EmbeddingSearchResult struct {
|
||||
RepoName api.RepoName `json:"repoName"`
|
||||
Revision api.CommitID `json:"revision"`
|
||||
|
||||
FileName string `json:"fileName"`
|
||||
StartLine int `json:"startLine"`
|
||||
EndLine int `json:"endLine"`
|
||||
|
||||
ScoreDetails SearchScoreDetails `json:"scoreDetails"`
|
||||
}
|
||||
|
||||
func (esr *EmbeddingSearchResult) Score() int32 {
|
||||
return esr.ScoreDetails.RankScore + esr.ScoreDetails.SimilarityScore
|
||||
}
|
||||
|
||||
type SearchScoreDetails struct {
|
||||
Score int32 `json:"score"`
|
||||
|
||||
// Breakdown
|
||||
SimilarityScore int32 `json:"similarityScore"`
|
||||
RankScore int32 `json:"rankScore"`
|
||||
}
|
||||
|
||||
func (s *SearchScoreDetails) String() string {
|
||||
return fmt.Sprintf("score:%d, similarity:%d, rank:%d", s.Score, s.SimilarityScore, s.RankScore)
|
||||
RepoName api.RepoName
|
||||
Revision api.CommitID
|
||||
RepoEmbeddingRowMetadata
|
||||
Content string `json:"content"`
|
||||
// Experimental: Clients should not rely on any particular format of debug
|
||||
Debug string `json:"debug,omitempty"`
|
||||
}
|
||||
|
||||
// DEPRECATED: to support decoding old indexes, we need a struct
|
||||
|
||||
@ -123,15 +123,3 @@
|
||||
path: github.com/sourcegraph/sourcegraph/enterprise/internal/github_apps/store
|
||||
interfaces:
|
||||
- GitHubAppsStore
|
||||
- filename: enterprise/internal/embeddings/mocks_temp.go
|
||||
path: github.com/sourcegraph/sourcegraph/enterprise/internal/embeddings
|
||||
interfaces:
|
||||
- Client
|
||||
- filename: enterprise/internal/embeddings/background/repo/mocks_temp.go
|
||||
path: github.com/sourcegraph/sourcegraph/enterprise/internal/embeddings/background/repo
|
||||
interfaces:
|
||||
- RepoEmbeddingJobsStore
|
||||
- filename: enterprise/internal/embeddings/background/contextdetection/mocks_temp.go
|
||||
path: github.com/sourcegraph/sourcegraph/enterprise/internal/embeddings/background/contextdetection
|
||||
interfaces:
|
||||
- ContextDetectionEmbeddingJobsStore
|
||||
|
||||
Loading…
Reference in New Issue
Block a user