diff --git a/internal/codeintel/codenav/gittree_translator.go b/internal/codeintel/codenav/gittree_translator.go index 57f9fa89516..39a5f1ff43a 100644 --- a/internal/codeintel/codenav/gittree_translator.go +++ b/internal/codeintel/codenav/gittree_translator.go @@ -22,6 +22,271 @@ import ( "github.com/sourcegraph/sourcegraph/lib/pointers" ) +type CompactGitTreeTranslator interface { + // TranslatePosition returns None if the given position is on a line that was removed or modified + // between from and to + TranslatePosition( + ctx context.Context, from api.CommitID, to api.CommitID, path core.RepoRelPath, position scip.Position, + ) (core.Option[scip.Position], error) + + // TranslateRange returns None if its start or end positions are on a line that was removed or modified + // between from and to + TranslateRange( + ctx context.Context, from api.CommitID, to api.CommitID, path core.RepoRelPath, range_ scip.Range, + ) (core.Option[scip.Range], error) + + // Prefetch will set-up the cache and kick off a diff command for the given paths. It returns immediately + // and does not wait for the diff to complete. + Prefetch(ctx context.Context, from api.CommitID, to api.CommitID, paths []core.RepoRelPath) +} + +func NewCompactGitTreeTranslator(client gitserver.Client, repo sgtypes.Repo) CompactGitTreeTranslator { + return &newTranslator{ + client: client, + repo: repo, + hunkCache: make(map[string]func() ([]compactHunk, error)), + } +} + +type newTranslator struct { + client gitserver.Client + repo sgtypes.Repo + cacheLock sync.RWMutex + hunkCache map[string]func() ([]compactHunk, error) +} + +func (t *newTranslator) TranslatePosition( + ctx context.Context, from api.CommitID, to api.CommitID, path core.RepoRelPath, pos scip.Position, +) (core.Option[scip.Position], error) { + if from == to { + return core.Some(pos), nil + } + hunks, err := t.readCachedHunks(ctx, from, to, path) + if err != nil { + return core.None[scip.Position](), err + } + return translatePosition(hunks, pos), nil +} + +func (t *newTranslator) TranslateRange( + ctx context.Context, from api.CommitID, to api.CommitID, path core.RepoRelPath, range_ scip.Range, +) (core.Option[scip.Range], error) { + if from == to { + return core.Some(range_), nil + } + hunks, err := t.readCachedHunks(ctx, from, to, path) + if err != nil { + return core.None[scip.Range](), err + } + return translateRange(hunks, range_), nil +} + +func (t *newTranslator) readCachedHunks( + ctx context.Context, from api.CommitID, to api.CommitID, path core.RepoRelPath, +) (_ []compactHunk, err error) { + _ = t.fetchHunks(ctx, from, to, path) + t.cacheLock.RLock() + hunkFunc := t.hunkCache[makeTypedKey(from, to, path)] + t.cacheLock.RUnlock() + return hunkFunc() +} + +func (t *newTranslator) Prefetch(ctx context.Context, from api.CommitID, to api.CommitID, paths []core.RepoRelPath) { + run := t.fetchHunks(ctx, from, to, paths...) + // Kick off the actual diff command + go func() { run() }() +} + +func (t *newTranslator) fetchHunks(ctx context.Context, from api.CommitID, to api.CommitID, paths ...core.RepoRelPath) func() { + t.cacheLock.Lock() + defer t.cacheLock.Unlock() + paths = genslices.Filter(paths, func(path core.RepoRelPath) bool { + _, ok := t.hunkCache[makeTypedKey(from, to, path)] + return !ok + }) + if len(paths) == 0 { + return func() {} + } + onceHunks := sync.OnceValues(func() (map[core.RepoRelPath][]compactHunk, error) { + return t.runDiff(ctx, from, to, paths) + }) + for _, path := range paths { + key := makeTypedKey(from, to, path) + t.hunkCache[key] = sync.OnceValues(func() ([]compactHunk, error) { + hunkss, err := onceHunks() + if err != nil { + return []compactHunk{}, nil + } + hunks, ok := hunkss[path] + if !ok { + return []compactHunk{}, nil + } + return hunks, nil + }) + } + return func() { + _, _ = onceHunks() + } +} + +func (t *newTranslator) runDiff(ctx context.Context, from api.CommitID, to api.CommitID, paths []core.RepoRelPath) (map[core.RepoRelPath][]compactHunk, error) { + r, err := t.client.Diff(ctx, t.repo.Name, gitserver.DiffOptions{ + Base: string(from), + Head: string(to), + Paths: genslices.Map(paths, func(p core.RepoRelPath) string { return p.RawValue() }), + RangeType: "..", + InterHunkContext: pointers.Ptr(0), + ContextLines: pointers.Ptr(0), + }) + if err != nil { + return nil, err + } + defer func() { + closeErr := r.Close() + if err == nil { + err = closeErr + } + }() + fds := make(map[core.RepoRelPath][]compactHunk) + for { + fd, err := r.Next() + if err != nil { + if errors.Is(err, io.EOF) { + return fds, nil + } else { + return nil, err + } + } + if fd.OrigName != fd.NewName { + // We cannot handle file renames + continue + } + fds[core.NewRepoRelPathUnchecked(fd.OrigName)] = genslices.Map(fd.Hunks, newCompactHunk) + } +} + +func precedingHunk(hunks []compactHunk, line int32) core.Option[compactHunk] { + line += 1 // diff hunks are 1-based, compared to our 0-based scip ranges + precedingHunkIx, found := slices.BinarySearchFunc(hunks, line, func(h compactHunk, l int32) int { + return cmp.Compare(h.origStartLine, l) + }) + if precedingHunkIx == 0 && !found { + // No preceding hunk means the position was not affected by any hunks + return core.None[compactHunk]() + } + ix := precedingHunkIx + if !found { + ix -= 1 + } + return core.Some(hunks[ix]) +} + +func newTranslateLine( + hunks []compactHunk, + line int32, +) core.Option[int32] { + hunk, ok := precedingHunk(hunks, line).Get() + if !ok { + return core.Some(line) + } + return hunk.ShiftLine(line) +} + +func translatePosition( + hunks []compactHunk, + pos scip.Position, +) core.Option[scip.Position] { + hunk, ok := precedingHunk(hunks, pos.Line).Get() + if !ok { + return core.Some(pos) + } + return hunk.ShiftPosition(pos) +} + +func translateRange( + hunks []compactHunk, + range_ scip.Range, +) core.Option[scip.Range] { + // Fast path for single-line ranges + if range_.Start.Line == range_.End.Line { + newLine, ok := newTranslateLine(hunks, range_.Start.Line).Get() + if !ok { + return core.None[scip.Range]() + } + return core.Some(scip.Range{ + Start: scip.Position{Line: newLine, Character: range_.Start.Character}, + End: scip.Position{Line: newLine, Character: range_.End.Character}, + }) + } + + start, ok := translatePosition(hunks, range_.Start).Get() + if !ok { + return core.None[scip.Range]() + } + end, ok := translatePosition(hunks, range_.End).Get() + if !ok { + return core.None[scip.Range]() + } + return core.Some(scip.Range{Start: start, End: end}) +} + +type compactHunk struct { + // starting line number in original file + origStartLine int32 + // number of lines the hunk applies to in the original file + origLines int32 + // starting line number in new file + newStartLine int32 + // number of lines the hunk applies to in the new file + newLines int32 +} + +func newCompactHunk(h *diff.Hunk) compactHunk { + // If either origLines or newLines are 0, their corresponding line is shifted by an additional -1 + // in the `git diff` output, to make it clear to the user that the line is not included in the + // displayed hunk. + // For our purposes we need the actual start line of the hunk though + origStartLine := h.OrigStartLine + if h.OrigLines == 0 { + origStartLine += 1 + } + newStartLine := h.NewStartLine + if h.NewLines == 0 { + newStartLine += 1 + } + return compactHunk{ + origStartLine: origStartLine, + origLines: h.OrigLines, + newStartLine: newStartLine, + newLines: h.NewLines, + } +} + +func (h *compactHunk) OverlapsLine(line int32) bool { + // git diff hunks are 1-based, vs our 0-based scip ranges + return h.origStartLine <= line+1 && h.origStartLine+h.origLines > line+1 +} + +func (h *compactHunk) ShiftLine(line int32) core.Option[int32] { + if h.OverlapsLine(line) { + return core.None[int32]() + } + originalSpan := h.origStartLine + h.origLines + newSpan := h.newStartLine + h.newLines + return core.Some(line + newSpan - originalSpan) +} + +func (h *compactHunk) ShiftPosition(position scip.Position) core.Option[scip.Position] { + newLine, ok := h.ShiftLine(position.Line).Get() + if !ok { + return core.None[scip.Position]() + } + if newLine == position.Line { + return core.Some(position) + } + return core.Some(scip.Position{Line: newLine, Character: position.Character}) +} + // GitTreeTranslator translates a position within a git tree at a source commit into the // equivalent position in a target commit. The git tree translator instance carries // along with it the source commit. @@ -43,6 +308,7 @@ import ( // inconsistency when modifying the APIs below, as they take different values for 'reverse' // in production). type GitTreeTranslator interface { + Prefetch(ctx context.Context, from api.CommitID, to api.CommitID, paths []core.RepoRelPath) // GetTargetCommitPositionFromSourcePosition translates the given position from the source commit into the given // target commit. The target commit's position is returned, along with a boolean flag // indicating that the translation was successful. If reverse is true, then the source and @@ -123,177 +389,6 @@ func NewGitTreeTranslator(client gitserver.Client, base *TranslationBase, hunkCa } } -type CompactGitTreeTranslator interface { - // TranslatePosition returns None if the given position is on a line that was removed or modified - // between from and to - TranslatePosition( - ctx context.Context, from api.CommitID, to api.CommitID, path core.RepoRelPath, position scip.Position, - ) (core.Option[scip.Position], error) - - // TranslateRange returns None if its start or end positions are on a line that was removed or modified - // between from and to - TranslateRange( - ctx context.Context, from api.CommitID, to api.CommitID, path core.RepoRelPath, range_ scip.Range, - ) (core.Option[scip.Range], error) - - // TODO: Batch APIs/pre-fetching data from gitserver? -} - -func NewCompactGitTreeTranslator(client gitserver.Client, repo sgtypes.Repo) CompactGitTreeTranslator { - return &newTranslator{ - client: client, - repo: repo, - hunkCache: make(map[string]func() ([]compactHunk, error)), - } -} - -type newTranslator struct { - client gitserver.Client - repo sgtypes.Repo - cacheLock sync.RWMutex - hunkCache map[string]func() ([]compactHunk, error) -} - -func (t *newTranslator) TranslatePosition( - ctx context.Context, from api.CommitID, to api.CommitID, path core.RepoRelPath, pos scip.Position, -) (core.Option[scip.Position], error) { - if from == to { - return core.Some(pos), nil - } - hunks, err := t.readCachedHunks(ctx, from, to, path) - if err != nil { - return core.None[scip.Position](), err - } - return translatePosition(hunks, pos), nil -} - -func (t *newTranslator) TranslateRange( - ctx context.Context, from api.CommitID, to api.CommitID, path core.RepoRelPath, range_ scip.Range, -) (core.Option[scip.Range], error) { - if from == to { - return core.Some(range_), nil - } - hunks, err := t.readCachedHunks(ctx, from, to, path) - if err != nil { - return core.None[scip.Range](), err - } - return translateRange(hunks, range_), nil -} - -func (t *newTranslator) readCachedHunks( - ctx context.Context, from api.CommitID, to api.CommitID, path core.RepoRelPath, -) (_ []compactHunk, err error) { - key := makeTypedKey(from, to, path) - t.cacheLock.RLock() - hunksFunc, ok := t.hunkCache[key] - t.cacheLock.RUnlock() - if !ok { - t.cacheLock.Lock() - hunksFunc = sync.OnceValues(func() ([]compactHunk, error) { - return t.readHunks(ctx, from, to, path) - }) - t.hunkCache[key] = hunksFunc - t.cacheLock.Unlock() - } - return hunksFunc() -} - -func (t *newTranslator) readHunks( - ctx context.Context, from api.CommitID, to api.CommitID, path core.RepoRelPath, -) (_ []compactHunk, err error) { - r, err := t.client.Diff(ctx, t.repo.Name, gitserver.DiffOptions{ - Base: string(from), - Head: string(to), - Paths: []string{path.RawValue()}, - RangeType: "..", - InterHunkContext: pointers.Ptr(0), - ContextLines: pointers.Ptr(0), - }) - if err != nil { - return nil, err - } - defer func() { - closeErr := r.Close() - if err == nil { - err = closeErr - } - }() - - fd, err := r.Next() - if err != nil { - if errors.Is(err, io.EOF) { - return nil, nil - } - return nil, err - } - return genslices.Map(fd.Hunks, newCompactHunk), nil -} - -func precedingHunk(hunks []compactHunk, line int32) core.Option[compactHunk] { - line += 1 // diff hunks are 1-based, compared to our 0-based scip ranges - precedingHunkIx, found := slices.BinarySearchFunc(hunks, line, func(h compactHunk, l int32) int { - return cmp.Compare(h.origStartLine, l) - }) - if precedingHunkIx == 0 && !found { - // No preceding hunk means the position was not affected by any hunks - return core.None[compactHunk]() - } - ix := precedingHunkIx - if !found { - ix -= 1 - } - return core.Some(hunks[ix]) -} - -func newTranslateLine( - hunks []compactHunk, - line int32, -) core.Option[int32] { - hunk, ok := precedingHunk(hunks, line).Get() - if !ok { - return core.Some(line) - } - return hunk.ShiftLine(line) -} - -func translatePosition( - hunks []compactHunk, - pos scip.Position, -) core.Option[scip.Position] { - hunk, ok := precedingHunk(hunks, pos.Line).Get() - if !ok { - return core.Some(pos) - } - return hunk.ShiftPosition(pos) -} - -func translateRange( - hunks []compactHunk, - range_ scip.Range, -) core.Option[scip.Range] { - // Fast path for single-line ranges - if range_.Start.Line == range_.End.Line { - newLine, ok := newTranslateLine(hunks, range_.Start.Line).Get() - if !ok { - return core.None[scip.Range]() - } - return core.Some(scip.Range{ - Start: scip.Position{Line: newLine, Character: range_.Start.Character}, - End: scip.Position{Line: newLine, Character: range_.End.Character}, - }) - } - - start, ok := translatePosition(hunks, range_.Start).Get() - if !ok { - return core.None[scip.Range]() - } - end, ok := translatePosition(hunks, range_.End).Get() - if !ok { - return core.None[scip.Range]() - } - return core.Some(scip.Range{Start: start, End: end}) -} - // GetTargetCommitPositionFromSourcePosition translates the given position from the source commit into the given // target commit. The target commit position is returned, along with a boolean flag // indicating that the translation was successful. If reverse is true, then the source and @@ -332,66 +427,13 @@ func (g *gitTreeTranslator) GetSourceCommit() api.CommitID { return g.base.Commit } +func (g *gitTreeTranslator) Prefetch(ctx context.Context, from api.CommitID, to api.CommitID, paths []core.RepoRelPath) { + g.compact.Prefetch(ctx, from, to, paths) +} + func makeTypedKey(from api.CommitID, to api.CommitID, path core.RepoRelPath) string { return makeKey(string(from), string(to), path.RawValue()) } func makeKey(parts ...string) string { return strings.Join(parts, ":") } - -type compactHunk struct { - // starting line number in original file - origStartLine int32 - // number of lines the hunk applies to in the original file - origLines int32 - // starting line number in new file - newStartLine int32 - // number of lines the hunk applies to in the new file - newLines int32 -} - -func newCompactHunk(h *diff.Hunk) compactHunk { - // If either origLines or newLines are 0, their corresponding line is shifted by an additional -1 - // in the `git diff` output, to make it clear to the user that the line is not included in the - // displayed hunk. - // For our purposes we need the actual start line of the hunk though - origStartLine := h.OrigStartLine - if h.OrigLines == 0 { - origStartLine += 1 - } - newStartLine := h.NewStartLine - if h.NewLines == 0 { - newStartLine += 1 - } - return compactHunk{ - origStartLine: origStartLine, - origLines: h.OrigLines, - newStartLine: newStartLine, - newLines: h.NewLines, - } -} - -func (h *compactHunk) OverlapsLine(line int32) bool { - // git diff hunks are 1-based, vs our 0-based scip ranges - return h.origStartLine <= line+1 && h.origStartLine+h.origLines > line+1 -} - -func (h *compactHunk) ShiftLine(line int32) core.Option[int32] { - if h.OverlapsLine(line) { - return core.None[int32]() - } - originalSpan := h.origStartLine + h.origLines - newSpan := h.newStartLine + h.newLines - return core.Some(line + newSpan - originalSpan) -} - -func (h *compactHunk) ShiftPosition(position scip.Position) core.Option[scip.Position] { - newLine, ok := h.ShiftLine(position.Line).Get() - if !ok { - return core.None[scip.Position]() - } - if newLine == position.Line { - return core.Some(position) - } - return core.Some(scip.Position{Line: newLine, Character: position.Character}) -} diff --git a/internal/codeintel/codenav/gittree_translator_test.go b/internal/codeintel/codenav/gittree_translator_test.go index 859aff3cbd4..ca6fa3e1342 100644 --- a/internal/codeintel/codenav/gittree_translator_test.go +++ b/internal/codeintel/codenav/gittree_translator_test.go @@ -24,15 +24,16 @@ var mockTranslationBase = TranslationBase{ } func TestGetTargetCommitPositionFromSourcePosition(t *testing.T) { - client := gitserver.NewMockClientWithExecReader(nil, func(_ context.Context, _ api.RepoName, args []string) (reader io.ReadCloser, err error) { - return io.NopCloser(bytes.NewReader([]byte(hugoDiff))), nil + client := gitserver.NewMockClient() + client.DiffFunc.SetDefaultHook(func(ctx context.Context, rn api.RepoName, do gitserver.DiffOptions) (*gitserver.DiffFileIterator, error) { + return gitserver.NewDiffFileIterator(io.NopCloser(bytes.NewReader([]byte(hugoDiff)))), nil }) posIn := shared.Position{Line: 302, Character: 15} args := &mockTranslationBase adjuster := NewGitTreeTranslator(client, args, nil) - posOut, ok, err := adjuster.GetTargetCommitPositionFromSourcePosition(context.Background(), "deadbeef2", "foo/bar.go", posIn, false) + posOut, ok, err := adjuster.GetTargetCommitPositionFromSourcePosition(context.Background(), "deadbeef2", "resources/image.go", posIn, false) if err != nil { t.Fatalf("unexpected error: %s", err) } @@ -48,15 +49,16 @@ func TestGetTargetCommitPositionFromSourcePosition(t *testing.T) { } func TestGetTargetCommitPositionFromSourcePositionEmptyDiff(t *testing.T) { - client := gitserver.NewMockClientWithExecReader(nil, func(_ context.Context, _ api.RepoName, args []string) (reader io.ReadCloser, err error) { - return io.NopCloser(bytes.NewReader(nil)), nil + client := gitserver.NewMockClient() + client.DiffFunc.SetDefaultHook(func(ctx context.Context, rn api.RepoName, do gitserver.DiffOptions) (*gitserver.DiffFileIterator, error) { + return gitserver.NewDiffFileIterator(io.NopCloser(bytes.NewReader(nil))), nil }) posIn := shared.Position{Line: 10, Character: 15} args := &mockTranslationBase adjuster := NewGitTreeTranslator(client, args, nil) - posOut, ok, err := adjuster.GetTargetCommitPositionFromSourcePosition(context.Background(), "deadbeef2", "foo/bar.go", posIn, false) + posOut, ok, err := adjuster.GetTargetCommitPositionFromSourcePosition(context.Background(), "deadbeef2", "resources/image.go", posIn, false) if err != nil { t.Fatalf("unexpected error: %s", err) } @@ -70,15 +72,16 @@ func TestGetTargetCommitPositionFromSourcePositionEmptyDiff(t *testing.T) { } func TestGetTargetCommitPositionFromSourcePositionReverse(t *testing.T) { - client := gitserver.NewMockClientWithExecReader(nil, func(_ context.Context, _ api.RepoName, args []string) (reader io.ReadCloser, err error) { - return io.NopCloser(bytes.NewReader([]byte(hugoDiff))), nil + client := gitserver.NewMockClient() + client.DiffFunc.SetDefaultHook(func(ctx context.Context, rn api.RepoName, do gitserver.DiffOptions) (*gitserver.DiffFileIterator, error) { + return gitserver.NewDiffFileIterator(io.NopCloser(bytes.NewReader([]byte(hugoDiff)))), nil }) posIn := shared.Position{Line: 302, Character: 15} args := &mockTranslationBase adjuster := NewGitTreeTranslator(client, args, nil) - posOut, ok, err := adjuster.GetTargetCommitPositionFromSourcePosition(context.Background(), "deadbeef2", "foo/bar.go", posIn, true) + posOut, ok, err := adjuster.GetTargetCommitPositionFromSourcePosition(context.Background(), "deadbeef2", "resources/image.go", posIn, true) if err != nil { t.Fatalf("unexpected error: %s", err) } @@ -94,8 +97,9 @@ func TestGetTargetCommitPositionFromSourcePositionReverse(t *testing.T) { } func TestGetTargetCommitRangeFromSourceRange(t *testing.T) { - client := gitserver.NewMockClientWithExecReader(nil, func(_ context.Context, _ api.RepoName, args []string) (reader io.ReadCloser, err error) { - return io.NopCloser(bytes.NewReader([]byte(hugoDiff))), nil + client := gitserver.NewMockClient() + client.DiffFunc.SetDefaultHook(func(ctx context.Context, rn api.RepoName, do gitserver.DiffOptions) (*gitserver.DiffFileIterator, error) { + return gitserver.NewDiffFileIterator(io.NopCloser(bytes.NewReader([]byte(hugoDiff)))), nil }) rIn := shared.Range{ @@ -105,7 +109,7 @@ func TestGetTargetCommitRangeFromSourceRange(t *testing.T) { args := &mockTranslationBase adjuster := NewGitTreeTranslator(client, args, nil) - rOut, ok, err := adjuster.GetTargetCommitRangeFromSourceRange(context.Background(), "deadbeef2", "foo/bar.go", rIn, false) + rOut, ok, err := adjuster.GetTargetCommitRangeFromSourceRange(context.Background(), "deadbeef2", "resources/image.go", rIn, false) if err != nil { t.Fatalf("unexpected error: %s", err) } @@ -124,10 +128,10 @@ func TestGetTargetCommitRangeFromSourceRange(t *testing.T) { } func TestGetTargetCommitRangeFromSourceRangeEmptyDiff(t *testing.T) { - client := gitserver.NewMockClientWithExecReader(nil, func(_ context.Context, _ api.RepoName, args []string) (reader io.ReadCloser, err error) { - return io.NopCloser(bytes.NewReader([]byte(nil))), nil + client := gitserver.NewMockClient() + client.DiffFunc.SetDefaultHook(func(ctx context.Context, rn api.RepoName, do gitserver.DiffOptions) (*gitserver.DiffFileIterator, error) { + return gitserver.NewDiffFileIterator(io.NopCloser(bytes.NewReader([]byte(nil)))), nil }) - rIn := shared.Range{ Start: shared.Position{Line: 302, Character: 15}, End: shared.Position{Line: 305, Character: 20}, @@ -135,7 +139,7 @@ func TestGetTargetCommitRangeFromSourceRangeEmptyDiff(t *testing.T) { args := &mockTranslationBase adjuster := NewGitTreeTranslator(client, args, nil) - rOut, ok, err := adjuster.GetTargetCommitRangeFromSourceRange(context.Background(), "deadbeef2", "foo/bar.go", rIn, false) + rOut, ok, err := adjuster.GetTargetCommitRangeFromSourceRange(context.Background(), "deadbeef2", "resources/image.go", rIn, false) if err != nil { t.Fatalf("unexpected error: %s", err) } @@ -149,8 +153,9 @@ func TestGetTargetCommitRangeFromSourceRangeEmptyDiff(t *testing.T) { } func TestGetTargetCommitRangeFromSourceRangeReverse(t *testing.T) { - client := gitserver.NewMockClientWithExecReader(nil, func(_ context.Context, _ api.RepoName, args []string) (reader io.ReadCloser, err error) { - return io.NopCloser(bytes.NewReader([]byte(hugoDiff))), nil + client := gitserver.NewMockClient() + client.DiffFunc.SetDefaultHook(func(ctx context.Context, rn api.RepoName, do gitserver.DiffOptions) (*gitserver.DiffFileIterator, error) { + return gitserver.NewDiffFileIterator(io.NopCloser(bytes.NewReader([]byte(hugoDiff)))), nil }) rIn := shared.Range{ @@ -160,7 +165,7 @@ func TestGetTargetCommitRangeFromSourceRangeReverse(t *testing.T) { args := &mockTranslationBase adjuster := NewGitTreeTranslator(client, args, nil) - rOut, ok, err := adjuster.GetTargetCommitRangeFromSourceRange(context.Background(), "deadbeef2", "foo/bar.go", rIn, true) + rOut, ok, err := adjuster.GetTargetCommitRangeFromSourceRange(context.Background(), "deadbeef2", "resources/image.go", rIn, true) if err != nil { t.Fatalf("unexpected error: %s", err) } @@ -187,25 +192,25 @@ type gitTreeTranslatorTestCase struct { } // hugoDiff is a diff from github.com/gohugoio/hugo generated via the following command. -// git diff -U0 8947c3fa0beec021e14b3f8040857335e1ecd473 3e9db2ad951dbb1000cd0f8f25e4a95445046679 -- resources/image.go +// git diff-tree --patch --find-renames --full-index --inter-hunk-context=0 --unified=0 --no-prefix 8947c3fa0beec021e14b3f8040857335e1ecd473 3e9db2ad951dbb1000cd0f8f25e4a95445046679 -- resources/image.go const hugoDiff = ` -diff --git a/resources/image.go b/resources/image.go -index d1d9f650d..076f2ae4d 100644 ---- a/resources/image.go -+++ b/resources/image.go +diff --git resources/image.go resources/image.go +index d1d9f650d673e35359444dc9df4f1e24e2cd4fbc..076f2ae4d63b1b6e2de1e3308f6e7bdb791d4d33 100644 +--- resources/image.go ++++ resources/image.go @@ -39 +38,0 @@ import ( -- "github.com/pkg/errors" +- "github.com/pkg/errors" @@ -238 +237 @@ func (i *imageResource) doWithImageConfig(conf images.ImageConfig, f func(src im -- img, err := i.getSpec().imageCache.getOrCreate(i, conf, func() (*imageResource, image.Image, error) { -+ return i.getSpec().imageCache.getOrCreate(i, conf, func() (*imageResource, image.Image, error) { +- img, err := i.getSpec().imageCache.getOrCreate(i, conf, func() (*imageResource, image.Image, error) { ++ return i.getSpec().imageCache.getOrCreate(i, conf, func() (*imageResource, image.Image, error) { @@ -295,7 +293,0 @@ func (i *imageResource) doWithImageConfig(conf images.ImageConfig, f func(src im - -- if err != nil { -- if i.root != nil && i.root.getFileInfo() != nil { -- return nil, errors.Wrapf(err, "image %q", i.root.getFileInfo().Meta().Filename()) -- } -- } -- return img, nil +- if err != nil { +- if i.root != nil && i.root.getFileInfo() != nil { +- return nil, errors.Wrapf(err, "image %q", i.root.getFileInfo().Meta().Filename()) +- } +- } +- return img, nil ` var hugoTestCases = []gitTreeTranslatorTestCase{ @@ -237,20 +242,20 @@ var hugoTestCases = []gitTreeTranslatorTestCase{ } // prometheusDiff is a diff from github.com/prometheus/prometheus generated via the following command. -// git diff -U0 52025bd7a9446c3178bf01dd2949d4874dd45f24 45fbed94d6ee17840254e78cfc421ab1db78f734 -- discovery/manager.go +// git diff-tree --patch --find-renames --full-index --inter-hunk-context=0 --unified=0 --no-prefix 52025bd7a9446c3178bf01dd2949d4874dd45f24 45fbed94d6ee17840254e78cfc421ab1db78f734 -- discovery/manager.go const prometheusDiff = ` -diff --git a/discovery/manager.go b/discovery/manager.go -index 49bcbf86b..d135cd54e 100644 ---- a/discovery/manager.go -+++ b/discovery/manager.go +diff --git discovery/manager.go discovery/manager.go +index 49bcbf86b7baa70bff34b0fa306ca20877f5640e..d135cd54e700ea67963a186ca370d59466f9eb78 100644 +--- discovery/manager.go ++++ discovery/manager.go @@ -296,3 +295,0 @@ func (m *Manager) updateGroup(poolKey poolKey, tgs []*targetgroup.Group) { -- if _, ok := m.targets[poolKey]; !ok { -- m.targets[poolKey] = make(map[string]*targetgroup.Group) -- } +- if _, ok := m.targets[poolKey]; !ok { +- m.targets[poolKey] = make(map[string]*targetgroup.Group) +- } @@ -300,0 +298,3 @@ func (m *Manager) updateGroup(poolKey poolKey, tgs []*targetgroup.Group) { -+ if _, ok := m.targets[poolKey]; !ok { -+ m.targets[poolKey] = make(map[string]*targetgroup.Group) -+ } ++ if _, ok := m.targets[poolKey]; !ok { ++ m.targets[poolKey] = make(map[string]*targetgroup.Group) ++ } ` var prometheusTestCases = []gitTreeTranslatorTestCase{ diff --git a/internal/codeintel/codenav/internal/lsifstore/BUILD.bazel b/internal/codeintel/codenav/internal/lsifstore/BUILD.bazel index 71c90e22d52..7af9b1da2b8 100644 --- a/internal/codeintel/codenav/internal/lsifstore/BUILD.bazel +++ b/internal/codeintel/codenav/internal/lsifstore/BUILD.bazel @@ -31,6 +31,7 @@ go_library( "//lib/errors", "@com_github_keegancsmith_sqlf//:sqlf", "@com_github_lib_pq//:pq", + "@com_github_life4_genesis//slices", "@com_github_sourcegraph_scip//bindings/go/scip", "@io_opentelemetry_go_otel//attribute", "@org_golang_google_protobuf//proto", diff --git a/internal/codeintel/codenav/internal/lsifstore/lsifstore_documents.go b/internal/codeintel/codenav/internal/lsifstore/lsifstore_documents.go index 6e05e748e5f..1f835e5ac62 100644 --- a/internal/codeintel/codenav/internal/lsifstore/lsifstore_documents.go +++ b/internal/codeintel/codenav/internal/lsifstore/lsifstore_documents.go @@ -5,6 +5,7 @@ import ( "context" "github.com/keegancsmith/sqlf" + genslices "github.com/life4/genesis/slices" "github.com/sourcegraph/scip/bindings/go/scip" "go.opentelemetry.io/otel/attribute" "google.golang.org/protobuf/proto" @@ -16,6 +17,55 @@ import ( "github.com/sourcegraph/sourcegraph/internal/observation" ) +func (s *store) SCIPDocuments(ctx context.Context, uploadID int, paths []core.UploadRelPath) (_ map[core.UploadRelPath]*scip.Document, err error) { + stringPaths := genslices.Map(paths, func(p core.UploadRelPath) string { return p.RawValue() }) + ctx, _, endObservation := s.operations.scipDocuments.With(ctx, &err, observation.Args{Attrs: []attribute.KeyValue{ + attribute.Int("uploadID", uploadID), + attribute.StringSlice("paths", stringPaths), + }}) + defer endObservation(1, observation.Args{}) + + scanner := basestore.NewMapScanner(func(dbs dbutil.Scanner) (core.UploadRelPath, *scip.Document, error) { + var compressedSCIPPayload []byte + var path string + emptyPath := core.NewUploadRelPathUnchecked("") + if err := dbs.Scan(&path, &compressedSCIPPayload); err != nil { + return emptyPath, nil, err + } + + scipPayload, err := shared.Decompressor.Decompress(bytes.NewReader(compressedSCIPPayload)) + if err != nil { + return emptyPath, nil, err + } + + var document scip.Document + if err := proto.Unmarshal(scipPayload, &document); err != nil { + return emptyPath, nil, err + } + return core.NewUploadRelPathUnchecked(path), &document, nil + }) + searchPaths := make([]*sqlf.Query, 0, len(paths)) + for _, path := range stringPaths { + searchPaths = append(searchPaths, sqlf.Sprintf("%s", path)) + } + doc, err := scanner(s.db.Query(ctx, sqlf.Sprintf(fetchSCIPDocumentsQuery, uploadID, sqlf.Join(searchPaths, ",")))) + if err != nil { + return nil, nil + } + return doc, nil +} + +const fetchSCIPDocumentsQuery = ` +SELECT + sid.document_path, + sd.raw_scip_payload +FROM codeintel_scip_document_lookup sid +JOIN codeintel_scip_documents sd ON sd.id = sid.document_id +WHERE + sid.upload_id = %s AND + sid.document_path IN (%s) +` + func (s *store) SCIPDocument(ctx context.Context, uploadID int, path core.UploadRelPath) (_ core.Option[*scip.Document], err error) { ctx, _, endObservation := s.operations.scipDocument.With(ctx, &err, observation.Args{Attrs: []attribute.KeyValue{ attribute.String("path", path.RawValue()), diff --git a/internal/codeintel/codenav/internal/lsifstore/observability.go b/internal/codeintel/codenav/internal/lsifstore/observability.go index a27a4a94f0b..65b394c2d75 100644 --- a/internal/codeintel/codenav/internal/lsifstore/observability.go +++ b/internal/codeintel/codenav/internal/lsifstore/observability.go @@ -21,6 +21,7 @@ type operations struct { getHover *observation.Operation getDiagnostics *observation.Operation scipDocument *observation.Operation + scipDocuments *observation.Operation findDocumentIDs *observation.Operation } @@ -58,6 +59,7 @@ func newOperations(observationCtx *observation.Context) *operations { getHover: op("GetHover"), getDiagnostics: op("GetDiagnostics"), scipDocument: op("SCIPDocument"), + scipDocuments: op("SCIPDocuments"), findDocumentIDs: op("FindDocumentIDs"), } } diff --git a/internal/codeintel/codenav/internal/lsifstore/store.go b/internal/codeintel/codenav/internal/lsifstore/store.go index b6e271db30f..63927016a7c 100644 --- a/internal/codeintel/codenav/internal/lsifstore/store.go +++ b/internal/codeintel/codenav/internal/lsifstore/store.go @@ -25,6 +25,7 @@ type LsifStore interface { GetStencil(ctx context.Context, bundleID int, path core.UploadRelPath) ([]shared.Range, error) GetRanges(ctx context.Context, bundleID int, path core.UploadRelPath, startLine, endLine int) ([]shared.CodeIntelligenceRange, error) SCIPDocument(ctx context.Context, uploadID int, path core.UploadRelPath) (core.Option[*scip.Document], error) + SCIPDocuments(ctx context.Context, uploadID int, paths []core.UploadRelPath) (map[core.UploadRelPath]*scip.Document, error) // Fetch symbol names by position GetMonikersByPosition(ctx context.Context, uploadID int, path core.UploadRelPath, line, character int) ([][]precise.MonikerData, error) diff --git a/internal/codeintel/codenav/mapped_index.go b/internal/codeintel/codenav/mapped_index.go index 36417eaeba3..1983f8c3a11 100644 --- a/internal/codeintel/codenav/mapped_index.go +++ b/internal/codeintel/codenav/mapped_index.go @@ -18,8 +18,9 @@ type MappedIndex interface { // GetDocument returns None if the index does not contain a document at the given path. // There is no caching here, every call to GetDocument re-fetches the full document from the database. GetDocument(context.Context, core.RepoRelPath) (core.Option[MappedDocument], error) + // GetDocuments uses batch APIs to fetch multiple documents and pre-fetches diff contents + GetDocuments(context.Context, []core.RepoRelPath) ([]core.Option[MappedDocument], error) GetUploadSummary() core.UploadSummary - // TODO: Should there be a bulk-API for getting multiple documents? } var _ MappedIndex = mappedIndex{} @@ -31,6 +32,7 @@ type MappedDocument interface { GetOccurrences(context.Context) ([]*scip.Occurrence, error) // GetOccurrencesAtRange returns shared slices. Do not modify the returned slice or Occurrences without copying them first GetOccurrencesAtRange(context.Context, scip.Range) ([]*scip.Occurrence, error) + GetPath() core.RepoRelPath } var _ MappedDocument = &mappedDocument{} @@ -65,30 +67,51 @@ func (i mappedIndex) GetUploadSummary() core.UploadSummary { } } +func (i mappedIndex) makeMappedDocument(path core.RepoRelPath, scipDocument *scip.Document) MappedDocument { + return &mappedDocument{ + gitTreeTranslator: i.gitTreeTranslator, + indexCommit: i.upload.GetCommit(), + targetCommit: i.targetCommit, + path: path, + document: &lockedDocument{ + inner: scipDocument, + isMapped: false, + mapErrored: nil, + lock: sync.RWMutex{}, + }, + mapOnce: sync.Once{}, + } +} + func (i mappedIndex) GetDocument(ctx context.Context, path core.RepoRelPath) (core.Option[MappedDocument], error) { optDocument, err := i.lsifStore.SCIPDocument(ctx, i.upload.GetID(), core.NewUploadRelPath(i.upload, path)) if err != nil { return core.None[MappedDocument](), err } if document, ok := optDocument.Get(); ok { - return core.Some[MappedDocument](&mappedDocument{ - gitTreeTranslator: i.gitTreeTranslator, - indexCommit: i.upload.GetCommit(), - targetCommit: i.targetCommit, - path: path, - document: &lockedDocument{ - inner: document, - isMapped: false, - mapErrored: nil, - lock: sync.RWMutex{}, - }, - mapOnce: sync.Once{}, - }), nil + return core.Some[MappedDocument](i.makeMappedDocument(path, document)), nil } else { return core.None[MappedDocument](), nil } } +func (i mappedIndex) GetDocuments(ctx context.Context, paths []core.RepoRelPath) ([]core.Option[MappedDocument], error) { + i.gitTreeTranslator.Prefetch(ctx, i.gitTreeTranslator.GetSourceCommit(), i.upload.GetCommit(), paths) + documentMap, err := i.lsifStore.SCIPDocuments(ctx, i.upload.GetID(), genslices.Map(paths, func(p core.RepoRelPath) core.UploadRelPath { + return core.NewUploadRelPath(i.upload, p) + })) + if err != nil { + return nil, err + } + return genslices.Map(paths, func(path core.RepoRelPath) core.Option[MappedDocument] { + if document, ok := documentMap[core.NewUploadRelPath(i.upload, path)]; ok { + return core.Some[MappedDocument](i.makeMappedDocument(path, document)) + } else { + return core.None[MappedDocument]() + } + }), nil +} + type mappedDocument struct { gitTreeTranslator GitTreeTranslator indexCommit api.CommitID @@ -106,6 +129,10 @@ type lockedDocument struct { lock sync.RWMutex } +func (d *mappedDocument) GetPath() core.RepoRelPath { + return d.path +} + func cloneOccurrence(occ *scip.Occurrence) *scip.Occurrence { occCopy, ok := proto.Clone(occ).(*scip.Occurrence) if !ok { diff --git a/internal/codeintel/codenav/mocks_temp.go b/internal/codeintel/codenav/mocks_temp.go index f7cf5275c91..7e7d137062c 100644 --- a/internal/codeintel/codenav/mocks_temp.go +++ b/internal/codeintel/codenav/mocks_temp.go @@ -71,6 +71,9 @@ type MockLsifStore struct { // SCIPDocumentFunc is an instance of a mock function object controlling // the behavior of the method SCIPDocument. SCIPDocumentFunc *LsifStoreSCIPDocumentFunc + // SCIPDocumentsFunc is an instance of a mock function object + // controlling the behavior of the method SCIPDocuments. + SCIPDocumentsFunc *LsifStoreSCIPDocumentsFunc } // NewMockLsifStore creates a new mock of the LsifStore interface. All @@ -147,6 +150,11 @@ func NewMockLsifStore() *MockLsifStore { return }, }, + SCIPDocumentsFunc: &LsifStoreSCIPDocumentsFunc{ + defaultHook: func(context.Context, int, []core.UploadRelPath) (r0 map[core.UploadRelPath]*scip.Document, r1 error) { + return + }, + }, } } @@ -224,6 +232,11 @@ func NewStrictMockLsifStore() *MockLsifStore { panic("unexpected invocation of MockLsifStore.SCIPDocument") }, }, + SCIPDocumentsFunc: &LsifStoreSCIPDocumentsFunc{ + defaultHook: func(context.Context, int, []core.UploadRelPath) (map[core.UploadRelPath]*scip.Document, error) { + panic("unexpected invocation of MockLsifStore.SCIPDocuments") + }, + }, } } @@ -273,6 +286,9 @@ func NewMockLsifStoreFrom(i lsifstore.LsifStore) *MockLsifStore { SCIPDocumentFunc: &LsifStoreSCIPDocumentFunc{ defaultHook: i.SCIPDocument, }, + SCIPDocumentsFunc: &LsifStoreSCIPDocumentsFunc{ + defaultHook: i.SCIPDocuments, + }, } } @@ -1919,6 +1935,117 @@ func (c LsifStoreSCIPDocumentFuncCall) Results() []interface{} { return []interface{}{c.Result0, c.Result1} } +// LsifStoreSCIPDocumentsFunc describes the behavior when the SCIPDocuments +// method of the parent MockLsifStore instance is invoked. +type LsifStoreSCIPDocumentsFunc struct { + defaultHook func(context.Context, int, []core.UploadRelPath) (map[core.UploadRelPath]*scip.Document, error) + hooks []func(context.Context, int, []core.UploadRelPath) (map[core.UploadRelPath]*scip.Document, error) + history []LsifStoreSCIPDocumentsFuncCall + mutex sync.Mutex +} + +// SCIPDocuments delegates to the next hook function in the queue and stores +// the parameter and result values of this invocation. +func (m *MockLsifStore) SCIPDocuments(v0 context.Context, v1 int, v2 []core.UploadRelPath) (map[core.UploadRelPath]*scip.Document, error) { + r0, r1 := m.SCIPDocumentsFunc.nextHook()(v0, v1, v2) + m.SCIPDocumentsFunc.appendCall(LsifStoreSCIPDocumentsFuncCall{v0, v1, v2, r0, r1}) + return r0, r1 +} + +// SetDefaultHook sets function that is called when the SCIPDocuments method +// of the parent MockLsifStore instance is invoked and the hook queue is +// empty. +func (f *LsifStoreSCIPDocumentsFunc) SetDefaultHook(hook func(context.Context, int, []core.UploadRelPath) (map[core.UploadRelPath]*scip.Document, error)) { + f.defaultHook = hook +} + +// PushHook adds a function to the end of hook queue. Each invocation of the +// SCIPDocuments method of the parent MockLsifStore instance invokes the +// hook at the front of the queue and discards it. After the queue is empty, +// the default hook function is invoked for any future action. +func (f *LsifStoreSCIPDocumentsFunc) PushHook(hook func(context.Context, int, []core.UploadRelPath) (map[core.UploadRelPath]*scip.Document, error)) { + f.mutex.Lock() + f.hooks = append(f.hooks, hook) + f.mutex.Unlock() +} + +// SetDefaultReturn calls SetDefaultHook with a function that returns the +// given values. +func (f *LsifStoreSCIPDocumentsFunc) SetDefaultReturn(r0 map[core.UploadRelPath]*scip.Document, r1 error) { + f.SetDefaultHook(func(context.Context, int, []core.UploadRelPath) (map[core.UploadRelPath]*scip.Document, error) { + return r0, r1 + }) +} + +// PushReturn calls PushHook with a function that returns the given values. +func (f *LsifStoreSCIPDocumentsFunc) PushReturn(r0 map[core.UploadRelPath]*scip.Document, r1 error) { + f.PushHook(func(context.Context, int, []core.UploadRelPath) (map[core.UploadRelPath]*scip.Document, error) { + return r0, r1 + }) +} + +func (f *LsifStoreSCIPDocumentsFunc) nextHook() func(context.Context, int, []core.UploadRelPath) (map[core.UploadRelPath]*scip.Document, error) { + f.mutex.Lock() + defer f.mutex.Unlock() + + if len(f.hooks) == 0 { + return f.defaultHook + } + + hook := f.hooks[0] + f.hooks = f.hooks[1:] + return hook +} + +func (f *LsifStoreSCIPDocumentsFunc) appendCall(r0 LsifStoreSCIPDocumentsFuncCall) { + f.mutex.Lock() + f.history = append(f.history, r0) + f.mutex.Unlock() +} + +// History returns a sequence of LsifStoreSCIPDocumentsFuncCall objects +// describing the invocations of this function. +func (f *LsifStoreSCIPDocumentsFunc) History() []LsifStoreSCIPDocumentsFuncCall { + f.mutex.Lock() + history := make([]LsifStoreSCIPDocumentsFuncCall, len(f.history)) + copy(history, f.history) + f.mutex.Unlock() + + return history +} + +// LsifStoreSCIPDocumentsFuncCall is an object that describes an invocation +// of method SCIPDocuments on an instance of MockLsifStore. +type LsifStoreSCIPDocumentsFuncCall struct { + // Arg0 is the value of the 1st argument passed to this method + // invocation. + Arg0 context.Context + // Arg1 is the value of the 2nd argument passed to this method + // invocation. + Arg1 int + // Arg2 is the value of the 3rd argument passed to this method + // invocation. + Arg2 []core.UploadRelPath + // Result0 is the value of the 1st result returned from this method + // invocation. + Result0 map[core.UploadRelPath]*scip.Document + // Result1 is the value of the 2nd result returned from this method + // invocation. + Result1 error +} + +// Args returns an interface slice containing the arguments of this +// invocation. +func (c LsifStoreSCIPDocumentsFuncCall) Args() []interface{} { + return []interface{}{c.Arg0, c.Arg1, c.Arg2} +} + +// Results returns an interface slice containing the results of this +// invocation. +func (c LsifStoreSCIPDocumentsFuncCall) Results() []interface{} { + return []interface{}{c.Result0, c.Result1} +} + // MockGitTreeTranslator is a mock implementation of the GitTreeTranslator // interface (from the package // github.com/sourcegraph/sourcegraph/internal/codeintel/codenav) used for @@ -1935,6 +2062,9 @@ type MockGitTreeTranslator struct { // function object controlling the behavior of the method // GetTargetCommitRangeFromSourceRange. GetTargetCommitRangeFromSourceRangeFunc *GitTreeTranslatorGetTargetCommitRangeFromSourceRangeFunc + // PrefetchFunc is an instance of a mock function object controlling the + // behavior of the method Prefetch. + PrefetchFunc *GitTreeTranslatorPrefetchFunc } // NewMockGitTreeTranslator creates a new mock of the GitTreeTranslator @@ -1957,6 +2087,11 @@ func NewMockGitTreeTranslator() *MockGitTreeTranslator { return }, }, + PrefetchFunc: &GitTreeTranslatorPrefetchFunc{ + defaultHook: func(context.Context, api.CommitID, api.CommitID, []core.RepoRelPath) { + return + }, + }, } } @@ -1980,6 +2115,11 @@ func NewStrictMockGitTreeTranslator() *MockGitTreeTranslator { panic("unexpected invocation of MockGitTreeTranslator.GetTargetCommitRangeFromSourceRange") }, }, + PrefetchFunc: &GitTreeTranslatorPrefetchFunc{ + defaultHook: func(context.Context, api.CommitID, api.CommitID, []core.RepoRelPath) { + panic("unexpected invocation of MockGitTreeTranslator.Prefetch") + }, + }, } } @@ -1997,6 +2137,9 @@ func NewMockGitTreeTranslatorFrom(i GitTreeTranslator) *MockGitTreeTranslator { GetTargetCommitRangeFromSourceRangeFunc: &GitTreeTranslatorGetTargetCommitRangeFromSourceRangeFunc{ defaultHook: i.GetTargetCommitRangeFromSourceRange, }, + PrefetchFunc: &GitTreeTranslatorPrefetchFunc{ + defaultHook: i.Prefetch, + }, } } @@ -2354,6 +2497,114 @@ func (c GitTreeTranslatorGetTargetCommitRangeFromSourceRangeFuncCall) Results() return []interface{}{c.Result0, c.Result1, c.Result2} } +// GitTreeTranslatorPrefetchFunc describes the behavior when the Prefetch +// method of the parent MockGitTreeTranslator instance is invoked. +type GitTreeTranslatorPrefetchFunc struct { + defaultHook func(context.Context, api.CommitID, api.CommitID, []core.RepoRelPath) + hooks []func(context.Context, api.CommitID, api.CommitID, []core.RepoRelPath) + history []GitTreeTranslatorPrefetchFuncCall + mutex sync.Mutex +} + +// Prefetch delegates to the next hook function in the queue and stores the +// parameter and result values of this invocation. +func (m *MockGitTreeTranslator) Prefetch(v0 context.Context, v1 api.CommitID, v2 api.CommitID, v3 []core.RepoRelPath) { + m.PrefetchFunc.nextHook()(v0, v1, v2, v3) + m.PrefetchFunc.appendCall(GitTreeTranslatorPrefetchFuncCall{v0, v1, v2, v3}) + return +} + +// SetDefaultHook sets function that is called when the Prefetch method of +// the parent MockGitTreeTranslator instance is invoked and the hook queue +// is empty. +func (f *GitTreeTranslatorPrefetchFunc) SetDefaultHook(hook func(context.Context, api.CommitID, api.CommitID, []core.RepoRelPath)) { + f.defaultHook = hook +} + +// PushHook adds a function to the end of hook queue. Each invocation of the +// Prefetch method of the parent MockGitTreeTranslator instance invokes the +// hook at the front of the queue and discards it. After the queue is empty, +// the default hook function is invoked for any future action. +func (f *GitTreeTranslatorPrefetchFunc) PushHook(hook func(context.Context, api.CommitID, api.CommitID, []core.RepoRelPath)) { + f.mutex.Lock() + f.hooks = append(f.hooks, hook) + f.mutex.Unlock() +} + +// SetDefaultReturn calls SetDefaultHook with a function that returns the +// given values. +func (f *GitTreeTranslatorPrefetchFunc) SetDefaultReturn() { + f.SetDefaultHook(func(context.Context, api.CommitID, api.CommitID, []core.RepoRelPath) { + return + }) +} + +// PushReturn calls PushHook with a function that returns the given values. +func (f *GitTreeTranslatorPrefetchFunc) PushReturn() { + f.PushHook(func(context.Context, api.CommitID, api.CommitID, []core.RepoRelPath) { + return + }) +} + +func (f *GitTreeTranslatorPrefetchFunc) nextHook() func(context.Context, api.CommitID, api.CommitID, []core.RepoRelPath) { + f.mutex.Lock() + defer f.mutex.Unlock() + + if len(f.hooks) == 0 { + return f.defaultHook + } + + hook := f.hooks[0] + f.hooks = f.hooks[1:] + return hook +} + +func (f *GitTreeTranslatorPrefetchFunc) appendCall(r0 GitTreeTranslatorPrefetchFuncCall) { + f.mutex.Lock() + f.history = append(f.history, r0) + f.mutex.Unlock() +} + +// History returns a sequence of GitTreeTranslatorPrefetchFuncCall objects +// describing the invocations of this function. +func (f *GitTreeTranslatorPrefetchFunc) History() []GitTreeTranslatorPrefetchFuncCall { + f.mutex.Lock() + history := make([]GitTreeTranslatorPrefetchFuncCall, len(f.history)) + copy(history, f.history) + f.mutex.Unlock() + + return history +} + +// GitTreeTranslatorPrefetchFuncCall is an object that describes an +// invocation of method Prefetch on an instance of MockGitTreeTranslator. +type GitTreeTranslatorPrefetchFuncCall struct { + // Arg0 is the value of the 1st argument passed to this method + // invocation. + Arg0 context.Context + // Arg1 is the value of the 2nd argument passed to this method + // invocation. + Arg1 api.CommitID + // Arg2 is the value of the 3rd argument passed to this method + // invocation. + Arg2 api.CommitID + // Arg3 is the value of the 4th argument passed to this method + // invocation. + Arg3 []core.RepoRelPath +} + +// Args returns an interface slice containing the arguments of this +// invocation. +func (c GitTreeTranslatorPrefetchFuncCall) Args() []interface{} { + return []interface{}{c.Arg0, c.Arg1, c.Arg2, c.Arg3} +} + +// Results returns an interface slice containing the results of this +// invocation. +func (c GitTreeTranslatorPrefetchFuncCall) Results() []interface{} { + return []interface{}{} +} + // MockUploadService is a mock implementation of the UploadService interface // (from the package // github.com/sourcegraph/sourcegraph/internal/codeintel/codenav) used for diff --git a/internal/codeintel/codenav/service.go b/internal/codeintel/codenav/service.go index 92baf8d2f8f..7792b500fb5 100644 --- a/internal/codeintel/codenav/service.go +++ b/internal/codeintel/codenav/service.go @@ -1102,6 +1102,7 @@ func (s *Service) getSyntacticUpload(ctx context.Context, trace observation.Trac type SearchBasedMatch struct { Path core.RepoRelPath Range scip.Range + LineContent string IsDefinition bool } @@ -1109,6 +1110,7 @@ type SyntacticMatch struct { Path core.RepoRelPath Range scip.Range IsDefinition bool + LineContent string Symbol string } diff --git a/internal/codeintel/codenav/service_ranges_test.go b/internal/codeintel/codenav/service_ranges_test.go index 0e9aeb58ca3..75edd98263e 100644 --- a/internal/codeintel/codenav/service_ranges_test.go +++ b/internal/codeintel/codenav/service_ranges_test.go @@ -21,10 +21,10 @@ import ( ) const rangesDiff = ` -diff --git a/changed.go b/changed.go +diff --git sub3/changed.go sub3/changed.go index deadbeef1..deadbeef2 100644 ---- a/changed.go -+++ b/changed.go +--- sub3/changed.go ++++ sub3/changed.go @@ -12,7 +12,7 @@ const imageProcWorkers = 1 var imageProcSem = make(chan bool, imageProcWorkers) var random = "banana" diff --git a/internal/codeintel/codenav/syntactic.go b/internal/codeintel/codenav/syntactic.go index 2468f866a3a..120eb0e81b8 100644 --- a/internal/codeintel/codenav/syntactic.go +++ b/internal/codeintel/codenav/syntactic.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "slices" + "strings" genslices "github.com/life4/genesis/slices" conciter "github.com/sourcegraph/conc/iter" @@ -25,9 +26,15 @@ import ( "github.com/sourcegraph/sourcegraph/lib/errors" ) +type candidateMatch struct { + range_ scip.Range + lineContent string +} + type candidateFile struct { - matches []scip.Range // Guaranteed to be sorted - didSearchEntireFile bool // Or did we hit the search count limit? + path core.RepoRelPath + matches []candidateMatch // Guaranteed to be sorted + didSearchEntireFile bool // Or did we hit the search count limit? } type searchArgs struct { @@ -37,6 +44,16 @@ type searchArgs struct { language string } +func lineForRange(match result.ChunkMatch, range_ result.Range) string { + lines := strings.Split(match.Content, "\n") + index := range_.Start.Line - match.ContentStart.Line + // TODO: log? + if len(lines) <= index { + return "" + } + return lines[index] +} + // findCandidateOccurrencesViaSearch calls out to Searcher/Zoekt to find candidate occurrences of the given symbol. // It returns a map of file paths to candidate ranges. func findCandidateOccurrencesViaSearch( @@ -44,16 +61,16 @@ func findCandidateOccurrencesViaSearch( trace observation.TraceLogger, client searchclient.SearchClient, args searchArgs, -) (orderedmap.OrderedMap[core.RepoRelPath, candidateFile], error) { +) ([]candidateFile, error) { if args.identifier == "" { - return *orderedmap.New[core.RepoRelPath, candidateFile](), nil + return []candidateFile{}, nil } resultMap := *orderedmap.New[core.RepoRelPath, candidateFile]() // TODO: countLimit should be dependent on the number of requested usages, with a configured global limit // For now we're matching the current web app with 500 searchResults, err := executeQuery(ctx, client, trace, args, "file", 500, 0) if err != nil { - return resultMap, err + return []candidateFile{}, err } nonFileMatches := 0 @@ -67,7 +84,7 @@ func findCandidateOccurrencesViaSearch( continue } path := fileMatch.Path - matches := []scip.Range{} + matches := []candidateMatch{} for _, chunkMatch := range fileMatch.ChunkMatches { for _, matchRange := range chunkMatch.Ranges { if path != streamResult.Key().Path { @@ -83,12 +100,18 @@ func findCandidateOccurrencesViaSearch( continue } matchCount += 1 - matches = append(matches, scipRange) + matches = append(matches, candidateMatch{ + range_: scipRange, + lineContent: lineForRange(chunkMatch, matchRange), + }) } } + slices.SortFunc(matches, func(m1 candidateMatch, m2 candidateMatch) int { return m1.range_.CompareStrict(m2.range_) }) // OK to use Unchecked method here as search API only returns repo-root relative paths - _, alreadyPresent := resultMap.Set(core.NewRepoRelPathUnchecked(path), candidateFile{ - matches: scip.SortRanges(matches), + repoRelPath := core.NewRepoRelPathUnchecked(path) + _, alreadyPresent := resultMap.Set(repoRelPath, candidateFile{ + path: repoRelPath, + matches: matches, didSearchEntireFile: !fileMatch.LimitHit, }) if alreadyPresent { @@ -107,7 +130,11 @@ func findCandidateOccurrencesViaSearch( trace.Warn("Saw mismatched file paths between chunk matches in the same FileMatch. Report this to the search-platform") } - return resultMap, nil + results := make([]candidateFile, 0, resultMap.Len()) + for pair := resultMap.Oldest(); pair != nil; pair = pair.Next() { + results = append(results, pair.Value) + } + return results, nil } type symbolData struct { @@ -287,30 +314,16 @@ func symbolAtRange( func findSyntacticMatchesForCandidateFile( ctx context.Context, trace observation.TraceLogger, - mappedIndex MappedIndex, - filePath core.RepoRelPath, + document MappedDocument, candidateFile candidateFile, -) ([]SyntacticMatch, []SearchBasedMatch, *SyntacticUsagesError) { - documentOpt, docErr := mappedIndex.GetDocument(ctx, filePath) - if docErr != nil { - return nil, nil, &SyntacticUsagesError{ - Code: SU_Fatal, - UnderlyingError: docErr, - } - } - document, isSome := documentOpt.Get() - if !isSome { - return nil, nil, &SyntacticUsagesError{ - Code: SU_NoSyntacticIndex, - } - } - syntacticMatches := []SyntacticMatch{} - searchBasedMatches := []SearchBasedMatch{} +) (syntacticMatches []SyntacticMatch, searchBasedMatches []SearchBasedMatch) { + syntacticMatches = []SyntacticMatch{} + searchBasedMatches = []SearchBasedMatch{} failedTranslationCount := 0 for _, sourceCandidateRange := range candidateFile.matches { foundSyntacticMatch := false - occurrences, occErr := document.GetOccurrencesAtRange(ctx, sourceCandidateRange) + occurrences, occErr := document.GetOccurrencesAtRange(ctx, sourceCandidateRange.range_) if occErr != nil { failedTranslationCount += 1 continue @@ -319,8 +332,9 @@ func findSyntacticMatchesForCandidateFile( if !scip.IsLocalSymbol(occ.Symbol) { foundSyntacticMatch = true syntacticMatches = append(syntacticMatches, SyntacticMatch{ - Path: filePath, - Range: sourceCandidateRange, + Path: document.GetPath(), + Range: sourceCandidateRange.range_, + LineContent: sourceCandidateRange.lineContent, Symbol: occ.Symbol, IsDefinition: scip.SymbolRole_Definition.Matches(occ), }) @@ -328,15 +342,16 @@ func findSyntacticMatchesForCandidateFile( } if !foundSyntacticMatch { searchBasedMatches = append(searchBasedMatches, SearchBasedMatch{ - Path: filePath, - Range: sourceCandidateRange, + Path: document.GetPath(), + Range: sourceCandidateRange.range_, + LineContent: sourceCandidateRange.lineContent, }) } } if failedTranslationCount != 0 { trace.Info("findSyntacticMatchesForCandidateFile", log.Int("failedTranslationCount", failedTranslationCount)) } - return syntacticMatches, searchBasedMatches, nil + return syntacticMatches, searchBasedMatches } func syntacticUsagesImpl( @@ -373,7 +388,7 @@ func syntacticUsagesImpl( identifier: symbolName, language: language, } - candidateMatches, searchErr := findCandidateOccurrencesViaSearch(ctx, trace, searchClient, searchCoords) + candidateFiles, searchErr := findCandidateOccurrencesViaSearch(ctx, trace, searchClient, searchCoords) if searchErr != nil { return SyntacticUsagesResult{}, PreviousSyntacticSearch{}, &SyntacticUsagesError{ Code: SU_FailedToSearch, @@ -381,18 +396,27 @@ func syntacticUsagesImpl( } } - tasks := make([]orderedmap.Pair[core.RepoRelPath, candidateFile], 0, candidateMatches.Len()) - for pair := candidateMatches.Oldest(); pair != nil; pair = pair.Next() { - tasks = append(tasks, *pair) - } - results := conciter.Map(tasks, func(pair *orderedmap.Pair[core.RepoRelPath, candidateFile]) []SyntacticMatch { + tasks, _ := genslices.ChunkEvery(candidateFiles, 20) + + results := conciter.Map(tasks, func(files *[]candidateFile) []SyntacticMatch { // We're assuming the index we found earlier contains the relevant SCIP document // see NOTE(id: single-syntactic-upload) - syntacticMatches, _, err := findSyntacticMatchesForCandidateFile(ctx, trace, mappedIndex, (*pair).Key, (*pair).Value) + mappedDocuments, err := mappedIndex.GetDocuments(ctx, genslices.Map(*files, func(cf candidateFile) core.RepoRelPath { + return cf.path + })) if err != nil { - // TODO: Errors that are not "no index found in the DB" should be reported - // TODO: Track metrics about how often this happens (GRAPH-693) - return []SyntacticMatch{} + // TODO: Errors here should be reported + return nil + } + syntacticMatches := []SyntacticMatch{} + for i, optDocument := range mappedDocuments { + document, isSome := optDocument.Get() + if !isSome { + continue + } + synMatches, _ := findSyntacticMatchesForCandidateFile(ctx, trace, document, (*files)[i]) + syntacticMatches = append(syntacticMatches, synMatches...) + } return syntacticMatches }) @@ -420,34 +444,37 @@ func searchBasedUsagesImpl( identifier: symbolName, language: language, } - candidateMatches, err := findCandidateOccurrencesViaSearch(ctx, trace, searchClient, searchCoords) + candidateMatches, candidateSymbols, err := core.Join( + func() ([]candidateFile, error) { + return findCandidateOccurrencesViaSearch(ctx, trace, searchClient, searchCoords) + }, + func() (symbolSearchResult, error) { + symbolResult, err := symbolSearch(ctx, trace, searchClient, searchCoords) + if err != nil { + trace.Warn("Failed to run symbol search, will not mark any search-based usages as definitions", log.Error(err)) + } + return symbolResult, nil + }) if err != nil { return nil, err } - candidateSymbols, err := symbolSearch(ctx, trace, searchClient, searchCoords) - if err != nil { - trace.Warn("Failed to run symbol search, will not mark any search-based usages as definitions", log.Error(err)) - } - tasks := make([]orderedmap.Pair[core.RepoRelPath, candidateFile], 0, candidateMatches.Len()) - for pair := candidateMatches.Oldest(); pair != nil; pair = pair.Next() { - tasks = append(tasks, *pair) - } - - results := conciter.Map(tasks, func(pair *orderedmap.Pair[core.RepoRelPath, candidateFile]) []SearchBasedMatch { + results := conciter.Map(candidateMatches, func(file *candidateFile) []SearchBasedMatch { if index, ok := syntacticIndex.Get(); ok { - _, searchBasedMatches, err := findSyntacticMatchesForCandidateFile(ctx, trace, index, pair.Key, pair.Value) - if err == nil { - return searchBasedMatches - } else { - trace.Info("findSyntacticMatches failed, skipping filtering search-based results", log.Error(err)) + optDocument, err := index.GetDocument(ctx, file.path) + if err != nil { + if document, isSome := optDocument.Get(); !isSome { + _, searchBasedMatches := findSyntacticMatchesForCandidateFile(ctx, trace, document, *file) + return searchBasedMatches + } } } matches := []SearchBasedMatch{} - for _, rg := range pair.Value.matches { + for _, rg := range file.matches { matches = append(matches, SearchBasedMatch{ - Path: pair.Key, - Range: rg, - IsDefinition: candidateSymbols.Contains(pair.Key, rg), + Path: file.path, + Range: rg.range_, + LineContent: rg.lineContent, + IsDefinition: candidateSymbols.Contains(file.path, rg.range_), }) } return matches diff --git a/internal/codeintel/codenav/transport/graphql/BUILD.bazel b/internal/codeintel/codenav/transport/graphql/BUILD.bazel index fe578fd3983..b8579b18366 100644 --- a/internal/codeintel/codenav/transport/graphql/BUILD.bazel +++ b/internal/codeintel/codenav/transport/graphql/BUILD.bazel @@ -19,7 +19,6 @@ go_library( "root_resolver_stencil.go", "root_resolver_usages.go", "util_cursor.go", - "util_lines.go", "util_locations.go", ], importpath = "github.com/sourcegraph/sourcegraph/internal/codeintel/codenav/transport/graphql", @@ -29,7 +28,6 @@ go_library( "//cmd/frontend/graphqlbackend/graphqlutil", "//internal/api", "//internal/authz", - "//internal/byteutils", "//internal/codeintel/codenav", "//internal/codeintel/codenav/shared", "//internal/codeintel/core", diff --git a/internal/codeintel/codenav/transport/graphql/root_resolver.go b/internal/codeintel/codenav/transport/graphql/root_resolver.go index 33c3d204aa7..9e22860f326 100644 --- a/internal/codeintel/codenav/transport/graphql/root_resolver.go +++ b/internal/codeintel/codenav/transport/graphql/root_resolver.go @@ -251,7 +251,6 @@ func (r *rootResolver) UsagesForSymbol(ctx context.Context, unresolvedArgs *reso } remainingCount := int(args.RemainingCount) provsForSCIPData := args.Symbol.ProvenancesForSCIPData() - linesGetter := newCachedLinesGetter(r.gitserverClient, 5*1024*1024 /* 5MB */) usageResolvers := []resolverstubs.UsageResolver{} if provsForSCIPData.Precise { @@ -284,7 +283,7 @@ func (r *rootResolver) UsagesForSymbol(ctx context.Context, unresolvedArgs *reso } } else { for _, result := range syntacticResult.Matches { - usageResolvers = append(usageResolvers, NewSyntacticUsageResolver(result, args.Repo, args.CommitID, linesGetter)) + usageResolvers = append(usageResolvers, NewSyntacticUsageResolver(result, args.Repo, args.CommitID)) } numSyntacticResults = len(syntacticResult.Matches) remainingCount = remainingCount - numSyntacticResults @@ -304,7 +303,7 @@ func (r *rootResolver) UsagesForSymbol(ctx context.Context, unresolvedArgs *reso } } else { for _, result := range results { - usageResolvers = append(usageResolvers, NewSearchBasedUsageResolver(result, args.Repo, args.CommitID, linesGetter)) + usageResolvers = append(usageResolvers, NewSearchBasedUsageResolver(result, args.Repo, args.CommitID)) } } } diff --git a/internal/codeintel/codenav/transport/graphql/root_resolver_usages.go b/internal/codeintel/codenav/transport/graphql/root_resolver_usages.go index 7f645845941..619ea0aef7a 100644 --- a/internal/codeintel/codenav/transport/graphql/root_resolver_usages.go +++ b/internal/codeintel/codenav/transport/graphql/root_resolver_usages.go @@ -32,13 +32,13 @@ type usageResolver struct { symbol *symbolInformationResolver provenance resolverstubs.CodeGraphDataProvenance kind resolverstubs.SymbolUsageKind - linesGetter LinesGetter + lineContent string usageRange *usageRangeResolver } var _ resolverstubs.UsageResolver = &usageResolver{} -func NewSyntacticUsageResolver(usage codenav.SyntacticMatch, repository types.Repo, revision api.CommitID, linesGetter LinesGetter) resolverstubs.UsageResolver { +func NewSyntacticUsageResolver(usage codenav.SyntacticMatch, repository types.Repo, revision api.CommitID) resolverstubs.UsageResolver { var kind resolverstubs.SymbolUsageKind if usage.IsDefinition { kind = resolverstubs.UsageKindDefinition @@ -51,7 +51,7 @@ func NewSyntacticUsageResolver(usage codenav.SyntacticMatch, repository types.Re }, provenance: resolverstubs.ProvenanceSyntactic, kind: kind, - linesGetter: linesGetter, + lineContent: usage.LineContent, usageRange: &usageRangeResolver{ repository: repository, revision: revision, @@ -60,7 +60,7 @@ func NewSyntacticUsageResolver(usage codenav.SyntacticMatch, repository types.Re }, } } -func NewSearchBasedUsageResolver(usage codenav.SearchBasedMatch, repository types.Repo, revision api.CommitID, linesGetter LinesGetter) resolverstubs.UsageResolver { +func NewSearchBasedUsageResolver(usage codenav.SearchBasedMatch, repository types.Repo, revision api.CommitID) resolverstubs.UsageResolver { var kind resolverstubs.SymbolUsageKind if usage.IsDefinition { kind = resolverstubs.UsageKindDefinition @@ -71,7 +71,7 @@ func NewSearchBasedUsageResolver(usage codenav.SearchBasedMatch, repository type symbol: nil, provenance: resolverstubs.ProvenanceSearchBased, kind: kind, - linesGetter: linesGetter, + lineContent: usage.LineContent, usageRange: &usageRangeResolver{ repository: repository, revision: revision, @@ -106,18 +106,7 @@ func (u *usageResolver) UsageRange(ctx context.Context) (resolverstubs.UsageRang func (u *usageResolver) SurroundingContent(ctx context.Context, args *struct { *resolverstubs.SurroundingLines `json:"surroundingLines"` }) (string, error) { - lines, err := u.linesGetter.Get( - ctx, - u.usageRange.repository.Name, - u.usageRange.revision, - u.usageRange.path.RawValue(), - int(u.usageRange.range_.Start.Line-*args.LinesBefore), - int(u.usageRange.range_.End.Line+*args.LinesAfter+1), - ) - if err != nil { - return "", err - } - return string(lines), nil + return u.lineContent, nil } func (u *usageResolver) UsageKind() resolverstubs.SymbolUsageKind { diff --git a/internal/codeintel/codenav/transport/graphql/util_lines.go b/internal/codeintel/codenav/transport/graphql/util_lines.go deleted file mode 100644 index 9e7f7fdecff..00000000000 --- a/internal/codeintel/codenav/transport/graphql/util_lines.go +++ /dev/null @@ -1,94 +0,0 @@ -package graphql - -import ( - "context" - "io" - "sync" - - "github.com/sourcegraph/sourcegraph/internal/api" - "github.com/sourcegraph/sourcegraph/internal/byteutils" - "github.com/sourcegraph/sourcegraph/internal/gitserver" -) - -type LinesGetter interface { - Get(ctx context.Context, repo api.RepoName, commit api.CommitID, path string, startLine, endLine int) ([]byte, error) -} - -type cacheKey struct { - repo api.RepoName - revision api.CommitID - path string -} - -type cacheValue struct { - contents []byte - index byteutils.LineIndex -} - -type cachedLinesGetter struct { - mu sync.RWMutex - cache map[cacheKey]cacheValue - maxCachedBytes int - freeBytes int - gitserver gitserver.Client -} - -var _ LinesGetter = (*cachedLinesGetter)(nil) - -func newCachedLinesGetter(gitserver gitserver.Client, size int) *cachedLinesGetter { - return &cachedLinesGetter{ - cache: make(map[cacheKey]cacheValue), - maxCachedBytes: size, - freeBytes: size, - gitserver: gitserver, - } -} - -func (c *cachedLinesGetter) Get(ctx context.Context, repo api.RepoName, commit api.CommitID, path string, startLine, endLine int) ([]byte, error) { - key := cacheKey{repo, commit, path} - - c.mu.RLock() - if value, ok := c.cache[key]; ok { - c.mu.RUnlock() - start, end := value.index.LinesRange(startLine, endLine) - return value.contents[start:end], nil - } - c.mu.RUnlock() - - r, err := c.gitserver.NewFileReader(ctx, repo, commit, path) - if err != nil { - return nil, err - } - defer r.Close() - - contents, err := io.ReadAll(r) - if err != nil { - return nil, err - } - index := byteutils.NewLineIndex(contents) - start, end := index.LinesRange(startLine, endLine) - lines := contents[start:end] - - if len(contents) > c.maxCachedBytes { - // Don't both trying to fit it in the cache - return lines, nil - } - - c.mu.Lock() - defer c.mu.Unlock() - - // Make room for the file in the cache. This cache doesn't need to be high - // performance -- just randomly delete things until we have room. - for k, v := range c.cache { - if c.freeBytes >= len(contents) { - break - } - delete(c.cache, k) - c.freeBytes += len(v.contents) - } - - c.cache[key] = cacheValue{contents, index} - c.freeBytes -= len(contents) - - return lines, nil -} diff --git a/internal/codeintel/core/BUILD.bazel b/internal/codeintel/core/BUILD.bazel index c0c673dd9a8..c75297b669b 100644 --- a/internal/codeintel/core/BUILD.bazel +++ b/internal/codeintel/core/BUILD.bazel @@ -4,6 +4,7 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library") go_library( name = "core", srcs = [ + "async.go", "option.go", "paths.go", "upload.go", diff --git a/internal/codeintel/core/async.go b/internal/codeintel/core/async.go new file mode 100644 index 00000000000..d80fe2caed7 --- /dev/null +++ b/internal/codeintel/core/async.go @@ -0,0 +1,89 @@ +package core + +func Join[A any, B any](fa func() (A, error), fb func() (B, error)) (A, B, error) { + ca := make(chan A) + cb := make(chan B) + cerrA := make(chan error) + cerrB := make(chan error) + go func() { + a, err := fa() + if err != nil { + cerrA <- err + return + } + ca <- a + }() + go func() { + b, err := fb() + if err != nil { + cerrB <- err + return + } + cb <- b + }() + select { + // TODO: combine potential multiple errors? + case err := <-cerrA: + return *new(A), *new(B), err + case err := <-cerrB: + return *new(A), *new(B), err + case a := <-ca: + select { + case b := <-cb: + return a, b, nil + case err := <-cerrB: + return *new(A), *new(B), err + } + case b := <-cb: + select { + case a := <-ca: + return a, b, nil + case err := <-cerrA: + return *new(A), *new(B), err + } + } +} + +func Race[A any, B any](fa func(stop chan struct{}) (A, error), fb func(stop chan struct{}) (B, error)) (A, B, bool, error) { + ca := make(chan A) + cb := make(chan B) + cerr := make(chan error) + stop := make(chan struct{}) + + go func() { + a, err := fa(stop) + if err != nil { + cerr <- err + return + } + ca <- a + close(ca) + }() + go func() { + b, err := fb(stop) + if err != nil { + cerr <- err + return + } + cb <- b + close(cb) + }() + + select { + case a := <-ca: + stop <- struct{}{} + return a, *new(B), true, nil + case b := <-cb: + stop <- struct{}{} + return *new(A), b, false, nil + case <-cerr: + select { + case a := <-ca: + return a, *new(B), true, nil + case b := <-cb: + return *new(A), b, false, nil + case err := <-cerr: + return *new(A), *new(B), false, err + } + } +}