diff --git a/cmd/frontend/graphqlbackend/cody_context.go b/cmd/frontend/graphqlbackend/cody_context.go index 29d7e143225..a3feaf30556 100644 --- a/cmd/frontend/graphqlbackend/cody_context.go +++ b/cmd/frontend/graphqlbackend/cody_context.go @@ -44,9 +44,8 @@ func (f *FileChunkContextResolver) ToFileChunkContext() (*FileChunkContextResolv } func (f *FileChunkContextResolver) ChunkContent(ctx context.Context) (string, error) { - startLine, endLine := int32(f.startLine), int32(f.endLine) return f.treeEntry.Content(ctx, &GitTreeContentPageArgs{ - StartLine: &startLine, - EndLine: &endLine, + StartLine: &f.startLine, + EndLine: &f.endLine, }) } diff --git a/cmd/frontend/internal/codycontext/context.go b/cmd/frontend/internal/codycontext/context.go index 1423895a604..0bbb593374e 100644 --- a/cmd/frontend/internal/codycontext/context.go +++ b/cmd/frontend/internal/codycontext/context.go @@ -35,7 +35,6 @@ type FileChunkContext struct { CommitID api.CommitID Path string StartLine int - EndLine int } func NewCodyContextClient(obsCtx *observation.Context, db database.DB, embeddingsClient embeddings.Client, searchClient client.SearchClient, gitserverClient gitserver.Client) *CodyContextClient { @@ -233,7 +232,6 @@ func (c *CodyContextClient) getEmbeddingsContext(ctx context.Context, args GetCo CommitID: result.Revision, Path: result.FileName, StartLine: result.StartLine, - EndLine: result.EndLine, }) } @@ -356,7 +354,6 @@ func fileMatchToContextMatches(fm *result.FileMatch) []FileChunkContext { CommitID: fm.CommitID, Path: fm.Path, StartLine: 0, - EndLine: 20, }} } @@ -365,8 +362,6 @@ func fileMatchToContextMatches(fm *result.FileMatch) []FileChunkContext { // 5 lines of leading context, clamped to zero startLine := max(0, fm.ChunkMatches[0].ContentStart.Line-5) - // depend on content fetching to trim to the end of the file - endLine := startLine + 20 return []FileChunkContext{{ RepoName: fm.Repo.Name, @@ -374,6 +369,5 @@ func fileMatchToContextMatches(fm *result.FileMatch) []FileChunkContext { CommitID: fm.CommitID, Path: fm.Path, StartLine: startLine, - EndLine: endLine, }} } diff --git a/cmd/frontend/internal/codycontext/context_test.go b/cmd/frontend/internal/codycontext/context_test.go index a4842dc77a1..0301b7dba51 100644 --- a/cmd/frontend/internal/codycontext/context_test.go +++ b/cmd/frontend/internal/codycontext/context_test.go @@ -33,7 +33,6 @@ func TestFileMatchToContextMatches(t *testing.T) { CommitID: "abc123", Path: "main.go", StartLine: 0, - EndLine: 20, }}, }, { @@ -61,7 +60,6 @@ func TestFileMatchToContextMatches(t *testing.T) { CommitID: "abc123", Path: "main.go", StartLine: 85, - EndLine: 105, }}, }, } diff --git a/cmd/frontend/internal/context/resolvers/BUILD.bazel b/cmd/frontend/internal/context/resolvers/BUILD.bazel index 149e20ed0f8..ba20af0d288 100644 --- a/cmd/frontend/internal/context/resolvers/BUILD.bazel +++ b/cmd/frontend/internal/context/resolvers/BUILD.bazel @@ -14,6 +14,7 @@ go_library( "//internal/trace", "//internal/types", "//lib/errors", + "//lib/pointers", "@com_github_sourcegraph_conc//iter", ], ) @@ -40,9 +41,12 @@ go_test( "//internal/licensing", "//internal/observation", "//internal/rbac/types", + "//internal/search", "//internal/search/backend", "//internal/search/client", "//internal/search/job", + "//internal/search/result", + "//internal/search/streaming", "//internal/types", "//lib/errors", "//lib/pointers", @@ -50,5 +54,6 @@ go_test( "@com_github_sourcegraph_log//logtest", "@com_github_sourcegraph_zoekt//:zoekt", "@com_github_stretchr_testify//require", + "@io_k8s_utils//pointer", ], ) diff --git a/cmd/frontend/internal/context/resolvers/context.go b/cmd/frontend/internal/context/resolvers/context.go index a7a9cd08584..1cdd3c41c79 100644 --- a/cmd/frontend/internal/context/resolvers/context.go +++ b/cmd/frontend/internal/context/resolvers/context.go @@ -1,6 +1,7 @@ package resolvers import ( + "bytes" "context" "github.com/sourcegraph/conc/iter" @@ -12,6 +13,7 @@ import ( "github.com/sourcegraph/sourcegraph/internal/trace" "github.com/sourcegraph/sourcegraph/internal/types" "github.com/sourcegraph/sourcegraph/lib/errors" + "github.com/sourcegraph/sourcegraph/lib/pointers" ) func NewResolver(db database.DB, gitserverClient gitserver.Client, contextClient *codycontext.CodyContextClient) graphqlbackend.CodyContextResolver { @@ -68,6 +70,13 @@ func (r *Resolver) GetCodyContext(ctx context.Context, args graphqlbackend.GetCo }) } +// The rough size of a file chunk in runes. The value 1024 is due to historical reasons -- Cody context was once based +// on embeddings, and we chunked files into ~1024 characters (aiming for 256 tokens, assuming each token takes 4 +// characters on average). +// +// Ideally, the caller would pass a token 'budget' and we'd use a tokenizer and attempt to exactly match this budget. +const chunkSizeRunes = 1024 + func (r *Resolver) fileChunkToResolver(ctx context.Context, chunk *codycontext.FileChunkContext) (graphqlbackend.ContextResultResolver, error) { repoResolver := graphqlbackend.NewMinimalRepositoryResolver(r.db, r.gitserverClient, chunk.RepoID, chunk.RepoName) @@ -83,6 +92,31 @@ func (r *Resolver) fileChunkToResolver(ctx context.Context, chunk *codycontext.F }) // Populate content ahead of time so we can do it concurrently - gitTreeEntryResolver.Content(ctx, &graphqlbackend.GitTreeContentPageArgs{}) - return graphqlbackend.NewFileChunkContextResolver(gitTreeEntryResolver, chunk.StartLine, chunk.EndLine), nil + content, err := gitTreeEntryResolver.Content(ctx, &graphqlbackend.GitTreeContentPageArgs{ + StartLine: pointers.Ptr(int32(chunk.StartLine)), + }) + if err != nil { + return nil, err + } + + numLines := countLines(content, chunkSizeRunes) + endLine := chunk.StartLine + numLines - 1 // subtract 1 because endLine is inclusive + return graphqlbackend.NewFileChunkContextResolver(gitTreeEntryResolver, chunk.StartLine, endLine), nil +} + +// countLines finds the number of lines corresponding to the number of runes. We 'round down' +// to ensure that we don't return more characters than our budget. +func countLines(content string, numRunes int) int { + if len(content) == 0 { + return 0 + } + + if content[len(content)-1] != '\n' { + content += "\n" + } + + runes := []rune(content) + truncated := runes[:min(len(runes), numRunes)] + in := []byte(string(truncated)) + return bytes.Count(in, []byte("\n")) } diff --git a/cmd/frontend/internal/context/resolvers/context_test.go b/cmd/frontend/internal/context/resolvers/context_test.go index 2c2a7142d07..801bded234e 100644 --- a/cmd/frontend/internal/context/resolvers/context_test.go +++ b/cmd/frontend/internal/context/resolvers/context_test.go @@ -8,10 +8,12 @@ import ( "os" "sort" "testing" + "unicode/utf8" "github.com/sourcegraph/log/logtest" "github.com/sourcegraph/zoekt" "github.com/stretchr/testify/require" + "k8s.io/utils/pointer" "github.com/sourcegraph/sourcegraph/cmd/frontend/graphqlbackend" "github.com/sourcegraph/sourcegraph/cmd/frontend/internal/codycontext" @@ -26,16 +28,19 @@ import ( "github.com/sourcegraph/sourcegraph/internal/licensing" "github.com/sourcegraph/sourcegraph/internal/observation" rtypes "github.com/sourcegraph/sourcegraph/internal/rbac/types" + "github.com/sourcegraph/sourcegraph/internal/search" "github.com/sourcegraph/sourcegraph/internal/search/backend" "github.com/sourcegraph/sourcegraph/internal/search/client" "github.com/sourcegraph/sourcegraph/internal/search/job" + "github.com/sourcegraph/sourcegraph/internal/search/result" + "github.com/sourcegraph/sourcegraph/internal/search/streaming" "github.com/sourcegraph/sourcegraph/internal/types" "github.com/sourcegraph/sourcegraph/lib/errors" "github.com/sourcegraph/sourcegraph/lib/pointers" "github.com/sourcegraph/sourcegraph/schema" ) -func TestContextResolver(t *testing.T) { +func TestCodyIgnore(t *testing.T) { logger := logtest.Scoped(t) ctx := context.Background() db := database.NewDB(logger, dbtest.NewDB(t)) @@ -70,41 +75,10 @@ func TestContextResolver(t *testing.T) { return errors.New("error") } - // Create a normal user role with Cody access permission - normalUserRole, err := db.Roles().Create(ctx, "normal user role", false) - require.NoError(t, err) - codyAccessPermission, err := db.Permissions().Create(ctx, database.CreatePermissionOpts{ - Namespace: rtypes.CodyNamespace, - Action: rtypes.CodyAccessAction, - }) - require.NoError(t, err) - err = db.RolePermissions().Assign(ctx, database.AssignRolePermissionOpts{ - PermissionID: codyAccessPermission.ID, - RoleID: normalUserRole.ID, - }) - require.NoError(t, err) - - // Create an admin user, give them the normal user role, and authenticate our actor. - newAdminUser, err := db.Users().Create(ctx, database.NewUser{ - Email: "test@example.com", - Username: "test", - DisplayName: "Test User", - Password: "hunter123", - EmailIsVerified: true, - FailIfNotInitialUser: true, // initial site admin account - EnforcePasswordLength: false, - TosAccepted: true, - }) - require.NoError(t, err) - db.UserRoles().SetRolesForUser(ctx, database.SetRolesForUserOpts{ - UserID: newAdminUser.ID, - Roles: []int32{normalUserRole.ID}, - }) - require.NoError(t, err) - ctx = actor.WithActor(ctx, actor.FromMockUser(newAdminUser.ID)) + ctx = authenticateUser(t, db, ctx) // Create populates the IDs in the passed in types.Repo - err = db.Repos().Create(ctx, &repo1, &repo2) + err := db.Repos().Create(ctx, &repo1, &repo2) require.NoError(t, err) files := map[api.RepoName]map[string][]byte{ @@ -121,7 +95,7 @@ func TestContextResolver(t *testing.T) { } mockGitserver := gitserver.NewMockClient() - mockGitserver.GetDefaultBranchFunc.SetDefaultReturn("main", api.CommitID("abc123"), nil) + mockGitserver.GetDefaultBranchFunc.SetDefaultReturn("main", "abc123", nil) mockGitserver.StatFunc.SetDefaultHook(func(_ context.Context, repo api.RepoName, _ api.CommitID, fileName string) (fs.FileInfo, error) { return fakeFileInfo{path: fileName}, nil }) @@ -252,6 +226,42 @@ func TestContextResolver(t *testing.T) { } } +func authenticateUser(t *testing.T, db database.DB, ctx context.Context) context.Context { + // Create a normal user role with Cody access permission + normalUserRole, err := db.Roles().Create(ctx, "normal user role", false) + require.NoError(t, err) + codyAccessPermission, err := db.Permissions().Create(ctx, database.CreatePermissionOpts{ + Namespace: rtypes.CodyNamespace, + Action: rtypes.CodyAccessAction, + }) + require.NoError(t, err) + err = db.RolePermissions().Assign(ctx, database.AssignRolePermissionOpts{ + PermissionID: codyAccessPermission.ID, + RoleID: normalUserRole.ID, + }) + require.NoError(t, err) + + // Create an admin user, give them the normal user role, and authenticate our actor. + newAdminUser, err := db.Users().Create(ctx, database.NewUser{ + Email: "test@example.com", + Username: "test", + DisplayName: "Test User", + Password: "hunter123", + EmailIsVerified: true, + FailIfNotInitialUser: true, // initial site admin account + EnforcePasswordLength: false, + TosAccepted: true, + }) + require.NoError(t, err) + db.UserRoles().SetRolesForUser(ctx, database.SetRolesForUserOpts{ + UserID: newAdminUser.ID, + Roles: []int32{normalUserRole.ID}, + }) + require.NoError(t, err) + ctx = actor.WithActor(ctx, actor.FromMockUser(newAdminUser.ID)) + return ctx +} + type fakeFileInfo struct { path string fs.FileInfo @@ -260,3 +270,192 @@ type fakeFileInfo struct { func (f fakeFileInfo) Name() string { return f.path } + +func TestChunkSize(t *testing.T) { + logger := logtest.Scoped(t) + ctx := context.Background() + db := database.NewDB(logger, dbtest.NewDB(t)) + repo := types.Repo{Name: "repo"} + + conf.Mock(&conf.Unified{ + SiteConfiguration: schema.SiteConfiguration{ + CodyEnabled: pointer.Bool(true), + LicenseKey: "asdf", + }, + }) + t.Cleanup(func() { conf.Mock(nil) }) + + oldMock := licensing.MockCheckFeature + defer func() { + licensing.MockCheckFeature = oldMock + }() + + licensing.MockCheckFeature = func(feature licensing.Feature) error { + if feature == licensing.FeatureCody { + return nil + } + return errors.New("error") + } + + ctx = authenticateUser(t, db, ctx) + + // Create populates the IDs in the passed in types.Repo + err := db.Repos().Create(ctx, &repo) + require.NoError(t, err) + + var content string + for i := 0; i < chunkSizeRunes; i++ { + if i != 0 && i%10 == 0 { + content += "\n" + } else { + content += "a" + } + } + + wantLines := bytes.Count([]byte(content), []byte("\n")) + + files := map[api.RepoName]map[string][]byte{ + "repo": { + "testcode1.go": []byte(content), + "testcode2.go": []byte(content + "\n"), + "testcode3.go": []byte(content + "extra info"), + }, + } + + mockGitserver := gitserver.NewMockClient() + mockGitserver.GetDefaultBranchFunc.SetDefaultReturn("main", "abc123", nil) + mockGitserver.StatFunc.SetDefaultHook(func(_ context.Context, repo api.RepoName, _ api.CommitID, fileName string) (fs.FileInfo, error) { + return fakeFileInfo{path: fileName}, nil + }) + mockGitserver.NewFileReaderFunc.SetDefaultHook(func(ctx context.Context, repo api.RepoName, ci api.CommitID, fileName string) (io.ReadCloser, error) { + if content, ok := files[repo][fileName]; ok { + return io.NopCloser(bytes.NewReader(content)), nil + } + return nil, os.ErrNotExist + }) + + mockEmbeddingsClient := embeddings.NewMockClient() + mockEmbeddingsClient.SearchFunc.SetDefaultReturn(nil, errors.New("embeddings should be disabled")) + + lineRange := result.ChunkMatches{{ + Ranges: result.Ranges{{ + Start: result.Location{Line: 2}, + End: result.Location{Line: 6}, + }}, + }} + + mockSearchClient := client.NewMockSearchClient() + mockSearchClient.PlanFunc.SetDefaultHook(func(_ context.Context, _ string, _ *string, query string, _ search.Mode, _ search.Protocol, _ *int32) (*search.Inputs, error) { + return &search.Inputs{OriginalQuery: query}, nil + }) + mockSearchClient.ExecuteFunc.SetDefaultHook(func(_ context.Context, stream streaming.Sender, inputs *search.Inputs) (*search.Alert, error) { + stream.Send(streaming.SearchEvent{ + Results: result.Matches{&result.FileMatch{ + File: result.File{ + Path: "testcode1.go", + Repo: types.MinimalRepo{ID: repo.ID, Name: repo.Name}, + }, + ChunkMatches: lineRange, + }, &result.FileMatch{ + File: result.File{ + Path: "testcode2.go", + Repo: types.MinimalRepo{ID: repo.ID, Name: repo.Name}, + }, + ChunkMatches: lineRange, + }, &result.FileMatch{ + File: result.File{ + Path: "testcode3.go", + Repo: types.MinimalRepo{ID: repo.ID, Name: repo.Name}, + }, + ChunkMatches: lineRange, + }}}) + return nil, nil + }) + + observationCtx := observation.TestContextTB(t) + contextClient := codycontext.NewCodyContextClient( + observationCtx, + db, + mockEmbeddingsClient, + mockSearchClient, + mockGitserver, + ) + + resolver := NewResolver( + db, + mockGitserver, + contextClient, + ) + + results, err := resolver.GetCodyContext(ctx, graphqlbackend.GetContextArgs{ + Repos: graphqlbackend.MarshalRepositoryIDs([]api.RepoID{1}), + Query: "my test query", + TextResultsCount: 0, + CodeResultsCount: 5, + }) + require.NoError(t, err) + require.Equal(t, 3, len(results), "expected 3 results but got %d", len(results)) + + for _, r := range results { + f := r.(*graphqlbackend.FileChunkContextResolver) + gotLines := int(f.EndLine() - f.StartLine() + 1) + require.Equal(t, wantLines, gotLines, "expected %d lines but got %d", wantLines, gotLines) + + c, err := f.ChunkContent(ctx) + require.NoError(t, err) + require.LessOrEqual(t, utf8.RuneCountInString(c), chunkSizeRunes) + } +} + +func TestCountLines(t *testing.T) { + tests := []struct { + name string + content string + numRunes int + want int + }{ + { + name: "empty string", + content: "", + numRunes: 1024, + want: 0, + }, { + name: "single line", + content: "Hello world\n", + numRunes: 1024, + want: 1, + }, {}, { + name: "single line, no trailing newline", + content: "Hello world", + numRunes: 1024, + want: 1, + }, { + name: "should truncate", + content: "Hello\nWorld\nGo is awesome\n", + numRunes: 24, + want: 2, + }, { + name: "should truncate, no trailing newline", + content: "Hello\nWorld\nGo is awesome", + numRunes: 24, + want: 2, + }, { + name: "all lines", + content: "Hello\nWorld\nGo is awesome\n", + numRunes: 1024, + want: 3, + }, { + name: "all lines, no trailing newline", + content: "Hello\nWorld\nGo is awesome", + numRunes: 1024, + want: 3, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := countLines(tt.content, tt.numRunes); got != tt.want { + t.Errorf("countLines() = %v, want %v", got, tt.want) + } + }) + } +}