diff --git a/cmd/frontend/internal/context/resolvers/BUILD.bazel b/cmd/frontend/internal/context/resolvers/BUILD.bazel index 6585a2ef55c..149e20ed0f8 100644 --- a/cmd/frontend/internal/context/resolvers/BUILD.bazel +++ b/cmd/frontend/internal/context/resolvers/BUILD.bazel @@ -1,3 +1,4 @@ +load("//dev:go_defs.bzl", "go_test") load("@io_bazel_rules_go//go:def.bzl", "go_library") go_library( @@ -16,3 +17,38 @@ go_library( "@com_github_sourcegraph_conc//iter", ], ) + +go_test( + name = "resolvers_test", + srcs = ["context_test.go"], + embed = [":resolvers"], + tags = [ + # Test requires localhost database + "requires-network", + ], + deps = [ + "//cmd/frontend/graphqlbackend", + "//cmd/frontend/internal/codycontext", + "//internal/actor", + "//internal/api", + "//internal/conf", + "//internal/database", + "//internal/database/dbtest", + "//internal/dotcom", + "//internal/embeddings", + "//internal/gitserver", + "//internal/licensing", + "//internal/observation", + "//internal/rbac/types", + "//internal/search/backend", + "//internal/search/client", + "//internal/search/job", + "//internal/types", + "//lib/errors", + "//lib/pointers", + "//schema", + "@com_github_sourcegraph_log//logtest", + "@com_github_sourcegraph_zoekt//:zoekt", + "@com_github_stretchr_testify//require", + ], +) diff --git a/cmd/frontend/internal/context/resolvers/context_test.go b/cmd/frontend/internal/context/resolvers/context_test.go new file mode 100644 index 00000000000..e3eddd33c29 --- /dev/null +++ b/cmd/frontend/internal/context/resolvers/context_test.go @@ -0,0 +1,254 @@ +package resolvers + +import ( + "bytes" + "context" + "io" + "io/fs" + "os" + "sort" + "testing" + + "github.com/sourcegraph/log/logtest" + "github.com/sourcegraph/zoekt" + "github.com/stretchr/testify/require" + + "github.com/sourcegraph/sourcegraph/cmd/frontend/graphqlbackend" + "github.com/sourcegraph/sourcegraph/cmd/frontend/internal/codycontext" + "github.com/sourcegraph/sourcegraph/internal/actor" + "github.com/sourcegraph/sourcegraph/internal/api" + "github.com/sourcegraph/sourcegraph/internal/conf" + "github.com/sourcegraph/sourcegraph/internal/database" + "github.com/sourcegraph/sourcegraph/internal/database/dbtest" + "github.com/sourcegraph/sourcegraph/internal/dotcom" + "github.com/sourcegraph/sourcegraph/internal/embeddings" + "github.com/sourcegraph/sourcegraph/internal/gitserver" + "github.com/sourcegraph/sourcegraph/internal/licensing" + "github.com/sourcegraph/sourcegraph/internal/observation" + rtypes "github.com/sourcegraph/sourcegraph/internal/rbac/types" + "github.com/sourcegraph/sourcegraph/internal/search/backend" + "github.com/sourcegraph/sourcegraph/internal/search/client" + "github.com/sourcegraph/sourcegraph/internal/search/job" + "github.com/sourcegraph/sourcegraph/internal/types" + "github.com/sourcegraph/sourcegraph/lib/errors" + "github.com/sourcegraph/sourcegraph/lib/pointers" + "github.com/sourcegraph/sourcegraph/schema" +) + +func TestContextResolver(t *testing.T) { + logger := logtest.Scoped(t) + ctx := context.Background() + db := database.NewDB(logger, dbtest.NewDB(t)) + repo1 := types.Repo{Name: "repo1"} + repo2 := types.Repo{Name: "repo2"} + truePtr := true + conf.Mock(&conf.Unified{ + SiteConfiguration: schema.SiteConfiguration{ + CodyEnabled: &truePtr, + LicenseKey: "asdf", + ExperimentalFeatures: &schema.ExperimentalFeatures{ + CodyContextIgnore: pointers.Ptr(true), + }, + CodyContextFilters: &schema.CodyContextFilters{ + Exclude: []*schema.CodyContextFilterItem{ + {RepoNamePattern: "^repo2$"}, + }, + }, + }, + }) + t.Cleanup(func() { conf.Mock(nil) }) + + oldMock := licensing.MockCheckFeature + defer func() { + licensing.MockCheckFeature = oldMock + }() + + licensing.MockCheckFeature = func(feature licensing.Feature) error { + if feature == licensing.FeatureCody { + return nil + } + return errors.New("error") + } + + // Create a normal user role with Cody access permission + normalUserRole, err := db.Roles().Create(ctx, "normal user role", false) + require.NoError(t, err) + codyAccessPermission, err := db.Permissions().Create(ctx, database.CreatePermissionOpts{ + Namespace: rtypes.CodyNamespace, + Action: rtypes.CodyAccessAction, + }) + require.NoError(t, err) + err = db.RolePermissions().Assign(ctx, database.AssignRolePermissionOpts{ + PermissionID: codyAccessPermission.ID, + RoleID: normalUserRole.ID, + }) + require.NoError(t, err) + + // Create an admin user, give them the normal user role, and authenticate our actor. + newAdminUser, err := db.Users().Create(ctx, database.NewUser{ + Email: "test@example.com", + Username: "test", + DisplayName: "Test User", + Password: "hunter123", + EmailIsVerified: true, + FailIfNotInitialUser: true, // initial site admin account + EnforcePasswordLength: false, + TosAccepted: true, + }) + require.NoError(t, err) + db.UserRoles().SetRolesForUser(ctx, database.SetRolesForUserOpts{ + UserID: newAdminUser.ID, + Roles: []int32{normalUserRole.ID}, + }) + require.NoError(t, err) + ctx = actor.WithActor(ctx, actor.FromMockUser(newAdminUser.ID)) + + // Create populates the IDs in the passed in types.Repo + err = db.Repos().Create(ctx, &repo1, &repo2) + require.NoError(t, err) + + files := map[api.RepoName]map[string][]byte{ + "repo1": { + "testcode1.go": []byte("testcode1"), + "ignore_me1.go": []byte("secret"), + ".cody/ignore": []byte("ignore_me1.go"), + }, + "repo2": { + "testcode2.go": []byte("testcode2"), + "ignore_me2.go": []byte("secret"), + ".cody/ignore": []byte("ignore_me2.go"), + }, + } + + mockGitserver := gitserver.NewMockClient() + mockGitserver.GetDefaultBranchFunc.SetDefaultReturn("main", api.CommitID("abc123"), nil) + mockGitserver.StatFunc.SetDefaultHook(func(_ context.Context, repo api.RepoName, _ api.CommitID, fileName string) (fs.FileInfo, error) { + return fakeFileInfo{path: fileName}, nil + }) + mockGitserver.NewFileReaderFunc.SetDefaultHook(func(ctx context.Context, repo api.RepoName, ci api.CommitID, fileName string) (io.ReadCloser, error) { + if content, ok := files[repo][fileName]; ok { + return io.NopCloser(bytes.NewReader(content)), nil + } + return nil, os.ErrNotExist + }) + + mockEmbeddingsClient := embeddings.NewMockClient() + mockEmbeddingsClient.SearchFunc.SetDefaultReturn(nil, errors.New("embeddings should be disabled")) + + lineRange := func(start, end uint32) []zoekt.ChunkMatch { + return []zoekt.ChunkMatch{{ + Ranges: []zoekt.Range{{ + Start: zoekt.Location{LineNumber: start}, + End: zoekt.Location{LineNumber: end}, + }}, + }} + } + mockZoektStreamer := &backend.FakeStreamer{ + Results: []*zoekt.SearchResult{{ + Files: []zoekt.FileMatch{{ + FileName: "ignore_me1.go", + RepositoryID: uint32(repo1.ID), + ChunkMatches: lineRange(0, 4), + }, { + FileName: "ignore_me2.go", + RepositoryID: uint32(repo2.ID), + ChunkMatches: lineRange(0, 4), + }, { + FileName: "testcode1.go", + RepositoryID: uint32(repo1.ID), + ChunkMatches: lineRange(0, 4), + }, { + FileName: "testcode2.go", + RepositoryID: uint32(repo2.ID), + ChunkMatches: lineRange(0, 4), + }}, + }}, + Repos: []*zoekt.RepoListEntry{{ + Repository: zoekt.Repository{ + ID: uint32(repo1.ID), + Name: string(repo1.Name), + Branches: []zoekt.RepositoryBranch{{Name: "HEAD"}}, + }}, { + Repository: zoekt.Repository{ + ID: uint32(repo2.ID), + Name: string(repo2.Name), + Branches: []zoekt.RepositoryBranch{{Name: "HEAD"}}, + }}, + }, + } + + searchClient := client.Mocked(job.RuntimeClients{ + Logger: logger, + DB: db, + Zoekt: mockZoektStreamer, + }) + + tests := []struct { + name string + dotComMode bool + want []string + }{ + { + name: "dotcom mode", + dotComMode: true, + // .cody/ignore files are respected in dotcom mode + // Cody context filters in site config are not applied + // Results are duplicated because the Cody search job calls the Zoekt streamer twice and appends the results + want: []string{"testcode1.go", "testcode1.go", "testcode2.go", "testcode2.go"}, + }, + { + name: "enterprise mode", + dotComMode: false, + // "repo2" results are excluded according to the site config + // .cody/ignore files don't have any effect for enterprise + // Each result is duplicated because the Cody search job calls the Zoekt streamer twice and appends the results + want: []string{"testcode1.go", "testcode1.go", "ignore_me1.go", "ignore_me1.go"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dotcom.MockSourcegraphDotComMode(t, tt.dotComMode) + observationCtx := observation.TestContextTB(t) + contextClient := codycontext.NewCodyContextClient( + observationCtx, + db, + mockEmbeddingsClient, + searchClient, + mockGitserver, + ) + + resolver := NewResolver( + db, + mockGitserver, + contextClient, + ) + + results, err := resolver.GetCodyContext(ctx, graphqlbackend.GetContextArgs{ + Repos: graphqlbackend.MarshalRepositoryIDs([]api.RepoID{1, 2}), + Query: "my test query", + TextResultsCount: 2, + CodeResultsCount: 2, + }) + require.NoError(t, err) + + paths := make([]string, len(results)) + for i, r := range results { + paths[i] = r.(*graphqlbackend.FileChunkContextResolver).Blob().Path() + } + expected := tt.want + sort.Strings(expected) + sort.Strings(paths) + require.Equal(t, expected, paths) + }) + } +} + +type fakeFileInfo struct { + path string + fs.FileInfo +} + +func (f fakeFileInfo) Name() string { + return f.path +}