mirror of
https://github.com/sourcegraph/sourcegraph.git
synced 2026-02-06 17:51:57 +00:00
Context: more precise chunk sizing (#62643)
Currently, when retrieving context chunks, we hardcode the number of lines to 20. Historically, we've limited chunks to 1024 characters, and we chose 20 lines to roughly mirror that. In evals, I found that we're often returning fewer than 1024 characters. This PR updates the context resolver to load an adaptive number of lines based on the 1024 character limit. Addresses #61745
This commit is contained in:
parent
3b394e7954
commit
553121c5a1
@ -44,9 +44,8 @@ func (f *FileChunkContextResolver) ToFileChunkContext() (*FileChunkContextResolv
|
||||
}
|
||||
|
||||
func (f *FileChunkContextResolver) ChunkContent(ctx context.Context) (string, error) {
|
||||
startLine, endLine := int32(f.startLine), int32(f.endLine)
|
||||
return f.treeEntry.Content(ctx, &GitTreeContentPageArgs{
|
||||
StartLine: &startLine,
|
||||
EndLine: &endLine,
|
||||
StartLine: &f.startLine,
|
||||
EndLine: &f.endLine,
|
||||
})
|
||||
}
|
||||
|
||||
@ -35,7 +35,6 @@ type FileChunkContext struct {
|
||||
CommitID api.CommitID
|
||||
Path string
|
||||
StartLine int
|
||||
EndLine int
|
||||
}
|
||||
|
||||
func NewCodyContextClient(obsCtx *observation.Context, db database.DB, embeddingsClient embeddings.Client, searchClient client.SearchClient, gitserverClient gitserver.Client) *CodyContextClient {
|
||||
@ -233,7 +232,6 @@ func (c *CodyContextClient) getEmbeddingsContext(ctx context.Context, args GetCo
|
||||
CommitID: result.Revision,
|
||||
Path: result.FileName,
|
||||
StartLine: result.StartLine,
|
||||
EndLine: result.EndLine,
|
||||
})
|
||||
}
|
||||
|
||||
@ -356,7 +354,6 @@ func fileMatchToContextMatches(fm *result.FileMatch) []FileChunkContext {
|
||||
CommitID: fm.CommitID,
|
||||
Path: fm.Path,
|
||||
StartLine: 0,
|
||||
EndLine: 20,
|
||||
}}
|
||||
}
|
||||
|
||||
@ -365,8 +362,6 @@ func fileMatchToContextMatches(fm *result.FileMatch) []FileChunkContext {
|
||||
|
||||
// 5 lines of leading context, clamped to zero
|
||||
startLine := max(0, fm.ChunkMatches[0].ContentStart.Line-5)
|
||||
// depend on content fetching to trim to the end of the file
|
||||
endLine := startLine + 20
|
||||
|
||||
return []FileChunkContext{{
|
||||
RepoName: fm.Repo.Name,
|
||||
@ -374,6 +369,5 @@ func fileMatchToContextMatches(fm *result.FileMatch) []FileChunkContext {
|
||||
CommitID: fm.CommitID,
|
||||
Path: fm.Path,
|
||||
StartLine: startLine,
|
||||
EndLine: endLine,
|
||||
}}
|
||||
}
|
||||
|
||||
@ -33,7 +33,6 @@ func TestFileMatchToContextMatches(t *testing.T) {
|
||||
CommitID: "abc123",
|
||||
Path: "main.go",
|
||||
StartLine: 0,
|
||||
EndLine: 20,
|
||||
}},
|
||||
},
|
||||
{
|
||||
@ -61,7 +60,6 @@ func TestFileMatchToContextMatches(t *testing.T) {
|
||||
CommitID: "abc123",
|
||||
Path: "main.go",
|
||||
StartLine: 85,
|
||||
EndLine: 105,
|
||||
}},
|
||||
},
|
||||
}
|
||||
|
||||
@ -14,6 +14,7 @@ go_library(
|
||||
"//internal/trace",
|
||||
"//internal/types",
|
||||
"//lib/errors",
|
||||
"//lib/pointers",
|
||||
"@com_github_sourcegraph_conc//iter",
|
||||
],
|
||||
)
|
||||
@ -40,9 +41,12 @@ go_test(
|
||||
"//internal/licensing",
|
||||
"//internal/observation",
|
||||
"//internal/rbac/types",
|
||||
"//internal/search",
|
||||
"//internal/search/backend",
|
||||
"//internal/search/client",
|
||||
"//internal/search/job",
|
||||
"//internal/search/result",
|
||||
"//internal/search/streaming",
|
||||
"//internal/types",
|
||||
"//lib/errors",
|
||||
"//lib/pointers",
|
||||
@ -50,5 +54,6 @@ go_test(
|
||||
"@com_github_sourcegraph_log//logtest",
|
||||
"@com_github_sourcegraph_zoekt//:zoekt",
|
||||
"@com_github_stretchr_testify//require",
|
||||
"@io_k8s_utils//pointer",
|
||||
],
|
||||
)
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
package resolvers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
|
||||
"github.com/sourcegraph/conc/iter"
|
||||
@ -12,6 +13,7 @@ import (
|
||||
"github.com/sourcegraph/sourcegraph/internal/trace"
|
||||
"github.com/sourcegraph/sourcegraph/internal/types"
|
||||
"github.com/sourcegraph/sourcegraph/lib/errors"
|
||||
"github.com/sourcegraph/sourcegraph/lib/pointers"
|
||||
)
|
||||
|
||||
func NewResolver(db database.DB, gitserverClient gitserver.Client, contextClient *codycontext.CodyContextClient) graphqlbackend.CodyContextResolver {
|
||||
@ -68,6 +70,13 @@ func (r *Resolver) GetCodyContext(ctx context.Context, args graphqlbackend.GetCo
|
||||
})
|
||||
}
|
||||
|
||||
// The rough size of a file chunk in runes. The value 1024 is due to historical reasons -- Cody context was once based
|
||||
// on embeddings, and we chunked files into ~1024 characters (aiming for 256 tokens, assuming each token takes 4
|
||||
// characters on average).
|
||||
//
|
||||
// Ideally, the caller would pass a token 'budget' and we'd use a tokenizer and attempt to exactly match this budget.
|
||||
const chunkSizeRunes = 1024
|
||||
|
||||
func (r *Resolver) fileChunkToResolver(ctx context.Context, chunk *codycontext.FileChunkContext) (graphqlbackend.ContextResultResolver, error) {
|
||||
repoResolver := graphqlbackend.NewMinimalRepositoryResolver(r.db, r.gitserverClient, chunk.RepoID, chunk.RepoName)
|
||||
|
||||
@ -83,6 +92,31 @@ func (r *Resolver) fileChunkToResolver(ctx context.Context, chunk *codycontext.F
|
||||
})
|
||||
|
||||
// Populate content ahead of time so we can do it concurrently
|
||||
gitTreeEntryResolver.Content(ctx, &graphqlbackend.GitTreeContentPageArgs{})
|
||||
return graphqlbackend.NewFileChunkContextResolver(gitTreeEntryResolver, chunk.StartLine, chunk.EndLine), nil
|
||||
content, err := gitTreeEntryResolver.Content(ctx, &graphqlbackend.GitTreeContentPageArgs{
|
||||
StartLine: pointers.Ptr(int32(chunk.StartLine)),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
numLines := countLines(content, chunkSizeRunes)
|
||||
endLine := chunk.StartLine + numLines - 1 // subtract 1 because endLine is inclusive
|
||||
return graphqlbackend.NewFileChunkContextResolver(gitTreeEntryResolver, chunk.StartLine, endLine), nil
|
||||
}
|
||||
|
||||
// countLines finds the number of lines corresponding to the number of runes. We 'round down'
|
||||
// to ensure that we don't return more characters than our budget.
|
||||
func countLines(content string, numRunes int) int {
|
||||
if len(content) == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
if content[len(content)-1] != '\n' {
|
||||
content += "\n"
|
||||
}
|
||||
|
||||
runes := []rune(content)
|
||||
truncated := runes[:min(len(runes), numRunes)]
|
||||
in := []byte(string(truncated))
|
||||
return bytes.Count(in, []byte("\n"))
|
||||
}
|
||||
|
||||
@ -8,10 +8,12 @@ import (
|
||||
"os"
|
||||
"sort"
|
||||
"testing"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/sourcegraph/log/logtest"
|
||||
"github.com/sourcegraph/zoekt"
|
||||
"github.com/stretchr/testify/require"
|
||||
"k8s.io/utils/pointer"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/graphqlbackend"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/internal/codycontext"
|
||||
@ -26,16 +28,19 @@ import (
|
||||
"github.com/sourcegraph/sourcegraph/internal/licensing"
|
||||
"github.com/sourcegraph/sourcegraph/internal/observation"
|
||||
rtypes "github.com/sourcegraph/sourcegraph/internal/rbac/types"
|
||||
"github.com/sourcegraph/sourcegraph/internal/search"
|
||||
"github.com/sourcegraph/sourcegraph/internal/search/backend"
|
||||
"github.com/sourcegraph/sourcegraph/internal/search/client"
|
||||
"github.com/sourcegraph/sourcegraph/internal/search/job"
|
||||
"github.com/sourcegraph/sourcegraph/internal/search/result"
|
||||
"github.com/sourcegraph/sourcegraph/internal/search/streaming"
|
||||
"github.com/sourcegraph/sourcegraph/internal/types"
|
||||
"github.com/sourcegraph/sourcegraph/lib/errors"
|
||||
"github.com/sourcegraph/sourcegraph/lib/pointers"
|
||||
"github.com/sourcegraph/sourcegraph/schema"
|
||||
)
|
||||
|
||||
func TestContextResolver(t *testing.T) {
|
||||
func TestCodyIgnore(t *testing.T) {
|
||||
logger := logtest.Scoped(t)
|
||||
ctx := context.Background()
|
||||
db := database.NewDB(logger, dbtest.NewDB(t))
|
||||
@ -70,41 +75,10 @@ func TestContextResolver(t *testing.T) {
|
||||
return errors.New("error")
|
||||
}
|
||||
|
||||
// Create a normal user role with Cody access permission
|
||||
normalUserRole, err := db.Roles().Create(ctx, "normal user role", false)
|
||||
require.NoError(t, err)
|
||||
codyAccessPermission, err := db.Permissions().Create(ctx, database.CreatePermissionOpts{
|
||||
Namespace: rtypes.CodyNamespace,
|
||||
Action: rtypes.CodyAccessAction,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
err = db.RolePermissions().Assign(ctx, database.AssignRolePermissionOpts{
|
||||
PermissionID: codyAccessPermission.ID,
|
||||
RoleID: normalUserRole.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create an admin user, give them the normal user role, and authenticate our actor.
|
||||
newAdminUser, err := db.Users().Create(ctx, database.NewUser{
|
||||
Email: "test@example.com",
|
||||
Username: "test",
|
||||
DisplayName: "Test User",
|
||||
Password: "hunter123",
|
||||
EmailIsVerified: true,
|
||||
FailIfNotInitialUser: true, // initial site admin account
|
||||
EnforcePasswordLength: false,
|
||||
TosAccepted: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
db.UserRoles().SetRolesForUser(ctx, database.SetRolesForUserOpts{
|
||||
UserID: newAdminUser.ID,
|
||||
Roles: []int32{normalUserRole.ID},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
ctx = actor.WithActor(ctx, actor.FromMockUser(newAdminUser.ID))
|
||||
ctx = authenticateUser(t, db, ctx)
|
||||
|
||||
// Create populates the IDs in the passed in types.Repo
|
||||
err = db.Repos().Create(ctx, &repo1, &repo2)
|
||||
err := db.Repos().Create(ctx, &repo1, &repo2)
|
||||
require.NoError(t, err)
|
||||
|
||||
files := map[api.RepoName]map[string][]byte{
|
||||
@ -121,7 +95,7 @@ func TestContextResolver(t *testing.T) {
|
||||
}
|
||||
|
||||
mockGitserver := gitserver.NewMockClient()
|
||||
mockGitserver.GetDefaultBranchFunc.SetDefaultReturn("main", api.CommitID("abc123"), nil)
|
||||
mockGitserver.GetDefaultBranchFunc.SetDefaultReturn("main", "abc123", nil)
|
||||
mockGitserver.StatFunc.SetDefaultHook(func(_ context.Context, repo api.RepoName, _ api.CommitID, fileName string) (fs.FileInfo, error) {
|
||||
return fakeFileInfo{path: fileName}, nil
|
||||
})
|
||||
@ -252,6 +226,42 @@ func TestContextResolver(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func authenticateUser(t *testing.T, db database.DB, ctx context.Context) context.Context {
|
||||
// Create a normal user role with Cody access permission
|
||||
normalUserRole, err := db.Roles().Create(ctx, "normal user role", false)
|
||||
require.NoError(t, err)
|
||||
codyAccessPermission, err := db.Permissions().Create(ctx, database.CreatePermissionOpts{
|
||||
Namespace: rtypes.CodyNamespace,
|
||||
Action: rtypes.CodyAccessAction,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
err = db.RolePermissions().Assign(ctx, database.AssignRolePermissionOpts{
|
||||
PermissionID: codyAccessPermission.ID,
|
||||
RoleID: normalUserRole.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create an admin user, give them the normal user role, and authenticate our actor.
|
||||
newAdminUser, err := db.Users().Create(ctx, database.NewUser{
|
||||
Email: "test@example.com",
|
||||
Username: "test",
|
||||
DisplayName: "Test User",
|
||||
Password: "hunter123",
|
||||
EmailIsVerified: true,
|
||||
FailIfNotInitialUser: true, // initial site admin account
|
||||
EnforcePasswordLength: false,
|
||||
TosAccepted: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
db.UserRoles().SetRolesForUser(ctx, database.SetRolesForUserOpts{
|
||||
UserID: newAdminUser.ID,
|
||||
Roles: []int32{normalUserRole.ID},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
ctx = actor.WithActor(ctx, actor.FromMockUser(newAdminUser.ID))
|
||||
return ctx
|
||||
}
|
||||
|
||||
type fakeFileInfo struct {
|
||||
path string
|
||||
fs.FileInfo
|
||||
@ -260,3 +270,192 @@ type fakeFileInfo struct {
|
||||
func (f fakeFileInfo) Name() string {
|
||||
return f.path
|
||||
}
|
||||
|
||||
func TestChunkSize(t *testing.T) {
|
||||
logger := logtest.Scoped(t)
|
||||
ctx := context.Background()
|
||||
db := database.NewDB(logger, dbtest.NewDB(t))
|
||||
repo := types.Repo{Name: "repo"}
|
||||
|
||||
conf.Mock(&conf.Unified{
|
||||
SiteConfiguration: schema.SiteConfiguration{
|
||||
CodyEnabled: pointer.Bool(true),
|
||||
LicenseKey: "asdf",
|
||||
},
|
||||
})
|
||||
t.Cleanup(func() { conf.Mock(nil) })
|
||||
|
||||
oldMock := licensing.MockCheckFeature
|
||||
defer func() {
|
||||
licensing.MockCheckFeature = oldMock
|
||||
}()
|
||||
|
||||
licensing.MockCheckFeature = func(feature licensing.Feature) error {
|
||||
if feature == licensing.FeatureCody {
|
||||
return nil
|
||||
}
|
||||
return errors.New("error")
|
||||
}
|
||||
|
||||
ctx = authenticateUser(t, db, ctx)
|
||||
|
||||
// Create populates the IDs in the passed in types.Repo
|
||||
err := db.Repos().Create(ctx, &repo)
|
||||
require.NoError(t, err)
|
||||
|
||||
var content string
|
||||
for i := 0; i < chunkSizeRunes; i++ {
|
||||
if i != 0 && i%10 == 0 {
|
||||
content += "\n"
|
||||
} else {
|
||||
content += "a"
|
||||
}
|
||||
}
|
||||
|
||||
wantLines := bytes.Count([]byte(content), []byte("\n"))
|
||||
|
||||
files := map[api.RepoName]map[string][]byte{
|
||||
"repo": {
|
||||
"testcode1.go": []byte(content),
|
||||
"testcode2.go": []byte(content + "\n"),
|
||||
"testcode3.go": []byte(content + "extra info"),
|
||||
},
|
||||
}
|
||||
|
||||
mockGitserver := gitserver.NewMockClient()
|
||||
mockGitserver.GetDefaultBranchFunc.SetDefaultReturn("main", "abc123", nil)
|
||||
mockGitserver.StatFunc.SetDefaultHook(func(_ context.Context, repo api.RepoName, _ api.CommitID, fileName string) (fs.FileInfo, error) {
|
||||
return fakeFileInfo{path: fileName}, nil
|
||||
})
|
||||
mockGitserver.NewFileReaderFunc.SetDefaultHook(func(ctx context.Context, repo api.RepoName, ci api.CommitID, fileName string) (io.ReadCloser, error) {
|
||||
if content, ok := files[repo][fileName]; ok {
|
||||
return io.NopCloser(bytes.NewReader(content)), nil
|
||||
}
|
||||
return nil, os.ErrNotExist
|
||||
})
|
||||
|
||||
mockEmbeddingsClient := embeddings.NewMockClient()
|
||||
mockEmbeddingsClient.SearchFunc.SetDefaultReturn(nil, errors.New("embeddings should be disabled"))
|
||||
|
||||
lineRange := result.ChunkMatches{{
|
||||
Ranges: result.Ranges{{
|
||||
Start: result.Location{Line: 2},
|
||||
End: result.Location{Line: 6},
|
||||
}},
|
||||
}}
|
||||
|
||||
mockSearchClient := client.NewMockSearchClient()
|
||||
mockSearchClient.PlanFunc.SetDefaultHook(func(_ context.Context, _ string, _ *string, query string, _ search.Mode, _ search.Protocol, _ *int32) (*search.Inputs, error) {
|
||||
return &search.Inputs{OriginalQuery: query}, nil
|
||||
})
|
||||
mockSearchClient.ExecuteFunc.SetDefaultHook(func(_ context.Context, stream streaming.Sender, inputs *search.Inputs) (*search.Alert, error) {
|
||||
stream.Send(streaming.SearchEvent{
|
||||
Results: result.Matches{&result.FileMatch{
|
||||
File: result.File{
|
||||
Path: "testcode1.go",
|
||||
Repo: types.MinimalRepo{ID: repo.ID, Name: repo.Name},
|
||||
},
|
||||
ChunkMatches: lineRange,
|
||||
}, &result.FileMatch{
|
||||
File: result.File{
|
||||
Path: "testcode2.go",
|
||||
Repo: types.MinimalRepo{ID: repo.ID, Name: repo.Name},
|
||||
},
|
||||
ChunkMatches: lineRange,
|
||||
}, &result.FileMatch{
|
||||
File: result.File{
|
||||
Path: "testcode3.go",
|
||||
Repo: types.MinimalRepo{ID: repo.ID, Name: repo.Name},
|
||||
},
|
||||
ChunkMatches: lineRange,
|
||||
}}})
|
||||
return nil, nil
|
||||
})
|
||||
|
||||
observationCtx := observation.TestContextTB(t)
|
||||
contextClient := codycontext.NewCodyContextClient(
|
||||
observationCtx,
|
||||
db,
|
||||
mockEmbeddingsClient,
|
||||
mockSearchClient,
|
||||
mockGitserver,
|
||||
)
|
||||
|
||||
resolver := NewResolver(
|
||||
db,
|
||||
mockGitserver,
|
||||
contextClient,
|
||||
)
|
||||
|
||||
results, err := resolver.GetCodyContext(ctx, graphqlbackend.GetContextArgs{
|
||||
Repos: graphqlbackend.MarshalRepositoryIDs([]api.RepoID{1}),
|
||||
Query: "my test query",
|
||||
TextResultsCount: 0,
|
||||
CodeResultsCount: 5,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 3, len(results), "expected 3 results but got %d", len(results))
|
||||
|
||||
for _, r := range results {
|
||||
f := r.(*graphqlbackend.FileChunkContextResolver)
|
||||
gotLines := int(f.EndLine() - f.StartLine() + 1)
|
||||
require.Equal(t, wantLines, gotLines, "expected %d lines but got %d", wantLines, gotLines)
|
||||
|
||||
c, err := f.ChunkContent(ctx)
|
||||
require.NoError(t, err)
|
||||
require.LessOrEqual(t, utf8.RuneCountInString(c), chunkSizeRunes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCountLines(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
numRunes int
|
||||
want int
|
||||
}{
|
||||
{
|
||||
name: "empty string",
|
||||
content: "",
|
||||
numRunes: 1024,
|
||||
want: 0,
|
||||
}, {
|
||||
name: "single line",
|
||||
content: "Hello world\n",
|
||||
numRunes: 1024,
|
||||
want: 1,
|
||||
}, {}, {
|
||||
name: "single line, no trailing newline",
|
||||
content: "Hello world",
|
||||
numRunes: 1024,
|
||||
want: 1,
|
||||
}, {
|
||||
name: "should truncate",
|
||||
content: "Hello\nWorld\nGo is awesome\n",
|
||||
numRunes: 24,
|
||||
want: 2,
|
||||
}, {
|
||||
name: "should truncate, no trailing newline",
|
||||
content: "Hello\nWorld\nGo is awesome",
|
||||
numRunes: 24,
|
||||
want: 2,
|
||||
}, {
|
||||
name: "all lines",
|
||||
content: "Hello\nWorld\nGo is awesome\n",
|
||||
numRunes: 1024,
|
||||
want: 3,
|
||||
}, {
|
||||
name: "all lines, no trailing newline",
|
||||
content: "Hello\nWorld\nGo is awesome",
|
||||
numRunes: 1024,
|
||||
want: 3,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := countLines(tt.content, tt.numRunes); got != tt.want {
|
||||
t.Errorf("countLines() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user