Context: more precise chunk sizing (#62643)

Currently, when retrieving context chunks, we hardcode the number of lines to
20. Historically, we've limited chunks to 1024 characters, and we chose 20
lines to roughly mirror that.

In evals, I found that we're often returning fewer than 1024 characters. This
PR updates the context resolver to load an adaptive number of lines based on
the 1024 character limit. 

Addresses #61745
This commit is contained in:
Julie Tibshirani 2024-05-15 12:52:12 -07:00 committed by GitHub
parent 3b394e7954
commit 553121c5a1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 277 additions and 48 deletions

View File

@ -44,9 +44,8 @@ func (f *FileChunkContextResolver) ToFileChunkContext() (*FileChunkContextResolv
}
func (f *FileChunkContextResolver) ChunkContent(ctx context.Context) (string, error) {
startLine, endLine := int32(f.startLine), int32(f.endLine)
return f.treeEntry.Content(ctx, &GitTreeContentPageArgs{
StartLine: &startLine,
EndLine: &endLine,
StartLine: &f.startLine,
EndLine: &f.endLine,
})
}

View File

@ -35,7 +35,6 @@ type FileChunkContext struct {
CommitID api.CommitID
Path string
StartLine int
EndLine int
}
func NewCodyContextClient(obsCtx *observation.Context, db database.DB, embeddingsClient embeddings.Client, searchClient client.SearchClient, gitserverClient gitserver.Client) *CodyContextClient {
@ -233,7 +232,6 @@ func (c *CodyContextClient) getEmbeddingsContext(ctx context.Context, args GetCo
CommitID: result.Revision,
Path: result.FileName,
StartLine: result.StartLine,
EndLine: result.EndLine,
})
}
@ -356,7 +354,6 @@ func fileMatchToContextMatches(fm *result.FileMatch) []FileChunkContext {
CommitID: fm.CommitID,
Path: fm.Path,
StartLine: 0,
EndLine: 20,
}}
}
@ -365,8 +362,6 @@ func fileMatchToContextMatches(fm *result.FileMatch) []FileChunkContext {
// 5 lines of leading context, clamped to zero
startLine := max(0, fm.ChunkMatches[0].ContentStart.Line-5)
// depend on content fetching to trim to the end of the file
endLine := startLine + 20
return []FileChunkContext{{
RepoName: fm.Repo.Name,
@ -374,6 +369,5 @@ func fileMatchToContextMatches(fm *result.FileMatch) []FileChunkContext {
CommitID: fm.CommitID,
Path: fm.Path,
StartLine: startLine,
EndLine: endLine,
}}
}

View File

@ -33,7 +33,6 @@ func TestFileMatchToContextMatches(t *testing.T) {
CommitID: "abc123",
Path: "main.go",
StartLine: 0,
EndLine: 20,
}},
},
{
@ -61,7 +60,6 @@ func TestFileMatchToContextMatches(t *testing.T) {
CommitID: "abc123",
Path: "main.go",
StartLine: 85,
EndLine: 105,
}},
},
}

View File

@ -14,6 +14,7 @@ go_library(
"//internal/trace",
"//internal/types",
"//lib/errors",
"//lib/pointers",
"@com_github_sourcegraph_conc//iter",
],
)
@ -40,9 +41,12 @@ go_test(
"//internal/licensing",
"//internal/observation",
"//internal/rbac/types",
"//internal/search",
"//internal/search/backend",
"//internal/search/client",
"//internal/search/job",
"//internal/search/result",
"//internal/search/streaming",
"//internal/types",
"//lib/errors",
"//lib/pointers",
@ -50,5 +54,6 @@ go_test(
"@com_github_sourcegraph_log//logtest",
"@com_github_sourcegraph_zoekt//:zoekt",
"@com_github_stretchr_testify//require",
"@io_k8s_utils//pointer",
],
)

View File

