mirror of
https://github.com/sourcegraph/sourcegraph.git
synced 2026-02-06 16:31:47 +00:00
Embeddings: multi-repo search take 2 (#52019)
This re-opens the reverted #51662. There were three independent issues with the deployed version: 1) There was a stray printf from debugging. It wasn't breaking anything, but it was pretty spammy. 2) I made a change to `httpPost` that made it take an endpoint rather than a repo name, but forgot to update the callsite in `IsContextRequiredForChatQuery`, so it was using an empty string as the URL. This was not caught by tests because the API was not covered by tests before this PR, and I only added tests for the `EmbeddingsSearch` endpoint, not the `IsContextRequiredForChatQuery` endpoint. I've added a new test to this endpoint as part of this un-revert 3) Followup changes in the client to use the new endpoints were using types that did not match the actual response shape from GraphQL and this was causing undefined field accesses. This was not caught by the type system because the problem was at the JSON -> typescript type layer, which is an unchecked type assertion. I squashed the original PR into the first commit, so the second and third commits are the only changes ## Test plan Added a test for `IsContextRequiredForChatQuery` API
This commit is contained in:
parent
9578c413db
commit
78d95e61d0
@ -11,7 +11,7 @@ export class SourcegraphEmbeddingsSearchClient implements EmbeddingsSearch {
|
||||
textResultsCount: number
|
||||
): Promise<EmbeddingsSearchResults | Error> {
|
||||
if (this.web) {
|
||||
return this.client.searchEmbeddings(this.repoId, query, codeResultsCount, textResultsCount)
|
||||
return this.client.searchEmbeddings([this.repoId], query, codeResultsCount, textResultsCount)
|
||||
}
|
||||
|
||||
return this.client.legacySearchEmbeddings(this.repoId, query, codeResultsCount, textResultsCount)
|
||||
|
||||
@ -36,6 +36,10 @@ interface EmbeddingsSearchResponse {
|
||||
embeddingsSearch: EmbeddingsSearchResults
|
||||
}
|
||||
|
||||
interface EmbeddingsMultiSearchResponse {
|
||||
embeddingsMultiSearch: EmbeddingsSearchResults
|
||||
}
|
||||
|
||||
interface LogEventResponse {}
|
||||
|
||||
export interface EmbeddingsSearchResult {
|
||||
@ -144,17 +148,17 @@ export class SourcegraphGraphQLAPIClient {
|
||||
}
|
||||
|
||||
public async searchEmbeddings(
|
||||
repo: string,
|
||||
repos: string[],
|
||||
query: string,
|
||||
codeResultsCount: number,
|
||||
textResultsCount: number
|
||||
): Promise<EmbeddingsSearchResults | Error> {
|
||||
return this.fetchSourcegraphAPI<APIResponse<EmbeddingsSearchResponse>>(SEARCH_EMBEDDINGS_QUERY, {
|
||||
repo,
|
||||
return this.fetchSourcegraphAPI<APIResponse<EmbeddingsMultiSearchResponse>>(SEARCH_EMBEDDINGS_QUERY, {
|
||||
repos,
|
||||
query,
|
||||
codeResultsCount,
|
||||
textResultsCount,
|
||||
}).then(response => extractDataOrError(response, data => data.embeddingsSearch))
|
||||
}).then(response => extractDataOrError(response, data => data.embeddingsMultiSearch))
|
||||
}
|
||||
|
||||
// (Naman): This is a temporary workaround for supporting vscode cody integrated with older version of sourcegraph which do not support the latest searchEmbeddings query.
|
||||
|
||||
@ -21,19 +21,19 @@ query Repository($name: String!) {
|
||||
}`
|
||||
|
||||
export const SEARCH_EMBEDDINGS_QUERY = `
|
||||
query EmbeddingsSearch($repo: ID!, $query: String!, $codeResultsCount: Int!, $textResultsCount: Int!) {
|
||||
embeddingsSearch(repo: $repo, query: $query, codeResultsCount: $codeResultsCount, textResultsCount: $textResultsCount) {
|
||||
query EmbeddingsSearch($repos: [ID!]!, $query: String!, $codeResultsCount: Int!, $textResultsCount: Int!) {
|
||||
embeddingsMultiSearch(repos: $repos, query: $query, codeResultsCount: $codeResultsCount, textResultsCount: $textResultsCount) {
|
||||
codeResults {
|
||||
repoName
|
||||
revision
|
||||
repoName
|
||||
revision
|
||||
fileName
|
||||
startLine
|
||||
endLine
|
||||
content
|
||||
}
|
||||
textResults {
|
||||
repoName
|
||||
revision
|
||||
repoName
|
||||
revision
|
||||
fileName
|
||||
startLine
|
||||
endLine
|
||||
|
||||
@ -11,6 +11,7 @@ import (
|
||||
|
||||
type EmbeddingsResolver interface {
|
||||
EmbeddingsSearch(ctx context.Context, args EmbeddingsSearchInputArgs) (EmbeddingsSearchResultsResolver, error)
|
||||
EmbeddingsMultiSearch(ctx context.Context, args EmbeddingsMultiSearchInputArgs) (EmbeddingsSearchResultsResolver, error)
|
||||
IsContextRequiredForChatQuery(ctx context.Context, args IsContextRequiredForChatQueryInputArgs) (bool, error)
|
||||
RepoEmbeddingJobs(ctx context.Context, args ListRepoEmbeddingJobsArgs) (*graphqlutil.ConnectionResolver[RepoEmbeddingJobResolver], error)
|
||||
|
||||
@ -33,9 +34,16 @@ type EmbeddingsSearchInputArgs struct {
|
||||
TextResultsCount int32
|
||||
}
|
||||
|
||||
type EmbeddingsMultiSearchInputArgs struct {
|
||||
Repos []graphql.ID
|
||||
Query string
|
||||
CodeResultsCount int32
|
||||
TextResultsCount int32
|
||||
}
|
||||
|
||||
type EmbeddingsSearchResultsResolver interface {
|
||||
CodeResults(ctx context.Context) []EmbeddingsSearchResultResolver
|
||||
TextResults(ctx context.Context) []EmbeddingsSearchResultResolver
|
||||
CodeResults(ctx context.Context) ([]EmbeddingsSearchResultResolver, error)
|
||||
TextResults(ctx context.Context) ([]EmbeddingsSearchResultResolver, error)
|
||||
}
|
||||
|
||||
type EmbeddingsSearchResultResolver interface {
|
||||
|
||||
@ -22,6 +22,31 @@ extend type Query {
|
||||
"""
|
||||
textResultsCount: Int!
|
||||
): EmbeddingsSearchResults!
|
||||
|
||||
"""
|
||||
Experimental: Searches a set of repositories for similar code and text results using embeddings.
|
||||
We separated code and text results because text results tended to always feature at the top of the combined results,
|
||||
and didn't leave room for the code.
|
||||
"""
|
||||
embeddingsMultiSearch(
|
||||
"""
|
||||
The repository to search.
|
||||
"""
|
||||
repos: [ID!]!
|
||||
"""
|
||||
The query used for embeddings search.
|
||||
"""
|
||||
query: String!
|
||||
"""
|
||||
The number of code results to return.
|
||||
"""
|
||||
codeResultsCount: Int!
|
||||
"""
|
||||
The number of text results to return. Text results contain Markdown files and similar file types primarily used for writing documentation.
|
||||
"""
|
||||
textResultsCount: Int!
|
||||
): EmbeddingsSearchResults!
|
||||
|
||||
"""
|
||||
Experimental: Determines whether the given query requires further context before it can be answered.
|
||||
For example:
|
||||
|
||||
1
enterprise/cmd/embeddings/qa/BUILD.bazel
generated
1
enterprise/cmd/embeddings/qa/BUILD.bazel
generated
@ -8,6 +8,7 @@ go_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//enterprise/internal/embeddings",
|
||||
"//internal/api",
|
||||
"//lib/errors",
|
||||
],
|
||||
)
|
||||
|
||||
@ -11,6 +11,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/enterprise/internal/embeddings"
|
||||
"github.com/sourcegraph/sourcegraph/internal/api"
|
||||
"github.com/sourcegraph/sourcegraph/lib/errors"
|
||||
)
|
||||
|
||||
@ -21,7 +22,7 @@ import (
|
||||
var fs embed.FS
|
||||
|
||||
type embeddingsSearcher interface {
|
||||
Search(args embeddings.EmbeddingsSearchParameters) (*embeddings.EmbeddingSearchResults, error)
|
||||
Search(args embeddings.EmbeddingsSearchParameters) (*embeddings.EmbeddingCombinedSearchResults, error)
|
||||
}
|
||||
|
||||
// Run runs the evaluation and returns recall for the test data.
|
||||
@ -45,11 +46,10 @@ func Run(searcher embeddingsSearcher) (float64, error) {
|
||||
relevantFile := fields[1]
|
||||
|
||||
args := embeddings.EmbeddingsSearchParameters{
|
||||
RepoName: "github.com/sourcegraph/sourcegraph",
|
||||
RepoNames: []api.RepoName{"github.com/sourcegraph/sourcegraph"},
|
||||
Query: query,
|
||||
CodeResultsCount: 20,
|
||||
TextResultsCount: 2,
|
||||
Debug: true,
|
||||
}
|
||||
|
||||
results, err := searcher.Search(args)
|
||||
@ -70,11 +70,7 @@ func Run(searcher embeddingsSearcher) (float64, error) {
|
||||
fmt.Printf(" ")
|
||||
}
|
||||
fmt.Printf("%d. %s", i+1, result.FileName)
|
||||
if result.Debug != "" {
|
||||
fmt.Printf(" (%s)\n", result.Debug)
|
||||
} else {
|
||||
fmt.Print("\n")
|
||||
}
|
||||
fmt.Printf(" (%s)\n", result.ScoreDetails.String())
|
||||
}
|
||||
fmt.Println()
|
||||
if fileFound {
|
||||
@ -103,7 +99,7 @@ func NewClient(url string) *client {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *client) Search(args embeddings.EmbeddingsSearchParameters) (*embeddings.EmbeddingSearchResults, error) {
|
||||
func (c *client) Search(args embeddings.EmbeddingsSearchParameters) (*embeddings.EmbeddingCombinedSearchResults, error) {
|
||||
b, err := json.Marshal(args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -130,7 +126,7 @@ func (c *client) Search(args embeddings.EmbeddingsSearchParameters) (*embeddings
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res := embeddings.EmbeddingSearchResults{}
|
||||
res := embeddings.EmbeddingCombinedSearchResults{}
|
||||
err = json.Unmarshal(body, &res)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to unmarshal response")
|
||||
|
||||
6
enterprise/cmd/embeddings/shared/BUILD.bazel
generated
6
enterprise/cmd/embeddings/shared/BUILD.bazel
generated
@ -35,7 +35,6 @@ go_library(
|
||||
"//internal/env",
|
||||
"//internal/errcode",
|
||||
"//internal/featureflag",
|
||||
"//internal/gitserver",
|
||||
"//internal/goroutine",
|
||||
"//internal/honey",
|
||||
"//internal/httpserver",
|
||||
@ -79,8 +78,8 @@ go_test(
|
||||
srcs = [
|
||||
"context_detection_test.go",
|
||||
"context_qa_test.go",
|
||||
"main_test.go",
|
||||
"repo_embedding_index_cache_test.go",
|
||||
"search_test.go",
|
||||
],
|
||||
embed = [":shared"],
|
||||
embedsrcs = [
|
||||
@ -93,10 +92,11 @@ go_test(
|
||||
"//enterprise/internal/embeddings/background/repo",
|
||||
"//internal/api",
|
||||
"//internal/database",
|
||||
"//internal/endpoint",
|
||||
"//internal/types",
|
||||
"//internal/uploadstore/mocks",
|
||||
"//lib/errors",
|
||||
"@com_github_sourcegraph_log//:log",
|
||||
"@com_github_sourcegraph_log//logtest",
|
||||
"@com_github_stretchr_testify//require",
|
||||
],
|
||||
)
|
||||
|
||||
@ -13,8 +13,6 @@ import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/sourcegraph/log"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/enterprise/cmd/embeddings/qa"
|
||||
"github.com/sourcegraph/sourcegraph/enterprise/internal/embeddings"
|
||||
"github.com/sourcegraph/sourcegraph/internal/api"
|
||||
@ -34,7 +32,6 @@ func TestRecall(t *testing.T) {
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
logger := log.NoOp()
|
||||
|
||||
// Set up mock functions
|
||||
queryEmbeddings, err := loadQueryEmbeddings(t)
|
||||
@ -59,20 +56,13 @@ func TestRecall(t *testing.T) {
|
||||
return embeddings.DownloadRepoEmbeddingIndex(context.Background(), mockStore, string(key))
|
||||
}
|
||||
|
||||
// We only care about the file names in this test.
|
||||
mockReadFile := func(ctx context.Context, repoName api.RepoName, revision api.CommitID, fileName string) ([]byte, error) {
|
||||
return []byte{}, nil
|
||||
}
|
||||
|
||||
// Weaviate is disabled per default. We don't need it for this test.
|
||||
weaviate := &weaviateClient{}
|
||||
|
||||
searcher := func(args embeddings.EmbeddingsSearchParameters) (*embeddings.EmbeddingSearchResults, error) {
|
||||
return searchRepoEmbeddingIndex(
|
||||
searcher := func(args embeddings.EmbeddingsSearchParameters) (*embeddings.EmbeddingCombinedSearchResults, error) {
|
||||
return searchRepoEmbeddingIndexes(
|
||||
ctx,
|
||||
logger,
|
||||
args,
|
||||
mockReadFile,
|
||||
getRepoEmbeddingIndex,
|
||||
getQueryEmbedding,
|
||||
weaviate,
|
||||
@ -120,8 +110,8 @@ func loadQueryEmbeddings(t *testing.T) (map[string][]float32, error) {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
type embeddingsSearcherFunc func(args embeddings.EmbeddingsSearchParameters) (*embeddings.EmbeddingSearchResults, error)
|
||||
type embeddingsSearcherFunc func(args embeddings.EmbeddingsSearchParameters) (*embeddings.EmbeddingCombinedSearchResults, error)
|
||||
|
||||
func (f embeddingsSearcherFunc) Search(args embeddings.EmbeddingsSearchParameters) (*embeddings.EmbeddingSearchResults, error) {
|
||||
func (f embeddingsSearcherFunc) Search(args embeddings.EmbeddingsSearchParameters) (*embeddings.EmbeddingCombinedSearchResults, error) {
|
||||
return f(args)
|
||||
}
|
||||
|
||||
@ -19,7 +19,6 @@ import (
|
||||
"github.com/sourcegraph/sourcegraph/enterprise/internal/embeddings"
|
||||
"github.com/sourcegraph/sourcegraph/enterprise/internal/embeddings/background/repo"
|
||||
"github.com/sourcegraph/sourcegraph/internal/actor"
|
||||
"github.com/sourcegraph/sourcegraph/internal/api"
|
||||
"github.com/sourcegraph/sourcegraph/internal/authz"
|
||||
"github.com/sourcegraph/sourcegraph/internal/conf"
|
||||
"github.com/sourcegraph/sourcegraph/internal/conf/conftypes"
|
||||
@ -27,7 +26,6 @@ import (
|
||||
connections "github.com/sourcegraph/sourcegraph/internal/database/connections/live"
|
||||
"github.com/sourcegraph/sourcegraph/internal/errcode"
|
||||
"github.com/sourcegraph/sourcegraph/internal/featureflag"
|
||||
"github.com/sourcegraph/sourcegraph/internal/gitserver"
|
||||
"github.com/sourcegraph/sourcegraph/internal/goroutine"
|
||||
"github.com/sourcegraph/sourcegraph/internal/honey"
|
||||
"github.com/sourcegraph/sourcegraph/internal/httpserver"
|
||||
@ -58,7 +56,6 @@ func Main(ctx context.Context, observationCtx *observation.Context, ready servic
|
||||
repoEmbeddingJobsStore := repo.NewRepoEmbeddingJobsStore(db)
|
||||
|
||||
// Run setup
|
||||
gitserverClient := gitserver.NewClient()
|
||||
uploadStore, err := embeddings.NewEmbeddingsUploadStore(ctx, observationCtx, config.EmbeddingsUploadStoreConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -69,10 +66,6 @@ func Main(ctx context.Context, observationCtx *observation.Context, ready servic
|
||||
return errors.Wrap(err, "creating sub-repo client")
|
||||
}
|
||||
|
||||
readFile := func(ctx context.Context, repoName api.RepoName, revision api.CommitID, fileName string) ([]byte, error) {
|
||||
return gitserverClient.ReadFile(ctx, authz.DefaultSubRepoPermsChecker, repoName, revision, fileName)
|
||||
}
|
||||
|
||||
indexGetter, err := NewCachedEmbeddingIndexGetter(
|
||||
repoStore,
|
||||
repoEmbeddingJobsStore,
|
||||
@ -92,7 +85,6 @@ func Main(ctx context.Context, observationCtx *observation.Context, ready servic
|
||||
|
||||
weaviate := newWeaviateClient(
|
||||
logger,
|
||||
readFile,
|
||||
getQueryEmbedding,
|
||||
config.WeaviateURL,
|
||||
)
|
||||
@ -100,7 +92,7 @@ func Main(ctx context.Context, observationCtx *observation.Context, ready servic
|
||||
getContextDetectionEmbeddingIndex := getCachedContextDetectionEmbeddingIndex(uploadStore)
|
||||
|
||||
// Create HTTP server
|
||||
handler := NewHandler(logger, readFile, indexGetter.Get, getQueryEmbedding, weaviate, getContextDetectionEmbeddingIndex)
|
||||
handler := NewHandler(logger, indexGetter.Get, getQueryEmbedding, weaviate, getContextDetectionEmbeddingIndex)
|
||||
handler = handlePanic(logger, handler)
|
||||
handler = featureflag.Middleware(db.FeatureFlags(), handler)
|
||||
handler = trace.HTTPMiddleware(logger, handler, conf.DefaultClient())
|
||||
@ -122,7 +114,6 @@ func Main(ctx context.Context, observationCtx *observation.Context, ready servic
|
||||
|
||||
func NewHandler(
|
||||
logger log.Logger,
|
||||
readFile readFileFn,
|
||||
getRepoEmbeddingIndex getRepoEmbeddingIndexFn,
|
||||
getQueryEmbedding getQueryEmbeddingFn,
|
||||
weaviate *weaviateClient,
|
||||
@ -143,7 +134,7 @@ func NewHandler(
|
||||
return
|
||||
}
|
||||
|
||||
res, err := searchRepoEmbeddingIndex(r.Context(), logger, args, readFile, getRepoEmbeddingIndex, getQueryEmbedding, weaviate)
|
||||
res, err := searchRepoEmbeddingIndexes(r.Context(), args, getRepoEmbeddingIndex, getQueryEmbedding, weaviate)
|
||||
if errcode.IsNotFound(err) {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
|
||||
248
enterprise/cmd/embeddings/shared/main_test.go
Normal file
248
enterprise/cmd/embeddings/shared/main_test.go
Normal file
@ -0,0 +1,248 @@
|
||||
package shared
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/sourcegraph/log/logtest"
|
||||
"github.com/sourcegraph/sourcegraph/enterprise/internal/embeddings"
|
||||
"github.com/sourcegraph/sourcegraph/internal/api"
|
||||
"github.com/sourcegraph/sourcegraph/internal/endpoint"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestEmbeddingsSearch(t *testing.T) {
|
||||
logger := logtest.Scoped(t)
|
||||
|
||||
makeIndex := func(name api.RepoName, w int8) *embeddings.RepoEmbeddingIndex {
|
||||
return &embeddings.RepoEmbeddingIndex{
|
||||
RepoName: name,
|
||||
Revision: "",
|
||||
CodeIndex: embeddings.EmbeddingIndex{
|
||||
Embeddings: []int8{
|
||||
w, 0, 0, 0,
|
||||
0, w, 0, 0,
|
||||
0, 0, w, 0,
|
||||
0, 0, 0, w,
|
||||
},
|
||||
ColumnDimension: 4,
|
||||
RowMetadata: []embeddings.RepoEmbeddingRowMetadata{
|
||||
{FileName: "codefile1", StartLine: 0, EndLine: 1},
|
||||
{FileName: "codefile2", StartLine: 0, EndLine: 1},
|
||||
{FileName: "codefile3", StartLine: 0, EndLine: 1},
|
||||
{FileName: "codefile4", StartLine: 0, EndLine: 1},
|
||||
},
|
||||
},
|
||||
TextIndex: embeddings.EmbeddingIndex{
|
||||
Embeddings: []int8{
|
||||
w, 0, 0, 0,
|
||||
0, w, 0, 0,
|
||||
0, 0, w, 0,
|
||||
0, 0, 0, w,
|
||||
},
|
||||
ColumnDimension: 4,
|
||||
RowMetadata: []embeddings.RepoEmbeddingRowMetadata{
|
||||
{FileName: "textfile1", StartLine: 0, EndLine: 1},
|
||||
{FileName: "textfile2", StartLine: 0, EndLine: 1},
|
||||
{FileName: "textfile3", StartLine: 0, EndLine: 1},
|
||||
{FileName: "textfile4", StartLine: 0, EndLine: 1},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
indexes := map[api.RepoName]*embeddings.RepoEmbeddingIndex{
|
||||
"repo1": makeIndex("repo1", 1),
|
||||
"repo2": makeIndex("repo2", 2),
|
||||
"repo3": makeIndex("repo3", 3),
|
||||
"repo4": makeIndex("repo4", 4),
|
||||
}
|
||||
|
||||
getRepoEmbeddingIndex := func(_ context.Context, repoName api.RepoName) (*embeddings.RepoEmbeddingIndex, error) {
|
||||
return indexes[repoName], nil
|
||||
}
|
||||
getQueryEmbedding := func(_ context.Context, query string) ([]float32, error) {
|
||||
switch query {
|
||||
case "one":
|
||||
return []float32{1, 0, 0, 0}, nil
|
||||
case "two":
|
||||
return []float32{0, 1, 0, 0}, nil
|
||||
case "three":
|
||||
return []float32{0, 0, 1, 0}, nil
|
||||
case "four":
|
||||
return []float32{0, 0, 1, 1}, nil
|
||||
case "context detection":
|
||||
return []float32{2, 4, 6, 8}, nil
|
||||
default:
|
||||
panic("unknown")
|
||||
}
|
||||
}
|
||||
getContextDetectionEmbeddingIndex := func(context.Context) (*embeddings.ContextDetectionEmbeddingIndex, error) {
|
||||
return &embeddings.ContextDetectionEmbeddingIndex{
|
||||
MessagesWithAdditionalContextMeanEmbedding: []float32{1, 2, 3, 4},
|
||||
MessagesWithoutAdditionalContextMeanEmbedding: []float32{4, 3, 2, 1},
|
||||
}, nil
|
||||
}
|
||||
|
||||
server1 := httptest.NewServer(NewHandler(
|
||||
logger,
|
||||
getRepoEmbeddingIndex,
|
||||
getQueryEmbedding,
|
||||
nil,
|
||||
getContextDetectionEmbeddingIndex,
|
||||
))
|
||||
|
||||
server2 := httptest.NewServer(NewHandler(
|
||||
logger,
|
||||
getRepoEmbeddingIndex,
|
||||
getQueryEmbedding,
|
||||
nil,
|
||||
getContextDetectionEmbeddingIndex,
|
||||
))
|
||||
|
||||
client := embeddings.NewClient(endpoint.Static(server1.URL, server2.URL), http.DefaultClient)
|
||||
|
||||
{
|
||||
// First test: we should return results for file1 based on the query.
|
||||
// The rankings should have repo4 highest because it has the largest weighted
|
||||
// embeddings.
|
||||
params := embeddings.EmbeddingsSearchParameters{
|
||||
RepoNames: []api.RepoName{"repo1", "repo2", "repo3", "repo4"},
|
||||
RepoIDs: []api.RepoID{1, 2, 3, 4},
|
||||
Query: "one",
|
||||
CodeResultsCount: 2,
|
||||
TextResultsCount: 2,
|
||||
UseDocumentRanks: false,
|
||||
}
|
||||
|
||||
results, err := client.Search(context.Background(), params)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, &embeddings.EmbeddingCombinedSearchResults{
|
||||
CodeResults: embeddings.EmbeddingSearchResults{{
|
||||
RepoName: "repo4",
|
||||
FileName: "codefile1",
|
||||
StartLine: 0,
|
||||
EndLine: 1,
|
||||
ScoreDetails: embeddings.SearchScoreDetails{Score: 1016, SimilarityScore: 1016},
|
||||
}, {
|
||||
RepoName: "repo3",
|
||||
FileName: "codefile1",
|
||||
StartLine: 0,
|
||||
EndLine: 1,
|
||||
ScoreDetails: embeddings.SearchScoreDetails{Score: 762, SimilarityScore: 762},
|
||||
}},
|
||||
TextResults: embeddings.EmbeddingSearchResults{{
|
||||
RepoName: "repo4",
|
||||
FileName: "textfile1",
|
||||
StartLine: 0,
|
||||
EndLine: 1,
|
||||
ScoreDetails: embeddings.SearchScoreDetails{Score: 1016, SimilarityScore: 1016},
|
||||
}, {
|
||||
RepoName: "repo3",
|
||||
FileName: "textfile1",
|
||||
StartLine: 0,
|
||||
EndLine: 1,
|
||||
ScoreDetails: embeddings.SearchScoreDetails{Score: 762, SimilarityScore: 762},
|
||||
}},
|
||||
}, results)
|
||||
}
|
||||
|
||||
{
|
||||
// Second test: providing a subset of repos should only search those repos
|
||||
params := embeddings.EmbeddingsSearchParameters{
|
||||
RepoNames: []api.RepoName{"repo1", "repo3"},
|
||||
RepoIDs: []api.RepoID{1, 2, 3, 4},
|
||||
Query: "one",
|
||||
CodeResultsCount: 2,
|
||||
TextResultsCount: 2,
|
||||
UseDocumentRanks: false,
|
||||
}
|
||||
|
||||
results, err := client.Search(context.Background(), params)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, &embeddings.EmbeddingCombinedSearchResults{
|
||||
CodeResults: embeddings.EmbeddingSearchResults{{
|
||||
RepoName: "repo3",
|
||||
FileName: "codefile1",
|
||||
StartLine: 0,
|
||||
EndLine: 1,
|
||||
ScoreDetails: embeddings.SearchScoreDetails{Score: 762, SimilarityScore: 762},
|
||||
}, {
|
||||
RepoName: "repo1",
|
||||
FileName: "codefile1",
|
||||
StartLine: 0,
|
||||
EndLine: 1,
|
||||
ScoreDetails: embeddings.SearchScoreDetails{Score: 254, SimilarityScore: 254},
|
||||
}},
|
||||
TextResults: embeddings.EmbeddingSearchResults{{
|
||||
RepoName: "repo3",
|
||||
FileName: "textfile1",
|
||||
StartLine: 0,
|
||||
EndLine: 1,
|
||||
ScoreDetails: embeddings.SearchScoreDetails{Score: 762, SimilarityScore: 762},
|
||||
}, {
|
||||
RepoName: "repo1",
|
||||
FileName: "textfile1",
|
||||
StartLine: 0,
|
||||
EndLine: 1,
|
||||
ScoreDetails: embeddings.SearchScoreDetails{Score: 254, SimilarityScore: 254},
|
||||
}},
|
||||
}, results)
|
||||
}
|
||||
|
||||
{
|
||||
// Third test: try a different file just to be safe
|
||||
params := embeddings.EmbeddingsSearchParameters{
|
||||
RepoNames: []api.RepoName{"repo1", "repo2", "repo3", "repo4"},
|
||||
RepoIDs: []api.RepoID{1, 2, 3, 4},
|
||||
Query: "three",
|
||||
CodeResultsCount: 2,
|
||||
TextResultsCount: 2,
|
||||
UseDocumentRanks: false,
|
||||
}
|
||||
|
||||
results, err := client.Search(context.Background(), params)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, &embeddings.EmbeddingCombinedSearchResults{
|
||||
CodeResults: embeddings.EmbeddingSearchResults{{
|
||||
RepoName: "repo4",
|
||||
FileName: "codefile3",
|
||||
StartLine: 0,
|
||||
EndLine: 1,
|
||||
ScoreDetails: embeddings.SearchScoreDetails{Score: 1016, SimilarityScore: 1016},
|
||||
}, {
|
||||
RepoName: "repo3",
|
||||
FileName: "codefile3",
|
||||
StartLine: 0,
|
||||
EndLine: 1,
|
||||
ScoreDetails: embeddings.SearchScoreDetails{Score: 762, SimilarityScore: 762},
|
||||
}},
|
||||
TextResults: embeddings.EmbeddingSearchResults{{
|
||||
RepoName: "repo4",
|
||||
FileName: "textfile3",
|
||||
StartLine: 0,
|
||||
EndLine: 1,
|
||||
ScoreDetails: embeddings.SearchScoreDetails{Score: 1016, SimilarityScore: 1016},
|
||||
}, {
|
||||
RepoName: "repo3",
|
||||
FileName: "textfile3",
|
||||
StartLine: 0,
|
||||
EndLine: 1,
|
||||
ScoreDetails: embeddings.SearchScoreDetails{Score: 762, SimilarityScore: 762},
|
||||
}},
|
||||
}, results)
|
||||
}
|
||||
|
||||
t.Run("IsContextRequiredForChatQuery", func(t *testing.T) {
|
||||
res, err := client.IsContextRequiredForChatQuery(context.Background(), embeddings.IsContextRequiredForChatQueryParameters{
|
||||
Query: "context detection",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.True(t, res)
|
||||
})
|
||||
}
|
||||
@ -2,145 +2,66 @@ package shared
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
"github.com/sourcegraph/log"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/enterprise/internal/embeddings"
|
||||
"github.com/sourcegraph/sourcegraph/internal/api"
|
||||
"github.com/sourcegraph/sourcegraph/lib/errors"
|
||||
)
|
||||
|
||||
type readFileFn func(ctx context.Context, repoName api.RepoName, revision api.CommitID, fileName string) ([]byte, error)
|
||||
const SIMILARITY_SEARCH_MIN_ROWS_TO_SPLIT = 1000
|
||||
|
||||
type getRepoEmbeddingIndexFn func(ctx context.Context, repoName api.RepoName) (*embeddings.RepoEmbeddingIndex, error)
|
||||
type getQueryEmbeddingFn func(ctx context.Context, query string) ([]float32, error)
|
||||
|
||||
func searchRepoEmbeddingIndex(
|
||||
func searchRepoEmbeddingIndexes(
|
||||
ctx context.Context,
|
||||
logger log.Logger,
|
||||
params embeddings.EmbeddingsSearchParameters,
|
||||
readFile readFileFn,
|
||||
getRepoEmbeddingIndex getRepoEmbeddingIndexFn,
|
||||
getQueryEmbedding getQueryEmbeddingFn,
|
||||
weaviate *weaviateClient,
|
||||
) (*embeddings.EmbeddingSearchResults, error) {
|
||||
if weaviate.Use(ctx) {
|
||||
return weaviate.Search(ctx, params)
|
||||
}
|
||||
|
||||
embeddingIndex, err := getRepoEmbeddingIndex(ctx, params.RepoName)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "getting repo embedding index for repo %q", params.RepoName)
|
||||
}
|
||||
|
||||
) (*embeddings.EmbeddingCombinedSearchResults, error) {
|
||||
floatQuery, err := getQueryEmbedding(ctx, params.Query)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "getting query embedding")
|
||||
}
|
||||
embeddedQuery := embeddings.Quantize(floatQuery)
|
||||
|
||||
opts := embeddings.SearchOptions{
|
||||
Debug: params.Debug,
|
||||
workerOpts := embeddings.WorkerOptions{
|
||||
NumWorkers: runtime.GOMAXPROCS(0),
|
||||
MinRowsToSplit: SIMILARITY_SEARCH_MIN_ROWS_TO_SPLIT,
|
||||
}
|
||||
|
||||
searchOpts := embeddings.SearchOptions{
|
||||
UseDocumentRanks: params.UseDocumentRanks,
|
||||
}
|
||||
|
||||
codeResults := searchEmbeddingIndex(ctx, logger, embeddingIndex.RepoName, embeddingIndex.Revision, &embeddingIndex.CodeIndex, readFile, embeddedQuery, params.CodeResultsCount, opts)
|
||||
textResults := searchEmbeddingIndex(ctx, logger, embeddingIndex.RepoName, embeddingIndex.Revision, &embeddingIndex.TextIndex, readFile, embeddedQuery, params.TextResultsCount, opts)
|
||||
var result embeddings.EmbeddingCombinedSearchResults
|
||||
|
||||
return &embeddings.EmbeddingSearchResults{CodeResults: codeResults, TextResults: textResults}, nil
|
||||
}
|
||||
|
||||
const SIMILARITY_SEARCH_MIN_ROWS_TO_SPLIT = 1000
|
||||
|
||||
func searchEmbeddingIndex(
|
||||
ctx context.Context,
|
||||
logger log.Logger,
|
||||
repoName api.RepoName,
|
||||
revision api.CommitID,
|
||||
index *embeddings.EmbeddingIndex,
|
||||
readFile readFileFn,
|
||||
query []int8,
|
||||
nResults int,
|
||||
opts embeddings.SearchOptions,
|
||||
) []embeddings.EmbeddingSearchResult {
|
||||
numWorkers := runtime.GOMAXPROCS(0)
|
||||
results := index.SimilaritySearch(query, nResults, embeddings.WorkerOptions{NumWorkers: numWorkers, MinRowsToSplit: SIMILARITY_SEARCH_MIN_ROWS_TO_SPLIT}, opts)
|
||||
|
||||
return filterAndHydrateContent(
|
||||
ctx,
|
||||
logger,
|
||||
repoName,
|
||||
revision,
|
||||
readFile,
|
||||
opts.Debug,
|
||||
results,
|
||||
)
|
||||
}
|
||||
|
||||
// filterAndHydrateContent will mutate unfiltered to populate the Content
|
||||
// field. If we fail to read a file (eg permission issues) we will remove the
|
||||
// result. As such the returned slice should be used.
|
||||
func filterAndHydrateContent(
|
||||
ctx context.Context,
|
||||
logger log.Logger,
|
||||
repoName api.RepoName,
|
||||
revision api.CommitID,
|
||||
readFile readFileFn,
|
||||
debug bool,
|
||||
unfiltered []embeddings.SimilaritySearchResult,
|
||||
) []embeddings.EmbeddingSearchResult {
|
||||
filtered := make([]embeddings.EmbeddingSearchResult, 0, len(unfiltered))
|
||||
|
||||
for _, result := range unfiltered {
|
||||
fileContent, err := readFile(ctx, repoName, revision, result.FileName)
|
||||
if err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
logger.Error("error reading file", log.String("repoName", string(repoName)), log.String("revision", string(revision)), log.String("fileName", result.FileName), log.Error(err))
|
||||
for i, repoName := range params.RepoNames {
|
||||
if weaviate.Use(ctx) {
|
||||
codeResults, textResults, err := weaviate.Search(ctx, repoName, params.RepoIDs[i], params.Query, params.CodeResultsCount, params.TextResultsCount)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result.CodeResults.MergeTruncate(codeResults, params.CodeResultsCount)
|
||||
result.TextResults.MergeTruncate(textResults, params.TextResultsCount)
|
||||
continue
|
||||
}
|
||||
lines := strings.Split(string(fileContent), "\n")
|
||||
|
||||
// Sanity check: check that startLine and endLine are within 0 and len(lines).
|
||||
startLine := max(0, min(len(lines), result.StartLine))
|
||||
endLine := max(0, min(len(lines), result.EndLine))
|
||||
|
||||
content := strings.Join(lines[startLine:endLine], "\n")
|
||||
|
||||
var debugString string
|
||||
if debug {
|
||||
debugString = fmt.Sprintf("score:%d, similarity:%d, rank:%d", result.Score(), result.SimilarityScore, result.RankScore)
|
||||
embeddingIndex, err := getRepoEmbeddingIndex(ctx, repoName)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "getting repo embedding index for repo %q", repoName)
|
||||
}
|
||||
|
||||
filtered = append(filtered, embeddings.EmbeddingSearchResult{
|
||||
RepoName: repoName,
|
||||
Revision: revision,
|
||||
RepoEmbeddingRowMetadata: embeddings.RepoEmbeddingRowMetadata{
|
||||
FileName: result.FileName,
|
||||
StartLine: startLine,
|
||||
EndLine: endLine,
|
||||
},
|
||||
Debug: debugString,
|
||||
Content: content,
|
||||
})
|
||||
codeResults := embeddingIndex.CodeIndex.SimilaritySearch(embeddedQuery, params.CodeResultsCount, workerOpts, searchOpts, embeddingIndex.RepoName, embeddingIndex.Revision)
|
||||
textResults := embeddingIndex.TextIndex.SimilaritySearch(embeddedQuery, params.TextResultsCount, workerOpts, searchOpts, embeddingIndex.RepoName, embeddingIndex.Revision)
|
||||
|
||||
result.CodeResults.MergeTruncate(codeResults, params.CodeResultsCount)
|
||||
result.TextResults.MergeTruncate(textResults, params.TextResultsCount)
|
||||
|
||||
}
|
||||
|
||||
return filtered
|
||||
}
|
||||
|
||||
func min(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func max(a, b int) int {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
@ -1,38 +0,0 @@
|
||||
package shared
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/sourcegraph/log"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/enterprise/internal/embeddings"
|
||||
"github.com/sourcegraph/sourcegraph/internal/api"
|
||||
)
|
||||
|
||||
func TestFilterAndHydrateContent_emptyFile(t *testing.T) {
|
||||
// Set up test data
|
||||
repoName := api.RepoName("example/repo")
|
||||
revision := api.CommitID("abc123")
|
||||
debug := false
|
||||
unfiltered := []embeddings.SimilaritySearchResult{
|
||||
{
|
||||
RepoEmbeddingRowMetadata: embeddings.RepoEmbeddingRowMetadata{
|
||||
FileName: "file.txt",
|
||||
StartLine: 5,
|
||||
EndLine: 20,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Define a mock readFile function that returns an empty string
|
||||
readFile := func(ctx context.Context, repoName api.RepoName, revision api.CommitID, fileName string) ([]byte, error) {
|
||||
return []byte{}, nil
|
||||
}
|
||||
|
||||
filtered := filterAndHydrateContent(context.Background(), log.NoOp(), repoName, revision, readFile, debug, unfiltered)
|
||||
|
||||
if len(filtered) != 1 {
|
||||
t.Errorf("Expected 1 filtered result, but got %d elements", len(filtered))
|
||||
}
|
||||
}
|
||||
@ -19,7 +19,6 @@ import (
|
||||
|
||||
type weaviateClient struct {
|
||||
logger log.Logger
|
||||
readFile readFileFn
|
||||
getQueryEmbedding getQueryEmbeddingFn
|
||||
|
||||
client *weaviate.Client
|
||||
@ -28,7 +27,6 @@ type weaviateClient struct {
|
||||
|
||||
func newWeaviateClient(
|
||||
logger log.Logger,
|
||||
readFile readFileFn,
|
||||
getQueryEmbedding getQueryEmbeddingFn,
|
||||
url *url.URL,
|
||||
) *weaviateClient {
|
||||
@ -45,7 +43,6 @@ func newWeaviateClient(
|
||||
|
||||
return &weaviateClient{
|
||||
logger: logger.Scoped("weaviate", "client for weaviate embedding index"),
|
||||
readFile: readFile,
|
||||
getQueryEmbedding: getQueryEmbedding,
|
||||
client: client,
|
||||
clientErr: err,
|
||||
@ -56,14 +53,14 @@ func (w *weaviateClient) Use(ctx context.Context) bool {
|
||||
return featureflag.FromContext(ctx).GetBoolOr("search-weaviate", false)
|
||||
}
|
||||
|
||||
func (w *weaviateClient) Search(ctx context.Context, params embeddings.EmbeddingsSearchParameters) (*embeddings.EmbeddingSearchResults, error) {
|
||||
func (w *weaviateClient) Search(ctx context.Context, repoName api.RepoName, repoID api.RepoID, query string, codeResultsCount, textResultsCount int) (codeResults, textResults []embeddings.EmbeddingSearchResult, _ error) {
|
||||
if w.clientErr != nil {
|
||||
return nil, w.clientErr
|
||||
return nil, nil, w.clientErr
|
||||
}
|
||||
|
||||
embeddedQuery, err := w.getQueryEmbedding(ctx, params.Query)
|
||||
embeddedQuery, err := w.getQueryEmbedding(ctx, query)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "getting query embedding")
|
||||
return nil, nil, errors.Wrap(err, "getting query embedding")
|
||||
}
|
||||
|
||||
queryBuilder := func(klass string, limit int) *graphql.GetBuilder {
|
||||
@ -75,6 +72,9 @@ func (w *weaviateClient) Search(ctx context.Context, params embeddings.Embedding
|
||||
{Name: "start_line"},
|
||||
{Name: "end_line"},
|
||||
{Name: "revision"},
|
||||
{Name: "_additional", Fields: []graphql.Field{
|
||||
{Name: "distance"},
|
||||
}},
|
||||
}...).
|
||||
WithLimit(limit)
|
||||
}
|
||||
@ -86,7 +86,7 @@ func (w *weaviateClient) Search(ctx context.Context, params embeddings.Embedding
|
||||
return nil
|
||||
}
|
||||
|
||||
srs := make([]embeddings.SimilaritySearchResult, 0, len(code))
|
||||
srs := make([]embeddings.EmbeddingSearchResult, 0, len(code))
|
||||
revision := ""
|
||||
for _, c := range code {
|
||||
cMap := c.(map[string]any)
|
||||
@ -96,50 +96,48 @@ func (w *weaviateClient) Search(ctx context.Context, params embeddings.Embedding
|
||||
if revision == "" {
|
||||
revision = rev
|
||||
} else {
|
||||
w.logger.Warn("inconsistent revisions returned for an embedded repository", log.Int("repoid", int(params.RepoID)), log.String("filename", fileName), log.String("revision1", revision), log.String("revision2", rev))
|
||||
w.logger.Warn("inconsistent revisions returned for an embedded repository", log.Int("repoid", int(repoID)), log.String("filename", fileName), log.String("revision1", revision), log.String("revision2", rev))
|
||||
}
|
||||
}
|
||||
|
||||
srs = append(srs, embeddings.SimilaritySearchResult{
|
||||
RepoEmbeddingRowMetadata: embeddings.RepoEmbeddingRowMetadata{
|
||||
FileName: fileName,
|
||||
StartLine: int(cMap["start_line"].(float64)),
|
||||
EndLine: int(cMap["end_line"].(float64)),
|
||||
// multiply by half max int32 since distance will always be between 0 and 2
|
||||
similarity := int32(cMap["_additional"].(map[string]any)["distance"].(float64) * (1073741823))
|
||||
|
||||
srs = append(srs, embeddings.EmbeddingSearchResult{
|
||||
RepoName: repoName,
|
||||
Revision: api.CommitID(revision),
|
||||
FileName: fileName,
|
||||
StartLine: int(cMap["start_line"].(float64)),
|
||||
EndLine: int(cMap["end_line"].(float64)),
|
||||
ScoreDetails: embeddings.SearchScoreDetails{
|
||||
Score: similarity,
|
||||
SimilarityScore: similarity,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
commit := api.CommitID(revision)
|
||||
if commit == "" {
|
||||
w.logger.Warn("no revision set for an embedded repository", log.Int("repoid", int(params.RepoID)))
|
||||
commit = api.CommitID("HEAD")
|
||||
}
|
||||
|
||||
return filterAndHydrateContent(ctx, w.logger, params.RepoName, commit, w.readFile, false, srs)
|
||||
return srs
|
||||
}
|
||||
|
||||
// We partition the indexes by type and repository. Each class in
|
||||
// weaviate is its own index, so we achieve partitioning by a class
|
||||
// per repo and type.
|
||||
codeClass := fmt.Sprintf("Code_%d", params.RepoID)
|
||||
textClass := fmt.Sprintf("Text_%d", params.RepoID)
|
||||
codeClass := fmt.Sprintf("Code_%d", repoID)
|
||||
textClass := fmt.Sprintf("Text_%d", repoID)
|
||||
|
||||
res, err := w.client.GraphQL().MultiClassGet().
|
||||
AddQueryClass(queryBuilder(codeClass, params.CodeResultsCount)).
|
||||
AddQueryClass(queryBuilder(textClass, params.TextResultsCount)).
|
||||
AddQueryClass(queryBuilder(codeClass, codeResultsCount)).
|
||||
AddQueryClass(queryBuilder(textClass, textResultsCount)).
|
||||
Do(ctx)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "doing weaviate request")
|
||||
return nil, nil, errors.Wrap(err, "doing weaviate request")
|
||||
}
|
||||
|
||||
if len(res.Errors) > 0 {
|
||||
return nil, weaviateGraphQLError(res.Errors)
|
||||
return nil, nil, weaviateGraphQLError(res.Errors)
|
||||
}
|
||||
|
||||
return &embeddings.EmbeddingSearchResults{
|
||||
CodeResults: extractResults(res, codeClass),
|
||||
TextResults: extractResults(res, textClass),
|
||||
}, nil
|
||||
return extractResults(res, codeClass), extractResults(res, textClass), nil
|
||||
}
|
||||
|
||||
type weaviateGraphQLError []*models.GraphQLError
|
||||
|
||||
@ -26,8 +26,7 @@ func Init(
|
||||
repoEmbeddingsStore := repo.NewRepoEmbeddingJobsStore(db)
|
||||
contextDetectionEmbeddingsStore := contextdetection.NewContextDetectionEmbeddingJobsStore(db)
|
||||
gitserverClient := gitserver.NewClient()
|
||||
embeddingsClient := embeddings.NewClient()
|
||||
|
||||
embeddingsClient := embeddings.NewDefaultClient()
|
||||
enterpriseServices.EmbeddingsResolver = resolvers.NewResolver(
|
||||
db,
|
||||
observationCtx.Logger,
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
load("@io_bazel_rules_go//go:def.bzl", "go_library")
|
||||
load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
|
||||
|
||||
go_library(
|
||||
name = "resolvers",
|
||||
@ -17,6 +17,7 @@ go_library(
|
||||
"//enterprise/internal/embeddings/background/repo",
|
||||
"//internal/api",
|
||||
"//internal/auth",
|
||||
"//internal/authz",
|
||||
"//internal/cody",
|
||||
"//internal/conf",
|
||||
"//internal/database",
|
||||
@ -26,6 +27,31 @@ go_library(
|
||||
"//lib/errors",
|
||||
"@com_github_graph_gophers_graphql_go//:graphql-go",
|
||||
"@com_github_graph_gophers_graphql_go//relay",
|
||||
"@com_github_sourcegraph_conc//pool",
|
||||
"@com_github_sourcegraph_log//:log",
|
||||
],
|
||||
)
|
||||
|
||||
go_test(
|
||||
name = "resolvers_test",
|
||||
srcs = ["resolvers_test.go"],
|
||||
embed = [":resolvers"],
|
||||
deps = [
|
||||
"//cmd/frontend/graphqlbackend",
|
||||
"//enterprise/internal/embeddings",
|
||||
"//enterprise/internal/embeddings/background/contextdetection",
|
||||
"//enterprise/internal/embeddings/background/repo",
|
||||
"//internal/actor",
|
||||
"//internal/api",
|
||||
"//internal/authz",
|
||||
"//internal/conf",
|
||||
"//internal/database",
|
||||
"//internal/featureflag",
|
||||
"//internal/gitserver",
|
||||
"//internal/types",
|
||||
"//schema",
|
||||
"@com_github_graph_gophers_graphql_go//:graphql-go",
|
||||
"@com_github_sourcegraph_log//logtest",
|
||||
"@com_github_stretchr_testify//require",
|
||||
],
|
||||
)
|
||||
|
||||
@ -1,8 +1,12 @@
|
||||
package resolvers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"os"
|
||||
|
||||
"github.com/graph-gophers/graphql-go"
|
||||
"github.com/sourcegraph/conc/pool"
|
||||
"github.com/sourcegraph/log"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/lib/errors"
|
||||
@ -15,6 +19,7 @@ import (
|
||||
repobg "github.com/sourcegraph/sourcegraph/enterprise/internal/embeddings/background/repo"
|
||||
"github.com/sourcegraph/sourcegraph/internal/api"
|
||||
"github.com/sourcegraph/sourcegraph/internal/auth"
|
||||
"github.com/sourcegraph/sourcegraph/internal/authz"
|
||||
"github.com/sourcegraph/sourcegraph/internal/cody"
|
||||
"github.com/sourcegraph/sourcegraph/internal/conf"
|
||||
"github.com/sourcegraph/sourcegraph/internal/database"
|
||||
@ -25,7 +30,7 @@ func NewResolver(
|
||||
db database.DB,
|
||||
logger log.Logger,
|
||||
gitserverClient gitserver.Client,
|
||||
embeddingsClient *embeddings.Client,
|
||||
embeddingsClient embeddings.Client,
|
||||
repoStore repobg.RepoEmbeddingJobsStore,
|
||||
contextDetectionStore contextdetectionbg.ContextDetectionEmbeddingJobsStore,
|
||||
) graphqlbackend.EmbeddingsResolver {
|
||||
@ -43,13 +48,22 @@ type Resolver struct {
|
||||
db database.DB
|
||||
logger log.Logger
|
||||
gitserverClient gitserver.Client
|
||||
embeddingsClient *embeddings.Client
|
||||
embeddingsClient embeddings.Client
|
||||
repoEmbeddingJobsStore repobg.RepoEmbeddingJobsStore
|
||||
contextDetectionJobsStore contextdetectionbg.ContextDetectionEmbeddingJobsStore
|
||||
emails backend.UserEmailsService
|
||||
}
|
||||
|
||||
func (r *Resolver) EmbeddingsSearch(ctx context.Context, args graphqlbackend.EmbeddingsSearchInputArgs) (graphqlbackend.EmbeddingsSearchResultsResolver, error) {
|
||||
return r.EmbeddingsMultiSearch(ctx, graphqlbackend.EmbeddingsMultiSearchInputArgs{
|
||||
Repos: []graphql.ID{args.Repo},
|
||||
Query: args.Query,
|
||||
CodeResultsCount: args.CodeResultsCount,
|
||||
TextResultsCount: args.TextResultsCount,
|
||||
})
|
||||
}
|
||||
|
||||
func (r *Resolver) EmbeddingsMultiSearch(ctx context.Context, args graphqlbackend.EmbeddingsMultiSearchInputArgs) (graphqlbackend.EmbeddingsSearchResultsResolver, error) {
|
||||
if !conf.EmbeddingsEnabled() {
|
||||
return nil, errors.New("embeddings are not configured or disabled")
|
||||
}
|
||||
@ -62,19 +76,28 @@ func (r *Resolver) EmbeddingsSearch(ctx context.Context, args graphqlbackend.Emb
|
||||
return nil, err
|
||||
}
|
||||
|
||||
repoID, err := graphqlbackend.UnmarshalRepositoryID(args.Repo)
|
||||
repoIDs := make([]api.RepoID, len(args.Repos))
|
||||
for i, repo := range args.Repos {
|
||||
repoID, err := graphqlbackend.UnmarshalRepositoryID(repo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
repoIDs[i] = repoID
|
||||
}
|
||||
|
||||
repos, err := r.db.Repos().GetByIDs(ctx, repoIDs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
repo, err := r.db.Repos().Get(ctx, repoID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
repoNames := make([]api.RepoName, len(repos))
|
||||
for i, repo := range repos {
|
||||
repoNames[i] = repo.Name
|
||||
}
|
||||
|
||||
results, err := r.embeddingsClient.Search(ctx, embeddings.EmbeddingsSearchParameters{
|
||||
RepoName: repo.Name,
|
||||
RepoID: repoID,
|
||||
RepoNames: repoNames,
|
||||
RepoIDs: repoIDs,
|
||||
Query: args.Query,
|
||||
CodeResultsCount: int(args.CodeResultsCount),
|
||||
TextResultsCount: int(args.TextResultsCount),
|
||||
@ -83,7 +106,11 @@ func (r *Resolver) EmbeddingsSearch(ctx context.Context, args graphqlbackend.Emb
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &embeddingsSearchResultsResolver{results}, nil
|
||||
return &embeddingsSearchResultsResolver{
|
||||
results: results,
|
||||
gitserver: r.gitserverClient,
|
||||
logger: r.logger,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *Resolver) IsContextRequiredForChatQuery(ctx context.Context, args graphqlbackend.IsContextRequiredForChatQueryInputArgs) (bool, error) {
|
||||
@ -183,27 +210,91 @@ func (r *Resolver) ScheduleContextDetectionForEmbedding(ctx context.Context) (*g
|
||||
}
|
||||
|
||||
type embeddingsSearchResultsResolver struct {
|
||||
results *embeddings.EmbeddingSearchResults
|
||||
results *embeddings.EmbeddingCombinedSearchResults
|
||||
gitserver gitserver.Client
|
||||
logger log.Logger
|
||||
}
|
||||
|
||||
func (r *embeddingsSearchResultsResolver) CodeResults(ctx context.Context) []graphqlbackend.EmbeddingsSearchResultResolver {
|
||||
codeResults := make([]graphqlbackend.EmbeddingsSearchResultResolver, len(r.results.CodeResults))
|
||||
for idx, result := range r.results.CodeResults {
|
||||
codeResults[idx] = &embeddingsSearchResultResolver{result}
|
||||
}
|
||||
return codeResults
|
||||
func (r *embeddingsSearchResultsResolver) CodeResults(ctx context.Context) ([]graphqlbackend.EmbeddingsSearchResultResolver, error) {
|
||||
return embeddingsSearchResultsToResolvers(ctx, r.logger, r.gitserver, r.results.CodeResults)
|
||||
}
|
||||
|
||||
func (r *embeddingsSearchResultsResolver) TextResults(ctx context.Context) []graphqlbackend.EmbeddingsSearchResultResolver {
|
||||
textResults := make([]graphqlbackend.EmbeddingsSearchResultResolver, len(r.results.TextResults))
|
||||
for idx, result := range r.results.TextResults {
|
||||
textResults[idx] = &embeddingsSearchResultResolver{result}
|
||||
func (r *embeddingsSearchResultsResolver) TextResults(ctx context.Context) ([]graphqlbackend.EmbeddingsSearchResultResolver, error) {
|
||||
return embeddingsSearchResultsToResolvers(ctx, r.logger, r.gitserver, r.results.TextResults)
|
||||
}
|
||||
|
||||
func embeddingsSearchResultsToResolvers(
|
||||
ctx context.Context,
|
||||
logger log.Logger,
|
||||
gs gitserver.Client,
|
||||
results []embeddings.EmbeddingSearchResult,
|
||||
) ([]graphqlbackend.EmbeddingsSearchResultResolver, error) {
|
||||
|
||||
allContents := make([][]byte, len(results))
|
||||
allErrors := make([]error, len(results))
|
||||
{ // Fetch contents in parallel because fetching them serially can be slow.
|
||||
p := pool.New().WithMaxGoroutines(8)
|
||||
for i, result := range results {
|
||||
i, result := i, result
|
||||
p.Go(func() {
|
||||
content, err := gs.ReadFile(ctx, authz.DefaultSubRepoPermsChecker, result.RepoName, result.Revision, result.FileName)
|
||||
allContents[i] = content
|
||||
allErrors[i] = err
|
||||
})
|
||||
}
|
||||
p.Wait()
|
||||
}
|
||||
return textResults
|
||||
|
||||
resolvers := make([]graphqlbackend.EmbeddingsSearchResultResolver, 0, len(results))
|
||||
{ // Merge the results with their contents, skipping any that errored when fetching the context.
|
||||
for i, result := range results {
|
||||
contents := allContents[i]
|
||||
err := allErrors[i]
|
||||
if err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
logger.Error(
|
||||
"error reading file",
|
||||
log.String("repoName", string(result.RepoName)),
|
||||
log.String("revision", string(result.Revision)),
|
||||
log.String("fileName", result.FileName),
|
||||
log.Error(err),
|
||||
)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
resolvers = append(resolvers, &embeddingsSearchResultResolver{
|
||||
result: result,
|
||||
content: string(extractLineRange(contents, result.StartLine, result.EndLine)),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return resolvers, nil
|
||||
}
|
||||
|
||||
func extractLineRange(content []byte, startLine, endLine int) []byte {
|
||||
lines := bytes.Split(content, []byte("\n"))
|
||||
|
||||
// Sanity check: check that startLine and endLine are within 0 and len(lines).
|
||||
startLine = clamp(startLine, 0, len(lines))
|
||||
endLine = clamp(endLine, 0, len(lines))
|
||||
|
||||
return bytes.Join(lines[startLine:endLine], []byte("\n"))
|
||||
}
|
||||
|
||||
func clamp(input, min, max int) int {
|
||||
if input > max {
|
||||
return max
|
||||
} else if input < min {
|
||||
return min
|
||||
}
|
||||
return input
|
||||
}
|
||||
|
||||
type embeddingsSearchResultResolver struct {
|
||||
result embeddings.EmbeddingSearchResult
|
||||
result embeddings.EmbeddingSearchResult
|
||||
content string
|
||||
}
|
||||
|
||||
func (r *embeddingsSearchResultResolver) RepoName(ctx context.Context) string {
|
||||
@ -227,5 +318,5 @@ func (r *embeddingsSearchResultResolver) EndLine(ctx context.Context) int32 {
|
||||
}
|
||||
|
||||
func (r *embeddingsSearchResultResolver) Content(ctx context.Context) string {
|
||||
return r.result.Content
|
||||
return r.content
|
||||
}
|
||||
|
||||
@ -0,0 +1,122 @@
|
||||
package resolvers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/graph-gophers/graphql-go"
|
||||
"github.com/sourcegraph/log/logtest"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/graphqlbackend"
|
||||
"github.com/sourcegraph/sourcegraph/enterprise/internal/embeddings"
|
||||
"github.com/sourcegraph/sourcegraph/enterprise/internal/embeddings/background/contextdetection"
|
||||
"github.com/sourcegraph/sourcegraph/enterprise/internal/embeddings/background/repo"
|
||||
"github.com/sourcegraph/sourcegraph/internal/actor"
|
||||
"github.com/sourcegraph/sourcegraph/internal/api"
|
||||
"github.com/sourcegraph/sourcegraph/internal/authz"
|
||||
"github.com/sourcegraph/sourcegraph/internal/conf"
|
||||
"github.com/sourcegraph/sourcegraph/internal/database"
|
||||
"github.com/sourcegraph/sourcegraph/internal/featureflag"
|
||||
"github.com/sourcegraph/sourcegraph/internal/gitserver"
|
||||
"github.com/sourcegraph/sourcegraph/internal/types"
|
||||
"github.com/sourcegraph/sourcegraph/schema"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestEmbeddingSearchResolver(t *testing.T) {
|
||||
logger := logtest.Scoped(t)
|
||||
|
||||
mockDB := database.NewMockDB()
|
||||
mockRepos := database.NewMockRepoStore()
|
||||
mockRepos.GetByIDsFunc.SetDefaultReturn([]*types.Repo{{ID: 1, Name: "repo1"}}, nil)
|
||||
mockDB.ReposFunc.SetDefaultReturn(mockRepos)
|
||||
|
||||
mockGitserver := gitserver.NewMockClient()
|
||||
mockGitserver.ReadFileFunc.SetDefaultHook(func(_ context.Context, _ authz.SubRepoPermissionChecker, _ api.RepoName, _ api.CommitID, fileName string) ([]byte, error) {
|
||||
if fileName == "testfile" {
|
||||
return []byte("test\nfirst\nfour\nlines\nplus\nsome\nmore"), nil
|
||||
}
|
||||
return nil, os.ErrNotExist
|
||||
})
|
||||
|
||||
mockEmbeddingsClient := embeddings.NewMockClient()
|
||||
mockEmbeddingsClient.SearchFunc.SetDefaultReturn(&embeddings.EmbeddingCombinedSearchResults{
|
||||
CodeResults: embeddings.EmbeddingSearchResults{{
|
||||
FileName: "testfile",
|
||||
StartLine: 0,
|
||||
EndLine: 4,
|
||||
}, {
|
||||
FileName: "censored",
|
||||
StartLine: 0,
|
||||
EndLine: 4,
|
||||
}},
|
||||
}, nil)
|
||||
|
||||
repoEmbeddingJobsStore := repo.NewMockRepoEmbeddingJobsStore()
|
||||
contextDetectionJobsStore := contextdetection.NewMockContextDetectionEmbeddingJobsStore()
|
||||
|
||||
resolver := NewResolver(
|
||||
mockDB,
|
||||
logger,
|
||||
mockGitserver,
|
||||
mockEmbeddingsClient,
|
||||
repoEmbeddingJobsStore,
|
||||
contextDetectionJobsStore,
|
||||
)
|
||||
|
||||
conf.Mock(&conf.Unified{
|
||||
SiteConfiguration: schema.SiteConfiguration{
|
||||
Embeddings: &schema.Embeddings{Enabled: true},
|
||||
Completions: &schema.Completions{Enabled: true},
|
||||
},
|
||||
})
|
||||
|
||||
ctx := actor.WithActor(context.Background(), actor.FromMockUser(1))
|
||||
ffs := featureflag.NewMemoryStore(map[string]bool{"cody-experimental": true}, nil, nil)
|
||||
ctx = featureflag.WithFlags(ctx, ffs)
|
||||
|
||||
results, err := resolver.EmbeddingsMultiSearch(ctx, graphqlbackend.EmbeddingsMultiSearchInputArgs{
|
||||
Repos: []graphql.ID{graphqlbackend.MarshalRepositoryID(3)},
|
||||
Query: "test",
|
||||
CodeResultsCount: 1,
|
||||
TextResultsCount: 1,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
codeResults, err := results.CodeResults(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, codeResults, 1)
|
||||
require.Equal(t, "test\nfirst\nfour\nlines", codeResults[0].Content(ctx))
|
||||
|
||||
}
|
||||
|
||||
func Test_extractLineRange(t *testing.T) {
|
||||
cases := []struct {
|
||||
input []byte
|
||||
start, end int
|
||||
output []byte
|
||||
}{{
|
||||
[]byte("zero\none\ntwo\nthree\n"),
|
||||
0, 2,
|
||||
[]byte("zero\none"),
|
||||
}, {
|
||||
[]byte("zero\none\ntwo\nthree\n"),
|
||||
1, 2,
|
||||
[]byte("one"),
|
||||
}, {
|
||||
[]byte("zero\none\ntwo\nthree\n"),
|
||||
1, 2,
|
||||
[]byte("one"),
|
||||
}, {
|
||||
[]byte(""),
|
||||
1, 2,
|
||||
[]byte(""),
|
||||
}}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run("", func(t *testing.T) {
|
||||
got := extractLineRange(tc.input, tc.start, tc.end)
|
||||
require.Equal(t, tc.output, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
2
enterprise/internal/embeddings/BUILD.bazel
generated
2
enterprise/internal/embeddings/BUILD.bazel
generated
@ -12,6 +12,7 @@ go_library(
|
||||
"dot_portable.go",
|
||||
"index_name.go",
|
||||
"index_storage.go",
|
||||
"mocks_temp.go",
|
||||
"quantize.go",
|
||||
"similarity_search.go",
|
||||
"tokens.go",
|
||||
@ -31,6 +32,7 @@ go_library(
|
||||
"//internal/uploadstore",
|
||||
"//lib/errors",
|
||||
"@com_github_sourcegraph_conc//:conc",
|
||||
"@com_github_sourcegraph_conc//pool",
|
||||
"@com_github_sourcegraph_log//:log",
|
||||
"@org_golang_x_sync//errgroup",
|
||||
] + select({
|
||||
|
||||
@ -3,6 +3,7 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library")
|
||||
go_library(
|
||||
name = "contextdetection",
|
||||
srcs = [
|
||||
"mocks_temp.go",
|
||||
"store.go",
|
||||
"types.go",
|
||||
],
|
||||
|
||||
@ -0,0 +1,292 @@
|
||||
// Code generated by go-mockgen 1.3.7; DO NOT EDIT.
|
||||
//
|
||||
// This file was generated by running `sg generate` (or `go-mockgen`) at the root of
|
||||
// this repository. To add additional mocks to this or another package, add a new entry
|
||||
// to the mockgen.yaml file in the root of this repository.
|
||||
|
||||
package contextdetection
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
basestore "github.com/sourcegraph/sourcegraph/internal/database/basestore"
|
||||
)
|
||||
|
||||
// MockContextDetectionEmbeddingJobsStore is a mock implementation of the
|
||||
// ContextDetectionEmbeddingJobsStore interface (from the package
|
||||
// github.com/sourcegraph/sourcegraph/enterprise/internal/embeddings/background/contextdetection)
|
||||
// used for unit testing.
|
||||
type MockContextDetectionEmbeddingJobsStore struct {
|
||||
// CreateContextDetectionEmbeddingJobFunc is an instance of a mock
|
||||
// function object controlling the behavior of the method
|
||||
// CreateContextDetectionEmbeddingJob.
|
||||
CreateContextDetectionEmbeddingJobFunc *ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFunc
|
||||
// HandleFunc is an instance of a mock function object controlling the
|
||||
// behavior of the method Handle.
|
||||
HandleFunc *ContextDetectionEmbeddingJobsStoreHandleFunc
|
||||
}
|
||||
|
||||
// NewMockContextDetectionEmbeddingJobsStore creates a new mock of the
|
||||
// ContextDetectionEmbeddingJobsStore interface. All methods return zero
|
||||
// values for all results, unless overwritten.
|
||||
func NewMockContextDetectionEmbeddingJobsStore() *MockContextDetectionEmbeddingJobsStore {
|
||||
return &MockContextDetectionEmbeddingJobsStore{
|
||||
CreateContextDetectionEmbeddingJobFunc: &ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFunc{
|
||||
defaultHook: func(context.Context) (r0 int, r1 error) {
|
||||
return
|
||||
},
|
||||
},
|
||||
HandleFunc: &ContextDetectionEmbeddingJobsStoreHandleFunc{
|
||||
defaultHook: func() (r0 basestore.TransactableHandle) {
|
||||
return
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// NewStrictMockContextDetectionEmbeddingJobsStore creates a new mock of the
|
||||
// ContextDetectionEmbeddingJobsStore interface. All methods panic on
|
||||
// invocation, unless overwritten.
|
||||
func NewStrictMockContextDetectionEmbeddingJobsStore() *MockContextDetectionEmbeddingJobsStore {
|
||||
return &MockContextDetectionEmbeddingJobsStore{
|
||||
CreateContextDetectionEmbeddingJobFunc: &ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFunc{
|
||||
defaultHook: func(context.Context) (int, error) {
|
||||
panic("unexpected invocation of MockContextDetectionEmbeddingJobsStore.CreateContextDetectionEmbeddingJob")
|
||||
},
|
||||
},
|
||||
HandleFunc: &ContextDetectionEmbeddingJobsStoreHandleFunc{
|
||||
defaultHook: func() basestore.TransactableHandle {
|
||||
panic("unexpected invocation of MockContextDetectionEmbeddingJobsStore.Handle")
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// NewMockContextDetectionEmbeddingJobsStoreFrom creates a new mock of the
|
||||
// MockContextDetectionEmbeddingJobsStore interface. All methods delegate to
|
||||
// the given implementation, unless overwritten.
|
||||
func NewMockContextDetectionEmbeddingJobsStoreFrom(i ContextDetectionEmbeddingJobsStore) *MockContextDetectionEmbeddingJobsStore {
|
||||
return &MockContextDetectionEmbeddingJobsStore{
|
||||
CreateContextDetectionEmbeddingJobFunc: &ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFunc{
|
||||
defaultHook: i.CreateContextDetectionEmbeddingJob,
|
||||
},
|
||||
HandleFunc: &ContextDetectionEmbeddingJobsStoreHandleFunc{
|
||||
defaultHook: i.Handle,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFunc
|
||||
// describes the behavior when the CreateContextDetectionEmbeddingJob method
|
||||
// of the parent MockContextDetectionEmbeddingJobsStore instance is invoked.
|
||||
type ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFunc struct {
|
||||
defaultHook func(context.Context) (int, error)
|
||||
hooks []func(context.Context) (int, error)
|
||||
history []ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFuncCall
|
||||
mutex sync.Mutex
|
||||
}
|
||||
|
||||
// CreateContextDetectionEmbeddingJob delegates to the next hook function in
|
||||
// the queue and stores the parameter and result values of this invocation.
|
||||
func (m *MockContextDetectionEmbeddingJobsStore) CreateContextDetectionEmbeddingJob(v0 context.Context) (int, error) {
|
||||
r0, r1 := m.CreateContextDetectionEmbeddingJobFunc.nextHook()(v0)
|
||||
m.CreateContextDetectionEmbeddingJobFunc.appendCall(ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFuncCall{v0, r0, r1})
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// SetDefaultHook sets function that is called when the
|
||||
// CreateContextDetectionEmbeddingJob method of the parent
|
||||
// MockContextDetectionEmbeddingJobsStore instance is invoked and the hook
|
||||
// queue is empty.
|
||||
func (f *ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFunc) SetDefaultHook(hook func(context.Context) (int, error)) {
|
||||
f.defaultHook = hook
|
||||
}
|
||||
|
||||
// PushHook adds a function to the end of hook queue. Each invocation of the
|
||||
// CreateContextDetectionEmbeddingJob method of the parent
|
||||
// MockContextDetectionEmbeddingJobsStore instance invokes the hook at the
|
||||
// front of the queue and discards it. After the queue is empty, the default
|
||||
// hook function is invoked for any future action.
|
||||
func (f *ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFunc) PushHook(hook func(context.Context) (int, error)) {
|
||||
f.mutex.Lock()
|
||||
f.hooks = append(f.hooks, hook)
|
||||
f.mutex.Unlock()
|
||||
}
|
||||
|
||||
// SetDefaultReturn calls SetDefaultHook with a function that returns the
|
||||
// given values.
|
||||
func (f *ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFunc) SetDefaultReturn(r0 int, r1 error) {
|
||||
f.SetDefaultHook(func(context.Context) (int, error) {
|
||||
return r0, r1
|
||||
})
|
||||
}
|
||||
|
||||
// PushReturn calls PushHook with a function that returns the given values.
|
||||
func (f *ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFunc) PushReturn(r0 int, r1 error) {
|
||||
f.PushHook(func(context.Context) (int, error) {
|
||||
return r0, r1
|
||||
})
|
||||
}
|
||||
|
||||
func (f *ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFunc) nextHook() func(context.Context) (int, error) {
|
||||
f.mutex.Lock()
|
||||
defer f.mutex.Unlock()
|
||||
|
||||
if len(f.hooks) == 0 {
|
||||
return f.defaultHook
|
||||
}
|
||||
|
||||
hook := f.hooks[0]
|
||||
f.hooks = f.hooks[1:]
|
||||
return hook
|
||||
}
|
||||
|
||||
func (f *ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFunc) appendCall(r0 ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFuncCall) {
|
||||
f.mutex.Lock()
|
||||
f.history = append(f.history, r0)
|
||||
f.mutex.Unlock()
|
||||
}
|
||||
|
||||
// History returns a sequence of
|
||||
// ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFuncCall
|
||||
// objects describing the invocations of this function.
|
||||
func (f *ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFunc) History() []ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFuncCall {
|
||||
f.mutex.Lock()
|
||||
history := make([]ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFuncCall, len(f.history))
|
||||
copy(history, f.history)
|
||||
f.mutex.Unlock()
|
||||
|
||||
return history
|
||||
}
|
||||
|
||||
// ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFuncCall
|
||||
// is an object that describes an invocation of method
|
||||
// CreateContextDetectionEmbeddingJob on an instance of
|
||||
// MockContextDetectionEmbeddingJobsStore.
|
||||
type ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFuncCall struct {
|
||||
// Arg0 is the value of the 1st argument passed to this method
|
||||
// invocation.
|
||||
Arg0 context.Context
|
||||
// Result0 is the value of the 1st result returned from this method
|
||||
// invocation.
|
||||
Result0 int
|
||||
// Result1 is the value of the 2nd result returned from this method
|
||||
// invocation.
|
||||
Result1 error
|
||||
}
|
||||
|
||||
// Args returns an interface slice containing the arguments of this
|
||||
// invocation.
|
||||
func (c ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFuncCall) Args() []interface{} {
|
||||
return []interface{}{c.Arg0}
|
||||
}
|
||||
|
||||
// Results returns an interface slice containing the results of this
|
||||
// invocation.
|
||||
func (c ContextDetectionEmbeddingJobsStoreCreateContextDetectionEmbeddingJobFuncCall) Results() []interface{} {
|
||||
return []interface{}{c.Result0, c.Result1}
|
||||
}
|
||||
|
||||
// ContextDetectionEmbeddingJobsStoreHandleFunc describes the behavior when
|
||||
// the Handle method of the parent MockContextDetectionEmbeddingJobsStore
|
||||
// instance is invoked.
|
||||
type ContextDetectionEmbeddingJobsStoreHandleFunc struct {
|
||||
defaultHook func() basestore.TransactableHandle
|
||||
hooks []func() basestore.TransactableHandle
|
||||
history []ContextDetectionEmbeddingJobsStoreHandleFuncCall
|
||||
mutex sync.Mutex
|
||||
}
|
||||
|
||||
// Handle delegates to the next hook function in the queue and stores the
|
||||
// parameter and result values of this invocation.
|
||||
func (m *MockContextDetectionEmbeddingJobsStore) Handle() basestore.TransactableHandle {
|
||||
r0 := m.HandleFunc.nextHook()()
|
||||
m.HandleFunc.appendCall(ContextDetectionEmbeddingJobsStoreHandleFuncCall{r0})
|
||||
return r0
|
||||
}
|
||||
|
||||
// SetDefaultHook sets function that is called when the Handle method of the
|
||||
// parent MockContextDetectionEmbeddingJobsStore instance is invoked and the
|
||||
// hook queue is empty.
|
||||
func (f *ContextDetectionEmbeddingJobsStoreHandleFunc) SetDefaultHook(hook func() basestore.TransactableHandle) {
|
||||
f.defaultHook = hook
|
||||
}
|
||||
|
||||
// PushHook adds a function to the end of hook queue. Each invocation of the
|
||||
// Handle method of the parent MockContextDetectionEmbeddingJobsStore
|
||||
// instance invokes the hook at the front of the queue and discards it.
|
||||
// After the queue is empty, the default hook function is invoked for any
|
||||
// future action.
|
||||
func (f *ContextDetectionEmbeddingJobsStoreHandleFunc) PushHook(hook func() basestore.TransactableHandle) {
|
||||
f.mutex.Lock()
|
||||
f.hooks = append(f.hooks, hook)
|
||||
f.mutex.Unlock()
|
||||
}
|
||||
|
||||
// SetDefaultReturn calls SetDefaultHook with a function that returns the
|
||||
// given values.
|
||||
func (f *ContextDetectionEmbeddingJobsStoreHandleFunc) SetDefaultReturn(r0 basestore.TransactableHandle) {
|
||||
f.SetDefaultHook(func() basestore.TransactableHandle {
|
||||
return r0
|
||||
})
|
||||
}
|
||||
|
||||
// PushReturn calls PushHook with a function that returns the given values.
|
||||
func (f *ContextDetectionEmbeddingJobsStoreHandleFunc) PushReturn(r0 basestore.TransactableHandle) {
|
||||
f.PushHook(func() basestore.TransactableHandle {
|
||||
return r0
|
||||
})
|
||||
}
|
||||
|
||||
func (f *ContextDetectionEmbeddingJobsStoreHandleFunc) nextHook() func() basestore.TransactableHandle {
|
||||
f.mutex.Lock()
|
||||
defer f.mutex.Unlock()
|
||||
|
||||
if len(f.hooks) == 0 {
|
||||
return f.defaultHook
|
||||
}
|
||||
|
||||
hook := f.hooks[0]
|
||||
f.hooks = f.hooks[1:]
|
||||
return hook
|
||||
}
|
||||
|
||||
func (f *ContextDetectionEmbeddingJobsStoreHandleFunc) appendCall(r0 ContextDetectionEmbeddingJobsStoreHandleFuncCall) {
|
||||
f.mutex.Lock()
|
||||
f.history = append(f.history, r0)
|
||||
f.mutex.Unlock()
|
||||
}
|
||||
|
||||
// History returns a sequence of
|
||||
// ContextDetectionEmbeddingJobsStoreHandleFuncCall objects describing the
|
||||
// invocations of this function.
|
||||
func (f *ContextDetectionEmbeddingJobsStoreHandleFunc) History() []ContextDetectionEmbeddingJobsStoreHandleFuncCall {
|
||||
f.mutex.Lock()
|
||||
history := make([]ContextDetectionEmbeddingJobsStoreHandleFuncCall, len(f.history))
|
||||
copy(history, f.history)
|
||||
f.mutex.Unlock()
|
||||
|
||||
return history
|
||||
}
|
||||
|
||||
// ContextDetectionEmbeddingJobsStoreHandleFuncCall is an object that
|
||||
// describes an invocation of method Handle on an instance of
|
||||
// MockContextDetectionEmbeddingJobsStore.
|
||||
type ContextDetectionEmbeddingJobsStoreHandleFuncCall struct {
|
||||
// Result0 is the value of the 1st result returned from this method
|
||||
// invocation.
|
||||
Result0 basestore.TransactableHandle
|
||||
}
|
||||
|
||||
// Args returns an interface slice containing the arguments of this
|
||||
// invocation.
|
||||
func (c ContextDetectionEmbeddingJobsStoreHandleFuncCall) Args() []interface{} {
|
||||
return []interface{}{}
|
||||
}
|
||||
|
||||
// Results returns an interface slice containing the results of this
|
||||
// invocation.
|
||||
func (c ContextDetectionEmbeddingJobsStoreHandleFuncCall) Results() []interface{} {
|
||||
return []interface{}{c.Result0}
|
||||
}
|
||||
@ -8,6 +8,7 @@ import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/sourcegraph/conc/pool"
|
||||
"github.com/sourcegraph/sourcegraph/lib/errors"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/internal/api"
|
||||
@ -30,14 +31,23 @@ var defaultDoer = func() httpcli.Doer {
|
||||
return d
|
||||
}()
|
||||
|
||||
func NewClient() *Client {
|
||||
return &Client{
|
||||
Endpoints: defaultEndpoints(),
|
||||
HTTPClient: defaultDoer,
|
||||
func NewDefaultClient() Client {
|
||||
return NewClient(defaultEndpoints(), defaultDoer)
|
||||
}
|
||||
|
||||
func NewClient(endpoints *endpoint.Map, doer httpcli.Doer) Client {
|
||||
return &client{
|
||||
Endpoints: endpoints,
|
||||
HTTPClient: doer,
|
||||
}
|
||||
}
|
||||
|
||||
type Client struct {
|
||||
type Client interface {
|
||||
Search(context.Context, EmbeddingsSearchParameters) (*EmbeddingCombinedSearchResults, error)
|
||||
IsContextRequiredForChatQuery(context.Context, IsContextRequiredForChatQueryParameters) (bool, error)
|
||||
}
|
||||
|
||||
type client struct {
|
||||
// Endpoints to embeddings service.
|
||||
Endpoints *endpoint.Map
|
||||
|
||||
@ -46,15 +56,13 @@ type Client struct {
|
||||
}
|
||||
|
||||
type EmbeddingsSearchParameters struct {
|
||||
RepoName api.RepoName `json:"repoName"`
|
||||
RepoID api.RepoID `json:"repoID"`
|
||||
Query string `json:"query"`
|
||||
CodeResultsCount int `json:"codeResultsCount"`
|
||||
TextResultsCount int `json:"textResultsCount"`
|
||||
RepoNames []api.RepoName `json:"repoNames"`
|
||||
RepoIDs []api.RepoID `json:"repoIDs"`
|
||||
Query string `json:"query"`
|
||||
CodeResultsCount int `json:"codeResultsCount"`
|
||||
TextResultsCount int `json:"textResultsCount"`
|
||||
|
||||
UseDocumentRanks bool `json:"useDocumentRanks"`
|
||||
// If set to "True", EmbeddingSearchResult.Debug will contain useful information about scoring.
|
||||
Debug bool `json:"debug"`
|
||||
}
|
||||
|
||||
type IsContextRequiredForChatQueryParameters struct {
|
||||
@ -65,8 +73,43 @@ type IsContextRequiredForChatQueryResult struct {
|
||||
IsRequired bool `json:"isRequired"`
|
||||
}
|
||||
|
||||
func (c *Client) Search(ctx context.Context, args EmbeddingsSearchParameters) (*EmbeddingSearchResults, error) {
|
||||
resp, err := c.httpPost(ctx, "search", args.RepoName, args)
|
||||
func (c *client) Search(ctx context.Context, args EmbeddingsSearchParameters) (*EmbeddingCombinedSearchResults, error) {
|
||||
partitions, err := c.partition(args.RepoNames, args.RepoIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
p := pool.NewWithResults[*EmbeddingCombinedSearchResults]().WithContext(ctx)
|
||||
|
||||
for endpoint, partition := range partitions {
|
||||
endpoint := endpoint
|
||||
|
||||
// make a copy for this request
|
||||
args := args
|
||||
args.RepoNames = partition.repoNames
|
||||
args.RepoIDs = partition.repoIDs
|
||||
|
||||
p.Go(func(ctx context.Context) (*EmbeddingCombinedSearchResults, error) {
|
||||
return c.searchPartition(ctx, endpoint, args)
|
||||
})
|
||||
}
|
||||
|
||||
allResults, err := p.Wait()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var combinedResult EmbeddingCombinedSearchResults
|
||||
for _, result := range allResults {
|
||||
combinedResult.CodeResults.MergeTruncate(result.CodeResults, args.CodeResultsCount)
|
||||
combinedResult.TextResults.MergeTruncate(result.TextResults, args.TextResultsCount)
|
||||
}
|
||||
|
||||
return &combinedResult, nil
|
||||
}
|
||||
|
||||
func (c *client) searchPartition(ctx context.Context, endpoint string, args EmbeddingsSearchParameters) (*EmbeddingCombinedSearchResults, error) {
|
||||
resp, err := c.httpPost(ctx, "search", endpoint, args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -82,7 +125,7 @@ func (c *Client) Search(ctx context.Context, args EmbeddingsSearchParameters) (*
|
||||
)
|
||||
}
|
||||
|
||||
var response EmbeddingSearchResults
|
||||
var response EmbeddingCombinedSearchResults
|
||||
err = json.NewDecoder(resp.Body).Decode(&response)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -90,8 +133,13 @@ func (c *Client) Search(ctx context.Context, args EmbeddingsSearchParameters) (*
|
||||
return &response, nil
|
||||
}
|
||||
|
||||
func (c *Client) IsContextRequiredForChatQuery(ctx context.Context, args IsContextRequiredForChatQueryParameters) (bool, error) {
|
||||
resp, err := c.httpPost(ctx, "isContextRequiredForChatQuery", "", args)
|
||||
func (c *client) IsContextRequiredForChatQuery(ctx context.Context, args IsContextRequiredForChatQueryParameters) (bool, error) {
|
||||
endpoint, err := c.url("")
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
resp, err := c.httpPost(ctx, "isContextRequiredForChatQuery", endpoint, args)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@ -115,24 +163,50 @@ func (c *Client) IsContextRequiredForChatQuery(ctx context.Context, args IsConte
|
||||
return response.IsRequired, nil
|
||||
}
|
||||
|
||||
func (c *Client) url(repo api.RepoName) (string, error) {
|
||||
func (c *client) url(repo api.RepoName) (string, error) {
|
||||
if c.Endpoints == nil {
|
||||
return "", errors.New("an embeddings service has not been configured")
|
||||
}
|
||||
return c.Endpoints.Get(string(repo))
|
||||
}
|
||||
|
||||
func (c *Client) httpPost(
|
||||
ctx context.Context,
|
||||
method string,
|
||||
repo api.RepoName,
|
||||
payload any,
|
||||
) (resp *http.Response, err error) {
|
||||
url, err := c.url(repo)
|
||||
type repoPartition struct {
|
||||
repoNames []api.RepoName
|
||||
repoIDs []api.RepoID
|
||||
}
|
||||
|
||||
// returns a partition of the input repos by the endpoint their requests should be routed to
|
||||
func (c *client) partition(repos []api.RepoName, repoIDs []api.RepoID) (map[string]repoPartition, error) {
|
||||
if c.Endpoints == nil {
|
||||
return nil, errors.New("an embeddings service has not been configured")
|
||||
}
|
||||
|
||||
repoStrings := make([]string, len(repos))
|
||||
for i, repo := range repos {
|
||||
repoStrings[i] = string(repo)
|
||||
}
|
||||
|
||||
endpoints, err := c.Endpoints.GetMany(repoStrings...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res := make(map[string]repoPartition)
|
||||
for i, endpoint := range endpoints {
|
||||
res[endpoint] = repoPartition{
|
||||
repoNames: append(res[endpoint].repoNames, repos[i]),
|
||||
repoIDs: append(res[endpoint].repoIDs, repoIDs[i]),
|
||||
}
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (c *client) httpPost(
|
||||
ctx context.Context,
|
||||
method string,
|
||||
url string,
|
||||
payload any,
|
||||
) (resp *http.Response, err error) {
|
||||
reqBody, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
291
enterprise/internal/embeddings/mocks_temp.go
Normal file
291
enterprise/internal/embeddings/mocks_temp.go
Normal file
@ -0,0 +1,291 @@
|
||||
// Code generated by go-mockgen 1.3.7; DO NOT EDIT.
|
||||
//
|
||||
// This file was generated by running `sg generate` (or `go-mockgen`) at the root of
|
||||
// this repository. To add additional mocks to this or another package, add a new entry
|
||||
// to the mockgen.yaml file in the root of this repository.
|
||||
|
||||
package embeddings
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// MockClient is a mock implementation of the Client interface (from the
|
||||
// package
|
||||
// github.com/sourcegraph/sourcegraph/enterprise/internal/embeddings) used
|
||||
// for unit testing.
|
||||
type MockClient struct {
|
||||
// IsContextRequiredForChatQueryFunc is an instance of a mock function
|
||||
// object controlling the behavior of the method
|
||||
// IsContextRequiredForChatQuery.
|
||||
IsContextRequiredForChatQueryFunc *ClientIsContextRequiredForChatQueryFunc
|
||||
// SearchFunc is an instance of a mock function object controlling the
|
||||
// behavior of the method Search.
|
||||
SearchFunc *ClientSearchFunc
|
||||
}
|
||||
|
||||
// NewMockClient creates a new mock of the Client interface. All methods
|
||||
// return zero values for all results, unless overwritten.
|
||||
func NewMockClient() *MockClient {
|
||||
return &MockClient{
|
||||
IsContextRequiredForChatQueryFunc: &ClientIsContextRequiredForChatQueryFunc{
|
||||
defaultHook: func(context.Context, IsContextRequiredForChatQueryParameters) (r0 bool, r1 error) {
|
||||
return
|
||||
},
|
||||
},
|
||||
SearchFunc: &ClientSearchFunc{
|
||||
defaultHook: func(context.Context, EmbeddingsSearchParameters) (r0 *EmbeddingCombinedSearchResults, r1 error) {
|
||||
return
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// NewStrictMockClient creates a new mock of the Client interface. All
|
||||
// methods panic on invocation, unless overwritten.
|
||||
func NewStrictMockClient() *MockClient {
|
||||
return &MockClient{
|
||||
IsContextRequiredForChatQueryFunc: &ClientIsContextRequiredForChatQueryFunc{
|
||||
defaultHook: func(context.Context, IsContextRequiredForChatQueryParameters) (bool, error) {
|
||||
panic("unexpected invocation of MockClient.IsContextRequiredForChatQuery")
|
||||
},
|
||||
},
|
||||
SearchFunc: &ClientSearchFunc{
|
||||
defaultHook: func(context.Context, EmbeddingsSearchParameters) (*EmbeddingCombinedSearchResults, error) {
|
||||
panic("unexpected invocation of MockClient.Search")
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// NewMockClientFrom creates a new mock of the MockClient interface. All
|
||||
// methods delegate to the given implementation, unless overwritten.
|
||||
func NewMockClientFrom(i Client) *MockClient {
|
||||
return &MockClient{
|
||||
IsContextRequiredForChatQueryFunc: &ClientIsContextRequiredForChatQueryFunc{
|
||||
defaultHook: i.IsContextRequiredForChatQuery,
|
||||
},
|
||||
SearchFunc: &ClientSearchFunc{
|
||||
defaultHook: i.Search,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ClientIsContextRequiredForChatQueryFunc describes the behavior when the
|
||||
// IsContextRequiredForChatQuery method of the parent MockClient instance is
|
||||
// invoked.
|
||||
type ClientIsContextRequiredForChatQueryFunc struct {
|
||||
defaultHook func(context.Context, IsContextRequiredForChatQueryParameters) (bool, error)
|
||||
hooks []func(context.Context, IsContextRequiredForChatQueryParameters) (bool, error)
|
||||
history []ClientIsContextRequiredForChatQueryFuncCall
|
||||
mutex sync.Mutex
|
||||
}
|
||||
|
||||
// IsContextRequiredForChatQuery delegates to the next hook function in the
|
||||
// queue and stores the parameter and result values of this invocation.
|
||||
func (m *MockClient) IsContextRequiredForChatQuery(v0 context.Context, v1 IsContextRequiredForChatQueryParameters) (bool, error) {
|
||||
r0, r1 := m.IsContextRequiredForChatQueryFunc.nextHook()(v0, v1)
|
||||
m.IsContextRequiredForChatQueryFunc.appendCall(ClientIsContextRequiredForChatQueryFuncCall{v0, v1, r0, r1})
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// SetDefaultHook sets function that is called when the
|
||||
// IsContextRequiredForChatQuery method of the parent MockClient instance is
|
||||
// invoked and the hook queue is empty.
|
||||
func (f *ClientIsContextRequiredForChatQueryFunc) SetDefaultHook(hook func(context.Context, IsContextRequiredForChatQueryParameters) (bool, error)) {
|
||||
f.defaultHook = hook
|
||||
}
|
||||
|
||||
// PushHook adds a function to the end of hook queue. Each invocation of the
|
||||
// IsContextRequiredForChatQuery method of the parent MockClient instance
|
||||
// invokes the hook at the front of the queue and discards it. After the
|
||||
// queue is empty, the default hook function is invoked for any future
|
||||
// action.
|
||||
func (f *ClientIsContextRequiredForChatQueryFunc) PushHook(hook func(context.Context, IsContextRequiredForChatQueryParameters) (bool, error)) {
|
||||
f.mutex.Lock()
|
||||
f.hooks = append(f.hooks, hook)
|
||||
f.mutex.Unlock()
|
||||
}
|
||||
|
||||
// SetDefaultReturn calls SetDefaultHook with a function that returns the
|
||||
// given values.
|
||||
func (f *ClientIsContextRequiredForChatQueryFunc) SetDefaultReturn(r0 bool, r1 error) {
|
||||
f.SetDefaultHook(func(context.Context, IsContextRequiredForChatQueryParameters) (bool, error) {
|
||||
return r0, r1
|
||||
})
|
||||
}
|
||||
|
||||
// PushReturn calls PushHook with a function that returns the given values.
|
||||
func (f *ClientIsContextRequiredForChatQueryFunc) PushReturn(r0 bool, r1 error) {
|
||||
f.PushHook(func(context.Context, IsContextRequiredForChatQueryParameters) (bool, error) {
|
||||
return r0, r1
|
||||
})
|
||||
}
|
||||
|
||||
func (f *ClientIsContextRequiredForChatQueryFunc) nextHook() func(context.Context, IsContextRequiredForChatQueryParameters) (bool, error) {
|
||||
f.mutex.Lock()
|
||||
defer f.mutex.Unlock()
|
||||
|
||||
if len(f.hooks) == 0 {
|
||||
return f.defaultHook
|
||||
}
|
||||
|
||||
hook := f.hooks[0]
|
||||
f.hooks = f.hooks[1:]
|
||||
return hook
|
||||
}
|
||||
|
||||
func (f *ClientIsContextRequiredForChatQueryFunc) appendCall(r0 ClientIsContextRequiredForChatQueryFuncCall) {
|
||||
f.mutex.Lock()
|
||||
f.history = append(f.history, r0)
|
||||
f.mutex.Unlock()
|
||||
}
|
||||
|
||||
// History returns a sequence of ClientIsContextRequiredForChatQueryFuncCall
|
||||
// objects describing the invocations of this function.
|
||||
func (f *ClientIsContextRequiredForChatQueryFunc) History() []ClientIsContextRequiredForChatQueryFuncCall {
|
||||
f.mutex.Lock()
|
||||
history := make([]ClientIsContextRequiredForChatQueryFuncCall, len(f.history))
|
||||
copy(history, f.history)
|
||||
f.mutex.Unlock()
|
||||
|
||||
return history
|
||||
}
|
||||
|
||||
// ClientIsContextRequiredForChatQueryFuncCall is an object that describes
|
||||
// an invocation of method IsContextRequiredForChatQuery on an instance of
|
||||
// MockClient.
|
||||
type ClientIsContextRequiredForChatQueryFuncCall struct {
|
||||
// Arg0 is the value of the 1st argument passed to this method
|
||||
// invocation.
|
||||
Arg0 context.Context
|
||||
// Arg1 is the value of the 2nd argument passed to this method
|
||||
// invocation.
|
||||
Arg1 IsContextRequiredForChatQueryParameters
|
||||
// Result0 is the value of the 1st result returned from this method
|
||||
// invocation.
|
||||
Result0 bool
|
||||
// Result1 is the value of the 2nd result returned from this method
|
||||
// invocation.
|
||||
Result1 error
|
||||
}
|
||||
|
||||
// Args returns an interface slice containing the arguments of this
|
||||
// invocation.
|
||||
func (c ClientIsContextRequiredForChatQueryFuncCall) Args() []interface{} {
|
||||
return []interface{}{c.Arg0, c.Arg1}
|
||||
}
|
||||
|
||||
// Results returns an interface slice containing the results of this
|
||||
// invocation.
|
||||
func (c ClientIsContextRequiredForChatQueryFuncCall) Results() []interface{} {
|
||||
return []interface{}{c.Result0, c.Result1}
|
||||
}
|
||||
|
||||
// ClientSearchFunc describes the behavior when the Search method of the
|
||||
// parent MockClient instance is invoked.
|
||||
type ClientSearchFunc struct {
|
||||
defaultHook func(context.Context, EmbeddingsSearchParameters) (*EmbeddingCombinedSearchResults, error)
|
||||
hooks []func(context.Context, EmbeddingsSearchParameters) (*EmbeddingCombinedSearchResults, error)
|
||||
history []ClientSearchFuncCall
|
||||
mutex sync.Mutex
|
||||
}
|
||||
|
||||
// Search delegates to the next hook function in the queue and stores the
|
||||
// parameter and result values of this invocation.
|
||||
func (m *MockClient) Search(v0 context.Context, v1 EmbeddingsSearchParameters) (*EmbeddingCombinedSearchResults, error) {
|
||||
r0, r1 := m.SearchFunc.nextHook()(v0, v1)
|
||||
m.SearchFunc.appendCall(ClientSearchFuncCall{v0, v1, r0, r1})
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// SetDefaultHook sets function that is called when the Search method of the
|
||||
// parent MockClient instance is invoked and the hook queue is empty.
|
||||
func (f *ClientSearchFunc) SetDefaultHook(hook func(context.Context, EmbeddingsSearchParameters) (*EmbeddingCombinedSearchResults, error)) {
|
||||
f.defaultHook = hook
|
||||
}
|
||||
|
||||
// PushHook adds a function to the end of hook queue. Each invocation of the
|
||||
// Search method of the parent MockClient instance invokes the hook at the
|
||||
// front of the queue and discards it. After the queue is empty, the default
|
||||
// hook function is invoked for any future action.
|
||||
func (f *ClientSearchFunc) PushHook(hook func(context.Context, EmbeddingsSearchParameters) (*EmbeddingCombinedSearchResults, error)) {
|
||||
f.mutex.Lock()
|
||||
f.hooks = append(f.hooks, hook)
|
||||
f.mutex.Unlock()
|
||||
}
|
||||
|
||||
// SetDefaultReturn calls SetDefaultHook with a function that returns the
|
||||
// given values.
|
||||
func (f *ClientSearchFunc) SetDefaultReturn(r0 *EmbeddingCombinedSearchResults, r1 error) {
|
||||
f.SetDefaultHook(func(context.Context, EmbeddingsSearchParameters) (*EmbeddingCombinedSearchResults, error) {
|
||||
return r0, r1
|
||||
})
|
||||
}
|
||||
|
||||
// PushReturn calls PushHook with a function that returns the given values.
|
||||
func (f *ClientSearchFunc) PushReturn(r0 *EmbeddingCombinedSearchResults, r1 error) {
|
||||
f.PushHook(func(context.Context, EmbeddingsSearchParameters) (*EmbeddingCombinedSearchResults, error) {
|
||||
return r0, r1
|
||||
})
|
||||
}
|
||||
|
||||
func (f *ClientSearchFunc) nextHook() func(context.Context, EmbeddingsSearchParameters) (*EmbeddingCombinedSearchResults, error) {
|
||||
f.mutex.Lock()
|
||||
defer f.mutex.Unlock()
|
||||
|
||||
if len(f.hooks) == 0 {
|
||||
return f.defaultHook
|
||||
}
|
||||
|
||||
hook := f.hooks[0]
|
||||
f.hooks = f.hooks[1:]
|
||||
return hook
|
||||
}
|
||||
|
||||
func (f *ClientSearchFunc) appendCall(r0 ClientSearchFuncCall) {
|
||||
f.mutex.Lock()
|
||||
f.history = append(f.history, r0)
|
||||
f.mutex.Unlock()
|
||||
}
|
||||
|
||||
// History returns a sequence of ClientSearchFuncCall objects describing the
|
||||
// invocations of this function.
|
||||
func (f *ClientSearchFunc) History() []ClientSearchFuncCall {
|
||||
f.mutex.Lock()
|
||||
history := make([]ClientSearchFuncCall, len(f.history))
|
||||
copy(history, f.history)
|
||||
f.mutex.Unlock()
|
||||
|
||||
return history
|
||||
}
|
||||
|
||||
// ClientSearchFuncCall is an object that describes an invocation of method
|
||||
// Search on an instance of MockClient.
|
||||
type ClientSearchFuncCall struct {
|
||||
// Arg0 is the value of the 1st argument passed to this method
|
||||
// invocation.
|
||||
Arg0 context.Context
|
||||
// Arg1 is the value of the 2nd argument passed to this method
|
||||
// invocation.
|
||||
Arg1 EmbeddingsSearchParameters
|
||||
// Result0 is the value of the 1st result returned from this method
|
||||
// invocation.
|
||||
Result0 *EmbeddingCombinedSearchResults
|
||||
// Result1 is the value of the 2nd result returned from this method
|
||||
// invocation.
|
||||
Result1 error
|
||||
}
|
||||
|
||||
// Args returns an interface slice containing the arguments of this
|
||||
// invocation.
|
||||
func (c ClientSearchFuncCall) Args() []interface{} {
|
||||
return []interface{}{c.Arg0, c.Arg1}
|
||||
}
|
||||
|
||||
// Results returns an interface slice containing the results of this
|
||||
// invocation.
|
||||
func (c ClientSearchFuncCall) Results() []interface{} {
|
||||
return []interface{}{c.Result0, c.Result1}
|
||||
}
|
||||
@ -2,17 +2,16 @@ package embeddings
|
||||
|
||||
import (
|
||||
"container/heap"
|
||||
"fmt"
|
||||
"math"
|
||||
"sort"
|
||||
|
||||
"github.com/sourcegraph/conc"
|
||||
"github.com/sourcegraph/sourcegraph/internal/api"
|
||||
)
|
||||
|
||||
type nearestNeighbor struct {
|
||||
index int
|
||||
score int32
|
||||
debug searchDebugInfo
|
||||
index int
|
||||
scoreDetails SearchScoreDetails
|
||||
}
|
||||
|
||||
type nearestNeighborsHeap struct {
|
||||
@ -22,7 +21,7 @@ type nearestNeighborsHeap struct {
|
||||
func (nn *nearestNeighborsHeap) Len() int { return len(nn.neighbors) }
|
||||
|
||||
func (nn *nearestNeighborsHeap) Less(i, j int) bool {
|
||||
return nn.neighbors[i].score < nn.neighbors[j].score
|
||||
return nn.neighbors[i].scoreDetails.Score < nn.neighbors[j].scoreDetails.Score
|
||||
}
|
||||
|
||||
func (nn *nearestNeighborsHeap) Swap(i, j int) {
|
||||
@ -81,19 +80,16 @@ type WorkerOptions struct {
|
||||
MinRowsToSplit int
|
||||
}
|
||||
|
||||
type SimilaritySearchResult struct {
|
||||
RepoEmbeddingRowMetadata
|
||||
SimilarityScore int32
|
||||
RankScore int32
|
||||
}
|
||||
|
||||
func (r *SimilaritySearchResult) Score() int32 {
|
||||
return r.SimilarityScore + r.RankScore
|
||||
}
|
||||
|
||||
// SimilaritySearch finds the `nResults` most similar rows to a query vector. It uses the cosine similarity metric.
|
||||
// IMPORTANT: The vectors in the embedding index have to be normalized for similarity search to work correctly.
|
||||
func (index *EmbeddingIndex) SimilaritySearch(query []int8, numResults int, workerOptions WorkerOptions, opts SearchOptions) []SimilaritySearchResult {
|
||||
func (index *EmbeddingIndex) SimilaritySearch(
|
||||
query []int8,
|
||||
numResults int,
|
||||
workerOptions WorkerOptions,
|
||||
opts SearchOptions,
|
||||
repoName api.RepoName,
|
||||
revision api.CommitID,
|
||||
) []EmbeddingSearchResult {
|
||||
if numResults == 0 || len(index.Embeddings) == 0 {
|
||||
return nil
|
||||
}
|
||||
@ -131,16 +127,20 @@ func (index *EmbeddingIndex) SimilaritySearch(query []int8, numResults int, work
|
||||
}
|
||||
}
|
||||
// And re-sort it according to the score (descending).
|
||||
sort.Slice(neighbors, func(i, j int) bool { return neighbors[i].score > neighbors[j].score })
|
||||
sort.Slice(neighbors, func(i, j int) bool { return neighbors[i].scoreDetails.Score > neighbors[j].scoreDetails.Score })
|
||||
|
||||
// Take top neighbors and return them as results.
|
||||
results := make([]SimilaritySearchResult, numResults)
|
||||
results := make([]EmbeddingSearchResult, numResults)
|
||||
|
||||
for idx := 0; idx < min(numResults, len(neighbors)); idx++ {
|
||||
results[idx] = SimilaritySearchResult{
|
||||
RepoEmbeddingRowMetadata: index.RowMetadata[neighbors[idx].index],
|
||||
SimilarityScore: neighbors[idx].debug.similarity,
|
||||
RankScore: neighbors[idx].debug.rank,
|
||||
metadata := index.RowMetadata[neighbors[idx].index]
|
||||
results[idx] = EmbeddingSearchResult{
|
||||
RepoName: repoName,
|
||||
Revision: revision,
|
||||
FileName: metadata.FileName,
|
||||
StartLine: metadata.StartLine,
|
||||
EndLine: metadata.EndLine,
|
||||
ScoreDetails: neighbors[idx].scoreDetails,
|
||||
}
|
||||
}
|
||||
|
||||
@ -156,17 +156,17 @@ func (index *EmbeddingIndex) partialSimilaritySearch(query []int8, numResults in
|
||||
|
||||
nnHeap := newNearestNeighborsHeap()
|
||||
for i := partialRows.start; i < partialRows.start+numResults; i++ {
|
||||
score, debugInfo := index.score(query, i, opts)
|
||||
heap.Push(nnHeap, nearestNeighbor{index: i, score: score, debug: debugInfo})
|
||||
scoreDetails := index.score(query, i, opts)
|
||||
heap.Push(nnHeap, nearestNeighbor{index: i, scoreDetails: scoreDetails})
|
||||
}
|
||||
|
||||
for i := partialRows.start + numResults; i < partialRows.end; i++ {
|
||||
score, debugInfo := index.score(query, i, opts)
|
||||
scoreDetails := index.score(query, i, opts)
|
||||
// Add row if it has greater similarity than the smallest similarity in the heap.
|
||||
// This way we ensure keep a set of the highest similarities in the heap.
|
||||
if score > nnHeap.Peek().score {
|
||||
if scoreDetails.Score > nnHeap.Peek().scoreDetails.Score {
|
||||
heap.Pop(nnHeap)
|
||||
heap.Push(nnHeap, nearestNeighbor{index: i, score: score, debug: debugInfo})
|
||||
heap.Push(nnHeap, nearestNeighbor{index: i, scoreDetails: scoreDetails})
|
||||
}
|
||||
}
|
||||
|
||||
@ -178,7 +178,7 @@ const (
|
||||
scoreSimilarityWeight int32 = 2
|
||||
)
|
||||
|
||||
func (index *EmbeddingIndex) score(query []int8, i int, opts SearchOptions) (score int32, debugInfo searchDebugInfo) {
|
||||
func (index *EmbeddingIndex) score(query []int8, i int, opts SearchOptions) SearchScoreDetails {
|
||||
similarityScore := scoreSimilarityWeight * Dot(index.Row(i), query)
|
||||
|
||||
// handle missing ranks
|
||||
@ -195,20 +195,11 @@ func (index *EmbeddingIndex) score(query []int8, i int, opts SearchOptions) (sco
|
||||
rankScore = int32(float32(scoreFileRankWeight) * normalizedRank)
|
||||
}
|
||||
|
||||
return similarityScore + rankScore, searchDebugInfo{similarity: similarityScore, rank: rankScore, enabled: opts.Debug}
|
||||
}
|
||||
|
||||
type searchDebugInfo struct {
|
||||
similarity int32
|
||||
rank int32
|
||||
enabled bool
|
||||
}
|
||||
|
||||
func (i *searchDebugInfo) String() string {
|
||||
if !i.enabled {
|
||||
return ""
|
||||
return SearchScoreDetails{
|
||||
Score: similarityScore + rankScore,
|
||||
SimilarityScore: similarityScore,
|
||||
RankScore: rankScore,
|
||||
}
|
||||
return fmt.Sprintf("score:%d, similarity:%d, rank:%d", i.similarity+i.rank, i.similarity, i.rank)
|
||||
}
|
||||
|
||||
func min(a, b int) int {
|
||||
@ -226,6 +217,5 @@ func max(a, b int) int {
|
||||
}
|
||||
|
||||
type SearchOptions struct {
|
||||
Debug bool
|
||||
UseDocumentRanks bool
|
||||
}
|
||||
|
||||
@ -64,10 +64,10 @@ func TestSimilaritySearch(t *testing.T) {
|
||||
for q := 0; q < numQueries; q++ {
|
||||
t.Run(fmt.Sprintf("find nearest neighbors query=%d numResults=%d numWorkers=%d", q, numResults, numWorkers), func(t *testing.T) {
|
||||
query := queries[q*columnDimension : (q+1)*columnDimension]
|
||||
results := index.SimilaritySearch(query, numResults, WorkerOptions{NumWorkers: numWorkers, MinRowsToSplit: 0}, SearchOptions{})
|
||||
results := index.SimilaritySearch(query, numResults, WorkerOptions{NumWorkers: numWorkers, MinRowsToSplit: 0}, SearchOptions{}, "", "")
|
||||
resultRowNums := make([]int, len(results))
|
||||
for i, r := range results {
|
||||
resultRowNums[i], _ = strconv.Atoi(r.RepoEmbeddingRowMetadata.FileName)
|
||||
resultRowNums[i], _ = strconv.Atoi(r.FileName)
|
||||
}
|
||||
expectedResults := ranks[q]
|
||||
require.Equal(t, expectedResults[:min(numResults, len(expectedResults))], resultRowNums)
|
||||
@ -152,7 +152,7 @@ func BenchmarkSimilaritySearch(b *testing.B) {
|
||||
b.Run(fmt.Sprintf("numWorkers=%d", numWorkers), func(b *testing.B) {
|
||||
start := time.Now()
|
||||
for n := 0; n < b.N; n++ {
|
||||
_ = index.SimilaritySearch(query, numResults, WorkerOptions{NumWorkers: numWorkers}, SearchOptions{})
|
||||
_ = index.SimilaritySearch(query, numResults, WorkerOptions{NumWorkers: numWorkers}, SearchOptions{}, "", "")
|
||||
}
|
||||
m := float64(numRows) * float64(b.N) / time.Since(start).Seconds()
|
||||
b.ReportMetric(m, "embeddings/s")
|
||||
@ -175,23 +175,11 @@ func TestScore(t *testing.T) {
|
||||
}
|
||||
// embeddings[0] = 64, 83, 70,
|
||||
// queries[0:3] = 53, 61, 97,
|
||||
score, debugInfo := index.score(queries[0:columnDimension], 0, SearchOptions{Debug: true, UseDocumentRanks: true})
|
||||
scoreDetails := index.score(queries[0:columnDimension], 0, SearchOptions{UseDocumentRanks: true})
|
||||
|
||||
// Check that the score is correct
|
||||
expectedScore := scoreSimilarityWeight * ((64 * 53) + (83 * 61) + (70 * 97))
|
||||
if math.Abs(float64(score-expectedScore)) > 0.0001 {
|
||||
t.Fatalf("Expected score %d, but got %d", expectedScore, score)
|
||||
}
|
||||
|
||||
if debugInfo.String() == "" {
|
||||
t.Fatal("Expected a non-empty debug")
|
||||
if math.Abs(float64(scoreDetails.Score-expectedScore)) > 0.0001 {
|
||||
t.Fatalf("Expected score %d, but got %d", expectedScore, scoreDetails.Score)
|
||||
}
|
||||
}
|
||||
|
||||
func simpleCosineSimilarity(a, b []int8) int32 {
|
||||
similarity := int32(0)
|
||||
for i := 0; i < len(a); i++ {
|
||||
similarity += int32(a[i]) * int32(b[i])
|
||||
}
|
||||
return similarity
|
||||
}
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
package embeddings
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/sourcegraph/log"
|
||||
@ -45,18 +47,46 @@ type ContextDetectionEmbeddingIndex struct {
|
||||
MessagesWithoutAdditionalContextMeanEmbedding []float32
|
||||
}
|
||||
|
||||
type EmbeddingSearchResults struct {
|
||||
CodeResults []EmbeddingSearchResult `json:"codeResults"`
|
||||
TextResults []EmbeddingSearchResult `json:"textResults"`
|
||||
type EmbeddingCombinedSearchResults struct {
|
||||
CodeResults EmbeddingSearchResults `json:"codeResults"`
|
||||
TextResults EmbeddingSearchResults `json:"textResults"`
|
||||
}
|
||||
|
||||
type EmbeddingSearchResults []EmbeddingSearchResult
|
||||
|
||||
// MergeTruncate merges other into the search results, keeping only max results with the highest scores
|
||||
func (esrs *EmbeddingSearchResults) MergeTruncate(other EmbeddingSearchResults, max int) {
|
||||
self := *esrs
|
||||
self = append(self, other...)
|
||||
sort.Slice(self, func(i, j int) bool { return self[i].Score() > self[j].Score() })
|
||||
*esrs = self[:max]
|
||||
}
|
||||
|
||||
type EmbeddingSearchResult struct {
|
||||
RepoName api.RepoName
|
||||
Revision api.CommitID
|
||||
RepoEmbeddingRowMetadata
|
||||
Content string `json:"content"`
|
||||
// Experimental: Clients should not rely on any particular format of debug
|
||||
Debug string `json:"debug,omitempty"`
|
||||
RepoName api.RepoName `json:"repoName"`
|
||||
Revision api.CommitID `json:"revision"`
|
||||
|
||||
FileName string `json:"fileName"`
|
||||
StartLine int `json:"startLine"`
|
||||
EndLine int `json:"endLine"`
|
||||
|
||||
ScoreDetails SearchScoreDetails `json:"scoreDetails"`
|
||||
}
|
||||
|
||||
func (esr *EmbeddingSearchResult) Score() int32 {
|
||||
return esr.ScoreDetails.RankScore + esr.ScoreDetails.SimilarityScore
|
||||
}
|
||||
|
||||
type SearchScoreDetails struct {
|
||||
Score int32 `json:"score"`
|
||||
|
||||
// Breakdown
|
||||
SimilarityScore int32 `json:"similarityScore"`
|
||||
RankScore int32 `json:"rankScore"`
|
||||
}
|
||||
|
||||
func (s *SearchScoreDetails) String() string {
|
||||
return fmt.Sprintf("score:%d, similarity:%d, rank:%d", s.Score, s.SimilarityScore, s.RankScore)
|
||||
}
|
||||
|
||||
// DEPRECATED: to support decoding old indexes, we need a struct
|
||||
|
||||
@ -123,3 +123,15 @@
|
||||
path: github.com/sourcegraph/sourcegraph/enterprise/internal/github_apps/store
|
||||
interfaces:
|
||||
- GitHubAppsStore
|
||||
- filename: enterprise/internal/embeddings/mocks_temp.go
|
||||
path: github.com/sourcegraph/sourcegraph/enterprise/internal/embeddings
|
||||
interfaces:
|
||||
- Client
|
||||
- filename: enterprise/internal/embeddings/background/repo/mocks_temp.go
|
||||
path: github.com/sourcegraph/sourcegraph/enterprise/internal/embeddings/background/repo
|
||||
interfaces:
|
||||
- RepoEmbeddingJobsStore
|
||||
- filename: enterprise/internal/embeddings/background/contextdetection/mocks_temp.go
|
||||
path: github.com/sourcegraph/sourcegraph/enterprise/internal/embeddings/background/contextdetection
|
||||
interfaces:
|
||||
- ContextDetectionEmbeddingJobsStore
|
||||
|
||||
Loading…
Reference in New Issue
Block a user