@ -1,6 +1,7 @@
package resolvers
import (
"bytes"
"context"
"github.com/sourcegraph/conc/iter"
@ -12,6 +13,7 @@ import (
"github.com/sourcegraph/sourcegraph/internal/trace"
"github.com/sourcegraph/sourcegraph/internal/types"
"github.com/sourcegraph/sourcegraph/lib/errors"
"github.com/sourcegraph/sourcegraph/lib/pointers"
)
func NewResolver(db database.DB, gitserverClient gitserver.Client, contextClient *codycontext.CodyContextClient) graphqlbackend.CodyContextResolver {
@ -68,6 +70,13 @@ func (r *Resolver) GetCodyContext(ctx context.Context, args graphqlbackend.GetCo
})
}
// The rough size of a file chunk in runes. The value 1024 is due to historical reasons -- Cody context was once based
// on embeddings, and we chunked files into ~1024 characters (aiming for 256 tokens, assuming each token takes 4
// characters on average).
//
// Ideally, the caller would pass a token 'budget' and we'd use a tokenizer and attempt to exactly match this budget.
const chunkSizeRunes = 1024
func (r *Resolver) fileChunkToResolver(ctx context.Context, chunk *codycontext.FileChunkContext) (graphqlbackend.ContextResultResolver, error) {
repoResolver := graphqlbackend.NewMinimalRepositoryResolver(r.db, r.gitserverClient, chunk.RepoID, chunk.RepoName)
@ -83,6 +92,31 @@ func (r *Resolver) fileChunkToResolver(ctx context.Context, chunk *codycontext.F
})
// Populate content ahead of time so we can do it concurrently
gitTreeEntryResolver.Content(ctx, &graphqlbackend.GitTreeContentPageArgs{})
return graphqlbackend.NewFileChunkContextResolver(gitTreeEntryResolver, chunk.StartLine, chunk.EndLine), nil
content, err := gitTreeEntryResolver.Content(ctx, &graphqlbackend.GitTreeContentPageArgs{
StartLine: pointers.Ptr(int32(chunk.StartLine)),
})
if err != nil {
return nil, err
}
numLines := countLines(content, chunkSizeRunes)
endLine := chunk.StartLine + numLines - 1 // subtract 1 because endLine is inclusive
return graphqlbackend.NewFileChunkContextResolver(gitTreeEntryResolver, chunk.StartLine, endLine), nil
}
// countLines finds the number of lines corresponding to the number of runes. We 'round down'
// to ensure that we don't return more characters than our budget.
func countLines(content string, numRunes int) int {
if len(content) == 0 {
return 0
}
if content[len(content)-1] != '\n' {
content += "\n"
}
runes := []rune(content)
truncated := runes[:min(len(runes), numRunes)]
in := []byte(string(truncated))
return bytes.Count(in, []byte("\n"))
}

View File

@ -8,10 +8,12 @@ import (
"os"
"sort"
"testing"
"unicode/utf8"
"github.com/sourcegraph/log/logtest"
"github.com/sourcegraph/zoekt"
"github.com/stretchr/testify/require"
"k8s.io/utils/pointer"
"github.com/sourcegraph/sourcegraph/cmd/frontend/graphqlbackend"
"github.com/sourcegraph/sourcegraph/cmd/frontend/internal/codycontext"
@ -26,16 +28,19 @@ import (
"github.com/sourcegraph/sourcegraph/internal/licensing"
"github.com/sourcegraph/sourcegraph/internal/observation"
rtypes "github.com/sourcegraph/sourcegraph/internal/rbac/types"
"github.com/sourcegraph/sourcegraph/internal/search"
"github.com/sourcegraph/sourcegraph/internal/search/backend"
"github.com/sourcegraph/sourcegraph/internal/search/client"
"github.com/sourcegraph/sourcegraph/internal/search/job"
"github.com/sourcegraph/sourcegraph/internal/search/result"
"github.com/sourcegraph/sourcegraph/internal/search/streaming"
"github.com/sourcegraph/sourcegraph/internal/types"
"github.com/sourcegraph/sourcegraph/lib/errors"
"github.com/sourcegraph/sourcegraph/lib/pointers"
"github.com/sourcegraph/sourcegraph/schema"
)
func TestContextResolver(t *testing.T) {
func TestCodyIgnore(t *testing.T) {
logger := logtest.Scoped(t)
ctx := context.Background()
db := database.NewDB(logger, dbtest.NewDB(t))
@ -70,41 +75,10 @@ func TestContextResolver(t *testing.T) {
return errors.New("error")
}
// Create a normal user role with Cody access permission
normalUserRole, err := db.Roles().Create(ctx, "normal user role", false)
require.NoError(t, err)
codyAccessPermission, err := db.Permissions().Create(ctx, database.CreatePermissionOpts{
Namespace: rtypes.CodyNamespace,
Action: rtypes.CodyAccessAction,
})
require.NoError(t, err)
err = db.RolePermissions().Assign(ctx, database.AssignRolePermissionOpts{
PermissionID: codyAccessPermission.ID,
RoleID: normalUserRole.ID,
})
require.NoError(t, err)
// Create an admin user, give them the normal user role, and authenticate our actor.
newAdminUser, err := db.Users().Create(ctx, database.NewUser{
Email: "test@example.com",
Username: "test",
DisplayName: "Test User",
Password: "hunter123",
EmailIsVerified: true,
FailIfNotInitialUser: true, // initial site admin account
EnforcePasswordLength: false,
TosAccepted: true,
})
require.NoError(t, err)
db.UserRoles().SetRolesForUser(ctx, database.SetRolesForUserOpts{
UserID: newAdminUser.ID,
Roles: []int32{normalUserRole.ID},
})
require.NoError(t, err)
ctx = actor.WithActor(ctx, actor.FromMockUser(newAdminUser.ID))
ctx = authenticateUser(t, db, ctx)
// Create populates the IDs in the passed in types.Repo
err = db.Repos().Create(ctx, &repo1, &repo2)
err := db.Repos().Create(ctx, &repo1, &repo2)
require.NoError(t, err)
files := map[api.RepoName]map[string][]byte{
@ -121,7 +95,7 @@ func TestContextResolver(t *testing.T) {
}
mockGitserver := gitserver.NewMockClient()
mockGitserver.GetDefaultBranchFunc.SetDefaultReturn("main", api.CommitID("abc123"), nil)
mockGitserver.GetDefaultBranchFunc.SetDefaultReturn("main", "abc123", nil)
mockGitserver.StatFunc.SetDefaultHook(func(_ context.Context, repo api.RepoName, _ api.CommitID, fileName string) (fs.FileInfo, error) {
return fakeFileInfo{path: fileName}, nil
})
@ -252,6 +226,42 @@ func TestContextResolver(t *testing.T) {
}
}
func authenticateUser(t *testing.T, db database.DB, ctx context.Context) context.Context {
// Create a normal user role with Cody access permission
normalUserRole, err := db.Roles().Create(ctx, "normal user role", false)
require.NoError(t, err)
codyAccessPermission, err := db.Permissions().Create(ctx, database.CreatePermissionOpts{
Namespace: rtypes.CodyNamespace,
Action: rtypes.CodyAccessAction,
})
require.NoError(t, err)
err = db.RolePermissions().Assign(ctx, database.AssignRolePermissionOpts{
PermissionID: codyAccessPermission.ID,
RoleID: normalUserRole.ID,
})
require.NoError(t, err)
// Create an admin user, give them the normal user role, and authenticate our actor.
newAdminUser, err := db.Users().Create(ctx, database.NewUser{
Email: "test@example.com",
Username: "test",
DisplayName: "Test User",
Password: "hunter123",
EmailIsVerified: true,
FailIfNotInitialUser: true, // initial site admin account
EnforcePasswordLength: false,
TosAccepted: true,
})
require.NoError(t, err)
db.UserRoles().SetRolesForUser(ctx, database.SetRolesForUserOpts{
UserID: newAdminUser.ID,
Roles: []int32{normalUserRole.ID},
})
require.NoError(t, err)
ctx = actor.WithActor(ctx, actor.FromMockUser(newAdminUser.ID))
return ctx
}
type fakeFileInfo struct {
path string
fs.FileInfo
@ -260,3 +270,192 @@ type fakeFileInfo struct {
func (f fakeFileInfo) Name() string {
return f.path
}
func TestChunkSize(t *testing.T) {
logger := logtest.Scoped(t)
ctx := context.Background()
db := database.NewDB(logger, dbtest.NewDB(t))
repo := types.Repo{Name: "repo"}
conf.Mock(&conf.Unified{
SiteConfiguration: schema.SiteConfiguration{
CodyEnabled: pointer.Bool(true),
LicenseKey: "asdf",
},
})
t.Cleanup(func() { conf.Mock(nil) })
oldMock := licensing.MockCheckFeature
defer func() {
licensing.MockCheckFeature = oldMock
}()
licensing.MockCheckFeature = func(feature licensing.Feature) error {
if feature == licensing.FeatureCody {
return nil
}
return errors.New("error")
}
ctx = authenticateUser(t, db, ctx)
// Create populates the IDs in the passed in types.Repo
err := db.Repos().Create(ctx, &repo)
require.NoError(t, err)
var content string
for i := 0; i < chunkSizeRunes; i++ {
if i != 0 && i%10 == 0 {
content += "\n"
} else {
content += "a"
}
}
wantLines := bytes.Count([]byte(content), []byte("\n"))
files := map[api.RepoName]map[string][]byte{
"repo": {
"testcode1.go": []byte(content),
"testcode2.go": []byte(content + "\n"),
"testcode3.go": []byte(content + "extra info"),
},
}
mockGitserver := gitserver.NewMockClient()
mockGitserver.GetDefaultBranchFunc.SetDefaultReturn("main", "abc123", nil)
mockGitserver.StatFunc.SetDefaultHook(func(_ context.Context, repo api.RepoName, _ api.CommitID, fileName string) (fs.FileInfo, error) {
return fakeFileInfo{path: fileName}, nil
})
mockGitserver.NewFileReaderFunc.SetDefaultHook(func(ctx context.Context, repo api.RepoName, ci api.CommitID, fileName string) (io.ReadCloser, error) {
if content, ok := files[repo][fileName]; ok {
return io.NopCloser(bytes.NewReader(content)), nil
}
return nil, os.ErrNotExist
})
mockEmbeddingsClient := embeddings.NewMockClient()
mockEmbeddingsClient.SearchFunc.SetDefaultReturn(nil, errors.New("embeddings should be disabled"))
lineRange := result.ChunkMatches{{
Ranges: result.Ranges{{
Start: result.Location{Line: 2},
End: result.Location{Line: 6},
}},
}}
mockSearchClient := client.NewMockSearchClient()
mockSearchClient.PlanFunc.SetDefaultHook(func(_ context.Context, _ string, _ *string, query string, _ search.Mode, _ search.Protocol, _ *int32) (*search.Inputs, error) {
return &search.Inputs{OriginalQuery: query}, nil
})
mockSearchClient.ExecuteFunc.SetDefaultHook(func(_ context.Context, stream streaming.Sender, inputs *search.Inputs) (*search.Alert, error) {
stream.Send(streaming.SearchEvent{
Results: result.Matches{&result.FileMatch{
File: result.File{
Path: "testcode1.go",
Repo: types.MinimalRepo{ID: repo.ID, Name: repo.Name},
},
ChunkMatches: lineRange,
}, &result.FileMatch{
File: result.File{
Path: "testcode2.go",
Repo: types.MinimalRepo{ID: repo.ID, Name: repo.Name},
},
ChunkMatches: lineRange,
}, &result.FileMatch{
File: result.File{
Path: "testcode3.go",
Repo: types.MinimalRepo{ID: repo.ID, Name: repo.Name},
},
ChunkMatches: lineRange,
}}})
return nil, nil
})
observationCtx := observation.TestContextTB(t)
contextClient := codycontext.NewCodyContextClient(
observationCtx,
db,
mockEmbeddingsClient,
mockSearchClient,
mockGitserver,
)
resolver := NewResolver(
db,
mockGitserver,
contextClient,
)
results, err := resolver.GetCodyContext(ctx, graphqlbackend.GetContextArgs{
Repos: graphqlbackend.MarshalRepositoryIDs([]api.RepoID{1}),
Query: "my test query",
TextResultsCount: 0,
CodeResultsCount: 5,
})
require.NoError(t, err)
require.Equal(t, 3, len(results), "expected 3 results but got %d", len(results))
for _, r := range results {
f := r.(*graphqlbackend.FileChunkContextResolver)
gotLines := int(f.EndLine() - f.StartLine() + 1)
require.Equal(t, wantLines, gotLines, "expected %d lines but got %d", wantLines, gotLines)
c, err := f.ChunkContent(ctx)
require.NoError(t, err)
require.LessOrEqual(t, utf8.RuneCountInString(c), chunkSizeRunes)
}
}
func TestCountLines(t *testing.T) {
tests := []struct {
name string
content string
numRunes int
want int
}{
{
name: "empty string",
content: "",
numRunes: 1024,
want: 0,
}, {
name: "single line",
content: "Hello world\n",
numRunes: 1024,
want: 1,
}, {}, {
name: "single line, no trailing newline",
content: "Hello world",
numRunes: 1024,
want: 1,
}, {
name: "should truncate",
content: "Hello\nWorld\nGo is awesome\n",
numRunes: 24,
want: 2,
}, {
name: "should truncate, no trailing newline",
content: "Hello\nWorld\nGo is awesome",
numRunes: 24,
want: 2,
}, {
name: "all lines",
content: "Hello\nWorld\nGo is awesome\n",
numRunes: 1024,
want: 3,
}, {
name: "all lines, no trailing newline",
content: "Hello\nWorld\nGo is awesome",
numRunes: 1024,
want: 3,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := countLines(tt.content, tt.numRunes); got != tt.want {
t.Errorf("countLines() = %v, want %v", got, tt.want)
}
})
}
}